# Language Emergence Lab Demo

This notebook demonstrates how artificial agents learn to communicate about objects through referential games.

## What is Language Emergence?

The Language Emergence Lab studies how communication protocols emerge from interaction between neural agents:

1. **Speaker-Listener Architecture**: Neural agents that learn to communicate about visual objects
2. **Discrete Communication**: Agents use discrete tokens (like words) for communication  
3. **Emergent Languages**: Communication protocols develop through interaction and reinforcement
4. **Language Analysis**: We can analyze the systematicity and efficiency of emergent languages

**Key Features:**
- Multi-agent referential games with visual objects
- Contrastive learning achieving 80%+ accuracy vs ~20% baseline
- Population dynamics and cultural transmission studies
- Comprehensive analysis of emergent language patterns

## Setup and Installation

In [None]:
! [ ! -d "emergent" ] && git clone https://github.com/bangyen/emergent.git
! cd emergent && pip install -e .

print("Setup complete!")

In [None]:
import os

os.chdir("./emergent")
print(f"Current working directory: {os.getcwd()}")

## Imports and Configuration

In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import the Language Emergence Lab components
from langlab import Speaker, Listener, sample_scene, CommunicationConfig

# Set device and random seeds for reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

In [None]:

# Configuration for our referential game
config = CommunicationConfig(
    vocabulary_size=10,      # Size of the message vocabulary
    message_length=1,        # Length of messages (1 token for simplicity)
    hidden_size=64,          # Hidden layer size
)

# Learning parameters
learning_rate = 0.001
batch_size = 32

print("Configuration:")
print(f"Vocabulary size: {config.vocabulary_size}")
print(f"Message length: {config.message_length}")
print(f"Hidden size: {config.hidden_size}")

## Quick Demo: Language Emergence

Let's run a simple referential game where Speaker and Listener agents learn to communicate about objects:

In [None]:
# Initialize Speaker and Listener agents
speaker = Speaker(config).to(device)
listener = Listener(config).to(device)

# Set up optimizers
speaker_optimizer = torch.optim.Adam(speaker.parameters(), lr=learning_rate)
listener_optimizer = torch.optim.Adam(listener.parameters(), lr=learning_rate)

print("Agents initialized!")
print(f"Speaker parameters: {sum(p.numel() for p in speaker.parameters()):,}")
print(f"Listener parameters: {sum(p.numel() for p in listener.parameters()):,}")

# Sample a scene to see what objects look like
scene_objects, target_idx = sample_scene(k=3, seed=42)
print(f"\nSample scene with {len(scene_objects)} objects:")
for i, obj in enumerate(scene_objects):
    marker = " ← TARGET" if i == target_idx else ""
    print(f"  Object {i}: {obj}{marker}")

In [None]:
# Simple training loop for referential game
from langlab.data.world import encode_object

def train_step(speaker, listener, scene_objects, target_idx, speaker_opt, listener_opt):
    """Single training step for the referential game."""
    speaker.train()
    listener.train()
    
    # Encode the scene objects
    scene_tensor = torch.stack([encode_object(obj) for obj in scene_objects]).to(device)
    target_tensor = torch.tensor([target_idx], dtype=torch.long).to(device)  # Add batch dimension
    
    # Speaker generates message about target object
    target_object = scene_tensor[target_idx:target_idx+1]
    message_logits, message_tokens, gesture_logits, gesture_tokens = speaker(target_object)
    
    # Add batch dimension to scene_tensor for Listener (batch_size=1, num_candidates=3, object_dim)
    scene_tensor_batched = scene_tensor.unsqueeze(0)  # Shape: (1, 3, object_dim)
    
    # Listener tries to identify target from message and scene
    listener_logits = listener(message_tokens, scene_tensor_batched)
    
    # Compute loss (cross-entropy)
    loss = torch.nn.functional.cross_entropy(listener_logits, target_tensor)
    
    # Backpropagation
    speaker_opt.zero_grad()
    listener_opt.zero_grad()
    loss.backward()
    speaker_opt.step()
    listener_opt.step()
    
    return loss.item()

# Train for a few steps
print("Training agents...")
losses = []
accuracies = []

for step in tqdm(range(100), desc="Training"):
    # Sample a new scene each step
    scene_objects, target_idx = sample_scene(k=3, seed=step)
    
    # Training step
    loss = train_step(speaker, listener, scene_objects, target_idx, 
                     speaker_optimizer, listener_optimizer)
    losses.append(loss)
    
    # Evaluate accuracy every 10 steps
    if step % 10 == 0:
        speaker.eval()
        listener.eval()
        with torch.no_grad():
            scene_tensor = torch.stack([encode_object(obj) for obj in scene_objects]).to(device)
            target_object = scene_tensor[target_idx:target_idx+1]
            message_logits, message_tokens, gesture_logits, gesture_tokens = speaker(target_object)
            # Add batch dimension to scene_tensor for Listener
            scene_tensor_batched = scene_tensor.unsqueeze(0)  # Shape: (1, 3, object_dim)
            listener_logits = listener(message_tokens, scene_tensor_batched)
            predicted = torch.argmax(listener_logits, dim=1)
            accuracy = (predicted == target_idx).float().mean().item()
            accuracies.append(accuracy)

print("\nTraining complete!")
print(f"Final loss: {losses[-1]:.4f}")
print(f"Final accuracy: {accuracies[-1]:.2%}")

In [None]:
# Visualize training progress
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

# Plot raw loss
ax.plot(losses, alpha=0.3, color='lightblue', label='Raw Loss')

# Calculate and plot smoothed loss
window_size = 10
if len(losses) >= window_size:
    smoothed_losses = []
    for i in range(len(losses)):
        start_idx = max(0, i - window_size + 1)
        smoothed_loss = sum(losses[start_idx:i+1]) / (i - start_idx + 1)
        smoothed_losses.append(smoothed_loss)
    
    ax.plot(smoothed_losses, color='blue', linewidth=2, label=f'Smoothed Loss (window={window_size})')

ax.set_title('Training Loss')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.grid(True)
ax.legend()

plt.tight_layout()
plt.show()

print(f"Random baseline accuracy: {1/3:.2%} (3 objects)")
print(f"Final accuracy: {accuracies[-1]:.2%}")
print(f"Improvement over random: {(accuracies[-1] - 1/3)*100:.1f} percentage points")


## Key Takeaways

This demo shows the core concepts of the Language Emergence Lab:

1. **Multi-Agent Communication**: Speaker and Listener agents learn to communicate about visual objects
2. **Emergent Languages**: Communication protocols develop through interaction and reinforcement
3. **Discrete Messages**: Agents use discrete tokens (like words) for communication
4. **Language Analysis**: We can analyze the systematicity and efficiency of emergent languages

### Next Steps

- **Full Training**: Use `langlab train --steps 5000 --k 5 --v 10` for longer training
- **Population Studies**: Try `langlab pop-train` for cultural transmission experiments  
- **Interactive Dashboard**: Launch `langlab dash` for visualization
- **Advanced Features**: Explore contrastive learning, sequence models, and grid worlds

### Research Applications

- **Language Evolution**: Study how communication protocols change over time
- **Compositionality**: Investigate systematic vs. holistic language emergence
- **Cultural Transmission**: Model language learning across agent populations
- **Pragmatic Inference**: Explore context-dependent communication strategies

For more information, visit the [GitHub repository](https://github.com/bangyen/emergent).
