In [1]:
# Import the required packages
import numpy as np
# from numpy.linalg import svd as svd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
# from collections import OrderedDict
import matplotlib.pyplot as plt

# Define plot font sizes
label_font = 18
title_font = 24
legend_font = 12


# RNN Model Class and Optimizer

In [2]:
class Drosophila_RNN(nn.Module):
    def __init__(self, KC_size=200, MBON_size=20, DAN_size=20, FBN_size=60, ext_size=2, out_size=1, net_seed=1234):
        super(Drosophila_RNN, self).__init__()
        # Set the seeds
#         np.random.seed(net_seed)
#         torch.manual_seed(net_seed)
        # Set constants
        self.KC_MBON_min = 0. # Minimum synaptic weight
        self.KC_MBON_max = 0.05 # Maximum synaptic weight
        self.tau_w = 5 # Time scale of KC->MBON LTD/LTP (plasticity)
        self.tau_r = 1 # Time scale of output circuitry activity
        # Set the sizes of layers
        self.N_KC = KC_size
        self.N_MBON = MBON_size
        self.N_FBN = FBN_size
        self.N_DAN = DAN_size
        self.N_recur = MBON_size + FBN_size + DAN_size
        self.N_ext = ext_size
        self.N_out = out_size
        # Define updatable network parameters
#         seed_num = net_seed
        seed_num = None
        sqrt2 = torch.sqrt(torch.tensor(2, dtype=torch.float))
        mean_MBON = torch.zeros((self.N_recur, MBON_size))
        mean_FBN = torch.zeros((self.N_recur, FBN_size))
        mean_DAN = torch.zeros((self.N_recur, DAN_size))
        W_MBON = torch.normal(mean_MBON, torch.sqrt(1 / (sqrt2 * MBON_size)), generator=seed_num)
        W_FBN = torch.normal(mean_FBN, torch.sqrt(1 / (sqrt2 * FBN_size)), generator=seed_num)
        W_DAN = torch.normal(mean_DAN, torch.sqrt(1 / (sqrt2 * DAN_size)), generator=seed_num)
        self.W_recur = nn.Parameter(torch.cat((W_MBON, W_FBN, W_DAN), dim=1), requires_grad=True)
        self.W_ext = nn.Parameter(torch.randn(FBN_size, ext_size), requires_grad=True)
        mean_readout = torch.zeros((out_size, MBON_size))
        std_readout = 1 / torch.sqrt(torch.tensor(MBON_size, dtype=torch.float))
        self.W_readout = nn.Parameter(torch.normal(mean_readout, std_readout, generator=seed_num), requires_grad=True)
        self.bias = nn.Parameter(torch.ones(self.N_recur) * 0.1, requires_grad=True)
        
#         def KC_MBON_update(self, KC_trace):
#             """ Updates the synaptic weights from the Keyon cells to the output neurons.
#             Updates the synaptic connections between inputs (KCs) and output neurons (MBONs) using a LTP/LTD
#             rule (see Figure 1B of Jiang 2020). Models dopamine-gated neural plasticity on short time scale.
#             The KC->MBON weights are constrained to the range [0, 0.05].
            
#             Paramters
#                 wt = dynamic weight update variable (see Eq 4 of Jiang 2020)
#             """
            
#             self.W_KC_MBON = (-self.W_KC_MBON + wt) / self.tau_W
#             # Constrain weights to lie in [0, 0.05]
#             pass
            
            
    def forward(self, r_KC, r_ext, time, W0=None, batch_size=30):
        """ Defines the forward pass of the RNN

        Inputs
            r_KC = activity of the Kenyon cell neurons (representing odors)
            r_ext = context signals (representing the conditioning context)
            time = time vector for a single interval

        Returns
            r_recur torch.ndarray((N_MBON + N_FBN + N_DAN), time.size)
                = time series of activities in the output circuitry
            Wt torch.ndarray((N_MBON + N_FBN + N_DAN), time.size)
                = time series of KC->MBON weights (represent dopaminergic plasticity)
            readout torch.ndarray(time.size)
                = time series of valence readouts (represents behaviour)
        """

        # Define the time step of the simulation
        dt = np.diff(time)[0]

        # Initialize output circuit firing rates for each trial
        r_init = torch.ones((batch_size, self.N_recur)) * 0.1
        r_init[:self.N_MBON] = 0
        r_recur = torch.zeros((batch_size, self.N_recur, time.size()[0]))

        # Initialize the eligibility traces, dynamic weights and readout
        r_bar_KC = r_KC[:, :, 0]
        r_bar_DAN = r_recur[:, -self.N_DAN:, 0]
        wt = torch.zeros((batch_size, self.N_MBON, self.N_KC, time.size()[0]))
        W_KC_MBON = torch.zeros((batch_size, self.N_MBON, self.N_KC, time.size()[0]))
        if W0 is None:
            W_KC_MBON[:, :, :, 0] = W0
        else:
            W_KC_MBON[:, :, :, 0] = self.KC_MBON_max
        readout = torch.zeros((batch_size, time.size()[0]))

        # Set the weights DAN->MBON to zero
#         self.W_recur[:self.N_MBON, -self.N_DAN:] = 0

        # Update activity for each time step
        for t in range(time.size()[0] - 1):
            # Define the input to the output circuitry
            I_KC_MBON = torch.einsum('bmk, bk -> bm', W_KC_MBON[:, :, :, t], r_KC[:, :, t])
            I_FBN = torch.einsum('bfe, be -> bf', self.W_ext.repeat(batch_size, 1, 1), r_ext[:, :, t])
            I = torch.zeros((batch_size, self.N_recur))
            I[:, :self.N_MBON] = I_KC_MBON
            I[:, self.N_MBON:self.N_MBON + self.N_FBN] = I_FBN

            # Update the output circuitry activity
            Wr_prod = torch.einsum('bsr, br -> bs', self.W_recur.repeat(batch_size, 1, 1), r_recur[:, :, t])
            dr = F.relu(-r_recur[:, :, t] + Wr_prod + self.bias.repeat(batch_size, 1) + I) / self.tau_r
            r_recur[:, :, t+1] = r_recur[:, :, t] + dr * dt

            # Update KC->MBON plasticity variables
            # Calculate the eligibility traces
            r_bar_KC = (r_KC[:, :, t] - r_bar_KC) * dt / self.tau_w
            r_bar_DAN = (r_recur[:, -self.N_DAN:, t] - r_bar_DAN) * dt / self.tau_w
            # Update the dynamic weight variable
            prod1 = torch.einsum('bd, bk -> bdk', r_bar_DAN, r_KC[:, :, t])
            prod2 = torch.einsum('bd, bk -> bdk', r_recur[:, -self.N_DAN:, t], r_bar_KC)
            dw = (prod1 - prod2)
            wt[:, :, :, t+1] = wt[:, :, :, t] + dw * dt
            # Update the KC->MBON weights
            dW = (-W_KC_MBON[:, :, :, t] + wt[:, :, :, t+1]) / self.tau_w
            W_tp1 = W_KC_MBON[:, :, :, t] + dW * dt
            # Clip the KC->MBON weights to the range [0, 0.05]
            W_KC_MBON[:, :, :, t+1] = torch.clamp(W_tp1, self.KC_MBON_min, self.KC_MBON_max)

            # Calculate the readout
            readout[:, t] = torch.squeeze(torch.einsum('bom, bm -> bo', self.W_readout.repeat(batch_size, 1, 1), r_recur[:, :self.N_MBON, t]))

#             for b in range(n_batch):
#                 # Define the input to the output circuitry
#                 I_KC_MBON = W_KC_MBON[b, :, :, t] @ r_KC[b, :, t]
#                 I_FBN = self.W_ext @ r_ext[b, :, t]
#                 I = torch.zeros(self.N_recur)
#                 I[:self.N_MBON] = I_KC_MBON
#                 I[self.N_MBON:self.N_MBON + self.N_FBN] = I_FBN

#                 # Update the output circuitry activity
#                 Wr_prod = self.W_recur @ r_recur[b, :, t]
#                 dr = F.relu(-r_recur[b, :, t] + Wr_prod + self.bias + I) / self.tau_r
#                 r_recur[b, :, t+1] = r_recur[b, :, t] + dr * dt

#                 # Update KC->MBON plasticity variables
#                 # Calculate the eligibility traces
#                 r_bar_KC[b, :] = (r_KC[b, :, t] - r_bar_KC[b, :]) * dt / self.tau_w
#                 r_bar_DAN[b, :] = (r_recur[b, -self.N_DAN:, t] - r_bar_DAN[b, :]) * dt / self.tau_w
#                 # Update the dynamic weight variable
#                 prod1 = r_bar_DAN[b, :].view(-1, 1) @ r_KC[b, :, t].view(1, -1)
#                 prod2 = r_recur[b, -self.N_DAN:, t].view(-1, 1) @ r_bar_KC[b, :].view(1, -1)
#                 dw = (prod1 - prod2)
#                 wt[b, :, :, t+1] = wt[b, :, :, t] + dw * dt
#                 # Update the KC->MBON weights
#                 dW = (-W_KC_MBON[b, :, :, t] + wt[b, :, :, t+1]) / self.tau_w
#                 W_tp1 = W_KC_MBON[b, :, :, t] + dW * dt
#                 # Clip the KC->MBON weights to the range [0, 0.05]
#                 W_KC_MBON[b, :, :, t+1] = torch.clamp(W_tp1, self.KC_MBON_min, self.KC_MBON_max)

#                 # Calculate the readout
#                 readout[b, t] = self.W_readout @ r_recur[b, :self.N_MBON, t]

        return r_recur, W_KC_MBON, readout
            
        
# Clipping weights between [0, 0.05]
# https://discuss.pytorch.org/t/how-to-do-constrained-optimization-in-pytorch/60122
# https://discuss.pytorch.org/t/set-constraints-on-parameters-or-layers/23620
# https://discuss.pytorch.org/t/restrict-range-of-variable-during-gradient-descent/1933/4

# Setting DAN->MBON weights to zero
# https://pytorch.org/docs/stable/generated/torch.triu.html

# Broadcasting using einsum
# https://github.com/pytorch/pytorch/issues/15671


# Conditioning Tasks

In [3]:
# Initialize the network
classic_net = Drosophila_RNN()
for param in classic_net.parameters():
    print(param.shape)
#     print(param)
# print(classic_net.N_DAN)

# Define the model's optimizer
lr = 0.001
optimizer = optim.RMSprop(classic_net.parameters(), lr=lr)


torch.Size([100, 100])
torch.Size([60, 2])
torch.Size([1, 20])
torch.Size([100])


In [4]:
# Define the cost function for conditioning tasks
def cond_loss(vt, vt_opt, r_DAN, lam=0.1):
    """ Calculates the loss for conditioning tasks.
    Composed of an MSE cost based on the difference between output and
    target valence, and a regularization cost that penalizes excess
    dopaminergic activity. Reference Eq. (3) and (9) in Jiang 2020.
    
    Parameters
        vt = time dependent valence output of network
        vt_opt = target valence (must be a torch tensor)
        r_DAN = time series of dopaminergic neuron activities
        lam = regularization constant
    
    Returns
        loss_tot = scalar loss used in backprop
    """
    
    # Set the baseline DAN activity
    DAN_baseline = 0.1
    
    # Calculate the MSE loss of the valence
    v_sum = torch.mean((vt - vt_opt)**2, dim=1)
    v_loss = torch.mean(v_sum)
    
    # Calculate regularization term
    r_sum = torch.sum(F.relu(r_DAN-0.1)**2, dim=1)
    r_loss = lam * torch.mean(r_sum, dim=1)
    
    # Calculate the summed loss (size = n_batch)
    loss = v_loss + r_loss
    
    # Average the loss over all batches
    loss_tot = torch.mean(loss)
    
    return loss_tot


In [5]:
# Initialize the simulation constants
# n_batch = 30
n_batch = 10
# n_epochs = 2000
n_epochs = 5
T_int = 30
T_stim = 2
dt = 0.5
stim_wts = torch.ones(n_batch, 12)
stim_len = int(T_stim / dt)

# Odor (CS) constants
n_KC = 200
n_ones = int(n_KC * 0.1)
KC_wts = torch.ones(n_batch, n_KC)
# Context (US) constants
n_ext = 2
ext_wts = torch.ones(n_batch, n_ext)

# List to store losses
loss_hist = []


In [6]:
# Train first-order conditioning network
time_int = torch.arange(0, T_int + dt/10, dt)
n_ints = 2
for epoch in range(n_epochs):
    # Define the conditioned stimuli (CS) = odors
    r_KC_inds = torch.multinomial(KC_wts, n_ones)
    r_KC = torch.zeros(n_batch, n_KC)
    for b in range(n_batch):
        r_KC[b, r_KC_inds[b, :]] = 1
        
    # Randomly determine whether each trial is CS+ or CS- (0 = CS+, 1 = CS-)
    cs_trials = torch.multinomial(torch.ones(2), n_batch, replacement=True)
    # Define the unconditioned stimuli (US) = context
    r_ext = torch.multinomial(ext_wts, n_ext)
    r_ext[cs_trials == 0] = 0
    
    # Randomly determine the time of each stimulus presentation
    stim_times = torch.multinomial(stim_wts, n_ints, replacement=True) + 10
    
    # Make a list to store activities and valences
    r_outs = []
    vts = []
    
    # Initial KC->MBON weight value
    W_KC_MBON = 0.05
    
    # For each interval
    for i in range(n_ints):
        # Define a binary CS and US time series to mulitply the inputs by
        time_CS = torch.zeros(n_batch, time_int.size()[0])
        time_US = torch.zeros_like(time_CS)
        for b in range(n_batch):
            stim_inds = stim_times[b, i] + torch.arange(stim_len)
            time_CS[b, stim_inds] = 1
            if i == 0:
                stim_inds = stim_times[b, i] + torch.arange(stim_len) + stim_len
                time_US[b, stim_inds] = 1
        r_KCt = torch.einsum('bm, mbt -> bmt', r_KC, time_CS.repeat(n_KC, 1, 1))
        r_extt = torch.einsum('bm, mbt -> bmt', r_ext, time_US.repeat(n_ext, 1, 1))

        # Run the forward model
        r_int, W_KC_MBON, vt = classic_net(r_KCt, r_extt, time_int, W0=W_KC_MBON, batch_size=n_batch)
        r_outs.append(r_int)
        vts.append(vt)

    # Concatenate the activities and valences
    r_out_epoch = torch.cat((r_outs[0], r_outs[1]), dim=2)
    vt_epoch = torch.cat((vts[0], vts[1]), dim=1)

    # Define the target valence
    vt_opt = torch.cat((torch.zeros_like(time_CS), time_CS), dim=1)
    
    # Calculate the loss
    loss = cond_loss(vt_epoch, vt_opt, r_out_epoch[:, -classic_net.N_DAN:, :])
    print(loss)
    
    # Update the network parameters
    optimizer.zero_grad()
    torch.autograd.set_detect_anomaly(True)
    loss.backward()
    optimizer.step()
    
    # Print an update
    if epoch % 1000 == 0:
        print(epoch, loss.item())
    loss_hist.append(loss.item())
    
plt.plot(loss_hist)


tensor(31.0931, grad_fn=<MeanBackward0>)


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [10, 100, 1]], which is output 0 of ViewBackward, is at version 60; expected version 59 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!