# Convolutional Autoencoder (CAE) Training

This notebook trains a Convolutional Autoencoder for unsupervised anomaly detection on the MVTec AD dataset.

**Key Concepts:**
- Train on **normal images only**
- Detect anomalies via **reconstruction error**
- Higher error = likely defect

## 1. Setup & Imports

In [None]:
import sys
sys.path.insert(0, 'F:/Thesis')

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

from src.config import (
    DEVICE, BATCH_SIZE, NUM_EPOCHS, LEARNING_RATE,
    MODELS_DIR, FIGURES_DIR, MVTEC_CATEGORIES, ensure_dirs
)
from src.data import MVTecDataset, get_transforms
from src.data.transforms import denormalize
from src.models import create_cae
from src.training import (
    AutoencoderTrainer, EarlyStopping,
    get_optimizer, get_scheduler, get_loss_function
)

ensure_dirs()
print(f"Device: {DEVICE}")
print(f"Available categories: {MVTEC_CATEGORIES}")

## 2. Configuration

In [None]:
# Training Configuration
CONFIG = {
    'category': 'bottle',        # MVTec category to train on
    'batch_size': 16,            # Batch size
    'num_epochs': 100,           # Number of epochs
    'learning_rate': 1e-3,       # Initial learning rate
    'weight_decay': 1e-5,        # L2 regularization
    'channels': [32, 64, 128, 256],  # Model architecture
    'early_stopping_patience': 15,
    'save_every': 20,            # Save checkpoint every N epochs
}

print("Training Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 3. Load Dataset

In [None]:
# Create datasets
train_dataset = MVTecDataset(
    category=CONFIG['category'],
    split='train',
)

test_dataset = MVTecDataset(
    category=CONFIG['category'],
    split='test',
    return_mask=True,
)

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=0,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=0,
)

print(f"Category: {CONFIG['category']}")
print(f"Train samples: {len(train_dataset)} (all normal)")
print(f"Test samples: {len(test_dataset)}")

### Visualize Sample Data

In [None]:
# Visualize training samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
    img, label = train_dataset[i]
    img_np = denormalize(img).permute(1, 2, 0).numpy().clip(0, 1)
    ax.imshow(img_np)
    ax.set_title(f'Normal #{i+1}')
    ax.axis('off')

plt.suptitle(f'Training Samples - {CONFIG["category"].title()} (Normal)', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / f'cae_{CONFIG["category"]}_train_samples.png', dpi=150)
plt.show()

## 4. Create Model

In [None]:
# Create CAE model
model = create_cae(channels=CONFIG['channels'])

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: Convolutional Autoencoder")
print(f"Channels: {CONFIG['channels']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
test_input = torch.randn(1, 3, 256, 256)
test_output = model(test_input)
print(f"\nInput shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")

## 5. Training Setup

In [None]:
# Optimizer
optimizer = get_optimizer(
    model,
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    optimizer_type='adam'
)

# Learning rate scheduler
scheduler = get_scheduler(optimizer, 'plateau', patience=5, factor=0.5)

# Loss function
loss_fn = get_loss_function('cae', alpha=0.8)  # 80% MSE, 20% SSIM

# Early stopping
early_stopping = EarlyStopping(patience=CONFIG['early_stopping_patience'])

# Trainer
trainer = AutoencoderTrainer(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    device=DEVICE,
)

print("Training setup complete!")
print(f"Optimizer: Adam (lr={CONFIG['learning_rate']})")
print(f"Scheduler: ReduceLROnPlateau")
print(f"Loss: Combined MSE + SSIM")

## 6. Train Model

In [None]:
# Training
experiment_name = f'cae_{CONFIG["category"]}'

history = trainer.fit(
    train_loader=train_loader,
    val_loader=None,  # No validation for unsupervised
    num_epochs=CONFIG['num_epochs'],
    loss_fn=nn.MSELoss(),  # Simple MSE for training
    early_stopping=early_stopping,
    save_every=CONFIG['save_every'],
    experiment_name=experiment_name,
    verbose=True,
)

print(f"\nTraining completed!")
print(f"Final training loss: {history.train_loss[-1]:.6f}")

## 7. Training Visualization

In [None]:
# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(history.epochs, history.train_loss, 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'CAE Training Loss - {CONFIG["category"].title()}')
plt.grid(True, alpha=0.3)
plt.savefig(FIGURES_DIR / f'{experiment_name}_loss_curve.png', dpi=150)
plt.show()

## 8. Reconstruction Visualization

In [None]:
# Load best model
model.eval()

# Get some test samples
n_samples = 5
fig, axes = plt.subplots(3, n_samples, figsize=(15, 9))

for i in range(n_samples):
    img, mask, label = test_dataset[i * 10]  # Sample every 10th image
    img_input = img.unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        recon = model(img_input)
        error_map = model.get_anomaly_map(img_input)
    
    # Denormalize for display
    img_np = denormalize(img).permute(1, 2, 0).numpy().clip(0, 1)
    recon_np = recon[0].cpu().permute(1, 2, 0).numpy().clip(0, 1)
    error_np = error_map[0, 0].cpu().numpy()
    
    # Original
    axes[0, i].imshow(img_np)
    axes[0, i].set_title(f'Original ({"Defect" if label else "Normal"})')
    axes[0, i].axis('off')
    
    # Reconstruction
    axes[1, i].imshow(recon_np)
    axes[1, i].set_title('Reconstruction')
    axes[1, i].axis('off')
    
    # Error map
    im = axes[2, i].imshow(error_np, cmap='hot')
    axes[2, i].set_title('Error Map')
    axes[2, i].axis('off')

plt.suptitle(f'CAE Reconstruction Results - {CONFIG["category"].title()}', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / f'{experiment_name}_reconstructions.png', dpi=150)
plt.show()

## 9. Anomaly Detection Evaluation

In [None]:
from sklearn.metrics import roc_auc_score, roc_curve

# Compute anomaly scores for all test samples
all_scores = []
all_labels = []

model.eval()
with torch.no_grad():
    for img, mask, label in test_loader:
        img = img.to(DEVICE)
        error = model.get_reconstruction_error(img, reduction='mean')
        all_scores.extend(error.cpu().numpy())
        all_labels.extend(label.numpy())

all_scores = np.array(all_scores)
all_labels = np.array(all_labels)

# Compute ROC-AUC
auc = roc_auc_score(all_labels, all_scores)
fpr, tpr, thresholds = roc_curve(all_labels, all_scores)

print(f"Image-level ROC-AUC: {auc:.4f}")

# Plot ROC curve
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, 'b-', linewidth=2, label=f'CAE (AUC = {auc:.4f})')
plt.plot([0, 1], [0, 1], 'k--', linewidth=1)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'ROC Curve - {CONFIG["category"].title()}')
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.savefig(FIGURES_DIR / f'{experiment_name}_roc_curve.png', dpi=150)
plt.show()

## 10. Save Final Model

In [None]:
# Save model with metadata
save_path = MODELS_DIR / f'{experiment_name}_final.pth'

torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'auc': auc,
    'final_loss': history.train_loss[-1],
}, save_path)

print(f"Model saved to: {save_path}")
print(f"\n=== Training Summary ===")
print(f"Category: {CONFIG['category']}")
print(f"Epochs trained: {len(history.epochs)}")
print(f"Final loss: {history.train_loss[-1]:.6f}")
print(f"ROC-AUC: {auc:.4f}")