In [1]:
import torch
import os

import numpy as np

from src.imagebind import define_model, get_transform
from imagebind import data
from tqdm.notebook import tqdm
from src.imagenet_labels import lab_dict
from imagebind.models.imagebind_model import ModalityType
from src.dataloaders import imagenet_dataloader, imagenet_c_dataloader



In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [3]:
model = define_model(device)
transform = get_transform()

ValueError: Error initializing torch.distributed using env:// rendezvous: environment variable RANK expected, but not set

In [None]:
cls_names = [lab_dict[i].replace('_', ' ') for i in os.listdir('../data/imagenet')]
cls_names = [f"a {c}" for c in cls_names]
text = data.load_and_transform_text(cls_names, device)

In [None]:
inputs = {
    ModalityType.TEXT: text,
}
with torch.no_grad():
    embeddings = model(inputs)
text_features = embeddings[ModalityType.TEXT]

In [None]:
def get_acc(gt, preds = None):
    if preds is not None: 
        return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
    return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
    

def get_test_acc(model, loader, device='cuda'):
    eval_acc = []
    for batch in tqdm(loader, leave=False):
        ims, labels = batch
        ims, labels = ims.to(device), labels.to(device)
        ims = torch.stack(ims, dim=0) # ??
        inputs = {
            ModalityType.VISION: ims,
        }
        with torch.no_grad():
            embeddings = model(inputs)
        image_features = embeddings[ModalityType.VISION]
        probs = torch.softmax(image_features @ text_features.T, dim=-1)
        
        val_acc = get_acc(labels.view(-1,), probs)
        eval_acc.append(val_acc)
    
    return np.mean(eval_acc)

In [None]:
gaussian_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    loader = imagenet_c_dataloader(corruption_name='gaussian_noise', severity=sev, batch_size=256, transform=transform)
    gaussian_noise_acc.append(get_test_acc(model, loader, device))

In [None]:
gaussian_noise_acc

In [None]:
impulse_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    loader = imagenet_c_dataloader(corruption_name='impulse_noise', severity=sev, batch_size=256, transform=transform)
    impulse_noise_acc.append(get_test_acc(model, loader, device))

In [None]:
impulse_noise_acc

In [None]:
shot_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    loader = imagenet_c_dataloader(corruption_name='shot_noise', severity=sev, batch_size=256, transform=transform)
    shot_noise_acc.append(get_test_acc(model, loader, device))

In [None]:
shot_noise_acc

In [None]:
speckle_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    loader = imagenet_c_dataloader(corruption_name='speckle_noise', severity=sev, batch_size=256, transform=transform)
    speckle_noise_acc.append(get_test_acc(model, loader, device))

In [None]:
speckle_noise_acc

In [None]:
loader = imagenet_dataloader(batch_size=256, transform=transform)
clean_acc = get_test_acc(model, loader, device)

In [None]:
clean_acc