In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torchvision import transforms, datasets

In [2]:
train_dir = '/home/jovyan/cifar10-practice/data/processed/cifar10_original_png/train'
test_dir = '/home/jovyan/cifar10-practice/data/processed/cifar10_original_png/test'
model_output_path = '/home/jovyan/cifar10-practice/models/cifar10.pt'

In [3]:
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')
print('Using PyTorch version:', torch.__version__, ' Device:', DEVICE)

Using PyTorch version: 1.10.0+cu102  Device: cpu


In [4]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 8, kernel_size = 3, padding = 1)
        self.conv2 = nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = 3, padding = 1)
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.fc1 = nn.Linear(8 * 8 * 16, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool(x)
        
        x = x.view(-1, 8 * 8 * 16)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.log_softmax(x)
        return x

In [5]:
new_model_path = '/home/jovyan/cifar10-practice/models/cifar10-65.20.pt'

idx_to_class = {0: 'airplane',
 1: 'automobile',
 2: 'bird',
 3: 'cat',
 4: 'deer',
 5: 'dog',
 6: 'frog',
 7: 'horse',
 8: 'ship',
 9: 'truck'}
model_reloaded = torch.load(new_model_path)
loaded_model = CNN().to(DEVICE)
loaded_model.load_state_dict(model_reloaded)

<All keys matched successfully>

In [6]:
from PIL import Image

test_img_path = '/home/jovyan/cifar10-practice/data/processed/cifar10_original_png/test/bird/00067.png'
img_to_tensor = transforms.ToTensor()

test_img = Image.open(test_img_path)
test_img_tensor = img_to_tensor(test_img)
test_img_stack = torch.stack((test_img_tensor,)).to(DEVICE)

proba_logsoftmax = loaded_model(test_img_stack).tolist()[0]
proba = [np.exp(val) for val in proba_logsoftmax]
proba_max = max(proba)
proba_max_idx = np.argmax(proba)
pred_class = idx_to_class[proba_max_idx]
print(pred_class, proba_max)

bird 0.6601915246381649


