## VAE for Chest X-Ray Image Generation
ECE 285 - Deep Generative Models Assignment


In [None]:
# Install dependencies (run once)
%pip install -q scipy tqdm matplotlib


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms, models
from PIL import Image
import numpy as np
from scipy import linalg
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


## 1. Configuration


In [None]:
# Hyperparameters
CONFIG = {
    'img_size': 128,
    'latent_dim': 512,
    'batch_size': 32,
    'epochs': 100,
    'learning_rate': 1e-4,
    'beta_start': 0.0,      # Start with no KL penalty (focus on reconstruction)
    'beta_end': 1.0,        # End with full KL penalty
    'beta_warmup_epochs': 70,  # Epochs to linearly increase beta from start to end
    'num_workers': 2,
}

def get_beta(epoch, config):
    # Linear annealing: beta increases from beta_start to beta_end over warmup_epochs
    if epoch >= config['beta_warmup_epochs']:
        return config['beta_end']
    return config['beta_start'] + (config['beta_end'] - config['beta_start']) * (epoch / config['beta_warmup_epochs'])

# Kaggle dataset path
DATA_DIR = '/kaggle/input/chest-xray-pneumonia/chest_xray'

# Output directories
os.makedirs('/kaggle/working/checkpoints', exist_ok=True)
os.makedirs('/kaggle/working/results', exist_ok=True)

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")


## 2. Dataset


In [None]:
class ChestXRayDataset(Dataset):
    def __init__(self, root_dir, img_size=128, split='train'):
        self.img_size = img_size
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
        ])
        
        self.image_paths = []
        split_dir = os.path.join(root_dir, split)
        
        if not os.path.exists(split_dir):
            raise FileNotFoundError(f"Directory not found: {split_dir}")
        
        for category in ['NORMAL', 'PNEUMONIA']:
            category_path = os.path.join(split_dir, category)
            if os.path.exists(category_path):
                patterns = ['*.jpeg', '*.jpg', '*.png', '*.JPEG', '*.JPG', '*.PNG']
                for pattern in patterns:
                    self.image_paths.extend(glob.glob(os.path.join(category_path, pattern)))
        
        if len(self.image_paths) == 0:
            raise ValueError(f"No images found in {split_dir}")
        
        print(f"Found {len(self.image_paths)} images in {split} split")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            img_path = self.image_paths[idx]
            image = Image.open(img_path).convert('RGB')
            image = self.transform(image)
            return image
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            return self.__getitem__((idx + 1) % len(self))


In [None]:
# Create datasets and dataloaders
train_dataset = ChestXRayDataset(DATA_DIR, CONFIG['img_size'], split='train')
val_dataset = ChestXRayDataset(DATA_DIR, CONFIG['img_size'], split='val')
test_dataset = ChestXRayDataset(DATA_DIR, CONFIG['img_size'], split='test')

train_loader = DataLoader(
    train_dataset, batch_size=CONFIG['batch_size'], shuffle=True,
    num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True
)
val_loader = DataLoader(
    val_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
    num_workers=CONFIG['num_workers'], pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=CONFIG['batch_size'], shuffle=False,
    num_workers=CONFIG['num_workers'], pin_memory=True
)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")


In [None]:
# Visualize sample images
sample_batch = next(iter(train_loader))
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    ax.imshow(sample_batch[i].squeeze().numpy(), cmap='gray')
    ax.axis('off')
plt.suptitle('Sample Training Images', fontsize=14)
plt.tight_layout()
plt.savefig('/kaggle/working/results/sample_data.png', dpi=150)
plt.show()


## 3. VAE Model


In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=256, img_channels=1):
        super().__init__()
        # Conv layers: 128x128 -> 64 -> 32 -> 16 -> 8 -> 4
        self.conv1 = nn.Conv2d(img_channels, 32, 4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.conv5 = nn.Conv2d(256, 512, 4, stride=2, padding=1)
        
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm2d(512)
        
        # Flatten: 512 * 4 * 4 = 8192
        self.fc_mu = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(512 * 4 * 4, latent_dim)
    
    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.bn3(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.bn4(self.conv4(x)), 0.2)
        x = F.leaky_relu(self.bn5(self.conv5(x)), 0.2)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, latent_dim=256, img_channels=1):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)
        
        # Transposed conv: 4 -> 8 -> 16 -> 32 -> 64 -> 128
        self.deconv1 = nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(32, img_channels, 4, stride=2, padding=1)
        
        self.bn1 = nn.BatchNorm2d(256)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(32)
    
    def forward(self, z):
        x = self.fc(z)
        x = x.view(x.size(0), 512, 4, 4)
        x = F.relu(self.bn1(self.deconv1(x)))
        x = F.relu(self.bn2(self.deconv2(x)))
        x = F.relu(self.bn3(self.deconv3(x)))
        x = F.relu(self.bn4(self.deconv4(x)))
        x = torch.sigmoid(self.deconv5(x))
        return x


class VAE(nn.Module):
    def __init__(self, latent_dim=256, img_channels=1):
        super().__init__()
        self.encoder = Encoder(latent_dim, img_channels)
        self.decoder = Decoder(latent_dim, img_channels)
        self.latent_dim = latent_dim
    
    def reparameterize(self, mu, logvar):
        # z = mu + std * eps
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar
    
    def generate(self, num_samples, device):
        z = torch.randn(num_samples, self.latent_dim).to(device)
        with torch.no_grad():
            samples = self.decoder(z)
        return samples


def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    # Reconstruction loss (BCE)
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss, recon_loss, kl_loss


In [None]:
# Initialize model
model = VAE(latent_dim=CONFIG['latent_dim'], img_channels=1).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
test_input = torch.randn(2, 1, 128, 128).to(device)
recon, mu, logvar = model(test_input)
print(f"\nInput shape: {test_input.shape}")
print(f"Output shape: {recon.shape}")
print(f"Latent mu shape: {mu.shape}")


## 4. Training


In [None]:
# Training setup
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Loss tracking
train_losses = []
recon_losses = []
kl_losses = []
val_losses = []


In [None]:
def train_epoch(model, train_loader, optimizer, device, beta):
    model.train()
    total_loss, total_recon, total_kl = 0, 0, 0
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, data in enumerate(pbar):
        try:
            data = data.to(device)
            optimizer.zero_grad()
            
            recon, mu, logvar = model(data)
            loss, recon_loss, kl_loss = vae_loss(recon, data, mu, logvar, beta)
            
            # Normalize by batch size
            loss = loss / data.size(0)
            recon_loss = recon_loss / data.size(0)
            kl_loss = kl_loss / data.size(0)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_kl += kl_loss.item()
            
            pbar.set_postfix({'loss': f'{loss.item():.2f}', 'recon': f'{recon_loss.item():.2f}', 'kl': f'{kl_loss.item():.2f}'})
        except Exception as e:
            print(f"Error in batch {batch_idx}: {e}")
            continue
    
    n = len(train_loader)
    return total_loss/n, total_recon/n, total_kl/n


@torch.no_grad()
def validate(model, val_loader, device, beta):
    model.eval()
    total_loss = 0
    
    for data in val_loader:
        data = data.to(device)
        recon, mu, logvar = model(data)
        loss, _, _ = vae_loss(recon, data, mu, logvar, beta)
        total_loss += loss.item() / data.size(0)
    
    return total_loss / len(val_loader)


In [None]:
# Training loop with beta annealing
best_loss = float('inf')
beta_history = []  # Track beta values

for epoch in range(1, CONFIG['epochs'] + 1):
    # Calculate annealed beta for this epoch
    current_beta = get_beta(epoch, CONFIG)
    beta_history.append(current_beta)
    
    print(f"\n{'='*50}")
    print(f"Epoch {epoch}/{CONFIG['epochs']} | Beta: {current_beta:.4f}")
    print(f"{'='*50}")
    
    # Train with current beta
    train_loss, recon_loss, kl_loss = train_epoch(model, train_loader, optimizer, device, current_beta)
    train_losses.append(train_loss)
    recon_losses.append(recon_loss)
    kl_losses.append(kl_loss)
    
    # Validate with current beta
    val_loss = validate(model, val_loader, device, current_beta)
    val_losses.append(val_loss)
    
    print(f"Train Loss: {train_loss:.4f} | Recon: {recon_loss:.4f} | KL: {kl_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Save best model (after warmup period for fair comparison)
    if epoch > CONFIG['beta_warmup_epochs'] and val_loss < best_loss:
        best_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            'beta_history': beta_history,
        }, '/kaggle/working/checkpoints/best_model.pt')
        print(f"Saved best model (val_loss: {val_loss:.4f})")
    
    # Save checkpoint every 10 epochs
    if epoch % 10 == 0:
        torch.save(model.state_dict(), f'/kaggle/working/checkpoints/checkpoint_epoch_{epoch}.pt')
        
        # Generate samples
        samples = model.generate(4, device)
        fig, axes = plt.subplots(1, 4, figsize=(10, 3))
        for i, ax in enumerate(axes):
            ax.imshow(samples[i].cpu().squeeze().numpy(), cmap='gray')
            ax.axis('off')
        plt.suptitle(f'Generated Samples - Epoch {epoch} (β={current_beta:.2f})')
        plt.savefig(f'/kaggle/working/results/samples_epoch_{epoch}.png', dpi=150)
        plt.show()

# Save final model
torch.save(model.state_dict(), '/kaggle/working/checkpoints/final_model.pt')
print("\nTraining complete!")


## 5. Loss Curves


In [None]:
# Plot training curves with beta annealing
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Total loss
axes[0, 0].plot(train_losses, label='Train', color='blue', linewidth=2)
axes[0, 0].plot(val_losses, label='Validation', color='orange', linewidth=2)
axes[0, 0].axvline(x=CONFIG['beta_warmup_epochs'], color='gray', linestyle='--', alpha=0.7, label='Warmup End')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Total Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Reconstruction loss
axes[0, 1].plot(recon_losses, color='green', linewidth=2)
axes[0, 1].axvline(x=CONFIG['beta_warmup_epochs'], color='gray', linestyle='--', alpha=0.7)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('Reconstruction Loss')
axes[0, 1].grid(True, alpha=0.3)

# KL loss
axes[1, 0].plot(kl_losses, color='red', linewidth=2)
axes[1, 0].axvline(x=CONFIG['beta_warmup_epochs'], color='gray', linestyle='--', alpha=0.7)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('KL Divergence Loss')
axes[1, 0].grid(True, alpha=0.3)

# Beta annealing schedule
axes[1, 1].plot(beta_history, color='purple', linewidth=2)
axes[1, 1].axvline(x=CONFIG['beta_warmup_epochs'], color='gray', linestyle='--', alpha=0.7, label='Warmup End')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Beta (β)')
axes[1, 1].set_title('Beta Annealing Schedule')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('/kaggle/working/results/loss_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nFinal Train Loss: {train_losses[-1]:.4f}")
print(f"Final Val Loss: {val_losses[-1]:.4f}")
print(f"Best Val Loss: {best_loss:.4f}")
print(f"Final Beta: {beta_history[-1]:.4f}")


## 6. Generate Images


In [None]:
# Load best model
checkpoint = torch.load('/kaggle/working/checkpoints/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Loaded best model from epoch {checkpoint['epoch']}")


In [None]:
# Generate samples
num_samples = 16
generated_images = model.generate(num_samples, device)

fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(generated_images[i].cpu().squeeze().numpy(), cmap='gray')
    ax.axis('off')
plt.suptitle('Generated Chest X-Ray Images', fontsize=16)
plt.tight_layout()
plt.savefig('/kaggle/working/results/generated_samples.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# Reconstructions
model.eval()
sample_batch = next(iter(test_loader))[:8].to(device)

with torch.no_grad():
    recon, _, _ = model(sample_batch)

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(8):
    axes[0, i].imshow(sample_batch[i].cpu().squeeze().numpy(), cmap='gray')
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Original', fontsize=10)
    
    axes[1, i].imshow(recon[i].cpu().squeeze().numpy(), cmap='gray')
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Reconstructed', fontsize=10)

plt.tight_layout()
plt.savefig('/kaggle/working/results/reconstructions.png', dpi=150, bbox_inches='tight')
plt.show()


In [None]:
# Latent space interpolation
model.eval()
num_steps = 10

z1 = torch.randn(1, CONFIG['latent_dim']).to(device)
z2 = torch.randn(1, CONFIG['latent_dim']).to(device)

interpolations = []
for alpha in np.linspace(0, 1, num_steps):
    z = (1 - alpha) * z1 + alpha * z2
    with torch.no_grad():
        img = model.decoder(z)
    interpolations.append(img)

fig, axes = plt.subplots(1, num_steps, figsize=(num_steps * 1.5, 2))
for i, ax in enumerate(axes):
    ax.imshow(interpolations[i].cpu().squeeze().numpy(), cmap='gray')
    ax.axis('off')

plt.suptitle('Latent Space Interpolation', fontsize=12)
plt.tight_layout()
plt.savefig('/kaggle/working/results/interpolation.png', dpi=150, bbox_inches='tight')
plt.show()


## 7. Evaluation (FID & Inception Score)


In [None]:
class InceptionV3Features(nn.Module):
    def __init__(self, device):
        super().__init__()
        inception = models.inception_v3(pretrained=True)
        self.blocks = nn.Sequential(
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(3, stride=2),
            inception.Conv2d_3b_1x1,
            inception.Conv2d_4a_3x3,
            nn.MaxPool2d(3, stride=2),
            inception.Mixed_5b,
            inception.Mixed_5c,
            inception.Mixed_5d,
            inception.Mixed_6a,
            inception.Mixed_6b,
            inception.Mixed_6c,
            inception.Mixed_6d,
            inception.Mixed_6e,
            inception.Mixed_7a,
            inception.Mixed_7b,
            inception.Mixed_7c,
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.to(device)
        self.eval()
        self.device = device
    
    @torch.no_grad()
    def forward(self, x):
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)
        x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        x = self.blocks(x)
        return x.view(x.size(0), -1)


In [None]:
def get_activations(images, model, batch_size=32):
    model.eval()
    activations = []
    
    dataset = TensorDataset(images)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    for batch in tqdm(loader, desc='Computing activations'):
        batch = batch[0].to(model.device)
        act = model(batch)
        activations.append(act.cpu().numpy())
    
    return np.concatenate(activations, axis=0)


def calculate_fid(real_activations, fake_activations):
    mu_real = np.mean(real_activations, axis=0)
    mu_fake = np.mean(fake_activations, axis=0)
    sigma_real = np.cov(real_activations, rowvar=False)
    sigma_fake = np.cov(fake_activations, rowvar=False)
    
    diff = mu_real - mu_fake
    
    try:
        covmean, _ = linalg.sqrtm(sigma_real @ sigma_fake, disp=False)
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
    except Exception as e:
        print(f"FID calculation error: {e}")
        fid = float('inf')
    
    return fid


def calculate_inception_score(images, device, batch_size=32, splits=10):
    inception = models.inception_v3(pretrained=True).to(device)
    inception.eval()
    
    preds = []
    dataset = TensorDataset(images)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Computing IS'):
            batch = batch[0].to(device)
            if batch.size(1) == 1:
                batch = batch.repeat(1, 3, 1, 1)
            batch = F.interpolate(batch, size=(299, 299), mode='bilinear', align_corners=False)
            pred = F.softmax(inception(batch), dim=1)
            preds.append(pred.cpu().numpy())
    
    preds = np.concatenate(preds, axis=0)
    
    scores = []
    split_size = preds.shape[0] // splits
    
    for i in range(splits):
        part = preds[i * split_size:(i + 1) * split_size]
        py = np.mean(part, axis=0)
        kl_divs = part * (np.log(part + 1e-10) - np.log(py + 1e-10))
        kl_div = np.mean(np.sum(kl_divs, axis=1))
        scores.append(np.exp(kl_div))
    
    return np.mean(scores), np.std(scores)


In [None]:
# Evaluate VAE
print("Evaluating VAE...")
NUM_EVAL_SAMPLES = 1000

# Generate fake images
print(f"Generating {NUM_EVAL_SAMPLES} images...")
model.eval()
fake_images = model.generate(NUM_EVAL_SAMPLES, device).cpu()

# Collect real images
print("Collecting real images...")
real_images = []
for batch in test_loader:
    real_images.append(batch)
    if len(real_images) * batch.size(0) >= NUM_EVAL_SAMPLES:
        break
real_images = torch.cat(real_images, dim=0)[:NUM_EVAL_SAMPLES]

print(f"Real images shape: {real_images.shape}")
print(f"Fake images shape: {fake_images.shape}")


In [None]:
# Calculate FID
print("\nCalculating FID Score...")
try:
    inception_model = InceptionV3Features(device)
    real_acts = get_activations(real_images, inception_model)
    fake_acts = get_activations(fake_images, inception_model)
    fid_score = calculate_fid(real_acts, fake_acts)
    print(f"FID Score: {fid_score:.4f}")
except Exception as e:
    print(f"FID calculation failed: {e}")
    fid_score = None


In [None]:
# Calculate Inception Score
print("\nCalculating Inception Score...")
try:
    is_mean, is_std = calculate_inception_score(fake_images, device)
    print(f"Inception Score: {is_mean:.4f} ± {is_std:.4f}")
except Exception as e:
    print(f"IS calculation failed: {e}")
    is_mean, is_std = None, None


In [None]:
# Summary
print("\n" + "="*50)
print("EVALUATION SUMMARY")
print("="*50)
print(f"FID Score: {fid_score:.4f}" if fid_score else "FID Score: N/A")
print(f"Inception Score: {is_mean:.4f} ± {is_std:.4f}" if is_mean else "Inception Score: N/A")
print("="*50)


## 8. Save Results Summary


In [None]:
# Save hyperparameters and results
results_summary = f"""
VAE Chest X-Ray Generation - Results Summary
=============================================

HYPERPARAMETERS:
- Image Size: {CONFIG['img_size']}x{CONFIG['img_size']}
- Latent Dimension: {CONFIG['latent_dim']}
- Batch Size: {CONFIG['batch_size']}
- Epochs: {CONFIG['epochs']}
- Learning Rate: {CONFIG['learning_rate']}

BETA ANNEALING:
- Beta Start: {CONFIG['beta_start']}
- Beta End: {CONFIG['beta_end']}
- Warmup Epochs: {CONFIG['beta_warmup_epochs']}
- Strategy: Linear annealing from {CONFIG['beta_start']} to {CONFIG['beta_end']} over {CONFIG['beta_warmup_epochs']} epochs

TRAINING RESULTS:
- Final Train Loss: {train_losses[-1]:.4f}
- Final Val Loss: {val_losses[-1]:.4f}
- Best Val Loss: {best_loss:.4f}
- Final Reconstruction Loss: {recon_losses[-1]:.4f}
- Final KL Loss: {kl_losses[-1]:.4f}

EVALUATION METRICS:
- FID Score: {fid_score:.4f if fid_score else 'N/A'}
- Inception Score: {f'{is_mean:.4f} ± {is_std:.4f}' if is_mean else 'N/A'}

MODEL ARCHITECTURE:
- Encoder: 5 Conv layers (32->64->128->256->512) with BatchNorm
- Decoder: 5 TransposeConv layers (512->256->128->64->32->1)
- Total Parameters: {total_params:,}
"""

with open('/kaggle/working/results/summary.txt', 'w') as f:
    f.write(results_summary)

print(results_summary)


In [None]:
# List all saved files
print("\nSaved files:")
print("\nCheckpoints:")
for f in os.listdir('/kaggle/working/checkpoints'):
    print(f"  - {f}")

print("\nResults:")
for f in os.listdir('/kaggle/working/results'):
    print(f"  - {f}")
