# validate knn

Sometimes we want to run validation on a self-supervised model without training a linear probe.

To do so, we can do one of two nearest-neighbot classifiers:
1) get the embeddings for all training images, compare each val image to each training image, choose the K nearest neighbors (default = 200) and perform a similarity-weighted voting for the target class,
2) use training images to compute the prototype for each class, the simply assign each val image the class of the nearest-neighbor prototype

# ipcl code vs. deep_analytics

To make sure we port the code correctly, let's get scores with original IPCL code.

In [None]:
import os
import torch
from torch.utils.data import DataLoader

from deep_analytics.datasets.folder import ImageNetIndex
from lib.knn import run_kNN_chunky

device = 'cuda' if torch.cuda.is_available() else 'cpu'

run_kNN_chunky?

In [None]:
model, transform = torch.hub.load("harvard-visionlab/open_ipcl", "alexnetgn_ipcl_ref01",
                                  trust_repo=True, force_reload=True)
model.to(device)
transform

In [None]:
root = '/n/alvarez_lab_tier1/Users/alvarez/datasets/imagenet1k-256'
train_dataset = ImageNetIndex(root, split='train', transform=transform)
print(train_dataset)
test_dataset = ImageNetIndex(root, split='val', transform=transform)
print(test_dataset)

In [None]:
batch_size = 256
num_workers=len(os.sched_getaffinity(0))
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, 
                          shuffle=False, pin_memory=True, drop_last=False)

test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, 
                         shuffle=False, pin_memory=True, drop_last=False)

In [None]:
layer_name = 'fc8'
top1, top5 = run_kNN_chunky(model,
                            train_loader,
                            test_loader,
                            layer_name,
                            K=200,
                            sigma=0.07,
                            num_chunks=10,
                            out_device=None)
top1, top5

In [None]:
# ({'fc8': 39.173999428749084}, {'fc8': 61.40999794006348})

In [None]:
from deep_analytics.assays.cls_accuracy.knn import run_kNN

layer_names = ['ave_pool', 'fc6', 'fc7', 'fc8']
results = run_kNN(model, train_loader, test_loader, layer_names)

# prototypes

In [None]:
from deep_analytics.assays.cls_accuracy.knn import run_kNN

layer_names = ['fc7', 'fc8']
top1, top5, prototypes, features, labels = run_kNN(model, train_loader, test_loader, 
                                                   layer_names)

In [None]:
torch.save(dict(
    top1=top1,
    top5=top5,
    prototypes=prototypes,
    features=features,
    labels=labels
), './results/prototypes.pth.tar')

In [None]:
def mahalanobis_distance(x, y, covariance):
    """
    Compute the Mahalanobis distance between a test item and a prototype.

    Parameters:
    x: Tensor of shape (num_features,) representing the test item.
    y: Tensor of shape (num_features,) representing the prototype mean.
    covariance: Tensor of shape (num_features, num_features) representing the covariance matrix of the prototype.

    Returns:
    The Mahalanobis distance as a scalar.
    """
    diff = x - y
    # Ensure the covariance matrix is invertible
    if torch.det(covariance) == 0:
        raise ValueError("Covariance matrix is not invertible.")
    inv_covariance = torch.inverse(covariance)
    distance_squared = torch.matmul(torch.matmul(diff.view(1, -1), inv_covariance), diff.view(-1, 1))
    return torch.sqrt(distance_squared).item()

def compute_mahalanobis(test_means, proto_means, proto_vars):
    """
    Compute the Mahalanobis distance between test items and prototypes.

    Parameters:
    test_means: Tensor of shape (num_tests, num_features)
    proto_means: Tensor of shape (num_protos, num_features)
    proto_vars: Tensor of shape (num_protos, num_features)

    Returns:
    distances: Tensor of shape (num_tests, num_protos)
    """
    num_tests, num_features = test_means.shape
    num_protos, _ = proto_means.shape

    # Invert the variance to get the diagonal elements of the precision matrix
    precision_diag = 1.0 / proto_vars  # shape: (num_protos, num_features)

    # Expand dimensions for broadcasting
    test_means_exp = test_means.unsqueeze(1).expand(num_tests, num_protos, num_features)
    proto_means_exp = proto_means.unsqueeze(0).expand(num_tests, num_protos, num_features)
    precision_diag_exp = precision_diag.unsqueeze(0).expand(num_tests, num_protos, num_features)

    # Compute the Mahalanobis distance
    diff = test_means_exp - proto_means_exp
    distances = (diff ** 2 * precision_diag_exp).sum(dim=2).sqrt()

    return distances

# ClassificationKNN

In [None]:
import torch
from torchvision import models
from torchvision import transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
import os
import io
import gc
from collections import defaultdict
from contextlib import redirect_stdout
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import torchattacks
import matplotlib.pyplot as plt
from fastprogress import master_bar, progress_bar
from torch.cuda.amp import autocast

from deep_analytics.assays.model_assay import ModelAssay
from deep_analytics.utils.bootstrap import bootstrap_multi_dim
from deep_analytics.utils.stats import AccumMetric
# from deep_analytics.assays.metrics import *
from deep_analytics.utils.feature_extractor import FeatureExtractor

from pdb import set_trace

from types import SimpleNamespace

__all__ = ['ClassificationNearestNeighbors', 'ClassificationNearestPrototype']

class ClassificationNearestNeighbors(ModelAssay):
    
    datasets = dict(
        imagenette2=('imagenette2_s320_remap1k', 'val'),
        imagenet1k=('imagenet1k_s256', 'val'),
        imagenetV2_top_images=('imagenetV2', 'top-images'),
        imagenetV2_threshold07=('imagenetV2', 'threshold0.7'),
        imagenetV2_matched_frequency=('imagenetV2', 'matched-frequency')
    )

    def compute_metrics(self, df):
        raise NotImplementedError("Subclasses of ModelAssay should implement `compute_metrics`.")
        
    def plot_results(self, df):
        raise NotImplementedError("Subclasses of ModelAssay should implement `plot_results`.")
    
    def __call__(self, model_or_model_loader, transform):
        self.dataloader = self.get_dataloader(transform)        
        
        if isinstance(model_or_model_loader, nn.Module):
            model = model_or_model_loader
        else:
            model = model_or_model_loader()

        df = validate(model, self.dataloader)
        df['model_name'] = model.__dict__.get("model_name", model.__class__.__name__)
        df['dataset'] = self.dataset_name

        # Clear the cache
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
            
        return df
     
@torch.no_grad()        
def run_kNN(model, train_loader, test_loader, layer_names, num_classes=1000, 
            K=200, sigma=.07, num_chunks=10, out_device=None):
    '''
        we compute the full testFeatures, testLabels,
        
        then we iterate over the training set in batches, accumulating `num_chunks` (should
        be `num_batches`, but keeping the naming the same as run_kNN for api consistency).
        
        Finally we have the paiwise similarity between each val and each train image.
    '''
    
    if isinstance(layer_names, str):
        layer_names = [layer_names]
        
    if out_device is None:
        out_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
    print("==> extracting test features...")
    testFeatures, testLabels, indexes = get_features(model, test_loader, layer_names,
                                                     out_device=out_device)    
    
    print("==> extracting/comparing to train features...")
    topk_distances = defaultdict(lambda: torch.tensor([], device=out_device))
    trainLabels = defaultdict(lambda: torch.tensor([], device=out_device, dtype=torch.int64))
    trainIndexes = defaultdict(lambda: torch.tensor([], device=out_device, dtype=torch.int64))
    
    generator = gen_features(model, train_loader, layer_names, num_batches=num_chunks,
                             out_device=out_device)
    for batch_num, (trn_feat, trn_labels, trn_indexes) in enumerate(generator):                 
        for layer_name in layer_names:
            
            # compute distances between testFeatures and current train features
            d = torch.mm(testFeatures[layer_name], 
                         trn_feat[layer_name].T).to(out_device)
            
            # append these distances to running topk_distances
            topk_distances[layer_name] = torch.cat([topk_distances[layer_name], d], dim=1)
            
            # reshape train_labels (numTestImgs x numTrainImagesThisBatch)
            candidate_labels = trn_labels.view(1,-1).expand(len(testLabels), -1)
            # concat with retained trainLabels
            trainLabels[layer_name] = torch.cat([trainLabels[layer_name], 
                                                 candidate_labels], dim=1)
        
            # reshape train_indexes (numTestImgs x numTrainImagesThisBatch)
            candidate_indexes = trn_indexes.view(1,-1).expand(len(testLabels), -1)
            # concat with retained trainIndexes
            trainIndexes[layer_name] = torch.cat([trainIndexes[layer_name], 
                                                  candidate_indexes], dim=1)
        
            # keep the top K distances and labels  
            yd, yi = topk_distances[layer_name].topk(K, dim=1, largest=True, sorted=True)
            topk_distances[layer_name] = torch.gather(topk_distances[layer_name], 1, yi)
            trainLabels[layer_name] = torch.gather(trainLabels[layer_name], 1, yi)
            trainIndexes[layer_name] = torch.gather(trainIndexes[layer_name], 1, yi)
    
    # After iterating through the full training set, we have retained
    # the topk_distances, topk_labels, topk_indexes for the topk most
    # similar training items for each individual test item
    # generate weighted predictions.
    
    # Finally, we compute the predicted class through a similarity-weighted
    # voting amongst the topK separately for each layer
    print("==> computing top1,top5 accurcy: ...")
    top1_acc = dict()
    top5_acc = dict()
    
    for layer_name in progress_bar(layer_names):
        distances = topk_distances[layer_name]
        train_labels = trainLabels[layer_name]
        
        pred, top1, top5 = compute_knn_accuracy(distances, 
                                                train_labels, 
                                                testLabels, 
                                                num_classes=num_classes, 
                                                sigma=sigma)
        
        top1 = top1.float().sum(dim=1).mean().item() * 100
        top5 = top5.float().sum(dim=1).mean().item() * 100
        
        print(f"kNN accuracy {layer_name}: top1={top1}, top5={top5}")
        top1_acc[layer_name] = top1
        top5_acc[layer_name] = top5
        
    return top1_acc, top5_acc

def compute_knn_accuracy(distances, train_labels, test_labels, num_classes, sigma):
    """
    Computes the k-NN classification accuracy.

    :param distances: Tensor of distances between test and training features (num_Test x topK_Train)
    :param train_labels: Labels corresponding to the training data (num_Test x topK_Train)
    :param test_labels: Labels corresponding to the test data.
    :param num_classes: Total number of classes.
    :param sigma: Scaling parameter for distance transformation.
    :return: Tuple of (predictions, top1 accuracy, top5 accuracy)
    """
    num_test_images, K = train_labels.shape
    retrieval_one_hot = torch.zeros(K, num_classes).to('cpu')
    retrieval_one_hot.resize_(num_test_images * K, num_classes).zero_()
    retrieval_one_hot.scatter_(1, train_labels.view(-1, 1).cpu(), 1)
    yd_transform = distances.clone().div_(sigma).exp_().cpu()
    probs = torch.sum(torch.mul(retrieval_one_hot.view(num_test_images, -1 , num_classes), 
                                yd_transform.view(num_test_images, -1, 1)), 1)
    _, predictions = probs.sort(1, True)

    # Find which predictions match the target
    correct = predictions.eq(test_labels.view(-1,1).cpu())

    total = correct.size(0)
    top1 = correct.narrow(1,0,1)
    
    # Handle the case where the number of predictions is less than 5
    top_k = min(5, predictions.size(1))
    top5 = correct.narrow(1,0,top_k)
    
    return predictions, top1, top5

@torch.no_grad()
def get_features(model, dataloader, layer_names, device=None, out_device=None):    
    
    if isinstance(layer_names, str):
        layer_names = [layer_names]
        
    if device is None:
        device = next(model.parameters()).device
        
    if out_device is None:
        out_device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    model.eval()
    model.to(device)
    features,labels,indexes = defaultdict(list),[],[]
    
    with FeatureExtractor(model, layer_names, device=out_device) as extractor:
        for imgs,targs,idxs in progress_bar(dataloader):
            feat = extractor(imgs.to(device, non_blocking=True))
            
            for layer_name,X in feat.items():
                X = X.flatten(start_dim=1)
                X = F.normalize(X, dim=1)
                features[layer_name].append(X.to(out_device))
            labels.append(targs.to(out_device))
            indexes.append(idxs.to(out_device))
    
    for layer_name in layer_names:
        features[layer_name] = torch.cat(features[layer_name])
    labels = torch.cat(labels)
    indexes = torch.cat(indexes)
    
    return features, labels, indexes

@torch.no_grad()
def gen_features(model, dataloader, layer_names, num_batches=10, device=None, out_device=None):    
    
    if isinstance(layer_names, str):
        layer_names = [layer_names]
        
    if device is None:
        device = next(model.parameters()).device
        
    if out_device is None:
        out_device = 'cuda' if torch.cuda.is_available() else 'cpu'
                
    model.eval()
    model.to(device)
    features,targets,indexes = defaultdict(list), [], []
    batch_count=0
    for batch_num,(imgs,targs,idxs) in enumerate(progress_bar(dataloader)):
        batch_count+=1
        imgs = imgs.to(device, non_blocking=True)      
        with FeatureExtractor(model, layer_names, device=out_device) as extractor:
            feat = extractor(imgs)
        
        # normalize and aggregate features
        for layer_name,X in feat.items():
            X = X.flatten(start_dim=1) # flatten from dim1 onward
            X = F.normalize(X, dim=1)  # normalize across features
            features[layer_name].append(X.to(out_device))
            targets.append(targs.to(out_device))
            indexes.append(idxs.to(out_device))
        
        if batch_count==num_batches:
            #print(f"==> batch_num={batch_num}, batch_count={batch_count}")
            for layer_name in layer_names:
                features[layer_name] = torch.cat(features[layer_name])
            targets = torch.cat(targets)
            indexes = torch.cat(indexes)   
            yield features, targets, indexes
            features,targets,indexes = defaultdict(list), [], []
            batch_count=0
    
    # yield any remaining features
    if len(features[layer_name]) > 0:
        #print("==> wait, there's more!")
        for layer_name in layer_names:
            features[layer_name] = torch.cat(features[layer_name])
        targets = torch.cat(targets)
        indexes = torch.cat(indexes)

        yield features, targets, indexes

In [None]:
layer_names = ['fc8']
results = run_kNN(model, train_loader, test_loader, layer_names=layer_names)
results

In [None]:
layer_names = ['classifier.6']
results = run_kNN(model, train_loader, test_loader, layer_names=layer_names)
results

In [None]:
import torch
k = torch.tensor([1.2, 1.3])
dist = defaultdict(torch.Tensor)
torch.cat([dist['testing'], torch.tensor([1.2, 1.3])])

In [None]:
knn_assay = ClassificationNearestNeighbors(dataset='imagenette2_s320_remap1k', split='val')
knn_assay.dataset

In [None]:
import torch
from torchvision import models

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = models.alexnet(weights='IMAGENET1K_V1')
model.to(device)

In [None]:
layer_names = ['avgpool', 'classifier.2', 'classifier.5', 'classifier.6']

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

In [None]:
dataloader = knn_assay.get_dataloader(transform)
dataloader

In [None]:
testFeatures, testLabels, indexes = get_features(model, dataloader, layer_names, 
                                                 out_device='cuda')

In [None]:
for layer_name in layer_names:
    print(testFeatures[layer_name].shape, testFeatures[layer_name].dtype)
    print(testLabels[layer_name].shape, testLabels[layer_name].dtype)
    print(indexes[layer_name].shape, indexes[layer_name].dtype)    

In [None]:
generator = gen_features(model, dataloader, layer_names, num_batches=3, out_device='cuda')
num_total = 0
for batch_num, (feat, lab, ind) in enumerate(generator):
    print(batch_num, lab.shape, ind.shape)    
    for layer_name, X in feat.items():
        print(layer_name, X.shape)
    num_total += X.shape[0]
num_total, len(dataloader.dataset)

In [None]:
model, transform = torch.hub.load("harvard-visionlab/open_ipcl", "alexnetgn_ipcl_ref01",
                                  trust_repo=True, force_reload=True)
model.to(device)
transform

In [None]:
model

In [None]:
# from torchvision.datasets import ImageNet
from deep_analytics.datasets.folder import ImageNetIndex

root = '/n/alvarez_lab_tier1/Users/alvarez/datasets/imagenet1k-256'
train_dataset = ImageNetIndex(root, split='train', transform=transform)
train_dataset

In [None]:
from torch.utils.data import DataLoader

batch_size = 256
num_workers=len(os.sched_getaffinity(0))
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, 
                          shuffle=False, pin_memory=True)

In [None]:
train_loader = train_loader
test_loader = dataloader
layer_names = ['avgpool', 'classifier.2', 'classifier.5', 'classifier.6']
layer_names = ['ave_pool', 'fc6', 'fc7', 'fc8', 'l2norm']
top1, top5 = run_kNN(model, train_loader, test_loader, layer_names, 
                     num_classes=1000, K=200, sigma=.07, num_chunks=10, 
                     out_device=None)

In [None]:
top1

In [None]:
top1

In [None]:
top1

In [None]:
top5

In [None]:
top5

In [None]:
import nethook

def run_kNN_orig(model, train_loader, test_loader, layer_name, K=200, sigma=.07, num_chunks=200, out_device=None):
    print("extracting training features...")
    trainFeatures, trainLabels, indexes = get_features_orig(model, train_loader, layer_name, out_device=out_device)
    # trainFeatures = trainFeatures[layer_name]
    
    print("extracting test features...")
    testFeatures, testLabels, indexes = get_features_orig(model, test_loader, layer_name, out_device=out_device)
    # testFeatures = testFeatures[layer_name]
    
    print("running kNN test...")
    
    # split test features into chunks to avoid out-of-memory error:
    chunkFeatures = torch.chunk(testFeatures, num_chunks, dim=0)
    chunkLabels = torch.chunk(testLabels, num_chunks, dim=0)

    C = trainLabels.max() + 1
    top1, top5, total = 0., 0., 0.
    for features, labels in progress_bar(zip(chunkFeatures, chunkLabels), total=num_chunks):
        top1_, top5_, total_ = do_kNN(trainFeatures, trainLabels, features, labels, C, K, sigma)
        top1 += top1_ / 100 * total_
        top5 += top5_ / 100 * total_
        total += total_
    top1 = top1 / total * 100
    top5 = top5 / total * 100
    
    print(f"run_kNN accuracy: top1={top1}, top5={top5}")
    
    return top1, top5

def do_kNN(trainFeatures, trainLabels, testFeatures, testLabels, C, K, sigma, device=None, out_device=None):
    '''
        trainFeatures: [nTrainSamples, nFeatures]
        trainLabels: [nTrainSamples]
        
        testFeatures: [nTestSamples, nFeatures]
        testLabels: [nTestSamples]
    '''
    
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'            
    
    dist = torch.mm(testFeatures, trainFeatures.T).to(device)
    
    batchSize = len(testLabels)
    
    yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)
    
    candidates = trainLabels.view(1,-1).expand(batchSize, -1)
    
    retrieval = torch.gather(candidates, 1, yi)
    retrieval_one_hot = torch.zeros(K, C).to('cpu')
    retrieval_one_hot.resize_(batchSize * K, C).zero_()
    retrieval_one_hot.scatter_(1, retrieval.view(-1, 1).cpu(), 1)
    yd_transform = yd.clone().div_(sigma).exp_()
    probs = torch.sum(torch.mul(retrieval_one_hot.view(batchSize, -1 , C), yd_transform.view(batchSize, -1, 1).cpu()), 1)
    _, predictions = probs.sort(1, True)
    
    # Find which predictions match the target
    correct = predictions.eq(testLabels.view(-1,1).cpu())
    correct.shape
    
    total = correct.size(0)
    top1 = correct.narrow(1,0,1)
    top5 = correct.narrow(1,0,5)
    
    return top1, top5, total

def get_features_orig(model, dataloader, layer_name, device=None, out_device=None):    
    if not isinstance(model, nethook.InstrumentedModel):
        model = nethook.InstrumentedModel(model)
    model.retain_layers([layer_name])
    features,labels,indexes = [],[],[]
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    if out_device is None:
        out_device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
    model.to(device)
    model.eval()
    with torch.no_grad():
        for imgs,targs,idxs in progress_bar(dataloader):
            out = model(imgs.to(device))
            X = model.retained_layer(layer_name)
            X = F.normalize(X, dim=1)
            X = X.view(X.shape[0], -1)
            features.append(X.to(out_device))
            labels.append(targs.to(out_device))
            indexes.append(idxs.to(out_device))
    
    features = torch.cat(features)
    labels = torch.cat(labels)
    indexes = torch.cat(indexes)
    
    return features, labels, indexes

In [None]:
top1_, top5_ = run_kNN_orig(model, train_loader, test_loader, 'fc8', K=200, sigma=.07, 
                            num_chunks=200, out_device=None)

In [None]:
layer_name = 'fc8'
out_device = None

print("extracting training features...")
trainFeatures, trainLabels, indexes = get_features_orig(model, train_loader, layer_name, out_device=out_device)
# trainFeatures = trainFeatures[layer_name]

print("extracting test features...")
testFeatures, testLabels, indexes = get_features_orig(model, test_loader, layer_name, out_device=out_device)
# testFeatures = testFeatures[layer_name]

In [None]:
# split test features into chunks to avoid out-of-memory error:
num_chunks = 200
K = 200
sigma = .07
chunkFeatures = torch.chunk(testFeatures, num_chunks, dim=0)
chunkLabels = torch.chunk(testLabels, num_chunks, dim=0)

C = trainLabels.max() + 1
top1, top5 = [], []
for features, labels in progress_bar(zip(chunkFeatures, chunkLabels), total=num_chunks):
    top1_, top5_, total_ = do_kNN(trainFeatures, trainLabels, features, labels, C, K, sigma)
    top1.append(top1_)
    top5.append(top5_)    

In [None]:
torch.cat(top1).float().mean().item() * 100

# ClassificationNearestPrototype