In [4]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import pandas as pd
from PIL import Image
from tqdm import tqdm

# Imports required for TTA from the notebook
import albumentations as A
from albumentations.pytorch import ToTensorV2

# ==========================================
# 1. CONFIGURATION
# ==========================================
DEGRADED_TEST_DIR = "/kaggle/input/datasets/andreaspagnolo/visual-exam-dataset-2/visual_dataset/test_degradato"
MODEL_PATH = "/kaggle/input/models/andreaspagnolo/task2-denoisetestset/pytorch/default/1/best_model_task2_denoiseTestSet.pth"
BATCH_SIZE = 64
NUM_CLASSES = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Normalization constants used in the notebook
MEAN = [0.485, 0.456, 0.406]
STD  = [0.229, 0.224, 0.225]

print(f"Using device: {DEVICE}")

# ==========================================
# 2. MODEL DEFINITION
# ==========================================
def build_resnet18(num_classes: int) -> nn.Module:
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def load_model(path, num_classes, device):
    model = build_resnet18(num_classes)
    try:
        checkpoint = torch.load(path, map_location=device)
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint
        model.load_state_dict(state_dict)
        print("Model weights loaded successfully.")
    except Exception as e:
        print(f"[ERROR] Unable to load weights: {e}")
        return None
    model.to(device)
    model.eval()
    return model

# ==========================================
# 3. TRANSFORMATIONS (Standard vs TTA)
# ==========================================

# --- A. Standard PyTorch Transforms (No TTA) ---
test_transforms_standard = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

# --- B. TTA Transforms (Albumentations from Notebook) ---
def build_tta_transforms() -> list:
    """
    Returns the list of 10 augmentations defined in the input notebook.
    """
    norm = [A.Normalize(mean=MEAN, std=STD), ToTensorV2()]
    return [
        A.Compose([A.Resize(256, 256), A.CenterCrop(224, 224), *norm]),
        A.Compose([A.Resize(256, 256), A.CenterCrop(224, 224),
                   A.HorizontalFlip(p=1), *norm]),
        A.Compose([A.Resize(256, 256), A.CenterCrop(224, 224),
                   A.Sharpen(alpha=(0.3, 0.5), lightness=(0.8, 1.2), p=1), *norm]),
        A.Compose([A.Resize(256, 256), A.CenterCrop(224, 224),
                   A.HorizontalFlip(p=1),
                   A.Sharpen(alpha=(0.5, 0.7), lightness=(0.9, 1.1), p=1), *norm]),
        A.Compose([A.Resize(256, 256), A.CenterCrop(224, 224),
                   A.CLAHE(clip_limit=3.0, tile_grid_size=(8, 8), p=1), *norm]),
        A.Compose([A.Resize(256, 256), A.CenterCrop(224, 224),
                   A.HorizontalFlip(p=1),
                   A.CLAHE(clip_limit=3.0, tile_grid_size=(8, 8), p=1), *norm]),
        A.Compose([A.Resize(288, 288), A.CenterCrop(224, 224), *norm]),
        A.Compose([A.Resize(288, 288), A.CenterCrop(224, 224),
                   A.HorizontalFlip(p=1), *norm]),
        A.Compose([A.Resize(256, 256), A.RandomCrop(224, 224), *norm]),
        A.Compose([A.Resize(256, 256), A.CenterCrop(224, 224),
                   A.UnsharpMask(blur_limit=(3, 5), sigma_limit=(0.0, 1.0),
                                 alpha=(0.2, 0.5), p=1), *norm]),
    ]

@torch.no_grad()
def predict_tta(model, img_np: np.ndarray, tta_tfs: list, device) -> int:
    """
    Applies TTA transforms, aggregates probabilities, and returns best class index.
    """
    model.eval()
    probs_list = []
    
    # We use basic pytorch autocast if available for speed, though not strictly required for inference
    # Note: If autocast gives issues on CPU, it safely ignores usually, but here we check device
    use_amp = (device.type == 'cuda')
    
    for tf in tta_tfs:
        # Apply albumentations transform
        augmented = tf(image=img_np.copy())["image"]
        t = augmented.unsqueeze(0).to(device)
        
        with torch.amp.autocast('cuda', enabled=use_amp):
            # Model output
            output = model(t)
            # Softmax to get probabilities
            p = F.softmax(output, dim=1)
            
        probs_list.append(p.cpu().float().numpy())
    
    # Average predictions across all TTA versions
    mean_probs = np.mean(probs_list, axis=0)
    return int(np.argmax(mean_probs))

# ==========================================
# 4. METRICS & PLOTTING UTILS
# ==========================================
def calculate_and_print_metrics(targets, preds, class_names, title="Metrics"):
    print(f"\n--- {title} ---")
    
    report_dict = classification_report(
        targets, preds, target_names=class_names, output_dict=True, zero_division=0
    )
    
    accuracy = report_dict['accuracy']
    macro = report_dict['macro avg']
    weighted = report_dict['weighted avg']
    
    print(f"Overall Accuracy:   {accuracy:.4f} ({accuracy*100:.2f}%)")
    print("-" * 40)
    print("MACRO AVERAGE:")
    print(f"  Precision: {macro['precision']:.4f}")
    print(f"  Recall:    {macro['recall']:.4f}")
    print(f"  F1-Score:  {macro['f1-score']:.4f}")
    print("-" * 40)
    print("WEIGHTED AVERAGE:")
    print(f"  Precision: {weighted['precision']:.4f}")
    print(f"  Recall:    {weighted['recall']:.4f}")
    print(f"  F1-Score:  {weighted['f1-score']:.4f}")
    print("=" * 40)
    
    return report_dict

def plot_confusion_matrix(targets, preds, class_names, output_filename):
    cm = confusion_matrix(targets, preds)
    plt.figure(figsize=(20, 20))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names, cbar=False)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix - {output_filename}')
    plt.tight_layout()
    plt.savefig(output_filename, dpi=100)
    print(f"Confusion matrix saved to {output_filename}")
    plt.close()

# ==========================================
# 5. EVALUATION PIPELINES
# ==========================================

def run_evaluation_standard(model, dataset, device):
    """Standard evaluation using DataLoader (No TTA)"""
    print("\n[Mode: Standard (No TTA)] Starting inference...")
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Standard Eval"):
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(labels.numpy())
            
    return all_targets, all_preds

def run_evaluation_tta(model, dataset, device):
    """Evaluation using Test Time Augmentation (TTA)"""
    print("\n[Mode: TTA (Test Time Augmentation)] Starting inference...")
    
    # Initialize TTA transforms
    tta_transforms = build_tta_transforms()
    
    all_preds = []
    all_targets = []
    
    # Iterate through samples directly to apply custom TTA logic
    # dataset.samples contains list of (image_path, label_index)
    samples = dataset.samples
    
    for img_path, label in tqdm(samples, desc="TTA Eval"):
        # Load image as Numpy (RGB) to match Albumentations requirement
        # Note: PIL.Image -> np.array is what the notebook used
        image = np.array(Image.open(img_path).convert("RGB"))
        
        # Get prediction
        pred_idx = predict_tta(model, image, tta_transforms, device)
        
        all_preds.append(pred_idx)
        all_targets.append(label)
        
    return all_targets, all_preds

# ==========================================
# 6. MAIN EXECUTION
# ==========================================
def main():
    if not os.path.exists(DEGRADED_TEST_DIR):
        print(f"[ERROR] Directory not found: {DEGRADED_TEST_DIR}")
        return

    # 1. Load Model
    # We use the standard dataset class just to get class names and file paths easily
    # For standard eval, we pass the standard transform. 
    # For TTA, we ignore the transform in the dataset and load raw images.
    base_dataset = datasets.ImageFolder(DEGRADED_TEST_DIR, transform=test_transforms_standard)
    class_names = base_dataset.classes
    print(f"Classes: {len(class_names)}")
    print(f"Total Images: {len(base_dataset)}")

    model = load_model(MODEL_PATH, NUM_CLASSES, DEVICE)
    if model is None: return

    # ---------------------------------------------------------
    # 2. Run Standard Evaluation (No TTA)
    # ---------------------------------------------------------
    targets_std, preds_std = run_evaluation_standard(model, base_dataset, DEVICE)
    metrics_std = calculate_and_print_metrics(targets_std, preds_std, class_names, title="RESULTS: NO TTA")
    plot_confusion_matrix(targets_std, preds_std, class_names, "cm_no_tta.png")

    # ---------------------------------------------------------
    # 3. Run TTA Evaluation
    # ---------------------------------------------------------
    targets_tta, preds_tta = run_evaluation_tta(model, base_dataset, DEVICE)
    metrics_tta = calculate_and_print_metrics(targets_tta, preds_tta, class_names, title="RESULTS: WITH TTA")
    plot_confusion_matrix(targets_tta, preds_tta, class_names, "cm_with_tta.png")

    # ---------------------------------------------------------
    # 4. Final Comparison
    # ---------------------------------------------------------
    acc_no_tta = metrics_std['accuracy'] * 100
    acc_with_tta = metrics_tta['accuracy'] * 100
    diff = acc_with_tta - acc_no_tta
    
    print("\n" + "="*50)
    print("             FINAL COMPARISON")
    print("="*50)
    print(f"Accuracy (No TTA):   {acc_no_tta:.2f}%")
    print(f"Accuracy (With TTA): {acc_with_tta:.2f}%")
    print(f"Improvement:         {diff:+.2f}%")
    print("="*50)
    
    # Save comparison to CSV
    df_std = pd.DataFrame(metrics_std).transpose()
    df_std.to_csv("metrics_no_tta.csv")
    
    df_tta = pd.DataFrame(metrics_tta).transpose()
    df_tta.to_csv("metrics_with_tta.csv")
    print("\nMetrics saved to 'metrics_no_tta.csv' and 'metrics_with_tta.csv'")

if __name__ == "__main__":
    main()

Using device: cuda
Classes: 100
Total Images: 500
Model weights loaded successfully.

[Mode: Standard (No TTA)] Starting inference...


Standard Eval: 100%|██████████| 8/8 [00:01<00:00,  7.91it/s]



--- RESULTS: NO TTA ---
Overall Accuracy:   0.7880 (78.80%)
----------------------------------------
MACRO AVERAGE:
  Precision: 0.8181
  Recall:    0.7880
  F1-Score:  0.7796
----------------------------------------
WEIGHTED AVERAGE:
  Precision: 0.8181
  Recall:    0.7880
  F1-Score:  0.7796
Confusion matrix saved to cm_no_tta.png

[Mode: TTA (Test Time Augmentation)] Starting inference...


TTA Eval: 100%|██████████| 500/500 [00:26<00:00, 19.10it/s]



--- RESULTS: WITH TTA ---
Overall Accuracy:   0.8160 (81.60%)
----------------------------------------
MACRO AVERAGE:
  Precision: 0.8423
  Recall:    0.8160
  F1-Score:  0.8110
----------------------------------------
WEIGHTED AVERAGE:
  Precision: 0.8423
  Recall:    0.8160
  F1-Score:  0.8110
Confusion matrix saved to cm_with_tta.png

             FINAL COMPARISON
Accuracy (No TTA):   78.80%
Accuracy (With TTA): 81.60%
Improvement:         +2.80%

Metrics saved to 'metrics_no_tta.csv' and 'metrics_with_tta.csv'
