In [1]:
import torch
from ssl_transfer.temperature_scaling import *

from torchvision import models
from torchvision import datasets, transforms
from datasets import Split_Dataset
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import Subset
import numpy as np

normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

test_dataset = datasets.ImageFolder('/gpfs/u/locker/200/CADS/datasets/ImageNet/val', transform=val_transforms)

val_dataset = Split_Dataset('/gpfs/u/locker/200/CADS/datasets/ImageNet',  \
                    f'./calib_splits/am_imagenet_5percent_val.txt',
                    transform=val_transforms)

test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=256, shuffle=True,
            num_workers=20, pin_memory=True,
        )
val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=256, shuffle=False,
            num_workers=20, pin_memory=True,
        )

In [2]:
def load_3_models(list_ckpts):
    model1 = models.resnet50().cuda()
    model2 = models.resnet50().cuda()
    model3 = models.resnet50().cuda()
    sd = torch.load(f"./dist_models/{list_ckpts[0]}/checkpoint_best.pth", map_location="cpu")
    ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
    model1.load_state_dict(ckpt)
    model1.eval()

    sd = torch.load(f"./dist_models/{list_ckpts[1]}/checkpoint_best.pth", map_location="cpu")
    ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
    model2.load_state_dict(ckpt)
    model2.eval()

    sd = torch.load(f"./dist_models/{list_ckpts[2]}/checkpoint_best.pth", map_location="cpu")
    ckpt = {k.replace("members.0.",""):v for k,v in sd['model'].items()}
    model3.load_state_dict(ckpt)
    model3.eval()

    return model1, model2, model3

In [3]:
ckpts = ['ft_base31_cos_lr0.003_bs256', 'ft_eq42_cos_lr0.003_bs256', 'ft_inv_cos_lr0.003_bs256']
model1, model2, model3 = load_3_models(ckpts)

In [4]:
ece, scaled_ece = cross_validate_temp_scaling(model1, test_loader, 256)

Computing model calibration
ECE: 0.035
before:  (40000, 1000) [[4.7876875e-10 7.4484615e-07 3.3853475e-08 ... 1.1475749e-07
  8.9384940e-07 1.3120692e-04]
 [3.5335674e-08 8.5207237e-09 2.0189219e-09 ... 1.7574538e-07
  1.6651375e-06 1.3615895e-08]
 [1.5400020e-07 7.9633908e-09 2.6656147e-03 ... 1.2136996e-09
  5.9379901e-10 1.1630232e-09]
 ...
 [2.9817897e-08 1.9462689e-07 5.9196173e-06 ... 6.7637750e-04
  2.9525464e-07 1.0579722e-09]
 [1.1466370e-05 4.9280385e-07 6.7082482e-08 ... 3.5087287e-05
  3.2921543e-04 3.6790135e-07]
 [4.8998063e-11 2.7785073e-11 1.2037255e-11 ... 2.2684603e-08
  1.0562255e-10 5.5400156e-13]]
after:  (40000, 1000) [[6.30369223e-09 3.67929465e-06 2.52564405e-07 ... 7.27494012e-07
  4.30921773e-06 3.25145171e-04]
 [3.47838153e-07 1.01400396e-07 2.91137727e-08 ... 1.39675296e-06
  9.80442564e-06 1.52214369e-07]
 [1.15259991e-06 8.84810518e-08 5.42767951e-03 ... 1.73316739e-08
  9.32786204e-09 1.67027778e-08]
 ...
 [2.91478727e-07 1.48134700e-06 2.85706137e-05 ...

KeyboardInterrupt: 

In [None]:
ece, scaled_ece = cross_validate_temp_scaling(model2, test_loader, 256)
ece, scaled_ece = cross_validate_temp_scaling(model3, test_loader, 256)

Computing model calibration
ECE: 0.027
Cross validation fold 0, temperature scaled ECE: 0.016
Cross validation fold 1, temperature scaled ECE: 0.015
Cross validation fold 2, temperature scaled ECE: 0.018
