# Phase 1: RGB Latent Training with Mumford-Shah Loss on Google Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jtooates/blind_lm/blob/main/phase1_colab_training.ipynb)

This notebook trains a text encoder to produce **3-channel RGB** latents (32×32×3) with piecewise constant regions.

**NEW:** Uses Mumford-Shah loss for smooth regions with sharp boundaries + jittering to prevent speckles.

**Goal**: Create colored images with smooth homogeneous regions separated by boundaries.

**Training time**: ~2-3 hours on T4 GPU

---

## Setup Instructions

1. **Runtime → Change runtime type → T4 GPU**
2. Run all cells in order
3. Checkpoints save to Google Drive automatically
4. Results appear as RGB color images!

## 1. Environment Setup

In [1]:
# Check GPU availability
import torch
print("="*70)
print("GPU Check")
print("="*70)
print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
else:
    print("⚠️  WARNING: No GPU found! Training will be very slow.")
    print("   Please go to Runtime → Change runtime type → T4 GPU")
print("="*70)

GPU Check
GPU Available: False
   Please go to Runtime → Change runtime type → T4 GPU


In [2]:
# Mount Google Drive to save checkpoints
from google.colab import drive
drive.mount('/content/drive')

# Create output directory on Drive
!mkdir -p /content/drive/MyDrive/blind_lm_outputs
print("✓ Google Drive mounted")
print("✓ Checkpoints will save to: /content/drive/MyDrive/blind_lm_outputs/")

ModuleNotFoundError: No module named 'google.colab'

In [3]:
# Clone or update the repository
import os

repo_dir = 'blind_lm'
repo_url = 'https://github.com/jtooates/blind_lm.git'

if os.path.exists(repo_dir):
    print("Repository already exists. Pulling latest changes...")
    %cd blind_lm
    !git pull origin main
    print("✓ Repository updated to latest version")
else:
    print("Cloning repository...")
    !git clone {repo_url}
    %cd blind_lm
    print("✓ Repository cloned successfully")

print("\n" + "="*70)
print("Code is ready!")
print("="*70)

Cloning repository...
Cloning into 'blind_lm'...
remote: Enumerating objects: 373, done.[K
remote: Counting objects: 100% (133/133), done.[K
remote: Compressing objects: 100% (100/100), done.[K
remote: Total 373 (delta 84), reused 81 (delta 33), pack-reused 240 (from 1)[K
Receiving objects: 100% (373/373), 3.89 MiB | 20.14 MiB/s, done.
Resolving deltas: 100% (228/228), done.
/Users/oates/Desktop/mydocuments/projects/current/blind_lm/blind_lm
✓ Repository cloned successfully

Code is ready!


In [4]:
# Install dependencies
print("Installing dependencies...")
!pip install -q transformers scipy tqdm matplotlib

# Suppress tokenizer warning
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

print("✓ Dependencies installed")

Installing dependencies...
✓ Dependencies installed


## 2. Data Preparation

In [5]:
# Check if training data exists, generate if needed
import os

if not os.path.exists('train_sentences.txt'):
    print("Generating training data (10,000 sentences)...")
    !python generate_sentences.py --num 10000 --complexity 1 --seed 42 --output train_sentences.txt
    print("✓ Training data generated")
else:
    print("✓ Training data already exists")

if not os.path.exists('val_sentences.txt'):
    print("Generating validation data (1,000 sentences)...")
    !python generate_sentences.py --num 1000 --complexity 1 --seed 100 --output val_sentences.txt
    print("✓ Validation data generated")
else:
    print("✓ Validation data already exists")

# Show stats
print("\n" + "="*70)
print("Data Statistics")
print("="*70)
!wc -l train_sentences.txt val_sentences.txt

print("\nSample sentences:")
!head -5 train_sentences.txt

Generating training data (10,000 sentences)...
Generated 100/10000 sentences...
Generated 200/10000 sentences...
Generated 300/10000 sentences...
Generated 400/10000 sentences...
Generated 500/10000 sentences...
Generated 600/10000 sentences...
Generated 700/10000 sentences...
Generated 800/10000 sentences...
Generated 900/10000 sentences...
Generated 1000/10000 sentences...
Generated 1100/10000 sentences...
Generated 1200/10000 sentences...
Generated 1300/10000 sentences...
Generated 1400/10000 sentences...
Generated 1500/10000 sentences...
Generated 1600/10000 sentences...
Generated 1700/10000 sentences...
Generated 1800/10000 sentences...
Generated 1900/10000 sentences...
Generated 2000/10000 sentences...
Generated 2100/10000 sentences...
Generated 2200/10000 sentences...
Generated 2300/10000 sentences...
Generated 2400/10000 sentences...
Generated 2500/10000 sentences...
Generated 2600/10000 sentences...
Generated 2700/10000 sentences...
Generated 2800/10000 sentences...
Generated 

## 3. Configuration

In [6]:
# Create Colab-optimized config with Mumford-Shah loss
import json

config = {
    "description": "RGB latent with Mumford-Shah loss for piecewise constant regions",
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "output_dir": "/content/drive/MyDrive/blind_lm_outputs/phase1_rgb_mumford_shah",

    "model": {
        "vocab_size": 50257,
        "max_seq_len": 32,
        "hidden_size": 384,
        "num_layers": 6,
        "num_heads": 8,
        "ffn_size": 1536,
        "dropout": 0.1,
        "grid_size": 64,
        "num_channels": 3,  # RGB (3 channels)
        "use_rope": True,
        "use_smooth_head": False,
        "tokenizer_name": "gpt2"
    },

    "loss": {
        # Loss weights
        "lambda_recon": 1.0,                  # Reconstruction loss weight
        "lambda_magnitude": 1.0,              # Magnitude loss weight (prevent collapse)
        "lambda_mumford_shah": 1.0,           # Mumford-Shah loss (piecewise constant regions)
        
        # Magnitude parameters
        "min_magnitude": 0.3,                 # Minimum magnitude target
        
        # Mumford-Shah parameters
        "mumford_shah_alpha": 0.0,            # Within-region smoothness (L2 term)
        "mumford_shah_beta": 5.0              # Boundary sparsity (L1 term) - 0 for non-isotropic
    },

    "decoder": {
        "vocab_size": 50257,
        "max_seq_len": 32,
        "hidden_size": 384,
        "num_layers": 4,
        "num_heads": 8,
        "ffn_size": 1536,
        "dropout": 0.1,
        "use_rope": True
    },

    "training": {
        "batch_size": 64,  # Reduced for T4 GPU
        "lr": 2e-4,
        "beta1": 0.9,
        "beta2": 0.95,
        "weight_decay": 0.01,
        "warmup_steps": 500,
        "num_epochs": 1000,  # High limit - will stop at max_steps
        "max_steps": 10000,  # Shorter initial training
        "ema_decay": 0.999,
        "grad_clip": 1.0,
        "blur_sigma": 0.8,
        "blur_warmup_steps": 0,
        "jitter_std": 0.001  # Latent jittering to prevent speckles
    },

    "data": {
        "train_file": "../train_sentences.txt",
        "val_file": "../val_sentences.txt",
        "num_workers": 2,  # Colab-optimized
        "file_format": "txt"
    },

    "eval": {
        "eval_interval": 500,
        "save_interval": 2000,
        "num_fixed_sentences": 16
    }
}

# Save config
!mkdir -p phase1/configs
with open('phase1/configs/phase1_colab.json', 'w') as f:
    json.dump(config, f, indent=2)

print("Configuration created:")
print(f"  Device: {config['device']}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Max steps: {config['training']['max_steps']}")
print(f"  Channels: {config['model']['num_channels']} (RGB)")
print(f"\nLoss components:")
print(f"  - Reconstruction: {config['loss']['lambda_recon']}")
print(f"  - Magnitude: {config['loss']['lambda_magnitude']}")
print(f"  - Mumford-Shah: {config['loss']['lambda_mumford_shah']}")
print(f"\nMumford-Shah settings:")
print(f"  - Alpha (smoothness): {config['loss']['mumford_shah_alpha']}")
print(f"  - Beta (boundary): {config['loss']['mumford_shah_beta']}")
print(f"\nJittering:")
print(f"  - Jitter std: {config['training']['jitter_std']}")
print(f"\nOutput: {config['output_dir']}")
print("\n✓ Ready to train with Mumford-Shah loss for smooth regions!")

Configuration created:
  Device: cpu
  Batch size: 64
  Max steps: 10000
  Channels: 3 (RGB)

Loss components:
  - Reconstruction: 5.0
  - Magnitude: 5.0
  - Mumford-Shah: 5.0

Mumford-Shah settings:
  - Alpha (smoothness): 5.0
  - Beta (boundary): 0.0

Jittering:
  - Jitter std: 0.1

Output: /content/drive/MyDrive/blind_lm_outputs/phase1_rgb_mumford_shah

✓ Ready to train with Mumford-Shah loss for smooth regions!


## 4. Training

This will take approximately **2-3 hours** on a T4 GPU.

The training loop will:
- Train for up to 50,000 steps (or 10 epochs)
- Evaluate every 500 steps
- Save checkpoints every 2,000 steps to Google Drive
- Display progress bars and loss values

In [7]:
# Run training
%cd phase1

print("="*70)
print("Starting Phase 1 Training")
print("="*70)
print("This will take approximately 2-3 hours on T4 GPU")
print("You can monitor progress below...")
print("="*70)
print()

!python train.py --config configs/phase1_colab.json

/Users/oates/Desktop/mydocuments/projects/current/blind_lm/blind_lm/phase1
Starting Phase 1 Training
This will take approximately 2-3 hours on T4 GPU
You can monitor progress below...

Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.12/pathlib.py", line 1311, in mkdir
    os.mkdir(self, mode)
FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/blind_lm_outputs/phase1_rgb_mumford_shah'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.12/pathlib.py", line 1311, in mkdir
    os.mkdir(self, mode)
FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/blind_lm_outputs'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.12/pathlib.py", line 1311, in mkdir
    os.mkdir(self, mode)
FileNotFoundError: [Errno 2] No such file or directory: '/cont

In [8]:
# Simple evaluation - just show some sample reconstructions from training data
import sys
import os

# Add phase1 directory to path
sys.path.insert(0, '/content/blind_lm/phase1')

from model import create_model
from decoder_nonar import create_decoder
from transformers import AutoTokenizer
import torch
import json
import random

print("Generating sample reconstructions...")

output_dir = '/content/drive/MyDrive/blind_lm_outputs/phase1_rgb_infonce'

# Load config
with open(os.path.join(output_dir, 'config.json')) as f:
    config = json.load(f)

# Load models
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = create_model(config['model']).to(device)
decoder = create_decoder(config['decoder']).to(device)

# Load checkpoint
checkpoint_path = os.path.join(output_dir, 'checkpoint_latest.pt')
checkpoint = torch.load(checkpoint_path, map_location=device)
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])
encoder.eval()
decoder.eval()

print(f'Loaded checkpoint from step {checkpoint["step"]}')

# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Load test sentences from training data (use sentences the model has seen)
train_file = '/content/blind_lm/train_sentences.txt'
with open(train_file, 'r') as f:
    all_sentences = [line.strip() for line in f if line.strip()]

# Sample 5 random sentences from training data
random.seed(42)
test_sentences = random.sample(all_sentences, min(5, len(all_sentences)))

print("\n" + "="*70)
print("SAMPLE RECONSTRUCTIONS (from training data)")
print("="*70)

with torch.no_grad():
    for sentence in test_sentences:
        # Tokenize
        inputs = tokenizer(sentence, return_tensors='pt', padding='max_length', 
                         truncation=True, max_length=64)
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        
        # Encode
        latent = encoder(input_ids, attention_mask)
        
        # Decode (non-autoregressive - ignores input_ids)
        logits = decoder(latent, input_ids, attention_mask)
        predicted_ids = torch.argmax(logits, dim=-1)
        
        # Decode text
        reconstruction = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
        
        print(f"\nOriginal:       {sentence}")
        print(f"Reconstruction: {reconstruction}")
        
        # Check exact match
        if sentence.strip() == reconstruction.strip():
            print("✓ EXACT MATCH")

print("\n" + "="*70)

Generating sample reconstructions...


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/blind_lm_outputs/phase1_rgb_infonce/config.json'

In [None]:
# Display evaluation results for RGB latents (using training data)
from IPython.display import Image, display
import matplotlib.pyplot as plt
import torch
import numpy as np
import json
import os
import random
import sys

# Add phase1 directory to path
sys.path.insert(0, '/content/blind_lm/phase1')

eval_dir = "/content/drive/MyDrive/blind_lm_outputs/phase1_rgb_infonce/eval_report"

print("="*70)
print("RGB LATENT EVALUATION RESULTS")
print("="*70)

# Check if we can visualize the RGB latents directly
checkpoint_path = "/content/drive/MyDrive/blind_lm_outputs/phase1_rgb_infonce/checkpoint_latest.pt"
if os.path.exists(checkpoint_path):
    # Load the checkpoint to get some sample latents
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    print(f"\nCheckpoint Info:")
    print(f"  Step: {checkpoint['step']}")
    print(f"  Epoch: {checkpoint['epoch']}")
    
    # Display loss components if available
    if 'metrics_history' in checkpoint and checkpoint['metrics_history']:
        latest = checkpoint['metrics_history'][-1]
        print(f"\nLatest Loss Components:")
        if 'loss_components' in latest:
            for name, value in latest['loss_components'].items():
                print(f"  {name}: {value:.4f}")

# Try to display RGB visualizations if they exist
print("\n" + "="*70)
print("RGB VISUALIZATIONS (from training data)")
print("="*70)

# Helper function to convert latent to RGB for display
def latent_to_rgb(latent_tensor):
    """Convert [H, W, 3] tensor to displayable RGB, normalized to [0, 1]"""
    rgb = latent_tensor.cpu().numpy()
    # Normalize from [-1.5, 1.5] to [0, 1]
    rgb = (rgb + 1.5) / 3.0
    rgb = np.clip(rgb, 0, 1)
    return rgb

# Generate a simple visualization of the RGB latents
try:
    from model import create_model
    from transformers import AutoTokenizer
    
    # Load model and tokenizer
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Load config
    config_path = "/content/drive/MyDrive/blind_lm_outputs/phase1_rgb_infonce/config.json"
    if os.path.exists(config_path):
        with open(config_path) as f:
            config = json.load(f)
        
        encoder = create_model(config['model']).to(device)
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        encoder.eval()
        
        tokenizer = AutoTokenizer.from_pretrained('gpt2')
        tokenizer.pad_token = tokenizer.eos_token
        
        # Load test sentences from training data
        train_file = '/content/blind_lm/train_sentences.txt'
        with open(train_file, 'r') as f:
            all_sentences = [line.strip() for line in f if line.strip()]
        
        # Sample 4 random sentences from training data
        random.seed(42)
        test_sentences = random.sample(all_sentences, min(4, len(all_sentences)))
        
        # Create figure for RGB images
        fig, axes = plt.subplots(2, 2, figsize=(8, 8))
        axes = axes.flatten()
        
        with torch.no_grad():
            for i, sentence in enumerate(test_sentences):
                # Tokenize
                inputs = tokenizer(sentence, return_tensors='pt', padding='max_length', 
                                 truncation=True, max_length=64)
                input_ids = inputs['input_ids'].to(device)
                attention_mask = inputs['attention_mask'].to(device)
                
                # Generate latent
                latent = encoder(input_ids, attention_mask)  # [1, 32, 32, 3]
                
                # Convert to RGB for display
                rgb = latent_to_rgb(latent[0])
                
                # Display
                axes[i].imshow(rgb)
                axes[i].set_title(f'"{sentence[:30]}..."' if len(sentence) > 30 else f'"{sentence}"', 
                                 fontsize=8)
                axes[i].axis('off')
        
        plt.suptitle('RGB Latents (Training Data)', fontsize=12)
        plt.tight_layout()
        plt.savefig('/content/rgb_samples.png', dpi=150, bbox_inches='tight')
        plt.show()
        
        print("✓ RGB visualization generated!")
        
except Exception as e:
    print(f"Could not generate live visualization: {e}")
    print("This is normal if training hasn't completed yet.")

# Show any saved visualizations
if os.path.exists(eval_dir):
    viz_files = os.listdir(eval_dir)
    if viz_files:
        print(f"\nSaved visualizations in {eval_dir}:")
        for file in viz_files:
            if file.endswith('.png'):
                print(f"  - {file}")
                display(Image(os.path.join(eval_dir, file)))

## 7. Interpret Results

### What to Expect with RGB InfoNCE

With **3-channel RGB** output and **InfoNCE patch coherence**, you should see:

**RGB Color Images**:
- Full-color visualizations showing all 3 channels (R, G, B)
- Spatially coherent colored patterns (nearby regions have similar colors)
- Local smoothness with global diversity (different parts of image can have different colors)
- Different sentences produce distinct colored patterns

**InfoNCE Coherence Effects**:
- Nearby patches (within positive_radius ~3 pixels) will be similar in color
- Distant patches (beyond negative_radius ~11 pixels) will differ
- Creates smooth, blob-like colored regions
- Prevents noisy, scattered pixels

**Reconstruction Quality**:
- High reconstruction accuracy (target >40% exact match)
- Semantically similar sentences produce similar RGB patterns
- Different sentences produce distinct RGB patterns
- Text decoder can reconstruct original text from RGB latent

### Loss Components to Monitor

- **Reconstruction loss**: Should decrease steadily (target < 1.0)
  - Measures how well the text decoder reconstructs the input
  
- **InfoNCE loss**: Should stabilize after initial decrease
  - Measures spatial coherence of RGB patches
  - Lower = more coherent colored regions
  
- **Magnitude loss**: Should approach zero as training progresses
  - Ensures latent doesn't collapse to all zeros
  - Maintains meaningful signal strength

### Training Tips

- **If patterns are too noisy**: Increase `lambda_infonce` or decrease `infonce_temperature`
- **If patterns are too uniform**: Decrease `lambda_infonce` or increase `infonce_positive_radius`
- **If reconstruction is poor**: Increase `lambda_recon`
- **If latents collapse**: Increase `lambda_magnitude`

## 8. Download Checkpoints (Optional)

Download the final checkpoint and visualizations to your local machine

In [None]:
# Create a zip file with important results
import shutil
import os

output_dir = "/content/drive/MyDrive/blind_lm_outputs/phase1_rgb_infonce"
zip_path = "/content/phase1_rgb_infonce_results.zip"

print("Creating results archive...")

# Create temporary directory
temp_dir = "/content/phase1_results_temp"
os.makedirs(temp_dir, exist_ok=True)

# Copy important files
files_to_include = [
    "config.json",
    "checkpoint_latest.pt"
]

for file in files_to_include:
    src = os.path.join(output_dir, file)
    if os.path.exists(src):
        dst_dir = os.path.join(temp_dir, os.path.dirname(file))
        os.makedirs(dst_dir, exist_ok=True)
        shutil.copy2(src, os.path.join(temp_dir, file))
        print(f"  ✓ {file}")
    else:
        print(f"  ✗ {file} not found")

# Create zip
shutil.make_archive('/content/phase1_rgb_infonce_results', 'zip', temp_dir)

# Download
from google.colab import files
print("\nDownloading...")
files.download('/content/phase1_rgb_infonce_results.zip')

print("\n✓ Download complete!")

## 9. Next Steps

After Phase 1 passes:

1. **Phase 2**: Add semantic meaning via contrastive learning
   - Paraphrases should produce similar latents
   - Counterfactuals should produce different latents

2. **Phase 3**: Spatial jitter robustness
   - Latents should be invariant to small shifts

3. **Phase 4**: Add text decoder
   - Reconstruct text from latent

4. **Phase 5**: Round-trip generation
   - Generate paraphrases without copying

---

**Questions or issues?** Check the [project documentation](https://github.com/jtooates/blind_lm)