In [7]:
import os
import glob
import copy
from tqdm import tqdm
from typing import List

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset

from source.constants import RESULTS_PATH, CIFAR_MEAN, CIFAR_STD
from source.networks.resnet import get_resnet18
from source.networks.densenet import get_densenet169
from source.networks.regnet import get_regnet_y_800mf
from utils import load_test_dataset

In [8]:
seed = 42

dataset_names = ["cifar10", "cifar100", "svhn", "tin", "lsun"]
n_classes = [10, 100, 10, 200, 10]
models = ["resnet18", "densenet169", "regnet"]

train_dataset = dataset_names[2]    # select dataset
model = models[2]                   # select model

# infer number of classes from dataset
n_class = n_classes[dataset_names.index(train_dataset)]

device = "cuda:0"
batch_size = 2048 

print(train_dataset, model)

svhn regnet


In [9]:
@torch.no_grad()
def evaluate(networks: List, ds: Dataset):
    probits = list()
    for network in tqdm(networks):
        network.eval()
        probits_, ys_ = list(), list()
        for x, y in DataLoader(ds, batch_size = batch_size, shuffle=False, drop_last=False):
            x = x.to(device)

            probits_.append(torch.softmax(network.forward(x), dim=1).cpu())
            ys_.append(y.cpu())
        probits_ = torch.concat(probits_, dim=0)
        probits.append(probits_)
    return torch.stack(probits, dim=1), torch.cat(ys_, dim=0)

In [10]:
path = os.path.join(RESULTS_PATH, f"{train_dataset}_{model}_seed{seed}")

# load networks
networks = list()
model_files = glob.glob(os.path.join(path, "models", "*.pt"))
for m in tqdm(range(len(model_files)), desc="Loading models"):
    model_file = os.path.join(path, "models", f"model_{m}.pt")

    if model == "resnet18":
        network = get_resnet18(num_classes=n_class) 
    elif model == "densenet169":
        network = get_densenet169(num_classes=n_class)
    elif model == "regnet":
        network = get_regnet_y_800mf(num_classes=n_class)

    network.load_state_dict(torch.load(model_file, map_location=device))
    network.to(device)
    networks.append(copy.deepcopy(network))

Loading models:   0%|          | 0/50 [00:00<?, ?it/s]

Loading models: 100%|██████████| 50/50 [00:08<00:00,  6.13it/s]


In [None]:
# evaluate on test datasets
for dataset_name in dataset_names:
    print(f"> Evaluating on {dataset_name}")

    dataset = load_test_dataset(dataset_name)

    # evaluate
    probits, ys = evaluate(networks, dataset)
    torch.save(probits.to(torch.float16), os.path.join(path, f"{dataset_name}_probits.pt"))
    torch.save(ys, os.path.join(path, f"{dataset_name}_ys.pt"))

In [12]:
# evaluate on adversarial samples if available
adv_path = os.path.join(path, "adversarial_examples")

runs = 5    #! number of runs

if os.path.exists(adv_path):
    # get all adversarial datasets
    adv_datasets = sorted(glob.glob(os.path.join(adv_path, "*.pt")))

    adv_ds_dict = {os.path.basename(adv_dataset).split("_")[0]: 
                    [int(os.path.basename(adv_ds).split("_")[1].strip(".pt")) for adv_ds in adv_datasets 
                     if os.path.basename(adv_ds).split("_")[0] == os.path.basename(adv_dataset).split("_")[0] and "probits" not in os.path.basename(adv_ds)] 
                     for adv_dataset in adv_datasets}

    for adv_atk in adv_ds_dict.keys():
        print(f"> Evaluating on {adv_atk}")

        probits = list()

        for runid in adv_ds_dict[adv_atk]:
            # load adversarial samples and map to [0, 1]
            images = torch.load(os.path.join(adv_path, f"{adv_atk}_{runid}.pt"), map_location="cpu").to(torch.float32) / 255

            # apply cifar normalization
            images = (images - torch.tensor(CIFAR_MEAN).reshape(1, 3, 1, 1)) / torch.tensor(CIFAR_STD).reshape(1, 3, 1, 1)

            dataset = TensorDataset(images, torch.zeros(size=(len(images), )).long())

            # evaluate
            probits_, _ = evaluate(networks[runid * (len(networks) // runs):((runid + 1) * (len(networks) // runs))], dataset)
            probits.append(probits_)
        probits = torch.cat(probits, dim=1)
        print(probits.shape)
        torch.save(probits.to(torch.float16), os.path.join(adv_path, f"{adv_atk}_probits.pt"))
else:
    print("No adversarial examples found, use `generate_adversarial_examples.py` to generate adversarial examples.")

No adversarial examples found, use `generate_adversarial_examples.py` to generate adversarial examples.
