In [1]:
from models.model import ResNet50, ViTBase, CLIPModel
from models.utils import train_step, eval_step, DataLoaders, CIFAR100_Splits,DataLoader
from models.transforms import transforms_resnet,transforms_clip_vit
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10, CIFAR100
from tqdm import tqdm
import pandas as pd

### Prepare all models for evaluation on concept shift through CIFAR100 Splits:

In [None]:
train = CIFAR10(root='data',transform=transforms_resnet,download=True)
test = CIFAR10(root='data',transform=transforms_resnet,train=False,download=True)
BATCH_SIZE = 64
N_EPOCHS = 3
dl = DataLoaders(train,test,'resnet',BATCH_SIZE,True,'cifar10')
train_loader, test_loader = dl.get_loaders()

def train_and_eval(train_loader,test_loader,model,loss_fn,optimizer,device,modeltype):
    tr_metric = {"Accuracy":[],"Loss":[]}
    ts_metric = {"Accuracy":[],"Loss":[]}

    for epoch in tqdm(range(N_EPOCHS)):
        tr_loss, tr_acc = train_step(model,train_loader,loss_fn,optimizer,device,modeltype)
        ts_loss, ts_acc = eval_step(model,test_loader,loss_fn,device,modeltype,data="cifar10")

        tr_metric["Accuracy"].append(tr_acc)
        tr_metric["Loss"].append(tr_loss)

        ts_metric["Accuracy"].append(ts_acc)
        ts_metric["Loss"].append(ts_loss)
    
    return tr_metric, ts_metric

resnet_model = ResNet50(n_classes=10)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet_model.parameters())
device = 'cuda' if torch.cuda.is_available() else 'cpu'


tr,ts = train_and_eval(train_loader,test_loader,resnet_model,loss_func,optimizer,device,'resnet50')

print("Final Train Accuracy:",tr["Accuracy"][-1])
print("Final Test Accuracy:",ts["Accuracy"][-1])
print("Final Train Loss:",tr["Loss"][-1])
print("Final Test Loss:",ts["Loss"][-1])

In [None]:
train = CIFAR10(root='data',transform=transforms_clip_vit,download=True)
test = CIFAR10(root='data',transform=transforms_resnet,train=False,download=True)
dl = DataLoaders(train,test,'vit',BATCH_SIZE,True,'cifar10')
train_loader, test_loader = dl.get_loaders()

vit_model = ViTBase(n_classes=10)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vit_model.parameters())
device = 'cuda' if torch.cuda.is_available() else 'cpu'


tr,ts = train_and_eval(train_loader,test_loader,vit_model,loss_func,optimizer,device,'vit')

print("Final Train Accuracy:",tr["Accuracy"][-1])
print("Final Test Accuracy:",ts["Accuracy"][-1])
print("Final Train Loss:",tr["Loss"][-1])
print("Final Test Loss:",ts["Loss"][-1])

In [None]:
train = CIFAR10(root='data',download=True)
test = CIFAR10(root='data',train=False,download=True)
dl = DataLoaders(train,test,'clip',BATCH_SIZE,True,'cifar10')
train_loader, test_loader = dl.get_loaders()

clip_model = CLIPModel(256)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(clip_model.parameters())
device = 'cuda' if torch.cuda.is_available() else 'cpu'


tr,ts = train_and_eval(train_loader,test_loader,clip_model,loss_func,optimizer,device,'clip')

print("Final Train Accuracy:",tr["Accuracy"][-1])
print("Final Test Accuracy:",ts["Accuracy"][-1])
print("Final Train Loss:",tr["Loss"][-1])
print("Final Test Loss:",ts["Loss"][-1])

### Load in the CIFAR100 dataset and make the splits according to the paper

In [None]:
groups_datasets = {'resnet':[],'vit':[],'clip':[]}
test = CIFAR100(root='data',train=False,download=True)

for i in range(10):
    for group in groups_datasets.keys():
        groups_datasets[group].append(CIFAR100_Splits(test.data,test.targets,i,group))

In [None]:
metrics = {'resnet':{},'vit':{},'clip':{}}
for group in groups_datasets.keys():
        if group == 'resnet':
                model = resnet_model
        elif group == "vit":
                model = vit_model
        else:
                model = clip_model
        loss_func = nn.CrossEntropyLoss()
        for dataset in groups_datasets[group]:
                dl = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)
                ts_loss, ts_acc = eval_step(model,dl,loss_func,device,group,data="cifar10")
                metrics[group].append(ts_acc)

pd.DataFrame(metrics).to_csv('cifar100_splits_test_result.csv')

In [None]:
groups_datasets = {'resnet':[],'vit':[],'clip':[]}
train = CIFAR100(root='data',train=True,download=True)

for i in range(10):
    for group in groups_datasets.keys():
        groups_datasets[group].append(CIFAR100_Splits(train.data,train.targets,i,group))

In [None]:
metrics = {'resnet':{},'vit':{},'clip':{}}
for group in groups_datasets.keys():
        if group == 'resnet':
                model = resnet_model
        elif group == "vit":
                model = vit_model
        else:
                model = clip_model
        loss_func = nn.CrossEntropyLoss()
        for dataset in groups_datasets[group]:
                dl = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)
                ts_loss, ts_acc = eval_step(model,dl,loss_func,device,group,data="cifar10")
                metrics[group].append(ts_acc)

pd.DataFrame(metrics).to_csv('cifar100_splits_train_result.csv')