In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms

from context import rf_pool

**Load MNIST**

In [None]:
# get MNIST training/testing data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=10,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=True, num_workers=2)

**Build Model**

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,16,5), activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(16,32,5), activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,4)))
model.append('3', rf_pool.modules.FeedForward(input_shape=(-1, 64), hidden=torch.nn.Linear(64, 10)))

In [None]:
# print output shapes
model.output_shapes((1,1,28,28))

In [None]:
# set loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

**Load TensorBoard**

In [None]:
# load tensorboard extention
%load_ext tensorboard

In [None]:
# initialize summary writer
writer = SummaryWriter('runs/mnist_experiment1')

In [None]:
# open tensorboard within notebook (return here after training begins)
%tensorboard --logdir runs

**Set Metrics, Train Model**

In [None]:
class Metrics(object):
    def test_acc(self, dataloader):
        d, l = iter(dataloader).next()
        pred = model(d).max(-1)[1]
        return torch.mean((pred == l).float()).item()

In [None]:
# train model and monitor loss, weights, test accuracy
loss_history = model.train_model(1, trainloader, loss_fn, optim, monitor=100, 
                                 tensorboard=writer,
                                 show_weights={'layer_id': '0', 'cmap': 'gray'}, 
                                 metrics=Metrics(), test_acc={'dataloader': testloader})

In [None]:
writer.close()