In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset, ConcatDataset
import numpy as np
import matplotlib.pyplot as plt
import yaml
from tqdm import tqdm
import torchvision.transforms.v2 as v2
from copy import deepcopy
import sys
sys.path.append("/n/home11/sambt/phlab-neurips25")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

from models.litmodels import SimCLRModel
from models.networks import CustomResNet, MLP
from data.datasets import CIFAR10Dataset
from data.cifar import CIFAR5MDataset
import data.data_utils as dutils

from sklearn.metrics import roc_auc_score, top_k_accuracy_score
from utils.plotting import make_corner

# evaluate pre-trained model; train classifier on embeddings

In [None]:
cifar = CIFAR10Dataset("resnet50",num_workers=2,batch_size=1024)
cifar_train_dataset = cifar.train_dataset
cifar_test_dataset = cifar.test_dataset

cifar5m_full = CIFAR5MDataset("resnet50",[0],[(None,50_000)],grayscale=True)

model = SimCLRModel.load_from_checkpoint('/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.5/lightning_logs/uj88ngsb/checkpoints/epoch=14-step=735.ckpt')
model = model.to(device)
model = model.eval()

In [None]:
cifar_embeds = []
cifar_labels = []
cifar_train_embeds = []
cifar_train_labels = []
cifar5m_embeds = []
cifar5m_labels = []

for ims,labs in tqdm(DataLoader(cifar_test_dataset,batch_size=1024,shuffle=False)):
    with torch.no_grad():
        cifar_embeds.append(model.encoder(ims.to(device)).cpu().numpy())
        cifar_labels.append(labs.numpy()) 
        
for ims,labs in tqdm(DataLoader(cifar_train_dataset,batch_size=1024,shuffle=False)):
    with torch.no_grad():
        cifar_train_embeds.append(model.encoder(ims.to(device)).cpu().numpy())
        cifar_train_labels.append(labs.numpy()) 
        
for ims,labs in tqdm(DataLoader(cifar5m_full,batch_size=1024,shuffle=False)):
    with torch.no_grad():
        cifar5m_embeds.append(model.encoder(ims.to(device)).cpu().numpy())
        cifar5m_labels.append(labs.numpy())
        
cifar_embeds = np.concatenate(cifar_embeds)
cifar_labels = np.concatenate(cifar_labels)
cifar_train_embeds = np.concatenate(cifar_train_embeds)
cifar_train_labels = np.concatenate(cifar_train_labels)
cifar5m_embeds = np.concatenate(cifar5m_embeds)
cifar5m_labels = np.concatenate(cifar5m_labels)

In [None]:
make_corner(cifar_embeds,cifar_labels)

In [None]:
sel_label=0
regular = cifar_embeds[cifar_labels==sel_label]
shift = cifar5m_embeds[cifar5m_labels==sel_label]
make_corner(np.concatenate([regular,shift],axis=0),
            labels=np.concatenate([np.zeros(len(regular)),np.ones(len(shift))]),
            label_names={0:"cifar10",1:"cifar5m"})

In [None]:
from models.networks import MLP
classifier = MLP(4,[10,10],10,dropout=0.2,activation='tanh')
classifier = classifier.to(device)
optimizer = torch.optim.AdamW(classifier.parameters(),lr=1e-3)

loader = DataLoader(TensorDataset(torch.tensor(cifar_train_embeds),torch.tensor(cifar_train_labels)),batch_size=512,shuffle=True,num_workers=2)
val_loader = DataLoader(TensorDataset(torch.tensor(cifar_embeds),torch.tensor(cifar_labels)),batch_size=512,shuffle=True,num_workers=2)

n_epoch = 20
best_state = None
best_loss = 9999
pbar = tqdm(range(n_epoch))
train_losses = []
val_losses = []
for i in pbar:
    losses = []
    for x,y in loader:
        #with torch.no_grad():
        #    embed = model.encoder(x.to(device)).detach()
        out = classifier(x.to(device))
        loss = F.cross_entropy(out,y.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    losses = np.mean(losses)
    train_losses.append(losses)
    
    losses = []
    for x,y in val_loader:
        with torch.no_grad():
        #    embed = model.encoder(x.to(device))
            out = classifier(x.to(device))
            loss = F.cross_entropy(out,y.to(device))
        losses.append(loss.item())
    losses = np.mean(losses)
    val_losses.append(losses)
    if losses < best_loss:
        best_loss = losses
        best_state = classifier.state_dict()
        
    pbar.set_postfix_str(f"train:{train_losses[-1]:.5f}, val:{val_losses[-1]:.5f}")

classifier.load_state_dict(best_state)

plt.figure(figsize=(8,6))
x = np.arange(1,n_epoch+1)
plt.plot(x,train_losses,label='train')
plt.plot(x,val_losses,label='val')
plt.legend()

In [None]:
preds_cifar = []
probs_cifar = []
preds_cifar5m = []
probs_cifar5m = []

classifier.eval()
with torch.no_grad():
    for x in tqdm(torch.split(torch.tensor(cifar_embeds),4096)):
        preds_cifar.append(classifier(x.to(device)).cpu().numpy())
        probs_cifar.append(F.softmax(torch.tensor(preds_cifar[-1]),dim=1).numpy())
    for x in tqdm(torch.split(torch.tensor(cifar5m_embeds),4096)):
        preds_cifar5m.append(classifier(x.to(device)).cpu().numpy())
        probs_cifar5m.append(F.softmax(torch.tensor(preds_cifar5m[-1]),dim=1).numpy())
preds_cifar = np.concatenate(preds_cifar)
probs_cifar = np.concatenate(probs_cifar)
preds_cifar5m = np.concatenate(preds_cifar5m)
probs_cifar5m = np.concatenate(probs_cifar5m)

In [None]:
print("Embedding space classifier metrics for CIFAR10 test set")
auc = roc_auc_score(cifar_labels,probs_cifar,multi_class='ovr')
print("OVR auc = ",auc)
auc = roc_auc_score(cifar_labels,probs_cifar,multi_class='ovo')
print("OVO auc = ",auc)
for k in range(1,6):
    topk = top_k_accuracy_score(cifar_labels,probs_cifar,k=k)
    print(f"Top {k} acc = ",topk)

In [None]:
print("Embedding space classifier metrics for CIFAR5m tiny set")
auc = roc_auc_score(cifar5m_labels,probs_cifar5m,multi_class='ovr')
print("OVR auc = ",auc)
auc = roc_auc_score(cifar5m_labels,probs_cifar5m,multi_class='ovo')
print("OVO auc = ",auc)
for k in range(1,6):
    topk = top_k_accuracy_score(cifar5m_labels,probs_cifar5m,k=k)
    print(f"Top {k} acc = ",topk)

In [None]:
top_probs_cifar5m = np.max(probs_cifar5m,axis=1)
plt.figure(figsize=(8,6))
h,bins,_ = plt.hist(top_probs_cifar5m,bins=np.linspace(0,1,100),histtype='step',density=False)

In [None]:
plt.plot((bins[1:]+bins[:-1])/2,np.cumsum(h)/h.sum())

In [None]:
cifar5m_finetune_dset = cifar5m_full.subselection(top_probs_cifar5m > 0.5)
cifar5m_finetune_train, cifar5m_finetune_test = cifar5m_finetune_dset.random_split(0.8)

In [None]:
h=plt.hist(cifar5m_finetune_train.dataset[1],bins=np.arange(-0.5,10.5),histtype='step',density=True)
h=plt.hist(cifar5m_finetune_test.dataset[1],bins=np.arange(-0.5,10.5),histtype='step',density=True)

# fine tune v1: backbone only

In [None]:
from models.finetune import FineTuner
from models.revgrad import GradientReversal

weights = torch.load('/n/home11/sambt/phlab-neurips25/runs/cifar10_simCLR_ResNet50_T0.5/lightning_logs/uj88ngsb/checkpoints/epoch=14-step=735.ckpt',
                    map_location='cpu')
sd = weights['state_dict']
encoder_weights = {k.replace("encoder.",""):v for k,v in sd.items() if "encoder" in k}
projector_weights = {k.replace("projector.",""):v for k,v in sd.items() if "projector" in k}

encoder = CustomResNet("resnet50",[512,256,128],4)
encoder.load_state_dict(encoder_weights)
for param in encoder.parameters():
    param.requires_grad = True

projector = MLP(4,[4],4)
projector.load_state_dict(projector_weights)

corrector = MLP(4,[16,4],1,activation='relu')
corrector = nn.Sequential(GradientReversal(alpha=1.0),corrector)
#corrector = None

tuner = FineTuner(encoder,projector,corrector).to(device)

In [None]:
tuner = tuner.eval()

cifar_embeds_pretune = []
cifar_labels_pretune = []
for batch in tqdm(DataLoader(cifar_test_dataset,batch_size=1024,shuffle=False,num_workers=2)):
    x,labels = batch
    with torch.no_grad():
        cifar_embeds_pretune.append(tuner.encoder(x.to(device)).cpu().numpy())
        cifar_labels_pretune.append(labels.numpy())
cifar_embeds_pretune = np.concatenate(cifar_embeds_pretune)
cifar_labels_pretune = np.concatenate(cifar_labels_pretune)


#cifar5m_indpt = CIFAR5MDataset("resnet50",[1],[(None,10_000)],grayscale=True)
cifar5m_test_embed_pretune = []
cifar5m_test_labels_pretune = []
for batch in tqdm(DataLoader(cifar5m_finetune_test,batch_size=1024,shuffle=False,num_workers=2)):
    x,labels = batch
    with torch.no_grad():
        cifar5m_test_embed_pretune.append(tuner.encoder(x.to(device)).cpu().numpy())
        cifar5m_test_labels_pretune.append(labels.numpy())
cifar5m_test_embed_pretune = np.concatenate(cifar5m_test_embed_pretune)
cifar5m_test_labels_pretune = np.concatenate(cifar5m_test_labels_pretune)

In [None]:
del tuner
del x, dset_label, batch, labels, domain_labels, preds, pos_mask, optimizer, loss
torch.cuda.empty_cache()

In [None]:
from models.losses import SupervisedSimCLRLoss

num_epoch = 10
patience_thresh = 100
criterion = SupervisedSimCLRLoss(temperature=0.5)
optimizer = torch.optim.AdamW(tuner.parameters(),lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=num_epoch,eta_min=1e-5)

best_state = None
best_loss = 9999
patience = 0
train_losses = []
val_losses = []
train_simclr_losses = []
train_class_losses = []
val_simclr_losses = []
val_class_losses = []

train_loader = DataLoader(dutils.ConcatWithLabels([cifar_train_dataset,cifar5m_finetune_train],[0,1]),
                          batch_size=512,shuffle=True,num_workers=2,drop_last=True)
val_loader = DataLoader(dutils.ConcatWithLabels([cifar_test_dataset,cifar5m_finetune_test],[0,1]),
                        batch_size=512,shuffle=True,num_workers=2,drop_last=True)


tuner = tuner.train()
#pbar = tqdm(range(num_epoch),position=0,leave=True)
lambda_class = 1.0
for i in range(num_epoch):
    losses = []
    losses_simclr = []
    losses_class = []
    for batch in tqdm(train_loader):
        x,dset_label = batch
        x,labels = x
        h = tuner.encoder(x.to(device))
        preds = tuner.corrector(h)
        domain_labels = (dset_label==1).float().to(device).unsqueeze(1)
        pos_mask = (labels.unsqueeze(1) == labels.unsqueeze(1).T).to(device) & (domain_labels == domain_labels.T)
        z = tuner.projector(h)
        z = F.normalize(z,dim=1).unsqueeze(1) # normalize the projection for simclr loss
        #loss_simclr = criterion(z, labels=labels)
        loss_simclr = criterion(z, mask=pos_mask)
        loss_class = F.binary_cross_entropy_with_logits(preds,domain_labels)
        loss = loss_simclr + lambda_class*loss_class
        #loss = loss_simclr
        #loss = loss_class
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        losses_simclr.append(loss_simclr.item())
        losses_class.append(loss_class.item())
    train_losses.append(np.mean(losses))
    train_simclr_losses.append(np.mean(losses_simclr))
    train_class_losses.append(np.mean(losses_class))
    
    losses = []
    losses_simclr = []
    losses_class = []
    aucs = []
    for batch in tqdm(val_loader):
        x,dset_label = batch
        x,labels = x
        with torch.no_grad():
            h = tuner.encoder(x.to(device))
            preds = tuner.corrector(h)
            domain_labels = (dset_label==1).float().to(device).unsqueeze(1)
            pos_mask = (labels.unsqueeze(1) == labels.unsqueeze(1).T).to(device) & (domain_labels == domain_labels.T)
            z = tuner.projector(h)
            z = F.normalize(z,dim=1).unsqueeze(1) # normalize the projection for simclr loss
            #loss_simclr = criterion(z, labels=labels)
            loss_simclr = criterion(z, mask=pos_mask)
            loss_class = F.binary_cross_entropy_with_logits(preds,domain_labels)
            loss = loss_simclr + lambda_class*loss_class
            #loss = loss_simclr
            #loss = loss_class
            losses.append(loss.item())
            losses_simclr.append(loss_simclr.item())
            losses_class.append(loss_class.item())
            aucs.append(roc_auc_score(domain_labels.cpu().numpy()[:,0],preds.cpu().numpy()[:,0]))
    losses = np.mean(losses)
    val_losses.append(losses)
    val_simclr_losses.append(np.mean(losses_simclr))
    val_class_losses.append(np.mean(losses_class))
    if losses < best_loss:
        best_loss = losses
        best_state = tuner.state_dict()
        patience = 0
    else:
        patience += 1
        if patience == patience_thresh:
            print(f"{patience_thresh} epochs of no improvement, stopping")
            break
        
    scheduler.step()
    
    #pbar.set_postfix_str(f"Train: {train_losses[-1]:.5f}, Val:{val_losses[-1]:.5f}")
    print(f"Epoch {i+1}, Train: {train_losses[-1]:.5f}, Val:{val_losses[-1]:.5f}, Val auc: {np.mean(aucs):.5f}")
    print(f"\t Train (simclr): {train_simclr_losses[-1]:.5f}, Val (simclr):{val_simclr_losses[-1]:.5f}")
    print(f"\t Train (class): {train_class_losses[-1]:.5f}, Val (class):{val_class_losses[-1]:.5f}")
    
tuner.load_state_dict(best_state)

plt.figure(figsize=(8,6))
x = np.arange(1,len(train_losses)+1)
plt.plot(x,train_losses,label='train',color="C0")
plt.plot(x,val_losses,label='val',color="C1")
plt.plot(x,train_simclr_losses,label='train (simclr)',color="C0",linestyle='--')
plt.plot(x,val_simclr_losses,label='val (simclr)',color="C1",linestyle='--')
plt.plot(x,train_class_losses,label='train (class)',color="C0",linestyle=':')
plt.plot(x,val_class_losses,label='val (class)',color="C1",linestyle=':')

In [None]:
tuner.load_state_dict(best_state)
tuner=tuner.eval()

In [None]:
tuner = tuner.eval()

cifar_embeds_tune = []
cifar_labels_tune = []
cifar_domainProbs_tune = []
for batch in tqdm(DataLoader(cifar_test_dataset,batch_size=1024,shuffle=True,num_workers=2)):
    x,labels = batch
    with torch.no_grad():
        cifar_embeds_tune.append(tuner.encoder(x.to(device)).cpu().numpy())
        cifar_labels_tune.append(labels.numpy())
        cifar_domainProbs_tune.append(torch.sigmoid(tuner.corrector(torch.tensor(cifar_embeds_tune[-1]).to(device))).cpu().numpy())
cifar_embeds_tune = np.concatenate(cifar_embeds_tune)
cifar_labels_tune = np.concatenate(cifar_labels_tune)
cifar_domainProbs_tune = np.concatenate(cifar_domainProbs_tune)


#cifar5m_indpt = CIFAR5MDataset("resnet50",[1],[(None,10_000)],grayscale=True)
cifar5m_test_embed_tune = []
cifar5m_test_labels_tune = []
cifar5m_test_domainProbs_tune = []
for batch in tqdm(DataLoader(cifar5m_finetune_train,batch_size=1024,shuffle=True,num_workers=2)):
    x,labels = batch
    with torch.no_grad():
        cifar5m_test_embed_tune.append(tuner.encoder(x.to(device)).cpu().numpy())
        cifar5m_test_labels_tune.append(labels.numpy())
        cifar5m_test_domainProbs_tune.append(torch.sigmoid(tuner.corrector(torch.tensor(cifar5m_test_embed_tune[-1]).to(device))).cpu().numpy())
cifar5m_test_embed_tune = np.concatenate(cifar5m_test_embed_tune)
cifar5m_test_labels_tune = np.concatenate(cifar5m_test_labels_tune)
cifar5m_test_domainProbs_tune = np.concatenate(cifar5m_test_domainProbs_tune)

In [None]:
plt.figure(figsize=(8,6))
bins = np.linspace(0,1,100)
h = plt.hist(cifar_domainProbs_tune,bins=bins,density=True,histtype='step')
h = plt.hist(cifar5m_test_domainProbs_tune,bins=bins,density=True,histtype='step')

In [None]:
from utils.plotting import make_corner
sel_label=2
regular = cifar_embeds_tune[cifar_labels_tune==sel_label]
shift = cifar5m_test_embed_tune[cifar5m_test_labels_tune==sel_label]
make_corner(np.concatenate([regular,shift],axis=0),
            labels=np.concatenate([np.zeros(len(regular)),np.ones(len(shift))]),
            label_names={0:"cifar10",1:"cifar5m"})

In [None]:
from utils.plotting import make_corner
sel_label=2
regular = cifar_embeds_tune[cifar_labels_tune==sel_label]
shift = cifar_embeds_pretune[cifar_labels_pretune==sel_label]
make_corner(np.concatenate([regular,shift],axis=0),
            labels=np.concatenate([np.zeros(len(regular)),np.ones(len(shift))]),
            label_names={0:"cifar10",1:"pretune"})

In [None]:
cifar5m_indpt = CIFAR5MDataset("resnet50",[1],[(None,100_000)],grayscale=True)

cifar_embeds_tuned = []
cifar_labels_tuned = []
cifar_train_embeds_tuned = []
cifar_train_labels_tuned = []
cifar5m_embeds_tuned = []
cifar5m_labels_tuned = []

for ims,labs in tqdm(DataLoader(cifar_test_dataset,batch_size=1024,shuffle=False)):
    with torch.no_grad():
        cifar_embeds_tuned.append(tuner.encoder(ims.to(device)).cpu().numpy())
        cifar_labels_tuned.append(labs.numpy()) 
        
for ims,labs in tqdm(DataLoader(cifar_train_dataset,batch_size=1024,shuffle=False)):
    with torch.no_grad():
        cifar_train_embeds_tuned.append(tuner.encoder(ims.to(device)).cpu().numpy())
        cifar_train_labels_tuned.append(labs.numpy()) 
        
for ims,labs in tqdm(DataLoader(cifar5m_indpt,batch_size=1024,shuffle=False)):
    with torch.no_grad():
        cifar5m_embeds_tuned.append(tuner.encoder(ims.to(device)).cpu().numpy())
        cifar5m_labels_tuned.append(labs.numpy())
        
cifar_embeds_tuned = np.concatenate(cifar_embeds_tuned)
cifar_labels_tuned = np.concatenate(cifar_labels_tuned)
cifar_train_embeds_tuned = np.concatenate(cifar_train_embeds_tuned)
cifar_train_labels_tuned = np.concatenate(cifar_train_labels_tuned)
cifar5m_embeds_tuned = np.concatenate(cifar5m_embeds_tuned)
cifar5m_labels_tuned = np.concatenate(cifar5m_labels_tuned)

In [None]:
from utils.plotting import make_corner
sel_label=0
make_corner(np.concatenate([cifar_embeds_tuned[cifar_labels_tuned==sel_label],
                            cifar5m_embeds_tuned[cifar5m_labels_tuned==sel_label]],axis=0),
            labels=np.concatenate([np.zeros(len(cifar_embeds_tuned[cifar_labels_tuned==sel_label])),
                                   np.ones(len(cifar5m_embeds_tuned[cifar5m_labels_tuned==sel_label]))]),
            label_names={0:"cifar10",1:"cifar5m"})

In [None]:
from models.networks import MLP
classifier = MLP(4,[10,10],10,dropout=0.2,activation='tanh')
classifier = classifier.to(device)
optimizer = torch.optim.AdamW(classifier.parameters(),lr=1e-3)

cifar_dset.batch_size = 512
loader = DataLoader(TensorDataset(torch.tensor(cifar_train_embeds),torch.tensor(cifar_train_labels)),batch_size=512,shuffle=True,num_workers=2)
val_loader = DataLoader(TensorDataset(torch.tensor(test_cifar_embeds),torch.tensor(cifar_labels)),batch_size=512,shuffle=True,num_workers=2)

n_epoch = 20
best_state = None
best_loss = 9999
pbar = tqdm(range(n_epoch))
train_losses = []
val_losses = []
for i in pbar:
    losses = []
    for x,y in loader:
        #with torch.no_grad():
        #    embed = model.encoder(x.to(device)).detach()
        out = classifier(x.to(device))
        loss = F.cross_entropy(out,y.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    losses = np.mean(losses)
    train_losses.append(losses)
    
    losses = []
    for x,y in val_loader:
        with torch.no_grad():
        #    embed = model.encoder(x.to(device))
            out = classifier(x.to(device))
            loss = F.cross_entropy(out,y.to(device))
        losses.append(loss.item())
    losses = np.mean(losses)
    val_losses.append(losses)
    if losses < best_loss:
        best_loss = losses
        best_state = classifier.state_dict()
        
    pbar.set_postfix_str(f"train:{train_losses[-1]:.5f}, val:{val_losses[-1]:.5f}")

classifier.load_state_dict(best_state)

plt.figure(figsize=(8,6))
x = np.arange(1,n_epoch+1)
plt.plot(x,train_losses,label='train')
plt.plot(x,val_losses,label='val')
plt.legend()

In [None]:
val_loader = DataLoader(TensorDataset(torch.tensor(cifar_test_embed),torch.tensor(cifar_test_labels)),batch_size=512,shuffle=False,num_workers=2)
val_loader_5m = DataLoader(TensorDataset(torch.tensor(cifar5m_test_embed),torch.tensor(cifar5m_test_labels)),batch_size=512,shuffle=False,num_workers=2)

preds_cifar = []
preds_cifar5m = []

classifier.eval()
with torch.no_grad():
    for x,y in val_loader:
        preds_cifar.append(classifier(x.to(device)).cpu().numpy())
    for x,y in val_loader_5m:
        preds_cifar5m.append(classifier(x.to(device)).cpu().numpy())
preds_cifar = np.concatenate(preds_cifar)
preds_cifar5m = np.concatenate(preds_cifar5m)

In [None]:
print("Embedding space classifier metrics for CIFAR10 test set")
auc = roc_auc_score(F.one_hot(torch.tensor(cifar_test_labels)).numpy(),preds_cifar)
print("auc = ",auc)
for k in range(1,6):
    topk = top_k_accuracy_score(cifar_test_labels,preds_cifar,k=k)
    print(f"Top {k} acc = ",topk)

In [None]:
print("Embedding space classifier metrics for CIFAR5m tiny set")
auc = roc_auc_score(F.one_hot(torch.tensor(cifar5m_test_labels)).numpy(),preds_cifar5m)
print("auc = ",auc)
for k in range(1,6):
    topk = top_k_accuracy_score(cifar5m_test_labels,preds_cifar5m,k=k)
    print(f"Top {k} acc = ",topk)