In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

from models import FeedForwardModel
from layers import RF_Pool

In [2]:
rf_layer = RF_Pool(mu=torch.rand(5,2), sigma=torch.rand(5,1), img_shape=(20,20))
rf_layer.init_rfs()

In [3]:
# model params
loss_type = 'cross_entropy'
optimizer_type = 'adam'
model = FeedForwardModel(loss_type, optimizer_type)

In [4]:
# ff parameters
data_shape = (10,3,28,28)
layer_types = ['conv','conv']
output_channels = [10,10]
kernel_sizes = [9,5]
act_types = [torch.nn.ReLU(),None]
pool_types = [rf_layer,None]
pool_ksizes = [2,None]
dropout_types = [None,None]

model.ff_network(data_shape=data_shape, layer_types=layer_types, output_channels=output_channels, 
                 kernel_sizes=kernel_sizes, act_types=act_types, pool_types=pool_types, pool_ksizes=pool_ksizes,
                 dropout_types=dropout_types)

n_params = len(model.get_trainable_params())
print("Number of 'on' params from ff_net: ", n_params)

In [8]:
model.control_network([0], ['conv','fc'], [(-1,1,5,2), (-1,1,5,1)], 
                      output_channels=[10,128], kernel_sizes=[5, None])

n_params = len(model.get_trainable_params())
print("Number of 'on' params from ff_net + control_net: ", n_params)

In [11]:
# setting requires_gradient:
model.set_requires_grad("hidden", False) # turns off ff_net hidden layer params
n_params = len(model.get_trainable_params())
print("Number of 'on' params after removing ff_net: ", n_params)

Number of 'on' params after removing ff_net:  8


In [12]:
# run
random_input = torch.rand(data_shape)
model(random_input).shape

torch.Size([10, 10, 6, 6])

In [None]:
# train
epochs = 2
train_loader = None # torch.utils.data.DataLoader
#model.train_model(epochs, trainloader) # requires a torch.utils.data.DataLoader