This script trains an E/I continious time recurrent neural network to perform a change detection task

In [1]:
# import libraries
import math
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn import functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# seed random number generator for reproducibility
torch.manual_seed(0)
np.random.seed(0)


# number of networks 
num_nets = 50; 

Set up function to generate supervised dataset for change detection task and EI continious time RNN modified from code supplied in Yang & Wang (2020) https://www.cell.com/neuron/fulltext/S0896-6273(20)30705-4

In [2]:
# change detection environment
def ChangeEnvGen(direction):

    trial_length = 50

    # inputs
    inputs = np.zeros([2,trial_length])

    if direction == 0:
        inputs[0,:] += np.linspace(1,0,trial_length).T
        inputs[1,:] += np.linspace(0,1,trial_length).T
    else:
        inputs[0,:] += np.linspace(0,1,trial_length).T
        inputs[1,:] += np.linspace(1,0,trial_length).T
        
    # trial info
    labels = np.zeros([1,trial_length])
    if direction == 0:
        labels[:,:25] = 0; labels[:,25:] = 1
    else:
        labels[:,:25] = 1; labels[:,25:] = 0

    return inputs, labels, direction

# EI CTRNN

class PosWLinear(nn.Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`

    Same as nn.Linear, except that weight matrix is constrained to be non-negative
    """
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=False):
        super(PosWLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = torch.nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        # weight is non-negative
        return F.linear(input, torch.abs(self.weight), self.bias)
    
    
class EIRecLinear(nn.Module):
    r"""Recurrent E-I Linear transformation.
    
    Args:
        hidden_size: int, layer size
        e_prop: float between 0 and 1, proportion of excitatory units
    """
    __constants__ = ['bias', 'hidden_size', 'e_prop']

    def __init__(self,hidden_size, e_prop, diag, bias=False):
        super().__init__()
        self.hidden_size = hidden_size
        self.e_prop = e_prop
        self.e_size = int(e_prop * hidden_size)
        self.i_size = hidden_size - self.e_size
        self.weight = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        mask = np.tile([1]*self.e_size+[-1]*self.i_size, (hidden_size, 1))
        if diag == 1:
            np.fill_diagonal(mask, 0)
        self.mask = torch.tensor(mask, dtype=torch.float32)
        if bias:
            self.bias = nn.Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        
        # Scale E weight by E-I ratio
        self.weight.data[:, :self.e_size] /= (self.e_size/self.i_size)
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
    
    def effective_weight(self):
        return torch.abs(self.weight) * self.mask

    def forward(self, input):
        # weight is non-negative
        return F.linear(input, self.effective_weight(), self.bias)


class EIRNN(nn.Module):
    """E-I RNN.
    
    Reference:
        Song, H.F., Yang, G.R. and Wang, X.J., 2016.
        Training excitatory-inhibitory recurrent neural networks
        for cognitive tasks: a simple and flexible framework.
        PLoS computational biology, 12(2).

    Args:
        input_size: Number of input neurons
        hidden_size: Number of hidden neurons

    Inputs:
        input: (seq_len, batch, input_size)
        hidden: (batch, hidden_size)
        e_prop: float between 0 and 1, proportion of excitatory neurons
    """

    def __init__(self,input_size, hidden_size, dt=20,
                 e_prop=0.8, sigma_rec=0.1, sigma_gain = 1,diag = 1, **kwargs):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.e_size = int(hidden_size * e_prop)
        self.i_size = hidden_size - self.e_size
        self.num_layers = 1
        self.tau = 100
        if dt is None:
            alpha = 1
        else:
            alpha = dt / self.tau
        self.alpha = alpha
        self.oneminusalpha = 1 - alpha
        # Recurrent noise
        self._sigma_rec = np.sqrt(dt) * sigma_rec
        self.diag = diag;
        self.input2h = PosWLinear(input_size, hidden_size)
        # self.input2h = nn.Linear(input_size, hidden_size)
        self.h2h = EIRecLinear(hidden_size, e_prop=e_prop,diag=diag)
        self.gain = sigma_gain
 
    def init_hidden(self, input):
        batch_size = 1
        return (torch.zeros(batch_size, self.hidden_size).to(input.device),
                torch.zeros(batch_size, self.hidden_size).to(input.device))

    def recurrence(self, input, hidden):
        """Recurrence helper."""
        
        state, output = hidden # x(t) , r(t)
        total_input = self.input2h(input) + self.h2h(output) # W_in * input + W_r * r
        state = state * self.oneminusalpha + total_input * self.alpha # x(1-alpha) + # W_in * input + W_r * r
        state += self._sigma_rec * torch.randn_like(state) # + noise
        output = torch.sigmoid(self.gain*state)
        return state, output
        

    def forward(self, input, hidden=None):
        """Propogate input through the network."""
        if hidden is None:
            hidden = self.init_hidden(input)

        output = []
        steps = range(input.size(0))
        for i in steps:
            hidden = self.recurrence(input[i], hidden)
            output.append(hidden[1])

        output = torch.stack(output, dim=0)
        return output, hidden


class Net(nn.Module):
    """Recurrent network model.

    Args:
        input_size: int, input size
        hidden_size: int, hidden size
        output_size: int, output size
        rnn: str, type of RNN, lstm, rnn, ctrnn, or eirnn
    """
    def __init__(self, input_size, hidden_size, output_size, **kwargs):
        super().__init__()

        # Excitatory-inhibitory RNN
        self.rnn = EIRNN(input_size, hidden_size, **kwargs)
        self.fc = PosWLinear(self.rnn.e_size, output_size)
        # self.fc = nn.Linear(self.rnn.e_size, output_size)

    def forward(self, x):
        rnn_activity, _ = self.rnn(x)
        rnn_e = rnn_activity[:, :, :self.rnn.e_size]
        out = self.fc(rnn_e)
        return out, rnn_activity

With everything set up we can now train the networks

In [None]:
# Instantiate the network and print information
hidden_size = 40; input_size = 2; output_size = 2
e_size = .8; 

accuracy_cross_net = np.zeros([50])

training_iterations = 1000
for randseed in range(num_nets):
    
    net = Net(input_size=input_size, hidden_size=hidden_size, output_size=output_size,e_prop=e_size,diag=0)
    
    # Use Adam optimizer
    optimizer = optim.Adam(net.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    # trials
    trl_type = np.tile([0,1],int(training_iterations/2))
    np.random.shuffle(trl_type)
    
    running_loss = 0; running_acc = 0
    print('Training network:',randseed)
    
    for i in range(training_iterations):
        
        trialtype = trl_type[i]
        
        # Generate input and target, convert to pytorch tensor
        ob, gt, trial_type = ChangeEnvGen(trialtype)
        
        inputs = ob.T; labels = gt.T
    
        inputs = torch.from_numpy(inputs).type(torch.float)
        labels = torch.from_numpy(labels.flatten()).type(torch.long)
    
        # boiler plate pytorch training:
        optimizer.zero_grad() # zero the gradient buffers
        output, _ = net(inputs)
        print(output.shape)
        output = torch.squeeze(output)
        print(output.shape)
        print(labels.shape)
        # output = output.view(-1, output_size) # Reshape to (SeqLen x Batch, OutputSize)
    
        # cross entropy loss function
        plt.plot(output.detach().numpy())
        plt.plot(labels.detach().numpy())
        loss = criterion(output, labels)

        loss.backward()
        optimizer.step() # Does the update
        
        # grab choice and compare to ground truth
        choices = torch.argmax(output, dim=1) 
        acc = torch.mean(torch.eq(choices, labels).type(torch.float32)) # torch.eq computs elementwise equality
    
        # compute running accuracy and loss
        running_loss = loss.item()
        running_acc = acc.item()
        
        # Compute the running loss every 100 steps
        if i % 100 == 99:
            print('Step {}, Loss {:f}, Acc {:f}'.format(i+1, running_loss, running_acc))
     
    accuracy_cross_net[randseed] = running_acc
    # save network weights    
    varname = '/Users/chriswhyte/Documents/University/Projects/CurrentProjects/ChangeDetectionRNN/DaleNet/changedetection_weights_EI%d_%d.pth' % (randseed,hidden_size)
    torch.save(net.state_dict(), varname)


print('Training finished')

mean_acc = np.mean(accuracy_cross_net)
std_acc = np.std(accuracy_cross_net)
print('Accuracy:  {:f} ± {:f}'. format(mean_acc, std_acc))


Training network: 0
torch.Size([50, 1, 2])
torch.Size([50, 2])
torch.Size([50])


Convert weights from tensors to numpy arrays and save as csv 

In [7]:
print('start weights conversion')
for randseed in range(num_nets):

    # Instantialise the network
    hidden_size = 40; input_size = 2; output_size = 2
    net = Net(input_size=input_size, hidden_size=hidden_size, output_size=output_size, dt=20)
        
    e_prop = 1;
    e_size = int(e_prop * hidden_size)
    i_size = hidden_size - e_size
    mask = np.tile([1]*e_size+[-1]*i_size, (hidden_size, 1))
    
    # import pretrained networks
    varname = '/Users/chriswhyte/Documents/University/Projects/CurrentProjects/ChangeDetectionRNN/DaleNet/ChangeDetectionWeightsEI/changedetection_weights_EI%d_%d.pth' % (randseed,hidden_size)
    net.load_state_dict(torch.load(varname))
    
    # weights
    W_in_EI = np.abs(np.squeeze(net.rnn.input2h.weight.detach().numpy()))
    W_res_EI = net.rnn.h2h.weight.detach().numpy()
    W_res_EI = np.abs(W_res_EI)*mask
    W_out_EI = np.abs(net.fc.weight.detach().numpy())
    
    # save weights 
    np.savetxt("/Users/chriswhyte/Documents/University/Projects/CurrentProjects/ChangeDetectionRNN/DaleNet/EffectiveWeights/W_in_EI%d_%d" % (randseed,hidden_size), W_in_EI, delimiter=",")
    np.savetxt("/Users/chriswhyte/Documents/University/Projects/CurrentProjects/ChangeDetectionRNN/DaleNet/EffectiveWeights/W_res_EI%d_%d" % (randseed,hidden_size), W_res_EI, delimiter=",")
    np.savetxt("/Users/chriswhyte/Documents/University/Projects/CurrentProjects/ChangeDetectionRNN/DaleNet/EffectiveWeights/W_out_EI%d_%d" % (randseed,hidden_size), W_out_EI, delimiter=",")

print('weight conversion finished')

# plt.imshow(W_res_EI, cmap='hot', interpolation='nearest')
# plt.show()

start weights conversion
weight conversion finished
