In [None]:
import torch
import clip
import os

import numpy as np

from src.imagenet_labels import lab_dict
from tqdm.notebook import tqdm
from src.dataloaders import imagenet_c_dataloader, imagenet_dataloader

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
model, transform = clip.load("ViT-B/32", device="cpu")

In [None]:
cls_names = [lab_dict[i].replace('_', ' ') for i in os.listdir('../data/imagenet')]
cls_names = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cls_names]).to("cpu")

In [None]:
text_features = model.encode_text(cls_names)
text_features /= text_features.norm(dim=-1, keepdim=True)

In [None]:
model = model.to(device)

In [None]:
def get_acc(gt, preds = None):
    if preds is not None: 
        return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
    return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
    

def get_test_acc(model, loader, device='cuda'):
    eval_acc = []
    for batch in tqdm(loader, leave=False):
        ims, labels = batch
        ims, labels = ims.to(device), labels.to("cpu")
        with torch.no_grad():
            image_features = model.encode_image(ims)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            image_features = image_features.to("cpu")
            probs = (image_features @ text_features.T).softmax(dim=-1)
            
        val_acc = get_acc(labels.view(-1,), probs)
        eval_acc.append(val_acc)
    
    return np.mean(eval_acc)

In [None]:
gaussian_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    loader = imagenet_c_dataloader(corruption_name='gaussian_noise', severity=sev, batch_size=256, transform=transform)
    gaussian_noise_acc.append(get_test_acc(model, loader, device))

In [None]:
gaussian_noise_acc

In [None]:
impulse_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    loader = imagenet_c_dataloader(corruption_name='impulse_noise', severity=sev, batch_size=256, transform=transform)
    impulse_noise_acc.append(get_test_acc(model, loader, device))

In [None]:
impulse_noise_acc

In [None]:
shot_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    loader = imagenet_c_dataloader(corruption_name='shot_noise', severity=sev, batch_size=256, transform=transform)
    shot_noise_acc.append(get_test_acc(model, loader, device))

In [None]:
shot_noise_acc

In [None]:
speckle_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    loader = imagenet_c_dataloader(corruption_name='speckle_noise', severity=sev, batch_size=256, transform=transform)
    speckle_noise_acc.append(get_test_acc(model, loader, device))

In [None]:
speckle_noise_acc

In [None]:
loader = imagenet_dataloader(batch_size=256, transform=transform)
clean_acc = get_test_acc(model, loader, device)

In [None]:
clean_acc