In [None]:
# neural network
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import models, transforms

from collections import OrderedDict

from PIL import Image

import helper

import os

In [None]:
# path
PATH = os.getcwd()
PATH

In [None]:
# load model
state_dict = torch.load(PATH + '/model/cifar10_checkpoint.pth', map_location=torch.device('cpu'))
print(state_dict.keys())

In [None]:
# model architecture

class network(nn.Module):
    
    def __init__(self):
        
        super(network, self).__init__()
        
        # layers
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5)
        
        self.fc1 = nn.Linear(128 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 60)
        
        self.output = nn.Linear(60, 10)
        
    def forward(self, x):
        
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 128 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        x = self.output(x)
        
        return x
    
# complete CNN
model = network()
print(model)

In [None]:
# load state dict
model.load_state_dict(torch.load(PATH + '/model/cifar10_checkpoint.pth', map_location=torch.device('cpu')))
model.eval()

In [None]:
# define transformer, convert data to a normalized torch.FloatTensor

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], 
                        [0.5, 0.5, 0.5])
])

# load image

basewidth = 300

img = Image.open(PATH + '/katito_karla/02.jpg')
img = img.resize((32, 32))

wpercent = (basewidth/float(img.size[0]))
hsize = int((float(img.size[1])*float(wpercent)))
img.resize((basewidth,hsize), Image.ANTIALIAS)

In [None]:
img = transform(img).unsqueeze(0)

output = model(img)

_, preds_tensor = torch.max(output, 1)

preds = preds_tensor.numpy()

# specify the image classes
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

print('class predicted: ')
classes[ preds[0] ]