# Notebook setup

In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
import torch.nn.utils.prune as prune

In [2]:
# make sure that you have a GPU: this cell should print 'cuda'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


# Building sparse recurrent neural networks (RNNs)

In [3]:
# to ensure reproducibility, we should set the torch seed
torch.manual_seed(42)

<torch._C.Generator at 0x2ba73f26d350>

In [4]:
## make the model - a recurrent network with one hidden layer, and a fully connected output layer
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__() # initialise the nn.Module
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity = 'relu', batch_first = True)
        self.fc = nn.Linear(hidden_size, num_classes)


    def forward(self, x):
        # set initial hidden state
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        # pass input throught the recurrent layer
        output, hidden = self.rnn(x, h0) # output shape (batch_size, seq_length, hidden_size)
        # reshape the output so it fits into the fully connected layer (get the last output from the RNN)
        output = output[:, -1, :]
        # pass it to the linear layer to get the classification
        output = self.fc(output)

        return output

In [5]:
input_size = 28 # the size of the input at each timestep
hidden_size = 10000 # how many nodes do we want in the hidden layer?
num_layers = 1 # how many hidden recurrent layers 
num_classes = 10 # this is defined by the dataset, which has 10 classes (digits 0-9)

In [6]:
model = RNN(input_size, hidden_size, num_layers, num_classes)

In [7]:
print(model)

RNN(
  (rnn): RNN(28, 10000, batch_first=True)
  (fc): Linear(in_features=10000, out_features=10, bias=True)
)


In [8]:
# let's look at what makes up the recurrent layer
print(list(model.rnn.named_parameters()))

[('weight_ih_l0', Parameter containing:
tensor([[ 7.6454e-03,  8.3001e-03, -2.3427e-03,  ..., -4.6101e-03,
         -2.8237e-03, -6.0127e-03],
        [ 9.4383e-04, -9.8768e-03,  9.0311e-03,  ...,  8.2054e-03,
          2.8803e-03,  4.1421e-03],
        [ 3.1626e-03, -1.7396e-04,  7.8261e-03,  ..., -6.8171e-03,
          5.3058e-03, -4.0420e-03],
        ...,
        [ 6.5973e-03, -2.2318e-03,  2.7703e-03,  ..., -4.4835e-03,
          6.6070e-03, -8.1006e-03],
        [ 2.9607e-03,  6.2242e-03,  4.3185e-03,  ..., -4.0194e-04,
          2.5944e-05,  9.5441e-03],
        [ 1.1639e-03,  6.6203e-03,  4.5857e-03,  ..., -7.5040e-03,
          6.9919e-03, -4.0599e-03]], requires_grad=True)), ('weight_hh_l0', Parameter containing:
tensor([[ 0.0070,  0.0043, -0.0034,  ...,  0.0003, -0.0029,  0.0076],
        [ 0.0007, -0.0073, -0.0014,  ...,  0.0099, -0.0019, -0.0050],
        [ 0.0044,  0.0034,  0.0088,  ..., -0.0020,  0.0070, -0.0095],
        ...,
        [ 0.0080, -0.0074, -0.0059,  ...,  0

**weight_ih_l0** - weights from the input layer to the recurrent layer

**weight_hh_l0** - weights between nodes in the recurrent layer, by default each node is connected to all other nodes!

**bias_ih_l0, bias_hh_l0** - the bias modifies the activation of each node

In [9]:
# here, the strength of the connection between each pair of nodes is specified
model.rnn.weight_hh_l0.shape

torch.Size([10000, 10000])

In [10]:
# these weights are randomly initialised by default - this is what we want to change to add our connectivity constraints
model.rnn.weight_hh_l0

Parameter containing:
tensor([[ 0.0070,  0.0043, -0.0034,  ...,  0.0003, -0.0029,  0.0076],
        [ 0.0007, -0.0073, -0.0014,  ...,  0.0099, -0.0019, -0.0050],
        [ 0.0044,  0.0034,  0.0088,  ..., -0.0020,  0.0070, -0.0095],
        ...,
        [ 0.0080, -0.0074, -0.0059,  ...,  0.0088, -0.0020,  0.0023],
        [-0.0012, -0.0076, -0.0048,  ...,  0.0091, -0.0004,  0.0088],
        [-0.0040,  0.0074,  0.0059,  ...,  0.0025,  0.0009, -0.0042]],
       requires_grad=True)

## Generate sparse connection matrix

Let's make the connections in the recurrent layer randomly sparse. First, we need to decide which connections should remain. To do this, we generate a connectivity matrix, which has the shape (n_nodes x n_nodes) and contains 1s and 0s (1 = the nodes are connected by a trainable weight, 0 = they are not connected)

In [11]:
def generate_con_matrix(n, p, seed = 42):
    np.random.seed(seed) # make things reproducible
    con = np.random.rand(n,n) # the connectivity matrix has shape n_nodes x n_nodes
    con = np.array(con < p).astype(int) # threshold according to the connection probability
    
    return torch.tensor(con)

In [12]:
# p is the connection probability (how likely is it that any two nodes are connected?) - here it is 10%
a = generate_con_matrix(10, 0.1) 
a

tensor([[0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 1, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 1, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])

In [13]:
a.sum() # we would expect about 10% of the weights to be nonzero (10% of 100 = 10)

tensor(13)

In [14]:
# we can also generate a fully connected connectivity matrix using this method - that's good for controls
b = generate_con_matrix(10, 1.0) 
b

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

But how do we apply this to our RNN? We just saw that by default, the recurrent layer is fully connected. Enter: torch prune. This lets us remove specified weights:

In [15]:
# randomly initialise weights from a uniform distribution
model.rnn.weight_hh_l0 = torch.nn.init.uniform_(model.rnn.weight_hh_l0, a=-0.001, b=0.001)

# use the torch prune method to remove the weights which have a zero in our connectivity matrix - they won't be trainable
# you only need to do this once, when initialising the network. 
con = generate_con_matrix(10000, 0.1)
prune.custom_from_mask(model.rnn, 'weight_hh_l0', con)

RNN(28, 10000, batch_first=True)

Under the hood, torch prune inserts what's called a forward_hook. This is a function that is run every time before the model performs a forward pass. This means that all weights will actually get changed when weights are updated by the optimiser, but then prune re-zeroes our sparse weights before the network does anything else.

In [16]:
model.rnn._forward_pre_hooks

OrderedDict([(0, <torch.nn.utils.prune.CustomFromMask at 0x2ba740f72af0>)])

In [17]:
# our connectivity matrix is stored here by the network, to be applied at every forward pass
print(list(model.rnn.named_buffers())) 

[('weight_hh_l0_mask', tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        ...,
        [1., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.]]))]


In [18]:
con # should be the same

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 1, 0, 0],
        ...,
        [1, 1, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 1, 0,  ..., 0, 0, 0]])

## Load dataset

In [19]:
# basic datasets like MNIST are built into pytorch, so we can just import them:
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

In [20]:
data_dir = '/gpfs/soma_fs/scratch/rfruengel/data/MNIST' # set this to wherever you want to keep the dataset

RNNs accept input as a timeseries - in order to do image recognition with them, we have to split the images up. We present one row of the image at each timestep, so for MNIST, that's 28 timesteps, each with 28 input pixels. For MNIST, which only has one colour channel, this works without further adjustment. The images are 28x28: the RNN will interpret this as 28 timesteps, each with 28 inputs. 

In [21]:
# set up torch dataloaders to handle batching the data for training and testing
train_data = datasets.MNIST(root = data_dir, train = True, transform = ToTensor())
test_data = datasets.MNIST(root = data_dir, train = False, transform = ToTensor())

loaders = {
'train': DataLoader(train_data, batch_size = 100, shuffle = True, num_workers = 1),
'test': DataLoader(test_data, batch_size = 100, shuffle = False, num_workers = 1)}

## Train it

In [22]:
# here, we use cross entropy loss, because our outputs are categorical (classes)
loss_func = nn.CrossEntropyLoss()

In [23]:
# we use the Adam optimiser
optimiser = optim.Adam(model.parameters(), lr = 0.001)

In [24]:
def train(num_epochs, model, loaders):
    print('------------------------------')
    print('num_epochs: {}'.format(num_epochs))
    print('model: {}'.format(model))
    print('loaders: {}'.format(loaders))
    print('training on {}'.format(device))
    
    total_step = len(loaders['train'])
    
    print('------------------------------')
    for epoch in range(num_epochs):
        print('Started epoch: {}'.format(epoch+1))
        for i, (images, labels) in enumerate(loaders['train']):
            images = images.reshape(-1, 28, 28).to(device)
            labels = labels.to(device)

            # forward pass
            outputs = model(images)
            
            # pass output to loss function
            loss = loss_func(outputs, labels)
            
            # clear gradients from previous epoch
            optimiser.zero_grad()
            
            # backpropagation
            loss.backward()
            
            # update weights
            optimiser.step()
            
            # print progress
            if (i+1)%100 == 0:
                print('Step [{} / {}], Loss: {}'.format(i+1, total_step, loss.item()))
        
        # at the end of each epoch:
        model.eval() # put the model in evaluation mode
        running_test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad(): # get test loss/accuracy over whole test dataset
            for t, (images, labels) in enumerate(loaders['test']):
                images = images.reshape(-1, 28, 28).to(device)
                labels = labels.to(device)
                # get model predictions
                outputs = model(images)
                # get the label with the strongest activation, this is the model's class prediction
                _, predicted = torch.max(outputs.data, 1)
                # compute test loss
                loss = loss_func(outputs, labels)
                running_test_loss += loss.item()
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_loss = running_test_loss / t+1
        accuracy = 100 * correct/float(total)
        model.train() # put it back in training mode
        print('TESTING...')
        print('Loss: {}, accuracy: {}%'.format(test_loss, accuracy))

In [52]:
# everything needed for training must be sent to the GPU
model.to(device)

RNN(
  (rnn): RNN(28, 10000, batch_first=True)
  (fc): Linear(in_features=10000, out_features=10, bias=True)
)

In [53]:
num_epochs = 10
train(num_epochs, model, loaders)

------------------------------
num_epochs: 10
model: RNN(
  (rnn): RNN(28, 10000, batch_first=True)
  (fc): Linear(in_features=10000, out_features=10, bias=True)
)
loaders: {'train': <torch.utils.data.dataloader.DataLoader object at 0x2b148fd0cd90>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x2b14d6e0d730>}
training on cuda
------------------------------
Started epoch: 1
Step [100 / 600], Loss: 1.2321842908859253
Step [200 / 600], Loss: 0.800067663192749
Step [300 / 600], Loss: 0.3912740647792816
Step [400 / 600], Loss: 0.5160369873046875
Step [500 / 600], Loss: 0.24377603828907013
Step [600 / 600], Loss: 0.2154218703508377
TESTING...
Loss: 1.2641936365627882, accuracy: 92.28%
Started epoch: 2
Step [100 / 600], Loss: 0.11632823944091797
Step [200 / 600], Loss: 0.31488898396492004
Step [300 / 600], Loss: 0.39791491627693176
Step [400 / 600], Loss: 0.3656487762928009
Step [500 / 600], Loss: 0.11766573041677475
Step [600 / 600], Loss: 0.10287356376647949
TESTING...
Loss: 1

## Parallelised training on MNIST, CIFAR10 or Sleep-EDF

Code for Fig. 2

In [None]:
# make a folder so we can record stats about our different networks while they train
outdir = '/path/to/output/folder'

Below are delayed dask functions for training sparse RNNs on MNIST, CIFAR10 or Sleep-EDF. They also allow you to record the weight matrix at each epoch, and to record the activations from the hidden layer to the output layer on the test set (warning: this takes up a lot of disk space!)

In [None]:
import dask # we used dask for parallel computing

In [None]:
@dask.delayed
def make_model_and_train(n_cells, con_p, num_epochs, trackfile_format = '{}_{}', outdir = None, 
                         torch_seed = 0, early_stop = True, patience = 20,
                         weight_track = None, activ_track = None):
    '''
    Makes a (sparse) RNN, trains it on MNIST and records loss and accuracy and, if set, hidden layer weights and activations.
    
    n_cells: int, number of nodes in recurrent hidden layer.
    con_p: float between 0 and 1, connection probability between any two nodes in the hidden layer, used for generating connectivity matrix.
    num_epochs: int, number of epochs to train the model for
    trackfile_format: str, base for filename. Will be formatted with .format(n_cells, con_p)
    outdir: str, folder path
    torch_seed: int, seeds torch and numpy for reproducibility
    early_stop: bool, enables early stopping if loss stops decreasing.
    patience: int, number of epochs without loss improvement before training is stopped. Only relevant if early_stop = True.
    weight_track: None or str, folder path. If None, weights will not be saved.
    activ_track: None or str, folder path. If None, activations will not be saved.
    
    '''
    
    def record_test_stats(running_train_loss, i, epoch):
        model.eval()
        train_loss = running_train_loss / i+1
        model.eval()
        running_test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad(): # get test loss over whole test dataset
            for t, (images, labels) in enumerate(loaders['test']):
                images = images.reshape(-1, 28, 28).to(device)
                labels = labels.to(device)
                # get model predictions
                outputs = model(images, i=t, epoch=epoch)
                # get the label with the strongest activation
                _, predicted = torch.max(outputs.data, 1)
                # compute test loss
                loss = loss_func(outputs, labels)
                running_test_loss += loss.item()
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_loss = running_test_loss / t+1
        # dump it somewhere
        with open(trackfile, 'a') as f:
            f.write('{}_{}_{}_{}\n'.format(epoch, train_loss, test_loss, 100 * correct/float(total)))
        model.train()

        return test_loss


    device = torch.device('cuda')
    torch.manual_seed(torch_seed)
    train_data = datasets.MNIST(root = data_dir, train = True, transform = ToTensor(), download = False)
    test_data = datasets.MNIST(root = data_dir, train = False, transform = ToTensor())

    loaders = {
    'train': DataLoader(train_data, batch_size = 100, shuffle = True, num_workers = 0),
    'test': DataLoader(test_data, batch_size = 100, shuffle = True, num_workers = 0)}

    ## make the model
    class RNN(nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, num_classes):
            super(RNN, self).__init__() # initialise the nn.Module
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity = 'relu', batch_first = True)
            self.fc = nn.Linear(hidden_size, num_classes)
            self.con_p = con_p # just for writing the track file
            self.activ_track = activ_track


        def forward(self, x, i=None, epoch=None):
            # set initial hidden state
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            # forward run model
            output, hidden = self.rnn(x, h0) # output shape (batch_size, seq_length, hidden_size)
            # reshape the output so it fits into the fully connected layer (get the last output from the RNN)
            output = output[:, -1, :]
            # pass it to the linear layer
            output = self.fc(output)
            # if tracking activations:
            if self.activ_track is not None and i is not None: 
                with open(self.activ_track.join(trackfile_format.format(self.hidden_size, self.con_p) + '_activations_epoch_{}_{}'.format(epoch, i)), 'wb') as f:
                    np.savez(f, recurrent_layer = output.cpu().detach().numpy(), fc_layer = fc_output.cpu().detach().numpy())

            return output

    model = RNN(28, n_cells, 1, 10)

    con = generate_con_matrix(n_cells, con_p, seed = torch_seed)
    rnn_layer = model.rnn
    rnn_layer.weight_hh_l0 = torch.nn.init.uniform_(rnn_layer.weight_hh_l0, a=-0.001, b=0.001) # randomly initialise weights
    prune.custom_from_mask(rnn_layer, 'weight_hh_l0', con) # remove connections not present in the anatomical model

    model.to(device)
    loss_func = nn.CrossEntropyLoss()
    optimiser = optim.Adam(model.parameters(), lr = 0.001)

    trackfile = outdir.join(trackfile_format.format(n_cells, con_p))

    # record test loss before starting
    test_loss = record_test_stats(np.nan, np.nan, 'before') 

    if early_stop:
        test_loss_history = test_loss
        early_stop_counter = 0
    early_stop_flag = False

    for epoch in range(num_epochs):
        print('Started epoch: {}'.format(epoch))
        running_train_loss = 0.0

        for i, (images, labels) in enumerate(loaders['train']):

            images = images.reshape(-1, 28, 28).to(device)
            labels = labels.to(device)

            # forward pass
            outputs = model(images)

            # pass output to loss function
            loss = loss_func(outputs, labels)
            # backpropagation
            loss.backward()
            running_train_loss += loss.item()

            # update weights
            optimiser.step()
            # clear gradients from previous epoch
            optimiser.zero_grad(set_to_none = True)

            # print progress

            if i+1 == len(loaders['train']): # at the end of each epoch record stats
                if weight_track is not None: # save weights to file
                    weights = model.rnn.weight_hh_l0.cpu().detach().numpy()
                    weights_fc = model.fc.weight.cpu().detach().numpy()
                    assert weights is not None
                    with open(weight_track.join(trackfile_format.format(n_cells, con_p) + '_weights_epoch_{}'.format(epoch)), 'wb') as f:
                        np.savez(f, recurrent_layer = weights, fc_layer = weights_fc) 
                test_loss = record_test_stats(running_train_loss, i, epoch)
                if early_stop:
                    if test_loss < test_loss_history: # if improved
                        test_loss_history = test_loss
                        early_stop_counter = 0
                    else:
                        early_stop_counter += 1
                        if early_stop_counter >= patience:
                            early_stop_flag = True
        if early_stop_flag:
            break

    with open(trackfile, 'a') as f:
        f.write('training completed successfully!\n')

In [None]:
@dask.delayed
def make_model_and_train(n_cells, con_p, num_epochs, trackfile_format = '{}_{}', outdir = None, 
                         torch_seed = 0, early_stop = True, patience = 20,
                         weight_track = None, activ_track = None):
    '''
    Makes a (sparse) RNN, trains it on CIFAR10 and records loss and accuracy and, if set, hidden layer weights and activations.
    
    n_cells: int, number of nodes in recurrent hidden layer.
    con_p: float between 0 and 1, connection probability between any two nodes in the hidden layer, used for generating connectivity matrix.
    num_epochs: int, number of epochs to train the model for
    trackfile_format: str, base for filename. Will be formatted with .format(n_cells, con_p)
    outdir: str, folder path
    torch_seed: int, seeds torch and numpy for reproducibility
    early_stop: bool, enables early stopping if loss stops decreasing.
    patience: int, number of epochs without loss improvement before training is stopped. Only relevant if early_stop = True.
    weight_track: None or str, folder path. If None, weights will not be saved.
    activ_track: None or str, folder path. If None, activations will not be saved.
    
    '''
    
    def record_test_stats(running_train_loss, i, epoch):
        model.eval()
        train_loss = running_train_loss / i+1
        model.eval()
        running_test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad(): # get test loss over whole test dataset
            for t, (images, labels) in enumerate(loaders['test']):
                images = torch.tensor(np.concatenate((images[:,0], images[:,1], images[:,2]), axis = 2)).to(device)
                labels = labels.to(device)
                # get model predictions
                outputs = model(images, i=t, epoch=epoch)
                # get the label with the strongest activation
                _, predicted = torch.max(outputs.data, 1)
                # compute test loss
                loss = loss_func(outputs, labels)
                running_test_loss += loss.item()
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_loss = running_test_loss / t+1
        # dump it somewhere
        with open(trackfile, 'a') as f:
            f.write('{}_{}_{}_{}\n'.format(epoch, train_loss, test_loss, 100 * correct/float(total)))
        model.train()

        return test_loss


    device = torch.device('cuda')
    torch.manual_seed(torch_seed)
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_data = datasets.CIFAR10(root = data_dir, train = True, transform = transform, download = False)
    test_data = datasets.CIFAR10(root = data_dir, train = False, transform = transform)

    loaders = {
    'train': DataLoader(train_data, batch_size = 100, shuffle = True, num_workers = 0),
    'test': DataLoader(test_data, batch_size = 100, shuffle = True, num_workers = 0)}

    ## make the model
    class RNN(nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, num_classes):
            super(RNN, self).__init__() # initialise the nn.Module
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity = 'relu', batch_first = True)
            self.fc = nn.Linear(hidden_size, num_classes)
            self.con_p = con_p # just for writing the track file
            self.activ_track = activ_track


        def forward(self, x, i=None, epoch=None):
            # set initial hidden state
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            # forward run model
            output, hidden = self.rnn(x, h0) # output shape (batch_size, seq_length, hidden_size)
            # reshape the output so it fits into the fully connected layer (get the last output from the RNN)
            output = output[:, -1, :]
            # pass it to the linear layer
            output = self.fc(output)
            # if tracking activations:
            if self.activ_track is not None and i is not None: 
                with open(self.activ_track.join(trackfile_format.format(self.hidden_size, self.con_p) + '_activations_epoch_{}_{}'.format(epoch, i)), 'wb') as f:
                    np.savez(f, recurrent_layer = output.cpu().detach().numpy(), fc_layer = fc_output.cpu().detach().numpy())

            return output

    model = RNN(96, n_cells, 1, 10)

    con = generate_con_matrix(n_cells, con_p, seed = torch_seed)
    rnn_layer = model.rnn
    rnn_layer.weight_hh_l0 = torch.nn.init.uniform_(rnn_layer.weight_hh_l0, a=-0.001, b=0.001) # randomly initialise weights
    prune.custom_from_mask(rnn_layer, 'weight_hh_l0', con) # remove connections not present in the anatomical model

    model.to(device)
    loss_func = nn.CrossEntropyLoss()
    optimiser = optim.Adam(model.parameters(), lr = 0.001)

    trackfile = outdir.join(trackfile_format.format(n_cells, con_p))

    # record test loss before starting
    test_loss = record_test_stats(np.nan, np.nan, 'before') 

    if early_stop:
        test_loss_history = test_loss
        early_stop_counter = 0
    early_stop_flag = False

    for epoch in range(num_epochs):
        print('Started epoch: {}'.format(epoch))
        running_train_loss = 0.0

        for i, (images, labels) in enumerate(loaders['train']):

            images = torch.tensor(np.concatenate((images[:,0], images[:,1], images[:,2]), axis = 2)).to(device)
            labels = labels.to(device)

            # forward pass
            outputs = model(images)

            # pass output to loss function
            loss = loss_func(outputs, labels)
            # backpropagation
            loss.backward()
            running_train_loss += loss.item()

            # update weights
            optimiser.step()
            # clear gradients from previous epoch
            optimiser.zero_grad(set_to_none = True)

            # print progress

            if i+1 == len(loaders['train']): # at the end of each epoch record stats
                if weight_track is not None: # save weights to file
                    weights = model.rnn.weight_hh_l0.cpu().detach().numpy()
                    weights_fc = model.fc.weight.cpu().detach().numpy()
                    assert weights is not None
                    with open(weight_track.join(trackfile_format.format(n_cells, con_p) + '_weights_epoch_{}'.format(epoch)), 'wb') as f:
                        np.savez(f, recurrent_layer = weights, fc_layer = weights_fc) 
                test_loss = record_test_stats(running_train_loss, i, epoch)
                if early_stop:
                    if test_loss < test_loss_history: # if improved
                        test_loss_history = test_loss
                        early_stop_counter = 0
                    else:
                        early_stop_counter += 1
                        if early_stop_counter >= patience:
                            early_stop_flag = True
        if early_stop_flag:
            break

    with open(trackfile, 'a') as f:
        f.write('training completed successfully!\n')

In [None]:
# import torchtime which bundles the Sleep-EDF dataset in a preprocessed format
import torchtime
from torchtime.data import UEA

In [None]:
@dask.delayed
def make_model_and_train(n_cells, con_p, num_epochs, trackfile_format = '{}_{}', outdir = None, 
                         torch_seed = 0, early_stop = True, patience = 20,
                         weight_track = None, activ_track = None):
    '''
    Makes a (sparse) RNN, trains it on Sleep-EDF and records loss and accuracy and, if set, hidden layer weights and activations.
    
    n_cells: int, number of nodes in recurrent hidden layer.
    con_p: float between 0 and 1, connection probability between any two nodes in the hidden layer, used for generating connectivity matrix.
    num_epochs: int, number of epochs to train the model for
    trackfile_format: str, base for filename. Will be formatted with .format(n_cells, con_p)
    outdir: str, folder path
    torch_seed: int, seeds torch and numpy for reproducibility
    early_stop: bool, enables early stopping if loss stops decreasing.
    patience: int, number of epochs without loss improvement before training is stopped. Only relevant if early_stop = True.
    weight_track: None or str, folder path. If None, weights will not be saved.
    activ_track: None or str, folder path. If None, activations will not be saved.
    
    '''
    
    def record_test_stats(running_train_loss, i, epoch):
        model.eval()
        train_loss = running_train_loss / i+1
        model.eval()
        running_test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad(): # get test loss over whole test dataset
            for t, (images, labels) in enumerate(loaders['test']):
                images = dat['X'][:, :, 1].unsqueeze(axis = 2).to(device)
                labels = dat['y'].to(device)
                # get model predictions
                outputs = model(images, i=t, epoch=epoch)
                # get the label with the strongest activation
                _, predicted = torch.max(outputs.data, 1)
                # compute test loss
                loss = loss_func(outputs, labels)
                running_test_loss += loss.item()
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_loss = running_test_loss / t+1
        # dump it somewhere
        with open(trackfile, 'a') as f:
            f.write('{}_{}_{}_{}\n'.format(epoch, train_loss, test_loss, 100 * correct/float(total)))
        model.train()

        return test_loss


    device = torch.device('cuda')
    torch.manual_seed(torch_seed)
    test_data = UEA(
        dataset="Sleep",
        split="val",
        train_prop=0.7,
        seed=123,  # for reproducibility
        path = '/gpfs/soma_fs/scratch/rfruengel/data/sleep_EEG_torchtime'
    )
    train_data = UEA(
        dataset="Sleep",
        split="train",
        train_prop=0.7,
        seed=123,  # for reproducibility
        path = '/gpfs/soma_fs/scratch/rfruengel/data/sleep_EEG_torchtime'
    )

    loaders = {
    'train': DataLoader(train_data, batch_size = 100, shuffle = True, num_workers = 0),
    'test': DataLoader(test_data, batch_size = 100, shuffle = True, num_workers = 0)}

    ## make the model
    class RNN(nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, num_classes):
            super(RNN, self).__init__() # initialise the nn.Module
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity = 'relu', batch_first = True)
            self.fc = nn.Linear(hidden_size, num_classes)
            self.con_p = con_p # just for writing the track file
            self.activ_track = activ_track


        def forward(self, x, i=None, epoch=None):
            # set initial hidden state
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
            # forward run model
            output, hidden = self.rnn(x, h0) # output shape (batch_size, seq_length, hidden_size)
            # reshape the output so it fits into the fully connected layer (get the last output from the RNN)
            output = output[:, -1, :]
            # pass it to the linear layer
            output = self.fc(output)
            # if tracking activations:
            if self.activ_track is not None and i is not None: 
                with open(self.activ_track.join(trackfile_format.format(self.hidden_size, self.con_p) + '_activations_epoch_{}_{}'.format(epoch, i)), 'wb') as f:
                    np.savez(f, recurrent_layer = output.cpu().detach().numpy(), fc_layer = fc_output.cpu().detach().numpy())

            return output

    model = RNN(1, n_cells, 1, 5)

    con = generate_con_matrix(n_cells, con_p, seed = torch_seed)
    rnn_layer = model.rnn
    rnn_layer.weight_hh_l0 = torch.nn.init.uniform_(rnn_layer.weight_hh_l0, a=-0.001, b=0.001) # randomly initialise weights
    prune.custom_from_mask(rnn_layer, 'weight_hh_l0', con) # remove connections not present in the anatomical model

    model.to(device)
    loss_func = nn.CrossEntropyLoss()
    optimiser = optim.Adam(model.parameters(), lr = 0.001)

    trackfile = outdir.join(trackfile_format.format(n_cells, con_p))

    # record test loss before starting
    test_loss = record_test_stats(np.nan, np.nan, 'before') 

    if early_stop:
        test_loss_history = test_loss
        early_stop_counter = 0
    early_stop_flag = False

    for epoch in range(num_epochs):
        print('Started epoch: {}'.format(epoch))
        running_train_loss = 0.0

        for i, (images, labels) in enumerate(loaders['train']):

            images = dat['X'][:, :, 1].unsqueeze(axis = 2).to(device)
            labels = dat['y'].to(device)

            # forward pass
            outputs = model(images)

            # pass output to loss function
            loss = loss_func(outputs, labels)
            # backpropagation
            loss.backward()
            running_train_loss += loss.item()

            # update weights
            optimiser.step()
            # clear gradients from previous epoch
            optimiser.zero_grad(set_to_none = True)

            # print progress

            if i+1 == len(loaders['train']): # at the end of each epoch record stats
                if weight_track is not None: # save weights to file
                    weights = model.rnn.weight_hh_l0.cpu().detach().numpy()
                    weights_fc = model.fc.weight.cpu().detach().numpy()
                    assert weights is not None
                    with open(weight_track.join(trackfile_format.format(n_cells, con_p) + '_weights_epoch_{}'.format(epoch)), 'wb') as f:
                        np.savez(f, recurrent_layer = weights, fc_layer = weights_fc) 
                test_loss = record_test_stats(running_train_loss, i, epoch)
                if early_stop:
                    if test_loss < test_loss_history: # if improved
                        test_loss_history = test_loss
                        early_stop_counter = 0
                    else:
                        early_stop_counter += 1
                        if early_stop_counter >= patience:
                            early_stop_flag = True
        if early_stop_flag:
            break

    with open(trackfile, 'a') as f:
        f.write('training completed successfully!\n')

## Equivalent code for training feedforward networks

In [None]:
@dask.delayed
def make_model_and_train(hidden_size, con_p, num_layers, num_epochs, trackfile_format = '{}_{}', 
                         outdir = None, torch_seed = 0, early_stop = True, patience = 20):
    '''
    Makes a (sparse) feedforward ANN, trains it on MNIST and records loss and accuracy
    
    hidden_size: int, number of nodes in each hidden layer.
    con_p: float between 0 and 1, connection probability from a hidden node in one layer to the next layer
    num_layers: int, number of hidden layers (2 was used in the paper)
    num_epochs: int, number of epochs to train the model for
    trackfile_format: str, base for filename. Will be formatted with .format(n_cells, con_p)
    outdir: str, folder path
    torch_seed: int, seeds torch and numpy for reproducibility
    early_stop: bool, enables early stopping if loss stops decreasing.
    patience: int, number of epochs without loss improvement before training is stopped. Only relevant if early_stop = True.

    
    '''
    
    def record_test_stats(running_train_loss, i, epoch):
        model.eval()
        train_loss = running_train_loss / i+1
        model.eval()
        running_test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad(): # get test loss over whole test dataset
            for t, (images, labels) in enumerate(loaders['test']):
                images = images.reshape(-1, 28*28).to(device)
                labels = labels.to(device)
                # get model predictions
                outputs = model(images)
                # get the label with the strongest activation
                _, predicted = torch.max(outputs.data, 1)
                # compute test loss
                loss = loss_func(outputs, labels)
                running_test_loss += loss.item()
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_loss = running_test_loss / t+1
        # dump it somewhere
        with open(trackfile, 'a') as f:
            f.write('{}_{}_{}_{}\n'.format(epoch, train_loss, test_loss, 100 * correct/float(total)))
        model.train()

        return test_loss

    device = torch.device('cuda')
    torch.manual_seed(torch_seed)
    train_data = datasets.MNIST(root = data_dir, train = True, transform = ToTensor(), download = False)
    test_data = datasets.MNIST(root = data_dir, train = False, transform = ToTensor())

    loaders = {
    'train': DataLoader(train_data, batch_size = 100, shuffle = True, num_workers = 0),
    'test': DataLoader(test_data, batch_size = 100, shuffle = True, num_workers = 0)}

    ## make the model
    class feedforward(nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, num_classes):
            super(feedforward, self).__init__() # initialise the nn.Module
            self.hidden_size = hidden_size
            self.num_layers = num_layers

            self.l1 = nn.Linear(input_size, hidden_size)
            self.relu = nn.ReLU()
            self.linears = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for i in range(num_layers-1)])

            self.fc = nn.Linear(hidden_size, num_classes)


        def forward(self, x):
            # forward run model
            out = self.l1(x)
            out = self.relu(out)

            for lin in self.linears:
                out = lin(out)
                out = self.relu(out)

            # pass it to the linear layer
            output = self.fc(out)

            return output

    model = feedforward(28*28, hidden_size, num_layers, 10)

    for l in range(num_layers-1):
        con = generate_con_matrix(hidden_size, con_p)
        lin_layer = model.linears[l]
        lin_layer.weight = torch.nn.init.uniform_(lin_layer.weight, a=-0.001, b=0.001) # randomly initialise weights
        prune.custom_from_mask(lin_layer, 'weight', con) # remove connections to obtain sparsity
    model.to(device)
    loss_func = nn.CrossEntropyLoss()
    optimiser = optim.Adam(model.parameters(), lr = 0.001)

    trackfile = outdir.join(trackfile_format.format(hidden_size, num_layers, con_p))

    # record test loss before starting
    test_loss = record_test_stats(np.nan, np.nan, 'before') 

    if early_stop:
        test_loss_history = test_loss
        early_stop_counter = 0
    early_stop_flag = False

    for epoch in range(num_epochs):
        print('Started epoch: {}'.format(epoch))
        running_train_loss = 0.0

        for i, (images, labels) in enumerate(loaders['train']):

            images = images.reshape(-1, 28*28).to(device)
            labels = labels.to(device)

            # forward pass
            outputs = model(images)

            # pass output to loss function
            loss = loss_func(outputs, labels)
            # backpropagation
            loss.backward()
            running_train_loss += loss.item()

            # update weights
            optimiser.step()
            # clear gradients from previous epoch
            optimiser.zero_grad(set_to_none = True)

            # print progress

            if i+1 == len(loaders['train']): # at the end of each epoch record stats
                test_loss = record_test_stats(running_train_loss, i, epoch)
                if early_stop:
                    if test_loss < test_loss_history: # if improved
                        test_loss_history = test_loss
                        early_stop_counter = 0
                    else:
                        early_stop_counter += 1
                        if early_stop_counter >= patience:
                            early_stop_flag = True
        if early_stop_flag:
            break

    with open(trackfile, 'a') as f:
        f.write('training completed successfully!\n')

In [None]:
@dask.delayed
def make_model_and_train(hidden_size, con_p, num_layers, num_epochs, trackfile_format = '{}_{}', 
                         outdir = None, torch_seed = 0, early_stop = True, patience = 20):
    '''
    Makes a (sparse) feedforward ANN, trains it on CIFAR10 and records loss and accuracy
    
    hidden_size: int, number of nodes in each hidden layer.
    con_p: float between 0 and 1, connection probability from a hidden node in one layer to the next layer
    num_layers: int, number of hidden layers (2 was used in the paper)
    num_epochs: int, number of epochs to train the model for
    trackfile_format: str, base for filename. Will be formatted with .format(n_cells, con_p)
    outdir: str, folder path
    torch_seed: int, seeds torch and numpy for reproducibility
    early_stop: bool, enables early stopping if loss stops decreasing.
    patience: int, number of epochs without loss improvement before training is stopped. Only relevant if early_stop = True.

    
    '''
    
    def record_test_stats(running_train_loss, i, epoch):
        model.eval()
        train_loss = running_train_loss / i+1
        model.eval()
        running_test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad(): # get test loss over whole test dataset
            for t, (images, labels) in enumerate(loaders['test']):
                images = images.reshape(-1, 32*32*3).to(device)
                labels = labels.to(device)
                # get model predictions
                outputs = model(images)
                # get the label with the strongest activation
                _, predicted = torch.max(outputs.data, 1)
                # compute test loss
                loss = loss_func(outputs, labels)
                running_test_loss += loss.item()
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_loss = running_test_loss / t+1
        # dump it somewhere
        with open(trackfile, 'a') as f:
            f.write('{}_{}_{}_{}\n'.format(epoch, train_loss, test_loss, 100 * correct/float(total)))
        model.train()

        return test_loss

    device = torch.device('cuda')
    torch.manual_seed(torch_seed)
    train_data = datasets.CIFAR10(root = data_dir, train = True, transform = transform, download = False)
    test_data = datasets.CIFAR10(root = data_dir, train = False, transform = transform)

    loaders = {
    'train': DataLoader(train_data, batch_size = 100, shuffle = True, num_workers = 0),
    'test': DataLoader(test_data, batch_size = 100, shuffle = True, num_workers = 0)}

    ## make the model
    class feedforward(nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, num_classes):
            super(feedforward, self).__init__() # initialise the nn.Module
            self.hidden_size = hidden_size
            self.num_layers = num_layers

            self.l1 = nn.Linear(input_size, hidden_size)
            self.relu = nn.ReLU()
            self.linears = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for i in range(num_layers-1)])

            self.fc = nn.Linear(hidden_size, num_classes)


        def forward(self, x):
            # forward run model
            out = self.l1(x)
            out = self.relu(out)

            for lin in self.linears:
                out = lin(out)
                out = self.relu(out)

            # pass it to the linear layer
            output = self.fc(out)

            return output

    model = feedforward(32*32*3, hidden_size, num_layers, 10)

    for l in range(num_layers-1):
        con = generate_con_matrix(hidden_size, con_p)
        lin_layer = model.linears[l]
        lin_layer.weight = torch.nn.init.uniform_(lin_layer.weight, a=-0.001, b=0.001) # randomly initialise weights
        prune.custom_from_mask(lin_layer, 'weight', con) # remove connections to obtain sparsity
    model.to(device)
    loss_func = nn.CrossEntropyLoss()
    optimiser = optim.Adam(model.parameters(), lr = 0.001)

    trackfile = outdir.join(trackfile_format.format(hidden_size, num_layers, con_p))

    # record test loss before starting
    test_loss = record_test_stats(np.nan, np.nan, 'before') 

    if early_stop:
        test_loss_history = test_loss
        early_stop_counter = 0
    early_stop_flag = False

    for epoch in range(num_epochs):
        print('Started epoch: {}'.format(epoch))
        running_train_loss = 0.0

        for i, (images, labels) in enumerate(loaders['train']):

            images = images.reshape(-1, 32*32*3).to(device)
            labels = labels.to(device)

            # forward pass
            outputs = model(images)

            # pass output to loss function
            loss = loss_func(outputs, labels)
            # backpropagation
            loss.backward()
            running_train_loss += loss.item()

            # update weights
            optimiser.step()
            # clear gradients from previous epoch
            optimiser.zero_grad(set_to_none = True)

            # print progress

            if i+1 == len(loaders['train']): # at the end of each epoch record stats
                test_loss = record_test_stats(running_train_loss, i, epoch)
                if early_stop:
                    if test_loss < test_loss_history: # if improved
                        test_loss_history = test_loss
                        early_stop_counter = 0
                    else:
                        early_stop_counter += 1
                        if early_stop_counter >= patience:
                            early_stop_flag = True
        if early_stop_flag:
            break

    with open(trackfile, 'a') as f:
        f.write('training completed successfully!\n')

In [None]:
@dask.delayed
def make_model_and_train(hidden_size, con_p, num_layers, num_epochs, trackfile_format = '{}_{}', 
                         outdir = None, torch_seed = 0, early_stop = True, patience = 20):
    '''
    Makes a (sparse) feedforward ANN, trains it on Sleep-EDF and records loss and accuracy
    
    hidden_size: int, number of nodes in each hidden layer.
    con_p: float between 0 and 1, connection probability from a hidden node in one layer to the next layer
    num_layers: int, number of hidden layers (2 was used in the paper)
    num_epochs: int, number of epochs to train the model for
    trackfile_format: str, base for filename. Will be formatted with .format(n_cells, con_p)
    outdir: str, folder path
    torch_seed: int, seeds torch and numpy for reproducibility
    early_stop: bool, enables early stopping if loss stops decreasing.
    patience: int, number of epochs without loss improvement before training is stopped. Only relevant if early_stop = True.

    
    '''
    
    def record_test_stats(running_train_loss, i, epoch):
        model.eval()
        train_loss = running_train_loss / i+1
        model.eval()
        running_test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad(): # get test loss over whole test dataset
            for t, (images, labels) in enumerate(loaders['test']):
                images = dat['X'][:, :, 1].to(device)
                labels = dat['y'].to(device)
                # get model predictions
                outputs = model(images)
                # get the label with the strongest activation
                _, predicted = torch.max(outputs.data, 1)
                # compute test loss
                loss = loss_func(outputs, labels)
                running_test_loss += loss.item()
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_loss = running_test_loss / t+1
        # dump it somewhere
        with open(trackfile, 'a') as f:
            f.write('{}_{}_{}_{}\n'.format(epoch, train_loss, test_loss, 100 * correct/float(total)))
        model.train()

        return test_loss

    device = torch.device('cuda')
    torch.manual_seed(torch_seed)
    test_data = UEA(
        dataset="Sleep",
        split="val",
        train_prop=0.7,
        seed=123,  # for reproducibility
        path = '/gpfs/soma_fs/scratch/rfruengel/data/sleep_EEG_torchtime'
    )
    train_data = UEA(
        dataset="Sleep",
        split="train",
        train_prop=0.7,
        seed=123,  # for reproducibility
        path = '/gpfs/soma_fs/scratch/rfruengel/data/sleep_EEG_torchtime'
    )

    loaders = {
    'train': DataLoader(train_data, batch_size = 100, shuffle = True, num_workers = 0),
    'test': DataLoader(test_data, batch_size = 100, shuffle = True, num_workers = 0)}

    ## make the model
    class feedforward(nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, num_classes):
            super(feedforward, self).__init__() # initialise the nn.Module
            self.hidden_size = hidden_size
            self.num_layers = num_layers

            self.l1 = nn.Linear(input_size, hidden_size)
            self.relu = nn.ReLU()
            self.linears = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for i in range(num_layers-1)])

            self.fc = nn.Linear(hidden_size, num_classes)


        def forward(self, x):
            # forward run model
            out = self.l1(x)
            out = self.relu(out)

            for lin in self.linears:
                out = lin(out)
                out = self.relu(out)

            # pass it to the linear layer
            output = self.fc(out)

            return output

    model = feedforward(178, n_cells, 2, 5)

    for l in range(num_layers-1):
        con = generate_con_matrix(hidden_size, con_p)
        lin_layer = model.linears[l]
        lin_layer.weight = torch.nn.init.uniform_(lin_layer.weight, a=-0.001, b=0.001) # randomly initialise weights
        prune.custom_from_mask(lin_layer, 'weight', con) # remove connections to obtain sparsity
    model.to(device)
    loss_func = nn.CrossEntropyLoss()
    optimiser = optim.Adam(model.parameters(), lr = 0.001)

    trackfile = outdir.join(trackfile_format.format(hidden_size, num_layers, con_p))

    # record test loss before starting
    test_loss = record_test_stats(np.nan, np.nan, 'before') 

    if early_stop:
        test_loss_history = test_loss
        early_stop_counter = 0
    early_stop_flag = False

    for epoch in range(num_epochs):
        print('Started epoch: {}'.format(epoch))
        running_train_loss = 0.0

        for i, (images, labels) in enumerate(loaders['train']):

            images = dat['X'][:, :, 1].to(device)
            labels = dat['y'].to(device)

            # forward pass
            outputs = model(images)

            # pass output to loss function
            loss = loss_func(outputs, labels)
            # backpropagation
            loss.backward()
            running_train_loss += loss.item()

            # update weights
            optimiser.step()
            # clear gradients from previous epoch
            optimiser.zero_grad(set_to_none = True)

            # print progress

            if i+1 == len(loaders['train']): # at the end of each epoch record stats
                test_loss = record_test_stats(running_train_loss, i, epoch)
                if early_stop:
                    if test_loss < test_loss_history: # if improved
                        test_loss_history = test_loss
                        early_stop_counter = 0
                    else:
                        early_stop_counter += 1
                        if early_stop_counter >= patience:
                            early_stop_flag = True
        if early_stop_flag:
            break

    with open(trackfile, 'a') as f:
        f.write('training completed successfully!\n')

# Reduced datasets

Code for Fig. 4

In [None]:
def get_training_subset(train_data, samples_per_class = 1000):
    c = defaultdict(list)
    for i, t in enumerate(train_data.targets.tolist()):
        c[t].append(i)
    
    for t in sorted(list(set(train_data.targets.tolist()))):
        examples = c[t]
        np.random.shuffle(examples)
        c[t] = examples[:samples_per_class]
    
    indices = [int(i) for x in c.values() for i in x]
    
    assert len(indices) == samples_per_class * len(list(set(train_data.targets.tolist())))
    
    tr_reduced = torch.utils.data.Subset(train_data, indices)
    return tr_reduced

In [None]:
# applied to the training dataset as follows, all remaining model and training code unchanged
train_data = datasets.MNIST(root = data_dir, train = True, transform = ToTensor(), download = False)
test_data = datasets.MNIST(root = data_dir, train = False, transform = ToTensor())

train_partial = get_training_subset(train_data, samples_per_class = samples_per_class) ## 

loaders = {
'train': DataLoader(train_partial, batch_size = 100, shuffle = True, num_workers = 0), ## 
'test': DataLoader(test_data, batch_size = 100, shuffle = True, num_workers = 0)}

# Node dropout noise

Code for Fig. 5D&E

In [None]:
# modify the forward pass. We set dropout = True only during the test set
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__() # initialise the nn.Module
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity = 'relu', batch_first = True)
        self.fc = nn.Linear(hidden_size, num_classes)


    def forward(self, x, dropout = False):
        # set initial hidden state
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        # forward run model
        output, hidden = self.rnn(x, h0) # output shape (batch_size, seq_length, hidden_size)
        # reshape the output so it fits into the fully connected layer (get the last output from the RNN)
        output = output[:, -1, :]
        if dropout:
            random_mask = np.random.rand(100, n_cells) # 100 is the batch size -> a different subset of nodes is dropped for each image
            random_mask = np.array(random_mask > dropout_fraction).astype(int)
            random_mask = torch.tensor(random_mask)
            random_mask = random_mask.to(device)
            with torch.no_grad():
                output = output*random_mask

        # pass it to the linear layer
        output = self.fc(output)

        return output

# Dale's principle networks

Code for Fig. 6

Below is our implementation of Dale's networks. First, randomly choose a proportion (11.5% and 50% were reported in the paper) of nodes to be inhibitory when initialising the ANN. Inhibitory nodes should have all negative outgoing weights, while excitatory nodes should have all positive outgoing weights. Then, whenever the weights of these nodes get updated during training, check to see if they changed sign - if they did, set the weight to zero. 

In [56]:
# to ensure reproducibility, we should set the torch and numpy seed
torch.manual_seed(42)
np.random.seed(42)

In [64]:
## make the model - a recurrent network with one hidden layer, and a fully connected output layer
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__() # initialise the nn.Module
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity = 'relu', batch_first = True)
        self.fc = nn.Linear(hidden_size, num_classes)


    def forward(self, x):
        # set initial hidden state
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        # pass input throught the recurrent layer
        output, hidden = self.rnn(x, h0) # output shape (batch_size, seq_length, hidden_size)
        # reshape the output so it fits into the fully connected layer (get the last output from the RNN)
        output = output[:, -1, :]
        # pass it to the linear layer to get the classification
        output = self.fc(output)

        return output

In [65]:
input_size = 28 # the size of the input at each timestep
hidden_size = 10000 # how many nodes do we want in the hidden layer?
num_layers = 1 # how many recurrent layers 
num_classes = 10 # this is defined by the dataset, which has 10 classes (digits 0-9)

In [66]:
model = RNN(input_size, hidden_size, num_layers, num_classes)

In [67]:
con = generate_con_matrix(hidden_size, 0.1)

In [68]:
inh_percent = 0.115 # the proportion of inhibitory cells in somatosensory cortex https://doi.org/10.1073/pnas.1113648108

In [71]:
# select the nodes which we will make inhibitory
inh_nodes = np.random.choice(list(range(hidden_size)), size=int(inh_percent*hidden_size))

In [72]:
# randomly initialise weights from a uniform distribution, this time all positive
model.rnn.weight_hh_l0 = torch.nn.init.uniform_(model.rnn.weight_hh_l0, a=0, b=0.001) 

# now, make all the weights of our inhibitory nodes negative by multiplying them by -1
with torch.no_grad():
    model.rnn.weight_hh_l0[:, inh_nodes]*=-1 

# remove connections to make it sparse
prune.custom_from_mask(model.rnn, 'weight_hh_l0', con) 

RNN(28, 10000, batch_first=True)

In [73]:
loss_func = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(), lr = 0.001)

In [74]:
def train(num_epochs, model, loaders):
    print('------------------------------')
    print('num_epochs: {}'.format(num_epochs))
    print('model: {}'.format(model))
    print('loaders: {}'.format(loaders))
    print('training on {}'.format(device))
    
    total_step = len(loaders['train'])
    
    # NEW: record which weights started as positive
    pweights = model.rnn.weight_hh_l0 >= 0
    pweights = pweights.to(device)
    
    print('------------------------------')
    for epoch in range(num_epochs):
        print('Started epoch: {}'.format(epoch))
        for i, (images, labels) in enumerate(loaders['train']):
            images = images.reshape(-1, 28, 28).to(device)
            labels = labels.to(device)

            # forward pass
            outputs = model(images)
            
            # pass output to loss function
            loss = loss_func(outputs, labels)
            
            # clear gradients from previous epoch
            optimiser.zero_grad()
            
            # backpropagation
            loss.backward()
            
            # update weights
            optimiser.step()
            
            # NEW: check which weights changed sign, set them to zero                
            with torch.no_grad(): # if we want to make manual changes to the weights like setting them to zero, we have to use torch.no_grad
                model.rnn.weight_hh_l0[pweights & (~(model.rnn.weight_hh_l0 > 0))] = 0
                model.rnn.weight_hh_l0[~pweights & (model.rnn.weight_hh_l0 > 0)] = 0
            
            # print progress
            if (i+1)%100 == 0:
                print('Step [{} / {}], Loss: {}'.format(i+1, total_step, loss.item()))
        
        # at the end of each epoch:
        model.eval() # put the model in evaluation mode. This freezes certain things that might be going on during training, e.g. dropout. We don't use it, but it's good practice to include anyway.
        running_test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad(): # get test loss over whole test dataset
            for t, (images, labels) in enumerate(loaders['test']):
                images = images.reshape(-1, 28, 28).to(device)
                labels = labels.to(device)
                # get model predictions
                outputs = model(images)
                # get the label with the strongest activation, we will say this is the model's class prediction
                _, predicted = torch.max(outputs.data, 1)
                # compute test loss
                loss = loss_func(outputs, labels)
                running_test_loss += loss.item()
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_loss = running_test_loss / t+1
        accuracy = 100 * correct/float(total)
        model.train() # put it back in training mode
        print('TESTING...')
        print('Loss: {}, accuracy: {}%'.format(test_loss, accuracy))

In [75]:
model.to(device)

RNN(
  (rnn): RNN(28, 10000, batch_first=True)
  (fc): Linear(in_features=10000, out_features=10, bias=True)
)

In [76]:
num_epochs = 10
train(num_epochs, model, loaders)

------------------------------
num_epochs: 10
model: RNN(
  (rnn): RNN(28, 10000, batch_first=True)
  (fc): Linear(in_features=10000, out_features=10, bias=True)
)
loaders: {'train': <torch.utils.data.dataloader.DataLoader object at 0x2b148fd0cd90>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x2b14d6e0d730>}
training on cuda
------------------------------
Started epoch: 0
Step [100 / 600], Loss: 1.4706289768218994
Step [200 / 600], Loss: 0.8616660833358765
Step [300 / 600], Loss: 0.7347720861434937
Step [400 / 600], Loss: 0.24932155013084412
Step [500 / 600], Loss: 0.22504915297031403
Step [600 / 600], Loss: 0.22148172557353973
TESTING...
Loss: 1.3064008157587412, accuracy: 91.05%
Started epoch: 1
Step [100 / 600], Loss: 0.36837124824523926
Step [200 / 600], Loss: 0.20446786284446716
Step [300 / 600], Loss: 0.17891612648963928
Step [400 / 600], Loss: 0.2340857833623886
Step [500 / 600], Loss: 0.14661064743995667
Step [600 / 600], Loss: 0.24163366854190826
TESTING...
Loss