In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
from torch import device
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import os
from sklearn.metrics import roc_curve, precision_recall_curve

torch.manual_seed(47)
np.random.seed(47)

In [None]:
class ResUnit(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(ResUnit, self).__init__()
        self.c1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.b1 = nn.BatchNorm2d(planes)
        self.c2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.b2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

        self.shortcut = nn.Sequential()

        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = self.relu(self.b1(self.c1(x)))
        out = self.b2(self.c2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResNet18(nn.Module):
    def __init__(self, num_classes=100):
        super(ResNet18, self).__init__()
        self.in_planes = 64

        # Initial layer (Modified for 32x32 input like CIFAR)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        # ResNet18 Stages: 2 blocks each with channels [64, 128, 256, 512]
        self.layer1 = self._make_layer(64,  2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(ResUnit(self.in_planes, planes, s))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        feat = torch.flatten(out, 1)
        out = self.fc(feat)
        return feat,out

# Training the model

In [None]:
lr =0.01
batch_size = 128
epochs = 400
weight_decay = 5e-4
crit = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
resnet = ResNet18(num_classes=100).to(device)
opti = torch.optim.SGD(resnet.parameters(), lr=lr, weight_decay=weight_decay,momentum = 0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opti, T_max=epochs)

In [None]:
transf = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

])
tr = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

In [None]:
train_data = torchvision.datasets.CIFAR100(
    root="./data", train=True, download=True, transform=transf
)
val_data = torchvision.datasets.CIFAR100(
    root="./data", train=False, download=True, transform=tr
)

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
import os
os.makedirs("/content/drive/MyDrive/models", exist_ok=True)

In [None]:
def train(resnet, crit, opti, scheduler, train_dataloader, val_dataloader, epochs, save=True):
    losses, accuracies = [], []
    val_losses, val_accuracies = [], []

    for epoch in range(epochs):
        resnet.train()
        i = 0
        correct, total = 0, 0
        running_loss = 0.0

        for inputs, labels in tqdm(train_dataloader):
            inputs, labels = inputs.to(device), labels.to(device)

            _ , outputs = resnet(inputs)
            loss = crit(outputs, labels)

            opti.zero_grad()
            loss.backward()
            opti.step()

            running_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)
            i += 1

        scheduler.step()

        train_loss = running_loss / i
        train_acc = correct / total
        losses.append(train_loss)
        accuracies.append(train_acc)

        print(f"Epoch {epoch+1}/{epochs}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")

        resnet.eval()
        val_correct, val_total = 0, 0
        val_running_loss = 0.0

        with torch.no_grad():
            for val_in, val_l in val_dataloader:
                val_in, val_l = val_in.to(device), val_l.to(device)
                _, val_out = resnet(val_in)
                v_loss = crit(val_out, val_l)

                val_running_loss += v_loss.item()
                val_correct += (val_out.argmax(dim=1) == val_l).sum().item()
                val_total += val_l.size(0)

        val_loss_avg = val_running_loss / len(val_dataloader)
        val_acc = val_correct / val_total
        val_losses.append(val_loss_avg)
        val_accuracies.append(val_acc)

        print(f"Validation Loss: {val_loss_avg:.4f}, Validation Accuracy: {val_acc:.4f}")

        if save and (epoch + 1) % 10 == 0:
            save_path = f"/content/drive/MyDrive/models/resnet_{epoch+1}_epoch.pth"
            torch.save({
                'model_state_dict': resnet.state_dict(),
                'optimizer_state_dict': opti.state_dict(),
                'sheduler_state_dict': scheduler.state_dict(),
                'epoch': epoch + 1,
                'loss': train_loss,
                'train_losses': losses,
                'train_accuracies': accuracies,
                'val_losses': val_losses,
                'val_accuracies': val_accuracies,
            }, save_path)
            print(f"Successfully saved checkpoint to {save_path}")

    return losses, accuracies, val_losses, val_accuracies

Uncomment this if you want to train !

In [None]:
#losses, accuracies, val_losses, val_accuracies = train(resnet,crit,opti,scheduler,train_dataloader,val_dataloader,epochs)

Load model if you have not trained !

In [None]:
resnet = ResNet18(num_classes=100).to(device)
checkpoint_path = "/content/drive/MyDrive/models/resnet_400_epoch.pth"
checkpoint = torch.load(checkpoint_path, map_location=device)
resnet.load_state_dict(checkpoint['model_state_dict'])
optimizer = torch.optim.SGD(resnet.parameters(), lr=lr, weight_decay=1e-6,momentum = 0.9)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
scheduler.load_state_dict(checkpoint['sheduler_state_dict'])
losses = checkpoint['train_losses']
accuracies = checkpoint['train_accuracies']
val_losses = checkpoint['val_losses']
val_accuracies = checkpoint['val_accuracies']

In [None]:
def plot_metrics(losses, accuracies, val_losses, val_accuracies,path = "/content/drive/MyDrive/metrics/training_metrics"):
    epochs = range(1, len(losses) + 1)

    plt.figure(figsize=(15, 5))

    # Plot Accuracy
    plt.subplot(1, 2, 1)
    plt.plot(epochs, accuracies, 'b', label='Training Acc')
    plt.plot(epochs, val_accuracies, 'r', label='Test Acc')
    plt.title('Training and Test Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)

    # Plot Loss (Log Scale is better for Neural Collapse)
    plt.subplot(1, 2, 2)
    plt.semilogy(epochs, losses, 'b', label='Training Loss')
    plt.semilogy(epochs, val_losses, 'r', label='Test Loss')
    plt.title('Training and Test Loss (Log Scale)')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

    directory = os.path.dirname(path)
    os.makedirs(directory, exist_ok = True)

    plt.savefig(path, dpi=300, bbox_inches='tight')
    print(f"Metrics plot saved to: {path}")
    plt.show()

In [None]:
plot_metrics(losses=losses, accuracies=accuracies, val_losses=val_losses, val_accuracies=val_accuracies)


# OOD SCORES

In [None]:
@torch.no_grad()
def extract_features_and_logits(model, dataloader, device):
    model.eval()
    model.to(device)

    all_features = []
    all_logits = []
    all_labels = []

    for images, labels in tqdm(dataloader, desc="Extracting"):
        images = images.to(device)

        features,logits = model(images)

        all_features.append(features.detach().cpu().numpy())
        all_logits.append(logits.detach().cpu().numpy())
        all_labels.append(labels.numpy())

    return (
        np.concatenate(all_features),
        np.concatenate(all_logits),
        np.concatenate(all_labels)
    )

In [None]:
from scipy.special import logsumexp, softmax
from scipy.linalg import pinv
from sklearn.covariance import EmpiricalCovariance
from sklearn.metrics import roc_auc_score, roc_curve, average_precision_score

class OODScorer:
    def __init__(self, W, b, num_classes):
        self.W = W
        self.b = b
        self.num_classes = num_classes

    def fit(self, train_features, train_logits, train_labels, vim_dim=64):
        # Mahalanobis Fit
        self.means = np.array([train_features[train_labels == c].mean(axis=0) for c in range(self.num_classes)])
        centered_features = train_features - self.means[train_labels]
        ec = EmpiricalCovariance(assume_centered=True).fit(centered_features)
        self.precision = ec.precision_

        # ViM Fit
        self.u = -pinv(self.W) @ self.b
        X = train_features - self.u
        cov = np.cov(X, rowvar=False)
        eig_vals, eig_vecs = np.linalg.eigh(cov)

        self.residual_subspace = eig_vecs[:, : (X.shape[1] - vim_dim)]

        res_norms = np.linalg.norm(X @ self.residual_subspace, axis=-1)
        self.alpha = np.mean(np.max(train_logits, axis=-1)) / np.mean(res_norms)

    def score(self, features, logits):
        scores = {}

        # 1. MSP
        probs = softmax(logits, axis=1)
        scores['msp'] = np.max(probs, axis=1)

        # 2. MLS
        scores['mls'] = np.max(logits, axis=1)

        # 3. Energy Score
        scores['energy'] = logsumexp(logits, axis=1)

        # 4. Mahalanobis
        maha_scores = []
        for x in features:
            diff = self.means - x
            dists = np.diag(diff @ self.precision @ diff.T)
            maha_scores.append(-np.min(dists))
        scores['mahalanobis'] = np.array(maha_scores)

        # 5. ViM
        X_test = features - self.u
        res_norms = np.linalg.norm(X_test @ self.residual_subspace, axis=-1)
        v_logits = res_norms * self.alpha
        scores['vim'] = scores['energy'] - v_logits

        return scores

In [None]:
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve

def evaluate_ood(id_scores, ood_scores, device=None):

    # Labels: 0 for ID, 1 for OOD
    labels = np.concatenate([np.zeros(len(id_scores)), np.ones(len(ood_scores))])

    # Negate raw NECO scores so that higher values = OOD
    scores = np.concatenate([-id_scores, -ood_scores])

    # AUROC: Standard area under the curve
    auroc = roc_auc_score(labels, scores)

    # AUPR_Out: Precision-Recall area where OOD is positive
    aupr_out = average_precision_score(labels, scores)

    # AUPR_In: Precision-Recall area where ID is Positive
    aupr_in = average_precision_score(1 - labels, -scores)

    # FPR95: False Positive Rate (ID misclassified as OOD)
    fpr, tpr, _ = roc_curve(labels, scores)
    fpr95 = fpr[np.argmax(tpr >= 0.95)]

    return {
        "AUROC": auroc * 100,
        "FPR95": fpr95 * 100,
        "AUPR_In": aupr_in * 100,
        "AUPR_Out": aupr_out * 100
    }

In [None]:
from sklearn.metrics import precision_recall_curve, auc

def plot_comparisons(id_scores_dict, ood_scores_dict, metrics_dict):
    methods = list(id_scores_dict.keys())

    # Rows: Distributions, ROC, PR-In, PR-Out
    fig, axes = plt.subplots(4, len(methods), figsize=(4 * len(methods), 16))
    fig.suptitle("Full OOD Benchmark: ID vs OOD Perspective", fontsize=22, y=1.02)

    for i, method in enumerate(methods):
        id_s = id_scores_dict[method]
        ood_s = ood_scores_dict[method]
        met = metrics_dict[method]

        # Labels for ID-centric (ID=1, OOD=0)
        labels_in = np.concatenate([np.ones(len(id_s)), np.zeros(len(ood_s))])
        scores_in = np.concatenate([id_s, ood_s])

        # Labels for OOD-centric (OOD=1, ID=0)
        labels_out = 1 - labels_in
        scores_out = -scores_in

        # Row 1: Score Distributions
        axes[0, i].hist(id_s, bins=50, alpha=0.5, label='ID', density=True, color='skyblue')
        axes[0, i].hist(ood_s, bins=50, alpha=0.5, label='OOD', density=True, color='salmon')
        axes[0, i].set_title(f"{method.upper()}\nDistributions")
        if i == 0: axes[0, i].legend()

        # Row 2: ROC Curves
        fpr, tpr, _ = roc_curve(labels_out, scores_out)
        axes[1, i].plot(fpr, tpr, color='blue', label=f"AUROC: {met['AUROC']:.3f}")
        axes[1, i].plot([0, 1], [0, 1], 'k--', alpha=0.2)
        axes[1, i].scatter(met['FPR95'], 0.95, color='red', label=f"FPR95: {met['FPR95']:.3f}")
        axes[1, i].set_title("ROC Curve")
        axes[1,i].set_xlabel("FPR")
        axes[1,i].set_ylabel("TPR")
        axes[1,i].set_title(f"{method.upper()} ROC")
        axes[1,i].legend(loc="lower right")

        # Row 3: PR-In (Success/OOD Detection)
        prec_in, rec_in, _ = precision_recall_curve(labels_out, scores_out)
        axes[2, i].plot(rec_in, prec_in, color='green', label=f"AUPR-In: {met['AUPR_In']:.3f}")
        axes[2, i].set_title("PR-In (ID=Pos)")
        axes[2, i].set_xlabel("Recall (ID)")
        axes[2, i].set_ylabel("Precision (ID)")
        axes[2, i].legend(loc="lower left")

        # Row 4: PR-Out (Error/ID Detection)
        prec_out, rec_out, _ = precision_recall_curve(labels_in, scores_in)
        axes[3, i].plot(rec_out, prec_out, color='purple', label=f"AUPR-Out: {met['AUPR_Out']:.3f}")
        axes[3, i].set_title("PR-Out (OOD=Pos)")
        axes[3, i].set_xlabel("Recall (OOD)")
        axes[3, i].set_ylabel("Precision (OOD)")
        axes[3, i].legend(loc="lower left")

    plt.tight_layout()
    plt.savefig("/content/drive/MyDrive/metrics/full_ood_benchmark.png", dpi=300)
    plt.show()

In [None]:
# We remove augmentation on train_data for the continuation of our analysis
train_data = torchvision.datasets.CIFAR100(
    root="./data", train=True, download=True, transform=tr
)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
svhn_test = torchvision.datasets.SVHN(root="./data", split="test", download=True, transform = tr)

In [None]:
print(len(val_data),len(svhn_test))

In [None]:
# We want the same number of id and ood for better interpretation of metrics
from torch.utils.data import Subset

N = len(val_data)
g = torch.Generator().manual_seed(42)
indices = indices = torch.randperm(len(svhn_test), generator=g)[:N]
svhn_subset = Subset(svhn_test, indices)

In [None]:
svhn_test_loader = torch.utils.data.DataLoader(svhn_subset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
print("Extracting ID Train Statistics (CIFAR-100)...")
train_features, train_logits, train_labels = extract_features_and_logits(resnet, train_dataloader, device)

print("Extracting ID Test Data (CIFAR-100)...")
id_features, id_logits, _ = extract_features_and_logits(resnet, val_dataloader, device)

print("Extracting OOD Test Data (SVHN)...")
ood_features, ood_logits, _ = extract_features_and_logits(resnet, svhn_test_loader,device)

In [None]:
print("Fitting the OOD Scorer...")
W = resnet.fc.weight.detach().cpu().numpy()
b = resnet.fc.bias.detach().cpu().numpy()
scorer = OODScorer(W, b, num_classes=100)
scorer.fit(train_features, train_logits, train_labels)

print("Calculating Scores...")
id_scores = scorer.score(id_features, id_logits)
ood_scores = scorer.score(ood_features, ood_logits)

print("Evaluating Scores...")
all_metrics_evaluated = {}
for method_name in id_scores.keys():
    current_metrics = evaluate_ood(id_scores[method_name], ood_scores[method_name],device)
    all_metrics_evaluated[method_name] = current_metrics
    print(f"Metrics for {method_name.upper()}: {current_metrics}")

In [None]:
plot_comparisons(id_scores, ood_scores, all_metrics_evaluated)

In [None]:
def save_individual_plots(id_scores_dict, ood_scores_dict, metrics_dict):

    path = f"/content/drive/MyDrive/metrics"
    os.makedirs(path, exist_ok=True)
    methods = list(id_scores_dict.keys())

    for method in methods:

        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        fig.suptitle(f"OOD Benchmark: {method.upper()}", fontsize=18)

        id_s = id_scores_dict[method]
        ood_s = ood_scores_dict[method]
        met = metrics_dict[method]

        # Labels for ID-centric (ID=1, OOD=0)
        labels_in = np.concatenate([np.ones(len(id_s)), np.zeros(len(ood_s))])
        scores_in = np.concatenate([id_s, ood_s])

        # Labels for OOD-centric (OOD=1, ID=0)
        labels_out = 1 - labels_in
        scores_out = -scores_in

        # Top Left: Score Distributions
        ax = axes[0, 0]
        ax.hist(id_s, bins=50, alpha=0.5, label='ID', density=True, color='skyblue')
        ax.hist(ood_s, bins=50, alpha=0.5, label='OOD', density=True, color='salmon')
        ax.set_title("Score Distributions")
        ax.legend()

        # Top Right: ROC Curve
        ax = axes[0, 1]
        fpr, tpr, _ = roc_curve(labels_in, scores_in)
        ax.plot(fpr, tpr, color='blue', label=f"AUROC: {met['AUROC']:.3f}")
        ax.plot([0, 1], [0, 1], 'k--', alpha=0.2)
        ax.scatter(met['FPR95'], 0.95, color='red', label=f"FPR95: {met['FPR95']:.3f}")
        ax.set_xlabel("FPR")
        ax.set_ylabel("TPR")
        ax.set_title("ROC Curve")
        ax.legend(loc="lower right")

        # Bottom Left: PR-In (Success/ID Detection)
        ax = axes[1, 0]
        prec_in, rec_in, _ = precision_recall_curve(labels_in, scores_in)
        ax.plot(rec_in, prec_in, color='green', label=f"AUPR-In: {met['AUPR_In']:.3f}")
        ax.set_xlabel("Recall (ID)")
        ax.set_ylabel("Precision (ID)")
        ax.set_title("PR-In (ID = Positive)")
        ax.legend(loc="lower left")

        # Bottom Right: PR-Out (Error/OOD Detection)
        ax = axes[1, 1]
        prec_out, rec_out, _ = precision_recall_curve(labels_out, scores_out)
        ax.plot(rec_out, prec_out, color='purple', label=f"AUPR-Out: {met['AUPR_Out']:.3f}")
        ax.set_xlabel("Recall (OOD)")
        ax.set_ylabel("Precision (OOD)")
        ax.set_title("PR-Out (OOD = Positive)")
        ax.legend(loc="lower left")

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])

        # Save and close to free up memory
        file_path = os.path.join(path, f"{method.lower()}_analysis.png")
        plt.savefig(file_path, dpi=300)
        plt.close(fig)
        print(f"Saved: {file_path}")

In [None]:
save_individual_plots(id_scores_dict=id_scores, ood_scores_dict=ood_scores, metrics_dict=all_metrics_evaluated)

# NC Analysis

Variability collapse : In class variations converges towards 0.
$$
\Sigma_B \to 0
$$
In practice we have
$$
Tr\big [ \frac {\Sigma_W \Sigma_B ^ † } C \big ]
$$ with

- $\Sigma_W$ (Within-class): $\frac{1}{n} \sum_{c=1}^C \sum_{i=1}^{n_c} (h_{i,c} - \mu_c)(h_{i,c} - \mu_c)^\top$
- $\Sigma_B$ (Between-class): $\frac{1}{C} \sum_{c=1}^C (\mu_c - \mu_G)(\mu_c - \mu_G)^\top$


In [None]:
def verify_nc1(features, labels, logits = None):

    N, d = features.shape
    classes = torch.unique(labels)
    C = len(classes)

    # Global mean
    mu_g = features.mean(dim=0, keepdim=True)

    # Initialize Covariance Matrices
    sigma_w = torch.zeros((d, d), device=features.device)
    sigma_b = torch.zeros((d, d), device=features.device)

    for c in classes:
        # Filter features for class c
        h_c = features[labels == c]
        n_c = h_c.shape[0]

        # Compute class mean
        mu_c = h_c.mean(dim=0, keepdim=True)

        # Within-class covariance contribution
        centered_h = h_c - mu_c
        sigma_w += torch.mm(centered_h.t(), centered_h)

        # Between-class covariance contribution
        centered_mu = mu_c - mu_g
        sigma_b += torch.mm(centered_mu.t(), centered_mu)

    # Normalize by N and C
    sigma_w /= N
    sigma_b /= C

    Sigma_B_pinv = torch.linalg.pinv(sigma_b, hermitian=True)
    product = torch.mm(sigma_w, Sigma_B_pinv)
    nc1_val = torch.trace(product) / C
    print(f"NC1 Value: {nc1_val.item()}")
    return nc1_val.item()

In [None]:
# we make sure that we removed augmentation on the train_dataset
print(train_dataloader.dataset.transform)


In [None]:
features, logits, labels = extract_features_and_logits(resnet, train_dataloader,device)

In [None]:
id_features, id_logits, id_labels = extract_features_and_logits(resnet, val_dataloader,device)

In [None]:
ood_features, ood_logits, ood_labels = extract_features_and_logits(resnet, svhn_test_loader,device)

In [None]:
nc1 = verify_nc1(torch.from_numpy(features), torch.from_numpy(labels))

- The NC1 value is very small. This show that we have reached the terminal phase of training. The clusters have shrunk significantly.


Now we are going to visualize the values of NC 1 across epochs.


In [None]:
def NC_values(num,f_nc):

  nc_values = []
  for i in range(10,401,10):
    print(f'Computing nc{num} value for epoch {i}...')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    resnet1 = ResNet18(num_classes=100).to(device)
    checkpoint_path = f"/content/drive/MyDrive/models/resnet_{i}_epoch.pth"
    checkpoint = torch.load(checkpoint_path, map_location=device)
    resnet1.load_state_dict(checkpoint['model_state_dict'])

    features1, logits1, labels1 = extract_features_and_logits(resnet1, train_dataloader,device)
    if num < 4:
      values = f_nc(torch.from_numpy(features1), torch.from_numpy(labels1))
    elif num == 4:
      values = f_nc(torch.from_numpy(features1), torch.from_numpy(labels1),torch.from_numpy(logits1))

    else:
      values = f_nc(torch.from_numpy(features1), torch.from_numpy(labels1),torch.from_numpy(ood_features))

    nc_values.append(values)

    del resnet1
    torch.cuda.empty_cache()
  return nc_values

def plot_NC(num,nc_values):

  epochs = list(range(10, 401, 10))
  plt.figure(figsize=(10, 6))
  plt.plot(epochs, nc_values, marker='o', linestyle='-', color='#1f77b4', linewidth=2, label='NC Value')

  plt.title(f'Evolution of NC {num} across epochs', fontsize=14)
  plt.xlabel('Epoch', fontsize=12)
  plt.ylabel(f'NC {num}', fontsize=12)
  plt.grid(True, which="both", linestyle='--', alpha=0.5)
  plt.legend()

  # Save the figure to your drive or local environment
  save_path = f"/content/drive/MyDrive/NC/nc{num}_evolution_curve.png"
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
  plt.savefig(save_path, dpi=300, bbox_inches='tight')
  plt.show()

  print(f"Analysis complete. Plot saved at: {save_path}")

In [None]:
nc1_val = NC_values(1,verify_nc1)

In [None]:
plot_NC(1,nc1_val)

## NC 2

As training progresses, the standard deviations of the cosines approach zero
indicating equiangularity

$$
\left| \|\mu_c - \mu_G\|_2 - \|\mu_{c'} - \mu_G\|_2 \right| \to 0 \quad \forall c, c'
$$
$$
    \langle \tilde{\mu}_c, \tilde{\mu}_{c'} \rangle \to \frac{C}{C-1} \delta_{c,c'} - \frac{1}{C-1} \quad \forall c, c'
$$
In practice we have
$$
EN_{\text{class-means}} = \frac{std_c\{ \|\mu_c - \mu_G\|_2 \}}{avg_c\{ \|\mu_c - \mu_G\|_2 \}} \to 0
$$

and
$$
\text{Equiangularity}_{\text{class-means}} = Avg_{c,c'} \big | \frac{ \langle {\mu}_c-{\mu}_G, {\mu}_{c'} - {\mu}_G \rangle  + \frac 1 {C-1} } {\|\mu_c - \mu_G\|_2\|\mu_c' - \mu_G\|_2}   \big | \to 0
$$

In [None]:
def verify_nc2(features, labels, logits = None):

    classes = torch.unique(labels)
    C = len(classes)

    class_means = torch.stack([
        features[labels == c].mean(0) for c in classes
    ])

    mu_G = features.mean(0)
    centered_means = class_means - mu_G

    norms = torch.norm(centered_means, dim=1)

    # Equinormality
    en_val = torch.std(norms, unbiased=False) / torch.mean(norms)

    # Equiangularity
    eps = 1e-12
    normed_means = centered_means / (norms.view(-1, 1) + eps)
    cos_sim_matrix = torch.mm(normed_means, normed_means.t())

    mask = ~torch.eye(C, dtype=torch.bool, device=features.device)
    off_diag_cos = cos_sim_matrix[mask]

    ideal_cos = -1.0 / (C - 1)
    equiangularity_val = torch.abs(off_diag_cos - ideal_cos).mean()
    print(f'Equiangularity: {equiangularity_val.item()}')
    print(f'Equinormormality : {en_val.item()}')

    return en_val.item(), equiangularity_val.item()


In [None]:
results = verify_nc2(torch.from_numpy(features), torch.from_numpy(labels))

As expected, the values are very small. This validate the NC2 property.

Now we will visualize these values across different epochs and we will look at the cosine similarity heatmap and the distributions of pairwise angles.

In [None]:
nc2_val = NC_values(2,verify_nc2)

In [None]:
en_values = [item[0] for item in nc2_val]
equi_values = [item[1] for item in nc2_val]
epochs = list(range(10, 401, 10))

plt.figure(figsize=(10, 6))

# Plot Equinormality
plt.plot(epochs, en_values, marker='o', linestyle='-', color='blue', linewidth=2, label='Equinormality (EN)')

# Plot Equiangularity
plt.plot(epochs, equi_values, marker='s', linestyle='-', color='red', linewidth=2, label='Equiangularity')

plt.yscale('log')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Metric Value (Log Scale)', fontsize=12)
plt.title('Evolution of NC2: Equinormality & Equiangularity', fontsize=14)

plt.grid(True, which="both", linestyle='--', alpha=0.5)
plt.legend()

# 4. Save and Show
save_path = f"/content/drive/MyDrive/NC/nc2_evolution_curve.png"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"NC2 plot saved to {save_path}")

Heatmap for NC2

In [None]:
import seaborn as sns

def plot_nc2_diagnostics(features, labels):
    # Centered and Normalized Means
    classes = torch.unique(labels)
    C = len(classes)
    class_means = torch.stack([features[labels == c].mean(0) for c in classes])
    centered_means = class_means - features.mean(0)

    # Cosine Similarity Matrix
    normed_means = centered_means / (torch.norm(centered_means, dim=1, keepdim=True) + 1e-12)
    cos_sim = (normed_means @ normed_means.T).cpu().numpy()

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7))

    # Subplot 1: Heatmap
    sns.heatmap(cos_sim, cmap='coolwarm', center=0, vmin=-0.1, vmax=1.0, ax=ax1)
    ax1.set_title(f"Cosine Similarity Heatmap")
    ax1.set_xlabel("Class Index")
    ax1.set_ylabel("Class Index")

    # Subplot 2: Histogram of Off-Diagonal elements
    mask = ~np.eye(C, dtype=bool)
    off_diag = cos_sim[mask]

    ax2.hist(off_diag, bins=100, color='skyblue', edgecolor='black', alpha=0.7)
    ideal_val = -1.0 / (C - 1)
    ax2.axvline(ideal_val, color='red', linestyle='--', linewidth=2, label=f'Ideal: {ideal_val:.4f}')

    ax2.set_title(f"Distribution of Pairwise Angles")
    ax2.set_xlabel("Cosine Similarity")
    ax2.set_ylabel("Frequency")
    ax2.legend()
    ax2.grid(axis='y', alpha=0.3)

    plt.tight_layout()
    save_path = f"/content/drive/MyDrive/NC/nc2_diagnostics_epoch.png"
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
plot_nc2_diagnostics(torch.from_numpy(features), torch.from_numpy(labels))

## NC3
This is the cnvergence to self-duality property:

$$ \left\| \frac{W^T}{\|W|_F} - \frac{\dot M}{\|\dot M|_F}\right\|_F \to 0$$

In [None]:
def verify_nc3(features, labels, logits = None):

    C = len(torch.unique(labels))
    classifier_weights = resnet.fc.weight.detach().cpu()

    # Centered Class Means
    class_means = torch.stack([features[labels == c].mean(0) for c in range(C)])
    mu_G = features.mean(0)
    centered_means = class_means - mu_G

    eps = 1e-12
    # Normalize both weights and centered means
    norm_w = classifier_weights / (torch.norm(classifier_weights, dim=1, keepdim=True) + eps)
    norm_m = centered_means / (torch.norm(centered_means, dim=1, keepdim=True) + eps)

    # Compute L2 distance between normalized vectors
    alignment_dist = torch.norm(norm_w - norm_m, dim=1).mean()
    print(f"NC3 Value: {alignment_dist.item()}")
    return alignment_dist.item()

In [None]:
nc3_deviation = verify_nc3(torch.from_numpy(features), torch.from_numpy(labels))

The value of NC3 is small but not that small as NC3 must go to 0. Maybe we needed to train for more epochs. We will plot values of NC3 across epochs to confirm this hypothesis.

In [None]:
nc3_val = NC_values(3,verify_nc3)

In [None]:
plot_NC(3,nc3_val)

## NC 4
Simplification to NCC
$$ arg\max_{c'} \left< w_{c'}, h \right> + b_{c'} \to \arg\min_{c'} \|h - \mu_{c'}\|_2 $$

The NC4 Metric measures the "disagreement" between the actual model predictions and the NCM predictions. As collapse occurs, we have : $$NC4 = \frac{1}{N} \sum_{i=1}^{N} \mathbb{1} \left( \text{model}(x_i) \neq \arg\min_c \|h(x_i) - \mu_c\|_2 \right) \to 0$$

In [None]:
def verify_nc4(features,labels,logits = None):

    features = torch.as_tensor(features)
    logits = torch.as_tensor(logits)
    labels = torch.as_tensor(labels)

    classes = torch.unique(labels)
    C = len(classes)

    # Class means
    class_means = torch.stack([
        features[labels == c].mean(0)
        for c in classes
    ])

    # Model predictions
    model_preds = torch.argmax(logits, dim=1)

    # NCM predictions
    dist_matrix = torch.cdist(features, class_means)
    ncm_preds = torch.argmin(dist_matrix, dim=1)

    # Disagreement rate
    disagreement = (model_preds != ncm_preds).float().mean()
    print(f"NC4 Value: {disagreement.item()}")
    return disagreement.item()

In [None]:
agr = verify_nc4(features,labels,logits)

As expected, the NC4 value is very small. This validates property NC4.

Now we observe the evolution of NC4 values across epochs.

In [None]:
nc4_val = NC_values(4,verify_nc4)

In [None]:
plot_NC(4,nc4_val)

## NC 5

As training progresses, the clusters of OOD become increasingly orthgonal to the ETF subspace of the ID data. We have

$$
\text{OrthoDev}_\text{classes−OOD} = \text{Avg}_c \big | \frac{ \langle {\mu}_c, {\mu}_{G}^{OOD} \rangle } {\|\mu_c  \|_2\| {\mu}_{G}^{OOD} \|_2}     \big | \to 0
$$

In [None]:
def verify_nc5(id_features, id_labels, ood_features):

    id_features = torch.as_tensor(id_features)
    ood_features = torch.as_tensor(ood_features)
    id_labels = torch.as_tensor(id_labels)

    classes = torch.unique(id_labels)
    C = len(classes)

    # ID Class Means (mu_c)
    id_global_mean = id_features.mean(0)
    mu_c = torch.stack([id_features[id_labels == c].mean(0) - id_global_mean for c in classes])

    # OOD Global Mean (mu_G_OOD)
    mu_G_ood = ood_features.mean(0) - id_global_mean

    # Normalize for Cosine Similarity
    mu_c_norm = mu_c / (torch.norm(mu_c, dim=1, keepdim=True) + 1e-12)
    mu_G_ood_norm = mu_G_ood / (torch.norm(mu_G_ood) + 1e-12)

    # Average Absolute Cosine Similarity
    cos_sims = torch.mv(mu_c_norm, mu_G_ood_norm)
    ortho_dev = torch.abs(cos_sims).mean()
    print(f"NC5: {ortho_dev.item()}")
    return ortho_dev.item()

In [None]:
nc5 = verify_nc5(features,labels,ood_features)

As expected, the NC5 value is very small and this validates the property NC5. We will then look at the evolution of NC5 values across epochs.

In [None]:
nc5_val = NC_values(5,verify_nc5)

In [None]:
plot_NC(5,nc5_val)

The value of NC5 goes to 0 but seems to stabilize around 0.088 when the number of epochs increases.
This can be explain by the fact that the OOD datasets might share low-level features (textures, edges) with ID classes, creating a "floor" for how low the NC5 can go.

## NECO

The objective is to use a score-based strategy to find OOD detection.

We will implement NECO scores and compare it to previous OOD scores.
Here is the formula of NECO score.

$$\text{NECO}(x) = \frac{\|P h_{\omega}(x)\|}{\|h_{\omega}(x)\|}$$
- $h_{\omega}(x)$: The raw feature vector of an input $x$. In your code, this is often called x_features.
- $P$: The Orthogonal Projection Operator. This projects the feature vector onto the subspace spanned by the training class means.

In [None]:
from sklearn.decomposition import PCA

def neco_scores(test_features,train_features):

    test_features = torch.as_tensor(test_features)
    train_features = torch.as_tensor(train_features)
    num_classes = 100
    pca_model = PCA(n_components= num_classes - 1)
    pca_model.fit(train_features.cpu().numpy())

    x_feat = test_features

    etf_x_feat = pca_model.transform(x_feat) @ pca_model.components_

    numerator = np.linalg.norm(etf_x_feat, axis=1)
    denominator = np.linalg.norm(x_feat, axis=1)

    score = numerator / (denominator + 1e-12)
    print(score)
    return score

In [None]:
neco_id = neco_scores(features,ood_features)
neco_ood = neco_scores(ood_features,ood_features)

In [None]:
met = evaluate_ood(neco_id,neco_ood)
met

In [None]:
id_scores['neco'] = neco_id
ood_scores['neco'] = neco_ood
all_metrics_evaluated['neco'] = met

In [None]:
plot_comparisons(id_scores, ood_scores, all_metrics_evaluated)

NECO score is better than every other ood_scores on svhn dataset.