In [1]:
import sys 
sys.path.append('..')
from src.vit import get_image_features, define_model, feature_dim
from src.build_classifier import get_classifier
from src.train_clf import train

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as FT
import torch

In [3]:
sys.path.append('..')
from cifar.cifarRawCorrupted import get_original_loaders, get_corrupt_loaders

In [4]:
device = 'cuda:1'
model = define_model(device=device)

In [5]:
c_loader = get_corrupt_loaders(model_name='vit', severity=1)


In [6]:
u,kt = next(iter(c_loader))

In [7]:
u.shape

torch.Size([64, 3, 224, 224])

In [7]:
get_image_features(model, u.to(device)).shape

torch.Size([64, 192])

In [5]:
vit_clf = get_classifier(feature_dim, output_classes=10, n_layers=1).to(device)
train_loader, val_loader, test_loader = get_original_loaders(batch_size=1024, model_name='vit') 
test_corrupt_loader = get_corrupt_loaders(batch_size=1024, model_name='vit')

Files already downloaded and verified
Files already downloaded and verified


In [6]:
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(vit_clf.parameters(), lr=0.001)
n_epochs = 1

In [7]:
losses, accs, val_losses, val_accs = train(model, vit_clf, optim=optim, loss_fn=loss_fn,
                                           train_loader=train_loader, val_loader=val_loader,
                                           feature_fn=get_image_features, epochs=n_epochs, device=device) #TODO resize im in clip transforms

  input = module(input)


initial loss 2.297089695930481 and initial accuracy 0.10561623424291611
 train loss: 1.576060163974762, val loss: 1.5215713143348695, Train accuracy 0.8961426019668579, val accuracy 0.9419762492179871 


In [8]:
torch.save(vit_clf.state_dict(), '../saved_models/vit_clf_tt.pth')


In [9]:
def get_acc(gt, preds = None):
    if preds is not None: 
        return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
        
    
    return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
    

def get_test_acc(emb_model, model, test_loader, feature_fn, device='cuda'):
    eval_acc = []
    eval_losses = []
    for eval_batch in test_loader:
        if len(eval_batch)>2:
            _, ims, labels = eval_batch
        else: 
            ims, labels = eval_batch
        ims, labels = ims.to(device), labels.to(device)
        with torch.no_grad():
            features = feature_fn(emb_model, ims).squeeze()
            preds = model(features)
            val_acc = get_acc(labels.view(-1,), preds)
        
        eval_acc.append(val_acc)
    
    return np.mean(eval_acc)
            # 
test_acc_orig = racc =  get_test_acc(model, vit_clf, test_loader, get_image_features, device=device,)
 
print(test_acc_orig)

0.9419384


In [10]:
corrupts_dict = {}
corrupt_g_acc = []
for cr in ['gaussian_noise', 'speckle_noise', 'impulse_noise', 'shot_noise', ]:
    corrupts_dict[cr] = {}
    for sev in [1, 2, 3, 4, 5]:
        test_loader_corrupt = get_corrupt_loaders(batch_size=1024, corruption_type=cr, severity=sev, model_name='vit')
        acc =  get_test_acc(model, vit_clf, test_loader_corrupt, get_image_features, device=device,)
                                

        corrupts_dict[cr][sev]=acc
    print(corrupts_dict)

  input = module(input)


{'gaussian_noise': {1: 0.8157047, 2: 0.69214964, 3: 0.56577843, 4: 0.5023956, 5: 0.4608817}}
{'gaussian_noise': {1: 0.8157047, 2: 0.69214964, 3: 0.56577843, 4: 0.5023956, 5: 0.4608817}, 'speckle_noise': {1: 0.8630401, 2: 0.7723513, 3: 0.7180146, 4: 0.607099, 5: 0.51486963}}
{'gaussian_noise': {1: 0.8157047, 2: 0.69214964, 3: 0.56577843, 4: 0.5023956, 5: 0.4608817}, 'speckle_noise': {1: 0.8630401, 2: 0.7723513, 3: 0.7180146, 4: 0.607099, 5: 0.51486963}, 'impulse_noise': {1: 0.88996935, 2: 0.8419563, 3: 0.7927077, 4: 0.6654018, 5: 0.54996216}}
{'gaussian_noise': {1: 0.8157047, 2: 0.69214964, 3: 0.56577843, 4: 0.5023956, 5: 0.4608817}, 'speckle_noise': {1: 0.8630401, 2: 0.7723513, 3: 0.7180146, 4: 0.607099, 5: 0.51486963}, 'impulse_noise': {1: 0.88996935, 2: 0.8419563, 3: 0.7927077, 4: 0.6654018, 5: 0.54996216}, 'shot_noise': {1: 0.8587971, 2: 0.79785955, 3: 0.6485392, 4: 0.5892957, 5: 0.4926319}}


In [15]:
corrupts_dict #VITb16

{'gaussian_noise': {1: 0.797716,
  2: 0.6710878,
  3: 0.5398816,
  4: 0.47904974,
  5: 0.4321588},
 'speckle_noise': {1: 0.85576975,
  2: 0.75682396,
  3: 0.699432,
  4: 0.5807198,
  5: 0.485792},
 'impulse_noise': {1: 0.8809471,
  2: 0.8299845,
  3: 0.7735471,
  4: 0.6436564,
  5: 0.5237404},
 'shot_noise': {1: 0.8467933,
  2: 0.7868164,
  3: 0.63079756,
  4: 0.5654476,
  5: 0.4617885}}