In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Subset, DataLoader, ConcatDataset
from mouse_model.data_utils_new import MouseDatasetSegNewBehav
import numpy as np
from mouse_model.evaluation import cor_in_time
from sklearn.metrics import r2_score, mean_squared_error
import random
import os
import time
import torch.nn.init as init
from torch.nn import functional as F
from kornia.geometry.transform import get_affine_matrix2d, warp_affine
import matplotlib.pyplot as plt

the model

In [2]:
class Shifter(nn.Module):
    def __init__(self, input_dim=4, output_dim=3, hidden_dim=256, seq_len=8):
        super().__init__()
        self.seq_len = seq_len
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        self.layers = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh(),
        )
        self.bias = nn.Parameter(torch.zeros(3))
    def forward(self, x):
        x = x.reshape(-1,self.input_dim )
        x = self.layers(x)
        x0 = (x[...,0] + self.bias[0]) * 80/5.5
        x1 = (x[...,1] + self.bias[1]) * 60/5.5
        x2 = (x[...,2] + self.bias[2]) * 180/4
        x = torch.stack([x0, x1, x2], dim=-1)
        x = x.reshape(-1,self.seq_len,self.output_dim)
        return x

In [8]:
class PrintLayer(nn.Module):
    
    def __init__(self):
        super(PrintLayer, self).__init__()
    
    def forward(self, x):
        print(x.shape)
        return x
    
def size_helper(in_length, kernel_size, padding=0, dilation=1, stride=1):
    # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d
    res = in_length + 2 * padding - dilation * (kernel_size - 1) - 1
    res /= stride
    res += 1
    return np.floor(res)

# CNN, the last fully connected layer maps to output_dim
class VisualEncoder(nn.Module):
    
    def __init__(self, output_dim, input_shape=(60, 80), k1=7, k2=7, k3=7):
        
        super().__init__()
        
        self.input_shape = (60, 80)
        out_shape_0 = size_helper(in_length=input_shape[0], kernel_size=k1, stride=2)
        out_shape_0 = size_helper(in_length=out_shape_0, kernel_size=k2, stride=2)
        out_shape_0 = size_helper(in_length=out_shape_0, kernel_size=k3, stride=2)
        out_shape_1 = size_helper(in_length=input_shape[1], kernel_size=k1, stride=2)
        out_shape_1 = size_helper(in_length=out_shape_1, kernel_size=k2, stride=2)
        out_shape_1 = size_helper(in_length=out_shape_1, kernel_size=k3, stride=2)
        self.output_shape = (int(out_shape_0), int(out_shape_1)) # shape of the final feature map
        
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=k1, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=k2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=k3, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Flatten(),
            nn.Linear(480, output_dim)
        )
        
    def forward(self, x):

        x = self.layers(x)

        return x

    
# may consider adding an activation after linear
class BehavEncoder(nn.Module):
    
    def __init__(self, behav_dim, output_dim):
        
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.BatchNorm1d(behav_dim),
            nn.Linear(behav_dim, output_dim),
        )
        
    def forward(self, x):

        x = self.layers(x)

        return x

class LSTMPerNeuronCombiner(nn.Module):
    
    def __init__(self, num_neurons, behav_dim, k1, k2, k3, seq_len, hidden_size=512):
        
        super().__init__()
        
        self.seq_len = seq_len
        self.behav_dim = behav_dim
        self.num_neurons = num_neurons
        self.shifter = Shifter(seq_len = seq_len)
        self.visual_encoder = VisualEncoder(output_dim=num_neurons, k1=k1, k2=k2, k3=k3)
        self.behav_encoder = BehavEncoder(behav_dim=behav_dim, output_dim=num_neurons)
        self.bn = nn.BatchNorm1d(3) # apply bn to vis_feats, beh_feats, prod
        self.lstm_net = nn.GRU(input_size=num_neurons*3, hidden_size=hidden_size, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_neurons)
        self.softplus = nn.Softplus() # we could also do relu or elu offset by 1
        
    def forward(self, images, behav):
        if args.shifter:
            bs = images.size()[0]
            behav_shifter = torch.concat((behav[...,4].unsqueeze(-1),   # theta
                                          behav[...,3].unsqueeze(-1),   # phi
                                          behav[...,1].unsqueeze(-1),  # pitch
                                         behav[...,2].unsqueeze(-1),  # roll
                                         ), dim=-1)  
            shift_param = self.shifter(behav_shifter)  
            shift_param = shift_param.reshape(-1,3)
            scale_param = torch.ones_like(shift_param[..., 0:2]).to(shift_param.device)
            affine_mat = get_affine_matrix2d(
                                            translations=shift_param[..., 0:2] ,
                                             scale = scale_param, 
                                             center =torch.repeat_interleave(torch.tensor([[30,40]], dtype=torch.float), 
                                                                            bs*self.seq_len, dim=0).to(shift_param.device), 
                                             angle=shift_param[..., 2])
            affine_mat = affine_mat[:, :2, :]
            images = warp_affine(images.reshape(-1,1,60,80), affine_mat, dsize=(60,80)).reshape(bs, self.seq_len,1,60,80)

        if self.behav_dim == 3:
            behav = torch.concat((behav[...,0].unsqueeze(-1),  
                                behav[...,5].unsqueeze(-1), 
                                behav[...,-1].unsqueeze(-1),
                            ), dim=-1)  
        # get visual behavioral features in time
        vis_beh_feats = []
        for i in range(self.seq_len):
            v = self.visual_encoder(images[:, i, :, :, :])
            b = self.behav_encoder(behav[:, i, :])
            vb = v * b
            vis_beh_feat = torch.stack([v, b, vb], axis=1)
            vis_beh_feat = self.bn(vis_beh_feat)
            vis_beh_feats.append(vis_beh_feat)
        vis_beh_feats = torch.stack(vis_beh_feats, axis=1)
        
        # flatten features to (batch_size, seq_len, num_neurons*3)
        vis_beh_feats = torch.flatten(vis_beh_feats, start_dim=2)
        
        # get LSTM output
        output, _ = self.lstm_net(vis_beh_feats)
        output = output[:, -1, :] # extract the last hidden state
        
        # fully connected layer and activation function
        output = self.fc(output)
        pred_spikes = self.softplus(output)

        return pred_spikes

In [9]:
class Args:
    
    seed = 0
    file_id = None
    epochs = 50
    batch_size = 256
    learning_rate = 0.0002
    l1_reg_w = 1
    seq_len = None
    num_neurons = None
    behav_mode = None
    behav_dim = None
    best_val_path = None
    best_train_path = None
    vid_type = "vid_mean"
    segment_num = 10
    hidden_size = 512
    shifter = True
    
args=Args()

seed = args.seed
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

torch.cuda.empty_cache()
print(torch.cuda.is_available())

True


In [10]:
def load_train_val_ds():
    ds_list = [MouseDatasetSegNewBehav(file_id=args.file_id, segment_num=args.segment_num, seg_idx=i, data_split="train", 
                               vid_type=args.vid_type, seq_len=args.seq_len, predict_offset=1, 
                                       behav_mode=args.behav_mode, norm_mode="01") 
               for i in range(args.segment_num)]
    train_ds, val_ds = [], []
    for ds in ds_list:
        train_ratio = 0.8
        train_ds_len = int(len(ds) * train_ratio)
        train_ds.append(Subset(ds, np.arange(0, train_ds_len, 1)))
        val_ds.append(Subset(ds, np.arange(train_ds_len, len(ds), 1)))
    train_ds = ConcatDataset(train_ds)
    val_ds = ConcatDataset(val_ds)
    print(len(train_ds), len(val_ds))
    return train_ds, val_ds

In [11]:
def load_test_ds():
    test_ds = [MouseDatasetSegNewBehav(file_id=args.file_id, segment_num=args.segment_num, seg_idx=i, data_split="test", 
                               vid_type=args.vid_type, seq_len=args.seq_len, predict_offset=1, 
                                       behav_mode=args.behav_mode, norm_mode="01") 
               for i in range(args.segment_num)]
    test_ds = ConcatDataset(test_ds)
    return test_ds

In [12]:
def train_model():
    
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
    
    train_ds, val_ds = load_train_val_ds()

    train_dataloader = DataLoader(dataset=train_ds, batch_size=args.batch_size, shuffle=True, num_workers=8)
    val_dataloader = DataLoader(dataset=val_ds, batch_size=args.batch_size, shuffle=False, num_workers=8)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    best_train_spike_loss = np.inf
    best_val_spike_loss = np.inf
    train_loss_list = []
    val_loss_list = []

    # start training
    ct = 0
    
    for epoch in range(args.epochs):

        print("Start epoch", epoch)

        model.train()

        epoch_train_loss, epoch_train_spike_loss = 0, 0

        for (image, behav, spikes) in train_dataloader:

            image, behav, spikes = image.to(device), behav.to(device), spikes.to(device)
            
            pred = model(image, behav)

            spike_loss = nn.functional.poisson_nll_loss(pred, spikes, reduction='mean', log_input=False)
            
            l1_reg, l1_reg_num_param = 0.0, 0
            for name, param in model.named_parameters():
                if name == "behav_encoder.layers.1.weight":
                    l1_reg += param.abs().sum()
                    l1_reg_num_param += param.shape[0]*param.shape[1]
            l1_reg /= l1_reg_num_param

            total_loss = spike_loss + args.l1_reg_w * l1_reg

            epoch_train_loss += total_loss.item()
            epoch_train_spike_loss += spike_loss.item()

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        epoch_train_loss = epoch_train_loss / len(train_dataloader)
        epoch_train_spike_loss = epoch_train_spike_loss / len(train_dataloader)

        train_loss_list.append(epoch_train_loss)
        
        print("Epoch {} train loss: {}".format(epoch, epoch_train_loss))

        if epoch_train_spike_loss < best_train_spike_loss:

            print("save train model at epoch", epoch)
            torch.save(model.state_dict(), args.best_train_path)
            best_train_spike_loss = epoch_train_spike_loss

        model.eval()

        epoch_val_spike_loss = 0

        with torch.no_grad():      

            for (image, behav, spikes) in val_dataloader:

                image, behav, spikes = image.to(device), behav.to(device), spikes.to(device)

                pred = model(image, behav)

                loss = nn.functional.poisson_nll_loss(pred, spikes, reduction='mean', log_input=False)

                epoch_val_spike_loss += loss.item()

        epoch_val_spike_loss = epoch_val_spike_loss / len(val_dataloader)

        val_loss_list.append(epoch_val_spike_loss)
        
        print("Epoch {} val loss: {}".format(epoch, epoch_val_spike_loss))
        
        if epoch_val_spike_loss < best_val_spike_loss:
            ct = 0

            print("save val model at epoch", epoch)
            torch.save(model.state_dict(), args.best_val_path)
            best_val_spike_loss = epoch_val_spike_loss
        else:
            ct += 1
            if ct >=5:
                print('stop training')
                break

        print("End epoch", epoch)
        
    return train_loss_list, val_loss_list

In [13]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')    

for file_id, num_neurons in  [("070921_J553RT", 68), ("101521_J559NC", 49) , ("110421_J569LT", 32)]:
    for behav_mode, behav_dim in [("all", 3)]:  
    # pass all behav variables
    # take 4 (head&eye movement) for shifter
    # take 3 (the sensorium original) variable as model main component's input
        for seq_len in range(1, 2): 
            print(file_id, behav_mode, seq_len)
            
            args.file_id = file_id
            args.vid_type = "vid_mean"
            args.num_neurons = num_neurons
            args.shifter=True

            args.behav_mode = behav_mode
            args.behav_dim = behav_dim
            
            args.seq_len = seq_len

            args.best_train_path = "/hdd/yuchen/train_baseline_{}_{}_seq_{}.pth".format(
                args.file_id,  'sens_orig', args.seq_len)
            args.best_val_path = "/hdd/yuchen/val_baseline_{}_{}_seq_{}.pth".format(
                args.file_id, 'sens_orig' , args.seq_len)

            model = LSTMPerNeuronCombiner(num_neurons=args.num_neurons, 
                                          behav_dim=args.behav_dim, 
                                          k1=7, k2=7, k3=7, 
                                          seq_len=args.seq_len,
                                          hidden_size=args.hidden_size).to(device)

            train_loss_list, val_loss_list = train_model()

070921_J553RT all 1
30120 7540
Start epoch 0
Epoch 0 train loss: 0.9536194478051138
save train model at epoch 0
Epoch 0 val loss: 0.657522960503896
save val model at epoch 0
End epoch 0
Start epoch 1
Epoch 1 train loss: 0.8925990654250323
save train model at epoch 1
Epoch 1 val loss: 0.6512650827566783
save val model at epoch 1
End epoch 1
Start epoch 2
Epoch 2 train loss: 0.8599661346209251
save train model at epoch 2
Epoch 2 val loss: 0.639069265127182
save val model at epoch 2
End epoch 2
Start epoch 3
Epoch 3 train loss: 0.8299451116788186
save train model at epoch 3
Epoch 3 val loss: 0.6344329416751862
save val model at epoch 3
End epoch 3
Start epoch 4
Epoch 4 train loss: 0.8037741032697386
save train model at epoch 4
Epoch 4 val loss: 0.6259421209494272
save val model at epoch 4
End epoch 4
Start epoch 5
Epoch 5 train loss: 0.7812813040563615
save train model at epoch 5
Epoch 5 val loss: 0.6260844190915426
End epoch 5
Start epoch 6
Epoch 6 train loss: 0.760532091734773
save trai

In [None]:
model

eval

In [14]:
# default is smoothing with 2 second, 48 ms per frame
def smoothing_with_np_conv(nsp, size=int(2000/48)):
    np_conv_res = []
    for i in range(nsp.shape[1]):
        np_conv_res.append(np.convolve(nsp[:, i], np.ones(size)/size, mode="same"))        
    np_conv_res = np.transpose(np.array(np_conv_res))
    return np_conv_res

In [15]:
def evaluate_model(model, weights_path, dataset, device):

    dl = DataLoader(dataset=dataset, batch_size=256, shuffle=False, num_workers=4)
    
    model.load_state_dict(torch.load(weights_path))

    ground_truth_all = []
    pred_all = []
    
    model.eval()
    
    with torch.no_grad():      
        
        for (image, behav, spikes) in dl:
            
            image = image.to(device)
            behav = behav.to(device)
            
            pred = model(image, behav)
            
            ground_truth_all.append(spikes.numpy())
            pred_all.append(pred.cpu().numpy())
    
    return np.concatenate(pred_all, axis=0), np.concatenate(ground_truth_all, axis=0)

In [16]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')    

# for file_id, num_neurons in [("070921_J553RT", 68)]:
for file_id, num_neurons in  [("070921_J553RT", 68), ("101521_J559NC", 49) , ("110421_J569LT", 32)]:
    

    for behav_mode, behav_dim in [("all", 3)]:  
    # pass all behav variables
    # take 4 (head&eye movement) for shifter
    # take 3 (the sensorium original) variable as model main component's input

        for seq_len in range(1, 2): 

            print(file_id, behav_mode, seq_len)
            
            args.file_id = file_id
            args.num_neurons = num_neurons

            args.behav_mode = behav_mode
            args.behav_dim = behav_dim
            
            args.seq_len = seq_len

            args.best_train_path = "/hdd/yuchen/train_baseline_{}_{}_seq_{}.pth".format(
                args.file_id,  'sens_orig', args.seq_len)
            args.best_val_path = "/hdd/yuchen/val_baseline_{}_{}_seq_{}.pth".format(
                args.file_id, 'sens_orig' , args.seq_len)

            model = LSTMPerNeuronCombiner(num_neurons=args.num_neurons, 
                                          behav_dim=args.behav_dim, 
                                          k1=7, k2=7, k3=7, 
                                          seq_len=args.seq_len,
                                          hidden_size=args.hidden_size).to(device)
            
            train_ds, val_ds = load_train_val_ds()
            test_ds = load_test_ds()
            
            pred, label = evaluate_model(model, weights_path=args.best_val_path, dataset=test_ds, device=device)
            cor_array = cor_in_time(pred, label)
        #     print("best val model on test dataset, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
            pred = smoothing_with_np_conv(pred)
            label = smoothing_with_np_conv(label)
            # print("R2", "{:.6f}".format(r2_score(label.T, pred.T)))
            print("MSE", "{:.6f}".format(mean_squared_error(label, pred)))
            cor_array = cor_in_time(pred, label)
            print("mean corr, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
            # print("max corr", "{:.6f}".format(np.max(cor_array)))
            # print("min corr", "{:.6f}".format(np.min(cor_array)))

070921_J553RT all 1
30120 7540
MSE 0.055057
mean corr, 0.623+-0.142
101521_J559NC all 1
42410 10610
MSE 0.082836
mean corr, 0.579+-0.135
110421_J569LT all 1
32940 8240
MSE 0.091109
mean corr, 0.498+-0.154
