# Implementation test

In [11]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from dataset import IAMLinesDataset, IAMLinesDatasetPyTorch
import importlib


In [65]:
# Configuration
class Config:
    def __init__(self):
        # Dataset parameters
        self.img_size = (64, 256)  # Height, Width - Adjusted for line-level segments
        self.batch_size = 64
        self.num_workers = 4
        
        # VAE parameters
        self.latent_dim = 512  # Latent dimension for the VAE
        self.vae_lr = 1e-4
        self.vae_epochs = 2 #orig: 50
        
        # Diffusion model parameters
        self.timesteps = 1000  # Number of diffusion steps
        self.beta_start = 1e-4  # Starting noise schedule value
        self.beta_end = 2e-2  # Ending noise schedule value
        self.diffusion_lr = 1e-4
        self.diffusion_epochs = 5 #orig: 100
        
        # Style and content encoder parameters
        self.style_dim = 256
        self.content_dim = 256
        
        # Training parameters
        if torch.backends.mps.is_available():
            self.device = torch.device("mps")
            print("Using mac gpu")
        elif torch.cuda.is_available():
            self.device = torch.device("cuda")
            print("Using cuda gpu")
        else:
            self.device = torch.device("cpu")
            print("Using cpu")
        self.save_dir = "./models"
        self.results_dir = "./results"
        
        # Create directories if they don't exist
        os.makedirs(self.save_dir, exist_ok=True)
        os.makedirs(self.results_dir, exist_ok=True)

In [66]:
# Dataset and preprocessing
class IAMDatasetWrapper:
    def __init__(self, config):
        self.config = config
        self.transform = transforms.Compose([
            transforms.Resize(config.img_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
        ])
        
    def setup(self):
        """Load the IAM dataset and create train/val/test splits"""
        print("Loading IAM dataset...")
        # Assuming IAMLinesDatasetPyTorch is already implemented
        data_path = 'data/lines.tgz' 
        xml_path = 'data/xml.tgz'    
        iam_dataset = IAMLinesDataset(data_path, xml_path)
        full_dataset = IAMLinesDatasetPyTorch(iam_dataset=iam_dataset, transform=self.transform)
        
        # Split into train/val/test
        train_size = int(0.8 * len(full_dataset))
        val_size = int(0.1 * len(full_dataset))
        test_size = len(full_dataset) - train_size - val_size
        
        train_dataset, val_dataset, test_dataset = random_split(
            full_dataset, [train_size, val_size, test_size]
        )
        
        self.train_loader = DataLoader(
            train_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=True,
            num_workers=self.config.num_workers
        )
        
        self.val_loader = DataLoader(
            val_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=False,
            num_workers=self.config.num_workers
        )
        
        self.test_loader = DataLoader(
            test_dataset, 
            batch_size=self.config.batch_size, 
            shuffle=False,
            num_workers=self.config.num_workers
        )
        
        print(f"Dataset loaded. Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
        return self.train_loader, self.val_loader, self.test_loader


In [106]:
# VAE Model
class VAE(nn.Module):
    def __init__(self, config):
        super(VAE, self).__init__()
        self.config = config
        self.latent_dim = config.latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
        )
        
        # Calculate the size of the encoder output
        h, w = config.img_size
        self.encoder_output_size = (h // 16, w // 16, 256)
        self.encoder_flattened_dim = self.encoder_output_size[0] * self.encoder_output_size[1] * self.encoder_output_size[2]
        
        # Mean and log variance layers
        self.fc_mu = nn.Linear(self.encoder_flattened_dim, self.latent_dim)
        self.fc_logvar = nn.Linear(self.encoder_flattened_dim, self.latent_dim)
        
        # Decoder input layer
        self.decoder_input = nn.Linear(self.latent_dim, self.encoder_flattened_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Output in range [-1, 1]
        )
    
    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)  # Flatten
        
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        z = self.decoder_input(z)
        z = z.view(-1, 256, self.encoder_output_size[0], self.encoder_output_size[1])
        x_recon = self.decoder(z)
        return x_recon
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

# Style Encoder
class StyleEncoder(nn.Module):
    def __init__(self, config):
        super(StyleEncoder, self).__init__()
        self.config = config
        
        # CNN backbone
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        # Projection head
        self.fc = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, config.style_dim)
        )
    
    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)  # Flatten
        style = self.fc(x)
        return style

# Content Encoder (for text)
class ContentEncoder(nn.Module):
    def __init__(self, config, vocab_size=128, embedding_dim=128):
        super(ContentEncoder, self).__init__()
        self.config = config
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=256,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        
        # Projection head
        self.fc = nn.Linear(512, config.content_dim)
    
    def forward(self, x):
        # x shape: [batch_size, seq_len]
        embedded = self.embedding(x)
        
        # LSTM forward pass
        output, (hidden, _) = self.lstm(embedded)
        
        # Concatenate forward and backward hidden states
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        # Project to content dimension
        content = self.fc(hidden)
        return content

# UNet model for diffusion
class DiffusionUNet(nn.Module):
    # def __init__(self, config):
    #     super(DiffusionUNet, self).__init__()
    #     self.config = config
        
    #     # Calculate latent image dimensions
    #     h, w = config.img_size
    #     self.latent_h, self.latent_w = h // 16, w // 16
        
    #     # Time embedding
    #     self.time_embed = nn.Sequential(
    #         nn.Linear(1, 128),
    #         nn.SiLU(),
    #         nn.Linear(128, 128),
    #     )
        
    #     # Conditioning embeddings
    #     self.style_embed = nn.Linear(config.style_dim, 128)
    #     self.content_embed = nn.Linear(config.content_dim, 128)
        
    #     # Down blocks
    #     self.down1 = self._make_down_block(256, 256)
    #     self.down2 = self._make_down_block(256, 512)
    #     self.down3 = self._make_down_block(512, 512)
        
    #     # Middle block
    #     self.mid = nn.Sequential(
    #         nn.Conv2d(512, 512, kernel_size=3, padding=1),
    #         nn.GroupNorm(32, 512),
    #         nn.SiLU(),
    #         nn.Conv2d(512, 512, kernel_size=3, padding=1),
    #         nn.GroupNorm(32, 512),
    #         nn.SiLU(),
    #     )
        
    #     # Up blocks
    #     self.up1 = self._make_up_block(512, 512)
    #     self.up2 = self._make_up_block(512, 256)
    #     self.up3 = self._make_up_block(256, 256)
        
    #     # Output projection
    #     self.out = nn.Conv2d(256, 256, kernel_size=3, padding=1)
    def __init__(self, config):
        super(DiffusionUNet, self).__init__()
        self.config = config
        
        # Calculate latent image dimensions
        h, w = config.img_size
        # Ensure latent dimensions match VAE bottleneck: H/16, W/16
        self.latent_h, self.latent_w = h // 16, w // 16 
        
        # VAE bottleneck channels
        vae_latent_channels = 256 # Should match the VAE's output channels before flattening
        
        # Conditioning embedding dimension (after combining time, style, content)
        cond_channels = 128 # Based on time_embed, style_embed, content_embed outputs
        
        # --- Time and Conditioning Embeddings (Keep as is) ---
        self.time_embed = nn.Sequential(
            nn.Linear(1, 128),
            nn.SiLU(),
            nn.Linear(128, 128),
        )
        self.style_embed = nn.Linear(config.style_dim, 128)
        self.content_embed = nn.Linear(config.content_dim, 128)
        
        # --- ADD Initial Convolution Layer ---
        # Input channels = VAE latent channels + conditioning channels
        initial_in_channels = vae_latent_channels + cond_channels # 256 + 128 = 384
        initial_out_channels = 256 # Desired channels for the first down block
        self.initial_conv = nn.Conv2d(initial_in_channels, initial_out_channels, kernel_size=3, padding=1)
        
        # --- Down blocks ---
        # Make sure the first down block starts with initial_out_channels
        self.down1 = self._make_down_block(initial_out_channels, 256) 
        self.down2 = self._make_down_block(256, 512)
        self.down3 = self._make_down_block(512, 512)
        
        # --- Middle block (Keep as is, input is 512 channels) ---
        self.mid = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.GroupNorm(32, 512),
            nn.SiLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.GroupNorm(32, 512),
            nn.SiLU(),
        )
        
        # --- Up blocks ---
        # Adjust input channels based on skip connections
        self.up1 = self._make_up_block(512 + 512, 512) # mid output (512) + d3 skip (512)
        self.up2 = self._make_up_block(512 + 512, 256) # u1 output (512) + d2 skip (512)
        self.up3 = self._make_up_block(256 + 256, 256) # u2 output (256) + d1 skip (256)
        
        # --- Output projection ---
        # Output should predict noise, matching the VAE latent channels
        self.out = nn.Conv2d(256, vae_latent_channels, kernel_size=3, padding=1) 

    
    def _make_down_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(32, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(32, out_channels),
            nn.SiLU(),
            nn.AvgPool2d(2)
        )
    
    def _make_up_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(32, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.GroupNorm(32, out_channels),
            nn.SiLU(),
        )
    
    # def forward(self, x, t, style, content):
    #     # Reshape inputs
    #     t = t.unsqueeze(-1).float()  # [B, 1]
        
    #     # Time embedding
    #     t_emb = self.time_embed(t)  # [B, 128]
        
    #     # Style and content embeddings
    #     style_emb = self.style_embed(style)  # [B, 128]
    #     content_emb = self.content_embed(content)  # [B, 128]
        
    #     # Combine embeddings
    #     cond_emb = (t_emb + style_emb + content_emb).unsqueeze(-1).unsqueeze(-1)  # [B, 128, 1, 1]
    #     # cond_emb = cond_emb.expand(-1, -1, self.latent_h, self.latent_w)  # [B, 128, H, W]
        
    #     cond_emb = cond_emb.expand(-1, -1, x.shape[2], x.shape[3])
    #     # _batch_size, _channels, H, W = x.shape 

    #     # Expand cond_emb to match the actual H and W of x
    #     # cond_emb = cond_emb.expand(-1, -1, H, W) # NEW LINE
    #     print(f"x shape: {x.shape}, cond_emb shape: {cond_emb.shape}") # mark-a
        
    #     # Initial projection
    #     x = torch.cat([x, cond_emb], dim=1)  # [B, 256+128, H, W]
        
    #     # Encoder path
    #     d1 = self.down1(x)
    #     d2 = self.down2(d1)
    #     d3 = self.down3(d2)
        
    #     # Middle
    #     mid = self.mid(d3)
        
    #     # Decoder path with skip connections
    #     u1 = self.up1(mid + d3)
    #     u2 = self.up2(u1 + d2)
    #     u3 = self.up3(u2 + d1)
        
    #     # Output
    #     return self.out(u3)
    def forward(self, x, t, style, content):
        # x input shape expected: [B, vae_latent_channels, H', W'] e.g., [64, 256, 4, 16]
        
        t = t.unsqueeze(-1).float() # [B, 1]
        
        # Time embedding
        t_emb = self.time_embed(t) # [B, 128]
        
        # Style and content embeddings
        style_emb = self.style_embed(style) # [B, 128]
        content_emb = self.content_embed(content) # [B, 128]
        
        # Combine embeddings and expand to spatial dimensions
        cond_emb = (t_emb + style_emb + content_emb).unsqueeze(-1).unsqueeze(-1) # [B, 128, 1, 1]
        cond_emb = cond_emb.expand(-1, -1, self.latent_h, self.latent_w) # [B, 128, 4, 16]

        # print(f"x shape: {x.shape}, cond_emb shape: {cond_emb.shape}") # Should now be [64, 256, 4, 16] and [64, 128, 4, 16]
        
        # Concatenate along channel dimension
        x_cat = torch.cat([x, cond_emb], dim=1) # [B, 256+128, H', W'] -> [64, 384, 4, 16]
        
        # --- Apply Initial Convolution ---
        h = self.initial_conv(x_cat) # Project channels: [64, 384, 4, 16] -> [64, 256, 4, 16]
        
        # Encoder path (pass projected 'h' to down1)
        d1 = self.down1(h) # Input shape [64, 256, 4, 16]
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        
        # Middle
        mid = self.mid(d3)
        
        # Decoder path with skip connections - Ensure channel counts match in _make_up_block inputs
        # Note: Skip connections should come from BEFORE the pooling in down blocks if using standard UNet blocks
        # If _make_down_block applies pooling last, d1, d2, d3 have reduced spatial size but correct channels.
        # If skip connections are added *before* upsampling in the up block, channel counts need adjustment.
        # Assuming standard ResNet-like blocks in _make_down/up_block where skip happens before pooling/after upsampling:
        # Let's assume skips d1, d2, d3 have shapes matching the outputs of down blocks *before pooling*
        # And that u1, u2, u3 are outputs of up blocks *after convolution*.
        # The provided _make_up_block applies Upsample first, then Conv.
        # The provided _make_down_block applies Conv first, then Pool.
        # Need to carefully handle skip connection shapes and channel dimensions.
        # --> Re-evaluating the skip connection logic based on provided blocks:
        # d1 output shape (after pool): [B, 256, 2, 8]
        # d2 output shape (after pool): [B, 512, 1, 4]
        # d3 output shape (after pool): [B, 512, H'/8, W'/8] -> error in block structure likely, should downsample H/W
        # mid output shape: [B, 512, H'/8, W'/8]
        # Let's assume down blocks work correctly spatially.
        
        # Simplified skip connections (concatenating output of upsample+conv with down block output):
        # Check channel dimensions carefully based on your specific _make_up_block implementation
        u1_input = torch.cat([mid, d3], dim=1) # Shape: [B, 512+512, H/8, W/8] - Requires up1 input channels 1024
        u1 = self.up1(u1_input)           # Output shape: [B, 512, H/4, W/4]
        
        u2_input = torch.cat([u1, d2], dim=1)  # Shape: [B, 512+512, H/4, W/4] - Requires up2 input channels 1024
        u2 = self.up2(u2_input)           # Output shape: [B, 256, H/2, W/2]
        
        u3_input = torch.cat([u2, d1], dim=1)  # Shape: [B, 256+256, H/2, W/2] - Requires up3 input channels 512
        u3 = self.up3(u3_input)           # Output shape: [B, 256, H, W]

        # Output - Predict noise matching VAE latent channels
        return self.out(u3) # Output shape: [B, 256, 4, 16]

# Diffusion Model
class LatentDiffusionModel:
    def __init__(self, config):
        self.config = config
        self.device = config.device
        
        # Define beta schedule
        self.beta = torch.linspace(config.beta_start, config.beta_end, config.timesteps).to(self.device)
        self.alpha = 1. - self.beta
        self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
        
        # Initialize models
        self.vae = VAE(config).to(self.device)
        self.style_encoder = StyleEncoder(config).to(self.device)
        self.content_encoder = ContentEncoder(config).to(self.device)
        self.diffusion_model = DiffusionUNet(config).to(self.device)
        
        # Define optimizers
        self.vae_optimizer = optim.Adam(self.vae.parameters(), lr=config.vae_lr)
        self.style_optimizer = optim.Adam(self.style_encoder.parameters(), lr=config.diffusion_lr)
        self.content_optimizer = optim.Adam(self.content_encoder.parameters(), lr=config.diffusion_lr)
        self.diffusion_optimizer = optim.Adam(self.diffusion_model.parameters(), lr=config.diffusion_lr)
    
    def train_vae(self, train_loader, val_loader, epochs=None):
        """Train the VAE model"""
        if epochs is None:
            epochs = self.config.vae_epochs
        
        print("Training VAE...")
        
        for epoch in range(epochs):
            self.vae.train()
            train_loss = 0
            
            for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
                self.vae_optimizer.zero_grad()
                
                # Get handwriting images
                images = batch['image'].to(self.device)
                
                # Forward pass
                recon_images, mu, logvar = self.vae(images)
                
                # Compute loss
                recon_loss = F.mse_loss(recon_images, images)
                kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
                loss = recon_loss + 0.001 * kl_loss
                
                # Backward pass
                loss.backward()
                self.vae_optimizer.step()
                
                train_loss += loss.item()
            
            train_loss /= len(train_loader)
            
            # Validation
            val_loss = self._validate_vae(val_loader)
            
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            
            # Save checkpoint
            if (epoch + 1) % 10 == 0:
                self._save_checkpoint(f"vae_epoch_{epoch+1}.pt", model=self.vae)
        
        print("VAE training completed.")
    
    def _validate_vae(self, val_loader):
        """Validate the VAE model"""
        self.vae.eval()
        val_loss = 0
        
        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(self.device)
                
                # Forward pass
                recon_images, mu, logvar = self.vae(images)
                
                # Compute loss
                recon_loss = F.mse_loss(recon_images, images)
                kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
                loss = recon_loss + 0.001 * kl_loss
                
                val_loss += loss.item()
        
        return val_loss / len(val_loader)
    
    # def train_diffusion(self, train_loader, val_loader, epochs=None):
    #     """Train the diffusion model"""
    #     if epochs is None:
    #         epochs = self.config.diffusion_epochs
        
    #     print("Training Diffusion Model...")
        
    #     # Ensure VAE is in eval mode and frozen
    #     self.vae.eval()
    #     for param in self.vae.parameters():
    #         param.requires_grad = False
        
    #     for epoch in range(epochs):
    #         # Training
    #         self.diffusion_model.train()
    #         self.style_encoder.train()
    #         self.content_encoder.train()
            
    #         train_loss = 0
            
    #         for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
    #             self.diffusion_optimizer.zero_grad()
    #             self.style_optimizer.zero_grad()
    #             self.content_optimizer.zero_grad()
                
    #             # Get data
    #             images = batch['image'].to(self.device)
    #             text = batch['transcription']  # This would need preprocessing
                
    #             # Encode text to indices (simplified - you'd need proper text tokenization)
    #             text_indices = torch.zeros(len(text), 100).long().to(self.device)
                
    #             # Encode images to latent space using VAE
    #             with torch.no_grad():
    #                 mu, _ = self.vae.encode(images)
    #                 latent_images = mu
    #                 # latent_images = mu.reshape(-1, 256, 4, 16) #change-c-1
                
    #             # Encode style and content
    #             style_features = self.style_encoder(images)
    #             content_features = self.content_encoder(text_indices)
                
    #             # Sample a random timestep for each image
    #             t = torch.randint(0, self.config.timesteps, (images.shape[0],)).to(self.device)
                
    #             # Add noise to latent images
    #             noise = torch.randn_like(latent_images)
    #             alpha_cumprod_t = self.alpha_cumprod[t].view(-1, 1, 1, 1)
    #             noisy_latents = torch.sqrt(alpha_cumprod_t) * latent_images + torch.sqrt(1 - alpha_cumprod_t) * noise
                
    #             # Predict the noise
    #             noise_pred = self.diffusion_model(noisy_latents, t, style_features, content_features)
                
    #             # Compute loss
    #             loss = F.mse_loss(noise_pred, noise)
                
    #             # Backward pass
    #             loss.backward()
    #             self.diffusion_optimizer.step()
    #             self.style_optimizer.step()
    #             self.content_optimizer.step()
                
    #             train_loss += loss.item()
            
    #         train_loss /= len(train_loader)
            
    #         # Validation
    #         val_loss = self._validate_diffusion(val_loader)
            
    #         print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            
    #         # Save checkpoint
    #         if (epoch + 1) % 10 == 0:
    #             self._save_checkpoint(
    #                 f"diffusion_epoch_{epoch+1}.pt", 
    #                 models=[self.diffusion_model, self.style_encoder, self.content_encoder]
    #             )
        
    #     print("Diffusion model training completed.")

    def train_diffusion(self, train_loader, val_loader, epochs=None):
        """Train the diffusion model"""
        if epochs is None:
            epochs = self.config.diffusion_epochs
        
        print("Training Diffusion Model...")
        
        # Ensure VAE is in eval mode and frozen
        self.vae.eval()
        for param in self.vae.parameters():
            param.requires_grad = False
        
        for epoch in range(epochs):
            # Training
            self.diffusion_model.train()
            self.style_encoder.train()
            self.content_encoder.train()
            
            train_loss = 0
            
            for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
                self.diffusion_optimizer.zero_grad()
                self.style_optimizer.zero_grad()
                self.content_optimizer.zero_grad()
                
                # Get data
                images = batch['image'].to(self.device)
                text = batch['transcription']  # This would need preprocessing
                
                # Encode text to indices (simplified - you'd need proper text tokenization)
                text_indices = torch.zeros(len(text), 100).long().to(self.device)
                
                # Encode images to latent space using VAE
                with torch.no_grad():
                    # Get mean and log variance from the encoder
                    mu, logvar = self.vae.encode(images) 
                    # Reparameterize to get the latent vector z (flattened)
                    z_flat = self.vae.reparameterize(mu, logvar) 
                    # Project and reshape z back to spatial latent dimensions [B, C, H', W']
                    # (Mimicking the start of the VAE's decode method)
                    latent_spatial = self.vae.decoder_input(z_flat)
                    latent_spatial = latent_spatial.view(
                        -1,  # Batch size
                        256, # VAE bottleneck channels
                        self.config.img_size[0] // 16, # Latent Height (4)
                        self.config.img_size[1] // 16  # Latent Width (16) - Assuming img_width is 256
                    ) # Expected shape: [64, 256, 4, 16]
    
                # Sample a random timestep for each image
                t = torch.randint(0, self.config.timesteps, (images.shape[0],)).to(self.device)
                
                # Add noise to the *spatial* latent images
                noise = torch.randn_like(latent_spatial) # Noise shape: [64, 256, 4, 16]
                alpha_cumprod_t = self.alpha_cumprod[t].view(-1, 1, 1, 1) # Shape: [64, 1, 1, 1]
                
                # Calculate noisy latents correctly using the spatial representation
                noisy_latents = torch.sqrt(alpha_cumprod_t) * latent_spatial + torch.sqrt(1 - alpha_cumprod_t) * noise
                # noisy_latents expected shape: [64, 256, 4, 16]
                
                # Encode style and content (Keep this part as is)
                style_features = self.style_encoder(images)
                content_features = self.content_encoder(text_indices)
                
                # Predict the noise using the Diffusion UNet
                # Pass the correctly shaped noisy_latents
                noise_pred = self.diffusion_model(noisy_latents, t, style_features, content_features)
                
                # Compute loss (noise_pred should also have shape [64, 256, 4, 16])
                loss = F.mse_loss(noise_pred, noise) # Match between predicted and actual noise
                
                
                # Backward pass
                loss.backward()
                self.diffusion_optimizer.step()
                self.style_optimizer.step()
                self.content_optimizer.step()
                
                train_loss += loss.item()
            
            train_loss /= len(train_loader)
            
            # Validation
            val_loss = self._validate_diffusion(val_loader)
            
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
            
            # Save checkpoint
            if (epoch + 1) % 10 == 0:
                self._save_checkpoint(
                    f"diffusion_epoch_{epoch+1}.pt", 
                    models=[self.diffusion_model, self.style_encoder, self.content_encoder]
                )
        
        print("Diffusion model training completed.")
    
    def _validate_diffusion(self, val_loader):
        """Validate the diffusion model"""
        self.diffusion_model.eval()
        self.style_encoder.eval()
        self.content_encoder.eval()
        
        val_loss = 0
        
        with torch.no_grad():
            for batch in val_loader:
                # Get data
                images = batch['image'].to(self.device)
                text = batch['transcription']
                
                # Encode text to indices (simplified)
                text_indices = torch.zeros(len(text), 100).long().to(self.device)
                
                # Encode images to latent space using VAE
                mu, _ = self.vae.encode(images)
                latent_images = mu
                
                # Encode style and content
                style_features = self.style_encoder(images)
                content_features = self.content_encoder(text_indices)
                
                # Sample a random timestep for each image
                t = torch.randint(0, self.config.timesteps, (images.shape[0],)).to(self.device)
                
                # Add noise to latent images
                noise = torch.randn_like(latent_images)
                alpha_cumprod_t = self.alpha_cumprod[t].view(-1, 1, 1, 1)
                noisy_latents = torch.sqrt(alpha_cumprod_t) * latent_images + torch.sqrt(1 - alpha_cumprod_t) * noise
                
                # Predict the noise
                noise_pred = self.diffusion_model(noisy_latents, t, style_features, content_features)
                
                # Compute loss
                loss = F.mse_loss(noise_pred, noise)
                
                val_loss += loss.item()
        
        return val_loss / len(val_loader)
    
    def _save_checkpoint(self, filename, model=None, models=None):
        """Save model checkpoint"""
        save_path = os.path.join(self.config.save_dir, filename)
        
        if model is not None:
            torch.save(model.state_dict(), save_path)
        elif models is not None:
            save_dict = {}
            for i, m in enumerate(models):
                save_dict[f"model_{i}"] = m.state_dict()
            torch.save(save_dict, save_path)
    
    def generate_handwriting(self, style_image, text, steps=50):
        """Generate handwriting with a given style and text"""
        # Ensure models are in eval mode
        self.vae.eval()
        self.style_encoder.eval()
        self.content_encoder.eval()
        self.diffusion_model.eval()
        
        with torch.no_grad():
            # Preprocess style image
            style_image = style_image.unsqueeze(0).to(self.device)
            
            # Encode style
            style_features = self.style_encoder(style_image)
            
            # Encode text (simplified)
            text_indices = torch.zeros(1, 100).long().to(self.device)  # Placeholder
            content_features = self.content_encoder(text_indices)
            
            # Start with random noise in the latent space
            latent_shape = (1, 256, self.config.img_size[0] // 16, self.config.img_size[1] // 16)
            latent = torch.randn(latent_shape).to(self.device)
            
            # Denoise gradually
            for t in tqdm(range(self.config.timesteps - 1, -1, -1), desc="Generating"):
                # Get the timestep
                timestep = torch.tensor([t], device=self.device)
                
                # Predict noise
                noise_pred = self.diffusion_model(latent, timestep, style_features, content_features)
                
                # Get alpha values for current timestep
                alpha = self.alpha[t]
                alpha_cumprod = self.alpha_cumprod[t]
                beta = self.beta[t]
                
                # No noise at timestep 0
                if t > 0:
                    noise = torch.randn_like(latent)
                else:
                    noise = torch.zeros_like(latent)
                
                # Update latent
                latent = (1 / torch.sqrt(alpha)) * (
                    latent - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * noise_pred
                ) + torch.sqrt(beta) * noise
            
            # Decode latent to image
            generated_image = self.vae.decode(latent)
            
            return generated_image


In [107]:
config = Config()

Using mac gpu


In [108]:
data_wrapper = IAMDatasetWrapper(config)

In [109]:
train_loader, val_loader, test_loader = data_wrapper.setup()

Loading IAM dataset...
Dataset loaded. Train: 10682, Val: 1335, Test: 1336


In [110]:
model = LatentDiffusionModel(config)

In [111]:
model.train_vae(train_loader, val_loader)

Training VAE...


Epoch 1/2: 100%|██████████████████████████████████████████████████████| 167/167 [00:31<00:00,  5.38it/s]


Epoch 1/2, Train Loss: 0.4188, Val Loss: 0.1309


Epoch 2/2: 100%|██████████████████████████████████████████████████████| 167/167 [00:31<00:00,  5.38it/s]


Epoch 2/2, Train Loss: 0.1192, Val Loss: 0.1092
VAE training completed.


In [112]:
model.train_diffusion(train_loader, val_loader) #mark-b

Training Diffusion Model...


Epoch 1/5:   0%|                                                                | 0/167 [00:16<?, ?it/s]


RuntimeError: Given input size: (512x1x4). Calculated output size: (512x0x2). Output size is too small

In [5]:

    
    # Then train diffusion model
    model.train_diffusion(train_loader, val_loader)
    
    # Generate some samples
    for i, batch in enumerate(test_loader):
        if i >= 5:  # Generate 5 samples
            break
        
        # Get a sample for style reference
        style_image = batch['image'][0]
        text = batch['transcription'][0]
        
        # Generate new handwriting with the same style but different text
        generated = model.generate_handwriting(style_image, "This is a generated sample.")
        
        # Plot the results
        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.title("Original")
        plt.imshow(style_image.squeeze().cpu().numpy(), cmap='gray')
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.title("Generated")
        plt.imshow(generated.squeeze().cpu().numpy(), cmap='gray')
        plt.axis('off')
        
        plt.savefig(os.path.join(config.results_dir, f"sample_{i}.png"))
        plt.close()
