# Implementation test

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from dataset import IAMLinesDataset, IAMLinesDatasetPyTorch
import importlib


In [9]:
class Config:
    def __init__(self):
        # Dataset parameters
        self.img_size = (64, 256)  # Height, Width - Adjusted for line-level segments
        self.batch_size = 64 # Reduced for potential memory issues on MPS, adjust as needed
        self.num_workers = 0 # Often set to 0 for MPS compatibility issues

        # VAE parameters
        self.latent_dim = 512  # Latent dimension for the VAE (flattened)
        self.vae_latent_channels = 256 # Channel dimension in VAE bottleneck (spatial)
        self.vae_lr = 1e-4
        self.vae_epochs = 3 # Keep low for testing

        # Diffusion model parameters
        self.timesteps = 1000  # Number of diffusion steps
        self.beta_start = 1e-4  # Starting noise schedule value
        self.beta_end = 2e-2  # Ending noise schedule value
        self.diffusion_lr = 1e-4
        self.diffusion_epochs = 5 # Keep low for testing

        # Style and content encoder parameters
        self.style_dim = 256
        self.content_dim = 256
        self.vocab_size = 128 # Example vocab size for ContentEncoder
        self.max_seq_len = 100 # Example max sequence length

        # Training parameters
        if torch.backends.mps.is_available():
            self.device = torch.device("mps")
            print("Using MPS device.")
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
            print("Using CUDA device.")
        else:
            self.device = torch.device("cpu")
            print("Using CPU device.")
        self.save_dir = "./models"
        self.results_dir = "./results"

        # Create directories if they don't exist
        os.makedirs(self.save_dir, exist_ok=True)
        os.makedirs(self.results_dir, exist_ok=True)

In [10]:
# Dataset and preprocessing
class IAMDatasetWrapper:
    def __init__(self, config):
        self.config = config
        self.transform = transforms.Compose([
            transforms.Resize(config.img_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
        ])
        
    def setup(self):
        """Load the IAM dataset and create train/val/test splits"""
        print("Loading IAM dataset...")
        # Assuming IAMLinesDatasetPyTorch is already implemented
        data_path = 'data/lines.tgz' 
        xml_path = 'data/xml.tgz'    
        iam_dataset = IAMLinesDataset(data_path, xml_path)
        full_dataset = IAMLinesDatasetPyTorch(iam_dataset=iam_dataset, transform=self.transform)
        
        # Split into train/val/test
        train_size = int(0.8 * len(full_dataset))
        val_size = int(0.1 * len(full_dataset))
        test_size = len(full_dataset) - train_size - val_size
        
        train_dataset, val_dataset, test_dataset = random_split(
            full_dataset, [train_size, val_size, test_size]
        )
        
        self.train_loader = DataLoader(
            train_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=True,
            num_workers=self.config.num_workers
        )
        
        self.val_loader = DataLoader(
            val_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=False,
            num_workers=self.config.num_workers
        )
        
        self.test_loader = DataLoader(
            test_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=False,
            num_workers=self.config.num_workers
        )
        
        print(f"Dataset loaded. Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
        return self.train_loader, self.val_loader, self.test_loader


In [11]:
# VAE Model
class VAE(nn.Module):
    def __init__(self, config):
        super(VAE, self).__init__()
        self.config = config
        self.latent_dim = config.latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
        )
        
        # Calculate the size of the encoder output
        h, w = config.img_size
        self.encoder_output_size = (h // 16, w // 16, 256)
        self.encoder_flattened_dim = self.encoder_output_size[0] * self.encoder_output_size[1] * self.encoder_output_size[2]
        
        # Mean and log variance layers
        self.fc_mu = nn.Linear(self.encoder_flattened_dim, self.latent_dim)
        self.fc_logvar = nn.Linear(self.encoder_flattened_dim, self.latent_dim)
        
        # Decoder input layer
        self.decoder_input = nn.Linear(self.latent_dim, self.encoder_flattened_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Output in range [-1, 1]
        )
    
    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten
        
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        z = self.decoder_input(z)
        z = z.view(-1, 256, self.encoder_output_size[0], self.encoder_output_size[1])
        x_recon = self.decoder(z)
        return x_recon
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

# Style Encoder
class StyleEncoder(nn.Module):
    def __init__(self, config):
        super(StyleEncoder, self).__init__()
        self.config = config
        
        # CNN backbone
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        # Projection head
        self.fc = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, config.style_dim)
        )
    
    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)  # Flatten
        style = self.fc(x)
        return style

# Content Encoder (for text)
class ContentEncoder(nn.Module):
    def __init__(self, config, embedding_dim=128):
        super(ContentEncoder, self).__init__()
        self.config = config

        # Embedding layer
        self.embedding = nn.Embedding(config.vocab_size, embedding_dim)

        # LSTM layers
        lstm_hidden_size = 256
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=lstm_hidden_size,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )

        # Projection head (adjust input size based on bidirectional LSTM)
        self.fc = nn.Linear(lstm_hidden_size * 2, config.content_dim) # 2 for bidirectional

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: [batch_size, seq_len]
        batch_size = x.shape[0]
        embedded = self.embedding(x) # Shape: [B, seq_len, embedding_dim]

        # LSTM forward pass
        # Initialize hidden and cell states
        h0 = torch.zeros(2 * 2, batch_size, self.lstm.hidden_size).to(x.device) # num_layers * num_directions
        c0 = torch.zeros(2 * 2, batch_size, self.lstm.hidden_size).to(x.device)

        output, (hidden, _) = self.lstm(embedded, (h0, c0)) # hidden shape: [num_layers * num_directions, B, hidden_size]

        # Concatenate final forward and backward hidden states from the last layer
        # hidden[-2] is the last forward state, hidden[-1] is the last backward state
        hidden_cat = torch.cat([hidden[-2, :, :], hidden[-1, :, :]], dim=1) # Shape: [B, hidden_size * 2]

        # Project to content dimension
        content = self.fc(hidden_cat) # Shape: [B, content_dim]
        return content


# UNet model for diffusion
class DiffusionUNet(nn.Module):
    def __init__(self, config):
        super(DiffusionUNet, self).__init__()
        self.config = config

        # Calculate latent image dimensions (output of VAE encoder bottleneck)
        h, w = config.img_size
        self.latent_h = h // 16 # e.g., 64 // 16 = 4
        self.latent_w = w // 16 # e.g., 256 // 16 = 16

        # VAE bottleneck channels (should match VAE encoder output channels)
        vae_latent_channels = config.vae_latent_channels # e.g., 256

        # Conditioning embedding dimension (size after projecting time, style, content)
        cond_embed_dim = 128 # An intermediate dimension for embeddings

        # --- Time and Conditioning Embeddings ---
        self.time_embed = nn.Sequential(
            nn.Linear(1, cond_embed_dim * 4), # Project to a higher dim first
            nn.SiLU(),
            nn.Linear(cond_embed_dim * 4, cond_embed_dim), # Project back to target dim
        )
        # Use MLP for style/content embedding projection for potentially better capacity
        self.style_embed = nn.Sequential(
             nn.Linear(config.style_dim, cond_embed_dim * 2),
             nn.SiLU(),
             nn.Linear(cond_embed_dim * 2, cond_embed_dim)
        )
        self.content_embed = nn.Sequential(
             nn.Linear(config.content_dim, cond_embed_dim * 2),
             nn.SiLU(),
             nn.Linear(cond_embed_dim * 2, cond_embed_dim)
        )


        # --- Initial Convolution Layer ---
        # Input channels = VAE latent channels + conditioning embedding dimension
        initial_in_channels = vae_latent_channels + cond_embed_dim # e.g., 256 + 128 = 384
        initial_out_channels = 256 # Desired channels for the first down block

        # Add GroupNorm and SiLU after initial conv
        self.initial_conv = nn.Sequential(
             nn.Conv2d(initial_in_channels, initial_out_channels, kernel_size=3, padding=1),
             nn.GroupNorm(32, initial_out_channels), # Normalize after conv
             nn.SiLU() # Activation
        )


        # --- Down blocks ---
        # Structure assumes ResNet-like blocks where skip is taken before pooling
        # Let's define channels and create blocks
        down_channels = [initial_out_channels, 256, 512, 512] # Input channels for each level
        self.down_blocks = nn.ModuleList([])
        for i in range(len(down_channels) - 1):
             self.down_blocks.append(self._make_down_block(down_channels[i], down_channels[i+1]))

        # --- Middle block ---
        self.mid_block = nn.Sequential(
            nn.Conv2d(down_channels[-1], down_channels[-1], kernel_size=3, padding=1),
            nn.GroupNorm(32, down_channels[-1]),
            nn.SiLU(),
            nn.Conv2d(down_channels[-1], down_channels[-1], kernel_size=3, padding=1),
            nn.GroupNorm(32, down_channels[-1]),
            nn.SiLU(),
        )

        # --- Up blocks ---
        # Input channels need to account for skip connections
        up_channels = down_channels[::-1] # [512, 512, 256, 256] (Output channels of mid/down blocks)
        self.up_blocks = nn.ModuleList([])
        for i in range(len(up_channels) - 1):
            # Input channels = output channels from previous up-block + skip connection channels from corresponding down-block
            in_ch = up_channels[i] + down_channels[-1-i] # e.g., up1: 512(mid) + 512(d3_skip), up2: 512(u1) + 512(d2_skip) ...
            out_ch = up_channels[i+1] # Target output channels for this block
            self.up_blocks.append(self._make_up_block(in_ch, out_ch))

        # --- Output projection ---
        # Output should predict noise, matching the VAE latent channels
        self.out_proj = nn.Conv2d(up_channels[-1], vae_latent_channels, kernel_size=3, padding=1)

    def _make_down_block(self, in_channels, out_channels):
        # Simpler block: Conv -> Norm -> Act -> Conv -> Norm -> Act -> Pool
        # Returns the block and expects skip connection to be handled externally if needed
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(32, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(32, out_channels),
            nn.SiLU(),
            nn.AvgPool2d(2) # Pool at the end
        )

    def _make_up_block(self, in_channels, out_channels):
        # Upsample -> Conv -> Norm -> Act -> Conv -> Norm -> Act
        return nn.Sequential(
             nn.Upsample(scale_factor=2, mode='nearest'),
             nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), # Adjust channels after upsampling & concat
             nn.GroupNorm(32, out_channels),
             nn.SiLU(),
             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
             nn.GroupNorm(32, out_channels),
             nn.SiLU(),
         )

    def forward(self, x: torch.Tensor, t: torch.Tensor, style: torch.Tensor, content: torch.Tensor) -> torch.Tensor:
        # x input shape expected: [B, vae_latent_channels, H', W'] e.g., [B, 256, 4, 16]
        # t input shape: [B]
        # style input shape: [B, style_dim]
        # content input shape: [B, content_dim]

        # 1. Embeddings
        t = t.unsqueeze(-1).float() # [B, 1]
        t_emb = self.time_embed(t) # [B, cond_embed_dim]
        style_emb = self.style_embed(style) # [B, cond_embed_dim]
        content_emb = self.content_embed(content) # [B, cond_embed_dim]

        # Combine embeddings and expand spatially
        # Add embeddings element-wise
        cond_emb = t_emb + style_emb + content_emb # [B, cond_embed_dim]
        cond_emb = cond_emb.unsqueeze(-1).unsqueeze(-1) # [B, cond_embed_dim, 1, 1]
        cond_emb = cond_emb.expand(-1, -1, self.latent_h, self.latent_w) # [B, cond_embed_dim, H', W']

        # 2. Concatenate input x with spatial conditioning
        x_cat = torch.cat([x, cond_emb], dim=1) # [B, vae_latent_channels + cond_embed_dim, H', W'] e.g. [B, 384, 4, 16]

        # 3. Initial Convolution
        h = self.initial_conv(x_cat) # [B, initial_out_channels, H', W'] e.g. [B, 256, 4, 16]

        # 4. Down Path (Encoder) - Store outputs for skip connections
        skip_connections = [h] # Store initial conv output as first "skip"
        for block in self.down_blocks:
            h = block[:-1](h) # Apply convs/norms/activations
            skip_connections.append(h) # Store output *before* pooling
            h = block[-1](h) # Apply pooling
        # Now h is the output after the last down block's pooling

        # 5. Middle Path
        h = self.mid_block(h) # Apply middle block convolutions

        # 6. Up Path (Decoder) - Use skip connections
        for i, block in enumerate(self.up_blocks):
            skip = skip_connections.pop() # Get corresponding skip connection (last stored is first needed)
            # Upsample h (using the first layer of the up_block)
            h = block[0](h) # Apply Upsample layer

            # Check spatial alignment before concatenation (optional but good practice)
            if h.shape[2:] != skip.shape[2:]:
                 # If sizes don't match (e.g., due to odd dimensions + pooling), resize h
                 print(f"Warning: Resizing skip connection {skip.shape} to match upsampled {h.shape}")
                 h = F.interpolate(h, size=skip.shape[2:], mode='bilinear', align_corners=False)

            # Concatenate upsampled h with the skip connection
            h = torch.cat([h, skip], dim=1) # Concatenate along channel dimension

            # Apply the rest of the up_block layers (convs, norms, activations)
            h = block[1:](h)

        # 7. Output Projection
        output = self.out_proj(h) # [B, vae_latent_channels, H', W']

        # Ensure output shape matches input shape (important!)
        if output.shape != x.shape:
            raise ValueError(f"UNet output shape {output.shape} does not match input shape {x.shape}")

        return output # Shape: [B, vae_latent_channels, H', W'] e.g. [B, 256, 4, 16]

class LatentDiffusionModel:
    def __init__(self, config: Config):
        self.config = config
        self.device = config.device

        # Define beta schedule
        self.beta = torch.linspace(config.beta_start, config.beta_end, config.timesteps, device=self.device)
        self.alpha = 1. - self.beta
        self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)

        # Initialize models
        self.vae = VAE(config).to(self.device)
        self.style_encoder = StyleEncoder(config).to(self.device)
        # Pass config to ContentEncoder
        self.content_encoder = ContentEncoder(config).to(self.device)
        self.diffusion_model = DiffusionUNet(config).to(self.device)

        # Define optimizers
        self.vae_optimizer = optim.Adam(self.vae.parameters(), lr=config.vae_lr)
        # Group diffusion model parameters with style/content encoders for joint training
        diffusion_params = list(self.diffusion_model.parameters()) + \
                           list(self.style_encoder.parameters()) + \
                           list(self.content_encoder.parameters())
        self.diffusion_optimizer = optim.Adam(diffusion_params, lr=config.diffusion_lr)

    def train_vae(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int = None) -> None:
        """Train the VAE model"""
        if epochs is None:
            epochs = self.config.vae_epochs

        print("Training VAE...")

        for epoch in range(epochs):
            self.vae.train()
            train_loss = 0.0
            recon_loss_total = 0.0
            kl_loss_total = 0.0

            progress_bar = tqdm(train_loader, desc=f"VAE Epoch {epoch+1}/{epochs}", leave=False)
            for batch in progress_bar:
                self.vae_optimizer.zero_grad()

                images = batch['image'].to(self.device) # Shape: [B, 1, H, W]

                # Forward pass
                recon_images, mu, logvar = self.vae(images)

                # Compute loss
                # Ensure reconstruction loss is calculated correctly, maybe scale invariant
                recon_loss = F.mse_loss(recon_images, images, reduction='sum') / images.shape[0] # Per-sample MSE sum
                # KL divergence loss (ensure it's summed over latent dims and averaged over batch)
                kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) # Sum over latent dim
                kl_loss = torch.mean(kl_loss) # Average over batch

                # Adjust KL weight if needed (can start small and increase)
                kl_weight = 0.0001 # Example weight, tune this
                loss = recon_loss + kl_weight * kl_loss

                # Backward pass
                loss.backward()
                # Gradient clipping (optional but often helpful)
                torch.nn.utils.clip_grad_norm_(self.vae.parameters(), max_norm=1.0)
                self.vae_optimizer.step()

                train_loss += loss.item()
                recon_loss_total += recon_loss.item()
                kl_loss_total += kl_loss.item()
                progress_bar.set_postfix(loss=loss.item(), recon=recon_loss.item(), kl=kl_loss.item())


            train_loss /= len(train_loader)
            recon_loss_total /= len(train_loader)
            kl_loss_total /= len(train_loader)


            # Validation
            val_loss, val_recon, val_kl = self._validate_vae(val_loader)

            print(f"VAE Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} (Recon: {recon_loss_total:.4f}, KL: {kl_loss_total:.4f}) | "
                  f"Val Loss: {val_loss:.4f} (Recon: {val_recon:.4f}, KL: {val_kl:.4f})")

            # Save checkpoint (example: save every epoch or best model)
            # if (epoch + 1) % 1 == 0: # Save every epoch for debugging
            self._save_checkpoint(f"vae_epoch_{epoch+1}.pt", model=self.vae)
            # Add logic here to save the best model based on validation loss

        print("VAE training completed.")

    def _validate_vae(self, val_loader: DataLoader) -> tuple[float, float, float]:
        """Validate the VAE model"""
        self.vae.eval()
        val_loss = 0.0
        recon_loss_total = 0.0
        kl_loss_total = 0.0
        kl_weight = 0.0001 # Use same weight as training for comparison

        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(self.device)

                # Forward pass
                recon_images, mu, logvar = self.vae(images)

                # Compute loss
                recon_loss = F.mse_loss(recon_images, images, reduction='sum') / images.shape[0]
                kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
                kl_loss = torch.mean(kl_loss)
                loss = recon_loss + kl_weight * kl_loss

                val_loss += loss.item()
                recon_loss_total += recon_loss.item()
                kl_loss_total += kl_loss.item()


        val_loss /= len(val_loader)
        recon_loss_total /= len(val_loader)
        kl_loss_total /= len(val_loader)
        return val_loss, recon_loss_total, kl_loss_total

    # --- REWRITTEN train_diffusion ---
    def train_diffusion(self, train_loader: DataLoader, val_loader: DataLoader, epochs: int = None) -> None:
        """
        Train the Diffusion UNet model along with Style and Content Encoders,
        using the frozen VAE to operate in the latent space.
        """
        if epochs is None:
            epochs = self.config.diffusion_epochs

        print("Training Diffusion Model (UNet + Style/Content Encoders)...")

        # Ensure VAE is in eval mode and its parameters are frozen
        self.vae.eval()
        for param in self.vae.parameters():
            param.requires_grad = False

        # Ensure other models are trainable
        self.diffusion_model.train()
        self.style_encoder.train()
        self.content_encoder.train()

        for epoch in range(epochs):
            train_loss = 0.0
            progress_bar = tqdm(train_loader, desc=f"Diffusion Epoch {epoch+1}/{epochs}", leave=False)

            for batch in progress_bar:
                self.diffusion_optimizer.zero_grad()

                # --- Get Data ---
                images = batch['image'].to(self.device) # Shape: [B, 1, H, W]
                text_list = batch['transcription'] # List of strings

                # --- TODO: Implement Proper Text Tokenization ---
                # This is a placeholder. You NEED to replace this with actual tokenization
                # based on your dataset's vocabulary and padding strategy.
                # Example: Use a Hugging Face tokenizer or a custom vocabulary.
                # The output should be padded sequences of token indices.
                current_batch_size = images.shape[0]
                text_indices = torch.randint(0, self.config.vocab_size,
                                             (current_batch_size, self.config.max_seq_len),
                                             dtype=torch.long, device=self.device) # Placeholder
                # --- End Placeholder ---

                # --- VAE Encoding to Spatial Latent ---
                with torch.no_grad():
                    mu, logvar = self.vae.encode(images)
                    # Sample from the latent distribution (flattened)
                    z_flat = self.vae.reparameterize(mu, logvar) # Shape: [B, latent_dim]
                    # Project and reshape z back to spatial latent dimensions
                    latent_spatial = self.vae.get_spatial_latent(z_flat)
                    # Expected shape: [B, vae_latent_channels, H/16, W/16] e.g., [64, 256, 4, 16]

                # --- Diffusion Forward Process ---
                # Sample random timesteps for each image in the batch
                t = torch.randint(0, self.config.timesteps, (current_batch_size,), device=self.device).long()

                # Sample noise matching the spatial latent shape
                noise = torch.randn_like(latent_spatial) # Shape: [B, C, H', W']

                # Get alpha_cumprod for sampled timesteps t and reshape for broadcasting
                alpha_cumprod_t = self.alpha_cumprod[t].view(-1, 1, 1, 1) # Shape: [B, 1, 1, 1]

                # Calculate noisy latents (forward diffusion formula)
                noisy_latents = torch.sqrt(alpha_cumprod_t) * latent_spatial + \
                                torch.sqrt(1.0 - alpha_cumprod_t) * noise
                # noisy_latents shape: [B, C, H', W']

                # --- Conditioning ---
                # Encode style from images (using the original image)
                style_features = self.style_encoder(images) # Shape: [B, style_dim]
                # Encode content from text indices
                content_features = self.content_encoder(text_indices) # Shape: [B, content_dim]

                # --- Diffusion UNet Prediction ---
                # Predict the noise added at timestep t, conditioned on style and content
                noise_pred = self.diffusion_model(noisy_latents, t, style_features, content_features)
                # noise_pred expected shape: [B, C, H', W'] (same as noise)

                # --- Loss Calculation ---
                # Calculate loss between the predicted noise and the actual noise added
                loss = F.mse_loss(noise_pred, noise) # Simple MSE loss is common

                # --- Backward Pass & Optimization ---
                loss.backward()
                # Optional: Gradient clipping for diffusion model parameters
                torch.nn.utils.clip_grad_norm_(self.diffusion_optimizer.param_groups[0]['params'], max_norm=1.0)
                self.diffusion_optimizer.step()

                train_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item())

            train_loss /= len(train_loader)

            # --- Validation ---
            val_loss = self._validate_diffusion(val_loader)

            print(f"Diffusion Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

            # --- Save Checkpoint ---
            # Example: Save every few epochs or based on validation loss
            if (epoch + 1) % 5 == 0: # Save every 5 epochs
                self._save_checkpoint(
                    f"diffusion_epoch_{epoch+1}.pt",
                    models_dict={ # Save as a dictionary for clarity
                         'diffusion_model': self.diffusion_model,
                         'style_encoder': self.style_encoder,
                         'content_encoder': self.content_encoder
                     }
                )
            # Add logic here to save the best model based on validation loss


        print("Diffusion model training completed.")

    # --- REWRITTEN _validate_diffusion ---
    def _validate_diffusion(self, val_loader: DataLoader) -> float:
        """
        Validate the diffusion model (UNet + Style/Content Encoders) on the validation set.
        Mirrors the training loop logic for calculating the diffusion loss.
        """
        self.diffusion_model.eval()
        self.style_encoder.eval()
        self.content_encoder.eval()
        # VAE should already be in eval mode from the start of train_diffusion

        val_loss = 0.0

        with torch.no_grad():
            for batch in val_loader:
                # --- Get Data ---
                images = batch['image'].to(self.device)
                text_list = batch['transcription']
                current_batch_size = images.shape[0]

                # --- TODO: Implement Proper Text Tokenization (Placeholder) ---
                text_indices = torch.randint(0, self.config.vocab_size,
                                             (current_batch_size, self.config.max_seq_len),
                                             dtype=torch.long, device=self.device) # Placeholder
                # --- End Placeholder ---

                # --- VAE Encoding to Spatial Latent ---
                # Use VAE mean for validation/inference, or reparameterize like in training
                # Using reparameterization mirrors training loss calculation better
                mu, logvar = self.vae.encode(images)
                z_flat = self.vae.reparameterize(mu, logvar) # Or just use mu for deterministic validation
                latent_spatial = self.vae.get_spatial_latent(z_flat) # Shape: [B, C, H', W']

                # --- Diffusion Forward Process ---
                t = torch.randint(0, self.config.timesteps, (current_batch_size,), device=self.device).long()
                noise = torch.randn_like(latent_spatial)
                alpha_cumprod_t = self.alpha_cumprod[t].view(-1, 1, 1, 1)
                noisy_latents = torch.sqrt(alpha_cumprod_t) * latent_spatial + \
                                torch.sqrt(1.0 - alpha_cumprod_t) * noise

                # --- Conditioning ---
                style_features = self.style_encoder(images)
                content_features = self.content_encoder(text_indices)

                # --- Diffusion UNet Prediction ---
                noise_pred = self.diffusion_model(noisy_latents, t, style_features, content_features)

                # --- Loss Calculation ---
                loss = F.mse_loss(noise_pred, noise)
                val_loss += loss.item()

        # Switch models back to train mode after validation if needed within the epoch loop
        # (Done automatically at the start of the next training epoch loop iteration)
        self.diffusion_model.train()
        self.style_encoder.train()
        self.content_encoder.train()

        return val_loss / len(val_loader)


    def _save_checkpoint(self, filename: str, model: nn.Module = None, models_dict: dict = None) -> None:
        """Save model checkpoint(s)"""
        save_path = os.path.join(self.config.save_dir, filename)

        if model is not None:
            torch.save(model.state_dict(), save_path)
            print(f"Saved single model checkpoint to {save_path}")
        elif models_dict is not None:
            save_state = {name: m.state_dict() for name, m in models_dict.items()}
            torch.save(save_state, save_path)
            print(f"Saved multiple model checkpoints to {save_path}")
        else:
             print("Warning: _save_checkpoint called without model or models_dict.")


    # --- Generation function (kept similar, check spatial latent generation) ---
    @torch.no_grad() # Decorator for no_grad context
    def generate_handwriting(self, style_image: torch.Tensor, text: list[str], guidance_scale: float = 7.5) -> torch.Tensor:
        """
        Generate handwriting using the trained models with DDPM sampling and optional classifier-free guidance.

        Args:
            style_image: A single preprocessed style image tensor [1, H, W].
            text: A list containing the text string to generate.
            guidance_scale: Scale for classifier-free guidance. 0 means no guidance.

        Returns:
            A tensor representing the generated handwriting image [1, 1, H, W].
        """
        # Ensure models are in eval mode
        self.vae.eval()
        self.style_encoder.eval()
        self.content_encoder.eval()
        self.diffusion_model.eval()

        # --- Prepare Inputs ---
        # Style Image
        style_image = style_image.unsqueeze(0).to(self.device) # Add batch dim: [1, 1, H, W]
        style_features = self.style_encoder(style_image) # [1, style_dim]

        # Text
        # TODO: Use the same tokenization as in training for the input text
        text_indices = torch.randint(0, self.config.vocab_size,
                                      (1, self.config.max_seq_len),
                                      dtype=torch.long, device=self.device) # Placeholder [1, seq_len]
        content_features = self.content_encoder(text_indices) # [1, content_dim]

        # --- Prepare Unconditional Inputs (for Classifier-Free Guidance) ---
        # Often uses zeroed or averaged embeddings, or requires training with dropped conditioning
        # Placeholder: Use zero embeddings for unconditional generation
        uncond_style_features = torch.zeros_like(style_features).to(self.device)
        uncond_content_features = torch.zeros_like(content_features).to(self.device)
        # If trained with conditioning dropout, use specific null embeddings

        # Combine conditional and unconditional features if using guidance
        if guidance_scale > 0:
            style_features = torch.cat([uncond_style_features, style_features], dim=0) # [2, style_dim]
            content_features = torch.cat([uncond_content_features, content_features], dim=0) # [2, content_dim]
            batch_size = 2
        else:
            batch_size = 1 # Only conditional generation

        # --- Sampling Loop (DDPM) ---
        # Start with random noise in the VAE latent space
        latent_h = self.config.img_size[0] // 16
        latent_w = self.config.img_size[1] // 16
        latent_channels = self.config.vae_latent_channels
        latent_shape = (batch_size, latent_channels, latent_h, latent_w)
        latents = torch.randn(latent_shape, device=self.device) # Start with noise z_T

        progress_bar = tqdm(range(self.config.timesteps - 1, -1, -1), desc="Generating", leave=False)
        for i in progress_bar:
            t = torch.full((batch_size,), i, device=self.device, dtype=torch.long) # Timestep tensor [batch_size]

            # Predict noise (handles both conditional and unconditional if batch_size=2)
            noise_pred = self.diffusion_model(latents, t, style_features, content_features)

            # Classifier-Free Guidance
            if guidance_scale > 0:
                noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
                # Combine predictions: noise = uncond + scale * (cond - uncond)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
                # Now noise_pred has shape [1, C, H', W'] if guidance was applied

            # --- DDPM Sampling Step ---
            # Get coefficients for timestep t
            alpha_t = self.alpha[i]
            alpha_cumprod_t = self.alpha_cumprod[i]
            beta_t = self.beta[i]

            # Calculate mean prediction (denoised latent estimate)
            # Based on Ho et al. (2020) DDPM paper, Eq. 11 & 15 implementation details
            coeff1 = 1.0 / torch.sqrt(alpha_t)
            coeff2 = beta_t / torch.sqrt(1.0 - alpha_cumprod_t)
            mean_pred = coeff1 * (latents - coeff2 * noise_pred)

            # Sample noise for the step (unless it's the last step)
            if i > 0:
                noise = torch.randn_like(latents)
                # Calculate variance (fixed small or learned, using beta_t for simplicity)
                variance = torch.sqrt(beta_t) * noise
                latents = mean_pred + variance # Update latents z_{t-1}
            else:
                latents = mean_pred # Final denoised latent z_0 (no noise added)

            # If guidance was used, latents might have batch size 2, keep only the guided one
            if guidance_scale > 0 and latents.shape[0] == 2:
                 latents = latents[1:] # Keep the conditional prediction

        # --- Decode Final Latent ---
        # latents should now be the estimated clean latent [1, C, H', W']
        # Decode using the VAE decoder
        generated_image = self.vae.decode(latents) # Pass spatial latent directly

        # Post-process: Clamp or rescale output if needed (VAE uses Tanh -> [-1, 1])
        generated_image = (generated_image + 1.0) / 2.0 # Rescale to [0, 1]
        generated_image = generated_image.clamp(0, 1)

        return generated_image # Shape: [1, 1, H, W]


In [12]:
config = Config()

Using MPS device.


In [13]:
data_wrapper = IAMDatasetWrapper(config)

In [14]:
train_loader, val_loader, test_loader = data_wrapper.setup()

Loading IAM dataset...
Dataset loaded. Train: 10682, Val: 1335, Test: 1336


In [15]:
model = LatentDiffusionModel(config)

In [16]:
model.train_vae(train_loader, val_loader)

Training VAE...


                                                                                                        

VAE Epoch 1/3 | Train Loss: 1858.4385 (Recon: 1857.6365, KL: 8019.0288) | Val Loss: 1130.3636 (Recon: 1130.1324, KL: 2312.1105)
Saved single model checkpoint to ./models/vae_epoch_1.pt


                                                                                                        

VAE Epoch 2/3 | Train Loss: 1119.0259 (Recon: 1118.8272, KL: 1987.3379) | Val Loss: 1101.7568 (Recon: 1101.5337, KL: 2230.6409)
Saved single model checkpoint to ./models/vae_epoch_2.pt


                                                                                                        

VAE Epoch 3/3 | Train Loss: 1063.5880 (Recon: 1063.3850, KL: 2029.4460) | Val Loss: 1027.3713 (Recon: 1027.1740, KL: 1972.7895)
Saved single model checkpoint to ./models/vae_epoch_3.pt
VAE training completed.


In [17]:
model.train_diffusion(train_loader, val_loader) #mark-b

Training Diffusion Model (UNet + Style/Content Encoders)...


                                                                                                        

AttributeError: 'VAE' object has no attribute 'get_spatial_latent'

In [5]:

    
    # Then train diffusion model
    model.train_diffusion(train_loader, val_loader)
    
    # Generate some samples
    for i, batch in enumerate(test_loader):
        if i >= 5:  # Generate 5 samples
            break
        
        # Get a sample for style reference
        style_image = batch['image'][0]
        text = batch['transcription'][0]
        
        # Generate new handwriting with the same style but different text
        generated = model.generate_handwriting(style_image, "This is a generated sample.")
        
        # Plot the results
        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.title("Original")
        plt.imshow(style_image.squeeze().cpu().numpy(), cmap='gray')
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.title("Generated")
        plt.imshow(generated.squeeze().cpu().numpy(), cmap='gray')
        plt.axis('off')
        
        plt.savefig(os.path.join(config.results_dir, f"sample_{i}.png"))
        plt.close()
