# Complete Analysis & Visualization Notebook

Comprehensive visualizations for anomaly detection research:
- Dataset exploration
- Model comparison
- ROC/PR curves
- Anomaly heatmaps
- Score distributions
- Per-category analysis

In [None]:
import sys
sys.path.insert(0, 'F:/Thesis')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.metrics import roc_auc_score

from src.config import DEVICE, MODELS_DIR, FIGURES_DIR, MVTEC_CATEGORIES, ensure_dirs
from src.data import MVTecDataset
from src.data.transforms import denormalize
from src.models import create_cae, create_vae, create_denoising_ae
from src.evaluation import (
    set_style, plot_reconstruction_grid, plot_anomaly_heatmap_overlay,
    plot_roc_curves, plot_precision_recall_curves, plot_score_distribution,
    plot_category_comparison, plot_latent_space_2d, plot_training_curves
)

ensure_dirs()
set_style()
print(f'Device: {DEVICE}')

## 1. Dataset Exploration

In [None]:
# Visualize samples from each MVTec category
fig, axes = plt.subplots(3, 5, figsize=(15, 9))
for i, cat in enumerate(MVTEC_CATEGORIES[:15]):
    try:
        ds = MVTecDataset(category=cat, split='train')
        img, _ = ds[0]
        img_np = denormalize(img).permute(1, 2, 0).numpy().clip(0, 1)
        axes[i//5, i%5].imshow(img_np)
        axes[i//5, i%5].set_title(cat.title(), fontsize=10)
        axes[i//5, i%5].axis('off')
    except Exception as e:
        axes[i//5, i%5].text(0.5, 0.5, 'N/A', ha='center')
        axes[i//5, i%5].axis('off')

plt.suptitle('MVTec AD Categories (15 Types)', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'mvtec_categories_overview.png', dpi=150)
plt.show()

In [None]:
# Dataset statistics
stats = []
for cat in MVTEC_CATEGORIES:
    try:
        train_ds = MVTecDataset(category=cat, split='train')
        test_ds = MVTecDataset(category=cat, split='test')
        n_defect = sum(test_ds.labels)
        stats.append({'category': cat, 'train': len(train_ds), 'test': len(test_ds), 
                      'normal': len(test_ds) - n_defect, 'defect': n_defect})
    except:
        pass

import pandas as pd
df_stats = pd.DataFrame(stats)
display(df_stats)

# Bar plot
fig, ax = plt.subplots(figsize=(14, 5))
x = np.arange(len(df_stats))
ax.bar(x - 0.2, df_stats['train'], 0.4, label='Train (Normal)')
ax.bar(x + 0.2, df_stats['defect'], 0.4, label='Test (Defect)', color='red')
ax.set_xticks(x)
ax.set_xticklabels(df_stats['category'], rotation=45, ha='right')
ax.set_ylabel('Count')
ax.set_title('MVTec AD Dataset Statistics')
ax.legend()
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'mvtec_statistics.png', dpi=150)
plt.show()

## 2. Defect Type Visualization

In [None]:
# Show different defect types for a category
category = 'bottle'
test_ds = MVTecDataset(category=category, split='test', return_mask=True)

# Group by defect type
defect_samples = {}
for i in range(len(test_ds)):
    defect_type = test_ds.get_defect_type(i)
    if defect_type not in defect_samples:
        defect_samples[defect_type] = i

n_types = len(defect_samples)
fig, axes = plt.subplots(3, n_types, figsize=(4*n_types, 10))

for col, (dtype, idx) in enumerate(defect_samples.items()):
    img, mask, label = test_ds[idx]
    img_np = denormalize(img).permute(1, 2, 0).numpy().clip(0, 1)
    mask_np = mask[0].numpy()
    
    axes[0, col].imshow(img_np)
    axes[0, col].set_title(dtype.replace('_', ' ').title(), fontsize=11)
    axes[0, col].axis('off')
    
    axes[1, col].imshow(mask_np, cmap='Reds')
    axes[1, col].set_title('Defect Mask')
    axes[1, col].axis('off')
    
    # Overlay
    axes[2, col].imshow(img_np)
    axes[2, col].imshow(mask_np, cmap='Reds', alpha=0.5)
    axes[2, col].set_title('Overlay')
    axes[2, col].axis('off')

plt.suptitle(f'{category.title()} - Defect Types', fontsize=14)
plt.tight_layout()
plt.savefig(FIGURES_DIR / f'{category}_defect_types.png', dpi=150)
plt.show()

## 3. Model Comparison (Load Trained Models)

In [None]:
# Load trained models (if available)
category = 'bottle'
models = {}

# Try loading CAE
cae_path = MODELS_DIR / f'cae_{category}_final.pth'
if cae_path.exists():
    cae = create_cae()
    cae.load_state_dict(torch.load(cae_path, map_location=DEVICE)['model_state_dict'])
    cae.eval()
    models['CAE'] = cae
    print('Loaded CAE')

# Try loading VAE
vae_path = MODELS_DIR / f'vae_{category}_final.pth'
if vae_path.exists():
    vae = create_vae()
    vae.load_state_dict(torch.load(vae_path, map_location=DEVICE)['model_state_dict'])
    vae.eval()
    models['VAE'] = vae
    print('Loaded VAE')

# Try loading Denoising AE
dae_path = MODELS_DIR / f'denoising_ae_{category}_final.pth'
if dae_path.exists():
    dae = create_denoising_ae()
    dae.load_state_dict(torch.load(dae_path, map_location=DEVICE)['model_state_dict'])
    dae.eval()
    models['Denoising AE'] = dae
    print('Loaded Denoising AE')

if not models:
    print('No trained models found. Please train models first using the training notebooks.')

In [None]:
# If models exist, compute scores and create comparison plots
if models:
    test_ds = MVTecDataset(category=category, split='test', return_mask=True)
    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=16, shuffle=False)
    
    results = {}
    for name, model in models.items():
        scores, labels = [], []
        with torch.no_grad():
            for img, mask, label in test_loader:
                img = img.to(DEVICE)
                if name == 'VAE':
                    score = model.get_anomaly_score(img)
                else:
                    score = model.get_reconstruction_error(img, reduction='mean')
                scores.extend(score.cpu().numpy())
                labels.extend(label.numpy())
        results[name] = (np.array(labels), np.array(scores))
    
    # ROC curves comparison
    plot_roc_curves(results, title=f'Model Comparison - {category.title()}',
                    save_path=FIGURES_DIR / f'{category}_roc_comparison.png')
    
    # Precision-Recall curves
    plot_precision_recall_curves(results, title=f'Precision-Recall - {category.title()}',
                                  save_path=FIGURES_DIR / f'{category}_pr_comparison.png')

## 4. Anomaly Heatmap Visualization

In [None]:
# Detailed heatmap analysis for one model
if models:
    model_name = list(models.keys())[0]
    model = models[model_name]
    
    # Get samples
    test_ds = MVTecDataset(category=category, split='test', return_mask=True)
    
    # Find defect samples
    defect_indices = [i for i, l in enumerate(test_ds.labels) if l == 1][:5]
    
    for idx in defect_indices:
        img, mask, label = test_ds[idx]
        img_input = img.unsqueeze(0).to(DEVICE)
        
        with torch.no_grad():
            if model_name == 'VAE':
                error_map = model.get_anomaly_map(img_input)[0]
            else:
                error_map = model.get_anomaly_map(img_input)[0]
        
        plot_anomaly_heatmap_overlay(
            img, error_map, mask,
            title=f'{model_name} - {test_ds.get_defect_type(idx)}',
            save_path=FIGURES_DIR / f'{category}_heatmap_{idx}.png'
        )

## 5. Score Distribution Analysis

In [None]:
if results:
    for name, (labels, scores) in results.items():
        normal_scores = scores[labels == 0]
        anomaly_scores = scores[labels == 1]
        
        # Find optimal threshold
        from sklearn.metrics import f1_score
        thresholds = np.linspace(scores.min(), scores.max(), 100)
        f1s = [f1_score(labels, scores > t) for t in thresholds]
        best_thresh = thresholds[np.argmax(f1s)]
        
        plot_score_distribution(
            normal_scores, anomaly_scores, threshold=best_thresh,
            title=f'{name} Score Distribution - {category.title()}',
            save_path=FIGURES_DIR / f'{category}_{name.lower().replace(" ", "_")}_scores.png'
        )

## 6. Latent Space Visualization (VAE)

In [None]:
if 'VAE' in models:
    vae = models['VAE']
    test_ds = MVTecDataset(category=category, split='test', return_mask=True)
    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=16, shuffle=False)
    
    latent_vecs, labels = [], []
    with torch.no_grad():
        for img, mask, label in test_loader:
            z = vae.encode(img.to(DEVICE))
            latent_vecs.append(z.cpu())
            labels.extend(label.numpy())
    
    latent_vecs = torch.cat(latent_vecs, 0).numpy()
    labels = np.array(labels)
    
    # PCA visualization
    plot_latent_space_2d(latent_vecs, labels, method='pca',
                         title=f'VAE Latent Space (PCA) - {category.title()}',
                         save_path=FIGURES_DIR / f'{category}_vae_latent_pca.png')
    
    # t-SNE visualization (takes longer)
    plot_latent_space_2d(latent_vecs, labels, method='tsne',
                         title=f'VAE Latent Space (t-SNE) - {category.title()}',
                         save_path=FIGURES_DIR / f'{category}_vae_latent_tsne.png')

## 7. Multi-Category Analysis

In [None]:
# Evaluate CAE on multiple categories (if models trained)
category_results = {}

for cat in MVTEC_CATEGORIES[:5]:  # First 5 categories
    model_path = MODELS_DIR / f'cae_{cat}_final.pth'
    if model_path.exists():
        model = create_cae()
        model.load_state_dict(torch.load(model_path, map_location=DEVICE)['model_state_dict'])
        model.eval()
        
        test_ds = MVTecDataset(category=cat, split='test', return_mask=True)
        test_loader = torch.utils.data.DataLoader(test_ds, batch_size=16)
        
        scores, labels = [], []
        with torch.no_grad():
            for img, mask, label in test_loader:
                error = model.get_reconstruction_error(img.to(DEVICE), reduction='mean')
                scores.extend(error.cpu().numpy())
                labels.extend(label.numpy())
        
        auc = roc_auc_score(labels, scores)
        category_results[cat] = auc
        print(f'{cat}: AUC = {auc:.4f}')

if category_results:
    # Plot category results
    fig, ax = plt.subplots(figsize=(12, 5))
    cats = list(category_results.keys())
    aucs = list(category_results.values())
    colors = ['green' if a > 0.8 else 'orange' if a > 0.6 else 'red' for a in aucs]
    
    ax.bar(cats, aucs, color=colors)
    ax.axhline(0.8, color='green', linestyle='--', alpha=0.5, label='Good (>0.8)')
    ax.axhline(0.6, color='orange', linestyle='--', alpha=0.5, label='Fair (>0.6)')
    ax.set_ylabel('ROC-AUC')
    ax.set_title('CAE Performance by Category')
    ax.legend()
    ax.set_ylim(0, 1)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'cae_category_comparison.png', dpi=150)
    plt.show()

## 8. Summary Statistics

In [None]:
# Create summary table
print('='*60)
print('ANALYSIS SUMMARY')
print('='*60)

# List all saved figures
print('\nGenerated Figures:')
for f in sorted(FIGURES_DIR.glob('*.png')):
    print(f'  - {f.name}')

# List all models
print('\nTrained Models:')
for f in sorted(MODELS_DIR.glob('*.pth')):
    size_mb = f.stat().st_size / 1024 / 1024
    print(f'  - {f.name} ({size_mb:.1f} MB)')