In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
# ! pip install numpy==1.24.3 scipy==1.10.1

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import json
import time
import os
import random
from PIL import Image
import cv2
import ipywidgets as widgets
from IPython.display import display, clear_output
import seaborn as sns
import pandas as pd

# Code 1: Vit_tiny model

In [None]:
torch.manual_seed(48)
torch.cuda.manual_seed(48)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

exp_name = "expt_1"
os.makedirs(exp_name, exist_ok=True)

data_dir = "NSD"
train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform)
test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

class ViTMAEModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model("vit_tiny_patch16_224", pretrained=True, num_classes=9)
        self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()
        self.log("val_loss", loss, on_epoch=True)
        self.log("val_acc", acc, on_epoch=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()
        self.log("test_loss", loss, on_epoch=True)
        self.log("test_acc", acc, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0003)

checkpoint_callback = ModelCheckpoint(
    dirpath=exp_name,
    filename="best_vit_tiny_patch16_224",
    monitor="val_acc",
    mode="max",
    save_top_k=1
)
logger = CSVLogger(save_dir=exp_name, name="logs")

model = ViTMAEModel()
trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    callbacks=[checkpoint_callback],
    logger=logger,
    default_root_dir=exp_name,
    enable_progress_bar=True
)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

best_checkpoint = f"{exp_name}/best_vit_tiny_patch16_224.ckpt"
if os.path.exists(best_checkpoint):
    vit_model = ViTMAEModel.load_from_checkpoint(best_checkpoint).model
    vit_bundle = {
        "model_name": "vit_tiny_patch16_224",
        "state_dict": vit_model.state_dict(),
        "transform": transform,
        "num_classes": 9
    }
    torch.save(vit_bundle, f"{exp_name}/vit_tiny_model_bundle.pth")
    print(f"Saved bundle to {exp_name}/vit_tiny_model_bundle.pth")
else:
    raise FileNotFoundError(f"Checkpoint not found at {best_checkpoint}")

# Code 2: ResNet-50

In [None]:
exp_name = "expt_1"
os.makedirs(exp_name, exist_ok=True)

class ResNetModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = models.resnet50(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 9)
        self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0003)

checkpoint_callback = ModelCheckpoint(
    dirpath=exp_name,
    filename="best_resnet50",
    monitor="val_acc",
    mode="max",
    save_top_k=1
)

model = ResNetModel()
trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    callbacks=[checkpoint_callback],
    default_root_dir=exp_name,
    enable_progress_bar=True
)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

best_checkpoint = f"{exp_name}/best_resnet50.ckpt"
if os.path.exists(best_checkpoint):
    resnet_model = ResNetModel.load_from_checkpoint(best_checkpoint).model
    resnet_bundle = {
        "model_name": "resnet50",
        "state_dict": resnet_model.state_dict(),
        "transform": transform,
        "num_classes": 9
    }
    torch.save(resnet_bundle, f"{exp_name}/resnet50_model_bundle.pth")
    print(f"Saved bundle to {exp_name}/resnet50_model_bundle.pth")
else:
    raise FileNotFoundError(f"Checkpoint not found at {best_checkpoint}")

# Code 3: MobileVit XS

In [None]:
torch.manual_seed(48)
torch.cuda.manual_seed(48)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

exp_name = "expt_1"
os.makedirs(exp_name, exist_ok=True)

data_dir = "NSD"
train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=transform)
test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

class MobileViTModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model("mobilevit_xs", pretrained=True, num_classes=9)
        self.criterion = nn.CrossEntropyLoss()
        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()
        self.log("val_loss", loss, on_epoch=True)
        self.log("val_acc", acc, on_epoch=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean()
        self.log("test_loss", loss, on_epoch=True)
        self.log("test_acc", acc, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0003)

checkpoint_callback = ModelCheckpoint(
    dirpath=exp_name,
    filename="best_mobilevit_xs",
    monitor="val_acc",
    mode="max",
    save_top_k=1
)
logger = CSVLogger(save_dir=exp_name, name="logs")

model = MobileViTModel()
trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    callbacks=[checkpoint_callback],
    logger=logger,
    default_root_dir=exp_name,
    enable_progress_bar=True
)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

best_checkpoint = f"{exp_name}/best_mobilevit_xs.ckpt"
if os.path.exists(best_checkpoint):
    mobilevit_model = MobileViTModel.load_from_checkpoint(best_checkpoint).model
    mobilevit_bundle = {
        "model_name": "mobilevit_xs",
        "state_dict": mobilevit_model.state_dict(),
        "transform": transform,
        "num_classes": 9
    }
    torch.save(mobilevit_bundle, f"{exp_name}/mobilevit_xs_model_bundle.pth")
    print(f"Saved bundle to {exp_name}/mobilevit_xs_model_bundle.pth")
else:
    raise FileNotFoundError(f"Checkpoint not found at {best_checkpoint}")

# Code 3.5: Testing

In [None]:
torch.manual_seed(48)
np.random.seed(48)

def load_model_bundle(bundle_path, model_name):
    bundle = torch.load(bundle_path, map_location="cpu", weights_only=False)
    if bundle["model_name"] != model_name:
        raise ValueError(f"Expected {model_name}, got {bundle['model_name']}")
    if model_name == "resnet50":
        model = models.resnet50(weights=None)
        model.fc = nn.Linear(model.fc.in_features, bundle["num_classes"])
    else:
        model = timm.create_model(model_name, pretrained=False, num_classes=bundle["num_classes"])
    model.load_state_dict(bundle["state_dict"])
    model.eval()
    return model, bundle["transform"]

def evaluate(model, loader, device, desc=""):
    model.eval()
    preds, labels = [], []
    start_time = time.time()
    with torch.no_grad():
        for images, lbls in loader:
            images = images.to(device)
            outputs = model(images)
            preds.extend(outputs.argmax(dim=1).cpu().numpy())
            labels.extend(lbls.numpy())
    inference_time = time.time() - start_time
    metrics = {
        "accuracy": accuracy_score(labels, preds) * 100,
        "precision": precision_score(labels, preds, average="macro"),
        "recall": recall_score(labels, preds, average="macro"),
        "f1": f1_score(labels, preds, average="macro"),
        "inference_time": inference_time
    }
    return preds, labels, metrics

exp_name = "expt_1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = "NSD"
test_dataset = datasets.ImageFolder(os.path.join(data_dir, "test"), transform=None)

models_config = [
    {"name": "ViT-Tiny", "model_name": "vit_tiny_patch16_224", "bundle_path": f"{exp_name}/vit_tiny_model_bundle.pth"},
    {"name": "ResNet-50", "model_name": "resnet50", "bundle_path": f"{exp_name}/resnet50_model_bundle.pth"},
    {"name": "MobileViT-XS", "model_name": "mobilevit_xs", "bundle_path": f"{exp_name}/mobilevit_xs_model_bundle.pth"}
]

all_metrics = {}
mean, std = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])

for config in models_config:
    name, model_name, bundle_path = config["name"], config["model_name"], config["bundle_path"]
    print(f"\nTesting {name}...")
    
    model, transform = load_model_bundle(bundle_path, model_name)
    model = model.to(device)
    
    test_dataset.transform = transform
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
    
    _, _, metrics = evaluate(model, test_loader, device, "Balanced Test")
    
    all_metrics[name] = metrics
    
    print(f"{name} Results:")
    for k, v in metrics.items():
        unit = "s" if "time" in k else ""
        print(f"  {k.capitalize()}: {v:.2f}{unit}")
    
    os.makedirs(exp_name, exist_ok=True)
    json_path = f"{exp_name}/{model_name}_test_results.json"
    with open(json_path, "w") as f:
        json.dump(metrics, f, indent=4)
    print(f"Saved metrics to {json_path}")
    
    display_images, display_preds, display_labels = [], [], []
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(device)
            preds = model(images).argmax(dim=1).cpu().numpy()
            if i == 0:
                display_images = images[:5].cpu()
                display_preds = preds[:5]
                display_labels = labels[:5].numpy()
            break
    
    fig, axs = plt.subplots(1, 5, figsize=(15, 3))
    for i, (img, pred, true) in enumerate(zip(display_images, display_preds, display_labels)):
        img = img.permute(1, 2, 0).numpy() * std + mean
        img = np.clip(img, 0, 1)
        axs[i].imshow(img)
        axs[i].set_title(f"Pred: {pred}\nTrue: {true}")
        axs[i].axis("off")
    plt.tight_layout()
    img_path = f"{exp_name}/{model_name}_prediction_examples.png"
    plt.savefig(img_path)
    print(f"Saved predictions to {img_path}")

combined_json_path = f"{exp_name}/combined_test_results.json"
with open(combined_json_path, "w") as f:
    json.dump(all_metrics, f, indent=4)
print(f"\nSaved combined metrics to {combined_json_path}")

# Code 4: Feature Visualization Full

In [None]:
vit_bundle_path = "expt_1/vit_tiny_model_bundle.pth"
resnet_bundle_path = "expt_1/resnet50_model_bundle.pth"
mobilevit_bundle_path = "expt_1/mobilevit_xs_model_bundle.pth"
data_dir = "NSD/test"
output_dir = "expt_1/feature_visualization"

os.makedirs(output_dir, exist_ok=True)

def set_seed(seed=48):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed()

def get_feature_extractor(model, model_type):
    """
    Returns a feature extractor function for the given model type
    """
    if model_type == "vit":
        def extract_features(x):
            x = model.patch_embed(x)
            cls_token = model.cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_token, x), dim=1)
            if hasattr(model, 'pos_drop'):
                x = model.pos_drop(x + model.pos_embed)
            else:
                x = x + model.pos_embed
                
            
            for blk in model.blocks:
                x = blk(x)
                
            x = model.norm(x)
            
            return x[:, 0]
            
    elif model_type == "mobilevit":
        
        def extract_features(x):
            
            x = model.stem(x)
            for stage in model.stages:
                x = stage(x)
            
            x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
            x = x.flatten(1)
            return x
            
    elif model_type == "resnet":
        
        def extract_features(x):
            x = model.conv1(x)
            x = model.bn1(x)
            x = model.relu(x)
            x = model.maxpool(x)
            
            x = model.layer1(x)
            x = model.layer2(x)
            x = model.layer3(x)
            x = model.layer4(x)
            
            x = model.avgpool(x)
            x = torch.flatten(x, 1)
            return x
    
    return extract_features

def load_model_bundle(bundle_path):
    """
    Load a model bundle and return the model, transform, and model type
    """
    bundle = torch.load(bundle_path, map_location="cpu", weights_only=False)
    model_name = bundle["model_name"]
    
    if "vit" in model_name or "mobilevit" in model_name:
        model = timm.create_model(model_name, pretrained=False, num_classes=bundle["num_classes"])
        model_type = "vit" if "vit" in model_name and "mobilevit" not in model_name else "mobilevit"
    elif "resnet" in model_name:
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, bundle["num_classes"])
        model_type = "resnet"
    else:
        raise ValueError(f"Unsupported model type: {model_name}")
    
    model.load_state_dict(bundle["state_dict"])
    model.eval()
    
    return model, bundle["transform"], model_type

def extract_features(model, feature_extractor, data_loader, device):
    """
    Extract features from a model using the provided feature extractor
    """
    features = []
    labels = []
    
    with torch.no_grad():
        for images, targets in data_loader:
            images = images.to(device)
            batch_features = feature_extractor(images)
            features.append(batch_features.cpu().numpy())
            labels.append(targets.numpy())
    
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    
    return features, labels

def calculate_separation_metrics(features, labels):
    """
    Calculate separation metrics for feature space:
    - Average intra-class distance
    - Average inter-class distance
    - Separation ratio (inter/intra)
    """
    classes = np.unique(labels)
    
    
    centers = []
    for c in classes:
        centers.append(np.mean(features[labels == c], axis=0))
    centers = np.array(centers)
    
    
    intra_class_distances = []
    for c in classes:
        class_features = features[labels == c]
        center = centers[int(c)]
        distances = np.sqrt(np.sum((class_features - center)**2, axis=1))
        intra_class_distances.append(np.mean(distances))
    
    avg_intra_class_distance = np.mean(intra_class_distances)
    
    
    inter_class_distances = []
    for i in range(len(centers)):
        for j in range(i+1, len(centers)):
            distance = np.sqrt(np.sum((centers[i] - centers[j])**2))
            inter_class_distances.append(distance)
    
    avg_inter_class_distance = np.mean(inter_class_distances)
    
    
    separation_ratio = avg_inter_class_distance / avg_intra_class_distance
    
    return {
        'avg_intra_class_distance': avg_intra_class_distance,
        'avg_inter_class_distance': avg_inter_class_distance,
        'separation_ratio': separation_ratio
    }

def plot_class_distribution(labels, class_names, title, filename=None):
    """
    Plot class distribution
    """
    class_counts = np.bincount(labels, minlength=len(class_names))
    
    plt.figure(figsize=(10, 6))
    plt.bar(range(len(class_names)), class_counts)
    plt.xticks(range(len(class_names)), class_names, rotation=45)
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.title(title)
    
    if filename:
        plt.savefig(filename)
        plt.close()
    else:
        plt.show()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

print("Loading ViT model...")
vit_model, vit_transform, vit_model_type = load_model_bundle(vit_bundle_path)
vit_model = vit_model.to(device)
vit_feature_extractor = get_feature_extractor(vit_model, vit_model_type)

print("Loading ResNet model...")
resnet_model, resnet_transform, resnet_model_type = load_model_bundle(resnet_bundle_path)
resnet_model = resnet_model.to(device)
resnet_feature_extractor = get_feature_extractor(resnet_model, resnet_model_type)

print("Loading MobileViT model...")
mobilevit_model, mobilevit_transform, mobilevit_model_type = load_model_bundle(mobilevit_bundle_path)
mobilevit_model = mobilevit_model.to(device)
mobilevit_feature_extractor = get_feature_extractor(mobilevit_model, mobilevit_model_type)

print("Loading full dataset...")
full_dataset = datasets.ImageFolder(data_dir, transform=vit_transform)
class_names = full_dataset.classes
class_counts = np.bincount([label for _, label in full_dataset.samples])
print("Class distribution:", dict(zip(class_names, class_counts)))

print("Creating imbalanced dataset...")
min_class = np.argmin(class_counts)
print(f"Minority class: {class_names[min_class]} with {class_counts[min_class]} samples")
min_indices = [i for i, (_, lbl) in enumerate(full_dataset.samples) if lbl == min_class]
min_keep_indices = np.random.choice(min_indices, size=len(min_indices) // 2, replace=False)
other_indices = [i for i, (_, lbl) in enumerate(full_dataset.samples) if lbl != min_class]
imbalanced_indices = other_indices + list(min_keep_indices)
imbalanced_dataset = Subset(full_dataset, imbalanced_indices)

batch_size = 32
full_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
imbalanced_loader = DataLoader(imbalanced_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print("Extracting features from full dataset...")
vit_features_full, vit_labels_full = extract_features(vit_model, vit_feature_extractor, full_loader, device)
resnet_features_full, resnet_labels_full = extract_features(resnet_model, resnet_feature_extractor, full_loader, device)
mobilevit_features_full, mobilevit_labels_full = extract_features(mobilevit_model, mobilevit_feature_extractor, full_loader, device)

print("Extracting features from imbalanced dataset...")
vit_features_imbalanced, vit_labels_imbalanced = extract_features(vit_model, vit_feature_extractor, imbalanced_loader, device)
resnet_features_imbalanced, resnet_labels_imbalanced = extract_features(resnet_model, resnet_feature_extractor, imbalanced_loader, device)
mobilevit_features_imbalanced, mobilevit_labels_imbalanced = extract_features(mobilevit_model, mobilevit_feature_extractor, imbalanced_loader, device)

plot_class_distribution(vit_labels_full, class_names, "Full Dataset Class Distribution", 
                        os.path.join(output_dir, "full_class_distribution.png"))
plot_class_distribution(vit_labels_imbalanced, class_names, "Imbalanced Dataset Class Distribution", 
                        os.path.join(output_dir, "imbalanced_class_distribution.png"))

# Apply PCA to full and imbalanced datasets
print("Applying PCA to full dataset features...")
vit_pca_full = PCA(n_components=2)
vit_pca_full_result = vit_pca_full.fit_transform(vit_features_full)
resnet_pca_full = PCA(n_components=2)
resnet_pca_full_result = resnet_pca_full.fit_transform(resnet_features_full)
mobilevit_pca_full = PCA(n_components=2)
mobilevit_pca_full_result = mobilevit_pca_full.fit_transform(mobilevit_features_full)

print("Applying PCA to imbalanced dataset features...")
vit_pca_imbalanced = PCA(n_components=2)
vit_pca_imbalanced_result = vit_pca_imbalanced.fit_transform(vit_features_imbalanced)
resnet_pca_imbalanced = PCA(n_components=2)
resnet_pca_imbalanced_result = resnet_pca_imbalanced.fit_transform(resnet_features_imbalanced)
mobilevit_pca_imbalanced = PCA(n_components=2)
mobilevit_pca_imbalanced_result = mobilevit_pca_imbalanced.fit_transform(mobilevit_features_imbalanced)

# Apply t-SNE to full and imbalanced datasets
print("Applying t-SNE to full dataset features (this may take a while)...")
vit_tsne_full = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
vit_tsne_full_result = vit_tsne_full.fit_transform(vit_features_full)
resnet_tsne_full = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
resnet_tsne_full_result = resnet_tsne_full.fit_transform(resnet_features_full)
mobilevit_tsne_full = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
mobilevit_tsne_full_result = mobilevit_tsne_full.fit_transform(mobilevit_features_full)

print("Applying t-SNE to imbalanced dataset features (this may take a while)...")
vit_tsne_imbalanced = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
vit_tsne_imbalanced_result = vit_tsne_imbalanced.fit_transform(vit_features_imbalanced)
resnet_tsne_imbalanced = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
resnet_tsne_imbalanced_result = resnet_tsne_imbalanced.fit_transform(resnet_features_imbalanced)
mobilevit_tsne_imbalanced = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)
mobilevit_tsne_imbalanced_result = mobilevit_tsne_imbalanced.fit_transform(mobilevit_features_imbalanced)

# Plot PCA results for full vs imbalanced
plt.figure(figsize=(30, 10))

# ViT PCA - Full Dataset
plt.subplot(2, 3, 1)
scatter = plt.scatter(vit_pca_full_result[:, 0], vit_pca_full_result[:, 1], c=vit_labels_full, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('PCA: ViT features - Full Dataset', fontsize=14)
plt.xlabel('Principal Component 1', fontsize=12)
plt.ylabel('Principal Component 2', fontsize=12)
plt.grid(alpha=0.3)

# ViT PCA - Imbalanced Dataset
plt.subplot(2, 3, 4)
scatter = plt.scatter(vit_pca_imbalanced_result[:, 0], vit_pca_imbalanced_result[:, 1], c=vit_labels_imbalanced, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('PCA: ViT features - Imbalanced Dataset', fontsize=14)
plt.xlabel('Principal Component 1', fontsize=12)
plt.ylabel('Principal Component 2', fontsize=12)
plt.grid(alpha=0.3)

# ResNet PCA - Full Dataset
plt.subplot(2, 3, 2)
scatter = plt.scatter(resnet_pca_full_result[:, 0], resnet_pca_full_result[:, 1], c=resnet_labels_full, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('PCA: ResNet features - Full Dataset', fontsize=14)
plt.xlabel('Principal Component 1', fontsize=12)
plt.ylabel('Principal Component 2', fontsize=12)
plt.grid(alpha=0.3)

# ResNet PCA - Imbalanced Dataset
plt.subplot(2, 3, 5)
scatter = plt.scatter(resnet_pca_imbalanced_result[:, 0], resnet_pca_imbalanced_result[:, 1], c=resnet_labels_imbalanced, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('PCA: ResNet features - Imbalanced Dataset', fontsize=14)
plt.xlabel('Principal Component 1', fontsize=12)
plt.ylabel('Principal Component 2', fontsize=12)
plt.grid(alpha=0.3)

# MobileViT PCA - Full Dataset
plt.subplot(2, 3, 3)
scatter = plt.scatter(mobilevit_pca_full_result[:, 0], mobilevit_pca_full_result[:, 1], c=mobilevit_labels_full, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('PCA: MobileViT features - Full Dataset', fontsize=14)
plt.xlabel('Principal Component 1', fontsize=12)
plt.ylabel('Principal Component 2', fontsize=12)
plt.grid(alpha=0.3)

# MobileViT PCA - Imbalanced Dataset
plt.subplot(2, 3, 6)
scatter = plt.scatter(mobilevit_pca_imbalanced_result[:, 0], mobilevit_pca_imbalanced_result[:, 1], c=mobilevit_labels_imbalanced, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('PCA: MobileViT features - Imbalanced Dataset', fontsize=14)
plt.xlabel('Principal Component 1', fontsize=12)
plt.ylabel('Principal Component 2', fontsize=12)
plt.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'pca_comparison_imbalanced.png'))
plt.close()

# Plot t-SNE results for full vs imbalanced
plt.figure(figsize=(30, 10))

# ViT t-SNE - Full Dataset
plt.subplot(2, 3, 1)
scatter = plt.scatter(vit_tsne_full_result[:, 0], vit_tsne_full_result[:, 1], c=vit_labels_full, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('t-SNE: ViT features - Full Dataset', fontsize=14)
plt.xlabel('t-SNE Dimension 1', fontsize=12)
plt.ylabel('t-SNE Dimension 2', fontsize=12)
plt.grid(alpha=0.3)

# ViT t-SNE - Imbalanced Dataset
plt.subplot(2, 3, 4)
scatter = plt.scatter(vit_tsne_imbalanced_result[:, 0], vit_tsne_imbalanced_result[:, 1], c=vit_labels_imbalanced, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('t-SNE: ViT features - Imbalanced Dataset', fontsize=14)
plt.xlabel('t-SNE Dimension 1', fontsize=12)
plt.ylabel('t-SNE Dimension 2', fontsize=12)
plt.grid(alpha=0.3)

# ResNet t-SNE - Full Dataset
plt.subplot(2, 3, 2)
scatter = plt.scatter(resnet_tsne_full_result[:, 0], resnet_tsne_full_result[:, 1], c=resnet_labels_full, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('t-SNE: ResNet features - Full Dataset', fontsize=14)
plt.xlabel('t-SNE Dimension 1', fontsize=12)
plt.ylabel('t-SNE Dimension 2', fontsize=12)
plt.grid(alpha=0.3)

# ResNet t-SNE - Imbalanced Dataset
plt.subplot(2, 3, 5)
scatter = plt.scatter(resnet_tsne_imbalanced_result[:, 0], resnet_tsne_imbalanced_result[:, 1], c=resnet_labels_imbalanced, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('t-SNE: ResNet features - Imbalanced Dataset', fontsize=14)
plt.xlabel('t-SNE Dimension 1', fontsize=12)
plt.ylabel('t-SNE Dimension 2', fontsize=12)
plt.grid(alpha=0.3)

# MobileViT t-SNE - Full Dataset
plt.subplot(2, 3, 3)
scatter = plt.scatter(mobilevit_tsne_full_result[:, 0], mobilevit_tsne_full_result[:, 1], c=mobilevit_labels_full, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('t-SNE: MobileViT features - Full Dataset', fontsize=14)
plt.xlabel('t-SNE Dimension 1', fontsize=12)
plt.ylabel('t-SNE Dimension 2', fontsize=12)
plt.grid(alpha=0.3)

# MobileViT t-SNE - Imbalanced Dataset
plt.subplot(2, 3, 6)
scatter = plt.scatter(mobilevit_tsne_imbalanced_result[:, 0], mobilevit_tsne_imbalanced_result[:, 1], c=mobilevit_labels_imbalanced, cmap='viridis', alpha=0.7)
plt.colorbar(scatter, boundaries=np.arange(len(class_names)+1)-0.5).set_ticks(np.arange(len(class_names)))
plt.title('t-SNE: MobileViT features - Imbalanced Dataset', fontsize=14)
plt.xlabel('t-SNE Dimension 1', fontsize=12)
plt.ylabel('t-SNE Dimension 2', fontsize=12)
plt.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'tsne_comparison_imbalanced.png'))
plt.close()

min_class_name = class_names[min_class]

def add_minority_flag(labels, min_class):
    return ['Minority Class' if l == min_class else 'Other Classes' for l in labels]

vit_df_full = pd.DataFrame({
    'x': vit_tsne_full_result[:, 0],
    'y': vit_tsne_full_result[:, 1],
    'class': [class_names[l] for l in vit_labels_full],
    'minority_flag': add_minority_flag(vit_labels_full, min_class)
})

resnet_df_full = pd.DataFrame({
    'x': resnet_tsne_full_result[:, 0],
    'y': resnet_tsne_full_result[:, 1],
    'class': [class_names[l] for l in resnet_labels_full],
    'minority_flag': add_minority_flag(resnet_labels_full, min_class)
})

mobilevit_df_full = pd.DataFrame({
    'x': mobilevit_tsne_full_result[:, 0],
    'y': mobilevit_tsne_full_result[:, 1],
    'class': [class_names[l] for l in mobilevit_labels_full],
    'minority_flag': add_minority_flag(mobilevit_labels_full, min_class)
})

vit_df_imbalanced = pd.DataFrame({
    'x': vit_tsne_imbalanced_result[:, 0],
    'y': vit_tsne_imbalanced_result[:, 1],
    'class': [class_names[l] for l in vit_labels_imbalanced],
    'minority_flag': add_minority_flag(vit_labels_imbalanced, min_class)
})

resnet_df_imbalanced = pd.DataFrame({
    'x': resnet_tsne_imbalanced_result[:, 0],
    'y': resnet_tsne_imbalanced_result[:, 1],
    'class': [class_names[l] for l in resnet_labels_imbalanced],
    'minority_flag': add_minority_flag(resnet_labels_imbalanced, min_class)
})

mobilevit_df_imbalanced = pd.DataFrame({
    'x': mobilevit_tsne_imbalanced_result[:, 0],
    'y': mobilevit_tsne_imbalanced_result[:, 1],
    'class': [class_names[l] for l in mobilevit_labels_imbalanced],
    'minority_flag': add_minority_flag(mobilevit_labels_imbalanced, min_class)
})

plt.figure(figsize=(14, 12))
sns.scatterplot(
    x='x', y='y',
    hue='class',
    palette=sns.color_palette("hls", len(class_names)),
    data=vit_df_full,
    legend="full",
    alpha=0.8
)
plt.title('t-SNE: ViT features - Full Dataset', fontsize=18)
plt.xlabel('t-SNE Dimension 1', fontsize=14)
plt.ylabel('t-SNE Dimension 2', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'vit_tsne_full_labeled.png'))
plt.close()

# ViT - Imbalanced Dataset
plt.figure(figsize=(14, 12))
sns.scatterplot(
    x='x', y='y',
    hue='class',
    palette=sns.color_palette("hls", len(class_names)),
    data=vit_df_imbalanced,
    legend="full",
    alpha=0.8
)
plt.title('t-SNE: ViT features - Imbalanced Dataset', fontsize=18)
plt.xlabel('t-SNE Dimension 1', fontsize=14)
plt.ylabel('t-SNE Dimension 2', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'vit_tsne_imbalanced_labeled.png'))
plt.close()

# ResNet - Full Dataset
plt.figure(figsize=(14, 12))
sns.scatterplot(
    x='x', y='y',
    hue='class',
    palette=sns.color_palette("hls", len(class_names)),
    data=resnet_df_full,
    legend="full",
    alpha=0.8
)
plt.title('t-SNE: ResNet features - Full Dataset', fontsize=18)
plt.xlabel('t-SNE Dimension 1', fontsize=14)
plt.ylabel('t-SNE Dimension 2', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'resnet_tsne_full_labeled.png'))
plt.close()

# ResNet - Imbalanced Dataset
plt.figure(figsize=(14, 12))
sns.scatterplot(
    x='x', y='y',
    hue='class',
    palette=sns.color_palette("hls", len(class_names)),
    data=resnet_df_imbalanced,
    legend="full",
    alpha=0.8
)
plt.title('t-SNE: ResNet features - Imbalanced Dataset', fontsize=18)
plt.xlabel('t-SNE Dimension 1', fontsize=14)
plt.ylabel('t-SNE Dimension 2', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'resnet_tsne_imbalanced_labeled.png'))
plt.close()

# MobileViT - Full Dataset
plt.figure(figsize=(14, 12))
sns.scatterplot(
    x='x', y='y',
    hue='class',
    palette=sns.color_palette("hls", len(class_names)),
    data=mobilevit_df_full,
    legend="full",
    alpha=0.8
)
plt.title('t-SNE: MobileViT features - Full Dataset', fontsize=18)
plt.xlabel('t-SNE Dimension 1', fontsize=14)
plt.ylabel('t-SNE Dimension 2', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'mobilevit_tsne_full_labeled.png'))
plt.close()

# MobileViT - Imbalanced Dataset
plt.figure(figsize=(14, 12))
sns.scatterplot(
    x='x', y='y',
    hue='class',
    palette=sns.color_palette("hls", len(class_names)),
    data=mobilevit_df_imbalanced,
    legend="full",
    alpha=0.8
)
plt.title('t-SNE: MobileViT features - Imbalanced Dataset', fontsize=18)
plt.xlabel('t-SNE Dimension 1', fontsize=14)
plt.ylabel('t-SNE Dimension 2', fontsize=14)
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'mobilevit_tsne_imbalanced_labeled.png'))
plt.close()

# ViT - Full vs Imbalanced with highlighted minority class
fig, axes = plt.subplots(1, 3, figsize=(30, 10))

# ViT Full dataset
sns.scatterplot(
    x='x', y='y',
    hue='minority_flag',
    palette={'Minority Class': 'red', 'Other Classes': 'gray'},
    data=vit_df_full,
    alpha=0.8,
    ax=axes[0]
)
axes[0].set_title('ViT: Full Dataset\nRed: Minority Class', fontsize=16)
axes[0].set_xlabel('t-SNE Dimension 1', fontsize=12)
axes[0].set_ylabel('t-SNE Dimension 2', fontsize=12)

# ViT Imbalanced dataset
sns.scatterplot(
    x='x', y='y',
    hue='minority_flag',
    palette={'Minority Class': 'red', 'Other Classes': 'gray'},
    data=vit_df_imbalanced,
    alpha=0.8,
    ax=axes[1]
)
axes[1].set_title('ViT: Imbalanced Dataset\nRed: Minority Class', fontsize=16)
axes[1].set_xlabel('t-SNE Dimension 1', fontsize=12)
axes[1].set_ylabel('t-SNE Dimension 2', fontsize=12)

# MobileViT Full dataset
sns.scatterplot(
    x='x', y='y',
    hue='minority_flag',
    palette={'Minority Class': 'red', 'Other Classes': 'gray'},
    data=mobilevit_df_full,
    alpha=0.8,
    ax=axes[2]
)
axes[2].set_title('MobileViT: Full Dataset\nRed: Minority Class', fontsize=16)
axes[2].set_xlabel('t-SNE Dimension 1', fontsize=12)
axes[2].set_ylabel('t-SNE Dimension 2', fontsize=12)

plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'vit_mobilevit_minority_class_comparison.png'))
plt.close()

# ResNet and MobileViT - Full vs Imbalanced with highlighted minority class
fig, axes = plt.subplots(1, 3, figsize=(30, 10))

# ResNet Full dataset
sns.scatterplot(
    x='x', y='y',
    hue='minority_flag',
    palette={'Minority Class': 'red', 'Other Classes': 'gray'},
    data=resnet_df_full,
    alpha=0.8,
    ax=axes[0]
)
axes[0].set_title('ResNet: Full Dataset\nRed: Minority Class', fontsize=16)
axes[0].set_xlabel('t-SNE Dimension 1', fontsize=12)
axes[0].set_ylabel('t-SNE Dimension 2', fontsize=12)

# ResNet Imbalanced dataset
sns.scatterplot(
    x='x', y='y',
    hue='minority_flag',
    palette={'Minority Class': 'red', 'Other Classes': 'gray'},
    data=resnet_df_imbalanced,
    alpha=0.8,
    ax=axes[1]
)
axes[1].set_title('ResNet: Imbalanced Dataset\nRed: Minority Class', fontsize=16)
axes[1].set_xlabel('t-SNE Dimension 1', fontsize=12)
axes[1].set_ylabel('t-SNE Dimension 2', fontsize=12)

# MobileViT Imbalanced dataset
sns.scatterplot(
    x='x', y='y',
    hue='minority_flag',
    palette={'Minority Class': 'red', 'Other Classes': 'gray'},
    data=mobilevit_df_imbalanced,
    alpha=0.8,
    ax=axes[2]
)
axes[2].set_title('MobileViT: Imbalanced Dataset\nRed: Minority Class', fontsize=16)
axes[2].set_xlabel('t-SNE Dimension 1', fontsize=12)
axes[2].set_ylabel('t-SNE Dimension 2', fontsize=12)

plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'resnet_mobilevit_minority_class_comparison.png'))
plt.close()


metrics = {}

# Full dataset metrics
metrics['vit_full_pca'] = calculate_separation_metrics(vit_pca_full_result, vit_labels_full)
metrics['resnet_full_pca'] = calculate_separation_metrics(resnet_pca_full_result, resnet_labels_full)
metrics['mobilevit_full_pca'] = calculate_separation_metrics(mobilevit_pca_full_result, mobilevit_labels_full)
metrics['vit_full_tsne'] = calculate_separation_metrics(vit_tsne_full_result, vit_labels_full)
metrics['resnet_full_tsne'] = calculate_separation_metrics(resnet_tsne_full_result, resnet_labels_full)
metrics['mobilevit_full_tsne'] = calculate_separation_metrics(mobilevit_tsne_full_result, mobilevit_labels_full)

# Imbalanced dataset metrics
metrics['vit_imbalanced_pca'] = calculate_separation_metrics(vit_pca_imbalanced_result, vit_labels_imbalanced)
metrics['resnet_imbalanced_pca'] = calculate_separation_metrics(resnet_pca_imbalanced_result, resnet_labels_imbalanced)
metrics['mobilevit_imbalanced_pca'] = calculate_separation_metrics(mobilevit_pca_imbalanced_result, mobilevit_labels_imbalanced)
metrics['vit_imbalanced_tsne'] = calculate_separation_metrics(vit_tsne_imbalanced_result, vit_labels_imbalanced)
metrics['resnet_imbalanced_tsne'] = calculate_separation_metrics(resnet_tsne_imbalanced_result, resnet_labels_imbalanced)
metrics['mobilevit_imbalanced_tsne'] = calculate_separation_metrics(mobilevit_tsne_imbalanced_result, mobilevit_labels_imbalanced)

# Metrics dataframe
metrics_df = pd.DataFrame({
    'Model': ['ViT (Full, PCA)', 'ResNet (Full, PCA)', 'MobileViT (Full, PCA)', 
              'ViT (Full, t-SNE)', 'ResNet (Full, t-SNE)', 'MobileViT (Full, t-SNE)',
              'ViT (Imbalanced, PCA)', 'ResNet (Imbalanced, PCA)', 'MobileViT (Imbalanced, PCA)',
              'ViT (Imbalanced, t-SNE)', 'ResNet (Imbalanced, t-SNE)', 'MobileViT (Imbalanced, t-SNE)'],
    'Intra-class Distance': [
        metrics['vit_full_pca']['avg_intra_class_distance'],
        metrics['resnet_full_pca']['avg_intra_class_distance'],
        metrics['mobilevit_full_pca']['avg_intra_class_distance'],
        metrics['vit_full_tsne']['avg_intra_class_distance'],
        metrics['resnet_full_tsne']['avg_intra_class_distance'],
        metrics['mobilevit_full_tsne']['avg_intra_class_distance'],
        metrics['vit_imbalanced_pca']['avg_intra_class_distance'],
        metrics['resnet_imbalanced_pca']['avg_intra_class_distance'],
        metrics['mobilevit_imbalanced_pca']['avg_intra_class_distance'],
        metrics['vit_imbalanced_tsne']['avg_intra_class_distance'],
        metrics['resnet_imbalanced_tsne']['avg_intra_class_distance'],
        metrics['mobilevit_imbalanced_tsne']['avg_intra_class_distance']
    ],
    'Inter-class Distance': [
        metrics['vit_full_pca']['avg_inter_class_distance'],
        metrics['resnet_full_pca']['avg_inter_class_distance'],
        metrics['mobilevit_full_pca']['avg_inter_class_distance'],
        metrics['vit_full_tsne']['avg_inter_class_distance'],
        metrics['resnet_full_tsne']['avg_inter_class_distance'],
        metrics['mobilevit_full_tsne']['avg_inter_class_distance'],
        metrics['vit_imbalanced_pca']['avg_inter_class_distance'],
        metrics['resnet_imbalanced_pca']['avg_inter_class_distance'],
        metrics['mobilevit_imbalanced_pca']['avg_inter_class_distance'],
        metrics['vit_imbalanced_tsne']['avg_inter_class_distance'],
        metrics['resnet_imbalanced_tsne']['avg_inter_class_distance'],
        metrics['mobilevit_imbalanced_tsne']['avg_inter_class_distance']
    ],
    'Separation Ratio': [
        metrics['vit_full_pca']['separation_ratio'],
        metrics['resnet_full_pca']['separation_ratio'],
        metrics['mobilevit_full_pca']['separation_ratio'],
        metrics['vit_full_tsne']['separation_ratio'],
        metrics['resnet_full_tsne']['separation_ratio'],
        metrics['mobilevit_full_tsne']['separation_ratio'],
        metrics['vit_imbalanced_pca']['separation_ratio'],
        metrics['resnet_imbalanced_pca']['separation_ratio'],
        metrics['mobilevit_imbalanced_pca']['separation_ratio'],
        metrics['vit_imbalanced_tsne']['separation_ratio'],
        metrics['resnet_imbalanced_tsne']['separation_ratio'],
        metrics['mobilevit_imbalanced_tsne']['separation_ratio']
    ]
})

# print("\nFeature Separation Metrics:")
# print(metrics_df)

metrics_df.to_csv(os.path.join(output_dir, 'imbalance_feature_separation_metrics.csv'), index=False)
print(f"Saved feature separation metrics to {os.path.join(output_dir, 'imbalance_feature_separation_metrics.csv')}")

# Plot separation metrics as a bar chart for comparison
plt.figure(figsize=(20, 10))
metrics_df_plot = pd.melt(metrics_df, id_vars=['Model'], 
                          value_vars=['Intra-class Distance', 'Inter-class Distance', 'Separation Ratio'],
                          var_name='Metric', value_name='Value')


ax = sns.barplot(x='Model', y='Value', hue='Metric', data=metrics_df_plot)
plt.title('Feature Space Separation Metrics Comparison', fontsize=18)
plt.xticks(rotation=45, ha='right')
plt.legend(title='Metric', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'imbalance_separation_metrics_comparison.png'))
plt.close()

vit_metrics = metrics_df[metrics_df['Model'].str.contains('ViT(?! Mobile)')]
resnet_metrics = metrics_df[metrics_df['Model'].str.contains('ResNet')]
mobilevit_metrics = metrics_df[metrics_df['Model'].str.contains('MobileViT')]
pca_metrics = metrics_df[metrics_df['Model'].str.contains('PCA')]
tsne_metrics = metrics_df[metrics_df['Model'].str.contains('t-SNE')]

# Comparisons by model type
fig, axes = plt.subplots(3, 1, figsize=(14, 24))

# ViT comparison
vit_metrics_plot = pd.melt(vit_metrics, id_vars=['Model'], 
                          value_vars=['Intra-class Distance', 'Inter-class Distance', 'Separation Ratio'],
                          var_name='Metric', value_name='Value')
sns.barplot(x='Model', y='Value', hue='Metric', data=vit_metrics_plot, ax=axes[0])
axes[0].set_title('ViT Feature Space Metrics: Full vs Imbalanced', fontsize=16)
axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=45, ha='right')
axes[0].legend(title='Metric', bbox_to_anchor=(1.05, 1), loc='upper left')

# ResNet comparison
resnet_metrics_plot = pd.melt(resnet_metrics, id_vars=['Model'], 
                            value_vars=['Intra-class Distance', 'Inter-class Distance', 'Separation Ratio'],
                            var_name='Metric', value_name='Value')
sns.barplot(x='Model', y='Value', hue='Metric', data=resnet_metrics_plot, ax=axes[1])
axes[1].set_title('ResNet Feature Space Metrics: Full vs Imbalanced', fontsize=16)
axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=45, ha='right')
axes[1].legend(title='Metric', bbox_to_anchor=(1.05, 1), loc='upper left')

# MobileViT comparison
mobilevit_metrics_plot = pd.melt(mobilevit_metrics, id_vars=['Model'], 
                                value_vars=['Intra-class Distance', 'Inter-class Distance', 'Separation Ratio'],
                                var_name='Metric', value_name='Value')
sns.barplot(x='Model', y='Value', hue='Metric', data=mobilevit_metrics_plot, ax=axes[2])
axes[2].set_title('MobileViT Feature Space Metrics: Full vs Imbalanced', fontsize=16)
axes[2].set_xticklabels(axes[2].get_xticklabels(), rotation=45, ha='right')
axes[2].legend(title='Metric', bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'model_comparison_metrics.png'))
plt.close()

# Calculate minority class-specific metrics
def calculate_minority_class_metrics(features, labels, min_class):
    """
    Calculate metrics specific to the minority class:
    - Distance from minority class center to other class centers
    - Compactness of minority class (average distance from points to center)
    """
    all_classes = np.unique(labels)
    
    
    centers = {}
    for c in all_classes:
        centers[c] = np.mean(features[labels == c], axis=0)
    
    
    min_features = features[labels == min_class]
    min_center = centers[min_class]
    min_distances = np.sqrt(np.sum((min_features - min_center)**2, axis=1))
    min_compactness = np.mean(min_distances) if len(min_distances) > 0 else 0.0
    
    
    distances_to_other_centers = []
    for c in all_classes:
        if c != min_class:
            dist = np.sqrt(np.sum((centers[min_class] - centers[c])**2))
            distances_to_other_centers.append(dist)
    
    avg_distance_to_others = np.mean(distances_to_other_centers) if distances_to_other_centers else 0.0
    min_separation = avg_distance_to_others / min_compactness if min_compactness != 0 else 0.0
    
    return {
        'minority_compactness': min_compactness,
        'avg_distance_to_others': avg_distance_to_others,
        'minority_separation_ratio': min_separation
    }

# Minority class metrics
minority_metrics = {}

# Full dataset
minority_metrics['vit_full_pca'] = calculate_minority_class_metrics(vit_pca_full_result, vit_labels_full, min_class)
minority_metrics['resnet_full_pca'] = calculate_minority_class_metrics(resnet_pca_full_result, resnet_labels_full, min_class)
minority_metrics['mobilevit_full_pca'] = calculate_minority_class_metrics(mobilevit_pca_full_result, mobilevit_labels_full, min_class)
minority_metrics['vit_full_tsne'] = calculate_minority_class_metrics(vit_tsne_full_result, vit_labels_full, min_class)
minority_metrics['resnet_full_tsne'] = calculate_minority_class_metrics(resnet_tsne_full_result, resnet_labels_full, min_class)
minority_metrics['mobilevit_full_tsne'] = calculate_minority_class_metrics(mobilevit_tsne_full_result, mobilevit_labels_full, min_class)

# Imbalanced dataset
minority_metrics['vit_imbalanced_pca'] = calculate_minority_class_metrics(vit_pca_imbalanced_result, vit_labels_imbalanced, min_class)
minority_metrics['resnet_imbalanced_pca'] = calculate_minority_class_metrics(resnet_pca_imbalanced_result, resnet_labels_imbalanced, min_class)
minority_metrics['mobilevit_imbalanced_pca'] = calculate_minority_class_metrics(mobilevit_pca_imbalanced_result, mobilevit_labels_imbalanced, min_class)
minority_metrics['vit_imbalanced_tsne'] = calculate_minority_class_metrics(vit_tsne_imbalanced_result, vit_labels_imbalanced, min_class)
minority_metrics['resnet_imbalanced_tsne'] = calculate_minority_class_metrics(resnet_tsne_imbalanced_result, resnet_labels_imbalanced, min_class)
minority_metrics['mobilevit_imbalanced_tsne'] = calculate_minority_class_metrics(mobilevit_tsne_imbalanced_result, mobilevit_labels_imbalanced, min_class)

# Minority class metrics dataframe
minority_metrics_df = pd.DataFrame({
    'Model': ['ViT (Full, PCA)', 'ResNet (Full, PCA)', 'MobileViT (Full, PCA)', 
              'ViT (Full, t-SNE)', 'ResNet (Full, t-SNE)', 'MobileViT (Full, t-SNE)',
              'ViT (Imbalanced, PCA)', 'ResNet (Imbalanced, PCA)', 'MobileViT (Imbalanced, PCA)',
              'ViT (Imbalanced, t-SNE)', 'ResNet (Imbalanced, t-SNE)', 'MobileViT (Imbalanced, t-SNE)'],
    'Minority Compactness': [
        minority_metrics['vit_full_pca']['minority_compactness'],
        minority_metrics['resnet_full_pca']['minority_compactness'],
        minority_metrics['mobilevit_full_pca']['minority_compactness'],
        minority_metrics['vit_full_tsne']['minority_compactness'],
        minority_metrics['resnet_full_tsne']['minority_compactness'],
        minority_metrics['mobilevit_full_tsne']['minority_compactness'],
        minority_metrics['vit_imbalanced_pca']['minority_compactness'],
        minority_metrics['resnet_imbalanced_pca']['minority_compactness'],
        minority_metrics['mobilevit_imbalanced_pca']['minority_compactness'],
        minority_metrics['vit_imbalanced_tsne']['minority_compactness'],
        minority_metrics['resnet_imbalanced_tsne']['minority_compactness'],
        minority_metrics['mobilevit_imbalanced_tsne']['minority_compactness']
    ],
    'Avg Distance to Others': [
        minority_metrics['vit_full_pca']['avg_distance_to_others'],
        minority_metrics['resnet_full_pca']['avg_distance_to_others'],
        minority_metrics['mobilevit_full_pca']['avg_distance_to_others'],
        minority_metrics['vit_full_tsne']['avg_distance_to_others'],
        minority_metrics['resnet_full_tsne']['avg_distance_to_others'],
        minority_metrics['mobilevit_full_tsne']['avg_distance_to_others'],
        minority_metrics['vit_imbalanced_pca']['avg_distance_to_others'],
        minority_metrics['resnet_imbalanced_pca']['avg_distance_to_others'],
        minority_metrics['mobilevit_imbalanced_pca']['avg_distance_to_others'],
        minority_metrics['vit_imbalanced_tsne']['avg_distance_to_others'],
        minority_metrics['resnet_imbalanced_tsne']['avg_distance_to_others'],
        minority_metrics['mobilevit_imbalanced_tsne']['avg_distance_to_others']
    ],
    'Minority Separation Ratio': [
        minority_metrics['vit_full_pca']['minority_separation_ratio'],
        minority_metrics['resnet_full_pca']['minority_separation_ratio'],
        minority_metrics['mobilevit_full_pca']['minority_separation_ratio'],
        minority_metrics['vit_full_tsne']['minority_separation_ratio'],
        minority_metrics['resnet_full_tsne']['minority_separation_ratio'],
        minority_metrics['mobilevit_full_tsne']['minority_separation_ratio'],
        minority_metrics['vit_imbalanced_pca']['minority_separation_ratio'],
        minority_metrics['resnet_imbalanced_pca']['minority_separation_ratio'],
        minority_metrics['mobilevit_imbalanced_pca']['minority_separation_ratio'],
        minority_metrics['vit_imbalanced_tsne']['minority_separation_ratio'],
        minority_metrics['resnet_imbalanced_tsne']['minority_separation_ratio'],
        minority_metrics['mobilevit_imbalanced_tsne']['minority_separation_ratio']
    ]
})

# print("\nMinority Class Metrics:")
# print(minority_metrics_df)

minority_metrics_df.to_csv(os.path.join(output_dir, 'minority_class_metrics.csv'), index=False)
print(f"Saved minority class metrics to {os.path.join(output_dir, 'minority_class_metrics.csv')}")

# Plot minority class metrics
plt.figure(figsize=(20, 10))
minority_metrics_plot = pd.melt(minority_metrics_df, id_vars=['Model'], 
                              value_vars=['Minority Compactness', 'Avg Distance to Others', 'Minority Separation Ratio'],
                              var_name='Metric', value_name='Value')

sns.barplot(x='Model', y='Value', hue='Metric', data=minority_metrics_plot)
plt.title('Minority Class Metrics Comparison', fontsize=18)
plt.xticks(rotation=45, ha='right')
plt.legend(title='Metric', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'minority_class_metrics_comparison.png'))
plt.close()

# Imbalance impact
imbalance_impact = pd.DataFrame({
    'Model': ['ViT (PCA)', 'ResNet (PCA)', 'MobileViT (PCA)', 'ViT (t-SNE)', 'ResNet (t-SNE)', 'MobileViT (t-SNE)'],
    'Separation Ratio Change (%)': [
        (metrics['vit_imbalanced_pca']['separation_ratio'] / metrics['vit_full_pca']['separation_ratio'] - 1) * 100,
        (metrics['resnet_imbalanced_pca']['separation_ratio'] / metrics['resnet_full_pca']['separation_ratio'] - 1) * 100,
        (metrics['mobilevit_imbalanced_pca']['separation_ratio'] / metrics['mobilevit_full_pca']['separation_ratio'] - 1) * 100,
        (metrics['vit_imbalanced_tsne']['separation_ratio'] / metrics['vit_full_tsne']['separation_ratio'] - 1) * 100,
        (metrics['resnet_imbalanced_tsne']['separation_ratio'] / metrics['resnet_full_tsne']['separation_ratio'] - 1) * 100,
        (metrics['mobilevit_imbalanced_tsne']['separation_ratio'] / metrics['mobilevit_full_tsne']['separation_ratio'] - 1) * 100
    ],
    'Minority Separation Change (%)': [
        (minority_metrics['vit_imbalanced_pca']['minority_separation_ratio'] / 
         minority_metrics['vit_full_pca']['minority_separation_ratio'] - 1) * 100 if minority_metrics['vit_full_pca']['minority_separation_ratio'] != 0 else 0.0,
        (minority_metrics['resnet_imbalanced_pca']['minority_separation_ratio'] / 
         minority_metrics['resnet_full_pca']['minority_separation_ratio'] - 1) * 100 if minority_metrics['resnet_full_pca']['minority_separation_ratio'] != 0 else 0.0,
        (minority_metrics['mobilevit_imbalanced_pca']['minority_separation_ratio'] / 
         minority_metrics['mobilevit_full_pca']['minority_separation_ratio'] - 1) * 100 if minority_metrics['mobilevit_full_pca']['minority_separation_ratio'] != 0 else 0.0,
        (minority_metrics['vit_imbalanced_tsne']['minority_separation_ratio'] / 
         minority_metrics['vit_full_tsne']['minority_separation_ratio'] - 1) * 100 if minority_metrics['vit_full_tsne']['minority_separation_ratio'] != 0 else 0.0,
        (minority_metrics['resnet_imbalanced_tsne']['minority_separation_ratio'] / 
         minority_metrics['resnet_full_tsne']['minority_separation_ratio'] - 1) * 100 if minority_metrics['resnet_full_tsne']['minority_separation_ratio'] != 0 else 0.0,
        (minority_metrics['mobilevit_imbalanced_tsne']['minority_separation_ratio'] / 
         minority_metrics['mobilevit_full_tsne']['minority_separation_ratio'] - 1) * 100 if minority_metrics['mobilevit_full_tsne']['minority_separation_ratio'] != 0 else 0.0
    ]
})

# print("\nImbalance Impact (% change from full to imbalanced dataset):")
# print(imbalance_impact)

imbalance_impact.to_csv(os.path.join(output_dir, 'imbalance_impact.csv'), index=False)
print(f"Saved imbalance impact analysis to {os.path.join(output_dir, 'imbalance_impact.csv')}")

# Plot
plt.figure(figsize=(16, 8))
imbalance_impact_plot = pd.melt(imbalance_impact, id_vars=['Model'], 
                              value_vars=['Separation Ratio Change (%)', 'Minority Separation Change (%)'],
                              var_name='Metric', value_name='Percent Change')

sns.barplot(x='Model', y='Percent Change', hue='Metric', data=imbalance_impact_plot)
plt.title('Impact of Class Imbalance on Feature Space Metrics', fontsize=18)
plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
plt.ylabel('Percent Change (%)')
plt.legend(title='Metric', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'imbalance_impact.png'))
plt.close()

# Comparing ViT, ResNet, and MobileViT directly in terms of handling imbalance
model_comparison = pd.DataFrame({
    'Technique': ['PCA', 't-SNE'],
    'ViT Overall Separation Ratio': [
        metrics['vit_imbalanced_pca']['separation_ratio'],
        metrics['vit_imbalanced_tsne']['separation_ratio']
    ],
    'ResNet Overall Separation Ratio': [
        metrics['resnet_imbalanced_pca']['separation_ratio'],
        metrics['resnet_imbalanced_tsne']['separation_ratio']
    ],
    'MobileViT Overall Separation Ratio': [
        metrics['mobilevit_imbalanced_pca']['separation_ratio'],
        metrics['mobilevit_imbalanced_tsne']['separation_ratio']
    ],
    'ViT Minority Separation Ratio': [
        minority_metrics['vit_imbalanced_pca']['minority_separation_ratio'],
        minority_metrics['vit_imbalanced_tsne']['minority_separation_ratio']
    ],
    'ResNet Minority Separation Ratio': [
        minority_metrics['resnet_imbalanced_pca']['minority_separation_ratio'],
        minority_metrics['resnet_imbalanced_tsne']['minority_separation_ratio']
    ],
    'MobileViT Minority Separation Ratio': [
        minority_metrics['mobilevit_imbalanced_pca']['minority_separation_ratio'],
        minority_metrics['mobilevit_imbalanced_tsne']['minority_separation_ratio']
    ],
    'ViT Imbalance Impact (%)': [
        (metrics['vit_imbalanced_pca']['separation_ratio'] / metrics['vit_full_pca']['separation_ratio'] - 1) * 100,
        (metrics['vit_imbalanced_tsne']['separation_ratio'] / metrics['vit_full_tsne']['separation_ratio'] - 1) * 100
    ],
    'ResNet Imbalance Impact (%)': [
        (metrics['resnet_imbalanced_pca']['separation_ratio'] / metrics['resnet_full_pca']['separation_ratio'] - 1) * 100,
        (metrics['resnet_imbalanced_tsne']['separation_ratio'] / metrics['resnet_full_tsne']['separation_ratio'] - 1) * 100
    ],
    'MobileViT Imbalance Impact (%)': [
        (metrics['mobilevit_imbalanced_pca']['separation_ratio'] / metrics['mobilevit_full_pca']['separation_ratio'] - 1) * 100,
        (metrics['mobilevit_imbalanced_tsne']['separation_ratio'] / metrics['mobilevit_full_tsne']['separation_ratio'] - 1) * 100
    ]
})

print("\nDirectly Comparing ViT, ResNet, and MobileViT on Imbalanced Data:")
print(model_comparison)

model_comparison.to_csv(os.path.join(output_dir, 'vit_resnet_mobilevit_imbalance.csv'), index=False)
print(f"Saved ViT, ResNet, and MobileViT comparison to {os.path.join(output_dir, 'vit_resnet_mobilevit_imbalance.csv')}")

plt.figure(figsize=(16, 8))
sns.barplot(x='Technique', y='value', hue='variable', 
           data=pd.melt(model_comparison, id_vars=['Technique'], 
                      value_vars=['ViT Imbalance Impact (%)', 'ResNet Imbalance Impact (%)', 'MobileViT Imbalance Impact (%)']))
plt.title('ViT vs ResNet vs MobileViT: Impact of Class Imbalance on Separation', fontsize=18)
plt.axhline(y=0, color='r', linestyle='-', alpha=0.3)
plt.ylabel('Percent Change in Separation Ratio (%)')
plt.legend(title='Model')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'vit_resnet_mobilevit_imbalance_impact.png'))
plt.close()

plt.figure(figsize=(18, 16))

comprehensive_df = pd.DataFrame({
    'Model': ['ViT', 'ResNet', 'MobileViT', 'ViT', 'ResNet', 'MobileViT'],
    'Dataset': ['Full', 'Full', 'Full', 'Imbalanced', 'Imbalanced', 'Imbalanced'],
    'Overall Separation (PCA)': [
        metrics['vit_full_pca']['separation_ratio'],
        metrics['resnet_full_pca']['separation_ratio'],
        metrics['mobilevit_full_pca']['separation_ratio'],
        metrics['vit_imbalanced_pca']['separation_ratio'],
        metrics['resnet_imbalanced_pca']['separation_ratio'],
        metrics['mobilevit_imbalanced_pca']['separation_ratio']
    ],
    'Overall Separation (t-SNE)': [
        metrics['vit_full_tsne']['separation_ratio'],
        metrics['resnet_full_tsne']['separation_ratio'],
        metrics['mobilevit_full_tsne']['separation_ratio'],
        metrics['vit_imbalanced_tsne']['separation_ratio'],
        metrics['resnet_imbalanced_tsne']['separation_ratio'],
        metrics['mobilevit_imbalanced_tsne']['separation_ratio']
    ],
    'Minority Separation (PCA)': [
        minority_metrics['vit_full_pca']['minority_separation_ratio'],
        minority_metrics['resnet_full_pca']['minority_separation_ratio'],
        minority_metrics['mobilevit_full_pca']['minority_separation_ratio'],
        minority_metrics['vit_imbalanced_pca']['minority_separation_ratio'],
        minority_metrics['resnet_imbalanced_pca']['minority_separation_ratio'],
        minority_metrics['mobilevit_imbalanced_pca']['minority_separation_ratio']
    ],
    'Minority Separation (t-SNE)': [
        minority_metrics['vit_full_tsne']['minority_separation_ratio'],
        minority_metrics['resnet_full_tsne']['minority_separation_ratio'],
        minority_metrics['mobilevit_full_tsne']['minority_separation_ratio'],
        minority_metrics['vit_imbalanced_tsne']['minority_separation_ratio'],
        minority_metrics['resnet_imbalanced_tsne']['minority_separation_ratio'],
        minority_metrics['mobilevit_imbalanced_tsne']['minority_separation_ratio']
    ]
})

plot_data = pd.melt(comprehensive_df, 
                  id_vars=['Model', 'Dataset'], 
                  value_vars=['Overall Separation (PCA)', 'Overall Separation (t-SNE)', 
                            'Minority Separation (PCA)', 'Minority Separation (t-SNE)'],
                  var_name='Metric', value_name='Value')

g = sns.catplot(
    data=plot_data, kind="bar",
    x="Dataset", y="Value", hue="Model", col="Metric",
    height=5, aspect=0.8, sharey=False, col_wrap=2
)

g.fig.suptitle('Comprehensive Feature Space Analysis: ViT vs ResNet vs MobileViT', fontsize=20, y=1.05)
g.set_titles(col_template="{col_name}")
g.set_axis_labels("Dataset", "Separation Ratio")
g.tight_layout()
plt.savefig(os.path.join(output_dir, 'comprehensive_comparison.png'))
plt.close()

print(f"\nAll visualizations and analyses saved to {output_dir}")

# Class Imbalance and OOD Exploration

In [None]:
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from tqdm.notebook import tqdm

vit_bundle_path = "expt_1/vit_tiny_model_bundle.pth"
resnet_bundle_path = "expt_1/resnet50_model_bundle.pth"
mobilevit_bundle_path = "expt_1/mobilevit_xs_model_bundle.pth"
data_dir = "NSD/test"
output_dir = "expt_1/imbalance_ood_results"
os.makedirs(output_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_seed(seed=48):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed()

def load_model_bundle(bundle_path, model_name):
    bundle = torch.load(bundle_path, map_location="cpu", weights_only=False)
    if bundle["model_name"] != model_name:
        raise ValueError(f"Expected {model_name}, got {bundle['model_name']}")
    if "resnet" in model_name:
        model = models.resnet50(weights=None)
        model.fc = nn.Linear(model.fc.in_features, bundle["num_classes"])
    else:
        model = timm.create_model(model_name, pretrained=False, num_classes=bundle["num_classes"])
    model.load_state_dict(bundle["state_dict"])
    model = model.to(device)
    model.eval()
    return model, bundle["transform"]

print("Loading ViT model...")
vit_model, vit_transform = load_model_bundle(vit_bundle_path, "vit_tiny_patch16_224")
print("Loading ResNet model...")
resnet_model, resnet_transform = load_model_bundle(resnet_bundle_path, "resnet50")
print("Loading MobileViT model...")
mobilevit_model, mobilevit_transform = load_model_bundle(mobilevit_bundle_path, "mobilevit_xs")

print("Loading dataset...")
dataset = datasets.ImageFolder(data_dir, transform=vit_transform)
class_names = dataset.classes
class_counts = np.bincount([label for _, label in dataset.samples])
print("Class distribution:", dict(zip(class_names, class_counts)))

print("Creating imbalanced dataset...")
min_class = np.argmin(class_counts)
print(f"Minority class: {class_names[min_class]} with {class_counts[min_class]} samples")
min_indices = [i for i, (_, lbl) in enumerate(dataset.samples) if lbl == min_class]
min_keep_indices = np.random.choice(min_indices, size=len(min_indices) // 2, replace=False)
other_indices = [i for i, (_, lbl) in enumerate(dataset.samples) if lbl != min_class]
imbalanced_indices = other_indices + list(min_keep_indices)
imbalanced_dataset = Subset(dataset, imbalanced_indices)

print("Creating OOD datasets...")
ood_class = 0
print(f"OOD class: {class_names[ood_class]}")
ood_train_indices = [i for i, (_, lbl) in enumerate(dataset.samples) if lbl != ood_class]
ood_test_indices = [i for i, (_, lbl) in enumerate(dataset.samples) if lbl == ood_class]
ood_train_dataset = Subset(dataset, ood_train_indices)
ood_test_dataset = Subset(dataset, ood_test_indices)

batch_size = 32
full_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
imbalanced_loader = DataLoader(imbalanced_dataset, batch_size=batch_size, shuffle=False)
ood_train_loader = DataLoader(ood_train_dataset, batch_size=batch_size, shuffle=False)
ood_test_loader = DataLoader(ood_test_dataset, batch_size=batch_size, shuffle=False)

def plot_class_distribution(dataset_subset, title, save_path=None):
    if isinstance(dataset_subset, Subset):
        labels = [dataset_subset.dataset.targets[i] for i in dataset_subset.indices]
    else:
        labels = [y for _, y in dataset_subset.samples]
    class_dist = np.bincount(labels, minlength=len(class_names))
    plt.figure(figsize=(10, 6))
    plt.bar(range(len(class_names)), class_dist)
    plt.xticks(range(len(class_names)), class_names, rotation=45)
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.title(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.close()

def plot_confusion_matrix(y_true, y_pred, title, save_path=None):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.close()

def evaluate_model(model, dataloader, model_name, dataset_name):
    model.eval()
    all_preds = []
    all_labels = []
    start_time = time.time()
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc=f"Evaluating {model_name} on {dataset_name}"):
            images = images.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    inference_time = time.time() - start_time
    metrics = {
        "accuracy": accuracy_score(all_labels, all_preds) * 100,
        "precision": precision_score(all_labels, all_preds, average="macro", zero_division=0),
        "recall": recall_score(all_labels, all_preds, average="macro", zero_division=0),
        "f1": f1_score(all_labels, all_preds, average="macro", zero_division=0),
        "inference_time": inference_time
    }
    print(f"{model_name} on {dataset_name}:")
    for k, v in metrics.items():
        unit = "s" if "time" in k else ""
        print(f"  {k.capitalize()}: {v:.2f}{unit}")
    plot_confusion_matrix(
        all_labels, 
        all_preds, 
        f"{model_name} - {dataset_name} Confusion Matrix",
        f"{output_dir}/{model_name.lower()}_{dataset_name}_confusion_matrix.png"
    )
    return metrics, all_preds, all_labels

print("\nClass distribution analysis:")
plot_class_distribution(dataset, "Full Dataset Class Distribution", f"{output_dir}/full_class_distribution.png")
plot_class_distribution(imbalanced_dataset, "Imbalanced Dataset Class Distribution", f"{output_dir}/imbalanced_class_distribution.png")
plot_class_distribution(ood_train_dataset, "OOD Training Dataset Class Distribution", f"{output_dir}/ood_train_class_distribution.png")
plot_class_distribution(ood_test_dataset, "OOD Test Dataset Class Distribution", f"{output_dir}/ood_test_class_distribution.png")

results = {}

print("\n===== Full Dataset Evaluation =====")
results["vit_full"] = evaluate_model(vit_model, full_loader, "ViT", "Full Dataset")
results["resnet_full"] = evaluate_model(resnet_model, full_loader, "ResNet", "Full Dataset")
results["mobilevit_full"] = evaluate_model(mobilevit_model, full_loader, "MobileViT", "Full Dataset")

print("\n===== Imbalanced Dataset Evaluation =====")
results["vit_imbalanced"] = evaluate_model(vit_model, imbalanced_loader, "ViT", "Imbalanced Dataset")
results["resnet_imbalanced"] = evaluate_model(resnet_model, imbalanced_loader, "ResNet", "Imbalanced Dataset")
results["mobilevit_imbalanced"] = evaluate_model(mobilevit_model, imbalanced_loader, "MobileViT", "Imbalanced Dataset")

print("\n===== OOD Evaluation =====")
print("Training dataset (class excluded):")
results["vit_ood_train"] = evaluate_model(vit_model, ood_train_loader, "ViT", "OOD Train")
results["resnet_ood_train"] = evaluate_model(resnet_model, ood_train_loader, "ResNet", "OOD Train")
results["mobilevit_ood_train"] = evaluate_model(mobilevit_model, ood_train_loader, "MobileViT", "OOD Train")

print("Test dataset (excluded class only):")
results["vit_ood_test"] = evaluate_model(vit_model, ood_test_loader, "ViT", "OOD Test")
results["resnet_ood_test"] = evaluate_model(resnet_model, ood_test_loader, "ResNet", "OOD Test")
results["mobilevit_ood_test"] = evaluate_model(mobilevit_model, ood_test_loader, "MobileViT", "OOD Test")

summary = []
for model_name in ["ViT", "ResNet", "MobileViT"]:
    model_key = model_name.lower()
    for dataset_name in ["Full Dataset", "Imbalanced Dataset", "OOD Train", "OOD Test"]:
        dataset_key = dataset_name.lower().replace(" ", "_")
        key = f"{model_key}_{dataset_key}"
        if key in results:
            metrics, _, _ = results[key]
            summary.append({
                "Model": model_name,
                "Dataset": dataset_name,
                "Accuracy": metrics["accuracy"],
                "Precision": metrics["precision"],
                "Recall": metrics["recall"],
                "F1 Score": metrics["f1"],
                "Inference Time": metrics["inference_time"]
            })

models = ["ViT", "ResNet", "MobileViT"]
datasets = ["Full Dataset", "Imbalanced Dataset", "OOD Train", "OOD Test"]
metrics = ["Accuracy", "Precision", "Recall", "F1 Score", "Inference Time"]

plot_data = {}
for metric in metrics:
    plot_data[metric] = np.zeros((len(models), len(datasets)))
    for i, model in enumerate(models):
        for j, dataset in enumerate(datasets):
            for entry in summary:
                if entry["Model"] == model and entry["Dataset"] == dataset:
                    plot_data[metric][i, j] = entry[metric]

for metric in metrics:
    plt.figure(figsize=(12, 8))
    x = np.arange(len(datasets))
    width = 0.25
    for i, model in enumerate(models):
        offset = width * (i - 1)
        plt.bar(x + offset, plot_data[metric][i], width, label=model)
    plt.xlabel('Dataset')
    plt.ylabel(metric)
    plt.title(f'Model Comparison - {metric}')
    plt.xticks(x, datasets)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{output_dir}/comparison_{metric.lower().replace(' ', '_')}.png")
    plt.close()

print(f"All results saved to {output_dir}")

In [None]:
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import torchvision.models as models
import torchvision.datasets as datasets
import timm
import time
from tqdm.notebook import tqdm

vit_bundle_path = "expt_1/vit_tiny_model_bundle.pth"
resnet_bundle_path = "expt_1/resnet50_model_bundle.pth"
mobilevit_bundle_path = "expt_1/mobilevit_xs_model_bundle.pth"
data_dir = "NSD/test"
output_dir = "expt_1/imbalance_ood_results"
os.makedirs(output_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_seed(seed=48):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed()

def load_model_bundle(bundle_path, model_name):
    bundle = torch.load(bundle_path, map_location="cpu", weights_only=False)
    if bundle["model_name"] != model_name:
        raise ValueError(f"Expected {model_name}, got {bundle['model_name']}")
    if "resnet" in model_name:
        model = models.resnet50(weights=None)
        model.fc = nn.Linear(model.fc.in_features, bundle["num_classes"])
    else:
        model = timm.create_model(model_name, pretrained=False, num_classes=bundle["num_classes"])
    model.load_state_dict(bundle["state_dict"])
    model = model.to(device)
    model.eval()
    return model, bundle["transform"]

print("Loading ViT model...")
vit_model, vit_transform = load_model_bundle(vit_bundle_path, "vit_tiny_patch16_224")
print("Loading ResNet model...")
resnet_model, resnet_transform = load_model_bundle(resnet_bundle_path, "resnet50")
print("Loading MobileViT model...")
mobilevit_model, mobilevit_transform = load_model_bundle(mobilevit_bundle_path, "mobilevit_xs")

print("Loading dataset...")
dataset = datasets.ImageFolder(data_dir, transform=vit_transform)
class_names = dataset.classes
class_counts = np.bincount([label for _, label in dataset.samples])
print("Class distribution:", dict(zip(class_names, class_counts)))

print("Creating imbalanced dataset...")
min_class = np.argmin(class_counts)
print(f"Minority class: {class_names[min_class]} with {class_counts[min_class]} samples")
min_indices = [i for i, (_, lbl) in enumerate(dataset.samples) if lbl == min_class]
min_keep_indices = np.random.choice(min_indices, size=len(min_indices) // 2, replace=False)
other_indices = [i for i, (_, lbl) in enumerate(dataset.samples) if lbl != min_class]
imbalanced_indices = other_indices + list(min_keep_indices)
imbalanced_dataset = Subset(dataset, imbalanced_indices)

print("Creating OOD datasets...")
ood_class = 0
print(f"OOD class: {class_names[ood_class]}")
ood_train_indices = [i for i, (_, lbl) in enumerate(dataset.samples) if lbl != ood_class]
ood_test_indices = [i for i, (_, lbl) in enumerate(dataset.samples) if lbl == ood_class]
ood_train_dataset = Subset(dataset, ood_train_indices)
ood_test_dataset = Subset(dataset, ood_test_indices)

batch_size = 32
full_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
imbalanced_loader = DataLoader(imbalanced_dataset, batch_size=batch_size, shuffle=False)
ood_train_loader = DataLoader(ood_train_dataset, batch_size=batch_size, shuffle=False)
ood_test_loader = DataLoader(ood_test_dataset, batch_size=batch_size, shuffle=False)

def plot_class_distribution(dataset_subset, title, save_path=None):
    if isinstance(dataset_subset, Subset):
        labels = [dataset_subset.dataset.targets[i] for i in dataset_subset.indices]
    else:
        labels = [y for _, y in dataset_subset.samples]
    class_dist = np.bincount(labels, minlength=len(class_names))
    plt.figure(figsize=(10, 6))
    plt.bar(range(len(class_names)), class_dist)
    plt.xticks(range(len(class_names)), class_names, rotation=45)
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.title(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.close()

def plot_confusion_matrix(y_true, y_pred, title, save_path=None):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.close()

def evaluate_model(model, dataloader, model_name, dataset_name):
    model.eval()
    all_preds = []
    all_labels = []
    start_time = time.time()
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc=f"Evaluating {model_name} on {dataset_name}"):
            images = images.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    inference_time = time.time() - start_time
    
    # Calculate metrics only if there are predictions
    if len(all_preds) > 0:
        metrics = {
            "accuracy": accuracy_score(all_labels, all_preds) * 100,
            "precision": precision_score(all_labels, all_preds, average="macro", zero_division=0),
            "recall": recall_score(all_labels, all_preds, average="macro", zero_division=0),
            "f1": f1_score(all_labels, all_preds, average="macro", zero_division=0),
            "inference_time": inference_time
        }
    else:
        metrics = {
            "accuracy": 0,
            "precision": 0,
            "recall": 0,
            "f1": 0,
            "inference_time": 0
        }
    
    print(f"{model_name} on {dataset_name}:")
    for k, v in metrics.items():
        unit = "s" if "time" in k else ""
        print(f"  {k.capitalize()}: {v:.2f}{unit}")
    
    if len(all_preds) > 0:
        plot_confusion_matrix(
            all_labels, 
            all_preds, 
            f"{model_name} - {dataset_name} Confusion Matrix",
            f"{output_dir}/{model_name.lower()}_{dataset_name.lower().replace(' ', '_')}_confusion_matrix.png"
        )
    
    return metrics, all_preds, all_labels

print("\nClass distribution analysis:")
plot_class_distribution(dataset, "Full Dataset Class Distribution", f"{output_dir}/full_class_distribution.png")
plot_class_distribution(imbalanced_dataset, "Imbalanced Dataset Class Distribution", f"{output_dir}/imbalanced_class_distribution.png")
plot_class_distribution(ood_train_dataset, "OOD Training Dataset Class Distribution", f"{output_dir}/ood_train_class_distribution.png")
plot_class_distribution(ood_test_dataset, "OOD Test Dataset Class Distribution", f"{output_dir}/ood_test_class_distribution.png")

results = {}

print("\n===== Full Dataset Evaluation =====")
results["vit_full_dataset"] = evaluate_model(vit_model, full_loader, "ViT", "Full Dataset")
results["resnet_full_dataset"] = evaluate_model(resnet_model, full_loader, "ResNet", "Full Dataset")
results["mobilevit_full_dataset"] = evaluate_model(mobilevit_model, full_loader, "MobileViT", "Full Dataset")

print("\n===== Imbalanced Dataset Evaluation =====")
results["vit_imbalanced_dataset"] = evaluate_model(vit_model, imbalanced_loader, "ViT", "Imbalanced Dataset")
results["resnet_imbalanced_dataset"] = evaluate_model(resnet_model, imbalanced_loader, "ResNet", "Imbalanced Dataset")
results["mobilevit_imbalanced_dataset"] = evaluate_model(mobilevit_model, imbalanced_loader, "MobileViT", "Imbalanced Dataset")

print("\n===== OOD Evaluation =====")
print("Training dataset (class excluded):")
results["vit_ood_train"] = evaluate_model(vit_model, ood_train_loader, "ViT", "OOD Train")
results["resnet_ood_train"] = evaluate_model(resnet_model, ood_train_loader, "ResNet", "OOD Train")
results["mobilevit_ood_train"] = evaluate_model(mobilevit_model, ood_train_loader, "MobileViT", "OOD Train")

print("Test dataset (excluded class only):")
results["vit_ood_test"] = evaluate_model(vit_model, ood_test_loader, "ViT", "OOD Test")
results["resnet_ood_test"] = evaluate_model(resnet_model, ood_test_loader, "ResNet", "OOD Test")
results["mobilevit_ood_test"] = evaluate_model(mobilevit_model, ood_test_loader, "MobileViT", "OOD Test")

# Fix: Create a proper summary from results
summary = []
for key, (metrics, _, _) in results.items():
    model_name = key.split("_")[0]
    model_name = "ViT" if model_name == "vit" else "ResNet" if model_name == "resnet" else "MobileViT"
    
    # Parse dataset name from key
    if "full_dataset" in key:
        dataset_name = "Full Dataset"
    elif "imbalanced_dataset" in key:
        dataset_name = "Imbalanced Dataset"
    elif "ood_train" in key:
        dataset_name = "OOD Train"
    elif "ood_test" in key:
        dataset_name = "OOD Test"
    else:
        dataset_name = "Unknown"
    
    summary.append({
        "Model": model_name,
        "Dataset": dataset_name,
        "Accuracy": metrics["accuracy"],
        "Precision": metrics["precision"],
        "Recall": metrics["recall"],
        "F1 Score": metrics["f1"],
        "Inference Time": metrics["inference_time"]
    })

models = ["ViT", "ResNet", "MobileViT"]
datasets = ["Full Dataset", "Imbalanced Dataset", "OOD Train", "OOD Test"]
metrics = ["Accuracy", "Precision", "Recall", "F1 Score", "Inference Time"]

# Fix: Initialize plot_data properly
plot_data = {}
for metric in metrics:
    plot_data[metric] = np.zeros((len(models), len(datasets)))

# Fix: Populate plot_data correctly
for entry in summary:
    model_idx = models.index(entry["Model"])
    dataset_idx = datasets.index(entry["Dataset"])
    for metric in metrics:
        plot_data[metric][model_idx, dataset_idx] = entry[metric]

# Create bar plots for each metric
for metric in metrics:
    plt.figure(figsize=(12, 8))
    x = np.arange(len(datasets))
    width = 0.25
    
    for i, model in enumerate(models):
        offset = width * (i - 1)
        plt.bar(x + offset, plot_data[metric][i], width, label=model)
    
    plt.xlabel('Dataset')
    plt.ylabel(metric)
    plt.title(f'Model Comparison - {metric}')
    plt.xticks(x, datasets)
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{output_dir}/comparison_{metric.lower().replace(' ', '_')}.png")
    plt.close()

print(f"All results saved to {output_dir}")

# Interactive Gradcam

In [None]:
vit_bundle_path = "expt_1/vit_tiny_model_bundle.pth"
resnet_bundle_path = "expt_1/resnet50_model_bundle.pth"
mobilevit_bundle_path = "expt_1/mobilevit_xs_model_bundle.pth"
data_dir = "NSD/test"
output_dir = "gradcam_visualization"

os.makedirs(output_dir, exist_ok=True)

class ResNetGradCAM:
    def __init__(self, model, target_layer_name="layer4"):
        self.model = model
        self.model.eval()
        
        
        if target_layer_name == "layer4":
            self.target_layer = model.layer4
        else:
            raise ValueError(f"Target layer {target_layer_name} not found in model")
        
        self.gradients = None
        self.activations = None
        
        
        self.hooks = []
        self.register_hooks()
        
    def register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output
            
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]
        
        self.hooks.append(self.target_layer.register_forward_hook(forward_hook))
        self.hooks.append(self.target_layer.register_full_backward_hook(backward_hook))
        
    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        
    def generate_cam(self, input_image, target_class=None):
        model_output = self.model(input_image)
        
        if target_class is None:
            target_class = torch.argmax(model_output, dim=1).item()
        
        self.model.zero_grad()
        one_hot = torch.zeros_like(model_output)
        one_hot[0, target_class] = 1
        model_output.backward(gradient=one_hot, retain_graph=True)
        
        gradients = self.gradients.detach().cpu().numpy()[0]
        activations = self.activations.detach().cpu().numpy()[0]
        weights = np.mean(gradients, axis=(1, 2))
        
        cam = np.zeros(activations.shape[1:], dtype=np.float32)
        for i, w in enumerate(weights):
            cam += w * activations[i]
        
        cam = np.maximum(cam, 0)
        if np.max(cam) > 0:
            cam = cam / np.max(cam)
        
        cam = cv2.resize(cam, (input_image.shape[3], input_image.shape[2]))
        return cam, target_class


class ViTGradCAM:
    def __init__(self, model):
        self.model = model
        self.model.eval()
        self.target_layer = model.blocks[-1].norm1
        self.gradients = None
        self.activations = None
        self.hooks = []
        self.register_hooks()
        
    def register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output
            
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]
        
        self.hooks.append(self.target_layer.register_forward_hook(forward_hook))
        self.hooks.append(self.target_layer.register_full_backward_hook(backward_hook))
        
    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        
    def generate_cam(self, input_image, target_class=None):
        model_output = self.model(input_image)
        
        if target_class is None:
            target_class = torch.argmax(model_output, dim=1).item()
        
        self.model.zero_grad()
        one_hot = torch.zeros_like(model_output)
        one_hot[0, target_class] = 1
        model_output.backward(gradient=one_hot, retain_graph=True)
        
        gradients = self.gradients.detach().cpu().numpy()[0]
        activations = self.activations.detach().cpu().numpy()[0]
        gradients = gradients[1:]
        activations = activations[1:]
        weights = np.mean(gradients, axis=1)
        cam = np.mean(activations, axis=1) * weights
        
        patch_size = 16
        num_patches = int(np.sqrt(cam.shape[0]))
        cam = cam.reshape(num_patches, num_patches)
        cam = np.maximum(cam, 0)
        if np.max(cam) > 0:
            cam = cam / np.max(cam)
        
        cam = cv2.resize(cam, (input_image.shape[3], input_image.shape[2]))
        return cam, target_class


class MobileViTGradCAM:
    def __init__(self, model):
        self.model = model
        self.model.eval()
        
        self.target_layer = model.stages[-1]
        self.gradients = None
        self.activations = None
        self.hooks = []
        self.register_hooks()
        
    def register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output
            
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]
        
        self.hooks.append(self.target_layer.register_forward_hook(forward_hook))
        self.hooks.append(self.target_layer.register_full_backward_hook(backward_hook))
        
    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        
    def generate_cam(self, input_image, target_class=None):
        model_output = self.model(input_image)
        
        if target_class is None:
            target_class = torch.argmax(model_output, dim=1).item()
        
        self.model.zero_grad()
        one_hot = torch.zeros_like(model_output)
        one_hot[0, target_class] = 1
        model_output.backward(gradient=one_hot, retain_graph=True)
        
        gradients = self.gradients.detach().cpu().numpy()[0]
        activations = self.activations.detach().cpu().numpy()[0]
        weights = np.mean(gradients, axis=(1, 2))
        
        cam = np.zeros(activations.shape[1:], dtype=np.float32)
        for i, w in enumerate(weights):
            cam += w * activations[i]
        
        cam = np.maximum(cam, 0)
        if np.max(cam) > 0:
            cam = cam / np.max(cam)
        
        cam = cv2.resize(cam, (input_image.shape[3], input_image.shape[2]))
        return cam, target_class

def load_model_bundle(bundle_path):
    """
    Load a model bundle and return the model, transform, and model type
    """
    bundle = torch.load(bundle_path, map_location="cpu", weights_only=False)
    model_name = bundle["model_name"]
    
    if "vit" in model_name and "mobilevit" not in model_name:
        model = timm.create_model(model_name, pretrained=False, num_classes=bundle["num_classes"])
        model_type = "vit"
    elif "resnet" in model_name:
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, bundle["num_classes"])
        model_type = "resnet"
    elif "mobilevit" in model_name:
        model = timm.create_model(model_name, pretrained=False, num_classes=bundle["num_classes"])
        model_type = "mobilevit"
    else:
        raise ValueError(f"Unsupported model type: {model_name}")
    
    model.load_state_dict(bundle["state_dict"])
    model.eval()
    
    return model, bundle["transform"], model_type

def show_cam_on_image(img, mask, alpha=0.5):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    img_uint8 = np.uint8(255 * img)
    cam_img = heatmap * alpha + img_uint8 * (1 - alpha)
    cam_img = np.uint8(cam_img)
    return cam_img

def get_feature_map(model, image, layer_name, model_type):
    model.eval()
    features = []
    
    def hook(module, input, output):
        features.append(output.detach())
    
    if model_type == "vit":
        handle = model.blocks[-1].register_forward_hook(hook)
    elif model_type == "resnet":
        handle = model.layer4.register_forward_hook(hook)
    elif model_type == "mobilevit":
        handle = model.stages[-1].register_forward_hook(hook)
    
    with torch.no_grad():
        model(image.unsqueeze(0))
    handle.remove()
    
    feature_map = features[0][0].cpu().numpy()
    if model_type == "vit":
        if feature_map.shape[0] == 197:
            feature_map = feature_map[1:]
        feature_map = feature_map.mean(axis=1)
        if feature_map.shape[0] == 196:
            feature_map = feature_map.reshape(14, 14)
    elif model_type in ["resnet", "mobilevit"]:
        feature_map = feature_map.mean(axis=0)
    
    return feature_map

def compute_accuracy(model, test_loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, preds = outputs.max(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return accuracy_score(all_labels, all_preds)

def interactive_gradcam_viewer():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    
    print("Loading ViT model...")
    vit_model, vit_transform, vit_type = load_model_bundle(vit_bundle_path)
    vit_model = vit_model.to(device)
    
    print("Loading ResNet model...")
    resnet_model, resnet_transform, resnet_type = load_model_bundle(resnet_bundle_path)
    resnet_model = resnet_model.to(device)
    
    print("Loading MobileViT model...")
    mobilevit_model, mobilevit_transform, mobilevit_type = load_model_bundle(mobilevit_bundle_path)
    mobilevit_model = mobilevit_model.to(device)
    
    
    vit_gradcam = ViTGradCAM(vit_model)
    resnet_gradcam = ResNetGradCAM(resnet_model)
    mobilevit_gradcam = MobileViTGradCAM(mobilevit_model)
    
    
    vit_dataset = datasets.ImageFolder(data_dir, transform=vit_transform)
    resnet_dataset = datasets.ImageFolder(data_dir, transform=resnet_transform)
    mobilevit_dataset = datasets.ImageFolder(data_dir, transform=mobilevit_transform)
    class_names = vit_dataset.classes
    
    
    batch_size = 32
    vit_loader = DataLoader(vit_dataset, batch_size=batch_size, shuffle=False)
    resnet_loader = DataLoader(resnet_dataset, batch_size=batch_size, shuffle=False)
    mobilevit_loader = DataLoader(mobilevit_dataset, batch_size=batch_size, shuffle=False)
    
    
    vit_acc = compute_accuracy(vit_model, vit_loader, device)
    resnet_acc = compute_accuracy(resnet_model, resnet_loader, device)
    mobilevit_acc = compute_accuracy(mobilevit_model, mobilevit_loader, device)
    print(f"ViT Test Accuracy: {vit_acc:.4f}")
    print(f"ResNet Test Accuracy: {resnet_acc:.4f}")
    print(f"MobileViT Test Accuracy: {mobilevit_acc:.4f}")
    
    
    img_loader = DataLoader(vit_dataset, batch_size=64, shuffle=True)
    images, labels = next(iter(img_loader))
    
    
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    
    def update_view(idx=0):
        clear_output(wait=True)
        
        img = images[idx].to(device)
        label = labels[idx].item()
        
        
        with torch.no_grad():
            vit_output = vit_model(img.unsqueeze(0))
            vit_pred = torch.argmax(vit_output, dim=1).item()
            resnet_output = resnet_model(img.unsqueeze(0))
            resnet_pred = torch.argmax(resnet_output, dim=1).item()
            mobilevit_output = mobilevit_model(img.unsqueeze(0))
            mobilevit_pred = torch.argmax(mobilevit_output, dim=1).item()
        
        
        vit_cam, _ = vit_gradcam.generate_cam(img.unsqueeze(0), target_class=None)
        resnet_cam, _ = resnet_gradcam.generate_cam(img.unsqueeze(0), target_class=None)
        mobilevit_cam, _ = mobilevit_gradcam.generate_cam(img.unsqueeze(0), target_class=None)
        
        
        vit_feature = get_feature_map(vit_model, img, "blocks", model_type="vit")
        resnet_feature = get_feature_map(resnet_model, img, "layer4", model_type="resnet")
        mobilevit_feature = get_feature_map(mobilevit_model, img, "stages", model_type="mobilevit")
        
        
        img_np = img.cpu().numpy().transpose(1, 2, 0) * std + mean
        img_np = np.clip(img_np, 0, 1)
        
        
        vit_overlay = show_cam_on_image(img_np, vit_cam)
        resnet_overlay = show_cam_on_image(img_np, resnet_cam)
        mobilevit_overlay = show_cam_on_image(img_np, mobilevit_cam)
        
        
        fig, axes = plt.subplots(2, 4, figsize=(20, 10))
        
        
        axes[0, 0].imshow(img_np)
        axes[0, 0].set_title(f"Original - Class: {class_names[label]}")
        axes[0, 0].axis('off')
        
        
        axes[0, 1].imshow(vit_overlay)
        axes[0, 1].set_title(f"ViT GradCAM - Pred: {class_names[vit_pred]}")
        axes[0, 1].axis('off')
        
        
        axes[0, 2].imshow(resnet_overlay)
        axes[0, 2].set_title(f"ResNet GradCAM - Pred: {class_names[resnet_pred]}")
        axes[0, 2].axis('off')
        
        
        axes[0, 3].imshow(mobilevit_overlay)
        axes[0, 3].set_title(f"MobileViT GradCAM - Pred: {class_names[mobilevit_pred]}")
        axes[0, 3].axis('off')
        
        
        axes[1, 0].imshow(img_np)
        axes[1, 0].set_title("Original Image")
        axes[1, 0].axis('off')
        
        axes[1, 1].imshow(vit_feature, cmap='viridis')
        axes[1, 1].set_title("ViT Feature Map")
        axes[1, 1].axis('off')
        
        axes[1, 2].imshow(resnet_feature, cmap='viridis')
        axes[1, 2].set_title("ResNet Feature Map")
        axes[1, 2].axis('off')
        
        axes[1, 3].imshow(mobilevit_feature, cmap='viridis')
        axes[1, 3].set_title("MobileViT Feature Map")
        axes[1, 3].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        
        print(f"True Class: {class_names[label]}")
        print(f"ViT Prediction: {class_names[vit_pred]} {'✅' if vit_pred == label else '❌'}")
        print(f"ResNet Prediction: {class_names[resnet_pred]} {'✅' if resnet_pred == label else '❌'}")
        print(f"MobileViT Prediction: {class_names[mobilevit_pred]} {'✅' if mobilevit_pred == label else '❌'}")
    
    image_slider = widgets.IntSlider(min=0, max=len(images)-1, step=1, value=0, description='Image:')
    widgets.interact(update_view, idx=image_slider)
    
    vit_gradcam.remove_hooks()
    resnet_gradcam.remove_hooks()
    mobilevit_gradcam.remove_hooks()

if __name__ == "__main__":
    interactive_gradcam_viewer()