In [1]:
import imp
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import os

from context import rf_pool

In [2]:
from rf_pool import modules, models, pool, ops
from rf_pool.utils import functions, datasets, stimuli

In [3]:
#experiment directories
if not os.path.exists('models'):
    os.mkdir('models')      

**Load MNIST data**

In [4]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../../data', train=True, download=True, 
                                       transform=transform)
testset = torchvision.datasets.MNIST(root='../../data', train=False, download=True,
                                     transform=transform)

In [5]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=10,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=10,
                                         shuffle=True, num_workers=2)

**Initialize the model**

In [6]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,32,5),
                                              activation=torch.nn.ReLU(), 
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,5),
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(64,10,4)))
model.append('3', rf_pool.modules.FeedForward(input_shape=(-1,10)))

In [None]:
# set loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

**Train the model**

In [None]:
# train model, monitor weights and lattice
loss_history = model.train(10, trainloader, loss_fn, optim, monitor=100,
                           show_weights={'layer_id': '0', 'cmap': 'gray'})

In [None]:
# gets test accuracy
acc = model.get_accuracy(testloader)

In [None]:
# saves the model to a pickle file
model.save_model('models/MNIST_rate_0.2_10k_3deg.pkl')