In [None]:
import imp
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

from context import rf_pool

In [None]:
from rf_pool import modules, models
from rf_pool.utils import functions, datasets

In [None]:
imp.reload(datasets)
imp.reload(functions)
imp.reload(modules)
imp.reload(models)

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../rf_pool/data', train=True, download=True, transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=2)

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# rf pooling layer
img_shape = torch.as_tensor((24,24))
mu, sigma = rf_pool.utils.lattice.init_uniform_lattice(img_shape//2, 4, 7, sigma_init=2.)
rf_layer = rf_pool.layers.RF_Pool(mu=mu, sigma=sigma, img_shape=img_shape, 
                                  lattice_fn=rf_pool.utils.lattice.mask_kernel_lattice,
                                  pool_type='max', kernel_size=2)
rf_layer.show_lattice()
print(mu.shape)
n_kernels = mu.shape[0]

In [None]:
# set control network for updating mu/sigma
control_net = rf_pool.models.FeedForwardNetwork()
control_net.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(16,32,10), activation=torch.nn.ReLU()))
control_net.append('1', rf_pool.modules.FeedForward(input_shape=(-1,32*15*15), hidden=torch.nn.Linear(32*15*15,128), 
                                                 activation=torch.nn.ReLU()))
branch0 = rf_pool.modules.FeedForward(hidden=torch.nn.Linear(128,n_kernels*2), activation=torch.nn.Tanh())
branch1 = rf_pool.modules.FeedForward(hidden=torch.nn.Linear(128,n_kernels), activation=torch.nn.Tanh())
control_net.append('2', rf_pool.modules.Branch([branch0, branch1],
                                            branch_shapes=[(-1,1,n_kernels,2), (-1,1,n_kernels,1)]))

In [None]:
# append layers of model
model.append('0', rf_pool.modules.Control(hidden=torch.nn.Conv2d(1,16,5), activation=torch.nn.ReLU(), 
                                       control=control_net, pool=rf_layer))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(16,32,5), activation=torch.nn.ReLU(),
                                           pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,4)))
model.append('3', rf_pool.modules.FeedForward(input_shape=(-1, 64), hidden=torch.nn.Linear(64, 10)))

In [None]:
# set loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
# set additional L1 loss on control outputs 
# note that since control_net outputs two branches the two outputs are flattened and concatenated
n_branch_out = np.sum([np.prod(shp) for shp in control_net.output_shapes((1,16,24,24))[-1]])
add_loss = {'loss_fn': torch.nn.L1Loss(reduction='sum'), 'target': torch.zeros(n_branch_out),
            'layer_ids': ['0'], 'module_name': 'control'}

In [None]:
# train model, monitor weights and lattice
loss_history = model.train(1, trainloader, loss_fn, optim, monitor=100,
                           show_weights={'layer_id': '0', 'cmap': 'gray'},
                           show_lattice={},
                           add_loss=add_loss)