In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,TensorDataset
import copy
from sklearn.model_selection import train_test_split
import torchvision
import matplotlib.pyplot as plt
from IPython import display
display.set_matplotlib_formats('svg')

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
cdata = torchvision.datasets.EMNIST(root='emnist',split='letters',download=True)


In [None]:
images = cdata.data.view([124800,1,28,28]).float()
# normalize the images
images /= torch.max(images)
print('\nTensor data:')
print(images.shape)


In [None]:
letterCategories = cdata.classes[1:]
labels = copy.deepcopy(cdata.targets)-1

In [None]:
train_data,test_data, train_labels,test_labels = train_test_split(images, labels, test_size=.1)
train_data = TensorDataset(train_data,train_labels)
test_data  = TensorDataset(test_data,test_labels)
batchsize    = 32
train_loader = DataLoader(train_data,batch_size=batchsize,shuffle=True,drop_last=True)
test_loader  = DataLoader(test_data,batch_size=test_data.tensors[0].shape[0])

In [None]:
def makeTheNet(numchans=(6,6)):

  class emnistnet(nn.Module):
    def __init__(self,numchans):
      super().__init__()

      self.conv1  = nn.Conv2d(1,numchans[0],3,padding=1)
      self.bnorm1 = nn.BatchNorm2d(numchans[0])
      self.conv2  = nn.Conv2d(numchans[0],numchans[1],3,padding=1)
      self.bnorm2 = nn.BatchNorm2d(numchans[1])
      self.fc1 = nn.Linear(7*7*numchans[1],50)
      self.fc2 = nn.Linear(50,26)

    def forward(self,x):

      x = F.max_pool2d(self.conv1(x),2)
      x = F.leaky_relu(self.bnorm1(x))
      x = F.max_pool2d(self.conv2(x),2)
      x = F.leaky_relu(self.bnorm2(x))
      nUnits = x.shape.numel()/x.shape[0]
      x = x.view(-1,int(nUnits))
      x = F.leaky_relu(self.fc1(x))
      x = self.fc2(x)
      return x

  net = emnistnet(numchans)
  lossfun = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters(),lr=.001)

  return net,lossfun,optimizer