In [11]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from math import floor
import os
import torchvision.models
import time
from torchvision.models import resnet50, ResNet50_Weights

torch.manual_seed(106)
data_transforms = transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
                    ])


import cv2
import pandas as pd

pokemon_names = ['Abra', 'Aerodactyl', 'Aipom', 'Alakazam', 'Ampharos', 'Arbok', 'Arcanine', 'Ariados', 'Articuno', 'Azumarill', 'Bayleef', 'Beedrill', 'Bellossom', 'Bellsprout', 'Blastoise', 'Blissey', 'Bulbasaur', 'Butterfree', 'Caterpie', 'Celebi', 'Chansey', 'Charizard', 'Charmander', 'Charmeleon', 'Chikorita', 'Chinchou', 'Clefable', 'Clefairy', 'Cleffa', 'Cloyster', 'Corsola', 'Crobat', 'Croconaw', 'Cubone', 'Cyndaquil', 'Delibird', 'Dewgong', 'Diglett', 'Ditto', 'Dodrio', 'Doduo', 'Donphan', 'Dragonair', 'Dragonite', 'Dratini', 'Drowzee', 'Dugtrio', 'Dunsparce', 'Eevee', 'Ekans', 'Electabuzz', 'Electrode', 'Elekid', 'Entei', 'Espeon', 'Exeggcute', 'Exeggutor', 'Farfetch\'d', 'Fearow', 'Feraligatr', 'Flaaffy', 'Flareon', 'Forretress', 'Furret', 'Gastly', 'Gengar', 'Geodude', 'Girafarig', 'Gligar', 'Gloom', 'Golbat', 'Goldeen', 'Golduck', 'Golem', 'Granbull', 'Graveler', 'Grimer', 'Growlith', 'Gyarados', 'Haunter', 'Heracross', 'Hitmonchan', 'Hitmonlee', 'Hitmontop', 'Ho-oh', 'Hoothoot', 'Hoppip', 'Horsea', 'Houndoom', 'Houndour', 'Hypno', 'Igglybuff', 'Ivysaur', 'Jigglypuff', 'Jolteon', 'Jumpluff', 'Jynx', 'Kabuto', 'Kabutops', 'Kadabra', 'Kakuna', 'Kangaskhan', 'Kingdra', 'Kingler', 'Koffing', 'Krabby', 'Lanturn', 'Lapras', 'Larvitar', 'Ledian', 'Ledyba', 'Lickitung', 'Lugia', 'Machamp', 'Machoke', 'Machop', 'Magby', 'Magcargo', 'Magikarp', 'Magmar', 'Magnemite', 'Magneton', 'Mankey', 'Mantine', 'Mareep', 'Marill', 'Marowak', 'Meganium', 'Meowth', 'Metapod', 'Mew', 'Mewtwo', 'Miltank', 'Misdreavus', 'Moltres', 'Mr. Mime', 'Muk', 'Murkrow', 'Natu', 'Nidoking', 'Nidoqueen', 'Nidorina', 'Nidorino', 'Ninetales', 'Noctowl', 'Octillery', 'Oddish', 'Omanyte', 'Omastar', 'Onix', 'Paras', 'Parasect', 'Persian', 'Phanpy', 'Pichu', 'Pidgeot', 'Pidgeotto', 'Pidgey', 'Pikachu', 'Piloswine', 'Pineco', 'Pinsir', 'Politoed', 'Poliwag', 'Poliwhirl', 'Poliwrath', 'Ponyta', 'Porygon', 'Porygon2', 'Primeape', 'Psyduck', 'Pupitar', 'Quagsire', 'Quilava', 'Qwilfish', 'Raichu', 'Raikou', 'Rapidash', 'Raticate', 'Rattata', 'Remoraid', 'Rhydon', 'Rhyhorn', 'Sandshrew', 'Sandslash', 'Scizor', 'Scyther', 'Seadra', 'Seaking', 'Seel', 'Sentret', 'Shellder', 'Shuckle', 'Skarmory', 'Skiploom', 'Slowbro', 'Slowking', 'Slowpoke', 'Slugma', 'Smeargle', 'Smoochum', 'Sneasel', 'Snorlax', 'Snubbull', 'Spearow', 'Spinarak', 'Squirtle', 'Stantler', 'Starmie', 'Staryu', 'Steelix', 'Sudowoodo', 'Suicune', 'Sunflora', 'Sunkern', 'Swinub', 'Tangela', 'Tauros', 'Teddiursa', 'Tentacool', 'Tentacruel', 'Togepi', 'Togetic', 'Totodile', 'Typhlosion', 'Tyranitar', 'Tyrogue', 'Umbreon', 'Unown', 'Ursaring', 'Vaporeon', 'Venomoth', 'Venonat', 'Venusaur', 'Victreebel', 'Vileplume', 'Voltorb', 'Vulpix', 'Wartortle', 'Weedle', 'Weepinbell', 'Weezing', 'Wigglytuff', 'Whooper', 'Wobbuffet', 'Xatu', 'Yanma', 'Zapdos', 'Zubat']

In [12]:
model_test_path = "E:\Files\Github Files\Pokemon_Image_Recognition\DataSet\Model Test\Test"

model_test_datasets = torchvision.datasets.ImageFolder(root=model_test_path, transform = data_transforms)
model_test_loader = torch.utils.data.DataLoader(model_test_datasets, batch_size=1, num_workers=1, shuffle=False, collate_fn=None, pin_memory=False)

# Model Testing

In [13]:
best_model_path = 'result/saved_torch/model_{0}_bs{1}_lr{2}_epoch{3}'.format('ResNet50_SGD', 32, 0.001, 15)
num_classes = 249  # specify the number of classes in your saved model
best_model = resnet50(pretrained=False)  # create a new model instance
best_model.fc = torch.nn.Linear(2048, num_classes)  # replace the fully connected layer
best_model.load_state_dict(torch.load(best_model_path))
best_model = best_model.cuda()  # move the model to the GPU



In [14]:
testing_names = os.listdir(model_test_path)
correct_index = []
for i in testing_names:
    correct_index.append(pokemon_names.index(i))
print(correct_index)
def convert_to_real_index(input):
    integer = int(input)
    return correct_index[integer]

[2, 23, 43, 112, 146, 158]


In [15]:
def get_accuracy(model, test_loader):
    data = test_loader

    correct = 0
    total = 0
    for imgs, labels in data:
        if use_cuda == True:
          imgs = imgs.cuda()
          labels = labels.cuda()
        
        output = model(imgs)

        #select index with maximum prediction score
        pred = output.max(1, keepdim=True)[1]

        pred_numpy = pred.cpu().numpy()
        labels_numpy = labels.cpu().numpy()
        # print(labels_numpy[0])
        
        for pred, label in zip(pred_numpy, labels_numpy):
            true_label = convert_to_real_index(label)
            print("Predicted Pokemon:" + str(pokemon_names[pred[0]]))
            print("Actual Pokemon:" + str(pokemon_names[true_label])+ '\n')


In [16]:
use_cuda = True
test_accuracy = get_accuracy(best_model, model_test_loader)

Predicted Pokemon:Aipom
Actual Pokemon:Aipom

Predicted Pokemon:Charmeleon
Actual Pokemon:Charmeleon

Predicted Pokemon:Dragonite
Actual Pokemon:Dragonite

Predicted Pokemon:Tentacruel
Actual Pokemon:Lugia

Predicted Pokemon:Oddish
Actual Pokemon:Oddish

Predicted Pokemon:Pikachu
Actual Pokemon:Pikachu



<img src='result/ResNet50_SGD_batch_size_32_learning_rate_0.001_num_epochs_16_acc_curve.png' width='800' align='left' />

In [18]:
result

Unnamed: 0,epochs,train_acc,valid_acc,loss,time
0,1,0.857681,0.782974,0.043403,80.950003
1,2,0.960392,0.867706,0.009298,80.597164
2,3,0.98755,0.906075,0.003828,80.202336
3,4,0.993173,0.908873,0.001862,80.785075
4,5,0.999046,0.917266,0.005402,80.236235
5,6,0.998946,0.922462,0.001225,80.399127
6,7,0.999347,0.923661,0.001278,80.348723
7,8,0.999297,0.928457,0.000789,80.989134
8,9,0.999749,0.934452,0.001839,80.448843
9,10,0.999598,0.927258,0.000739,80.467639


In [17]:
# import result/ResNet50_SGD_batch_size_32_learning_rate_0.001_num_epochs_16.csv as pandas dataframe
result = pd.read_csv('result/ResNet50_SGD_batch_size_32_learning_rate_0.001_num_epochs_16.csv')