In [None]:
import os
import glob
import copy
from tqdm import tqdm
from typing import List

import torch
from torch.utils.data import Dataset, DataLoader

from source.constants import RESULTS_PATH
from source.utils.seeding import fix_seeds
from source.networks.resnet import get_resnet18_d
from source.data.cifar10_c import get_cifar10_c, corruptions

In [2]:
seed = 42

n_class = 10
models = ["resnet18"]

model = models[0]                   # select model

device = "cuda:0"
batch_size = 2048 

p_drop = 0.2
n_models = 10                     # number of dropout models

print(model)

resnet18


In [3]:
@torch.no_grad()
def evaluate(networks: List, ds: Dataset, n_models: int = n_models, seed: int = seed):
    probits = list()
    for network in tqdm(networks):
        probits_ = list()
        for x, _ in DataLoader(ds, batch_size = batch_size, shuffle=False, drop_last=False):
            x = x.to(device)

            # for consistent dropout masks
            fix_seeds(seed)
            
            probits__ = list()
            for n in range(n_models):
                
                # first model is normal model, others are dropout models
                if n == 0:
                    network.eval()
                else:
                    network.train()
                    # put bachnorm in eval mode
                    for m in network.modules():
                        if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
                            m.eval()

                probits__.append(torch.softmax(network.forward(x), dim=1).cpu())
            probits__ = torch.stack(probits__, dim=1)
            probits_.append(probits__)

        probits_ = torch.concat(probits_, dim=0)
        probits.append(probits_)
    return torch.stack(probits, dim=1)

In [4]:
path = os.path.join(RESULTS_PATH, f"cifar10_{model}_dropout{p_drop}_seed{seed}")

results_path = os.path.join(path, "corruptions")
os.makedirs(results_path, exist_ok=True)

# load networks
networks = list()
model_files = glob.glob(os.path.join(path, "models", "*.pt"))
for m, model_file in tqdm(enumerate(sorted(model_files)), desc="Loading models"):

    if model == "resnet18":
        network = get_resnet18_d(num_classes=n_class, p_drop=p_drop) 
    network.load_state_dict(torch.load(model_file, map_location=device))
    network.to(device)
    networks.append(copy.deepcopy(network))

for c in range(15):
    for s in range(1, 6):
        dataset_name = f"{corruptions[c]}_{s}"
        print(f"> Evaluating on {dataset_name}")

        dataset = get_cifar10_c(corruption=c, severity=s)

        # evaluate
        probits = evaluate(networks, dataset, n_models=n_models, seed=seed)
        torch.save(probits.to(torch.float16), os.path.join(results_path, f"{dataset_name}_probits.pt"))

Loading models: 5it [00:01,  3.38it/s]


> Evaluating on brightness_1


100%|██████████| 5/5 [00:31<00:00,  6.32s/it]


> Evaluating on brightness_2


100%|██████████| 5/5 [00:30<00:00,  6.03s/it]


> Evaluating on brightness_3


100%|██████████| 5/5 [00:30<00:00,  6.07s/it]


> Evaluating on brightness_4


100%|██████████| 5/5 [00:30<00:00,  6.08s/it]


> Evaluating on brightness_5


100%|██████████| 5/5 [00:30<00:00,  6.10s/it]


> Evaluating on contrast_1


100%|██████████| 5/5 [00:30<00:00,  6.11s/it]


> Evaluating on contrast_2


100%|██████████| 5/5 [00:30<00:00,  6.11s/it]


> Evaluating on contrast_3


100%|██████████| 5/5 [00:31<00:00,  6.23s/it]


> Evaluating on contrast_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on contrast_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on defocus_blur_1


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on defocus_blur_2


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on defocus_blur_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on defocus_blur_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on defocus_blur_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on elastic_transform_1


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on elastic_transform_2


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on elastic_transform_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on elastic_transform_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on elastic_transform_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on fog_1


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on fog_2


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on fog_3


100%|██████████| 5/5 [00:31<00:00,  6.34s/it]


> Evaluating on fog_4


100%|██████████| 5/5 [00:42<00:00,  8.53s/it]


> Evaluating on fog_5


100%|██████████| 5/5 [00:38<00:00,  7.75s/it]


> Evaluating on frost_1


100%|██████████| 5/5 [00:57<00:00, 11.58s/it]


> Evaluating on frost_2


100%|██████████| 5/5 [00:57<00:00, 11.51s/it]


> Evaluating on frost_3


100%|██████████| 5/5 [00:30<00:00,  6.17s/it]


> Evaluating on frost_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on frost_5


100%|██████████| 5/5 [00:42<00:00,  8.58s/it]


> Evaluating on gaussian_blur_1


100%|██████████| 5/5 [00:57<00:00, 11.49s/it]


> Evaluating on gaussian_blur_2


100%|██████████| 5/5 [00:51<00:00, 10.33s/it]


> Evaluating on gaussian_blur_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on gaussian_blur_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on gaussian_blur_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on gaussian_noise_1


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on gaussian_noise_2


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on gaussian_noise_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on gaussian_noise_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on gaussian_noise_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on impulse_noise_1


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on impulse_noise_2


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on impulse_noise_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on impulse_noise_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on impulse_noise_5


100%|██████████| 5/5 [00:30<00:00,  6.16s/it]


> Evaluating on jpeg_compression_1


100%|██████████| 5/5 [00:30<00:00,  6.17s/it]


> Evaluating on jpeg_compression_2


100%|██████████| 5/5 [00:30<00:00,  6.14s/it]


> Evaluating on jpeg_compression_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on jpeg_compression_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on jpeg_compression_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on motion_blur_1


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on motion_blur_2


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on motion_blur_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on motion_blur_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on motion_blur_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on pixelate_1


100%|██████████| 5/5 [00:30<00:00,  6.14s/it]


> Evaluating on pixelate_2


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on pixelate_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on pixelate_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on pixelate_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on shot_noise_1


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on shot_noise_2


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on shot_noise_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on shot_noise_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on shot_noise_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on snow_1


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on snow_2


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on snow_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on snow_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on snow_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on zoom_blur_1


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on zoom_blur_2


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on zoom_blur_3


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on zoom_blur_4


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]


> Evaluating on zoom_blur_5


100%|██████████| 5/5 [00:30<00:00,  6.13s/it]
