In [1]:
import sys 
sys.path.append('..')
from src.clip import get_image_features, define_model, feature_dim
from src.build_classifier import get_classifier
from src.train_clf import train

In [2]:
import clip
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [3]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as FT
import torch

In [4]:
sys.path.append('..')
from cifar.cifarRawCorrupted import get_original_loaders, get_corrupt_loaders

In [5]:
device = 'cuda:1'
model, _ = define_model(device=device)

In [6]:
c_loader = get_corrupt_loaders(model_name='clip', severity=1)

In [7]:
u,kt = next(iter(c_loader))

In [8]:
u.shape

torch.Size([64, 3, 288, 288])

In [9]:
get_image_features(model, u.to(device)).shape

torch.Size([64, 640])

In [10]:
clip_clf = get_classifier(640, output_classes=10, n_layers=1).to(device)
train_loader, val_loader, test_loader = get_original_loaders(batch_size=1024, model_name='clip') 
test_corrupt_loader = get_corrupt_loaders(batch_size=1024, model_name='clip')

Files already downloaded and verified
Files already downloaded and verified


In [11]:
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(clip_clf.parameters(), lr=0.001)
n_epochs = 4

In [12]:
losses, accs, val_losses, val_accs = train(model, clip_clf, optim=optim, loss_fn=loss_fn,
                                           train_loader=train_loader, val_loader=val_loader,
                                           feature_fn=get_image_features, epochs=n_epochs, device=device) #TODO resize im in clip transforms

  input = module(input)


initial loss 2.302623963356018 and initial accuracy 0.0981186181306839
 train loss: 2.0446340799331666, val loss: 1.7381766200065614, Train accuracy 0.6327880620956421, val accuracy 0.8375996351242065 
 train loss: 1.6625311434268952, val loss: 1.6256305932998658, Train accuracy 0.857861340045929, val accuracy 0.8709203600883484 
 train loss: 1.6100441753864287, val loss: 1.60384703874588, Train accuracy 0.880444347858429, val accuracy 0.8813077807426453 
 train loss: 1.5956449300050735, val loss: 1.5912055134773255, Train accuracy 0.8869873285293579, val accuracy 0.8912906646728516 


In [14]:
torch.save(clip_clf.state_dict(), '../saved_models/clip_clf_resnet_4.pth')


In [15]:
def get_acc(gt, preds = None):
    if preds is not None: 
        return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
        
    
    return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
    

def get_test_acc(emb_model, model, test_loader, feature_fn, device='cuda'):
    eval_acc = []
    eval_losses = []
    for eval_batch in test_loader:
        if len(eval_batch)>2:
            _, ims, labels = eval_batch
        else: 
            ims, labels = eval_batch
        ims, labels = ims.to(device), labels.to(device)
        with torch.no_grad():
            features = feature_fn(emb_model, ims).squeeze()
            preds = model(features)
            val_acc = get_acc(labels.view(-1,), preds)
        
        eval_acc.append(val_acc)
    
    return np.mean(eval_acc)
            # 
test_acc_orig = racc =  get_test_acc(model, clip_clf, test_loader, get_image_features, device=device,)
 
print(test_acc_orig)

0.88372725


In [20]:
corrupts_dict = {}
corrupt_g_acc = []
for cr in ['gaussian_noise', 'speckle_noise', 'impulse_noise', 'shot_noise', ]:
    corrupts_dict[cr] = {}
    for sev in [1, 2, 3, 4, 5]:
        test_loader_corrupt = get_corrupt_loaders(batch_size=1024, corruption_type=cr, severity=sev, model_name='clip')
        acc =  get_test_acc(model, clip_clf, test_loader_corrupt, get_image_features, device=device,)
                                

        corrupts_dict[cr][sev]=acc

In [21]:
corrupts_dict

{'gaussian_noise': {1: 0.7270528,
  2: 0.5826969,
  3: 0.44952568,
  4: 0.3876395,
  5: 0.33660913},
 'speckle_noise': {1: 0.7966956,
  2: 0.6757573,
  3: 0.6117407,
  4: 0.49331355,
  5: 0.39814055},
 'impulse_noise': {1: 0.7782346,
  2: 0.700562,
  3: 0.6367865,
  4: 0.51759607,
  5: 0.42521125},
 'shot_noise': {1: 0.78822345,
  2: 0.71658164,
  3: 0.53617865,
  4: 0.4664541,
  5: 0.36796674}}

In [15]:
corrupts_dict # Resnet

{'gaussian_noise': {1: 0.66822386,
  2: 0.46558714,
  3: 0.28862005,
  4: 0.23345824,
  5: 0.18863998},
 'speckle_noise': {1: 0.7678073,
  2: 0.61615515,
  3: 0.52337176,
  4: 0.36995178,
  5: 0.25821707},
 'impulse_noise': {1: 0.77613604,
  2: 0.6875359,
  3: 0.59512913,
  4: 0.43106666,
  5: 0.30344787},
 'shot_noise': {1: 0.76686865,
  2: 0.6564692,
  3: 0.41788703,
  4: 0.3300522,
  5: 0.22840402}}

In [15]:
corrupts_dict # vitb16

{'gaussian_noise': {1: 0.8287767,
  2: 0.68462014,
  3: 0.52642095,
  4: 0.45319477,
  5: 0.39358658},
 'speckle_noise': {1: 0.8855748,
  2: 0.7826431,
  3: 0.71402866,
  4: 0.575851,
  5: 0.4634108},
 'impulse_noise': {1: 0.92388994,
  2: 0.8797334,
  3: 0.8381437,
  4: 0.7255361,
  5: 0.60639346},
 'shot_noise': {1: 0.881543,
  2: 0.8156689,
  3: 0.6336296,
  4: 0.5540637,
  5: 0.42462334}}

In [16]:
corrupts_dict

{'gaussian_noise': {1: 0.8200175,
  2: 0.6643335,
  3: 0.5043985,
  4: 0.42891026,
  5: 0.37114358},
 'speckle_noise': {1: 0.8763732,
  2: 0.7686045,
  3: 0.69716597,
  4: 0.55485094,
  5: 0.4382573},
 'impulse_noise': {1: 0.92210215,
  2: 0.8777124,
  3: 0.8338388,
  4: 0.7114158,
  5: 0.5900191},
 'shot_noise': {1: 0.8765625,
  2: 0.803388,
  3: 0.6125757,
  4: 0.53303176,
  5: 0.39992028}}