In [None]:
import imp
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

from context import rf_pool

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.RandomRotation((90,90)),
                                transforms.RandomVerticalFlip(1.),
                                transforms.ToTensor()])
trainset = torchvision.datasets.EMNIST(root='../data',  split='byclass',
                                      train=True, download=True, 
                                      transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=2)

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,16,5), activation=torch.nn.ReLU(), 
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(16,32,5), activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,4)))
branch0 = rf_pool.modules.FeedForward(hidden=torch.nn.Linear(64,10))
branch1 = rf_pool.modules.FeedForward(hidden=torch.nn.Linear(64,52))
model.append('3', rf_pool.modules.Branch(input_shape=(-1, 64), branches=[branch0, branch1], cat_output=True))

In [None]:
# print output_shapes with branch output concatenated
model.output_shapes((1,1,28,28))

In [None]:
# set loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
# train model, monitor weights and lattice
loss_history = model.train(1, trainloader, loss_fn, optim, monitor=100,
                           show_weights={'layer_id': '0', 'cmap': 'gray'})

In [None]:
# set concatenate output to False
model.layers['3'].cat_output = False

In [None]:
# print digit and letter predictions
letters = [chr(n) for n in np.concatenate([np.arange(65,91), np.arange(97,123)])]
digits_letters = np.concatenate([np.arange(10), letters])
data = iter(trainloader).next()
print('actual:', digits_letters[data[1]])
outputs = model(data[0])
print('predicted digit:', torch.argmax(outputs[0]).item())
print('predicted letter:', letters[torch.argmax(outputs[1])])