In [None]:
import os
import glob
import copy
from tqdm import tqdm
from typing import List

import torch
from torch.utils.data import Dataset, DataLoader

from source.constants import RESULTS_PATH
from source.networks.resnet import get_resnet18
from source.networks.densenet import get_densenet169
from source.networks.regnet import get_regnet_y_800mf
from source.data.cifar10_c import get_cifar10_c, corruptions

In [2]:
seed = 42

n_class = 10
models = ["resnet18", "densenet169", "regnet"]

model = models[0]                   # select model

device = "cuda:0"
batch_size = 2048 

print(model)

cnn


In [3]:
@torch.no_grad()
def evaluate(networks: List, ds: Dataset):
    probits = list()
    for network in tqdm(networks):
        network.eval()
        probits_ = list()
        for x, _ in DataLoader(ds, batch_size = batch_size, shuffle=False, drop_last=False):
            x = x.to(device)

            probits_.append(torch.softmax(network.forward(x), dim=1).cpu())
        probits_ = torch.concat(probits_, dim=0)
        probits.append(probits_)
    return torch.stack(probits, dim=1)

In [4]:
path = os.path.join(RESULTS_PATH, f"cifar10_{model}_seed{seed}")

results_path = os.path.join(path, "corruptions")
os.makedirs(results_path, exist_ok=True)

# load networks
networks = list()
model_files = glob.glob(os.path.join(path, "models", "*.pt"))
for m, model_file in tqdm(enumerate(sorted(model_files)), desc="Loading models"):

    if model == "resnet18":
        network = get_resnet18(num_classes=n_class) 
    elif model == "densenet169":
        network = get_densenet169(num_classes=n_class)
    elif model == "regnet":
        network = get_regnet_y_800mf(num_classes=n_class)

    network.load_state_dict(torch.load(model_file, map_location=device))
    network.to(device)
    networks.append(copy.deepcopy(network))

for c in range(15):
    for s in range(1, 6):
        dataset_name = f"{corruptions[c]}_{s}"
        print(f"> Evaluating on {dataset_name}")

        dataset = get_cifar10_c(corruption=c, severity=s)

        # evaluate
        probits = evaluate(networks, dataset)
        torch.save(probits.to(torch.float16), os.path.join(results_path, f"{dataset_name}_probits.pt"))

Loading models: 50it [00:01, 36.47it/s]


> Evaluating on brightness_1


100%|██████████| 50/50 [00:06<00:00,  8.03it/s]


> Evaluating on brightness_2


100%|██████████| 50/50 [00:04<00:00, 10.03it/s]


> Evaluating on brightness_3


100%|██████████| 50/50 [00:04<00:00, 10.09it/s]


> Evaluating on brightness_4


100%|██████████| 50/50 [00:05<00:00,  9.99it/s]


> Evaluating on brightness_5


100%|██████████| 50/50 [00:05<00:00, 10.00it/s]


> Evaluating on contrast_1


100%|██████████| 50/50 [00:05<00:00,  9.95it/s]


> Evaluating on contrast_2


100%|██████████| 50/50 [00:05<00:00,  9.92it/s]


> Evaluating on contrast_3


100%|██████████| 50/50 [00:05<00:00,  9.94it/s]


> Evaluating on contrast_4


100%|██████████| 50/50 [00:05<00:00,  9.98it/s]


> Evaluating on contrast_5


100%|██████████| 50/50 [00:05<00:00,  9.97it/s]


> Evaluating on defocus_blur_1


100%|██████████| 50/50 [00:05<00:00,  9.97it/s]


> Evaluating on defocus_blur_2


100%|██████████| 50/50 [00:05<00:00,  9.94it/s]


> Evaluating on defocus_blur_3


100%|██████████| 50/50 [00:05<00:00,  9.90it/s]


> Evaluating on defocus_blur_4


100%|██████████| 50/50 [00:05<00:00,  9.90it/s]


> Evaluating on defocus_blur_5


100%|██████████| 50/50 [00:04<00:00, 10.04it/s]


> Evaluating on elastic_transform_1


100%|██████████| 50/50 [00:07<00:00,  7.01it/s]


> Evaluating on elastic_transform_2


100%|██████████| 50/50 [00:05<00:00,  8.58it/s]


> Evaluating on elastic_transform_3


100%|██████████| 50/50 [00:07<00:00,  7.00it/s]


> Evaluating on elastic_transform_4


100%|██████████| 50/50 [00:06<00:00,  7.90it/s]


> Evaluating on elastic_transform_5


100%|██████████| 50/50 [00:05<00:00,  9.67it/s]


> Evaluating on fog_1


100%|██████████| 50/50 [00:05<00:00,  9.76it/s]


> Evaluating on fog_2


100%|██████████| 50/50 [00:05<00:00,  9.87it/s]


> Evaluating on fog_3


100%|██████████| 50/50 [00:05<00:00,  9.94it/s]


> Evaluating on fog_4


100%|██████████| 50/50 [00:05<00:00,  9.95it/s]


> Evaluating on fog_5


100%|██████████| 50/50 [00:05<00:00,  9.99it/s]


> Evaluating on frost_1


100%|██████████| 50/50 [00:05<00:00,  9.97it/s]


> Evaluating on frost_2


100%|██████████| 50/50 [00:05<00:00,  9.93it/s]


> Evaluating on frost_3


100%|██████████| 50/50 [00:05<00:00,  9.91it/s]


> Evaluating on frost_4


100%|██████████| 50/50 [00:05<00:00,  9.94it/s]


> Evaluating on frost_5


100%|██████████| 50/50 [00:04<00:00, 10.02it/s]


> Evaluating on gaussian_blur_1


100%|██████████| 50/50 [00:04<00:00, 10.03it/s]


> Evaluating on gaussian_blur_2


100%|██████████| 50/50 [00:05<00:00,  9.97it/s]


> Evaluating on gaussian_blur_3


100%|██████████| 50/50 [00:05<00:00,  9.95it/s]


> Evaluating on gaussian_blur_4


100%|██████████| 50/50 [00:05<00:00,  9.96it/s]


> Evaluating on gaussian_blur_5


100%|██████████| 50/50 [00:04<00:00, 10.01it/s]


> Evaluating on gaussian_noise_1


100%|██████████| 50/50 [00:05<00:00,  9.97it/s]


> Evaluating on gaussian_noise_2


100%|██████████| 50/50 [00:04<00:00, 10.01it/s]


> Evaluating on gaussian_noise_3


100%|██████████| 50/50 [00:05<00:00,  9.83it/s]


> Evaluating on gaussian_noise_4


100%|██████████| 50/50 [00:04<00:00, 10.03it/s]


> Evaluating on gaussian_noise_5


100%|██████████| 50/50 [00:05<00:00,  9.95it/s]


> Evaluating on impulse_noise_1


100%|██████████| 50/50 [00:05<00:00,  9.95it/s]


> Evaluating on impulse_noise_2


100%|██████████| 50/50 [00:05<00:00,  9.99it/s]


> Evaluating on impulse_noise_3


100%|██████████| 50/50 [00:05<00:00,  9.96it/s]


> Evaluating on impulse_noise_4


100%|██████████| 50/50 [00:05<00:00,  9.94it/s]


> Evaluating on impulse_noise_5


100%|██████████| 50/50 [00:05<00:00,  9.99it/s]


> Evaluating on jpeg_compression_1


100%|██████████| 50/50 [00:05<00:00, 10.00it/s]


> Evaluating on jpeg_compression_2


100%|██████████| 50/50 [00:04<00:00, 10.06it/s]


> Evaluating on jpeg_compression_3


100%|██████████| 50/50 [00:05<00:00,  9.95it/s]


> Evaluating on jpeg_compression_4


100%|██████████| 50/50 [00:04<00:00, 10.04it/s]


> Evaluating on jpeg_compression_5


100%|██████████| 50/50 [00:05<00:00,  9.95it/s]


> Evaluating on motion_blur_1


100%|██████████| 50/50 [00:04<00:00, 10.07it/s]


> Evaluating on motion_blur_2


100%|██████████| 50/50 [00:05<00:00,  9.95it/s]


> Evaluating on motion_blur_3


100%|██████████| 50/50 [00:04<00:00, 10.04it/s]


> Evaluating on motion_blur_4


100%|██████████| 50/50 [00:05<00:00, 10.00it/s]


> Evaluating on motion_blur_5


100%|██████████| 50/50 [00:04<00:00, 10.02it/s]


> Evaluating on pixelate_1


100%|██████████| 50/50 [00:05<00:00,  9.97it/s]


> Evaluating on pixelate_2


100%|██████████| 50/50 [00:05<00:00,  9.99it/s]


> Evaluating on pixelate_3


100%|██████████| 50/50 [00:05<00:00,  9.70it/s]


> Evaluating on pixelate_4


100%|██████████| 50/50 [00:05<00:00,  9.98it/s]


> Evaluating on pixelate_5


100%|██████████| 50/50 [00:05<00:00,  9.96it/s]


> Evaluating on shot_noise_1


100%|██████████| 50/50 [00:05<00:00,  9.85it/s]


> Evaluating on shot_noise_2


100%|██████████| 50/50 [00:05<00:00,  9.88it/s]


> Evaluating on shot_noise_3


100%|██████████| 50/50 [00:05<00:00, 10.00it/s]


> Evaluating on shot_noise_4


100%|██████████| 50/50 [00:05<00:00,  9.90it/s]


> Evaluating on shot_noise_5


100%|██████████| 50/50 [00:05<00:00, 10.00it/s]


> Evaluating on snow_1


100%|██████████| 50/50 [00:05<00:00,  9.89it/s]


> Evaluating on snow_2


100%|██████████| 50/50 [00:05<00:00,  9.85it/s]


> Evaluating on snow_3


100%|██████████| 50/50 [00:05<00:00,  9.96it/s]


> Evaluating on snow_4


100%|██████████| 50/50 [00:04<00:00, 10.05it/s]


> Evaluating on snow_5


100%|██████████| 50/50 [00:05<00:00,  9.96it/s]


> Evaluating on zoom_blur_1


100%|██████████| 50/50 [00:05<00:00,  9.85it/s]


> Evaluating on zoom_blur_2


100%|██████████| 50/50 [00:05<00:00,  9.77it/s]


> Evaluating on zoom_blur_3


100%|██████████| 50/50 [00:05<00:00,  9.99it/s]


> Evaluating on zoom_blur_4


100%|██████████| 50/50 [00:05<00:00,  9.91it/s]


> Evaluating on zoom_blur_5


100%|██████████| 50/50 [00:04<00:00, 10.01it/s]
