# Variational Autoencoder (VAE) Training

This notebook trains a Variational Autoencoder for anomaly detection.

**Key Differences from CAE:**
- Learns a **probabilistic latent space**
- Uses **KL divergence** as regularization
- Can **generate** new samples

## 1. Setup & Imports

In [None]:
import sys
sys.path.insert(0, 'F:/Thesis')

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

from src.config import DEVICE, MODELS_DIR, FIGURES_DIR, MVTEC_CATEGORIES, ensure_dirs
from src.data import MVTecDataset
from src.data.transforms import denormalize
from src.models import create_vae
from src.training import AutoencoderTrainer, EarlyStopping, get_optimizer, get_scheduler

ensure_dirs()
print(f"Device: {DEVICE}")

## 2. Configuration

In [None]:
CONFIG = {
    'category': 'bottle',
    'batch_size': 16,
    'num_epochs': 100,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    'channels': [32, 64, 128, 256],
    'latent_dim': 128,
    'beta': 1.0,  # KL weight (beta-VAE)
    'early_stopping_patience': 15,
    'save_every': 20,
}

print("Training Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 3. Load Dataset

In [None]:
train_dataset = MVTecDataset(category=CONFIG['category'], split='train')
test_dataset = MVTecDataset(category=CONFIG['category'], split='test', return_mask=True)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0
)

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 4. Create Model

In [None]:
model = create_vae(
    channels=CONFIG['channels'],
    latent_dim=CONFIG['latent_dim'],
    beta=CONFIG['beta']
)

print(f"VAE Model")
print(f"Latent dimension: {CONFIG['latent_dim']}")
print(f"Beta (KL weight): {CONFIG['beta']}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 5. Custom VAE Training Loop

In [None]:
optimizer = get_optimizer(model, lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = get_scheduler(optimizer, 'plateau', patience=5, factor=0.5)
early_stopping = EarlyStopping(patience=CONFIG['early_stopping_patience'])

model = model.to(DEVICE)

# Training history
history = {'epochs': [], 'train_loss': [], 'recon_loss': [], 'kl_loss': []}

experiment_name = f'vae_{CONFIG["category"]}'

In [None]:
from tqdm import tqdm

for epoch in tqdm(range(1, CONFIG['num_epochs'] + 1), desc='Training'):
    model.train()
    epoch_loss = 0.0
    epoch_recon = 0.0
    epoch_kl = 0.0
    
    for batch in train_loader:
        images = batch[0].to(DEVICE)
        
        optimizer.zero_grad()
        
        # Forward pass
        recon, mu, logvar = model(images)
        
        # Compute losses
        losses = model.loss_function(images, recon, mu, logvar)
        loss = losses['loss']
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_recon += losses['recon_loss'].item()
        epoch_kl += losses['kl_loss'].item()
    
    # Average losses
    n_batches = len(train_loader)
    avg_loss = epoch_loss / n_batches
    avg_recon = epoch_recon / n_batches
    avg_kl = epoch_kl / n_batches
    
    # Record history
    history['epochs'].append(epoch)
    history['train_loss'].append(avg_loss)
    history['recon_loss'].append(avg_recon)
    history['kl_loss'].append(avg_kl)
    
    # Scheduler step
    scheduler.step(avg_loss)
    
    # Early stopping
    if early_stopping(avg_loss):
        print(f'\nEarly stopping at epoch {epoch}')
        break
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}: Loss={avg_loss:.4f}, Recon={avg_recon:.4f}, KL={avg_kl:.4f}')

print(f"\nTraining completed!")

## 6. Training Visualization

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Total loss
axes[0].plot(history['epochs'], history['train_loss'], 'b-', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Total Loss')
axes[0].set_title('Total Loss')
axes[0].grid(True, alpha=0.3)

# Reconstruction loss
axes[1].plot(history['epochs'], history['recon_loss'], 'g-', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Reconstruction Loss')
axes[1].set_title('Reconstruction Loss')
axes[1].grid(True, alpha=0.3)

# KL loss
axes[2].plot(history['epochs'], history['kl_loss'], 'r-', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('KL Divergence')
axes[2].set_title('KL Divergence')
axes[2].grid(True, alpha=0.3)

plt.suptitle(f'VAE Training Losses - {CONFIG["category"].title()}', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / f'{experiment_name}_loss_curves.png', dpi=150)
plt.show()

## 7. Reconstruction & Anomaly Maps

In [None]:
model.eval()

n_samples = 5
fig, axes = plt.subplots(4, n_samples, figsize=(15, 12))

for i in range(n_samples):
    img, mask, label = test_dataset[i * 10]
    img_input = img.unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        recon, mu, logvar = model(img_input)
        error_map = model.get_anomaly_map(img_input)
    
    img_np = denormalize(img).permute(1, 2, 0).numpy().clip(0, 1)
    recon_np = recon[0].cpu().permute(1, 2, 0).numpy().clip(0, 1)
    error_np = error_map[0, 0].cpu().numpy()
    mask_np = mask[0].numpy()
    
    axes[0, i].imshow(img_np)
    axes[0, i].set_title(f'{"Defect" if label else "Normal"}')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(recon_np)
    axes[1, i].set_title('Reconstruction')
    axes[1, i].axis('off')
    
    axes[2, i].imshow(error_np, cmap='hot')
    axes[2, i].set_title('Error Map')
    axes[2, i].axis('off')
    
    axes[3, i].imshow(mask_np, cmap='gray')
    axes[3, i].set_title('Ground Truth')
    axes[3, i].axis('off')

plt.suptitle(f'VAE Reconstruction Results - {CONFIG["category"].title()}', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / f'{experiment_name}_reconstructions.png', dpi=150)
plt.show()

## 8. Latent Space Visualization

In [None]:
# Encode all test samples
latent_vectors = []
labels = []

model.eval()
with torch.no_grad():
    for img, mask, label in test_loader:
        z = model.encode(img.to(DEVICE))
        latent_vectors.append(z.cpu())
        labels.extend(label.numpy())

latent_vectors = torch.cat(latent_vectors, dim=0).numpy()
labels = np.array(labels)

# PCA for visualization
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
latent_2d = pca.fit_transform(latent_vectors)

plt.figure(figsize=(10, 8))
scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], c=labels, cmap='coolwarm', alpha=0.7)
plt.colorbar(scatter, label='Anomaly (1) / Normal (0)')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title(f'VAE Latent Space (PCA) - {CONFIG["category"].title()}')
plt.grid(True, alpha=0.3)
plt.savefig(FIGURES_DIR / f'{experiment_name}_latent_space.png', dpi=150)
plt.show()

## 9. Evaluation

In [None]:
from sklearn.metrics import roc_auc_score, roc_curve

# Compute anomaly scores
all_scores = []
all_labels = []

model.eval()
with torch.no_grad():
    for img, mask, label in test_loader:
        scores = model.get_anomaly_score(img.to(DEVICE))
        all_scores.extend(scores.cpu().numpy())
        all_labels.extend(label.numpy())

all_scores = np.array(all_scores)
all_labels = np.array(all_labels)

auc = roc_auc_score(all_labels, all_scores)
fpr, tpr, _ = roc_curve(all_labels, all_scores)

print(f"Image-level ROC-AUC: {auc:.4f}")

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, 'b-', linewidth=2, label=f'VAE (AUC = {auc:.4f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'ROC Curve - VAE - {CONFIG["category"].title()}')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig(FIGURES_DIR / f'{experiment_name}_roc_curve.png', dpi=150)
plt.show()

## 10. Save Model

In [None]:
save_path = MODELS_DIR / f'{experiment_name}_final.pth'

torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'history': history,
    'auc': auc,
}, save_path)

print(f"Model saved to: {save_path}")
print(f"\n=== Training Summary ===")
print(f"Category: {CONFIG['category']}")
print(f"Epochs: {len(history['epochs'])}")
print(f"Final loss: {history['train_loss'][-1]:.6f}")
print(f"ROC-AUC: {auc:.4f}")