# Predict

In [1]:
import torch 
import torchvision
from torchvision import transforms

import numpy as np 
import pandas as pd 
import PIL
import os
import copy
import csv 

In [2]:
class SGNet(torch.nn.Module):

    def __init__(self, supernet):
        super(SGNet, self).__init__()

        ########
        # Convolutional layers
        ########
        self.cnn_block = supernet[:21]
    
        ########
        # Fully-connected MLP
        # If only use MLP, flatten the image to (1, 8x8x3), 8x8x3 is 8x8 image resolution with each pixel 3-color channel RGB
        # 8x8x3 = 192
        ########
        self.scb = supernet[21:30]
        
        self.fcb = torch.nn.Sequential(
            torch.nn.Linear(1024, 2048),
            torch.nn.BatchNorm1d(2048),
            torch.nn.ReLU(inplace=True),
    
            torch.nn.Linear(2048, 4096),
            torch.nn.BatchNorm1d(4096),
            torch.nn.ReLU(inplace=True),
    
            torch.nn.Linear(4096, 8192),
            torch.nn.BatchNorm1d(8192),
            torch.nn.ReLU(inplace=True),
            
            torch.nn.Linear(8192, 8192),
            torch.nn.BatchNorm1d(8192),
            torch.nn.ReLU(inplace=True)
        )
    
        # Output layer
        self.output = torch.nn.Sequential(
            torch.nn.Linear(16384, 90)
        )

    def forward(self, inputs):
        inputs = self.cnn_block(inputs)
        scb = self.scb(inputs)
        fcb = self.fcb(inputs)
        output = self.output(torch.cat((scb, fcb), dim=1))
        return output

In [3]:
"""
ImageTransform class is called for all images for image transformation. 
Used for data augmentation, input image resizing, and converting images data into tensor floats. 
Note that data augmentation is only used in training data. Not validation and testing data.
"""
class ImageTransform():
    
    def __init__(self, size):
        self.interpolation_mode = transforms.InterpolationMode.BILINEAR
        self.data_transform = {
            'test': transforms.Compose([
                transforms.Resize(size, interpolation=self.interpolation_mode),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
        }
        
    def __call__(self, img, phase='test'):
        return self.data_transform[phase](img)

## Load

In [4]:
sgnetv2_sr = torch.load('saved_model/sgnetv2_sub')
for i in range(len(sgnetv2_sr)):
    sgnetv2_sr[i] = sgnetv2_sr[i].eval()

In [5]:
transform_func = ImageTransform((8, 8))

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## Predict

In [7]:
import glob
import os

In [8]:
base_dir = None
train_dir = 'test_shuffle'
test_list = glob.glob(os.path.join(train_dir, '*.jpg'))

In [9]:
"""
predict function
returns prediction for a single image
if max(probas) > threshold, return the predicted class, else return 89 (novel class)
"""
def predict(model, device, img_path, transform_func, thresh):
    outputs = []
    
    for i in range(len(model)):
        img = PIL.Image.open(img_path)
        img_transformed = transform_func(img)
        img_transformed = img_transformed.unsqueeze(0)
        img_transformed = img_transformed.to(device)
        output = model[i](img_transformed)
        probas = torch.nn.functional.softmax(output, dim=1)
        outputs.append(torch.Tensor.cpu(probas).detach().numpy())
        
    outputs = np.array(outputs)
    avg_outputs = np.mean(outputs, axis=0)
    
    # find the max probability among probas
    preds = np.argmax(avg_outputs)
    max_ = np.max(avg_outputs)
    
    if max_ > thresh:
        # return prediction
        return int(preds)
    else:
        # return novel class (89)
        return 89

In [10]:
map_dict = {}
    
with open('sub_classes_mapping.csv') as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    next(csv_reader) # Skip the header
    for row in csv_reader:
        map_dict[int(row[0])] = row[1]

In [11]:
preds = {}
for img in test_list:
    img_id = int(img.split('/')[-1].split('.')[0])
    
    result = predict(sgnetv2_sr, device, img, transform_func, 0.7)
    preds[img_id] = map_dict[result]

In [12]:
from collections import OrderedDict
preds = dict(OrderedDict(sorted(preds.items())))
output_list = []

for key in preds:
    output_list.append([preds[key]])

In [13]:
with open('sub_predictions_th070.csv', 'w') as file:
    writer = csv.writer(file)
    writer.writerow(['predictions'])
    writer.writerows(output_list)