In [3]:
import torch
import numpy as np
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import os
from PIL import Image
import cv2 as cv

In [4]:
# All the categories that the model knows in the same order that were feed at training.
class_names = ['Admiral Ackbar',
 'Admiral Piett',
 'Anakin Skywalker',
 'BB-8',
 'Bail Organa',
 'Bib Fortuna',
 'Boba Fett',
 'Bodhi Rook',
 'C-3PO',
 'Captain Phasma',
 'Cassian Andor',
 'Chewbacca',
 'Dark Sidious',
 'Darth Maul',
 'Darth Vader',
 'Finn (FN-2187)',
 'General Grievous',
 'General Hux',
 'Grand Moff Tarkin',
 'Greedo',
 'Han Solo',
 'Jabba the Hutt',
 'Jango Fett',
 'Jar Jar Binks',
 'Jyn Erso',
 'K-2SO',
 'Kenobi',
 'Kylo Ren',
 'Lando Calrissian',
 'Luke Skywalker',
 'Mace Windu',
 'Maz Kanata',
 'Nien Nunb',
 'Obi-Wan',
 'Orson Krennic',
 'Padme Amidala',
 'Poe Dameron',
 'Princess Leia Organa',
 "Qi'ra",
 'Qui-Gon Jinn',
 'R2-D2',
 'Rey',
 'Rose Tico',
 'Saw Gerrera',
 'Supreme Leader Snoke',
 'Tobias Beckett',
 'Vice-Admiral Holdo',
 'Watto',
 'Wedge Antilles',
 'Wicket W. Warrick',
 'Yoda']

In [None]:
# Load the model
# Model must be on the same level as this notebook 
model_ft = torch.load("entire_model.pt")
model_ft.eval()

In [18]:
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])

def image_loader(image_name: str) -> list:
    """
    Load image, returns tensor.
    image_name: path to the image file
    return: image already preprocessed
    """
    image = Image.open(image_name)
    image = loader(image).float()
    image = torch.autograd.Variable(image, requires_grad=True)
    image = image.unsqueeze(0)  #this is for VGG, may not be needed for ResNet
    return image

def predict(model, img, img_show):
    """
    Predicts the character in the picture and plots the image.
    model: Pytorch model
    img: preprocessed image to be used for inference
    img_show: image read with cv library. It will be plotted along with the predicted category
    return: None
    """
    was_training = model.training
    model.eval()
    fig = plt.figure()

    with torch.no_grad():
        output = model_ft(img)
        prediction = int(torch.max(output.data, 1)[1].numpy())

        ax = fig.add_subplot(1, 1, 1)
        ax.axis('off')
        ax.set_title('predicted: {}'.format(class_names[ prediction ]))
        plt.imshow(img_show)

        model.train(mode=was_training)

        return

In [None]:
# Get the image's path, preprocess it and then predict
# It is necessary to create a folder named 'pred' on the same level as this notebook and store the images there
path= [path for path in os.listdir("pred")
       if not path.endswith(".ipynb") and not path.endswith(".ipynb_checkpoints")][0]
imagen = cv.imread(os.path.join(r"pred",path))
image = image_loader(os.path.join(r"pred",path))
predict(model_ft, image, imagen)
# os.remove(os.path.join(r"pred",path))