# Diffusion LLM Conversion Quick Start

This notebook demonstrates basic usage of the diffusion conversion pipeline.

## Overview

This pipeline converts pre-trained autoregressive language models (like GPT-2, LLaMA, etc.) into diffusion-based language models that can generate text through a denoising process.

### Key Components:
- **DiffusionTransformer**: Wrapper that converts causal attention to bidirectional
- **MaskedDiffusionScheduler**: Handles noise addition and denoising schedule
- **Multi-stage training**: Gradual conversion from autoregressive to diffusion model

In [None]:
# Setup
import sys
sys.path.append('../src')

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from models.diffusion_transformer import DiffusionTransformer, DiffusionTransformerConfig
from models.noise_scheduler import MaskedDiffusionScheduler
import matplotlib.pyplot as plt
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Base Model

We'll start with a small GPT-2 model for demonstration purposes.

In [None]:
# Load a small model for testing
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(model_name)

# Add special tokens
tokenizer.add_special_tokens({'mask_token': '[MASK]'})
base_model.resize_token_embeddings(len(tokenizer))

# Set pad token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Model: {model_name}")
print(f"Vocab size: {len(tokenizer)}")
print(f"Parameters: {sum(p.numel() for p in base_model.parameters()) / 1e6:.1f}M")
print(f"Mask token: {tokenizer.mask_token} (ID: {tokenizer.mask_token_id})")

## 2. Create Diffusion Model

Convert the autoregressive model to a diffusion model.

In [None]:
# Create diffusion config
config = DiffusionTransformerConfig(
    base_model_name=model_name,
    hidden_size=base_model.config.hidden_size,
    num_timesteps=100,  # Fewer timesteps for demo
    mask_token_id=tokenizer.mask_token_id,
    use_bidirectional_attention=True,
    freeze_base_model=True  # Start with frozen base model
)

# Create diffusion model
diffusion_model = DiffusionTransformer(base_model, config)
diffusion_model.to(device)

print(f"Diffusion model created with {sum(p.numel() for p in diffusion_model.parameters()) / 1e6:.1f}M parameters")
print(f"Trainable parameters: {sum(p.numel() for p in diffusion_model.parameters() if p.requires_grad) / 1e6:.1f}M")

## 3. Create Noise Scheduler

The noise scheduler handles the diffusion process.

In [None]:
# Create noise scheduler
scheduler = MaskedDiffusionScheduler(
    num_timesteps=config.num_timesteps,
    schedule_type="cosine",
    mask_token_id=tokenizer.mask_token_id
)

print(f"Scheduler created with {scheduler.num_timesteps} timesteps")
print(f"Schedule type: {scheduler.schedule_type}")

# Visualize noise schedule
timesteps = np.arange(scheduler.num_timesteps)
noise_levels = [scheduler.get_noise_level(t) for t in timesteps]

plt.figure(figsize=(10, 4))
plt.plot(timesteps, noise_levels)
plt.xlabel('Timestep')
plt.ylabel('Noise Level')
plt.title('Noise Schedule')
plt.grid(True)
plt.show()

## 4. Test Forward Pass

Let's test the model with a simple forward pass.

In [None]:
# Test with a simple input
test_text = "The quick brown fox jumps over the lazy dog."
inputs = tokenizer(test_text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs['input_ids'].to(device)
attention_mask = inputs['attention_mask'].to(device)

print(f"Input text: {test_text}")
print(f"Input shape: {input_ids.shape}")
print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}")

# Add noise at different timesteps
timesteps = torch.tensor([0, 25, 50, 75], device=device)

print("\nNoise examples:")
for t in timesteps:
    noisy_ids, mask = scheduler.add_noise(input_ids, t.unsqueeze(0))
    masked_text = tokenizer.decode(noisy_ids[0], skip_special_tokens=False)
    print(f"t={t:2d}: {masked_text}")

## 5. Forward Pass Through Diffusion Model

Test the diffusion model's forward pass.

In [None]:
# Test forward pass
diffusion_model.eval()

with torch.no_grad():
    # Test at different timesteps
    for t in [10, 50, 90]:
        timestep = torch.tensor([t], device=device)
        
        # Add noise
        noisy_ids, mask = scheduler.add_noise(input_ids, timestep)
        
        # Forward pass
        outputs = diffusion_model(
            input_ids=noisy_ids,
            timesteps=timestep,
            attention_mask=attention_mask
        )
        
        logits = outputs['logits']
        
        print(f"\nTimestep {t}:")
        print(f"  Input shape: {noisy_ids.shape}")
        print(f"  Output shape: {logits.shape}")
        print(f"  Masked positions: {mask.sum().item()}")
        
        # Show prediction for first masked token
        if mask.any():
            first_mask_pos = mask[0].nonzero()[0].item()
            pred_logits = logits[0, first_mask_pos]
            pred_id = pred_logits.argmax().item()
            pred_token = tokenizer.decode([pred_id])
            original_token = tokenizer.decode([input_ids[0, first_mask_pos].item()])
            print(f"  Prediction for pos {first_mask_pos}: '{pred_token}' (original: '{original_token}')")

## 6. Test Bidirectional Attention

Compare outputs with and without bidirectional attention.

In [None]:
# Test bidirectional vs causal attention
test_input = "The cat sat on the [MASK] mat."
test_tokens = tokenizer(test_input.replace('[MASK]', tokenizer.mask_token), 
                       return_tensors="pt", padding=True)
test_ids = test_tokens['input_ids'].to(device)
test_mask = test_tokens['attention_mask'].to(device)

print(f"Test input: {test_input}")
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(test_ids[0])}")

# Find mask position
mask_pos = (test_ids[0] == tokenizer.mask_token_id).nonzero().item()
print(f"Mask position: {mask_pos}")

timestep = torch.tensor([50], device=device)

# Test with bidirectional attention
diffusion_model.config.use_bidirectional_attention = True
with torch.no_grad():
    outputs_bi = diffusion_model(
        input_ids=test_ids,
        timesteps=timestep,
        attention_mask=test_mask
    )
    
    # Get top predictions
    logits_bi = outputs_bi['logits'][0, mask_pos]
    top_ids_bi = logits_bi.topk(5).indices
    top_tokens_bi = [tokenizer.decode([id.item()]) for id in top_ids_bi]
    
print(f"\nBidirectional attention - Top 5 predictions:")
for i, token in enumerate(top_tokens_bi):
    print(f"  {i+1}. '{token}'")

# Test with causal attention
diffusion_model.config.use_bidirectional_attention = False
with torch.no_grad():
    outputs_causal = diffusion_model(
        input_ids=test_ids,
        timesteps=timestep,
        attention_mask=test_mask
    )
    
    # Get top predictions
    logits_causal = outputs_causal['logits'][0, mask_pos]
    top_ids_causal = logits_causal.topk(5).indices
    top_tokens_causal = [tokenizer.decode([id.item()]) for id in top_ids_causal]
    
print(f"\nCausal attention - Top 5 predictions:")
for i, token in enumerate(top_tokens_causal):
    print(f"  {i+1}. '{token}'")

# Reset to bidirectional
diffusion_model.config.use_bidirectional_attention = True

## 7. Simple Generation Test

Test the generation capabilities of the model.

In [None]:
# Simple generation test
def simple_generate(model, scheduler, tokenizer, prompt, max_length=20, num_steps=10):
    """Simple generation using the diffusion model."""
    model.eval()
    
    # Tokenize prompt
    if prompt:
        prompt_tokens = tokenizer(prompt, return_tensors="pt")['input_ids'].to(device)
        prompt_len = prompt_tokens.shape[1]
    else:
        prompt_tokens = None
        prompt_len = 0
    
    # Initialize with masks
    seq_len = max_length
    input_ids = torch.full((1, seq_len), tokenizer.mask_token_id, device=device)
    
    if prompt_tokens is not None:
        input_ids[0, :prompt_len] = prompt_tokens[0]
    
    # Simple sampling loop
    timesteps = torch.linspace(scheduler.num_timesteps-1, 0, num_steps, device=device).long()
    
    with torch.no_grad():
        for t in timesteps:
            t_batch = t.unsqueeze(0)
            
            # Get predictions
            outputs = model(input_ids=input_ids, timesteps=t_batch)
            logits = outputs['logits']
            
            # Sample from logits for masked positions
            mask = (input_ids == tokenizer.mask_token_id)
            if mask.any():
                # Simple sampling - take top prediction
                predictions = logits.argmax(dim=-1)
                
                # Only update some masked positions (gradual denoising)
                masked_positions = mask.nonzero(as_tuple=False)
                num_to_update = max(1, len(masked_positions) // (num_steps - timesteps.tolist().index(t)))
                
                if len(masked_positions) > 0:
                    # Update random subset of masked positions
                    update_indices = torch.randperm(len(masked_positions))[:num_to_update]
                    for idx in update_indices:
                        pos = masked_positions[idx]
                        input_ids[pos[0], pos[1]] = predictions[pos[0], pos[1]]
    
    return input_ids

# Test generation
print("Testing generation...")
prompt = "The weather today is"
generated = simple_generate(diffusion_model, scheduler, tokenizer, prompt, max_length=15, num_steps=5)
generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)

print(f"\nPrompt: '{prompt}'")
print(f"Generated: '{generated_text}'")

# Test without prompt
print("\nTesting generation without prompt...")
generated_no_prompt = simple_generate(diffusion_model, scheduler, tokenizer, "", max_length=12, num_steps=8)
generated_text_no_prompt = tokenizer.decode(generated_no_prompt[0], skip_special_tokens=True)
print(f"Generated: '{generated_text_no_prompt}'")

## 8. Model Analysis

Analyze the model's structure and capabilities.

In [None]:
# Model analysis
print("Model Analysis:")
print(f"  Base model: {diffusion_model.base_model.__class__.__name__}")
print(f"  Hidden size: {diffusion_model.config.hidden_size}")
print(f"  Vocab size: {diffusion_model.base_config.vocab_size}")
print(f"  Timesteps: {diffusion_model.config.num_timesteps}")
print(f"  Bidirectional: {diffusion_model.config.use_bidirectional_attention}")
print(f"  Frozen base: {diffusion_model.config.freeze_base_model}")

# Count parameters by component
base_params = sum(p.numel() for p in diffusion_model.base_model.parameters())
time_params = sum(p.numel() for p in diffusion_model.time_embed.parameters()) + \
              sum(p.numel() for p in diffusion_model.time_mlp.parameters())
head_params = sum(p.numel() for p in diffusion_model.denoising_head.parameters())

print(f"\nParameter breakdown:")
print(f"  Base model: {base_params / 1e6:.1f}M")
print(f"  Time components: {time_params / 1e6:.1f}M")
print(f"  Denoising head: {head_params / 1e6:.1f}M")
print(f"  Total: {(base_params + time_params + head_params) / 1e6:.1f}M")

# Memory usage estimate
model_size_mb = (base_params + time_params + head_params) * 4 / 1e6  # 4 bytes per float32
print(f"\nMemory estimate: {model_size_mb:.1f}MB (float32)")

## 9. Next Steps

This notebook demonstrates the basic setup and testing of the diffusion model. For actual training:

1. **Use the training script**: `python scripts/train_diffusion.py --model_config configs/model_configs/gpt2_diffusion.yaml --training_config configs/training_configs/default.yaml --hardware_config configs/hardware_configs/gpu.yaml`

2. **Try different models**: The pipeline supports various architectures (GPT-2, GPT-Neo, LLaMA, etc.)

3. **Experiment with hyperparameters**: Adjust timesteps, noise schedules, training stages

4. **Evaluate on benchmarks**: Test perplexity, generation quality, etc.

5. **Fine-tune for specific tasks**: Adapt the model for particular applications

The conversion from autoregressive to diffusion models opens up new possibilities for text generation and understanding!