# VAE for Chest X-Ray Generation
**ECE 285 - Deep Generative Models**

Optimized implementation with:
- Spatial latent space for preserving structure
- Perceptual loss (VGG) for sharp images
- GroupNorm + SiLU for stability
- Low beta (0.0001-0.001) for reconstruction quality
- Two-phase training strategy

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms, models
from torchvision.transforms import functional as TF
from PIL import Image
import numpy as np
from scipy import linalg
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
import cv2
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## Configuration

In [None]:
CONFIG = {
    'img_size': 128,
    'latent_spatial': (4, 4),
    'latent_channels': 64,
    'batch_size': 32,
    'epochs_phase1': 50,
    'epochs_phase2': 100,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    'beta_start': 0.0,
    'beta_end': 0.001,
    'beta_warmup': 50,
    'num_workers': 2,
    'l1_weight': 10.0,
    'perceptual_weight': 1.0,
    'ssim_weight': 1.0,
}

DATA_DIR = '/kaggle/input/chest-xray-pneumonia/chest_xray'
os.makedirs('/kaggle/working/checkpoints', exist_ok=True)
os.makedirs('/kaggle/working/results', exist_ok=True)

print('Configuration:')
for k, v in CONFIG.items():
    print(f'  {k}: {v}')

## Dataset

In [None]:
def apply_clahe(img_np, clip_limit=2.0, tile_size=(8, 8)):
    if img_np.dtype != np.uint8:
        img_np = (img_np * 255).astype(np.uint8)
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_size)
    return clahe.apply(img_np)

class ChestXRayDataset(Dataset):
    def __init__(self, root_dir, img_size=128, split='train'):
        self.img_size = img_size
        self.is_training = (split == 'train')
        self.image_paths = []
        
        split_dir = os.path.join(root_dir, split)
        for category in ['NORMAL', 'PNEUMONIA']:
            cat_path = os.path.join(split_dir, category)
            if os.path.exists(cat_path):
                for ext in ['*.jpeg', '*.jpg', '*.png']:
                    self.image_paths.extend(glob.glob(os.path.join(cat_path, ext)))
        
        print(f'{split}: {len(self.image_paths)} images')
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            img = Image.open(self.image_paths[idx]).convert('L')
            img_np = apply_clahe(np.array(img))
            img = Image.fromarray(img_np)
            img = TF.resize(img, (self.img_size, self.img_size))
            
            if self.is_training and np.random.random() > 0.5:
                img = TF.hflip(img)
            
            return TF.to_tensor(img)
        except:
            return self.__getitem__((idx + 1) % len(self))

train_dataset = ChestXRayDataset(DATA_DIR, CONFIG['img_size'], 'train')
val_dataset = ChestXRayDataset(DATA_DIR, CONFIG['img_size'], 'val')
test_dataset = ChestXRayDataset(DATA_DIR, CONFIG['img_size'], 'test')

train_loader = DataLoader(train_dataset, CONFIG['batch_size'], shuffle=True, 
                         num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, CONFIG['batch_size'], shuffle=False,
                       num_workers=CONFIG['num_workers'], pin_memory=True)
test_loader = DataLoader(test_dataset, CONFIG['batch_size'], shuffle=False,
                        num_workers=CONFIG['num_workers'], pin_memory=True)

In [None]:
# Visualize samples
sample = next(iter(train_loader))
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    ax.imshow(sample[i].squeeze(), cmap='gray')
    ax.axis('off')
plt.suptitle('Sample Training Images')
plt.tight_layout()
plt.savefig('/kaggle/working/results/sample_data.png', dpi=150)
plt.show()

print(f'Stats - Mean: {sample.mean():.3f}, Std: {sample.std():.3f}')

## Model Architecture

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, groups=8):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1)
        self.gn1 = nn.GroupNorm(groups, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
        self.gn2 = nn.GroupNorm(groups, out_ch)
        
        self.shortcut = nn.Identity()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride),
                nn.GroupNorm(groups, out_ch)
            )
    
    def forward(self, x):
        out = F.silu(self.gn1(self.conv1(x)))
        out = self.gn2(self.conv2(out))
        return F.silu(out + self.shortcut(x))

class AttentionBlock(nn.Module):
    def __init__(self, channels, groups=8):
        super().__init__()
        self.gn = nn.GroupNorm(groups, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)
    
    def forward(self, x):
        B, C, H, W = x.shape
        h = self.gn(x)
        qkv = self.qkv(h).reshape(B, 3, C, H * W).permute(1, 0, 2, 3)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = torch.softmax(torch.bmm(q.transpose(1, 2), k) / (C ** 0.5), dim=-1)
        out = torch.bmm(v, attn.transpose(1, 2)).reshape(B, C, H, W)
        return x + self.proj(out)

class Encoder(nn.Module):
    def __init__(self, latent_channels=64):
        super().__init__()
        self.init = nn.Sequential(
            nn.Conv2d(1, 32, 7, 2, 3),
            nn.GroupNorm(8, 32),
            nn.SiLU()
        )
        self.down1 = ResBlock(32, 64, 2)
        self.down2 = ResBlock(64, 128, 2)
        self.down3 = nn.Sequential(ResBlock(128, 256, 2), AttentionBlock(256))
        self.down4 = ResBlock(256, 512, 2)
        
        self.mu_conv = nn.Conv2d(512, latent_channels, 3, 1, 1)
        self.logvar_conv = nn.Conv2d(512, latent_channels, 3, 1, 1)
    
    def forward(self, x):
        x = self.init(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        return self.mu_conv(x), self.logvar_conv(x)

class Decoder(nn.Module):
    def __init__(self, latent_channels=64):
        super().__init__()
        self.init = nn.Conv2d(latent_channels, 512, 3, 1, 1)
        
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            ResBlock(512, 256, 1)
        )
        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            ResBlock(256, 128, 1),
            AttentionBlock(128)
        )
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            ResBlock(128, 64, 1)
        )
        self.up4 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            ResBlock(64, 32, 1)
        )
        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(32, 16, 3, 1, 1),
            nn.GroupNorm(8, 16),
            nn.SiLU(),
            nn.Conv2d(16, 1, 3, 1, 1),
            nn.Sigmoid()
        )
    
    def forward(self, z):
        x = self.init(z)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)
        return self.final(x)

class SpatialVAE(nn.Module):
    def __init__(self, latent_channels=64):
        super().__init__()
        self.encoder = Encoder(latent_channels)
        self.decoder = Decoder(latent_channels)
        self.latent_channels = latent_channels
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + std * torch.randn_like(std)
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar
    
    def generate(self, n, device):
        z = torch.randn(n, self.latent_channels, 4, 4).to(device)
        with torch.no_grad():
            return self.decoder(z)

## Loss Functions

In [None]:
class PerceptualLoss(nn.Module):
    def __init__(self, device):
        super().__init__()
        vgg = models.vgg16(pretrained=True).features.eval().to(device)
        for p in vgg.parameters():
            p.requires_grad = False
        
        self.slices = nn.ModuleList([
            vgg[:4],   # relu1_2
            vgg[4:9],  # relu2_2
            vgg[9:16]  # relu3_3
        ])
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    
    def forward(self, x, y):
        x_rgb = x.repeat(1, 3, 1, 1)
        y_rgb = y.repeat(1, 3, 1, 1)
        
        # ImageNet normalization (move to device)
        mean = self.mean.to(x.device)
        std = self.std.to(x.device)
        x_rgb = (x_rgb - mean) / std
        y_rgb = (y_rgb - mean) / std
        
        loss = 0
        x_feats, y_feats = x_rgb, y_rgb
        for slice_net in self.slices:
            x_feats = slice_net(x_feats)
            y_feats = slice_net(y_feats)
            loss += F.l1_loss(x_feats, y_feats)
        return loss

class SSIMLoss(nn.Module):
    def __init__(self, window_size=11, sigma=1.5):
        super().__init__()
        self.window_size = window_size
        coords = torch.arange(window_size, dtype=torch.float32) - window_size // 2
        g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
        g = g / g.sum()
        window = (g.unsqueeze(0) * g.unsqueeze(1)).unsqueeze(0).unsqueeze(0)
        self.register_buffer('window', window)
    
    def forward(self, x, y):
        C1, C2 = 0.01 ** 2, 0.03 ** 2
        pad = self.window_size // 2
        window = self.window.to(x.device)
        
        mu_x = F.conv2d(x, window, padding=pad)
        mu_y = F.conv2d(y, window, padding=pad)
        mu_x_sq, mu_y_sq, mu_xy = mu_x ** 2, mu_y ** 2, mu_x * mu_y
        
        sigma_x_sq = F.conv2d(x * x, window, padding=pad) - mu_x_sq
        sigma_y_sq = F.conv2d(y * y, window, padding=pad) - mu_y_sq
        sigma_xy = F.conv2d(x * y, window, padding=pad) - mu_xy
        
        ssim = ((2 * mu_xy + C1) * (2 * sigma_xy + C2)) / \
               ((mu_x_sq + mu_y_sq + C1) * (sigma_x_sq + sigma_y_sq + C2))
        return 1 - ssim.mean()

class VAELoss(nn.Module):
    def __init__(self, device, l1_w=1.0, perc_w=0.1, ssim_w=0.5):
        super().__init__()
        self.perceptual = PerceptualLoss(device)
        self.ssim = SSIMLoss()
        self.l1_w = l1_w
        self.perc_w = perc_w
        self.ssim_w = ssim_w
    
    def forward(self, recon, target, mu, logvar, beta=0.001):
        l1_loss = F.l1_loss(recon, target, reduction='mean')
        perc_loss = self.perceptual(recon, target)
        ssim_loss = self.ssim(recon, target)
        
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        
        recon_total = self.l1_w * l1_loss + self.perc_w * perc_loss + self.ssim_w * ssim_loss
        total_loss = recon_total + beta * kl_loss
        
        return total_loss, recon_total, kl_loss, l1_loss, perc_loss, ssim_loss

## Initialize Model

In [None]:
model = SpatialVAE(latent_channels=CONFIG['latent_channels']).to(device)
loss_fn = VAELoss(device, CONFIG['l1_weight'], CONFIG['perceptual_weight'], CONFIG['ssim_weight'])

total_params = sum(p.numel() for p in model.parameters())
print(f'Model: SpatialVAE with GroupNorm + SiLU')
print(f'Total parameters: {total_params:,}')
print(f'Latent shape: {CONFIG["latent_spatial"]} x {CONFIG["latent_channels"]}')

test_input = torch.rand(2, 1, 128, 128).to(device)
recon, mu, logvar = model(test_input)
print(f'Input: {test_input.shape} -> Output: {recon.shape}')
print(f'Latent mu: {mu.shape}')
print(f'Output range: [{recon.min():.3f}, {recon.max():.3f}]')

## Training Functions

In [None]:
def get_beta(epoch, phase, config):
    if phase == 1:
        return 0.0
    if epoch >= config['beta_warmup']:
        return config['beta_end']
    return config['beta_start'] + (config['beta_end'] - config['beta_start']) * (epoch / config['beta_warmup'])

def train_epoch(model, loader, optimizer, loss_fn, device, beta):
    model.train()
    total_loss = total_recon = total_kl = 0
    
    pbar = tqdm(loader, desc='Training')
    for data in pbar:
        data = data.to(device)
        optimizer.zero_grad()
        
        recon, mu, logvar = model(data)
        loss, recon_loss, kl_loss, l1, perc, ssim = loss_fn(recon, data, mu, logvar, beta)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl_loss.item()
        
        pbar.set_postfix({'loss': f'{loss.item():.2f}', 'recon': f'{recon_loss.item():.2f}'})
    
    n = len(loader)
    return total_loss/n, total_recon/n, total_kl/n

@torch.no_grad()
def validate(model, loader, loss_fn, device, beta):
    model.eval()
    total_loss = 0
    
    for data in loader:
        data = data.to(device)
        recon, mu, logvar = model(data)
        loss, _, _, _, _, _ = loss_fn(recon, data, mu, logvar, beta)
        total_loss += loss.item()
    
    return total_loss / len(loader)

## Phase 1: Reconstruction Training

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

train_losses, recon_losses, kl_losses, val_losses, beta_history = [], [], [], [], []

print('='*60)
print('PHASE 1: Reconstruction-Only Training (beta=0)')
print('='*60)

for epoch in range(1, CONFIG['epochs_phase1'] + 1):
    beta = get_beta(epoch, 1, CONFIG)
    beta_history.append(beta)
    
    train_loss, recon_loss, kl_loss = train_epoch(model, train_loader, optimizer, loss_fn, device, beta)
    val_loss = validate(model, val_loader, loss_fn, device, beta)
    
    train_losses.append(train_loss)
    recon_losses.append(recon_loss)
    kl_losses.append(kl_loss)
    val_losses.append(val_loss)
    scheduler.step()
    
    print(f'Epoch {epoch}/{CONFIG["epochs_phase1"]} | Train: {train_loss:.3f} | Val: {val_loss:.3f}')
    
    if epoch % 10 == 0:
        samples = model.generate(4, device)
        fig, axes = plt.subplots(1, 4, figsize=(10, 3))
        for i, ax in enumerate(axes):
            ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
            ax.axis('off')
        plt.suptitle(f'Phase 1 - Epoch {epoch}')
        plt.tight_layout()
        plt.show()

torch.save(model.state_dict(), '/kaggle/working/checkpoints/phase1_final.pt')
print('Phase 1 complete!')

## Phase 2: VAE Training with KL

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'] * 0.5, weight_decay=CONFIG['weight_decay'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', 0.5, 10, verbose=True)

print('\n' + '='*60)
print('PHASE 2: VAE Training with KL Regularization')
print('='*60)

best_loss = float('inf')
phase2_start = CONFIG['epochs_phase1']

for epoch in range(1, CONFIG['epochs_phase2'] + 1):
    global_epoch = phase2_start + epoch
    beta = get_beta(epoch, 2, CONFIG)
    beta_history.append(beta)
    
    train_loss, recon_loss, kl_loss = train_epoch(model, train_loader, optimizer, loss_fn, device, beta)
    val_loss = validate(model, val_loader, loss_fn, device, beta)
    
    train_losses.append(train_loss)
    recon_losses.append(recon_loss)
    kl_losses.append(kl_loss)
    val_losses.append(val_loss)
    scheduler.step(val_loss)
    
    print(f'Epoch {global_epoch} | Î²: {beta:.5f} | Train: {train_loss:.3f} | Val: {val_loss:.3f} | KL: {kl_loss:.3f}')
    
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save({
            'epoch': global_epoch,
            'model_state_dict': model.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            'config': CONFIG,
        }, '/kaggle/working/checkpoints/best_model.pt')
        print(f'  -> Saved best model (val_loss: {val_loss:.3f})')
    
    if epoch % 10 == 0:
        samples = model.generate(8, device)
        fig, axes = plt.subplots(2, 4, figsize=(12, 6))
        for i, ax in enumerate(axes.flat):
            ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
            ax.axis('off')
        plt.suptitle(f'Phase 2 - Epoch {global_epoch} (Î²={beta:.5f})')
        plt.tight_layout()
        plt.show()

torch.save(model.state_dict(), '/kaggle/working/checkpoints/final_model.pt')
print('\nTraining complete!')

## Loss Curves & Results

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

axes[0, 0].plot(train_losses, label='Train', linewidth=2)
axes[0, 0].plot(val_losses, label='Val', linewidth=2)
axes[0, 0].axvline(CONFIG['epochs_phase1'], color='red', linestyle='--', alpha=0.5, label='Phase 2')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Total Loss')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

axes[0, 1].plot(recon_losses, linewidth=2, color='green')
axes[0, 1].axvline(CONFIG['epochs_phase1'], color='red', linestyle='--', alpha=0.5)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('Reconstruction Loss')
axes[0, 1].grid(alpha=0.3)

axes[1, 0].plot(kl_losses, linewidth=2, color='red')
axes[1, 0].axvline(CONFIG['epochs_phase1'], color='red', linestyle='--', alpha=0.5)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('KL Divergence')
axes[1, 0].grid(alpha=0.3)

axes[1, 1].plot(beta_history, linewidth=2, color='purple')
axes[1, 1].axvline(CONFIG['epochs_phase1'], color='red', linestyle='--', alpha=0.5)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Beta')
axes[1, 1].set_title('Beta Schedule')
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('/kaggle/working/results/loss_curves.png', dpi=150)
plt.show()

print(f'Final Train Loss: {train_losses[-1]:.3f}')
print(f'Final Val Loss: {val_losses[-1]:.3f}')
print(f'Best Val Loss: {best_loss:.3f}')

## Generate Images

In [None]:
checkpoint = torch.load('/kaggle/working/checkpoints/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f'Loaded best model from epoch {checkpoint["epoch"]}')

# Generated samples
generated = model.generate(16, device)

fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i, ax in enumerate(axes.flat):
    ax.imshow(generated[i].cpu().squeeze(), cmap='gray')
    ax.axis('off')
plt.suptitle('Generated Chest X-Ray Images', fontsize=16)
plt.tight_layout()
plt.savefig('/kaggle/working/results/generated_samples.png', dpi=150)
plt.show()

# Reconstructions
test_batch = next(iter(test_loader))[:8].to(device)
with torch.no_grad():
    recon, _, _ = model(test_batch)

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(8):
    axes[0, i].imshow(test_batch[i].cpu().squeeze(), cmap='gray')
    axes[0, i].axis('off')
    axes[1, i].imshow(recon[i].cpu().squeeze(), cmap='gray')
    axes[1, i].axis('off')
plt.tight_layout()
plt.savefig('/kaggle/working/results/reconstructions.png', dpi=150)
plt.show()

## Summary

In [None]:
summary = f"""
VAE Chest X-Ray Generation - Results Summary
{'='*60}

ARCHITECTURE:
- Model: Spatial VAE (GroupNorm + SiLU)
- Latent: {CONFIG['latent_spatial']} x {CONFIG['latent_channels']} channels
- Parameters: {total_params:,}

LOSS FUNCTION:
- L1 Loss (weight: {CONFIG['l1_weight']})
- Perceptual Loss (VGG, weight: {CONFIG['perceptual_weight']})
- SSIM Loss (weight: {CONFIG['ssim_weight']})
- KL Divergence (beta: {CONFIG['beta_start']}->{CONFIG['beta_end']})

TRAINING:
- Phase 1: {CONFIG['epochs_phase1']} epochs (reconstruction only)
- Phase 2: {CONFIG['epochs_phase2']} epochs (with KL)
- Learning Rate: {CONFIG['learning_rate']}

RESULTS:
- Final Train Loss: {train_losses[-1]:.3f}
- Final Val Loss: {val_losses[-1]:.3f}
- Best Val Loss: {best_loss:.3f}

KEY IMPROVEMENTS:
1. Spatial latent space (preserves structure)
2. Very low beta (0.001 for sharp images)
3. Perceptual loss (VGG features)
4. L1 loss instead of MSE
5. Two-phase training strategy
6. GroupNorm + SiLU for stability
"""

with open('/kaggle/working/results/summary.txt', 'w') as f:
    f.write(summary)

print(summary)

print('\n' + '='*60)
print('Training Complete!')
print('='*60)

## Evaluation - FID & Inception Score

In [None]:
class InceptionFeatures(nn.Module):
    def __init__(self, device):
        super().__init__()
        inception = models.inception_v3(pretrained=True)
        self.blocks = nn.Sequential(
            inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
            nn.MaxPool2d(3, 2), inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3,
            nn.MaxPool2d(3, 2), inception.Mixed_5b, inception.Mixed_5c, inception.Mixed_5d,
            inception.Mixed_6a, inception.Mixed_6b, inception.Mixed_6c, inception.Mixed_6d,
            inception.Mixed_6e, inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.to(device).eval()
        self.device = device
    
    @torch.no_grad()
    def forward(self, x):
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)
        x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        return self.blocks(x).view(x.size(0), -1)

def get_activations(images, feat_model, batch_size=32):
    activations = []
    loader = DataLoader(TensorDataset(images), batch_size=batch_size, shuffle=False)
    
    for batch in tqdm(loader, desc='Extracting features'):
        act = feat_model(batch[0].to(feat_model.device))
        activations.append(act.cpu().numpy())
    
    return np.concatenate(activations, axis=0)

def calculate_fid(real_acts, fake_acts):
    mu_r, mu_f = np.mean(real_acts, axis=0), np.mean(fake_acts, axis=0)
    sigma_r = np.cov(real_acts, rowvar=False)
    sigma_f = np.cov(fake_acts, rowvar=False)
    
    diff = mu_r - mu_f
    covmean, _ = linalg.sqrtm(sigma_r @ sigma_f, disp=False)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff @ diff + np.trace(sigma_r + sigma_f - 2 * covmean)
    return float(fid)

def calculate_inception_score(images, device, batch_size=32, splits=10):
    inception = models.inception_v3(pretrained=True).to(device).eval()
    preds = []
    
    loader = DataLoader(TensorDataset(images), batch_size=batch_size, shuffle=False)
    with torch.no_grad():
        for batch in tqdm(loader, desc='Computing Inception Score'):
            x = batch[0].to(device)
            if x.size(1) == 1:
                x = x.repeat(1, 3, 1, 1)
            x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
            pred = F.softmax(inception(x), dim=1)
            preds.append(pred.cpu().numpy())
    
    preds = np.concatenate(preds, axis=0)
    scores = []
    split_size = preds.shape[0] // splits
    
    for i in range(splits):
        part = preds[i * split_size:(i + 1) * split_size]
        py = np.mean(part, axis=0)
        kl = part * (np.log(part + 1e-10) - np.log(py + 1e-10))
        scores.append(np.exp(np.mean(np.sum(kl, axis=1))))
    
    return np.mean(scores), np.std(scores)

print('Evaluation functions loaded')

In [None]:
# Prepare evaluation data
NUM_EVAL = 1000
print(f'Preparing {NUM_EVAL} images for evaluation...')

# Generate fake images
print('\nGenerating synthetic images...')
fake_imgs = model.generate(NUM_EVAL, device).cpu()

# Collect real images
print('Collecting real test images...')
real_imgs_list = []
for batch in test_loader:
    real_imgs_list.append(batch)
    if len(real_imgs_list) * batch.size(0) >= NUM_EVAL:
        break
real_imgs = torch.cat(real_imgs_list, dim=0)[:NUM_EVAL]

print(f'\nReal images: {real_imgs.shape}, range [{real_imgs.min():.3f}, {real_imgs.max():.3f}]')
print(f'Fake images: {fake_imgs.shape}, range [{fake_imgs.min():.3f}, {fake_imgs.max():.3f}]')

In [None]:
# Calculate FID Score
print('\n' + '='*60)
print('Computing FID Score')
print('='*60)

try:
    inception_model = InceptionFeatures(device)
    print('Extracting features from real images...')
    real_acts = get_activations(real_imgs, inception_model)
    print('Extracting features from generated images...')
    fake_acts = get_activations(fake_imgs, inception_model)
    
    fid_score = calculate_fid(real_acts, fake_acts)
    print(f'\nâœ“ FID Score: {fid_score:.2f}')
    
    # Interpretation
    if fid_score < 50:
        print('  â†’ Excellent quality!')
    elif fid_score < 100:
        print('  â†’ Good quality')
    elif fid_score < 200:
        print('  â†’ Moderate quality')
    else:
        print('  â†’ Needs improvement')
        
except Exception as e:
    print(f'âœ— FID calculation failed: {e}')
    fid_score = None

In [None]:
# Calculate Inception Score
print('\n' + '='*60)
print('Computing Inception Score')
print('='*60)

try:
    is_mean, is_std = calculate_inception_score(fake_imgs, device)
    print(f'\nâœ“ Inception Score: {is_mean:.3f} Â± {is_std:.3f}')
    
    # Interpretation
    if is_mean > 5:
        print('  â†’ Excellent diversity and quality!')
    elif is_mean > 3:
        print('  â†’ Good diversity')
    elif is_mean > 2:
        print('  â†’ Moderate diversity')
    else:
        print('  â†’ Low diversity')
        
except Exception as e:
    print(f'âœ— Inception Score calculation failed: {e}')
    is_mean, is_std = None, None

In [None]:
# Final Evaluation Summary
print('\n' + '='*70)
print(' '*20 + 'EVALUATION SUMMARY')
print('='*70)

print('\nMODEL ARCHITECTURE:')
print(f'  â€¢ Total Parameters: {total_params:,}')
print(f'  â€¢ Latent Space: {CONFIG["latent_spatial"]} Ã— {CONFIG["latent_channels"]} channels')
print(f'  â€¢ Architecture: Spatial VAE with GroupNorm + SiLU + Attention')

print('\nTRAINING RESULTS:')
print(f'  â€¢ Best Validation Loss: {best_loss:.4f}')
print(f'  â€¢ Final Train Loss: {train_losses[-1]:.4f}')
print(f'  â€¢ Final Val Loss: {val_losses[-1]:.4f}')
print(f'  â€¢ Total Epochs: {len(train_losses)}')

print('\nGENERATIVE METRICS:')
if fid_score:
    print(f'  â€¢ FID Score: {fid_score:.2f}')
    fid_status = 'Excellent' if fid_score < 50 else 'Good' if fid_score < 100 else 'Moderate' if fid_score < 200 else 'Poor'
    print(f'    Status: {fid_status}')
else:
    print('  â€¢ FID Score: N/A')

if is_mean:
    print(f'  â€¢ Inception Score: {is_mean:.3f} Â± {is_std:.3f}')
    is_status = 'Excellent' if is_mean > 5 else 'Good' if is_mean > 3 else 'Moderate' if is_mean > 2 else 'Low'
    print(f'    Status: {is_status}')
else:
    print('  â€¢ Inception Score: N/A')

print('\nLOSS CONFIGURATION:')
print(f'  â€¢ L1 Weight: {CONFIG["l1_weight"]}')
print(f'  â€¢ Perceptual Weight: {CONFIG["perceptual_weight"]}')
print(f'  â€¢ SSIM Weight: {CONFIG["ssim_weight"]}')
print(f'  â€¢ Beta Range: {CONFIG["beta_start"]} â†’ {CONFIG["beta_end"]}')

print('\n' + '='*70)
print('Evaluation Complete!')
print('='*70)

# Save results to file
results_dict = {
    'model_params': total_params,
    'latent_dim': CONFIG['latent_channels'] * 4 * 4,
    'best_val_loss': float(best_loss),
    'final_train_loss': float(train_losses[-1]),
    'final_val_loss': float(val_losses[-1]),
    'fid_score': float(fid_score) if fid_score else None,
    'inception_mean': float(is_mean) if is_mean else None,
    'inception_std': float(is_std) if is_std else None,
    'config': CONFIG
}

import json
with open('/kaggle/working/results/evaluation_metrics.json', 'w') as f:
    json.dump(results_dict, f, indent=2)

print('\nâœ“ Saved metrics to: /kaggle/working/results/evaluation_metrics.json')

## ðŸ”§ Reconstruction Quality Fixes Applied

**Key Changes:**
1. **Loss Normalization**: Changed from `sum` to `mean` reduction
2. **Perceptual Loss**: Added ImageNet normalization
3. **Learning Rate**: Increased from `1e-4` to `1e-3` for Phase 1
4. **Loss Weights**: 
   - L1: `1.0` â†’ `10.0`
   - Perceptual: `0.1` â†’ `1.0`
   - SSIM: `0.5` â†’ `1.0`
5. **Training Loop**: Removed redundant division

**Expected Results:**
- Loss values should be in range [0.1 - 2.0] instead of [800-900]
- Sharper reconstructions with better contrast
- Faster convergence in Phase 1


In [None]:
# Diagnostic: Test reconstruction quality before training
print('='*60)
print('DIAGNOSTIC TEST - Reconstruction Quality')
print('='*60)

model_test = SpatialVAE(latent_channels=CONFIG['latent_channels']).to(device)
loss_fn_test = VAELoss(device, CONFIG['l1_weight'], CONFIG['perceptual_weight'], CONFIG['ssim_weight'])

# Get a sample batch
sample_batch = next(iter(train_loader))
sample_batch = sample_batch.to(device)

# Forward pass
with torch.no_grad():
    recon, mu, logvar = model_test(sample_batch)
    loss, recon_loss, kl_loss, l1, perc, ssim = loss_fn_test(recon, sample_batch, mu, logvar, beta=0.0)

print(f'\nUntrained Model Loss Check:')
print(f'  Total Loss: {loss.item():.4f}')
print(f'  Reconstruction Loss: {recon_loss.item():.4f}')
print(f'  L1 Loss: {l1.item():.4f}')
print(f'  Perceptual Loss: {perc.item():.4f}')
print(f'  SSIM Loss: {ssim.item():.4f}')

print(f'\nReconstruction Stats:')
print(f'  Input range: [{sample_batch.min():.3f}, {sample_batch.max():.3f}]')
print(f'  Output range: [{recon.min():.3f}, {recon.max():.3f}]')
print(f'  Mean absolute diff: {(sample_batch - recon).abs().mean():.4f}')

# Visualize
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i in range(4):
    axes[0, i].imshow(sample_batch[i].cpu().squeeze(), cmap='gray')
    axes[0, i].set_title('Original')
    axes[0, i].axis('off')
    axes[1, i].imshow(recon[i].cpu().squeeze(), cmap='gray')
    axes[1, i].set_title(f'Recon (untrained)')
    axes[1, i].axis('off')
plt.suptitle('Reconstruction Test - Untrained Model')
plt.tight_layout()
plt.show()

print('\nâœ“ Loss values are now normalized (should be < 5.0)')
print('âœ“ Ready to start training with improved configuration!')
print('='*60)

del model_test, loss_fn_test
