# Shape Loss Experiment: Sculpting Objects from Noise

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jtooates/blind_lm/blob/main/experiments/shape_loss_experiment.ipynb)

This notebook demonstrates how the object-forming losses can create shape-like patterns from pure Gaussian noise.

**Key Idea**: We optimize pixel values directly (no neural network) using only loss functions as guidance.

**What you'll see**: Random noise → distinct blob-like shapes with sharp boundaries

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

print(f"PyTorch version: {torch.__version__}")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## Loss Functions

These are the same losses used in the single-channel training:

In [None]:
def create_shape_losses(latent):
    """
    Create losses that encourage shape-like structures.
    
    Args:
        latent: [B, H, W] tensor to optimize
    
    Returns:
        dict of loss components
    """
    B, H, W = latent.shape
    losses = {}
    
    # 1. Sparsity - most pixels should be background (near 0)
    losses['sparsity'] = torch.mean(torch.abs(latent))
    
    # 2. Binary-ness - pixels should be either on or off
    distance_from_binary = torch.min(
        (latent - 1.0)**2,  # Distance from 1
        (latent + 1.0)**2   # Distance from -1
    )
    losses['binary'] = torch.mean(distance_from_binary)
    
    # 3. Smoothness within objects (Total Variation)
    dx = torch.abs(latent[:, 1:, :] - latent[:, :-1, :])
    dy = torch.abs(latent[:, :, 1:] - latent[:, :, :-1])
    losses['tv'] = torch.mean(dx) + torch.mean(dy)
    
    # 4. Object size - encourage ~20% of pixels to be bright
    threshold = 0.5
    binary_mask = (latent > threshold).float()
    bright_ratio = torch.mean(binary_mask)
    target_ratio = 0.2
    losses['object_size'] = (bright_ratio - target_ratio)**2
    
    return losses


def contrastive_shape_loss(latents):
    """
    Contrastive loss: different latents should look different.
    """
    B = latents.shape[0]
    if B < 2:
        return torch.tensor(0.0)
    
    # Flatten and normalize
    flat = latents.reshape(B, -1)
    flat_norm = F.normalize(flat, dim=1)
    
    # Compute similarity matrix
    sim_matrix = torch.matmul(flat_norm, flat_norm.T)
    
    # Penalize similarity between different samples
    mask = 1.0 - torch.eye(B).to(latents.device)
    loss = torch.mean(torch.abs(sim_matrix) * mask)
    
    return loss

print("✓ Loss functions defined")

## Initialize Random Noise

We start with Gaussian noise + a few random blob seeds to break symmetry:

In [None]:
def initialize_latents(batch_size=6, image_size=(32, 32), device='cpu'):
    """
    Initialize latents with structured noise.
    """
    H, W = image_size
    latents = []
    
    for i in range(batch_size):
        # Start with weak noise
        noise = torch.randn(H, W) * 0.1
        
        # Add a few random blobs to break symmetry
        for _ in range(2):
            y = np.random.randint(5, H-5)
            x = np.random.randint(5, W-5)
            size = np.random.randint(3, 7)
            
            yy, xx = torch.meshgrid(
                torch.arange(H) - y,
                torch.arange(W) - x,
                indexing='ij'
            )
            blob = torch.exp(-(yy**2 + xx**2) / (2 * size**2))
            noise += blob * np.random.uniform(0.5, 1.5)
        
        latents.append(noise)
    
    latents = torch.stack(latents).to(device)
    latents.requires_grad_(True)
    
    return latents

# Initialize
batch_size = 6
latents = initialize_latents(batch_size=batch_size, device=device)

# Visualize initial state
fig, axes = plt.subplots(2, 3, figsize=(9, 6))
axes = axes.flatten()

for i in range(batch_size):
    axes[i].imshow(latents[i].detach().cpu().numpy(), cmap='gray', vmin=-1.5, vmax=1.5)
    axes[i].set_title(f'Initial Noise {i+1}')
    axes[i].axis('off')

plt.suptitle('Initial Random Noise (with blob seeds)')
plt.tight_layout()
plt.show()

print(f"✓ Initialized {batch_size} latents of shape {latents.shape[1:]}")

## Configure Loss Weights

**Try adjusting these!** Different weights create different types of patterns:

In [None]:
# Loss weights - EXPERIMENT WITH THESE!
weights = {
    'sparsity': 0.5,      # Higher = more black background
    'binary': 0.3,        # Higher = sharper boundaries (less gray)
    'tv': 0.1,            # Higher = smoother objects
    'object_size': 1.0,   # Higher = enforces target size more strongly
    'contrastive': 2.0    # Higher = more diverse patterns
}

# Optimization settings
num_steps = 300
learning_rate = 0.02

print("Loss weights:")
for name, weight in weights.items():
    print(f"  {name:15s}: {weight:.2f}")
print(f"\nOptimization: {num_steps} steps @ lr={learning_rate}")

## Run Optimization

Watch the noise transform into shapes!

In [None]:
# Re-initialize latents for fresh start
latents = initialize_latents(batch_size=batch_size, device=device)

# Optimizer
optimizer = torch.optim.Adam([latents], lr=learning_rate)

# Track losses
loss_history = []

print("Starting optimization...\n")

for step in range(num_steps):
    optimizer.zero_grad()
    
    # Compute losses
    shape_losses = create_shape_losses(latents)
    contrast_loss = contrastive_shape_loss(latents)
    
    # Total weighted loss
    total_loss = 0
    for name, loss in shape_losses.items():
        total_loss += loss * weights.get(name, 0.1)
    total_loss += contrast_loss * weights.get('contrastive', 1.0)
    
    # Optimize
    total_loss.backward()
    optimizer.step()
    
    # Clamp to valid range
    with torch.no_grad():
        latents.clamp_(-1.5, 1.5)
    
    loss_history.append(total_loss.item())
    
    # Print progress
    if step % 50 == 0:
        print(f"Step {step:3d}/{num_steps}: Loss = {total_loss.item():.4f}")
        print(f"  sparsity: {shape_losses['sparsity'].item():.4f}")
        print(f"  binary: {shape_losses['binary'].item():.4f}")
        print(f"  tv: {shape_losses['tv'].item():.4f}")
        print(f"  object_size: {shape_losses['object_size'].item():.4f}")
        print(f"  contrastive: {contrast_loss.item():.4f}")
        print()

print("\n✓ Optimization complete!")

## Visualize Results

See the optimized shapes:

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for i in range(batch_size):
    img = latents[i].detach().cpu().numpy()
    axes[i].imshow(img, cmap='gray', vmin=-1.5, vmax=1.5)
    axes[i].set_title(f'Optimized Shape {i+1}', fontsize=12)
    axes[i].axis('off')

plt.suptitle('Gaussian Noise → Shapes (via losses only)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("Notice:")
print("  - Distinct blob-like shapes")
print("  - Sharp boundaries between black/white")
print("  - Sparse (mostly black background)")
print("  - Each shape is different (contrastive loss)")

## Plot Loss Curve

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(loss_history, linewidth=2)
plt.xlabel('Optimization Step', fontsize=12)
plt.ylabel('Total Loss', fontsize=12)
plt.title('Loss During Optimization', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Final loss: {loss_history[-1]:.4f}")
print(f"Initial loss: {loss_history[0]:.4f}")
print(f"Reduction: {(1 - loss_history[-1]/loss_history[0])*100:.1f}%")

## Experiment: Try Different Loss Weights

Go back to the "Configure Loss Weights" cell and try:

**More sparsity** (higher `sparsity` weight):
- Creates smaller, sparser objects
- More black background

**More binary** (higher `binary` weight):
- Sharper black/white contrast
- Less gray values

**More smoothness** (higher `tv` weight):
- Smoother blob boundaries
- Less texture within objects

**Less contrastive** (lower `contrastive` weight):
- Shapes may become more similar
- Could collapse to identical patterns

**More object_size** (higher `object_size` weight):
- Enforces target size (20%) more strongly
- More consistent object sizes across samples

## Key Takeaway

This experiment shows that **you don't need paired image-text data** to get image-like patterns.

The right combination of losses can **sculpt structure out of noise** purely through optimization pressure.

This is exactly what the text encoder learns to do:
1. Map text → visual latent space
2. Minimize these object-forming losses
3. Naturally creates interpretable object-like patterns
4. Different text → different shapes (via contrastive loss)
5. Reconstruction loss ensures information is preserved