In [None]:
import sys
import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from datetime import datetime

from network_definitions import VGG16

"""
This notebook runs inference of upsampled CIFAR10 trained models on the generalisation dataset (CIFAR-10G)
"""

In [None]:
def test(self, model, testloader):
    # Fuction performing testing and returning testing accuracy
    correct = 0
    total = 0
    accuracy = 0
    model.train(False)
    with torch.no_grad():
        for i,(images,labels)in enumerate(tqdm(testloader)):
            if torch.cuda.is_available():
                images = images.cuda()
                labels = labels.cuda()
            outputs = model(Variable(images.cuda()))
            labels = Variable(labels.cuda())

            _,predicted = outputs.max(1)
            correct = predicted.eq(labels).sum().item()
            total = labels.size(0)
            accuracy+=100*(correct/total)
    return accuracy/len(testloader)

def save_accuracy(path,m_type,w_f,cat,a,m="Local"):
    file_path = path + "generalisation_inference.csv"
    if os.path.isfile(file_path):
        f = open(file_path, "a")
    else:
        f = open(file_path, "x")
        f.write("model_type,last_training_epoch,category,inference_accuracy,original_training_accuracy,original_testing_accuracy,date_time,machine" + "\n")
    f.write(m_type + ',' + w_f.split('_')[-5] + ',' + cat + ',' + str(a) + ',' + w_f.split('_')[-3] + ',' + w_f.split('_')[-1].split('.')[0] + ',' + datetime.now().strftime("%d/%m/%Y %H:%M:%S") + ',' + m + "\n")
    f.close()

In [None]:
def run():
    
    mean_cifar10=122.61385345458984
    std_cifar10=60.87860107421875
    
    cifar_transforms = transforms.Compose([
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize(mean_cifar10,std_cifar10)])
    
    for folder in generalisation_categories:
        testset = datasets.ImageFolder(testing_path + folder, transform=cifar_transforms)
        inference_loader = torch.utils.data.DataLoader(testset, batch_size=50, shuffle=False, num_workers=4, pin_memory=False)

        # do inference
        accuracy = test(model, inference_loader)
        save_accuracy(base_path+"/", model_type, weights_file, folder, accuracy, running_machine)
        print('model:',model_type,'| category:',folder,'| accuracy:',accuracy,'%')

In [None]:
# load last trained model checkpoint
base_path = '/home/user/data'
weights_file = "/path/to/trained/model"
model_type = "VGG16"
running_machine = "Local"
testing_path = base_path + "/CIFAR-10G/224x224/"

model = VGG16().to("cuda")
dict = torch.load(weights_file)
model.load_state_dict(dict["model_state"])

generalisation_categories = ["contours","contours_inverted","line_drawings","line_drawings_inverted","silhouettes","silhouettes_inverted"]

run()