In [None]:
import os
import numpy as np
import glob
import copy
from tqdm import tqdm

import torch
from torch.utils.data import Subset, DataLoader

from laplace import Laplace

from source.constants import RESULTS_PATH
from source.networks.resnet import get_resnet18
from source.data.cifar10_c import get_cifar10_c, corruptions
from utils import load_train_dataset

In [2]:
seed = 42

n_class = 10
models = ["resnet18"]

model = models[0]                   # select model

device = "cuda:0"
batch_size = 256
num_workers = 4 

n_models = 10

print(model)

resnet18


In [3]:
path = os.path.join(RESULTS_PATH, f"cifar10_{model}_seed{seed}_laplace")

results_path = os.path.join(path, "corruptions")
os.makedirs(results_path, exist_ok=True)

# load networks
networks = list()
model_files = glob.glob(os.path.join(path, "models", "*.pt"))
for m, model_file in tqdm(enumerate(sorted(model_files)), desc="Loading models"):

    if model == "resnet18":
        network = get_resnet18(num_classes=n_class) 

    network.load_state_dict(torch.load(model_file, map_location=device))
    network.to(device)
    networks.append(copy.deepcopy(network))

Loading models: 0it [00:00, ?it/s]

Loading models: 5it [00:14,  2.83s/it]


In [4]:
full_train, _ = load_train_dataset("cifar10")

val_inds = torch.load(os.path.join(path, "val_inds.pt"))
train_inds = np.delete(np.arange(len(full_train)), (val_inds))

print(len(train_inds), len(val_inds))

# for training just train and val datasets necessary
train_ds = Subset(full_train, indices=train_inds)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)

Files already downloaded and verified
41667 8333


In [None]:
ds_probits = list()

for n, network in enumerate(networks):
    print("Laplace Approximation for model", n)

    network.train()
    # define laplace approximation
    la = Laplace(network, 
                likelihood='classification',
                subset_of_weights='last_layer',
                hessian_structure='kron')

    la.fit(train_loader)
    la.optimize_prior_precision(method="marglik")
    print(la.prior_precision, torch.log(la.prior_precision))
    
    network.eval()
    la.model.eval()

    for c in tqdm(range(15)):
        for s in range(1, 6):
            cs = c * 5 + (s - 1)

            dataset = get_cifar10_c(corruption=c, severity=s)
            test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

            # predict
            probits_ = list()
            for x, _ in test_loader:
                x = x.to(device)
                probs_model = torch.softmax(network.forward(x), dim=1).detach().cpu().unsqueeze(1)
                probs = la._nn_predictive_samples(x, n_samples=n_models - 1).permute(1,0,2).cpu()
                probits_.append(torch.cat([probs_model, probs], dim=1))
            probits_ = torch.cat(probits_, dim=0)
            if len(ds_probits) == cs:
                ds_probits.append(list())
            ds_probits[cs].append(probits_)

In [13]:
for c in range(15):
    for s in range(1, 6):
        cs = c * 5 + (s - 1)
        dataset_name = f"{corruptions[c]}_{s}"
        probits = torch.stack(ds_probits[cs], dim=1)
        torch.save(probits.to(torch.float16), os.path.join(results_path, f"{dataset_name}_probits.pt"))