# 1. Load pretrained model

**Prerequisites:**

- pytorch
- torchvision
- numpy



**Models to test:**

- AlexNet
- VGG-16
- GoogLeNet
- ResNet-50

In [1]:
import torch
from torchvision import datasets, models
import os
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from probabilities_to_decision import ImageNetProbabilitiesTo16ClassesMapping

In [2]:
# alexnet 
# vgg_16 
# googlenet 
# resnet_50 

alexnet = models.alexnet(pretrained=True, progress=True)

# 2. Load test files


In [3]:
PATH_TO_IMAGES = "../stimuli/style-transfer-preprocessed-512/"
directories = [x for x in os.listdir(PATH_TO_IMAGES) if os.path.isdir(PATH_TO_IMAGES + x)]

## Deprecated:
## Each element in the `test_set` list will be a pair: `(true label, image url)`
# test_set = []
# for directory in directories:
#     urls = [x for x in os.listdir(PATH_TO_IMAGES + directory) if ".png" in x]
#     for url in urls:
#         abs_path = os.path.abspath(PATH_TO_IMAGES + directory + "/" + url)
#         test_set.append((directory, abs_path))
# print("Loaded test set of size:", len(test_set))


transform = transforms.Compose([
    transforms.ToTensor()
])
test_set = datasets.ImageFolder(PATH_TO_IMAGES, transform=transform)
test = DataLoader(test_set, batch_size=1, shuffle=False) # Load in batches of size 1

# 3. Setup pipeline

In [6]:
# Class that gi
def predict_for_image(model, image_tensors):
    # get softmax output
    softmax_output = torch.softmax(model(image_tensors),1) # replace with your favourite CNN
    # convert to numpy
    softmax_output_numpy = softmax_output.detach().numpy().flatten() # replace with conversion
    # create mapping
    mapping = ImageNetProbabilitiesTo16ClassesMapping()
    # obtain decision 
    decision_from_16_classes = mapping.probabilities_to_decision(softmax_output_numpy)
    return decision_from_16_classes


# Dictionary that stores class labels
class_labels = {v: k for k, v in test_set.class_to_idx.items()}

# Test run for first 100 images
for i, data in enumerate(test):
    print(i, ', url:', test_set.imgs[i][0])
    images, labels = data
    
    output = predict_for_image(alexnet, images)
    actual = class_labels[labels.item()]
    print('Predicted:', output,", Actual:", actual, '\n')
    if i == 100:
        break


0 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane1-bicycle2.png
Predicted: chair , Actual: airplane 

1 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane1-chair2.png
Predicted: chair , Actual: airplane 

2 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane1-clock1.png
Predicted: clock , Actual: airplane 

3 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane1-elephant1.png
Predicted: knife , Actual: airplane 

4 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane10-airplane1.png
Predicted: airplane , Actual: airplane 

5 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane10-bear3.png
Predicted: cat , Actual: airplane 

6 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane10-boat3.png
Predicted: boat , Actual: airplane 

7 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane10-car3.png
Predicted: chair , Actual: airplane 

8 , url: ../stimuli/style-transf

71 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane8-truck3.png
Predicted: chair , Actual: airplane 

72 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane9-airplane2.png
Predicted: chair , Actual: airplane 

73 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane9-boat1.png
Predicted: airplane , Actual: airplane 

74 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane9-car2.png
Predicted: bicycle , Actual: airplane 

75 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane9-clock1.png
Predicted: clock , Actual: airplane 

76 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane9-keyboard2.png
Predicted: clock , Actual: airplane 

77 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane9-knife3.png
Predicted: knife , Actual: airplane 

78 , url: ../stimuli/style-transfer-preprocessed-512/airplane/airplane9-oven2.png
Predicted: chair , Actual: airplane 

79 , url: ../stimuli/styl