<a href="https://colab.research.google.com/github/elliottabe/RF_Workshop/blob/main/Workshop_notebook.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Installing repo and dependencies if using colab. Skip to imports if running locally
!pip install -U matplotlib &> /dev/null
!git clone https://github.com/elliottabe/RF_workshop.git &> /dev/null
!pip install -r ./RF_workshop/requirements.txt &> /dev/null
# !pip install git+https://github.com/elliottabe/RF_workshop.git &> /dev/null

In [None]:
import gdown
file_id = '1AUYAmfQp3Hh25uf_mohaT3N3qLXKlMeo' # File id to example data
output_file = 'data.h5'

gdown.download(f"https://drive.google.com/uc?id={file_id}", output_file)

# Import modules

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
import RF_workshop.io_dict_to_hdf5 as ioh5

##### Plotting settings ######
import matplotlib as mpl

mpl.rcParams.update({'font.size':         10,
                     'axes.linewidth':    2,
                     'xtick.major.size':  3,
                     'xtick.major.width': 2,
                     'ytick.major.size':  3,
                     'ytick.major.width': 2,
                     'axes.spines.right': False,
                     'axes.spines.top':   False,
                     'pdf.fonttype':      42,
                     'xtick.labelsize':   10,
                     'ytick.labelsize':   10,
                     'figure.facecolor': 'white'

                    })


In [None]:
data = ioh5.load('./data.h5')


In [None]:
data.keys()

In [None]:
model_nsp = data['model_nsp']
model_vid_sm = data['model_vid_sm']

In [None]:
model_nsp.shape, model_vid_sm.shape

# Data prep

Receptive field (RF) mapping is classically done with reverse correlation (spike triggered averages). The basics can be done with simple linear algebra, but the reverse correlation becomes computationally expensive when dealing with high dimensional inputs. In this workshop, we will cover how to map RFs with a simple neural network. 

In [None]:
# Import train/test split functions
from sklearn.model_selection import train_test_split, GroupShuffleSplit

Due to the temporal correlations in visual data, we do a group shuffle split where nonoverlapping 10\% chunks of the data are split and randomly shuffled to generate our train and test datasets. 

In [None]:
NKfold = 1 # Number of Kfolds for the shuffle
train_size = 0.8 # Fraction of data used for training set
frac = 0.1 # fraction of the data to create chunks
gss = GroupShuffleSplit(n_splits=NKfold, train_size=train_size, random_state=42)
nT = model_nsp.shape[0] # Number of timepoints
groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((frac*(i-1))*nT)) for i in range(1,int(1/frac)+1)]) # defining groups

# Create list of train and test indicies
train_idx_list=[]
test_idx_list = []
for train_idx, test_idx in gss.split(np.arange(nT), groups=groups):
    train_idx_list.append(train_idx)
    test_idx_list.append(test_idx)
    
# Defining train and test datasets, with option to crop images. 
cropn = 0
train_idx = train_idx_list[0]
test_idx = test_idx_list[0]
if cropn>0:
    xtrain = model_vid_sm[train_idx][:,cropn:-cropn,cropn:-cropn]
    xtest = model_vid_sm[test_idx][:,cropn:-cropn,cropn:-cropn]
else: 
    xtrain = model_vid_sm[train_idx]
    xtest = model_vid_sm[test_idx]
im_size = xtrain.shape[1:]
xtrain = xtrain.reshape(len(train_idx),-1)
xtest = xtest.reshape(len(test_idx),-1)
ytrain = model_nsp[train_idx]
ytest = model_nsp[test_idx]

xtrain.shape, ytrain.shape, xtest.shape, ytest.shape

In [None]:
# Pytorch uses the tensor datastructure, here we load the numpy arrays into tensors and put them onto the gpu for processing if available
xtr, xte, ytr, yte = torch.from_numpy(xtrain).float().to(device), torch.from_numpy(xtest).float().to(device), torch.from_numpy(ytrain).float().to(device), torch.from_numpy(ytest).float().to(device)

# creating some variables to keep track of dimensions
input_size = xtr.shape[1]
output_size = ytr.shape[1]
Num_units = model_nsp.shape[1]

# Create Pytorch Model

Pytorch is a coding base used to train deep neural networks. Here we use the predefined layers to create a simple generalized linear model (GLM). 

A model in pytorch is defined in a couple of different ways. The simplest example, used here, is with the nn.Sequential function. This function constructs a model based on predefined operations and pushes data through them sequentially. 

In this case we construct a single linear layer with an output ReLU nonlinearity. 

To train the model parameters, pytorch utilizes auto differentiation methods to compute the gradient with respect to a loss value. This is defined using the ```torch.optim``` module. A commonly used optimizer is the ADAM algorithm. 

In this simple case a with a single linear layer the input/output function is defined as: 

$y = f(Wx + b)$, where x is the inputs, y is the outputs, and W,b are learnable parameters. $f$ is a nonlinear function, in this case ReLU. 

W is a weight matrix which after training represents the receptive fields. 

In [None]:
model = nn.Sequential(nn.Linear(input_size,output_size),
                      nn.ReLU()).to(device)

# Define optimizer and paramters to be learned. The learning rate (lr) represents how big of a step we go along the gradient
optimizer = torch.optim.Adam(model.parameters(), lr=.001, weight_decay=.1)

# To train the model we must define a loss function. A simple one for this case is the mean-squared error (MSE). 
criterion = nn.MSELoss()


In [None]:
# When printing the model we can see it is built with a single Linear layer and ReLU activation fuction. 
print(model)

In [None]:
# The variable 'model' holds all the parameters that will be used to learn the mapping between our inputs and outputs. We can inspect them using the following code: 
for name,p in model.named_parameters():
    print('{}: {}'.format(name, p.shape))

In [None]:
# Further inspectin the weights we see that the weights have the flag requires_grad=True meaning every operation is tracked for the gradient calculation. 
print(model[0].weight)
# Before performing the backwards pass we see there is no gradient information in our paremters
print(model[0].weight.grad)

If we want to add regularization into the model we can calculate additional terms and add them to the loss. L1 and L2 regularization are common in regression. L2 regularization (ridge regression) is already implemented in the optimizers and is used by adding a weight_decay value. 

In [None]:
# Get predicted output
yhat = model(xtr)
print(yhat.shape)
# Calculate the loss value
loss_value = criterion(yhat,ytr)
print(loss_value)

# Add L2 regularization with weight_decay
optimizer = torch.optim.Adam(model.parameters(), lr=.001, weight_decay=.1)
# Calculate loss with L1 regularization
l1_alpha = 0.0001 # strength of L1 regularization
loss_value = criterion(yhat,ytr) + l1_alpha*torch.norm(model[0].weight,p=1) 


In [None]:
# To update the paraemters of our model based we must first call the backwards pass. 

# Make sure to clear gradient before calculating backwards pass just in case. 
optimizer.zero_grad()
# backwards pass
loss_value.backward()

# Now we see that the parameters have a gradient value
print('Gradient of weights:',model[0].weight.grad)
print('Parameters before update:',model[0].weight)
# update parameters
optimizer.step()
print('Parameters after update:',model[0].weight)




As we can see the weights of the model have change. Now we can place these operations within a for loop and iterate through out data multiple times. 

Terminology:
- Epoch: A single runthough the dataset
- batch (minibatch): when not all the data can be loaded on to the gpu at the same time, chunks of data are processed at a time. A batch represent one of these chunks. 
- batch size: represent how many chunks are processed in parallel. For example, data is often of the shape (batch_size, time, features)

## Full Training loop

In [None]:
Nepochs = 2000 # Number of epochs
l2_lambda_list = [.05,.1,1] # List of L2 regularzation strengths to iterate over
l1_alpha = 0.0001 # Strength of L1 regularization
min_loss = np.inf # define initial validation loss

##### Use tqdm to visualize progress #####
with tqdm(initial=0,total=len(l2_lambda_list), dynamic_ncols=False, miniters=1) as tq:
    ##### Loop through the different L2 values
    for l2_lambda in l2_lambda_list:
        ##### For each L2 value we define a new model to train #####
        model = nn.Sequential(nn.Linear(input_size,output_size),
                            nn.ReLU()).to(device)
        ##### Define the optimizer for learning #####
        optimizer = torch.optim.Adam(model.parameters(), lr=.001, weight_decay=l2_lambda)
        
        ##### Training loop ####
        for epoch in tqdm(range(Nepochs),leave=False):
            model.train()
            optimizer.zero_grad()
            yhat = model(xtr)
            train_loss = nn.MSELoss()(yhat, ytr) + l1_alpha*torch.norm(model[0].weight,p=1)
            train_loss.backward(torch.ones_like(train_loss))
            optimizer.step()
    
        ##### Check on validation data ####
        yhat = model(xte)
        ##### Calculate validation loss #####
        val_loss = nn.MSELoss()(yhat, yte)  + l1_alpha*torch.norm(model[0].weight,p=1)
        
        #### If validation loss is new minimum save, and save model parameters ####
        if val_loss < min_loss:
            l2_lambda_min = l2_lambda
            torch.save(model.state_dict(),'./RF_l2_min.pt')
            min_loss = val_loss
        
        ##### Update visualziation progress #### 
        tq.set_postfix(val_loss='{:05.3f}'.format(val_loss),train_loss='{:05.3f}'.format(train_loss),min_loss='{:05.3f}'.format(min_loss))
        tq.update()
        
##### load best model #####
load_model = torch.load('./RF_l2_min.pt')
model.load_state_dict(load_model)

In [None]:
##### The weights of our model represent the visual RFs so lets extract them. #####

# To retreive them from the model and put them in a form that is more easily visualizable we have to detach the weights, put them on the cpu and change from tensor to numpy. 
RF = model[0].weight.detach().cpu().numpy().reshape(Num_units,im_size[0],im_size[1])
RF.shape

Now that we have the RFs, lets plot them to see what they look like.

In [None]:
fig, axs = plt.subplots(13,10,figsize=(20,20))
for n, ax in enumerate(range(RF.shape[0])):
    ax = axs.flatten()[n]
    cmax = np.max(np.abs(RF[n]))
    ax.imshow(RF[n],cmap='RdBu_r',vmin=-cmax,vmax=cmax)
    ax.axis('off')
    ax.set_title(f'{n}')

In the data provided, we also have the calculated RFs from Parker, Abe, et. al. 2022 and we can compare. 

Note: In the paper we thresholded out some neruons due to firing rate and duplication. The data we used to map the receptive fields today have not been filtered. 

In [None]:
RF_vis = data['RF_vis']

RF_vis.shape

In [None]:
fig, axs = plt.subplots(13,10,figsize=(20,20))
for n, ax in enumerate(range(RF_vis.shape[0])):
    ax = axs.flatten()[n]
    cmax = np.max(np.abs(RF_vis[n]))
    ax.imshow(RF_vis[n,2],cmap='RdBu_r',vmin=-cmax,vmax=cmax)
    ax.axis('off')
    ax.set_title(f'{n}')

# Traditional STA

In [None]:
# make sure there are no nans in the data
xtr[torch.isnan(xtr)] = 0

# Compute the STA
sta = xtr.T @ ytr
sta = sta/torch.sum(ytr,dim=0,keepdim=True)

# Reshape for visualization
sta_all = sta.T.reshape(Num_units,im_size[0],im_size[1]).cpu().numpy()


In [None]:
# Plotting the RFs
fig, axs = plt.subplots(13,10,figsize=(20,20))
for n, ax in enumerate(range(sta_all.shape[0])):
    ax = axs.flatten()[n]
    cmax = np.max(np.abs(sta_all[n]))
    ax.imshow(sta_all[n],cmap='RdBu_r',vmin=-cmax,vmax=cmax)
    ax.axis('off')
    ax.set_title(f'{n}')

# Small extensions

So far we have only calculated a RF for a single time point. To create a spatio-temporal receptive field we can add additional time delayed inputs to our data. 

In [None]:
NKfold = 1 # Number of Kfolds for the shuffle
train_size = 0.8 # Fraction of data used for training set
frac = 0.1 # fraction of the data to create chunks
gss = GroupShuffleSplit(n_splits=NKfold, train_size=train_size, random_state=42)
nT = model_nsp.shape[0] # Number of timepoints
groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((frac*(i-1))*nT)) for i in range(1,int(1/frac)+1)]) # defining groups

# Create list of train and test indicies
train_idx_list=[]
test_idx_list = []
for train_idx, test_idx in gss.split(np.arange(nT), groups=groups):
    train_idx_list.append(train_idx)
    test_idx_list.append(test_idx)
    
# Defining train and test datasets, with option to crop images. 
cropn = 0
train_idx = train_idx_list[0]
test_idx = test_idx_list[0]
if cropn>0:
    xtrain = model_vid_sm[train_idx][:,cropn:-cropn,cropn:-cropn]
    xtest = model_vid_sm[test_idx][:,cropn:-cropn,cropn:-cropn]
else: 
    xtrain = model_vid_sm[train_idx]
    xtest = model_vid_sm[test_idx]
im_size = xtrain.shape[1:]
xtrain = xtrain.reshape(len(train_idx),-1)
xtest = xtest.reshape(len(test_idx),-1)
ytrain = model_nsp[train_idx]
ytest = model_nsp[test_idx]

xtrain.shape, ytrain.shape, xtest.shape, ytest.shape

In [None]:
lag_list = [-1,0,1]
xtrain = np.hstack([np.roll(xtrain, nframes, axis=0) for nframes in lag_list])
xtest = np.hstack([np.roll(xtest, nframes, axis=0) for nframes in lag_list])

In [None]:
# Pytorch uses the tensor datastructure, here we load the numpy arrays into tensors and put them onto the gpu for processing if available
xtr, xte, ytr, yte = torch.from_numpy(xtrain).float().to(device), torch.from_numpy(xtest).float().to(device), torch.from_numpy(ytrain).float().to(device), torch.from_numpy(ytest).float().to(device)

# creating some variables to keep track of dimensions
input_size = xtr.shape[1]
output_size = ytr.shape[1]
Num_units = model_nsp.shape[1]

In [None]:
Nepochs = 2000 # Number of epochs
l2_lambda_list = [.1] # List of L2 regularzation strengths to iterate over
l1_alpha = 0.0001 # Strength of L1 regularization
min_loss = np.inf # define initial validation loss

##### Use tqdm to visualize progress #####
with tqdm(initial=0,total=len(l2_lambda_list), dynamic_ncols=False, miniters=1) as tq:
    ##### Loop through the different L2 values
    for l2_lambda in l2_lambda_list:
        ##### For each L2 value we define a new model to train #####
        model = nn.Sequential(nn.Linear(input_size,output_size),
                            nn.ReLU()).to(device)
        ##### Define the optimizer for learning #####
        optimizer = torch.optim.Adam(model.parameters(), lr=.001, weight_decay=l2_lambda)
        
        ##### Training loop ####
        for epoch in tqdm(range(Nepochs),leave=False):
            model.train()
            optimizer.zero_grad()
            yhat = model(xtr)
            train_loss = nn.MSELoss()(yhat, ytr) + l1_alpha*torch.norm(model[0].weight,p=1)
            train_loss.backward(torch.ones_like(train_loss))
            optimizer.step()
    
        ##### Check on validation data ####
        yhat = model(xte)
        ##### Calculate validation loss #####
        val_loss = nn.MSELoss()(yhat, yte)  + l1_alpha*torch.norm(model[0].weight,p=1)
        
        #### If validation loss is new minimum save, and save model parameters ####
        if val_loss < min_loss:
            l2_lambda_min = l2_lambda
            torch.save(model.state_dict(),'./RF_l2_min_spatiotemporal.pt')
            min_loss = val_loss
        
        ##### Update visualziation progress #### 
        tq.set_postfix(val_loss='{:05.3f}'.format(val_loss),train_loss='{:05.3f}'.format(train_loss),min_loss='{:05.3f}'.format(min_loss))
        tq.update()
        
##### load best model #####
load_model = torch.load('./RF_l2_min_spatiotemporal.pt')
model.load_state_dict(load_model)

In [None]:
# Grabbing RFs from model
RF_ST = model[0].weight.detach().cpu().numpy().reshape(Num_units,len(lag_list),im_size[0],im_size[1])
RF_ST.shape

In [None]:
n_neurons = [4,7,63,120,125]
fig, axs = plt.subplots(len(n_neurons),3,figsize=(15,12))
for n, ax in enumerate(n_neurons):
    cmax = np.max(np.abs(RF_ST[n]))
    for k in range(len(lag_list)):
        ax = axs[n,k]
        ax.imshow(RF_ST[n,k],cmap='RdBu_r',vmin=-cmax,vmax=cmax)
        ax.axis('off')
        ax.set_title('Unit:{},t={}'.format(n_neurons[n],lag_list[k]))