In [3]:
# =================================================================
# Cell 1: Setup, Imports, and Configuration
# =================================================================
import os
import time
import json
import pickle
import random
from collections import defaultdict

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights

from sklearn.metrics import accuracy_score
from tqdm import tqdm
import matplotlib.pyplot as plt

# -----------------------
# CONFIGURATION
# -----------------------
class Config:
    # --- Paths ---
    AWA2_DIR = "/kaggle/input/awa2-data/Animals_with_Attributes2"
    OOD_DIR = "/kaggle/input/cifar10-python/cifar-10-batches-py"
    OUTPUT_DIR = "/kaggle/working/"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # --- Data & Model ---
    IMAGE_SIZE = 224
    BATCH_SIZE = 64
    NUM_WORKERS = 2
    NUM_AWA2_CLASSES = 50
    
    # [FAIR EVAL] New 3-way data split ratios for AWA2
    AWA2_TRAIN_RATIO = 0.7 # 70% for training
    AWA2_VAL_RATIO = 0.1   # 10% for validation
    # The remaining 20% will be for testing

    # [FAIR EVAL] Define sizes for OOD validation and test sets
    NUM_OOD_VAL_IMAGES = 2000
    NUM_OOD_TEST_IMAGES = 10000

    # --- Training ---
    EPOCHS = 10
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-4

    # --- NCI Evaluation ---
    NCI_ALPHA = 1e-3 # Filtering strength for the L1 norm component

cfg = Config()
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
print("Config:", {k: v for k, v in cfg.__dict__.items() if not k.startswith('__')})
cudnn.benchmark = True


Config: {}


In [4]:
# =================================================================
# Cell 2: Data Loading and 3-Way Split (Train/Val/Test)
# =================================================================

data_tfms = transforms.Compose([
    transforms.Resize((cfg.IMAGE_SIZE, cfg.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class ImagePathDataset(Dataset):
    def __init__(self, paths, labels, transform):
        self.paths, self.labels, self.transform = paths, labels, transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        p, y = self.paths[idx], self.labels[idx]
        try: img = Image.open(p).convert("RGB")
        except: img = Image.new("RGB", (cfg.IMAGE_SIZE, cfg.IMAGE_SIZE))
        return self.transform(img), y

class ImageObjectDataset(Dataset):
    def __init__(self, images, labels, transform):
        self.images, self.labels, self.transform = images, labels, transform
    def __len__(self): return len(self.images)
    def __getitem__(self, idx):
        return self.transform(self.images[idx]), self.labels[idx]

# --- 1. Load AWA2 Data and perform a 3-way split ---
print("\nLoading and splitting AWA2 data...")
df_classes = pd.read_csv(os.path.join(cfg.AWA2_DIR, 'classes.txt'), sep='\t', header=None, names=['id', 'class_name'])
id_to_class_name = {i: row['class_name'] for i, row in df_classes.iterrows()}
img_root = os.path.join(cfg.AWA2_DIR, "JPEGImages")
all_paths, all_labels = [], []
for cid, class_name in sorted(id_to_class_name.items(), key=lambda x: x[0]):
    folder = os.path.join(img_root, class_name)
    if not os.path.isdir(folder): continue
    for fname in os.listdir(folder):
        if fname.lower().endswith((".jpg", ".jpeg", ".png")):
            all_paths.append(os.path.join(folder, fname))
            all_labels.append(cid)

full_dataset = ImagePathDataset(all_paths, all_labels, transform=data_tfms)

# Calculate split sizes
train_size = int(cfg.AWA2_TRAIN_RATIO * len(full_dataset))
val_size = int(cfg.AWA2_VAL_RATIO * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

# Split the dataset
train_ds, val_ds_awa2, test_ds_awa2 = random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)
print(f"AWA2 data split: {len(train_ds)} training, {len(val_ds_awa2)} validation, {len(test_ds_awa2)} testing.")

# --- 2. Load and split CIFAR-10 OOD Data ---
def unpickle(file):
    with open(file, 'rb') as fo: return pickle.load(fo, encoding='bytes')
all_ood_images = []
non_animal_labels = {0, 1, 8, 9} # airplane, automobile, ship, truck
for i in range(1, 6):
    p = os.path.join(cfg.OOD_DIR, f"data_batch_{i}")
    if not os.path.exists(p): continue
    d = unpickle(p)
    for j, label in enumerate(d[b'labels']):
        if label in non_animal_labels:
            all_ood_images.append(Image.fromarray(d[b'data'][j].reshape(3, 32, 32).transpose(1, 2, 0)))
random.seed(42); random.shuffle(all_ood_images)

# Create separate OOD sets for validation and testing
val_ood_images = all_ood_images[:cfg.NUM_OOD_VAL_IMAGES]
test_ood_images = all_ood_images[cfg.NUM_OOD_VAL_IMAGES : cfg.NUM_OOD_VAL_IMAGES + cfg.NUM_OOD_TEST_IMAGES]

val_ds_ood = ImageObjectDataset(val_ood_images, [-1]*len(val_ood_images), data_tfms)
test_ds_ood = ImageObjectDataset(test_ood_images, [-1]*len(test_ood_images), data_tfms)
print(f"CIFAR-10 OOD data split: {len(val_ds_ood)} validation, {len(test_ds_ood)} testing.")

# --- 3. Create Final Dataloaders ---
train_loader = DataLoader(train_ds, batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=cfg.NUM_WORKERS)

# Validation loader (mixed AWA2 val + OOD val)
val_loader = DataLoader(
    torch.utils.data.ConcatDataset([val_ds_awa2, val_ds_ood]),
    batch_size=cfg.BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS
)

# Test loader (mixed AWA2 test + OOD test)
test_loader = DataLoader(
    torch.utils.data.ConcatDataset([test_ds_awa2, test_ds_ood]),
    batch_size=cfg.BATCH_SIZE, shuffle=False, num_workers=cfg.NUM_WORKERS
)
print(f"Final loader sizes: Train={len(train_loader.dataset)}, Val={len(val_loader.dataset)}, Test={len(test_loader.dataset)}")



Loading and splitting AWA2 data...
AWA2 data split: 26125 training, 3732 validation, 7465 testing.
CIFAR-10 OOD data split: 2000 validation, 10000 testing.
Final loader sizes: Train=26125, Val=5732, Test=17465


In [7]:
# =================================================================
# Cell 3: Model Training
# =================================================================

def train_classifier():
    print("\n--- Training AWA2 Classifier (Post-Hoc Method) ---")
    model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, cfg.NUM_AWA2_CLASSES)
    model = model.to(cfg.DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(1, cfg.EPOCHS + 1):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.EPOCHS}")
        for x, y in pbar:
            x, y = x.to(cfg.DEVICE), y.to(cfg.DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(loss=f"{loss.item():.4f}")
            
    torch.save(model.state_dict(), os.path.join(cfg.OUTPUT_DIR, "nci_classifier.pth"))
    print("Finished training and saved model.")
    return model

model_path = os.path.join(cfg.OUTPUT_DIR, "nci_classifier.pth")
if os.path.exists(model_path):
    print("\nFound existing classifier. Loading it.")
    model = resnet50(weights=None)
    model.fc = nn.Linear(model.fc.in_features, cfg.NUM_AWA2_CLASSES)
    model.load_state_dict(torch.load(model_path, map_location=cfg.DEVICE))
    model.to(cfg.DEVICE)
else:
    model = train_classifier()



--- Training AWA2 Classifier (Post-Hoc Method) ---


Epoch 1/10: 100%|██████████| 409/409 [05:28<00:00,  1.25it/s, loss=0.3067]
Epoch 2/10: 100%|██████████| 409/409 [03:36<00:00,  1.89it/s, loss=0.0816]
Epoch 3/10: 100%|██████████| 409/409 [03:38<00:00,  1.87it/s, loss=0.0161]
Epoch 4/10: 100%|██████████| 409/409 [03:35<00:00,  1.90it/s, loss=0.2816]
Epoch 5/10: 100%|██████████| 409/409 [03:36<00:00,  1.89it/s, loss=0.0379]
Epoch 6/10: 100%|██████████| 409/409 [03:37<00:00,  1.88it/s, loss=0.0240]
Epoch 7/10: 100%|██████████| 409/409 [03:38<00:00,  1.87it/s, loss=0.0217]
Epoch 8/10: 100%|██████████| 409/409 [03:39<00:00,  1.87it/s, loss=0.0116]
Epoch 9/10: 100%|██████████| 409/409 [03:36<00:00,  1.89it/s, loss=0.3028]
Epoch 10/10: 100%|██████████| 409/409 [03:35<00:00,  1.89it/s, loss=0.0148]


Finished training and saved model.


In [8]:
# =================================================================
# Cell 4: NCI Setup (Post-Hoc Analysis)
# =================================================================
@torch.no_grad()
def setup_nci(model, train_loader):
    print("\n--- Setting up NCI detector (Post-Hoc Analysis) ---")
    model.eval()
    feature_extractor = nn.Sequential(*list(model.children())[:-1])
    
    all_train_feats = []
    # Use the train_loader which points to the correct training dataset
    for x, y in tqdm(train_loader, desc="Extracting training features"):
        x = x.to(cfg.DEVICE)
        feats = feature_extractor(x)
        all_train_feats.append(torch.flatten(feats, 1))
    all_train_feats = torch.cat(all_train_feats, dim=0)
    
    mu_G = all_train_feats.mean(dim=0)
    print(f"Calculated global feature mean (mu_G) from training data.")
    
    w_c = model.fc.weight
    print(f"Extracted weight vectors (w_c) from trained model.")
    
    return feature_extractor, mu_G, w_c

feature_extractor, mu_G, w_c = setup_nci(model, train_loader)



--- Setting up NCI detector (Post-Hoc Analysis) ---


Extracting training features: 100%|██████████| 409/409 [03:08<00:00,  2.16it/s]

Calculated global feature mean (mu_G) from training data.
Extracted weight vectors (w_c) from trained model.





In [9]:
# =================================================================
# Cell 5: Find Optimal Threshold using the VALIDATION Set
# =================================================================
@torch.no_grad()
def find_optimal_threshold(model, feature_extractor, mu_G, w_c, val_loader, alpha):
    print("\n--- Finding optimal threshold on validation set ---")
    model.eval()
    feature_extractor.eval()
    
    all_nci_scores, all_true_labels = [], []
    OOD_LABEL = -1
    
    for x, y_true in tqdm(val_loader, desc="Analyzing validation scores"):
        x = x.to(cfg.DEVICE)
        logits = model(x)
        h = torch.flatten(feature_extractor(x), 1)
        predicted_classes = torch.argmax(logits, dim=1)
        h_centered = h - mu_G
        w_c_batch = w_c[predicted_classes]
        
        cos_sim = F.cosine_similarity(h_centered, w_c_batch)
        w_c_norm = torch.norm(w_c_batch, p=2, dim=1)
        pScore = cos_sim * w_c_norm
        l1_norm = torch.norm(h, p=1, dim=1)
        nci_score = pScore + alpha * l1_norm
        
        all_nci_scores.extend(nci_score.cpu().numpy())
        all_true_labels.extend(y_true.numpy())

    all_nci_scores = np.array(all_nci_scores)
    all_true_labels = np.array(all_true_labels)
    
    id_mask = all_true_labels != OOD_LABEL
    ood_mask = all_true_labels == OOD_LABEL
    
    id_scores_mean = all_nci_scores[id_mask].mean()
    ood_scores_mean = all_nci_scores[ood_mask].mean()
    
    optimal_threshold = (id_scores_mean + ood_scores_mean) / 2.0
    print(f"\nValidation ID Scores Mean: {id_scores_mean:.4f}")
    print(f"Validation OOD Scores Mean: {ood_scores_mean:.4f}")
    print(f"==> Optimal Threshold Found: {optimal_threshold:.4f}")
    
    return optimal_threshold

# Find the best threshold using ONLY the validation data
optimal_threshold = find_optimal_threshold(model, feature_extractor, mu_G, w_c, val_loader, cfg.NCI_ALPHA)



--- Finding optimal threshold on validation set ---


Analyzing validation scores: 100%|██████████| 90/90 [00:57<00:00,  1.56it/s]


Validation ID Scores Mean: 1.6152
Validation OOD Scores Mean: 1.1511
==> Optimal Threshold Found: 1.3831





In [10]:
# =================================================================
# Cell 6: Final, Unbiased Evaluation on the TEST Set (Updated)
# =================================================================
@torch.no_grad()
def final_evaluation(model, feature_extractor, mu_G, w_c, test_loader, threshold, alpha):
    print("\n--- Final Evaluation on UNSEEN Test Set ---")
    model.eval()
    feature_extractor.eval()

    all_nci_scores, all_true_labels, all_initial_preds = [], [], []
    OOD_LABEL = -1

    for x, y_true in tqdm(test_loader, desc="Final evaluation"):
        x = x.to(cfg.DEVICE)
        logits = model(x)
        h = torch.flatten(feature_extractor(x), 1)
        predicted_classes = torch.argmax(logits, dim=1)
        h_centered = h - mu_G
        w_c_batch = w_c[predicted_classes]

        cos_sim = F.cosine_similarity(h_centered, w_c_batch)
        w_c_norm = torch.norm(w_c_batch, p=2, dim=1)
        pScore = cos_sim * w_c_norm
        l1_norm = torch.norm(h, p=1, dim=1)
        nci_score = pScore + alpha * l1_norm

        all_nci_scores.extend(nci_score.cpu().numpy())
        all_true_labels.extend(y_true.numpy())
        all_initial_preds.extend(predicted_classes.cpu().numpy())

    all_nci_scores = np.array(all_nci_scores)
    all_true_labels = np.array(all_true_labels)
    all_initial_preds = np.array(all_initial_preds)

    # --- Create final predictions based on the threshold ---
    final_preds = all_initial_preds.copy()
    ood_mask_pred = all_nci_scores < threshold
    final_preds[ood_mask_pred] = OOD_LABEL

    # --- Masks for separating true ID and OOD samples ---
    id_true_mask = all_true_labels != OOD_LABEL
    ood_true_mask = all_true_labels == OOD_LABEL

    # --- [NEW] Detailed Accuracy Calculation ---
    print("\n--- FINAL RESULTS ---")

    # 1. OOD Rejection Accuracy (same as before)
    ood_accuracy = accuracy_score(all_true_labels[ood_true_mask], final_preds[ood_true_mask])
    print(f"1. Out-of-Distribution (OOD) Rejection Accuracy: {ood_accuracy:.2%}")
    print("   (How well the model rejects non-animals)")

    # 2. ID Detection Accuracy (New Metric)
    # This checks how many true animals were correctly NOT labeled as OOD.
    id_preds_for_detection = final_preds[id_true_mask]
    # A correct prediction is anything that is NOT OOD (-1)
    id_detection_accuracy = np.mean(id_preds_for_detection != OOD_LABEL)
    print(f"\n2. In-Distribution (ID) Detection Accuracy: {id_detection_accuracy:.2%}")
    print("   (How well the model recognizes an animal is an animal)")

    # 3. ID Conditional Classification Accuracy (New Metric)
    # First, get only the animals that were correctly detected as ID
    correctly_detected_as_id_mask = (final_preds != OOD_LABEL) & (id_true_mask)
    # Then, calculate classification accuracy on just this subset
    conditional_class_accuracy = accuracy_score(
        all_true_labels[correctly_detected_as_id_mask],
        final_preds[correctly_detected_as_id_mask]
    )
    print(f"\n3. ID Conditional Classification Accuracy: {conditional_class_accuracy:.2%}")
    print("   (Of the animals it correctly identified as 'animal', how many did it classify correctly?)")


# --- Run the final evaluation using the optimal threshold found on the validation set ---
# Make sure 'optimal_threshold' has been calculated from the previous cell
final_evaluation(model, feature_extractor, mu_G, w_c, test_loader, optimal_threshold, cfg.NCI_ALPHA)




--- Final Evaluation on UNSEEN Test Set ---


Final evaluation: 100%|██████████| 273/273 [02:11<00:00,  2.07it/s]


--- FINAL RESULTS ---
1. Out-of-Distribution (OOD) Rejection Accuracy: 95.16%
   (How well the model rejects non-animals)

2. In-Distribution (ID) Detection Accuracy: 92.18%
   (How well the model recognizes an animal is an animal)

3. ID Conditional Classification Accuracy: 92.21%
   (Of the animals it correctly identified as 'animal', how many did it classify correctly?)



