# Predict

In [1]:
import torch 
import torchvision
from torchvision import transforms

import numpy as np 
import pandas as pd 
import PIL
import os
import copy
import csv 

In [2]:
"""
ImageTransform class is called for all images for image transformation. 
Used for data augmentation, input image resizing, and converting images data into tensor floats. 
Note that data augmentation is only used in training data. Not validation and testing data.
"""
class ImageTransform():
    
    def __init__(self, size):
        self.interpolation_mode = transforms.InterpolationMode.BILINEAR
        self.data_transform = {
            'test': transforms.Compose([
                transforms.Resize(size, interpolation=self.interpolation_mode),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
        }
        
    def __call__(self, img, phase='test'):
        return self.data_transform[phase](img)

## Load

In [3]:
model = torch.load('saved_model/netv3_sr_xsuper')
for i in range(len(model)):
    model[i] = model[i].eval()

In [4]:
transform_func = ImageTransform((32, 32))

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Predict

In [6]:
import glob
import os

In [7]:
base_dir = None
train_dir = 'sr_test_shuffle'
test_list = glob.glob(os.path.join(train_dir, '*.jpg'))

In [8]:
def predict(model, device, img_path, transform_func):
    
    outputs = []
    img = PIL.Image.open(img_path)
    img_transformed = transform_func(img)
    img_transformed = img_transformed.unsqueeze(0)
    img_transformed = img_transformed.to(device)
    
    for i in range(len(model)):
        output = model[i](img_transformed)
        probas = torch.nn.functional.softmax(output, dim=1)
        outputs.append(torch.Tensor.cpu(probas).detach().numpy())
        
    outputs = np.array(outputs)
    avg_outputs = np.mean(outputs, axis=0)
    preds = np.argmax(avg_outputs)

    return int(preds)

In [9]:
mapping = {0: 'bird', 1: 'dog', 2: 'reptile'}
preds = {}
for img in test_list:
    img_id = int(img.split('/')[-1].split('.')[0])
    result = predict(model, device, img, transform_func)
    preds[img_id] = mapping[result]

In [10]:
from collections import OrderedDict
preds = dict(OrderedDict(sorted(preds.items())))
output_list = []

for key in preds:
    output_list.append([preds[key]])

In [11]:
import csv

with open('super_predictions.csv', 'w') as file:
    writer = csv.writer(file)
    writer.writerow(['predictions'])
    writer.writerows(output_list)