In [None]:
import os
import glob
import copy
from tqdm import tqdm
from typing import List

import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset

from source.utils.seeding import fix_seeds
from source.constants import RESULTS_PATH, CIFAR_MEAN, CIFAR_STD
from source.networks.resnet import get_resnet18_d
from utils import load_test_dataset

In [16]:
seed = 42

dataset_names = ["cifar10", "cifar100", "svhn", "tin", "lsun"]
n_classes = [10, 100, 10, 200, 10]
models = ["resnet18"]

train_dataset = dataset_names[3]    # select dataset
model = models[0]                   # select model

p_drop = 0.2                        # dropout probability   
n_models = 10                     # number of dropout models

# infer number of classes from dataset
n_class = n_classes[dataset_names.index(train_dataset)]

device = "cuda:0"
batch_size = 2048 

path = os.path.join(RESULTS_PATH, f"{train_dataset}_{model}_dropout{p_drop}_seed{seed}")

print(train_dataset, model)

tin resnet18


In [17]:
@torch.no_grad()
def evaluate(networks: List, ds: Dataset, n_models: int = n_models, seed: int = seed):
    probits = list()
    for network in tqdm(networks):
        probits_, ys_ = list(), list()
        for x, y in DataLoader(ds, batch_size = batch_size, shuffle=False, drop_last=False):
            x = x.to(device)

            # for consistent dropout masks
            fix_seeds(seed)
            
            probits__ = list()
            for n in range(n_models):
                
                # first model is normal model, others are dropout models
                if n == 0:
                    network.eval()
                else:
                    network.train()
                    # put bachnorm in eval mode
                    for m in network.modules():
                        if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
                            m.eval()

                probits__.append(torch.softmax(network.forward(x), dim=1).cpu())
            probits__ = torch.stack(probits__, dim=1)
            probits_.append(probits__)
            ys_.append(y.cpu())

        probits_ = torch.concat(probits_, dim=0)
        probits.append(probits_)
    return torch.stack(probits, dim=1), torch.cat(ys_, dim=0)

In [18]:
# load networks
networks = list()
model_files = glob.glob(os.path.join(path, "models", "*.pt"))
for m, model_file in tqdm(enumerate(sorted(model_files)), desc="Loading models"):

    if model == "resnet18":
        network = get_resnet18_d(num_classes=n_class, p_drop=p_drop) 
    else:
        raise NotImplementedError("Model not supported")

    network.load_state_dict(torch.load(model_file, map_location=device))
    network.to(device)
    networks.append(copy.deepcopy(network))

Loading models: 5it [00:00,  7.22it/s]


In [None]:
for dataset_name in dataset_names:
    print(f"> Evaluating on {dataset_name}")

    dataset = load_test_dataset(dataset_name)

    # evaluate
    probits, ys = evaluate(networks, dataset)
    torch.save(probits.to(torch.float16), os.path.join(path, f"{dataset_name}_probits.pt"))
    torch.save(ys, os.path.join(path, f"{dataset_name}_ys.pt"))

In [19]:
# evaluate on adversarial samples if available
adv_path = os.path.join(path, "adversarial_examples")

runs = 5    #! number of runs

if os.path.exists(adv_path):
    # get all adversarial datasets
    adv_datasets = sorted(glob.glob(os.path.join(adv_path, "*.pt")))

    adv_ds_dict = {os.path.basename(adv_dataset).split("_")[0]: 
                    [int(os.path.basename(adv_ds).split("_")[1].strip(".pt")) for adv_ds in adv_datasets 
                     if os.path.basename(adv_ds).split("_")[0] == os.path.basename(adv_dataset).split("_")[0] and "probits" not in os.path.basename(adv_ds)] 
                     for adv_dataset in adv_datasets}

    for adv_atk in adv_ds_dict.keys():
        print(f"> Evaluating on {adv_atk}")

        probits = list()

        for runid in adv_ds_dict[adv_atk]:
            # load adversarial samples and map to [0, 1]
            images = torch.load(os.path.join(adv_path, f"{adv_atk}_{runid}.pt"), map_location="cpu").to(torch.float32) / 255

            # apply cifar normalization
            images = (images - torch.tensor(CIFAR_MEAN).reshape(1, 3, 1, 1)) / torch.tensor(CIFAR_STD).reshape(1, 3, 1, 1)

            dataset = TensorDataset(images, torch.zeros(size=(len(images), )).long())

            # evaluate
            probits_, _ = evaluate(networks[runid * (len(networks) // runs):((runid + 1) * (len(networks) // runs))], dataset)
            probits.append(probits_)
        probits = torch.cat(probits, dim=1)
        print(probits.shape)
        torch.save(probits.to(torch.float16), os.path.join(adv_path, f"{adv_atk}_probits.pt"))
else:
    print("No adversarial examples found, use `generate_adversarial_examples.py` to generate adversarial examples.")

> Evaluating on fgsm


100%|██████████| 1/1 [00:06<00:00,  6.04s/it]
100%|██████████| 1/1 [00:06<00:00,  6.06s/it]
100%|██████████| 1/1 [00:06<00:00,  6.02s/it]
100%|██████████| 1/1 [00:06<00:00,  6.01s/it]
100%|██████████| 1/1 [00:06<00:00,  6.08s/it]


torch.Size([10000, 5, 10, 200])
> Evaluating on linfpgd


100%|██████████| 1/1 [00:06<00:00,  6.02s/it]
100%|██████████| 1/1 [00:06<00:00,  6.03s/it]
100%|██████████| 1/1 [00:06<00:00,  6.09s/it]
100%|██████████| 1/1 [00:06<00:00,  6.05s/it]
100%|██████████| 1/1 [00:06<00:00,  6.05s/it]


torch.Size([10000, 5, 10, 200])
