In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Subset, DataLoader, ConcatDataset
from mouse_model.data_utils_new import MouseDatasetSegNewBehav
import numpy as np
from mouse_model.evaluation import cor_in_time
from sklearn.metrics import r2_score, mean_squared_error
import random
import os
import time
import torch.nn.init as init
from torch.nn import functional as F
from kornia.geometry.transform import get_affine_matrix2d, warp_affine
import matplotlib.pyplot as plt

the model

In [2]:
class Shifter(nn.Module):
    def __init__(self, input_dim=4, output_dim=3, hidden_dim=256, seq_len=8):
        super().__init__()
        self.seq_len = seq_len
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.layers = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh(),
        )
        self.bias = nn.Parameter(torch.zeros(3))
    def forward(self, x):
        x = x.reshape(-1,self.input_dim )
        x = self.layers(x)
        x0 = (x[...,0] + self.bias[0]) * 80/5.5
        x1 = (x[...,1] + self.bias[1]) * 60/5.5
        x2 = (x[...,2] + self.bias[2]) * 180/4
        x = torch.stack([x0, x1, x2], dim=-1)
        x = x.reshape(-1,self.seq_len,self.output_dim)
        return x

In [3]:
Shifter()

Shifter(
  (layers): Sequential(
    (0): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Linear(in_features=4, out_features=256, bias=True)
    (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Tanh()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Tanh()
    (7): Linear(in_features=256, out_features=3, bias=True)
    (8): Tanh()
  )
)

In [4]:
class PrintLayer(nn.Module):
    
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        print(x.shape)
        return x
    
def size_helper(in_length, kernel_size, padding=0, dilation=1, stride=1):
    # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d
    res = in_length + 2 * padding - dilation * (kernel_size - 1) - 1
    res /= stride
    res += 1
    return np.floor(res)

# CNN, the last fully connected layer maps to output_dim
class VisualEncoder(nn.Module):
    
    def __init__(self, output_dim, input_shape=(60, 80), k1=7, k2=7, k3=7):
        
        super().__init__()
        
        self.input_shape = (60, 80)
        out_shape_0 = size_helper(in_length=input_shape[0], kernel_size=k1, stride=2)
        out_shape_0 = size_helper(in_length=out_shape_0, kernel_size=k2, stride=2)
        out_shape_0 = size_helper(in_length=out_shape_0, kernel_size=k3, stride=2)
        out_shape_1 = size_helper(in_length=input_shape[1], kernel_size=k1, stride=2)
        out_shape_1 = size_helper(in_length=out_shape_1, kernel_size=k2, stride=2)
        out_shape_1 = size_helper(in_length=out_shape_1, kernel_size=k3, stride=2)
        self.output_shape = (int(out_shape_0), int(out_shape_1)) # shape of the final feature map
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=k1, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=k2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=k3, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Flatten(),
            nn.Linear(480, output_dim)
        )
        
    def forward(self, x):

        x = self.layers(x)

        return x

    
# may consider adding an activation after linear
class BehavEncoder(nn.Module):
    
    def __init__(self, behav_dim, output_dim):
        
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.BatchNorm1d(behav_dim),
            nn.Linear(behav_dim, output_dim),
        )
        
    def forward(self, x):

        x = self.layers(x)

        return x

class LSTMPerNeuronCombiner(nn.Module):
    
    def __init__(self, num_neurons, behav_dim, k1, k2, k3, seq_len, hidden_size=512):
        
        super().__init__()
        
        self.seq_len = seq_len
        self.num_neurons = num_neurons
        self.shifter = Shifter(seq_len = seq_len)
        self.visual_encoder = VisualEncoder(output_dim=num_neurons, k1=k1, k2=k2, k3=k3)
        self.behav_encoder = BehavEncoder(behav_dim=behav_dim, output_dim=num_neurons)
        self.bn = nn.BatchNorm1d(3) # apply bn to vis_feats, beh_feats, prod
        self.lstm_net = nn.GRU(input_size=num_neurons*3, hidden_size=hidden_size, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_neurons)
        self.softplus = nn.Softplus() # we could also do relu or elu offset by 1
        
    def forward(self, images, behav):
        if args.shifter:
            bs = images.size()[0]
            behav_shifter = torch.concat((behav[...,4].unsqueeze(-1),   # theta
                                          behav[...,3].unsqueeze(-1),   # phi
                                          behav[...,1].unsqueeze(-1),  # pitch
                                         behav[...,2].unsqueeze(-1),  # roll
                                         ), dim=-1)  
            shift_param = self.shifter(behav_shifter)  
            shift_param = shift_param.reshape(-1,3)
            scale_param = torch.ones_like(shift_param[..., 0:2]).to(shift_param.device)
            affine_mat = get_affine_matrix2d(
                                            translations=shift_param[..., 0:2] ,
                                             scale = scale_param, 
                                             center =torch.repeat_interleave(torch.tensor([[30,40]], dtype=torch.float), 
                                                                            bs*self.seq_len, dim=0).to(shift_param.device), 
                                             angle=shift_param[..., 2])
            affine_mat = affine_mat[:, :2, :]
            images = warp_affine(images.reshape(-1,1,60,80), affine_mat, dsize=(60,80)).reshape(bs, self.seq_len,1,60,80)
        
        # get visual behavioral features in time
        vis_beh_feats = []
        for i in range(self.seq_len):
            v = self.visual_encoder(images[:, i, :, :, :])
            b = self.behav_encoder(behav[:, i, :])
            vb = v * b
            vis_beh_feat = torch.stack([v, b, vb], axis=1)
            vis_beh_feat = self.bn(vis_beh_feat)
            vis_beh_feats.append(vis_beh_feat)
        vis_beh_feats = torch.stack(vis_beh_feats, axis=1)
        
        # flatten features to (batch_size, seq_len, num_neurons*3)
        vis_beh_feats = torch.flatten(vis_beh_feats, start_dim=2)
        
        # get LSTM output
        output, _ = self.lstm_net(vis_beh_feats)
        output = output[:, -1, :] # extract the last hidden state
        
        # fully connected layer and activation function
        output = self.fc(output)
        pred_spikes = self.softplus(output)

        return pred_spikes

In [5]:
class Args:
    
    seed = 0
    file_id = None
    epochs = 50
    batch_size = 256
    learning_rate = 0.0002
    l1_reg_w = 1
    seq_len = None
    num_neurons = None
    behav_mode = None
    behav_dim = None
    best_val_path = None
    best_train_path = None
    vid_type = "vid_mean"
    segment_num = 10
    hidden_size = 512
    shifter = True
    
args=Args()

def set_random_seed(seed: int, deterministic: bool = True):
    # from nnfabrik package
    random.seed(seed)
    np.random.seed(seed)
    if deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)  # this sets both CPU and CUDA seeds for PyTorch

seed = args.seed
set_random_seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

torch.cuda.empty_cache()
print(torch.cuda.is_available())

True


In [6]:
def load_train_val_ds():
    ds_list = [MouseDatasetSegNewBehav(file_id=args.file_id, segment_num=args.segment_num, seg_idx=i, data_split="train", 
                               vid_type=args.vid_type, seq_len=args.seq_len, predict_offset=1, 
                                       behav_mode=args.behav_mode, norm_mode="01") 
               for i in range(args.segment_num)]
    train_ds, val_ds = [], []
    for ds in ds_list:
        train_ratio = 0.8
        train_ds_len = int(len(ds) * train_ratio)
        train_ds.append(Subset(ds, np.arange(0, train_ds_len, 1)))
        val_ds.append(Subset(ds, np.arange(train_ds_len, len(ds), 1)))
    train_ds = ConcatDataset(train_ds)
    val_ds = ConcatDataset(val_ds)
    print(len(train_ds), len(val_ds))
    return train_ds, val_ds

In [7]:
def load_test_ds():
    test_ds = [MouseDatasetSegNewBehav(file_id=args.file_id, segment_num=args.segment_num, seg_idx=i, data_split="test", 
                               vid_type=args.vid_type, seq_len=args.seq_len, predict_offset=1, 
                                       behav_mode=args.behav_mode, norm_mode="01") 
               for i in range(args.segment_num)]
    test_ds = ConcatDataset(test_ds)
    return test_ds

In [8]:
def train_model():
    
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    
    train_ds, val_ds = load_train_val_ds()

    train_dataloader = DataLoader(dataset=train_ds, batch_size=args.batch_size, shuffle=True, num_workers=8)
    val_dataloader = DataLoader(dataset=val_ds, batch_size=args.batch_size, shuffle=False, num_workers=8)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    best_train_spike_loss = np.inf
    best_val_spike_loss = np.inf
    train_loss_list = []
    val_loss_list = []

    # start training
    ct = 0
    
    for epoch in range(args.epochs):

        print("Start epoch", epoch)

        model.train()

        epoch_train_loss, epoch_train_spike_loss = 0, 0

        for (image, behav, spikes) in train_dataloader:

            image, behav, spikes = image.to(device), behav.to(device), spikes.to(device)
            
            pred = model(image, behav)

            spike_loss = nn.functional.poisson_nll_loss(pred, spikes, reduction='mean', log_input=False)
            
            l1_reg, l1_reg_num_param = 0.0, 0
            for name, param in model.named_parameters():
                if name == "behav_encoder.layers.1.weight":
                    l1_reg += param.abs().sum()
                    l1_reg_num_param += param.shape[0]*param.shape[1]
            l1_reg /= l1_reg_num_param

            total_loss = spike_loss + args.l1_reg_w * l1_reg

            epoch_train_loss += total_loss.item()
            epoch_train_spike_loss += spike_loss.item()

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        epoch_train_loss = epoch_train_loss / len(train_dataloader)
        epoch_train_spike_loss = epoch_train_spike_loss / len(train_dataloader)

        train_loss_list.append(epoch_train_loss)
        
        print("Epoch {} train loss: {}".format(epoch, epoch_train_loss))

        if epoch_train_spike_loss < best_train_spike_loss:

            print("save train model at epoch", epoch)
            torch.save(model.state_dict(), args.best_train_path)
            best_train_spike_loss = epoch_train_spike_loss

        model.eval()

        epoch_val_spike_loss = 0

        with torch.no_grad():      

            for (image, behav, spikes) in val_dataloader:

                image, behav, spikes = image.to(device), behav.to(device), spikes.to(device)

                pred = model(image, behav)

                loss = nn.functional.poisson_nll_loss(pred, spikes, reduction='mean', log_input=False)

                epoch_val_spike_loss += loss.item()

        epoch_val_spike_loss = epoch_val_spike_loss / len(val_dataloader)

        val_loss_list.append(epoch_val_spike_loss)
        
        print("Epoch {} val loss: {}".format(epoch, epoch_val_spike_loss))
        
        if epoch_val_spike_loss < best_val_spike_loss:
            ct = 0

            print("save val model at epoch", epoch)
            torch.save(model.state_dict(), args.best_val_path)
            best_val_spike_loss = epoch_val_spike_loss
        else:
            ct += 1
            if ct >=5:
                print('stop training')
                break

        print("End epoch", epoch)
        
    return train_loss_list, val_loss_list

In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')    

for file_id, num_neurons in  [("070921_J553RT", 68), ("101521_J559NC", 49), ("110421_J569LT", 32)]:
    for behav_mode, behav_dim in [("all_prod", 66)]:
        for seq_len in range(1, 9): 
            print(file_id, behav_mode, seq_len)
            
            args.file_id = file_id
            args.num_neurons = num_neurons
            args.shifter=True

            args.behav_mode = behav_mode
            args.behav_dim = behav_dim
            
            args.seq_len = seq_len

            args.best_train_path = "weights_cnn_gru_shifter/train_{}_{}_seq_{}.pth".format(
                args.file_id, args.behav_mode, args.seq_len)
            args.best_val_path = "weights_cnn_gru_shifter/val_{}_{}_seq_{}.pth".format(
                args.file_id, args.behav_mode, args.seq_len)

            model = LSTMPerNeuronCombiner(num_neurons=args.num_neurons, 
                                          behav_dim=args.behav_dim, 
                                          k1=7, k2=7, k3=7, 
                                          seq_len=args.seq_len,
                                          hidden_size=args.hidden_size).to(device)

            train_loss_list, val_loss_list = train_model()

070921_J553RT all_prod 1
30120 7540
Start epoch 0
Epoch 0 train loss: 0.7431484688136537
save train model at epoch 0
Epoch 0 val loss: 0.6559784253438313
save val model at epoch 0
End epoch 0
Start epoch 1
Epoch 1 train loss: 0.676811802690312
save train model at epoch 1
Epoch 1 val loss: 0.6420734246571859
save val model at epoch 1
End epoch 1
Start epoch 2
Epoch 2 train loss: 0.6523936643438825
save train model at epoch 2
Epoch 2 val loss: 0.6371280570824941
save val model at epoch 2
End epoch 2
Start epoch 3
Epoch 3 train loss: 0.6362312577538571
save train model at epoch 3
Epoch 3 val loss: 0.6290084004402161
save val model at epoch 3
End epoch 3
Start epoch 4
Epoch 4 train loss: 0.6236386743642516
save train model at epoch 4
Epoch 4 val loss: 0.6207810123761495
save val model at epoch 4
End epoch 4
Start epoch 5
Epoch 5 train loss: 0.6161649297859709
save train model at epoch 5
Epoch 5 val loss: 0.6180520315965017
save val model at epoch 5
End epoch 5
Start epoch 6
Epoch 6 train l

Epoch 14 train loss: 0.578990152831805
save train model at epoch 14
Epoch 14 val loss: 0.5963725646336874
save val model at epoch 14
End epoch 14
Start epoch 15
Epoch 15 train loss: 0.5767980652340388
save train model at epoch 15
Epoch 15 val loss: 0.5945991237958272
save val model at epoch 15
End epoch 15
Start epoch 16
Epoch 16 train loss: 0.5747399436215223
save train model at epoch 16
Epoch 16 val loss: 0.5952917257944743
End epoch 16
Start epoch 17
Epoch 17 train loss: 0.5733452893919864
save train model at epoch 17
Epoch 17 val loss: 0.5987995445728302
End epoch 17
Start epoch 18
Epoch 18 train loss: 0.5714053819745274
save train model at epoch 18
Epoch 18 val loss: 0.5939799507459005
save val model at epoch 18
End epoch 18
Start epoch 19
Epoch 19 train loss: 0.569688734866805
save train model at epoch 19
Epoch 19 val loss: 0.5932040015856425
save val model at epoch 19
End epoch 19
Start epoch 20
Epoch 20 train loss: 0.5682554441993519
save train model at epoch 20
Epoch 20 val lo

Epoch 18 val loss: 0.5924500584602356
save val model at epoch 18
End epoch 18
Start epoch 19
Epoch 19 train loss: 0.5598759757260144
save train model at epoch 19
Epoch 19 val loss: 0.5920013944307964
save val model at epoch 19
End epoch 19
Start epoch 20
Epoch 20 train loss: 0.5574767670388949
save train model at epoch 20
Epoch 20 val loss: 0.5945514837900797
End epoch 20
Start epoch 21
Epoch 21 train loss: 0.5558825231204598
save train model at epoch 21
Epoch 21 val loss: 0.594162768125534
End epoch 21
Start epoch 22
Epoch 22 train loss: 0.5538186228881448
save train model at epoch 22
Epoch 22 val loss: 0.5940540770689646
End epoch 22
Start epoch 23
Epoch 23 train loss: 0.5519520044326782
save train model at epoch 23
Epoch 23 val loss: 0.5938971916834513
End epoch 23
Start epoch 24
Epoch 24 train loss: 0.5501341254024182
save train model at epoch 24
Epoch 24 val loss: 0.5956300477186839
stop training
070921_J553RT all_prod 5
30090 7530
Start epoch 0
Epoch 0 train loss: 0.7205607249575

Epoch 1 val loss: 0.6295513908068339
save val model at epoch 1
End epoch 1
Start epoch 2
Epoch 2 train loss: 0.635275269969035
save train model at epoch 2
Epoch 2 val loss: 0.6196413139502207
save val model at epoch 2
End epoch 2
Start epoch 3
Epoch 3 train loss: 0.6192219903913595
save train model at epoch 3
Epoch 3 val loss: 0.6134775300820668
save val model at epoch 3
End epoch 3
Start epoch 4
Epoch 4 train loss: 0.607691914348279
save train model at epoch 4
Epoch 4 val loss: 0.6089808483918507
save val model at epoch 4
End epoch 4
Start epoch 5
Epoch 5 train loss: 0.599547598321559
save train model at epoch 5
Epoch 5 val loss: 0.6051565845807393
save val model at epoch 5
End epoch 5
Start epoch 6
Epoch 6 train loss: 0.5934569204257707
save train model at epoch 6
Epoch 6 val loss: 0.6022655745347341
save val model at epoch 6
End epoch 6
Start epoch 7
Epoch 7 train loss: 0.5883054723173885
save train model at epoch 7
Epoch 7 val loss: 0.5993504981199901
save val model at epoch 7
End 

Epoch 12 train loss: 0.6663903457572661
save train model at epoch 12
Epoch 12 val loss: 0.6859872852052961
End epoch 12
Start epoch 13
Epoch 13 train loss: 0.6642881653395044
save train model at epoch 13
Epoch 13 val loss: 0.6812135633968172
save val model at epoch 13
End epoch 13
Start epoch 14
Epoch 14 train loss: 0.6627669779651136
save train model at epoch 14
Epoch 14 val loss: 0.6842233013539087
End epoch 14
Start epoch 15
Epoch 15 train loss: 0.660872571080564
save train model at epoch 15
Epoch 15 val loss: 0.6818091088817233
End epoch 15
Start epoch 16
Epoch 16 train loss: 0.6597490364528564
save train model at epoch 16
Epoch 16 val loss: 0.6854045831021809
End epoch 16
Start epoch 17
Epoch 17 train loss: 0.6579938997705299
save train model at epoch 17
Epoch 17 val loss: 0.6788630740983146
save val model at epoch 17
End epoch 17
Start epoch 18
Epoch 18 train loss: 0.6567348658320415
save train model at epoch 18
Epoch 18 val loss: 0.6778451587472644
save val model at epoch 18
End

Epoch 12 train loss: 0.6466147472341377
save train model at epoch 12
Epoch 12 val loss: 0.6729027572132292
save val model at epoch 12
End epoch 12
Start epoch 13
Epoch 13 train loss: 0.6440404420157513
save train model at epoch 13
Epoch 13 val loss: 0.6743664202235994
End epoch 13
Start epoch 14
Epoch 14 train loss: 0.6415298448269626
save train model at epoch 14
Epoch 14 val loss: 0.6754869818687439
End epoch 14
Start epoch 15
Epoch 15 train loss: 0.6392319952867117
save train model at epoch 15
Epoch 15 val loss: 0.6714543671835036
save val model at epoch 15
End epoch 15
Start epoch 16
Epoch 16 train loss: 0.6368268448186208
save train model at epoch 16
Epoch 16 val loss: 0.6718713385718209
End epoch 16
Start epoch 17
Epoch 17 train loss: 0.6345041483999735
save train model at epoch 17
Epoch 17 val loss: 0.6721480332669758
End epoch 17
Start epoch 18
Epoch 18 train loss: 0.6326008525239416
save train model at epoch 18
Epoch 18 val loss: 0.675529108161018
End epoch 18
Start epoch 19
Ep

Epoch 9 train loss: 0.6489110782922033
save train model at epoch 9
Epoch 9 val loss: 0.673529421999341
End epoch 9
Start epoch 10
Epoch 10 train loss: 0.64550879956728
save train model at epoch 10
Epoch 10 val loss: 0.6723382714248839
save val model at epoch 10
End epoch 10
Start epoch 11
Epoch 11 train loss: 0.6416267188916723
save train model at epoch 11
Epoch 11 val loss: 0.6829407385417393
End epoch 11
Start epoch 12
Epoch 12 train loss: 0.6385589457419981
save train model at epoch 12
Epoch 12 val loss: 0.6754292774768103
End epoch 12
Start epoch 13
Epoch 13 train loss: 0.6350795170628881
save train model at epoch 13
Epoch 13 val loss: 0.6738817521503994
End epoch 13
Start epoch 14
Epoch 14 train loss: 0.6318697498505375
save train model at epoch 14
Epoch 14 val loss: 0.6740118023895082
End epoch 14
Start epoch 15
Epoch 15 train loss: 0.6289995948234236
save train model at epoch 15
Epoch 15 val loss: 0.6731084599381402
stop training
101521_J559NC all_prod 7
42360 10600
Start epoch 

Epoch 16 train loss: 0.6906477568685546
save train model at epoch 16
Epoch 16 val loss: 0.7331881487008297
End epoch 16
Start epoch 17
Epoch 17 train loss: 0.6888295860253564
save train model at epoch 17
Epoch 17 val loss: 0.7312970414306178
End epoch 17
Start epoch 18
Epoch 18 train loss: 0.6881138169488241
save train model at epoch 18
Epoch 18 val loss: 0.7231675350304806
save val model at epoch 18
End epoch 18
Start epoch 19
Epoch 19 train loss: 0.6861006178597148
save train model at epoch 19
Epoch 19 val loss: 0.7225732586600564
save val model at epoch 19
End epoch 19
Start epoch 20
Epoch 20 train loss: 0.6845957858617916
save train model at epoch 20
Epoch 20 val loss: 0.7265421964905479
End epoch 20
Start epoch 21
Epoch 21 train loss: 0.6834238698316175
save train model at epoch 21
Epoch 21 val loss: 0.7228349754304597
End epoch 21
Start epoch 22
Epoch 22 train loss: 0.6821787070858386
save train model at epoch 22
Epoch 22 val loss: 0.7374747052337184
End epoch 22
Start epoch 23
E

Epoch 4 val loss: 0.7245057351661451
save val model at epoch 4
End epoch 4
Start epoch 5
Epoch 5 train loss: 0.7006916717965473
save train model at epoch 5
Epoch 5 val loss: 0.7219607089505051
save val model at epoch 5
End epoch 5
Start epoch 6
Epoch 6 train loss: 0.6945571973342304
save train model at epoch 6
Epoch 6 val loss: 0.7189166455557852
save val model at epoch 6
End epoch 6
Start epoch 7
Epoch 7 train loss: 0.6889236726502116
save train model at epoch 7
Epoch 7 val loss: 0.7181993740977664
save val model at epoch 7
End epoch 7
Start epoch 8
Epoch 8 train loss: 0.6845214611800142
save train model at epoch 8
Epoch 8 val loss: 0.718849929896268
End epoch 8
Start epoch 9
Epoch 9 train loss: 0.6801352759664373
save train model at epoch 9
Epoch 9 val loss: 0.7161482048757148
save val model at epoch 9
End epoch 9
Start epoch 10
Epoch 10 train loss: 0.6774860066036845
save train model at epoch 10
Epoch 10 val loss: 0.71550114588304
save val model at epoch 10
End epoch 10
Start epoch 

Epoch 18 train loss: 0.6512138834295347
save train model at epoch 18
Epoch 18 val loss: 0.7138871297691808
End epoch 18
Start epoch 19
Epoch 19 train loss: 0.6480577130650365
save train model at epoch 19
Epoch 19 val loss: 0.7186040751861803
End epoch 19
Start epoch 20
Epoch 20 train loss: 0.6452609041864558
save train model at epoch 20
Epoch 20 val loss: 0.7175055507457617
End epoch 20
Start epoch 21
Epoch 21 train loss: 0.6427546536275582
save train model at epoch 21
Epoch 21 val loss: 0.7186246153080103
End epoch 21
Start epoch 22
Epoch 22 train loss: 0.640399574771408
save train model at epoch 22
Epoch 22 val loss: 0.7161332007610437
stop training
110421_J569LT all_prod 7
32890 8230
Start epoch 0
Epoch 0 train loss: 0.8086218196292256
save train model at epoch 0
Epoch 0 val loss: 0.7455958326657613
save val model at epoch 0
End epoch 0
Start epoch 1
Epoch 1 train loss: 0.7589819551438324
save train model at epoch 1
Epoch 1 val loss: 0.732901004227725
save val model at epoch 1
End e

In [10]:
model

LSTMPerNeuronCombiner(
  (shifter): Shifter(
    (layers): Sequential(
      (0): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Linear(in_features=4, out_features=256, bias=True)
      (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Tanh()
      (4): Linear(in_features=256, out_features=256, bias=True)
      (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): Tanh()
      (7): Linear(in_features=256, out_features=3, bias=True)
      (8): Tanh()
    )
  )
  (visual_encoder): VisualEncoder(
    (layers): Sequential(
      (0): Conv2d(1, 128, kernel_size=(7, 7), stride=(2, 2))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Dropout(p=0.5, inplace=False)
      (4): Conv2d(128, 64, kernel_size=(7, 7), stride=(2, 2))
      (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tra

eval

In [8]:
# default is smoothing with 2 second, 48 ms per frame
def smoothing_with_np_conv(nsp, size=int(2000/48)):
    np_conv_res = []
    for i in range(nsp.shape[1]):
        np_conv_res.append(np.convolve(nsp[:, i], np.ones(size)/size, mode="same"))        
    np_conv_res = np.transpose(np.array(np_conv_res))
    return np_conv_res

In [9]:
def evaluate_model(model, weights_path, dataset, device):

    dl = DataLoader(dataset=dataset, batch_size=256, shuffle=False, num_workers=4)
    
    model.load_state_dict(torch.load(weights_path))

    ground_truth_all = []
    pred_all = []
    
    model.eval()
    
    with torch.no_grad():      
        
        for (image, behav, spikes) in dl:
            
            image = image.to(device)
            behav = behav.to(device)
            
            pred = model(image, behav)
            
            ground_truth_all.append(spikes.numpy())
            pred_all.append(pred.cpu().numpy())
    
    return np.concatenate(pred_all, axis=0), np.concatenate(ground_truth_all, axis=0)

In [13]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')    

for file_id, num_neurons in  [("070921_J553RT", 68), ("110421_J569LT", 32), ("101521_J559NC", 49)]:
    

    for behav_mode, behav_dim in [("all_prod", 66)]:

        for seq_len in range(1, 9): 

            print(file_id, behav_mode, seq_len)
            
            args.file_id = file_id
            args.num_neurons = num_neurons

            args.behav_mode = behav_mode
            args.behav_dim = behav_dim
            
            args.seq_len = seq_len

            args.best_val_path = "weights_cnn_gru_shifter/val_{}_{}_seq_{}.pth".format(
                args.file_id, args.behav_mode, args.seq_len)

            model = LSTMPerNeuronCombiner(num_neurons=args.num_neurons, 
                                          behav_dim=args.behav_dim, 
                                          k1=7, k2=7, k3=7, 
                                          seq_len=args.seq_len,
                                          hidden_size=args.hidden_size).to(device)
            
            train_ds, val_ds = load_train_val_ds()
            test_ds = load_test_ds()
            
            pred, label = evaluate_model(model, weights_path=args.best_val_path, dataset=test_ds, device=device)
            cor_array = cor_in_time(pred, label)
        #     print("best val model on test dataset, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
            pred = smoothing_with_np_conv(pred)
            label = smoothing_with_np_conv(label)
            # print("R2", "{:.6f}".format(r2_score(label.T, pred.T)))
            print("MSE", "{:.6f}".format(mean_squared_error(label, pred)))
            cor_array = cor_in_time(pred, label)
            print("mean corr, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
            # print("max corr", "{:.6f}".format(np.max(cor_array)))
            # print("min corr", "{:.6f}".format(np.min(cor_array)))

070921_J553RT all_prod 1
30120 7540
MSE 0.054336
mean corr, 0.646+-0.136
070921_J553RT all_prod 2
30120 7530
MSE 0.052839
mean corr, 0.649+-0.139
070921_J553RT all_prod 3
30110 7530
MSE 0.052778
mean corr, 0.653+-0.139
070921_J553RT all_prod 4
30100 7530
MSE 0.052547
mean corr, 0.650+-0.142
070921_J553RT all_prod 5
30090 7530
MSE 0.055567
mean corr, 0.645+-0.144
070921_J553RT all_prod 6
30080 7530
MSE 0.052630
mean corr, 0.654+-0.142
070921_J553RT all_prod 7
30080 7520
MSE 0.053443
mean corr, 0.644+-0.148
070921_J553RT all_prod 8
30070 7520
MSE 0.054729
mean corr, 0.646+-0.146
110421_J569LT all_prod 1
32940 8240
MSE 0.091719
mean corr, 0.508+-0.166
110421_J569LT all_prod 2
32930 8240
MSE 0.089837
mean corr, 0.506+-0.174
110421_J569LT all_prod 3
32920 8240
MSE 0.084313
mean corr, 0.528+-0.160
110421_J569LT all_prod 4
32920 8230
MSE 0.079940
mean corr, 0.566+-0.169
110421_J569LT all_prod 5
32910 8230
MSE 0.093304
mean corr, 0.519+-0.177
110421_J569LT all_prod 6
32900 8230
MSE 0.082298
me

save per neuron correlation for future plotting

In [11]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')    

for file_id, num_neurons in  [("070921_J553RT", 68), ("110421_J569LT", 32), ("101521_J559NC", 49)]:
    

    for behav_mode, behav_dim in [("all_prod", 66)]:

        seq_len = 1

        print(file_id, behav_mode, seq_len)

        args.file_id = file_id
        args.num_neurons = num_neurons

        args.behav_mode = behav_mode
        args.behav_dim = behav_dim

        args.seq_len = seq_len

        args.best_val_path = "weights_cnn_gru_shifter/val_{}_{}_seq_{}.pth".format(
            args.file_id, args.behav_mode, args.seq_len)

        model = LSTMPerNeuronCombiner(num_neurons=args.num_neurons, 
                                      behav_dim=args.behav_dim, 
                                      k1=7, k2=7, k3=7, 
                                      seq_len=args.seq_len,
                                      hidden_size=args.hidden_size).to(device)

        train_ds, val_ds = load_train_val_ds()
        test_ds = load_test_ds()

        pred, label = evaluate_model(model, weights_path=args.best_val_path, dataset=test_ds, device=device)
        cor_array = cor_in_time(pred, label)
    #     print("best val model on test dataset, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
        pred = smoothing_with_np_conv(pred)
        label = smoothing_with_np_conv(label)
        # print("R2", "{:.6f}".format(r2_score(label.T, pred.T)))
        print("MSE", "{:.6f}".format(mean_squared_error(label, pred)))
        cor_array = cor_in_time(pred, label)
        print("mean corr, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
        print(cor_array.shape)
        np.save("corr_cnn_gru_shifter/{}_seq_{}.npy".format(file_id, seq_len), cor_array)
        # print("max corr", "{:.6f}".format(np.max(cor_array)))
        # print("min corr", "{:.6f}".format(np.min(cor_array)))

070921_J553RT all_prod 1
30120 7540
MSE 0.054336
mean corr, 0.646+-0.136
(1, 68)
110421_J569LT all_prod 1
32940 8240
MSE 0.091719
mean corr, 0.508+-0.166
(1, 32)
101521_J559NC all_prod 1
42410 10610
MSE 0.080114
mean corr, 0.607+-0.132
(1, 49)
