In [None]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

from context import rf_pool

**Load MNIST Data**

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=2)

**Initialize model with RF layer**

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# create 3x3 uniform lattice
img_shape = torch.tensor([24,24])
rf_layer = rf_pool.pool.RF_Uniform(img_shape, (3,3), spacing=6, sigma_init=3,
                                   pool_fn='max_pool', kernel_size=(2,2))
rf_layer.show_lattice()
print(rf_layer.mu.shape)
n_kernels = rf_layer.mu.shape[0]

**Create control network to update RF locations and sizes**

In [None]:
# set control network for updating mu/sigma
control_net = rf_pool.models.FeedForwardNetwork()
control_net.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(16,32,10), activation=torch.nn.ReLU()))
control_net.append('1', rf_pool.modules.FeedForward(input_shape=(-1,32*15*15), hidden=torch.nn.Linear(32*15*15,128), 
                                                    activation=torch.nn.ReLU()))
mu_branch = rf_pool.modules.FeedForward(hidden=torch.nn.Linear(128,n_kernels*2), activation=torch.nn.Tanh())
sigma_branch = rf_pool.modules.FeedForward(hidden=torch.nn.Linear(128,n_kernels)) #activation=torch.nn.Tanh())
control_net.append('2', rf_pool.modules.Branch([mu_branch,sigma_branch], 
                                               branch_shapes=[(n_kernels, 2), (n_kernels, 1)],
                                               output_names=['delta_mu','delta_sigma']))

In [None]:
# append layers of model
model.append('0', rf_pool.modules.Control(hidden=torch.nn.Conv2d(1,16,5), activation=torch.nn.ReLU(),
                                          control=control_net, pool=rf_layer))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(16,32,5), activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,4)))
model.append('3', rf_pool.modules.FeedForward(input_shape=(-1, 64), hidden=torch.nn.Linear(64, 10)))

In [None]:
# print output shapes
print('Control network:')
print(control_net.output_shapes((1,16,24,24), output={'2': {'branch_0': [], 'branch_1': []}}))
print('Full model:')
print(model.output_shapes((1,1,28,28)))

**Train model and control network to classify digits**

In [None]:
# set loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

In [None]:
# train model, monitor weights and lattice
loss_history = model.train(1, trainloader, loss_fn, optim, monitor=100,
                           show_weights={'layer_id': '0', 'cmap': 'gray'},
                           show_lattice={'cmap': 'gray'})