In [37]:
import os
import glob
import copy
import numpy as np
from tqdm import tqdm

import torch
from torch.utils.data import Subset, DataLoader, TensorDataset

from laplace import Laplace

from source.constants import RESULTS_PATH, CIFAR_MEAN, CIFAR_STD
from source.networks.resnet import get_resnet18
from source.networks.densenet import get_densenet169
from source.networks.regnet import get_regnet_y_800mf
from utils import load_train_dataset, load_test_dataset

In [38]:
seed = 42

dataset_names = ["cifar10", "cifar100", "svhn", "tin", "lsun"]
n_classes = [10, 100, 10, 200, 10]
models = ["resnet18", "densenet169", "regnet"]

dataset_name = dataset_names[2]    # select dataset
model = models[0]                   # select model

# infer number of classes from dataset
n_class = n_classes[dataset_names.index(dataset_name)]

n_models = 5

device = "cuda:0"
batch_size = 256
num_workers = 4 

# how many samples to draw from the posterior
n_samples = 10

print(dataset_name, model)

svhn resnet18


In [None]:
# create laplace results directory
orig_path = os.path.join(RESULTS_PATH, f"{dataset_name}_{model}_seed{seed}")
new_path = os.path.join(RESULTS_PATH, f"{dataset_name}_{model}_seed{seed}_laplace")
os.makedirs(new_path, exist_ok=True)

# copy models from original directory to new directory
model_files = glob.glob(os.path.join(orig_path, "models", "*.pt"))

os.makedirs(os.path.join(new_path, "models"), exist_ok=True)

for run_id in range(n_models):
    model_file = sorted(model_files)[run_id]
    print("moving", model_file)
    os.system(f"cp {model_file} {os.path.join(new_path, 'models', f'model_{run_id}.pt')}")
os.system(f"cp {os.path.join(orig_path, 'val_inds.pt')} {os.path.join(new_path, 'val_inds.pt')}")

model_files = glob.glob(os.path.join(new_path, "models", "*.pt"))

In [40]:
# load networks
networks = list()
for m, model_file in tqdm(enumerate(sorted(model_files)), desc="Loading models"):

    if model == "resnet18":
        network = get_resnet18(num_classes=n_class) 
    elif model == "densenet169":
        network = get_densenet169(num_classes=n_class)
    elif model == "regnet":
        network = get_regnet_y_800mf(num_classes=n_class)
    else:
        raise ValueError("Model not implemented")

    network.load_state_dict(torch.load(model_file, map_location=device))
    network.to(device)
    networks.append(copy.deepcopy(network))

Loading models: 5it [00:00,  7.99it/s]


In [None]:
full_train, _ = load_train_dataset(dataset_name)

val_inds = torch.load(os.path.join(orig_path, "val_inds.pt"))
train_inds = np.delete(np.arange(len(full_train)), (val_inds))

print(len(train_inds), len(val_inds))

# for training just train and val datasets necessary
train_ds = Subset(full_train, indices=train_inds)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [42]:
def evaluate(la, network, test_loader):
    # predict
    probits_, ys_ = list(), list()
    for x, y in tqdm(test_loader):
        x, y = x.to(device), y.to(device)
        probs_model = torch.softmax(network.forward(x), dim=1).detach().cpu().unsqueeze(1)
        probs = la._nn_predictive_samples(x, n_samples=n_samples - 1).permute(1,0,2).cpu()
        probits_.append(torch.cat([probs_model, probs], dim=1))
        ys_.append(y.cpu())
    probits_ = torch.cat(probits_, dim=0)
    ys_ = torch.cat(ys_, dim=0)
    return probits_, ys_

In [None]:
ds_probits = list()
atk_probits = list()
adv_path = os.path.join(new_path, "adversarial_examples")

for run_id, network in enumerate(networks):
    print("Laplace Approximation for model", run_id)

    network.train()
    # define laplace approximation
    la = Laplace(network, 
                likelihood='classification',
                subset_of_weights='last_layer',
                hessian_structure='kron')

    la.fit(train_loader)
    la.optimize_prior_precision(method="marglik")
    print(la.prior_precision, torch.log(la.prior_precision))
    
    network.eval()
    la.model.eval()

    # evaluate on test datasets
    for d, dataset_name in enumerate(dataset_names):
        print(f"> Evaluating on {dataset_name}")

        dataset = load_test_dataset(dataset_name)
        test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        probits_, ys_ = evaluate(la, network, test_loader)
        
        if len(ds_probits) == d:
            ds_probits.append(list())
        ds_probits[d].append(probits_)

        # save labels per dataset
        if run_id == 0:
            torch.save(ys_, os.path.join(new_path, f"{dataset_name}_ys.pt"))

    # Evaluate adversarial examples if available
    if os.path.exists(adv_path):
        # get all adversarial datasets
        adv_datasets = sorted(glob.glob(os.path.join(adv_path, "*.pt")))

        adv_ds_dict = {os.path.basename(adv_dataset).split("_")[0]: 
                [int(os.path.basename(adv_ds).split("_")[1].strip(".pt")) for adv_ds in adv_datasets 
                    if os.path.basename(adv_ds).split("_")[0] == os.path.basename(adv_dataset).split("_")[0] and "probits" not in os.path.basename(adv_ds)] 
                    for adv_dataset in adv_datasets}
        
        for a, adv_atk in enumerate(adv_ds_dict.keys()):
            print(f"> Evaluating on {adv_atk}")

            # load adversarial samples and map to [0, 1]
            images = torch.load(os.path.join(adv_path, f"{adv_atk}_{run_id}.pt"), map_location="cpu").to(torch.float32) / 255

            # apply cifar normalization
            images = (images - torch.tensor(CIFAR_MEAN).reshape(1, 3, 1, 1)) / torch.tensor(CIFAR_STD).reshape(1, 3, 1, 1)

            dataset = TensorDataset(images, torch.zeros(size=(len(images), )).long())
            test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

            # evaluate
            probits_, _ = evaluate(la, network, test_loader)
            if len(atk_probits) == a:
                atk_probits.append(list())
            atk_probits[a].append(probits_)
    else:
        print("No adversarial examples found for dataset", dataset_name)

for d, dataset_name in enumerate(dataset_names):
    probits = torch.stack(ds_probits[d], dim=1)
    print(probits.shape)
    torch.save(probits.to(torch.float16), os.path.join(new_path, f"{dataset_name}_probits.pt"))  

if len(atk_probits) > 0:
    for a, adv_atk in enumerate(adv_ds_dict.keys()):
        probits = torch.stack(atk_probits[a], dim=1)
        print(probits.shape)
        torch.save(probits.to(torch.float16), os.path.join(adv_path, f"{adv_atk}_probits.pt"))