# Toy Experiment: Class Label Latents

**Goal**: Test whether loss functions can prevent speckles while encoding information.

**Setup**:
- 100 learnable 32×32×3 RGB latents (one per class)
- Tiny MLP decoder: latent → class prediction
- Same losses as real training (InfoNCE, batch InfoNCE, magnitude)

**Question**: Can we encode 100 classes smoothly, or do speckles emerge?

In [None]:
import sys
sys.path.insert(0, '../phase1')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Import loss functions from phase1
from infonce_losses import InfoNCEPatchLoss, MagnitudeLoss
from batch_infonce_loss import BatchInfoNCELoss

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

## 1. Setup: Learnable Latents and Decoder

In [None]:
# Configuration
num_classes = 100
latent_h = 32
latent_w = 32
latent_c = 3

# Learnable latents (one per class)
# Initialize with small random values
latents = nn.Parameter(torch.randn(num_classes, latent_h, latent_w, latent_c) * 0.1)
print(f'Latents shape: {latents.shape}')
print(f'Latents parameters: {latents.numel():,}')

In [None]:
# Tiny MLP Decoder
class TinyClassifier(nn.Module):
    def __init__(self, input_dim=32*32*3, hidden_dim=256, num_classes=100):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, latents_bhwc):
        """
        Args:
            latents_bhwc: [B, H, W, C] latent images
        Returns:
            logits: [B, num_classes]
        """
        B = latents_bhwc.shape[0]
        x = latents_bhwc.reshape(B, -1)  # Flatten
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)
        return logits

decoder = TinyClassifier().to(device)
print(f'\nDecoder parameters: {sum(p.numel() for p in decoder.parameters()):,}')

## 2. Loss Functions

In [None]:
# Loss function configuration
config = {
    'lambda_class': 1.0,        # Classification (must encode class ID)
    'lambda_infonce': 2.0,      # Spatial smoothness within latent
    'lambda_batch_infonce': 1.0,  # Different classes look different
    'lambda_magnitude': 5.0,    # Prevent collapse
    
    # InfoNCE parameters
    'patch_size': 3,
    'num_samples': 25,
    'temperature': 1.0,
    'positive_radius': 3.0,
    'negative_radius': 11.0,
    
    # Batch InfoNCE parameters
    'batch_infonce_temperature': 0.5,
    'batch_infonce_cross_image_radius': 2.0,
    'batch_infonce_num_cross_images': 8,
    
    # Magnitude
    'min_magnitude': 0.3,
}

# Initialize loss functions
class_loss_fn = nn.CrossEntropyLoss()

infonce_loss_fn = InfoNCEPatchLoss(
    patch_size=config['patch_size'],
    num_samples=config['num_samples'],
    temperature=config['temperature'],
    positive_radius=config['positive_radius'],
    negative_radius=config['negative_radius']
).to(device)

batch_infonce_loss_fn = BatchInfoNCELoss(
    patch_size=config['patch_size'],
    num_samples=config['num_samples'],
    temperature=config['batch_infonce_temperature'],
    cross_image_radius=config['batch_infonce_cross_image_radius'],
    num_cross_images=config['batch_infonce_num_cross_images']
).to(device)

magnitude_loss_fn = MagnitudeLoss(
    min_magnitude=config['min_magnitude']
).to(device)

print('Loss functions initialized')
print(f"  lambda_class: {config['lambda_class']}")
print(f"  lambda_infonce: {config['lambda_infonce']}")
print(f"  lambda_batch_infonce: {config['lambda_batch_infonce']}")
print(f"  lambda_magnitude: {config['lambda_magnitude']}")

## 3. Training Loop

In [None]:
# Optimizer
optimizer = AdamW([
    {'params': [latents], 'lr': 1e-2},
    {'params': decoder.parameters(), 'lr': 1e-3}
])

# Training settings
num_steps = 2000
batch_size = 32
log_interval = 100

# Track losses
history = {
    'total': [],
    'class': [],
    'infonce': [],
    'batch_infonce': [],
    'magnitude': [],
    'accuracy': []
}

print(f'Training for {num_steps} steps with batch size {batch_size}...')
print()

In [None]:
# Training loop
latents_param = latents.to(device)

pbar = tqdm(range(num_steps))
for step in pbar:
    optimizer.zero_grad()
    
    # Sample random batch of class IDs
    class_ids = torch.randint(0, num_classes, (batch_size,), device=device)
    
    # Look up corresponding latents
    batch_latents = latents_param[class_ids]  # [batch_size, H, W, C]
    
    # Decode to class predictions
    logits = decoder(batch_latents)
    
    # 1. Classification loss
    loss_class = class_loss_fn(logits, class_ids)
    
    # 2. InfoNCE loss (spatial smoothness within each latent)
    loss_infonce = infonce_loss_fn(batch_latents)
    
    # 3. Batch InfoNCE loss (different classes look different)
    if batch_size > 1:
        loss_batch_infonce = batch_infonce_loss_fn(batch_latents)
    else:
        loss_batch_infonce = torch.tensor(0.0, device=device)
    
    # 4. Magnitude loss
    loss_magnitude = magnitude_loss_fn(batch_latents)
    
    # Total loss
    total_loss = (
        config['lambda_class'] * loss_class +
        config['lambda_infonce'] * loss_infonce +
        config['lambda_batch_infonce'] * loss_batch_infonce +
        config['lambda_magnitude'] * loss_magnitude
    )
    
    # Backward
    total_loss.backward()
    optimizer.step()
    
    # Track metrics
    with torch.no_grad():
        accuracy = (logits.argmax(dim=1) == class_ids).float().mean()
    
    history['total'].append(total_loss.item())
    history['class'].append(loss_class.item())
    history['infonce'].append(loss_infonce.item())
    history['batch_infonce'].append(loss_batch_infonce.item())
    history['magnitude'].append(loss_magnitude.item())
    history['accuracy'].append(accuracy.item())
    
    # Update progress bar
    if step % log_interval == 0:
        pbar.set_postfix({
            'loss': f"{total_loss.item():.3f}",
            'acc': f"{accuracy.item():.3f}",
            'class': f"{loss_class.item():.3f}",
            'info': f"{loss_infonce.item():.3f}",
            'batch': f"{loss_batch_infonce.item():.3f}",
            'mag': f"{loss_magnitude.item():.3f}"
        })

print('\nTraining complete!')

## 4. Visualization: Loss Curves

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

axes[0, 0].plot(history['total'])
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_xlabel('Step')
axes[0, 0].grid(True)

axes[0, 1].plot(history['accuracy'])
axes[0, 1].set_title('Classification Accuracy')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylim([0, 1.05])
axes[0, 1].grid(True)

axes[0, 2].plot(history['class'])
axes[0, 2].set_title('Classification Loss')
axes[0, 2].set_xlabel('Step')
axes[0, 2].grid(True)

axes[1, 0].plot(history['infonce'])
axes[1, 0].set_title('InfoNCE Loss (smoothness)')
axes[1, 0].set_xlabel('Step')
axes[1, 0].grid(True)

axes[1, 1].plot(history['batch_infonce'])
axes[1, 1].set_title('Batch InfoNCE Loss (diversity)')
axes[1, 1].set_xlabel('Step')
axes[1, 1].grid(True)

axes[1, 2].plot(history['magnitude'])
axes[1, 2].set_title('Magnitude Loss')
axes[1, 2].set_xlabel('Step')
axes[1, 2].grid(True)

plt.tight_layout()
plt.show()

print(f'Final accuracy: {history["accuracy"][-1]:.3f}')

## 5. Visualization: Learned Latents

Show all 100 learned latent images. Look for speckles!

In [None]:
def latent_to_rgb(latent_tensor):
    """Convert [H, W, 3] tensor to displayable RGB, normalized to [0, 1]"""
    rgb = latent_tensor.detach().cpu().numpy()
    # Normalize to [0, 1] for display
    rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min() + 1e-8)
    return np.clip(rgb, 0, 1)

# Show grid of all latents
num_rows = 10
num_cols = 10

fig, axes = plt.subplots(num_rows, num_cols, figsize=(20, 20))

for i in range(num_classes):
    row = i // num_cols
    col = i % num_cols
    
    rgb = latent_to_rgb(latents_param[i])
    
    axes[row, col].imshow(rgb)
    axes[row, col].set_title(f'{i}', fontsize=8)
    axes[row, col].axis('off')

plt.suptitle('Learned Latents for 100 Classes', fontsize=16)
plt.tight_layout()
plt.show()

print('\nExamine the images above:')
print('  - Are they smooth or speckled?')
print('  - Do different classes look different?')
print('  - Is there visible structure?')

## 6. Analysis: Speckle Detection

Quantify high-frequency content (speckles) in the latents

In [None]:
def compute_high_freq_energy(latent):
    """Measure high-frequency energy using gradient magnitude"""
    # Compute gradients (high freq = large gradients)
    grad_y = latent[1:, :, :] - latent[:-1, :, :]
    grad_x = latent[:, 1:, :] - latent[:, :-1, :]
    
    # Average gradient magnitude
    return (grad_y.abs().mean() + grad_x.abs().mean()) / 2

# Compute for all latents
with torch.no_grad():
    high_freq_energies = []
    for i in range(num_classes):
        energy = compute_high_freq_energy(latents_param[i])
        high_freq_energies.append(energy.item())

plt.figure(figsize=(10, 4))
plt.hist(high_freq_energies, bins=30)
plt.xlabel('High-Frequency Energy (gradient magnitude)')
plt.ylabel('Count')
plt.title('Distribution of High-Frequency Content Across Latents')
plt.axvline(np.mean(high_freq_energies), color='r', linestyle='--', label=f'Mean: {np.mean(high_freq_energies):.4f}')
plt.legend()
plt.grid(True)
plt.show()

print(f'Average high-frequency energy: {np.mean(high_freq_energies):.4f}')
print(f'Std dev: {np.std(high_freq_energies):.4f}')
print(f'\nHigher values = more speckles/high-frequency content')

## 7. Experiment: Test Different Loss Configurations

Compare results with different settings to see what prevents speckles

In [None]:
# Save current results
baseline_latents = latents_param.detach().clone()
baseline_energy = np.mean(high_freq_energies)
baseline_accuracy = history['accuracy'][-1]

print('Baseline (current) results saved:')
print(f'  High-freq energy: {baseline_energy:.4f}')
print(f'  Final accuracy: {baseline_accuracy:.3f}')
print(f'  Config: lambda_infonce={config["lambda_infonce"]}, positive_radius={config["positive_radius"]}')

### Experiment: Higher InfoNCE Weight

Test if increasing lambda_infonce reduces speckles

In [None]:
# You can copy the training loop above and modify config['lambda_infonce']
# Then compare the resulting latents visually and quantitatively

print('To experiment:')
print('1. Modify config["lambda_infonce"] (try 5.0, 10.0)')
print('2. Re-run training loop')
print('3. Compare learned latents')
print('4. Compare high-frequency energy')

## 8. Future: Test L2 Distance vs Cosine Similarity

Modify InfoNCEPatchLoss to use L2 distance instead of cosine similarity.
This would require editing the loss function to avoid normalization.

In [None]:
print('Next step: Implement L2-based InfoNCE variant')
print('Expected: L2 distance constrains both direction AND magnitude')
print('Hypothesis: This should prevent magnitude-based speckles')