# 🚀 Diffusion Language Model - Complete Training & Demo

This notebook provides:
1. PDF data processing pipeline
2. Model initialization with Flash Attention
3. Interactive training with progress visualization
4. Generation and evaluation

**Note**: Flash Attention is enabled by default for faster training!


## 📦 1. Setup & Dependencies


In [None]:
# Install Flash Attention (if not installed)
import subprocess
import sys

try:
    import flash_attn
    print("✅ Flash Attention already installed")
except ImportError:
    print("Installing Flash Attention...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"])
    print("✅ Flash Attention installed")

# Import all dependencies
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import json
from pathlib import Path
from transformers import AutoTokenizer
from typing import List, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.getcwd()))

from model.diffusion import DiffusionLM, DiffusionSchedule
from model.transformer import FLASH_AVAILABLE
from data.pdf_processor import PDFProcessor
from data.dataset import create_dataloader, DiffusionDataset
from training.config import TrainingConfig

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️ Using device: {device}")
print(f"⚡ Flash Attention available: {FLASH_AVAILABLE}")

if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name()}")
    print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


## 📄 2. Data Processing Pipeline


In [None]:
# Initialize PDF processor
pdf_processor = PDFProcessor(
    min_length=100,
    max_length=512,
    clean_text=True,
    use_pdfplumber=True  # Better extraction quality
)

# Configuration
PDF_FOLDER = "pdfs"  # Update this path to your PDF folder
DATA_OUTPUT = "data/processed"
CREATE_SAMPLE_DATA = True  # Set to False if you have PDFs

print("📚 PDF Processor initialized")


In [None]:
# Process PDFs or create sample data
if CREATE_SAMPLE_DATA:
    # Create sample training data for demo
    sample_texts = [
        "The diffusion language model learns to generate text by progressively denoising masked tokens through iterative refinement.",
        "Machine learning models can be trained using various optimization techniques including gradient descent and adaptive learning rates.",
        "Natural language processing has made significant advances with transformer architectures and attention mechanisms.",
        "Deep learning requires large amounts of data and computational resources for training complex neural networks.",
        "The attention mechanism allows models to focus on relevant parts of the input sequence for better understanding.",
        "Neural networks consist of interconnected layers that process information hierarchically through forward propagation.",
        "Text generation models can produce coherent and contextually relevant outputs using language modeling objectives.",
        "Pretrained language models can be fine-tuned for specific downstream tasks with transfer learning.",
        "The transformer architecture has revolutionized natural language understanding with self-attention mechanisms.",
        "Diffusion models iteratively refine noisy inputs to generate high-quality outputs through reverse diffusion.",
        "Gradient-based optimization methods help neural networks learn optimal parameters from training data.",
        "Recurrent neural networks process sequential data by maintaining hidden states across time steps.",
        "Convolutional neural networks excel at processing grid-like data such as images and spectrograms.",
        "Reinforcement learning agents learn optimal policies through interaction with environments and rewards.",
        "Generative adversarial networks consist of generator and discriminator networks in competition.",
        "Variational autoencoders learn latent representations through probabilistic encoding and decoding.",
        "Meta-learning algorithms enable models to quickly adapt to new tasks with limited examples.",
        "Contrastive learning methods learn representations by comparing positive and negative sample pairs.",
        "Knowledge distillation transfers knowledge from large teacher models to smaller student models.",
        "Multi-task learning shares representations across related tasks for improved generalization."
    ] * 5  # Repeat for more data
    
    # Save as training data
    os.makedirs('data', exist_ok=True)
    with open('data/train_data.json', 'w') as f:
        json.dump({'texts': sample_texts}, f, indent=2)
    
    print(f"✅ Created sample dataset with {len(sample_texts)} text samples")
    data_path = 'data/train_data.json'
    
else:
    # Process actual PDFs
    if os.path.exists(PDF_FOLDER):
        print(f"📂 Processing PDFs from {PDF_FOLDER}")
        docs = pdf_processor.process_folder(
            PDF_FOLDER, 
            output_dir=DATA_OUTPUT,
            save_format='json'
        )
        print(f"✅ Processed {len(docs)} documents")
        data_path = f"{DATA_OUTPUT}/all_documents.json"
    else:
        print(f"⚠️ PDF folder '{PDF_FOLDER}' not found. Creating sample data instead...")
        CREATE_SAMPLE_DATA = True

print(f"📍 Data path: {data_path}")


## 🤖 3. Model Initialization with Flash Attention


In [None]:
# Training configuration
config = TrainingConfig(
    data_path=data_path,
    batch_size=4 if device.type == 'cuda' else 2,  # Smaller batch for CPU
    learning_rate=5e-4,
    num_epochs=10,  # Adjust based on your needs
    max_length=128,  # Sequence length
    num_timesteps=50,  # Diffusion steps
    device=str(device),
    
    # Model configuration
    vocab_size=30000,
    d_model=256,  # Model dimension
    n_heads=8,
    n_layers=4,  # Number of transformer layers
    d_ff=1024,
    dropout=0.1,
    
    # Logging
    log_interval=10,
    eval_interval=50,
    save_interval=100,
    
    # Generation
    num_generation_steps=30,
    generation_temperature=1.0,
    generation_top_p=0.95
)

print("⚙️ Configuration set")
print(f"  - Batch size: {config.batch_size}")
print(f"  - Learning rate: {config.learning_rate}")
print(f"  - Model dimension: {config.d_model}")
print(f"  - Layers: {config.n_layers}")
print(f"  - Flash Attention: Enabled by default")


In [None]:
# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
if tokenizer.mask_token is None:
    tokenizer.add_special_tokens({'mask_token': '[MASK]'})

# Update vocab size
config.vocab_size = len(tokenizer)

# Initialize model with Flash Attention
model = DiffusionLM(
    vocab_size=config.vocab_size,
    d_model=config.d_model,
    n_heads=config.n_heads,
    n_layers=config.n_layers,
    d_ff=config.d_ff,
    max_seq_len=config.max_length,
    num_timesteps=config.num_timesteps,
    dropout=config.dropout,
    mask_token_id=tokenizer.mask_token_id,
    pad_token_id=tokenizer.pad_token_id,
    use_flash=True  # Enable Flash Attention
).to(device)

# Model statistics
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n🧠 Model initialized:")
print(f"  - Parameters: {num_params:,}")
print(f"  - Vocabulary size: {config.vocab_size:,}")
print(f"  - Mask token: {tokenizer.mask_token}")
print(f"  - Memory usage: {num_params * 4 / 1e6:.2f} MB (FP32)")


## 📊 4. Data Loading & Preparation


In [None]:
# Create dataloaders
print("📊 Loading datasets...")

train_dataloader = create_dataloader(
    config.data_path,
    batch_size=config.batch_size,
    tokenizer=tokenizer,
    max_length=config.max_length,
    shuffle=True,
    num_workers=0  # Set to 0 for Windows compatibility
)

# Create validation split (using same data for demo)
val_dataloader = create_dataloader(
    config.data_path,
    batch_size=config.batch_size,
    tokenizer=tokenizer,
    max_length=config.max_length,
    shuffle=False,
    num_workers=0
)

print(f"✅ Data loaded:")
print(f"  - Training batches: {len(train_dataloader)}")
print(f"  - Validation batches: {len(val_dataloader)}")
print(f"  - Batch size: {config.batch_size}")
print(f"  - Max sequence length: {config.max_length}")


## 🏋️ 5. Interactive Training with Live Progress


In [None]:
# Setup training
from model.utils import get_linear_schedule_with_warmup

# Initialize optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)

# Learning rate scheduler
num_training_steps = len(train_dataloader) * config.num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=min(100, num_training_steps // 10),
    num_training_steps=num_training_steps
)

# Training metrics storage
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

print(f"🎯 Training setup complete:")
print(f"  - Optimizer: AdamW")
print(f"  - Learning rate: {config.learning_rate}")
print(f"  - Total training steps: {num_training_steps}")
print(f"  - Warmup steps: {min(100, num_training_steps // 10)}")


In [None]:
# Interactive training loop with live visualization
import matplotlib.pyplot as plt
from IPython.display import clear_output

def train_epoch(model, dataloader, optimizer, scheduler, epoch, config):
    """Train for one epoch with progress tracking."""
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    num_batches = 0
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
    
    for batch_idx, batch in enumerate(progress_bar):
        # Move to device
        input_ids = batch['input_ids'].to(config.device)
        
        # Forward pass
        outputs = model(input_ids)
        
        # Compute loss
        loss = model.compute_loss(
            outputs['logits'],
            input_ids,
            outputs['mask']
        )
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
        
        # Optimizer step
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Calculate accuracy
        with torch.no_grad():
            mask = outputs['mask']
            if mask.any():
                preds = outputs['logits'][mask].argmax(dim=-1)
                targets = input_ids[mask]
                accuracy = (preds == targets).float().mean().item()
            else:
                accuracy = 0.0
        
        # Update metrics
        epoch_loss += loss.item()
        epoch_acc += accuracy
        num_batches += 1
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{accuracy:.3f}',
            'lr': f'{scheduler.get_last_lr()[0]:.6f}'
        })
        
        # Store metrics for plotting
        if batch_idx % 5 == 0:
            train_losses.append(loss.item())
            train_accuracies.append(accuracy)
    
    return epoch_loss / num_batches, epoch_acc / num_batches


def evaluate(model, dataloader, config):
    """Evaluate model on validation set."""
    model.eval()
    total_loss = 0
    total_acc = 0
    num_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            input_ids = batch['input_ids'].to(config.device)
            
            # Forward pass
            outputs = model(input_ids)
            
            # Compute loss
            loss = model.compute_loss(
                outputs['logits'],
                input_ids,
                outputs['mask']
            )
            
            # Calculate accuracy
            mask = outputs['mask']
            if mask.any():
                preds = outputs['logits'][mask].argmax(dim=-1)
                targets = input_ids[mask]
                accuracy = (preds == targets).float().mean().item()
            else:
                accuracy = 0.0
            
            total_loss += loss.item()
            total_acc += accuracy
            num_batches += 1
    
    return total_loss / num_batches, total_acc / num_batches


print("✅ Training functions defined")


In [None]:
# Start training with live progress visualization
print("🚀 Starting training with Flash Attention...")
print("=" * 50)

# Setup for live plotting
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
fig.suptitle('Training Progress', fontsize=14)

best_val_loss = float('inf')
os.makedirs('checkpoints', exist_ok=True)

for epoch in range(config.num_epochs):
    # Train for one epoch
    train_loss, train_acc = train_epoch(
        model, train_dataloader, optimizer, scheduler, epoch, config
    )
    
    # Evaluate
    val_loss, val_acc = evaluate(model, val_dataloader, config)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # Print epoch summary
    print(f"\n📈 Epoch {epoch+1}/{config.num_epochs}")
    print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.3f}")
    print(f"  Val   - Loss: {val_loss:.4f}, Acc: {val_acc:.3f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'config': config.to_dict(),
            'val_loss': val_loss
        }, 'checkpoints/best_model.pt')
        print(f"  💾 Saved best model (val_loss: {val_loss:.4f})")
    
    # Update live plot
    clear_output(wait=True)
    
    # Plot losses
    axes[0].clear()
    if len(train_losses) > 0:
        axes[0].plot(train_losses, label='Train', alpha=0.7)
    if len(val_losses) > 0:
        axes[0].plot(np.arange(0, len(val_losses)) * len(train_dataloader), 
                    val_losses, label='Val', marker='o')
    axes[0].set_xlabel('Steps')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss over Time')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot accuracies
    axes[1].clear()
    if len(train_accuracies) > 0:
        axes[1].plot(train_accuracies, label='Train', alpha=0.7)
    if len(val_accuracies) > 0:
        axes[1].plot(np.arange(0, len(val_accuracies)) * len(train_dataloader), 
                    val_accuracies, label='Val', marker='o')
    axes[1].set_xlabel('Steps')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Accuracy over Time')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Generate sample every 2 epochs
    if (epoch + 1) % 2 == 0:
        print("\n🎨 Generating sample...")
        model.eval()
        with torch.no_grad():
            generated = model.generate(
                prompt=None,
                max_length=50,
                num_steps=config.num_generation_steps,
                temperature=config.generation_temperature,
                top_p=config.generation_top_p
            )
            generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
            print(f"  Generated: {generated_text[:150]}...")
        model.train()

print("\n✅ Training complete!")
print(f"Best validation loss: {best_val_loss:.4f}")


## 🎨 6. Text Generation & Evaluation


In [None]:
# Load best model for generation
checkpoint = torch.load('checkpoints/best_model.pt', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"✅ Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"  Validation loss: {checkpoint['val_loss']:.4f}")


In [None]:
# Generate text samples
print("🎨 Generating text samples...")
print("=" * 50)

num_samples = 5
generated_texts = []

for i in range(num_samples):
    with torch.no_grad():
        # Generate from scratch
        generated = model.generate(
            prompt=None,
            max_length=60,
            num_steps=30,
            temperature=0.9,
            top_p=0.95
        )
        
        generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
        generated_texts.append(generated_text)
        
        print(f"\n📝 Sample {i+1}:")
        print(f"  {generated_text}")

print("\n" + "=" * 50)


## 🔮 7. Text Infilling / Completion


In [None]:
# Text infilling demonstration
print("🔮 Text Infilling Demo")
print("=" * 50)

templates = [
    "The [MASK] [MASK] model can [MASK] text efficiently.",
    "Machine learning [MASK] are trained using [MASK] [MASK].",
    "[MASK] attention mechanisms allow [MASK] to focus on [MASK] parts.",
    "The transformer [MASK] has revolutionized [MASK] [MASK] processing.",
]

for template in templates:
    print(f"\n📝 Template: {template}")
    
    # Tokenize template
    inputs = tokenizer(
        template,
        return_tensors='pt',
        padding='max_length',
        max_length=30,
        truncation=True
    )
    input_ids = inputs['input_ids'].to(device)
    
    # Fill in the masks
    with torch.no_grad():
        filled = model.reverse_diffusion(
            input_ids,
            num_steps=20,
            temperature=0.8,
            confidence_threshold=0.7
        )
    
    filled_text = tokenizer.decode(filled[0], skip_special_tokens=True)
    print(f"✨ Filled:   {filled_text}")

print("\n" + "=" * 50)


## 📈 8. Visualizations


In [None]:
# Visualize diffusion process
print("📊 Visualizing Diffusion Process")

# Sample text for visualization
sample_text = "The diffusion model learns to generate text."
inputs = tokenizer(sample_text, return_tensors='pt', padding='max_length', 
                   max_length=20, truncation=True)
input_ids = inputs['input_ids'].to(device)

# Track diffusion process
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
fig.suptitle('Forward Diffusion Process (Masking)', fontsize=14)

timesteps = [0, 10, 20, 30, 40, 49]
masks_history = []

for idx, t in enumerate(timesteps):
    ax = axes[idx // 3, idx % 3]
    
    # Apply forward diffusion
    t_tensor = torch.tensor([t], device=device)
    with torch.no_grad():
        noised_ids, mask = model.forward_diffusion(input_ids, t_tensor)
    
    masks_history.append(mask[0].cpu().numpy())
    
    # Visualize mask
    mask_visual = mask[0].cpu().numpy()[:20].reshape(1, -1)
    im = ax.imshow(mask_visual, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=1)
    
    # Decode text
    text = tokenizer.decode(noised_ids[0][:20], skip_special_tokens=False)
    
    ax.set_title(f't={t} ({mask.float().mean():.1%} masked)')
    ax.set_xlabel('Token Position')
    ax.set_yticks([])
    ax.set_xticks(range(0, 20, 5))
    
    # Add text below
    wrapped_text = text[:40] + "..." if len(text) > 40 else text
    ax.text(0.5, -0.15, wrapped_text, transform=ax.transAxes, 
            ha='center', fontsize=8, wrap=True)

plt.tight_layout()
plt.colorbar(im, ax=axes, label='Masked (1) vs Unmasked (0)', 
             orientation='horizontal', fraction=0.05, pad=0.15)
plt.show()

print("✅ Diffusion visualization complete")


## 💾 9. Save & Export


In [None]:
# Save final model and training history
print("💾 Saving model and results...")

# Save model with full configuration
model_save_path = 'checkpoints/final_model.pt'
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config.to_dict(),
    'vocab_size': len(tokenizer),
    'training_history': {
        'train_losses': train_losses[-100:],  # Save last 100 steps
        'val_losses': val_losses,
        'val_accuracies': val_accuracies
    }
}, model_save_path)

print(f"✅ Model saved to {model_save_path}")

# Export configuration
config_path = 'checkpoints/config.json'
with open(config_path, 'w') as f:
    json.dump(config.to_dict(), f, indent=2)
print(f"✅ Configuration saved to {config_path}")

# Save generated samples
samples_path = 'checkpoints/generated_samples.txt'
with open(samples_path, 'w', encoding='utf-8') as f:
    for i, text in enumerate(generated_texts):
        f.write(f"Sample {i+1}:\n{text}\n\n")
print(f"✅ Generated samples saved to {samples_path}")

print("\n" + "=" * 50)
print("🎉 All done! Your model is trained and ready to use.")
print("\nTo load the model later:")
print("```python")
print("checkpoint = torch.load('checkpoints/final_model.pt')")
print("model.load_state_dict(checkpoint['model_state_dict'])")
print("```")


## 🎯 Summary

This notebook demonstrated:
1. ✅ **PDF Processing**: Extract and prepare text data from PDFs
2. ✅ **Flash Attention**: Enabled by default for 2-3x faster training
3. ✅ **Interactive Training**: Live progress visualization with plots
4. ✅ **Text Generation**: Generate coherent text from scratch
5. ✅ **Text Infilling**: Fill in masked tokens intelligently
6. ✅ **Visualization**: See the diffusion process in action

### Next Steps:
- 📚 Add more training data (PDFs)
- 🔧 Tune hyperparameters in config
- 🚀 Scale up model size for better quality
- 🎨 Experiment with different generation settings
- 💡 Try conditional generation with prompts

**Flash Attention Note**: If you have a compatible GPU (Ampere or newer), Flash Attention provides significant speedup. The model automatically falls back to standard attention if Flash Attention is not available.
