# let's validate

So you trained a model and want to test your classification accuracy?



In [None]:
import torch
from torchvision import models
from torchvision import transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
import os
import io
import gc
from collections import defaultdict
from contextlib import redirect_stdout
from functools import partial

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import torchattacks
import matplotlib.pyplot as plt
from fastprogress import master_bar, progress_bar
from torch.cuda.amp import autocast

from deep_analytics.assays.model_assay import ModelAssay
from deep_analytics.utils.bootstrap import bootstrap_multi_dim
from deep_analytics.utils.stats import AccumMetric

# from deep_analytics.assays.metrics import *

from pdb import set_trace

from types import SimpleNamespace

__all__ = ['ClassificationAccuracy']

class ClassificationAccuracy(ModelAssay):
    
    datasets = dict(
        imagenette2=('imagenette2_s320_remap1k', 'val'),
        imagenet1k=('imagenet1k_s256', 'val'),
        imagenetV2_top_images=('imagenetV2', 'top-images'),
        imagenetV2_threshold07=('imagenetV2', 'threshold0.7'),
        imagenetV2_matched_frequency=('imagenetV2', 'matched-frequency')
    )

    def compute_metrics(self, df):
        raise NotImplementedError("Subclasses of ModelAssay should implement `compute_metrics`.")
        
    def plot_results(self, df):
        raise NotImplementedError("Subclasses of ModelAssay should implement `plot_results`.")
    
    def __call__(self, model_or_model_loader, transform):
        self.dataloader = self.get_dataloader(transform)        
        
        if isinstance(model_or_model_loader, nn.Module):
            model = model_or_model_loader
        else:
            model = model_or_model_loader()

        df = validate(model, self.dataloader)
        df['model_name'] = model.__dict__.get("model_name", model.__class__.__name__)
        df['dataset'] = self.dataset_name

        # Clear the cache
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
            
        return df
     
@torch.no_grad()        
def validate(model, val_loader, print_freq=100, mb=None, store_outputs=False, set_eval=True):
    if set_eval:
        model.eval()
    device = next(model.parameters()).device
    criterion = nn.CrossEntropyLoss(reduction='none')
    filepaths = [(os.path.sep).join(f.split(os.path.sep)[-2:]) for f,_ in val_loader.dataset.imgs]
    
    results = defaultdict(list)
    count = 0
    for i, batch in enumerate(progress_bar(val_loader, parent=mb)):
        batch_size = batch[0].shape[0]
        images = batch[0].to(device, non_blocking=True)
        target = batch[1].to(device, non_blocking=True)
        index = batch[2].tolist()
        filenames = [filepaths[idx] for idx in index]
        
        with autocast():
            output = model(images)
        loss = criterion(output, target)
        preds, correct1, correct5 = accuracy(output, target, topk=(1, 5))

        results['index'] += index
        results['filenames'] += filenames
        results['correct_label'] += target.tolist()
        results['pred_label'] += preds[0].tolist()
        results['loss'] += loss.tolist()
        results['correct1'] += correct1.tolist()
        results['correct5'] += correct5.tolist()

    df = pd.DataFrame(results)

    return df

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        corrects = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            corrects.append(correct[:k].any(dim=0).reshape(-1).float())
        return pred, *corrects

In [None]:
model = models.alexnet(weights='IMAGENET1K_V1')
# model = models.resnet50(weights='IMAGENET1K_V1')
# model = models.resnet50(weights='IMAGENET1K_V1')
model.to(device)

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
transform

In [None]:
# import timm
# from transformers import CLIPProcessor, CLIPModel

# # timm.list_models(pretrained=True)

# model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
# model.to(device)

In [None]:
# model.

In [None]:
# model.eval()
# with torch.no_grad():
#     images = np.random.randint(0, 255, (500, 500, 3))
#     inputs = processor(text=[""], images=image, return_tensors="pt", padding=True).to(device)
#     outputs = model(**inputs)
#     image_embeddings = outputs.image_embeds 
# image_embeddings.shape

In [None]:
# from r3m import load_r3m
# model = load_r3m("resnet50") # resnet18, resnet34
# model.to(device).eval()

In [None]:
# from PIL import Image

# # ENCODE IMAGE
# image = np.random.randint(0, 255, (500, 500, 3))
# preprocessed_image = transforms(Image.fromarray(image.astype(np.uint8))).reshape(-1, 3, 224, 224)
# preprocessed_image.to(device) 
# with torch.no_grad():
#     embedding = model(preprocessed_image * 255.0) ## R3M expects image input to be [0-255]
# embedding.shape

In [None]:
# transforms = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     lambda x: x * 255.0
# ])


In [None]:
cls_assay = ClassificationAccuracy(dataset='imagenetV2', split='matched-frequency')
# cls_assay = ClassificationAccuracy(dataset='imagenet1k_s256', split='val')
cls_assay.dataset

In [None]:
img,label,index = cls_assay.dataset[0]
print(index,label)
img

In [None]:
# 63.15, 76.01

In [None]:
results = cls_assay(model, transform)
results.correct1.mean() * 100

In [None]:
cls_assay = ClassificationAccuracy(dataset='imagenet1k_s256', split='val')
results2 = cls_assay(model, transform)
results2.correct1.mean() * 100

In [None]:
results2.correct1.mean() * 100 - results.correct1.mean() * 100

In [None]:
import timm

model_name = 'vit_large_patch14_clip_224.openai_ft_in1k'
model_name = 'vgg11.tv_in1k'
# cfg = timm.get_pretrained_cfg(model_name).__dict__
# cfg

In [None]:
# timm.list_models(pretrained=True)
model = timm.create_model(model_name, pretrained=True)
model.to(device);

In [None]:
data_config = timm.data.resolve_model_data_config(model)
# data_config['input_size'] = (3,224,224)
# data_config['crop_pct'] = 1.0
data_config

In [None]:
timm.data.create_transform(**data_config, is_training=True)

In [None]:
transform = timm.data.create_transform(**data_config, is_training=False)
transform

In [None]:
model.eval()
img_size = data_config['input_size'][-2:]
with torch.no_grad():
    x = torch.rand(10,3,*img_size).to(device)
    output = model(x)
x.shape, output.shape

In [None]:
# img_size = cfg["test_input_size"][-1] if "test_input_size" in cfg and cfg["test_input_size"] else cfg["input_size"][-1]
# transform = timm.data.transforms_factory.transforms_imagenet_eval(
#     img_size=img_size,
#     interpolation=cfg["interpolation"],
#     mean=cfg["mean"],
#     std=cfg["std"],
#     crop_pct=cfg["crop_pct"],
#     crop_mode=cfg.get("crop_mode", None)
# )
# transform

In [None]:
# transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
# ])
# transform

In [None]:
cls_assay = ClassificationAccuracy(dataset='imagenetV2', split='matched-frequency')
# cls_assay = ClassificationAccuracy(dataset='imagenet1k_s256', split='val')
cls_assay.dataset

In [None]:
results = cls_assay(model, transform)
results.correct1.mean() * 100

In [None]:
cls_assay = ClassificationAccuracy(dataset='imagenet1k_s256', split='val')
results2 = cls_assay(model, transform)
results2.correct1.mean() * 100

In [None]:
results.correct1.mean() * 100 - results2.correct1.mean() * 100

In [None]:
# -13.202

In [None]:
model.forward_head

In [None]:
# from cleanvision import Imagelab

# data_path = '/n/alvarez_lab_tier1/Lab/cache/torch/data/imagenetv2-matched-frequency-format-val-5fbc2174/imagenetv2-matched-frequency-format-val'
# imagelab = Imagelab(data_path=data_path)

# imagelab.find_issues()
# imagelab.report()

In [None]:
cls_assay = ClassificationAccuracy(dataset='imagenette2_s320_remap1k', split='val')
res = cls_assay(model, transform)
res.correct1.mean() * 100, res.correct5.mean() * 100

In [None]:
res.groupby(by=['correct_label']).correct1.mean()

In [None]:
import torchvision.models as models
import torchvision.transforms as transforms

# Load the pretrained AlexNet model
alexnet_weights = models.AlexNet_Weights.DEFAULT
alexnet = models.alexnet(weights=alexnet_weights)

# Get the transforms used for the pretrained model
weights_transform = alexnet_weights.transforms()
print(alexnet)
print(weights_transform)

In [None]:
transform = transforms.Compose([
    transforms.Resize(weights_transform.crop_size, interpolation=weights_transform.interpolation),
    transforms.CenterCrop(weights_transform.resize_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=weights_transform.mean,std=weights_transform.std)
])
transform