<a href="https://colab.research.google.com/github/elliottabe/RF_Workshop/blob/main/Workshop_notebook.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Installing repo and dependencies if using colab. Skip to imports if running locally
!pip install -U matplotlib &> /dev/null
!git clone https://github.com/elliottabe/RF_workshop.git &> /dev/null
!pip install -r ./RF_workshop/requirements.txt &> /dev/null
# !pip install git+https://github.com/elliottabe/RF_workshop.git &> /dev/null

In [None]:
import gdown
file_id = '1AUYAmfQp3Hh25uf_mohaT3N3qLXKlMeo' # File id to example data
output_file = 'data.h5'

gdown.download(f"https://drive.google.com/uc?id={file_id}", output_file)

# Import modules

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
import io_dict_to_hdf5 as ioh5

##### Plotting settings ######
import matplotlib as mpl

mpl.rcParams.update({'font.size':         10,
                     'axes.linewidth':    2,
                     'xtick.major.size':  3,
                     'xtick.major.width': 2,
                     'ytick.major.size':  3,
                     'ytick.major.width': 2,
                     'axes.spines.right': False,
                     'axes.spines.top':   False,
                     'pdf.fonttype':      42,
                     'xtick.labelsize':   10,
                     'ytick.labelsize':   10,
                     'figure.facecolor': 'white'

                    })


In [None]:
data = ioh5.load('./data.h5')


In [None]:
data.keys()

In [None]:
model_nsp = data['model_nsp']
model_vid_sm = data['model_vid_sm']

In [None]:
model_nsp.shape, model_vid_sm.shape

# Data prep

In [None]:
from sklearn.model_selection import train_test_split, GroupShuffleSplit

In [None]:
NKfold = 1
test_train_size = 0.8
frac = 0.1
gss = GroupShuffleSplit(n_splits=NKfold, train_size=test_train_size, random_state=42)
nT = model_nsp.shape[0]
groups = np.hstack([i*np.ones(int((frac*i)*nT) - int((frac*(i-1))*nT)) for i in range(1,int(1/frac)+1)])

train_idx_list=[]
test_idx_list = []
for train_idx, test_idx in gss.split(np.arange(nT), groups=groups):
    train_idx_list.append(train_idx)
    test_idx_list.append(test_idx)
    
cropn = 0
train_idx = train_idx_list[0]
test_idx = test_idx_list[0]
if cropn>0:
    xtrain = model_vid_sm[train_idx][:,cropn:-cropn,cropn:-cropn]
    xtest = model_vid_sm[test_idx][:,cropn:-cropn,cropn:-cropn]
else: 
    xtrain = model_vid_sm[train_idx]
    xtest = model_vid_sm[test_idx]
im_size = xtrain.shape[1:]
xtrain = xtrain.reshape(len(train_idx),-1)
xtest = xtest.reshape(len(test_idx),-1)
ytrain = model_nsp[train_idx]
ytest = model_nsp[test_idx]

xtrain.shape, ytrain.shape, xtest.shape, ytest.shape

In [None]:
xtr, xte, ytr, yte = torch.from_numpy(xtrain).float().to(device), torch.from_numpy(xtest).float().to(device), torch.from_numpy(ytrain).float().to(device), torch.from_numpy(ytest).float().to(device)

In [None]:
input_size = xtr.shape[1]
output_size = ytr.shape[1]
Num_units = model_nsp.shape[1]

# Create Pytorch Model

In [None]:
model = nn.Sequential(nn.Linear(input_size,output_size),
                      nn.ReLU()).to(device)
optimizer = torch.optim.RAdam(model.parameters(), lr=.001, weight_decay=.1)
# optimizer = torch.optim.SGD(model.parameters(), lr=.001, weight_decay=5)

In [None]:
Nepochs = 2000
l2_lambda_list = [.05,.1,1]
l1_alpha = 0.0001
min_loss = np.inf
with tqdm(initial=0,total=len(l2_lambda_list), dynamic_ncols=False, miniters=1) as tq:
    for l2_lambda in l2_lambda_list:
        model = nn.Sequential(nn.Linear(input_size,output_size),
                            nn.ReLU()).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=.001, weight_decay=l2_lambda)
        
        for epoch in tqdm(range(Nepochs),leave=False):
            model.train()
            optimizer.zero_grad()
            yhat = model(xtr)
            train_loss = nn.MSELoss()(yhat, ytr) + l1_alpha*torch.norm(model[0].weight,p=1)
            train_loss.backward(torch.ones_like(train_loss))
            optimizer.step()
            

        
        yhat = model(xte)
        val_loss = nn.MSELoss()(yhat, yte)  + l1_alpha*torch.norm(model[0].weight,p=1)

        if val_loss < min_loss:
            l2_lambda_min = l2_lambda
            torch.save(model.state_dict(),'./RF_l2_min.pt')
            min_loss = val_loss
            
        tq.set_postfix(val_loss='{:05.3f}'.format(val_loss),train_loss='{:05.3f}'.format(train_loss),min_loss='{:05.3f}'.format(min_loss))
        tq.update()
load_model = torch.load('./RF_l2_min.pt')
model.load_state_dict(load_model)

In [None]:
RF = model[0].weight.detach().cpu().numpy().reshape(Num_units,im_size[0],im_size[1])
RF.shape

In [None]:
fig, axs = plt.subplots(13,10,figsize=(20,20))
for n, ax in enumerate(range(RF.shape[0])):
    ax = axs.flatten()[n]
    cmax = np.max(np.abs(RF[n]))
    ax.imshow(RF[n],cmap='RdBu_r',vmin=-cmax,vmax=cmax)
    ax.axis('off')
    ax.set_title(f'{n}')

In [None]:
RF_vis = data['RF_vis']

RF_vis.shape

In [None]:
fig, axs = plt.subplots(13,10,figsize=(20,20))
for n, ax in enumerate(range(RF_vis.shape[0])):
    ax = axs.flatten()[n]
    cmax = np.max(np.abs(RF_vis[n]))
    ax.imshow(RF_vis[n,2],cmap='RdBu_r',vmin=-cmax,vmax=cmax)
    ax.axis('off')
    ax.set_title(f'{n}')

# Traditional STA

In [None]:
xtr[torch.isnan(xtr)] = 0

# fig = plt.figure(figsize=(20, np.ceil(n_units/2)))
sta_all = np.zeros((Num_units,im_size[0],im_size[1]))
for c in range(Num_units):

    sp = ytr[:,c].clone().unsqueeze(1)
    # sp = np.roll(sp, -lag)
    sta = xtr.T @ sp
    sta = torch.reshape(sta, im_size)
    nsp = torch.sum(sp)

    # plt.subplot(int(np.ceil(n_units/10)), 10, c+1)

    if nsp > 0:

        sta = sta/nsp
        # flip matrix so that physical top is at the top (worldcam comes in upsidedown)
        # sta = np.fliplr(np.flipud(sta))
    sta_all[c] = sta.cpu().numpy()

In [None]:
fig, axs = plt.subplots(13,10,figsize=(20,20))
for n, ax in enumerate(range(sta_all.shape[0])):
    ax = axs.flatten()[n]
    cmax = np.max(np.abs(sta_all[n]))
    ax.imshow(sta_all[n],cmap='RdBu_r',vmin=-cmax,vmax=cmax)
    ax.axis('off')
    ax.set_title(f'{n}')