<a href="https://colab.research.google.com/github/chiara01712/CV-OOD-9/blob/main/Untitled1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
from torchvision import models, transforms
from torchvision.datasets import SVHN, CIFAR10, DTD, Places365
from sklearn.covariance import EmpiricalCovariance
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, precision_score, recall_score, f1_score, accuracy_score
import numpy as np
from PIL import Image
import os
import time
import copy
from torch.cuda.amp import GradScaler, autocast

from tqdm import tqdm

# ==================== CONFIGURATION ====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 101
batch_size = 64
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
data_root = "./food-101/images"
train_meta = "./food-101/meta/train.txt"
test_meta = "./food-101/meta/test.txt"
num_seeds = 3  # For reproducibility and statistical significance

# ==================== DATASET CLASS ====================
class Food101Dataset():
    def __init__(self, root_dir, meta_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        idx = 0

        with open(meta_file, 'r') as f:
            lines = f.read().splitlines()

        for line in lines:
            cls, img_id = line.split('/')
            if cls not in self.class_to_idx:
                self.class_to_idx[cls] = idx
                idx += 1
            path = os.path.join(root_dir, cls, img_id + '.jpg')
            self.samples.append((path, self.class_to_idx[cls]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        image = Image.open(path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# ==================== TRANSFORMS ====================
def get_train_transforms(image_size, aug_strength):
    aug_list = [transforms.RandomHorizontalFlip()]
    if aug_strength in ['medium', 'strong']:
        aug_list += [transforms.ColorJitter(0.2, 0.2, 0.2, 0.1)]
    if aug_strength == 'strong':
        aug_list += [transforms.RandomRotation(15),
                    ]
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        *aug_list,
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

def get_test_transforms(image_size):
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

# ==================== MODEL TRAINING ====================

def test_model(model, dataloaders, criterion):
    # Switch the model to evaluation mode (so we don't backpropagate or drop)
    model.eval()
    test_loss = 0
    correct = 0
    target_all = []
    predicted_all = []
    batch_count = 0

    with torch.no_grad():
        batch_count = 0
        for data, target in dataloaders:
            batch_count += 1
            data, target = data.to(device), target.to(device)

            # Get the predicted classes for this batch
            output = model(data)

            # Calculate the loss for this batch
            test_loss += criterion(output, target).item()

            # Calculate the accuracy for this batch
            _, predicted = torch.max(output.data, 1)
            correct += torch.sum(target==predicted).item()

            # Append scalar target and predicted values after moving into cpu
            target_all.extend(target.cpu().numpy())
            predicted_all.extend(predicted.cpu().numpy())

    # Calculate the average loss and total accuracy for this epoch
    avg_loss = test_loss/batch_count
    print(f'Validation set: Average loss: {avg_loss}, Accuracy: {correct}/{len(dataloaders.dataset)} ({100. * correct / len(dataloaders.dataset)}%)')

    avg_loss = test_loss / batch_count
    accuracy = round((accuracy_score(target_all, predicted_all)*100), 2)
    precision = round(precision_score(target_all, predicted_all, average='weighted'), 2)
    recall = round(recall_score(target_all, predicted_all, average='weighted'), 2)
    f1 = round(f1_score(target_all, predicted_all, average='weighted'), 2)

    print(f'Validation set: Average loss: {avg_loss:.4f}, Accuracy: {accuracy}%')
    print(f'Weighted precision: {precision}, recall: {recall}, f1: {f1}')

    return avg_loss, accuracy, precision, recall, f1


def train_model(model, dataloaders, criterion, optimizer, num_epochs):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

            model.train()
            running_loss = 0.0
            running_corrects = 0
            batch_count = 0
            scaler = GradScaler()


            for inputs, labels in tqdm(dataloaders['train'],  desc='train'):
                optimizer.zero_grad()
                #Reduces memory usage and speeds up training by 1.5-3x with minimal accuracy loss.
                with autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()


                # Push the data forward through the model layers
                output = model(data)
                _, preds = torch.max(output, 1)
                loss = criterion(output, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                batch_count += 1

            if batch_count % 30 == 0:
                print(f'Batch {batch_count}/{len(dataloaders["train"])} - Loss: {loss.item():.4f}')
            epoch_loss = running_loss / len(dataloaders['train'].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders['train'].dataset)


             # Validation phase
              val_loss, val_acc, val_prec, val_rec, val_f1 = test_model(model, dataloaders['test'], criterion)

              if val_acc > best_acc:
                  best_acc = val_acc
                  best_model_wts = copy.deepcopy(model.state_dict())
                  print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')


    print(f'Best val Acc: {best_acc:4f}')
    model.load_state_dict(best_model_wts)
    return model

# ==================== OOD DETECTION METHODS ====================
class OODDetector:
    def __init__(self, model):
        self.model = model
        self.features = []
        self.hook = model.features[-1].register_forward_hook(self._hook_fn)

    def _hook_fn(self, module, input, output):
        self.features.append(output)

    def _get_features(self, x):
        self.features.clear()
        _ = self.model(x)
        return torch.nn.functional.adaptive_avg_pool2d(self.features[-1], 1).squeeze()

    def energy_score(self, logits):
        return -torch.logsumexp(logits, dim=1)

    def mahalanobis_score(self, x, class_means, precision):
        feats = self._get_features(x).cpu().numpy()
        if feats.ndim == 1:
            feats = feats[None, :]
        scores = []
        for f in feats:
            dists = [(f - mu).T @ precision @ (f - mu) for mu in class_means.values()]
            scores.append(min(dists))  # oppure max(dists), dipende dallo scopo
        return np.array(scores)

    def gradient_score(self, x):
        x = x.to(device).requires_grad_(True)
        outputs = self.model(x)
        targets = outputs.max(1)[0]
        gradients = torch.autograd.grad(outputs=targets.sum(), inputs=x,
                                      create_graph=False, retain_graph=False)[0]
        return gradients.norm(p=2, dim=(1, 2, 3)).cpu().numpy()

    def cores_score(self, x):
        features = []
        hooks = []

        # Register hooks on intermediate layers
        for layer in [self.model.features[4], self.model.features[8], self.model.features[-1]]:
            hooks.append(layer.register_forward_hook(lambda m, i, o: features.append(o)))

        _ = self.model(x)

        # Compute CORES score
        score = 0
        for feat in features:
            feat = feat.mean(dim=[2, 3])  # Global average pooling
            score += (feat ** 2).mean(dim=1)  # L2 deviation

        # Remove hooks
        for h in hooks:
            h.remove()

        return score.cpu().numpy()

# ==================== CALIBRATION METRICS ====================
def expected_calibration_error(model, dataloader, bins=15):
    bin_boundaries = torch.linspace(0, 1, bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    confidences = []
    accuracies = []
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            probabilities = torch.softmax(logits, dim=1)
            confidence, preds = torch.max(probabilities, 1)
            accuracy = preds.eq(y)

            confidences.extend(confidence.cpu())
            accuracies.extend(accuracy.cpu())

    confidences = torch.tensor(confidences)
    accuracies = torch.tensor(accuracies)

    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    return ece.item()

# ==================== MAIN EXPERIMENT ====================
def main():
    # Seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)

    # Load datasets

    train_dataset = Food101Dataset(data_root, train_meta, transform=None)
    test_dataset = Food101Dataset(data_root, test_meta, transform=None)

    # Initialize test transform
    test_transform = get_test_transforms(224)
    # OOD datasets
    ood_datasets = {
        "SVHN": SVHN(root='./data', split='test', download=True, transform=test_transform),
        "CIFAR10": CIFAR10(root='./data', train=False, download=True, transform=test_transform),
        "DTD": DTD(root='./data', split='test', download=True, transform=test_transform),
        "Places365": Places365(root='./data', split='val', download=True, transform=test_transform),
        "Gaussian Noise": TensorDataset(torch.randn(1000, 3, 224, 224), torch.zeros(1000))
    }

    # Progressive training schedule
    prog_schedule = [
        {'image_size': 128, 'epochs': 5, 'aug_strength': 'light', 'dropout': 0.2},
        {'image_size': 160, 'epochs': 5, 'aug_strength': 'medium', 'dropout': 0.2},
        {'image_size': 224, 'epochs': 10, 'aug_strength': 'strong', 'dropout': 0.5},
    ]

    # Results storage
    all_results = {name: [] for name in ["Energy", "Mahalanobis", "Gradient", "CORES"]}

    for seed in range(num_seeds):
        print(f"\n=== Experiment Seed {seed + 1}/{num_seeds} ===")
        torch.manual_seed(seed)
        np.random.seed(seed)

        # Initialize model
        model = models.efficientnet_v2_m(weights=models.EfficientNet_V2_M_Weights.DEFAULT)
        for param in model.parameters():
            param.requires_grad = False

        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        model = model.to(device)

        # Progressive training
        for stage in prog_schedule:
            print(f"\nStage: {stage['image_size']}px")

            # Update transforms
            train_dataset.dataset.transform = get_train_transforms(
                stage['image_size'], stage['aug_strength'])
            test_dataset.dataset.transform = get_test_transforms(stage['image_size'])

            # Update dropout
            model.classifier[0].p = stage['dropout']

            # Dataloaders
            dataloaders = {
                'train': DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True),
                'test': DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
            }

            # Unfreeze layers in later stages
            if stage['image_size'] >= 160:
                for param in model.features[-3:].parameters():
                    param.requires_grad = True

            optimizer = optim.Adam(model.parameters(), lr=0.001)
            criterion = nn.CrossEntropyLoss()

            model = train_model(model, dataloaders, criterion, optimizer, stage['epochs'])

        # Save model
        torch.save(model.state_dict(), f'efficientnetv2_food101_seed{seed}.pth')

        # Evaluate calibration
        ece = expected_calibration_error(model, dataloaders['test'])
        print(f"\nExpected Calibration Error: {ece:.4f}")

        # OOD Detection Setup
        detector = OODDetector(model)

        # Compute class means and covariance for Mahalanobis
        print("Computing class statistics...")
        train_feats = []
        train_labels = []
        with torch.no_grad():
            for x, y in dataloaders['train']:
                x = x.to(device)
                feats = detector._get_features(x).cpu().numpy()
                train_feats.append(feats)
                train_labels.append(y.numpy())

        train_feats = np.vstack(train_feats)
        train_labels = np.concatenate(train_labels)

        class_means = {}
        for c in np.unique(train_labels):
            class_means[c] = train_feats[train_labels == c].mean(axis=0)

        cov = EmpiricalCovariance().fit(train_feats)
        precision = cov.precision_

        # Evaluate OOD detection
        test_loader = dataloaders['test']

        for ood_name, ood_data in ood_datasets.items():
            print(f"\nEvaluating on {ood_name}...")
            ood_loader = DataLoader(ood_data, batch_size=batch_size)

            # Get all scores
            methods = {
                "Energy": (detector.energy_score, False),
                "Mahalanobis": (lambda x: detector.mahalanobis_score(x, class_means, precision), True),
                "Gradient": (detector.gradient_score, True),
                "CORES": (detector.cores_score, True)
            }

            for method_name, (score_fn, higher_is_ood) in methods.items():
                # Compute ID scores
                id_scores = []
                with torch.no_grad():
                    for x, _ in test_loader:
                        x = x.to(device)
                        if method_name == "Energy":
                            logits = model(x)
                            scores = score_fn(logits).cpu().numpy()
                        else:
                            scores = score_fn(x)
                        id_scores.extend(scores)

                # Compute OOD scores
                ood_scores = []
                with torch.no_grad():
                    for x, _ in ood_loader:
                        x = x[0].to(device) if isinstance(x, list) else x.to(device)
                        if method_name == "Energy":
                            logits = model(x)
                            scores = score_fn(logits).cpu().numpy()
                        else:
                            scores = score_fn(x)
                        ood_scores.extend(scores)

                # Evaluate
                labels = np.concatenate([np.zeros_like(id_scores), np.ones_like(ood_scores)])
                scores = np.concatenate([id_scores, ood_scores])

                if not higher_is_ood:
                    scores = -scores

                auroc = roc_auc_score(labels, scores)
                precision_pr, recall_pr, _ = precision_recall_curve(labels, scores)
                aupr = auc(recall_pr, precision_pr)
                threshold = np.percentile(scores[labels == 1], 95)
                fpr95 = np.mean(scores[labels == 0] >= threshold)

                print(f"{method_name}: AUROC={auroc:.4f}, AUPR={aupr:.4f}, FPR95={fpr95:.4f}")
                all_results[method_name].append((auroc, aupr, fpr95))

    # Print final results
    print("\n=== FINAL RESULTS ===")
    for method, results in all_results.items():
        aurocs = [r[0] for r in results]
        auprs = [r[1] for r in results]
        fpr95s = [r[2] for r in results]

        print(f"\n{method}:")
        print(f"AUROC: {np.mean(aurocs):.4f} ± {np.std(aurocs):.4f}")
        print(f"AUPR: {np.mean(auprs):.4f} ± {np.std(auprs):.4f}")
        print(f"FPR95: {np.mean(fpr95s):.4f} ± {np.std(fpr95s):.4f}")

if __name__ == "__main__":
    main()