In [31]:
# Import the required packages
import numpy as np
# from numpy.linalg import svd as svd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
# from collections import OrderedDict
import matplotlib.pyplot as plt

# Define plot font sizes
label_font = 18
title_font = 24
legend_font = 12


# RNN Model Class and Optimizer

In [32]:
class Drosophila_RNN(nn.Module):
    def __init__(self, KC_size=200, MBON_size=20, DAN_size=20, FBN_size=60, ext_size=2, out_size=1, batch_size=30, net_seed=1234):
        super(Drosophila_RNN, self).__init__()
        # Set the seeds
#         np.random.seed(net_seed)
#         torch.manual_seed(net_seed)
        # Set constants
#         self.dt = 0.5 # seconds
        self.KC_MBON_min = 0. # Minimum synaptic weight
        self.KC_MBON_max = 0.05 # Maximum synaptic weight
        self.tau_w = 5 # Time scale of KC->MBON LTD/LTP (plasticity)
        self.tau_r = 1 # Time scale of output circuitry activity
        # Set the sizes of layers
        self.N_KC = KC_size
        self.N_MBON = MBON_size
        self.N_DAN = DAN_size
        self.N_recur = FBN_size + MBON_size + DAN_size
        self.N_ext = ext_size
        self.N_out = out_size
#         # Define dynamic variables
#         self.W_KC_MBON = Variable(torch.ones(MBON_size, KC_size) * self.KC_MBON_max, requires_grad=False)
#         self.wt = 
        # Define updatable network parameters
#         seed_num = net_seed
        seed_num = None
        sqrt2 = torch.sqrt(torch.tensor(2, dtype=torch.float))
        mean_MBON = torch.zeros((self.N_recur, MBON_size))
        mean_FBN = torch.zeros((self.N_recur, FBN_size))
        mean_DAN = torch.zeros((self.N_recur, DAN_size))
        W_MBON = torch.normal(mean_MBON, torch.sqrt(1 / (sqrt2 * MBON_size)), generator=seed_num)
        W_FBN = torch.normal(mean_FBN, torch.sqrt(1 / (sqrt2 * FBN_size)), generator=seed_num)
        W_DAN = torch.normal(mean_DAN, torch.sqrt(1 / (sqrt2 * DAN_size)), generator=seed_num)
        self.W_recur = nn.Parameter(torch.cat((W_MBON, W_FBN, W_DAN), dim=1), requires_grad=True)
        self.W_ext = nn.Parameter(torch.randn(FBN_size, ext_size), requires_grad=True)
        mean_readout = torch.zeros((out_size, MBON_size))
        std_readout = 1 / torch.sqrt(torch.tensor(MBON_size, dtype=torch.float))
        self.W_readout = nn.Parameter(torch.normal(mean_readout, std_readout, generator=seed_num), requires_grad=True)
        self.bias = nn.Parameter(torch.ones(self.N_recur) * 0.1, requires_grad=True)
        # Initialize neuron activity variables
#         self.r_MBON = Variable(torch.zeros(MBON_size), requires_grad=False)
#         self.r_DAN = Variable(torch.ones(DAN_size) * 0.1, requires_grad=False)
#         self.r_FBN = Variable(torch.ones(FBN_size) * 0.1, requires_grad=False)
#         self.r_recur = torch.cat((self.r_MBON, self.r_DAN, self.r_FBN))
#         self.r_recur = Variable(torch.ones(self.N_recur) * 0.1, requires_grad=False)
#         self.r_recur[:MBON_size] = 0
#         r_init = Variable(torch.ones(self.N_recur) * 0.1, requires_grad=False)
#         r_init[:MBON_size] = 0
#         self.r_recur = []
#         self.r_recur.append(r_init)
#         self.r_KC = Variable(torch.ones(KC_size) * 0.1, requires_grad=False)
        self.r_recur = []
        
#         def KC_MBON_update(self, KC_trace):
#             """ Updates the synaptic weights from the Keyon cells to the output neurons.
#             Updates the synaptic connections between inputs (KCs) and output neurons (MBONs) using a LTP/LTD
#             rule (see Figure 1B of Jiang 2020). Models dopamine-gated neural plasticity on short time scale.
#             The KC->MBON weights are constrained to the range [0, 0.05].
            
#             Paramters
#                 wt = dynamic weight update variable (see Eq 4 of Jiang 2020)
#             """
            
#             self.W_KC_MBON = (-self.W_KC_MBON + wt) / self.tau_W
#             # Constrain weights to lie in [0, 0.05]
#             pass
            
            
        def forward(self, r_KC, r_ext, time, W0=self.KC_MBON_max, batch_size=30):
            """ Defines the forward pass of the RNN
            
            Inputs
                r_KC = activity of the Kenyon cell neurons (representing odors)
                r_ext = context signals (representing the conditioning context)
                time = time vector for a single interval
                
            Returns
                r_recur torch.ndarray((N_MBON + N_FBN + N_DAN), time.size)
                    = time series of activities in the output circuitry
                Wt torch.ndarray((N_MBON + N_FBN + N_DAN), time.size)
                    = time series of KC->MBON weights (represent dopaminergic plasticity)
                readout torch.ndarray(time.size)
                    = time series of valence readouts (represents behaviour)
            """
            
            # Define the time step of the simulation
            dt = np.diff(time)[0]
            
            # Initialize output circuit firing rates for each trial
            r_init = torch.ones((batch_size, self.N_recur)) * 0.1
            r_init[:self.N_MBON] = 0
            r_recur = torch.zeros((batch_size, self.N_recur, time.size))
            
            # Initialize the eligibility traces, dynamic weights and readout
#             r_bar_KC = r_KC[:, :, 0]
            r_bar_KC = r_KC
            r_bar_DAN = r_recur[:, -self.N_DAN:, 0]
            wt = torch.zeros((batch_size, self.N_KC, time.size))
            W_KC_MBON = torch.zeros((batch_size, self.N_MBON, self.N_KC, time.size))
            W_KC_MBON[:, :, :, 0] = W0
            readout = torch.zeros((batch_size, time.size))
            
            # Set the weights DAN->MBON to zero
            self.W_recur[:self.N_MBON, -self.N_DAN:] = 0
            
            # Update activity for each time step
            for t in range(time.size - 1):
                # Define the input to the output circuitry
#                 I_KC_MBON = torch.einsum('bmk, bk -> bm', W_KC_MBON[:, :, :, t], r_KC[:, :, t])
                I_KC_MBON = torch.einsum('bmk, bk -> bm', W_KC_MBON[:, :, :, t], r_KC)
#                 I_FBN = torch.einsum('fbe, be -> bf', self.W_ext.repeat(batch_size, 1, 1), r_ext[:, :, t]).repeat(batch_size, 1)
                I_FBN = torch.einsum('fbe, be -> bf', self.W_ext.repeat(batch_size, 1, 1), r_ext).repeat(batch_size, 1)
#                 I_FBN = (self.W_ext @ r_ext[:, t]).repeat(batch_size, 1)
#                 I_FBN = (self.W_ext @ r_ext).repeat(batch_size, 1)
                I = torch.zeros((batch_size, self.N_recur))
                I[:, :self.N_MBON] = I_KC_MBON
                I[:, self.N_MBON:self.N_MBON + self.N_FBN] = I_FBN

                # Update the output circuitry activity
                Wr_prod = torch.einsum('bsr, br -> bs', self.W_recur.repeat(batch_size, 1, 1), r_recur[:, :, t])
                dr = F.relu(-r_recur[:, :, t] + Wr_prod + self.b.repeat(batch_size, 1) + I) / self.tau_r
                r_recur[:, :, t+1] = r_recur[:, :, t] + dr * dt

                # Update KC->MBON plasticity variables
                # Calculate the eligibility traces
#                 r_bar_KC = (r_KC[:, :, t] - r_bar_KC) * dt / self.tau_w
                r_bar_KC = (r_KC - r_bar_KC) * dt / self.tau_w
                r_bar_DAN = (r_recur[:, -self.N_DAN:, t] - r_bar_DAN) * dt / self.tau_w
                # Update the dynamic weight variable
#                 prod1 = torch.einsum('bd, bk -> bdk', r_bar_DAN, r_KC[:, :, t])
                prod1 = torch.einsum('bd, bk -> bdk', r_bar_DAN, r_KC)
                prod2 = torch.einsum('bd, bk -> bdk', r_recur[:, -self.N_DAN:, t], r_bar_KC)
                dw = (prod1 - prod2)
                wt[:, :, :, t+1] = wt[:, :, :, t] + dw * dt
                # Update the KC->MBON weights
                dW = (-W_KC_MBON[:, :, :, t] + wt[:, :, :, t+1]) / self.tau_w
                W_tp1 = W_KC_MBON[:, :, :, t] + dW * dt
                # Clip the KC->MBON weights to the range [0, 0.05]
                W_KC_MBON[:, :, :, t+1] = torch.clamp(W_tp1, self.KC_MBON_min, self.KC_MBON_max)
                
                # Calculate the readout
                readout[t] = torch.einsum('bom, bm -> bo', self.W_readout.repeat(batch_size, 1, 1), r_recur[:, :self.N_MBON, t])

#             self.r_recur.append(r_recur)
            return r_recur, W_KC_MBON, readout
            
        
    
    
# torch.einsum("abc, cd -> abd", ht, w)
# a: batch size
# b: time steps (not there in for loops)
# c: dims
# d: dims (==c if recurrent weights)

# Clipping weights between [0, 0.05]
# https://discuss.pytorch.org/t/how-to-do-constrained-optimization-in-pytorch/60122
# https://discuss.pytorch.org/t/set-constraints-on-parameters-or-layers/23620
# https://discuss.pytorch.org/t/restrict-range-of-variable-during-gradient-descent/1933/4

# Setting DAN->MBON weights to zero
# https://pytorch.org/docs/stable/generated/torch.triu.html

# Broadcasting using einsum
# https://github.com/pytorch/pytorch/issues/15671


In [36]:
# How to initialize odors with probability p=0.1
# In Numpy
A = np.random.sample((5, 5))
# print(A<0.1)
W = np.zeros(A.shape)
W[A<0.1] = 1
print(W)

# In PyTorch
A = torch.rand((5, 5))
W = torch.zeros(A.shape)
W[A<0.1] = 1
print(W)

# How to clamp weights
A = torch.rand((8, 8)) - 0.3
W = torch.clamp(A, 0, 0.5)
print(W)

# How to initilize odors (CS in paper) with 10% of neurons as 1
n_KC = 200
n_ones = int(n_KC * 0.1)
n_batches = 30
weights = torch.ones(n_batches, n_KC)
r_ints = torch.multinomial(weights, n_ones)


[[0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
tensor([[0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.]])
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0829, 0.0160, 0.4969],
        [0.1757, 0.1152, 0.0000, 0.0122, 0.0454, 0.0000, 0.4526, 0.1721],
        [0.2412, 0.0461, 0.0000, 0.5000, 0.2354, 0.4359, 0.4684, 0.0000],
        [0.0000, 0.0942, 0.0000, 0.0000, 0.5000, 0.3302, 0.4791, 0.1849],
        [0.0000, 0.0000, 0.2742, 0.1962, 0.0555, 0.0341, 0.1114, 0.2181],
        [0.0314, 0.0000, 0.1716, 0.0000, 0.3559, 0.2366, 0.0000, 0.5000],
        [0.3959, 0.1571, 0.0987, 0.3999, 0.1285, 0.5000, 0.0554, 0.5000],
        [0.3156, 0.0000, 0.4285, 0.5000, 0.3415, 0.0000, 0.5000, 0.4421]])


# Conditioning Tasks

In [38]:
# Initialize the network
classic_net = Drosophila_RNN(batch_size=n_batch)
for param in classic_net.parameters():
    print(param.shape)
#     print(param)
# print(classic_net.N_DAN)

# Define the model's optimizer
lr = 0.001
optimizer = optim.RMSprop(classic_net.parameters(), lr=lr)


torch.Size([100, 100])
torch.Size([60, 2])
torch.Size([1, 20])
torch.Size([100])
20


In [35]:
# Define the cost function for conditioning tasks
def cond_loss(vt, vt_opt, r_DAN, lam=0.1):
    """ Calculates the loss for conditioning tasks.
    Composed of an MSE cost based on the difference between output and
    target valence, and a regularization cost that penalizes excess
    dopaminergic activity. Reference Eq. (3) and (9) in Jiang 2020.
    
    Parameters
        vt = time dependent valence output of network
        vt_opt = target valence (must be a torch tensor)
        r_DAN = time series of dopaminergic neuron activities
        lam = regularization constant
    
    Returns
        loss_tot = scalar loss used in backprop
    """
    
    # Set the baseline DAN activity
    DAN_baseline = 0.1
    
    # Calculate the MSE loss of the valence
    v_sum = torch.mean((vt - vt_opt)**2, dim=1)
    v_loss = torch.mean(v_sum)
    
    # Calculate regularization term
    r_sum = torch.sum(F.relu(r_DAN-0.1)**2, dim=1)
    r_loss = lam * torch.mean(r_sum, dim=1)
    
    # Calculate the summed loss (size = n_batch)
    loss = v_loss + r_loss
    
    # Average the loss over all batches
    loss_tot = torch.mean(loss)
    loss_tot.requires_grad = False
    
    return loss_tot

# https://discuss.pytorch.org/t/solved-what-is-the-correct-way-to-implement-custom-loss-function/3568/8
# https://discuss.pytorch.org/t/build-your-own-loss-function-in-pytorch/235/41
# https://discuss.pytorch.org/t/custom-loss-functions/29387


In [37]:
# Initialize the simulation constants
T_int = 30
T_stim = 2
dt = 0.5
stim_wts = torch.ones(n_batch, 12)
# n_batch = 30
n_batch = 2
# n_epochs = 2000
n_epochs = 5

# Odor (CS) constants
n_KC = 200
n_ones = int(n_KC * 0.1)
KC_wts = torch.ones(n_batch, n_KC)
# Context (US) constants
ext_wts = torch.ones(n_batch, 2)

# List to store losses
loss_hist = []


In [None]:
# Train first-order conditioning network
time_int = torch.arange(0, 2*T_int + dt/10, dt)
n_ints = 2
for epoch in range(n_epochs):
    # Randomly determine whether each trial is CS+ or CS- (0 = CS+, 1 = CS-)
    cs_trials = torch.multinomial(torch.ones(2), n_batch)
    # Define the conditioned stimuli (CS) = odors
    r_KC = torch.multinomial(KC_wts, n_ones)
    # Define the unconditioned stimuli (US) = context
    r_ext = torch.multinomial(ext_wts, 2)
    r_ext[cs_trials == 0] = 0
    # Randomly determine the time of each stimulus presentation
    stim_times = torch.multinomial(stim_wts, n_ints) + 10
    # Define the stimulus time vectors
    time_KC = torch.zeros(n_batch, time_int.size)
    r_KCt = 
    
#     X, Xcontext, Y, mu= gen_batch(batch_size)
#     X = torch.cat((X, Xcontext),-1)
    
    # Define the target valence
    vt_opt = 
    # Run the forward model
#     Yhat, _ = rnn(X)
    r_out, W_KC_MBON, vt = classic_net(r_KC, r_ext, time, batch_size=n_batch)
    
    # Calculate the loss
#     loss = loss_fn(Y, Yhat)
    loss = cond_loss(vt, vt_opt, r_out[:, -Cond_Net.N_DAN:, :])
    
    # Update the network parameters
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Print an update
    if epoch % 1000 == 0:
        print(epoch, loss.item())
    loss_hist.append(loss.item())
    
plt.plot(loss_hist)
