In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import numpy as np
from tqdm.notebook import tqdm
from functools import wraps
import warnings
from matplotlib import pyplot as plt

def gen_D_binary(S, M, key=None):
    """Generate a decoder matrix D where D[m,s] = P(m|s)"""
    if key is None:
        key = jax.random.PRNGKey(0)
    
    D = jnp.zeros((M, S))
    for col in range(S):
        subkey1, subkey2, key = jax.random.split(key, 3)
        # For each symbol, randomly assign probabilities to messages
        num_messages = jax.random.randint(subkey1, (), 1, M+1)
        selected_messages = jax.random.choice(subkey2, M, shape=(M,), replace=True)
        unique_messages = jnp.unique(selected_messages[:num_messages])
        D = D.at[unique_messages, col].set(1.0 / len(unique_messages))
    return D

def gen_optimal_encoder(D):
    """
    Generate the optimal encoder E given decoder D.
    D: (M, S) matrix where D[m,s] = P(m|s)
    Returns E: (S, M) matrix where E[s,m] = P(s|m)
    """
    M, S = D.shape
    
    # Compute marginal P(s) under uniform P(m) = 1/M
    P_m = 1.0 / M
    P_s = jnp.sum(D * P_m, axis=0)  # P(s) = sum_m P(m|s)P(m)
    
    # Compute encoder using Bayes rule
    # E(s|m) = D(m|s)P(s) / sum_s' D(m|s')P(s')
    # Vectorized computation
    numerator = D.T * P_s[:, jnp.newaxis]  # (S, M)
    denominator = jnp.sum(D * P_s[jnp.newaxis, :], axis=1)  # (M,)
    E = numerator / (denominator[jnp.newaxis, :])
    
    return E

@jit
def calculate_mutual_information_jax(E_i, D_j, P_m=None):
    """
    JAX-compatible calculation of normalized mutual information.
    
    Parameters:
    -----------
    E_i : jnp.ndarray
        Encoder matrix for agent i, shape (|S|, |M|) where E_i[s,m] = P(s|m)
    D_j : jnp.ndarray  
        Decoder matrix for agent j, shape (|M|, |S|) where D_j[m',s] = P(m'|s)
    P_m : jnp.ndarray, optional
        Prior distribution over messages, shape (|M|,). If None, assumes uniform.
    
    Returns:
    --------
    I_ij_normalized : float
        Normalized mutual information I(M_i; M_j')/H(M), in range [0,1]
    """
    
    # Get dimensions
    num_symbols, num_messages = E_i.shape
    
    # Set uniform prior if not provided
    if P_m is None:
        P_m = jnp.ones(num_messages) / num_messages
    
    # Compute composite channel matrix C_ij
    C_ij = D_j @ E_i
    
    # Compute joint probability matrix
    P_joint = C_ij * P_m[jnp.newaxis, :]
    
    # Compute marginal P(m')
    P_m_prime = jnp.sum(P_joint, axis=1)
    
    # Compute mutual information using vectorized operations
    epsilon = 1e-10
    # Create outer product of marginals
    P_marginal_product = jnp.outer(P_m_prime, P_m)

    # TODO: is this logic actually robust to p=0?
    # Compute MI with numerical stability
    # Using xlogy for x * log(y) which handles x=0 case properly
    log_ratio = P_joint / (P_marginal_product + epsilon)
    mutual_info = jnp.sum(jax.scipy.special.xlogy(P_joint, log_ratio))
    
    # Normalize by entropy of M
    H_M = -jnp.sum(P_m * jnp.log(P_m))
    I_ij_normalized = mutual_info / H_M
    
    return I_ij_normalized

def construct_matrix_G(I, regularization=1e-6):
    """
    Construct matrix G from matrix I by solving the system:
    G = (I - MI/N)^(-1)
    
    Uses CPU computation to avoid Metal/GPU compatibility issues.
    
    Parameters:
    -----------
    I : numpy.ndarray
        Input matrix I (mutual information matrix)
    regularization : float
        Small value added to diagonal for numerical stability
    
    Returns:
    --------
    numpy.ndarray
        The constructed matrix G
    """
    # Convert to numpy and use CPU computation to avoid Metal backend issues
    I_np = np.array(I) / N
    n = I_np.shape[0]
    identity = np.eye(n)
    
    # Add small regularization to ensure matrix is invertible
    matrix_to_invert = identity - I_np + regularization * identity
    
    try:
        # Use numpy for CPU-only computation
        G_np = np.linalg.solve(matrix_to_invert, identity)
    except np.linalg.LinAlgError:
        # Fallback: if solve fails, use pseudoinverse
        try:
            G_np = np.linalg.pinv(matrix_to_invert) @ identity
        except:
            # Ultimate fallback: return identity matrix
            G_np = identity
    
    # Convert back to JAX array
    return jnp.array(G_np)

@jit
def mutual_info_sum(D, E_arr, v=None):
    """
    Calculate weighted sum of mutual information for decoder D with multiple encoders.
    
    Parameters:
    -----------
    D : jnp.ndarray
        Decoder matrix, shape (|M|, |S|)
    E_arr : list of jnp.ndarray
        List of encoder matrices, each shape (|S|, |M|)
    v : jnp.ndarray, optional
        Weights for each encoder, default is uniform
    
    Returns:
    --------
    weighted_sum : float
        Weighted sum of mutual information values
    """
    n = len(E_arr)
    if v is None:
        v = jnp.ones(n) / n
    
    # Stack encoders for vectorized computation
    E_stack = jnp.stack(E_arr, axis=0)  # (n, |S|, |M|)
    
    # Vectorized mutual information calculation
    vmap_mi = vmap(lambda E: calculate_mutual_information_jax(E, D))
    mi_values = vmap_mi(E_stack)
    
    return jnp.sum(v * mi_values)

@jit
def info_grad(D, E_arr, v=None):
    """
    Compute gradient of mutual information sum with respect to decoder D.
    
    Parameters:
    -----------
    D : jnp.ndarray  
        Decoder matrix for agent j, shape (|M|, |S|) where D_j[m',s] = P(m'|s)
    E_arr : list[jnp.ndarray] of length n
        List of encoder matrices for agent i, shape (|S|, |M|) where E_i[s,m] = P(s|m)
    v : jnp.ndarray
        A n-vector of weights to prioritize the mutual information from different channels
        default = uniform weights
    
    Returns:
    --------
    grad_D: jnp.ndarray
        The gradient of the weighted mutual information sum with respect to D
    """
    # Create gradient function
    grad_fn = grad(mutual_info_sum, argnums=0)
    
    # Compute gradient
    grad_D = grad_fn(D, E_arr, v)
    
    return grad_D

def normalize_decoder(D):
    """
    Normalize decoder matrix so columns sum to 1.
    """
    return (D + 1e-10) / (jnp.sum(D + 1e-10, axis=0, keepdims=True))

def update(D, E_arr, learning_rate=0.01, v=None):
    """
    Update decoder D to maximize weighted mutual information from encoders E_arr.
    
    Parameters:
    -----------
    D : jnp.ndarray
        Current decoder matrix
    E_arr : list[jnp.ndarray]
        List of encoder matrices to receive from
    learning_rate : float
        Step size for gradient ascent
    v : jnp.ndarray, optional
        Weights for each encoder
    
    Returns:
    --------
    D_new : jnp.ndarray
        Updated and normalized decoder matrix
    """
    # Compute gradient
    grad_D = info_grad(D, E_arr, v)
    
    # Gradient ascent step
    D_new = D + learning_rate * grad_D

    # Set any negative values to 0
    D_new = jnp.maximum(D_new, 0.0)
    
    # Project back to probability simplex by normalization
    D_new = normalize_decoder(D_new)
    
    return D_new

# Additional utilities for the full system

def initialize_agents(N, S, M, key=None):
    """
    Initialize N agents with random decoders and optimal encoders.
    
    Parameters:
    -----------
    N : int
        Number of agents
    S : int
        Size of shared symbol space
    M : int
        Size of message space
    key : jax.random.PRNGKey
        Random key for initialization
    
    Returns:
    --------
    decoders : list of jnp.ndarray
        List of decoder matrices
    encoders : list of jnp.ndarray
        List of optimal encoder matrices
    """
    if key is None:
        key = jax.random.PRNGKey(0)
    
    decoders = []
    encoders = []
    
    for i in range(N):
        key, subkey = jax.random.split(key)
        # Initialize random decoder
        D = jax.random.uniform(subkey, (M, S))
        D = normalize_decoder(D)
        decoders.append(D)
        
        # Compute optimal encoder
        E = gen_optimal_encoder(D)
        encoders.append(E)
    
    return decoders, encoders

def simulate_convergence(
    N, S, M, num_iterations=100, learning_rate=0.01, key=None, 
    home_planet=None, gossip=None):
    """
    Simulate the convergence of the multi-agent language system.
    
    Parameters:
    -----------
    N : int
        Number of agents
    S : int
        Size of shared symbol space
    M : int
        Size of message space
    num_iterations : int
        Number of update iterations
    learning_rate : float
        Learning rate for decoder updates
    key : jax.random.PRNGKey
        Random key for initialization
    
    Returns:
    --------
    history : dict
        Dictionary containing evolution history
    """
    # Initialize agents
    decoders, encoders = initialize_agents(N, S, M, key)

    # Record mutual information matrix
    mi_matrix = jnp.zeros((N, N))
    for i in range(N):
        for j in range(N):
            if i != j:
                mi_matrix = mi_matrix.at[i, j].set(
                    calculate_mutual_information_jax(encoders[j], decoders[i])
                )
    
    # Track history
    history = {
        'mutual_info': [mi_matrix],
        'decoders': [decoders],
        'encoders': [encoders]
    }
    
    for iteration in tqdm(range(num_iterations)):
        # Update each agent's decoder
        new_decoders = []
        new_encoders = []

        if gossip is not None:
            G = construct_matrix_G(mi_matrix)
        
        for i in range(N):
            # Get encoders from all other agents
            other_encoders = [encoders[j] for j in range(N) if j != i]

            if home_planet is not None:
                other_encoders.append(history['encoders'][0][i]) # Also optimize for mutual info with original language

            v = np.ones(len(other_encoders))
            
            if isinstance(home_planet, float) or isinstance(home_planet, int):
                v[-1] = home_planet # Weight according to passed-in value

            if gossip is not None:
                v[:N-1] = gossip * G[:, jnp.arange(N) != i].sum(0)

            v /= v.sum()
            
            # Update decoder
            D_new = update(decoders[i], other_encoders, learning_rate, v=v)
            new_decoders.append(D_new)
            
            # Compute new optimal encoder
            E_new = gen_optimal_encoder(D_new)
            new_encoders.append(E_new)
        
        decoders = new_decoders
        encoders = new_encoders
        
        # Record mutual information matrix
        mi_matrix = jnp.zeros((N, N))
        for i in range(N):
            for j in range(N):
                if i != j:
                    mi_matrix = mi_matrix.at[i, j].set(
                        calculate_mutual_information_jax(encoders[j], decoders[i])
                    )
        
        history['mutual_info'].append(mi_matrix)
        history['decoders'].append(decoders)
        history['encoders'].append(encoders)
    
    return history

In [None]:

# Example usage and testing
# Set parameters
N = 10 # Number of agents
S = 100 # Symbol space size
M = 100 # Message space size
home_planet = None
gossip = None
learning_rate=0.1
iterations = 1000

# Initialize
key = jax.random.PRNGKey(42)
decoders, encoders = initialize_agents(N, S, M, key)

# Test gradient computation
D_test = decoders[0]
E_test = encoders[1:]

# Compute gradient
grad_D = info_grad(D_test, E_test)
print(f"Gradient shape: {grad_D.shape}")
print(f"Gradient norm: {jnp.linalg.norm(grad_D):.4f}")

# Test update
D_new = update(D_test, E_test, learning_rate=learning_rate)
print(f"Decoder still normalized: {jnp.allclose(jnp.sum(D_new, axis=0), 1.0)}")

# Run short simulation
history = simulate_convergence(N, S, M, num_iterations=iterations, learning_rate=learning_rate, key=key, 
    home_planet=home_planet, gossip=1)

print(f"\nInitial average mutual information: {jnp.mean(history['mutual_info'][0]):.4f}")
print(f"Final average mutual information: {jnp.mean(history['mutual_info'][-1]):.4f}")

In [None]:
# Plot the results from the single simulation
plt.figure(figsize=(6, 4))

# Calculate mean MI and percentiles for each iteration
mean_mi = []

for mi_matrix in history['mutual_info']:
    # Get non-zero, non-diagonal elements
    
    mean_mi.append(jnp.mean(mi_matrix[mi_matrix > 0]))

iterations_range = range(len(mean_mi))

# Plot mean line
plt.plot(iterations_range, mean_mi, 'blue', linewidth=2, label='Mean Mutual Information')

plt.xlabel('Iteration')
plt.ylabel('Mutual Information')
plt.title(f'Convergence of Multi-Agent Language System\nN={N}, S={S}, M={M}, Learning Rate={learning_rate}')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()

# Save the plot
plt.savefig(f"Single_simulation_S{S}_M{M}_N{N}_I{iterations}_LR{learning_rate}_HP{home_planet}_G{gossip}.png", 
            dpi=300, bbox_inches='tight')
plt.show()

# Print summary statistics
print(f"Initial mean MI: {mean_mi[0]:.4f}")
print(f"Final mean MI: {mean_mi[-1]:.4f}")
print(f"Total improvement: {mean_mi[-1] - mean_mi[0]:.4f}")
print(f"Percentage improvement: {(mean_mi[-1] - mean_mi[0])/mean_mi[0]*100:.1f}%")

In [None]:
# Example usage and testing
# Set parameters
N = 10 # Number of agents
S = 100 # Symbol space size
M = 100 # Message space size
home_planet = 1
gossip = 0
learning_rate=0.05
iterations = 100

# Define different home planet bias values to test
home_planet_values = [0, 1, 10, 100, 1000]
colors = ['blue', 'orange', 'green', 'red', 'purple']

# Store results for each simulation
all_results = {}

# Run simulations for each home planet value
for hp_value in home_planet_values:
    print(f"\nRunning simulation with home_planet = {hp_value}")
    
    # Reset random key for consistency
    sim_key = jax.random.PRNGKey(42)
    
    # Run simulation
    history_hp = simulate_convergence(N, S, M, num_iterations=iterations, 
                                     learning_rate=learning_rate, key=sim_key, 
                                     home_planet=hp_value if hp_value > 0 else None)
    
    # Calculate mean MI and percentiles for each iteration
    mean_mi_hp = []
    p25_hp = []
    p75_hp = []
    
    for mi_matrix in history_hp['mutual_info']:
        # Get non-zero, non-diagonal elements
        mask = (mi_matrix > 0) & (mi_matrix != jnp.diag(jnp.diag(mi_matrix)))
        valid_mi = mi_matrix[mask]
        
        mean_mi_hp.append(jnp.mean(valid_mi))
        p25_hp.append(jnp.percentile(valid_mi, 25))
        p75_hp.append(jnp.percentile(valid_mi, 75))
    
    all_results[hp_value] = {
        'mean_mi': mean_mi_hp,
        'p25': p25_hp,
        'p75': p75_hp,
        'history': history_hp
    }
    
    print(f"Final mean MI: {mean_mi_hp[-1]:.4f}")

In [None]:
from matplotlib import pyplot as plt


# Create the comparison plot
plt.figure(figsize=(6, 4))

for i, hp_value in enumerate(home_planet_values):
    mean_mi_data = all_results[hp_value]['mean_mi']
    p25_data = all_results[hp_value]['p25']
    p75_data = all_results[hp_value]['p75']

    hp_pct = int(hp_value / (N - 1 + hp_value) * 100)
    
    label = f'Home Planet Bias = {hp_pct}%' if hp_value > 0 else 'No Home Planet'
    iterations_range = range(len(mean_mi_data))
    
    # Plot mean line
    plt.plot(iterations_range, mean_mi_data, 
             color=colors[i], linewidth=2, label=label)
    
    # Add shaded region for 25th-75th percentiles
    plt.fill_between(iterations_range, p25_data, p75_data, 
                     color=colors[i], alpha=0.2)

plt.xlabel('Iteration')
plt.ylabel('Mean Mutual Information')
plt.title('Convergence Comparison: Effect of Home Planet Bias\n(Shaded regions show 25th-75th percentiles)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(f"Convergence_comparison_S{S}_M{M}_N{N}_I{iterations}_LR{learning_rate}.png", 
            dpi=300, bbox_inches='tight')
plt.show()

# Print summary statistics
print("\n" + "="*60)
print("SUMMARY STATISTICS")
print("="*60)
for hp_value in home_planet_values:
    mean_mi_data = all_results[hp_value]['mean_mi']
    initial_mi = mean_mi_data[0]
    final_mi = mean_mi_data[-1]
    improvement = final_mi - initial_mi
    
    label = f"HP={hp_value}" if hp_value > 0 else "No HP"
    print(f"{label:8s}: Initial={initial_mi:.4f}, Final={final_mi:.4f}, "
          f"Improvement={improvement:.4f} ({improvement/initial_mi*100:.1f}%)")