In [1]:
import numpy as np
import torch
from torch import nn
from torchvision import transforms
from PIL import Image

In [2]:
# Model definition
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm1 = nn.BatchNorm2d(1)

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2)
        self.drop1 = nn.Dropout(0.1)
        self.norm2 = nn.BatchNorm2d(32)

        self.avg1 = nn.AvgPool2d(4, stride=2)
        self.drop2 = nn.Dropout(0.5)
        self.norm3 = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.norm4 = nn.BatchNorm2d(32)
        self.drop3 = nn.Dropout(0.1)

        self.avg2 = nn.AvgPool2d(4, stride=2)
        self.norm5 = nn.BatchNorm2d(32)
        self.drop4 = nn.Dropout(0.5)

        self.avg3 = nn.AvgPool2d(4, stride=2)
        self.norm7 = nn.BatchNorm2d(32)
        
        
        self.fc1 = nn.Linear(128, 64)
        self.drop6 = nn.Dropout()
        self.fc2 = nn.Linear(64, 27)
  
    def forward(self, x):
        relu = nn.ReLU()

        x = self.norm1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = relu(x)

        x = self.avg1(x)
        x = self.drop2(x)
        x = self.norm3(x)
        x = relu(x)

        x = self.conv3(x)
        x = self.conv4(x)
        x = self.norm4(x)
        x = relu(x)

        x = self.avg2(x)
        x = self.norm5(x)
        x = self.drop4(x)
        x = relu(x)

        x = self.avg3(x)
        x = self.norm7(x)
        x = relu(x)


        x = x.view(-1,x.shape[1] * x.shape[2] * x.shape[3])
        x = self.fc1(x)
        x = self.drop6(x)
        x = relu(x)
        x = self.fc2(x)
        return x

In [3]:
# Takes in PIL image and processes it in preparation for model
def process_img(img):
  transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Normalize(0, 1)
  ])
  img.thumbnail((32, 32))
  img_arr = np.asarray(img)
  if len(img_arr.shape) < 3:
    img_arr = np.expand_dims(img_arr, 2)
    img_arr = np.repeat(img_arr, 3, 2)
  if img_arr.shape[2] > 3:
    img_arr = img_arr[:,:,:3]

  tpad = (32 - img_arr.shape[0]) // 2
  bpad = 32 - tpad - img_arr.shape[0]
  lpad = (32 - img_arr.shape[1]) // 2
  rpad = 32 - img_arr.shape[1] - lpad

  img_arr = np.pad(img_arr, ((tpad, bpad), (lpad, rpad), (0, 0)))
  img_arr = img_arr.transpose(2, 0, 1).astype(np.double) / 256
  img_tensor = transform(torch.Tensor(img_arr))
  return img_tensor

In [4]:
# Takes in PIL image and gives letter prediction. Returns None if prediction is a blank tile.
def get_prediction(img, model):
  img_arr = process_img(img)
  img_arr = img_arr.reshape(1, *img_arr.shape) # Reshape to include batch dimension
  model.eval()
  pred = torch.softmax(model(img_arr), 1).argmax(1)[0]
  if pred == 36:
    return None
  else:
    return chr(65 + pred)

In [5]:
# Initialize model
model = Network()
# Load pretrained model weights
state_dict = torch.load("letter_model_state.sav", map_location=torch.device('cpu'))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [6]:
# Load image from file and get prediction
img = Image.open("o.png")
get_prediction(img, model)

'O'