In [1]:
import sys 
sys.path.append('..')
from src.resnet import get_image_features, define_model, feature_dim
from src.build_classifier import get_classifier
from src.train_clf import train

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as FT
import torch

In [3]:
sys.path.append('..')
from cifar.cifarRawCorrupted import get_original_loaders, get_corrupt_loaders

In [4]:
device = 'cuda:0'
model = define_model(device=device)



In [5]:
c_loader = get_corrupt_loaders(model_name='imagebind', severity=1)


In [6]:
u,kt = next(iter(c_loader))

In [7]:
get_image_features(model, u.to(device)).shape

torch.Size([64, 2048])

In [8]:
resnet_clf = get_classifier(feature_dim, output_classes=10, n_layers=1).to(device)
train_loader, val_loader, test_loader = get_original_loaders(batch_size=1024, model_name='resnet') 
test_corrupt_loader = get_corrupt_loaders(batch_size=1024, model_name='resnet')

Files already downloaded and verified
Files already downloaded and verified


In [9]:
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.Adam(resnet_clf.parameters(), lr=0.001)
n_epochs = 15

In [10]:
losses, accs, val_losses, val_accs = train(model, resnet_clf, optim=optim, loss_fn=loss_fn,
                                           train_loader=train_loader, val_loader=val_loader,
                                           feature_fn=get_image_features, epochs=n_epochs, device=device) #TODO resize im in clip transforms

  input = module(input)


initial loss 2.3026732921600344 and initial accuracy 0.0623864009976387
 train loss: 1.742794719338417, val loss: 1.627689814567566, Train accuracy 0.760498046875, val accuracy 0.8436185121536255 
 train loss: 1.612250429391861, val loss: 1.618070137500763, Train accuracy 0.85986328125, val accuracy 0.8526087999343872 
 train loss: 1.6003514289855958, val loss: 1.6120728492736816, Train accuracy 0.8692626953125, val accuracy 0.8552913665771484 
 train loss: 1.5913942873477935, val loss: 1.5991934895515443, Train accuracy 0.876782238483429, val accuracy 0.8681361079216003 
 train loss: 1.5837869733572005, val loss: 1.5896828651428223, Train accuracy 0.884326159954071, val accuracy 0.8753547668457031 
 train loss: 1.5784279495477676, val loss: 1.5901500821113586, Train accuracy 0.8881591558456421, val accuracy 0.8752790689468384 
 train loss: 1.5752985626459122, val loss: 1.5899160265922547, Train accuracy 0.890795886516571, val accuracy 0.8753467798233032 
 train loss: 1.569368073344230

In [11]:
torch.save(resnet_clf.state_dict(), '../saved_models/resnet_clf_50.pth')


In [12]:
def get_acc(gt, preds = None):
    if preds is not None: 
        return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
        
    
    return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
    

def get_test_acc(emb_model, model, test_loader, feature_fn, device='cuda'):
    eval_acc = []
    eval_losses = []
    for eval_batch in test_loader:
        if len(eval_batch)>2:
            _, ims, labels = eval_batch
        else: 
            ims, labels = eval_batch
        ims, labels = ims.to(device), labels.to(device)
        with torch.no_grad():
            features = feature_fn(emb_model, ims).squeeze()
            preds = model(features)
            val_acc = get_acc(labels.view(-1,), preds)
        
        eval_acc.append(val_acc)
    
    return np.mean(eval_acc)
            # 
test_acc_orig = racc =  get_test_acc(model, resnet_clf, test_loader, get_image_features, device=device,)
 
print(test_acc_orig)

0.8799585


In [13]:
corrupts_dict = {}
corrupt_g_acc = []
for cr in ['gaussian_noise', 'speckle_noise', 'impulse_noise', 'shot_noise', ]:
    corrupts_dict[cr] = {}
    for sev in [1, 2, 3, 4, 5]:
        test_loader_corrupt = get_corrupt_loaders(batch_size=1024, corruption_type=cr, severity=sev, model_name='resnet')
        acc =  get_test_acc(model, resnet_clf, test_loader_corrupt, get_image_features, device=device,)
                                

        corrupts_dict[cr][sev]=acc

In [14]:
corrupts_dict

{'gaussian_noise': {1: 0.72055566,
  2: 0.57129705,
  3: 0.41699418,
  4: 0.35500437,
  5: 0.30824298},
 'speckle_noise': {1: 0.7930066,
  2: 0.65925944,
  3: 0.58779496,
  4: 0.4652503,
  5: 0.3694854},
 'impulse_noise': {1: 0.64765424,
  2: 0.45254502,
  3: 0.36398277,
  4: 0.2950833,
  5: 0.26612923},
 'shot_noise': {1: 0.79162747,
  2: 0.69628704,
  3: 0.50337017,
  4: 0.43933955,
  5: 0.3334443}}