In [None]:
import sys
import os
import numpy as np
import torch
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets
from datetime import datetime

from network_definitions import VGG16

In [None]:
class VariableCorruptionDatasetLoader():
    def __init__(self, corrupt_dataset_path, raw_dataset_path):
        self.corrupt_dataset_path = corrupt_dataset_path

        corruptions_list = [d for d in os.listdir(corrupt_dataset_path) if os.path.isdir(os.path.join(corrupt_dataset_path, d))]
        self.corruptions_list = corruptions_list

        self.severities = [int(d) for d in os.listdir(corrupt_dataset_path+corruptions_list[0]) if os.path.isdir(os.path.join(corrupt_dataset_path+corruptions_list[0], d))]
        
        self.raw_dataset_path = raw_dataset_path

    def get_dataloader(self,mode,corruption,severity):
        if mode == "raw":
            loading_path = self.raw_dataset_path
        elif mode == "corrupt":
            loading_path = self.corrupt_dataset_path + corruption + slash + str(severity) + slash
        
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(loading_path, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize])), batch_size=50, shuffle=False, num_workers=4, pin_memory=False)
        return loader

def save_accuracy(path,m_type,w_f,c_n,s,a,m="Local"):
    file_path = path + "imagenette_c_inference_metrics.csv"
    if os.path.isfile(file_path):
        f = open(file_path, "a")
    else:
        f = open(file_path, "x")
        f.write("model_type,last_training_epoch,corruption_name,severity,average_benchmarking_accuracy,accuracy_from_training,date_time,machine" + "\n")
    f.write(m_type + ',' + w_f.split('_')[-5] + ',' + c_n.split('.')[0] + ',' + str(s) + ',' + str(a) + ',' + w_f.split('_')[-3] + ',' + datetime.now().strftime("%d/%m/%Y %H:%M:%S") + ',' + m + "\n")
    f.close()

def show_model_params(m):
    for param_tensor in m.state_dict():
        print(param_tensor, "\t", m.state_dict()[param_tensor].size())
    for p in m.state_dict():
        print(p)
def show_dict_params(d):
    for p in d["model"]:
        print(p)
    print("length",len(d["model"]))
def show_image(i):
    plt.imshow(i, cmap='gray')
    plt.show()

def benchmark(model,loader):
    accuracy = 0
    for i, data in enumerate(tqdm(loader)):
            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()

            outputs = model(inputs)

            _,predicted = outputs.max(1)
            num_correct = predicted.eq(labels).sum().item()
            num_total = labels.size(0)
            accuracy+=100*(num_correct/num_total)
    return accuracy/len(loader)

In [None]:
def run():
    corrupted_data_object = VariableCorruptionDatasetLoader(corrupt_dataset_path, raw_dataset_path)

    print("Benchmarking on raw testing data (severity 0):")
    average_accuracy = benchmark(model,corrupted_data_object.get_dataloader("raw","none",0))
    print("model =",model_type,"corruption = none severity = 0 accuracy =",average_accuracy)
    save_accuracy(base_path + "/",model_type,weights_file,"none",0,average_accuracy,m=running_machine)

    print("Benchmarking on corrupted testing data:")
    for corruption in corrupted_data_object.corruptions_list:
        for severity in corrupted_data_object.severities:

            loader = corrupted_data_object.get_dataloader("corrupt", corruption, severity)
            average_accuracy = benchmark(model,loader)

            print("model =",model_type,"corruption =",corruption,"severity =",severity,"accuracy =",average_accuracy)
            save_accuracy(base_path + "/",model_type,weights_file,corruption,severity,average_accuracy,m=running_machine)

In [None]:
base_path = '/home/user/data'
corrupt_dataset_path = base_path + "/imagenette_C/"
raw_dataset_path = base_path + "/imagenet/test_val/"
weights_file = "/path/to/trained/model"
model_type = "VGG16"
running_machine = "Local"

model = VGG16().to("cuda")
dict = torch.load(weights_file)
model.load_state_dict(dict["model_state"])
model.eval()
print("Loaded model state from file:",weights_file)

run()