In [1]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

import movi
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from collections import OrderedDict
import itertools
from functools import partial
import math

In [2]:
%cd /home/yuchen/2023-neurips-mouse

/home/yuchen/2023-neurips-mouse


In [6]:
from torch.utils.data import Subset, DataLoader, ConcatDataset
from mouse_model.data_utils import MouseDatasetSeg
from mouse_model.evaluation import cor_in_time


def load_train_val_ds():
    ds_list = [MouseDatasetSeg(file_id=args.file_id, seg_idx=i, data_split="train", vid_type="vid_shift_mean", 
                           seq_len=args.seq_len, predict_offset=1, smoothing=None, behav_prod=True) for i in range(args.split_range)]
    train_ds, val_ds = [], []
    for ds in ds_list:
        train_ratio = 0.8
        train_ds_len = int(len(ds) * train_ratio)
        train_ds.append(Subset(ds, np.arange(0, train_ds_len, 1)))
        val_ds.append(Subset(ds, np.arange(train_ds_len, len(ds), 1)))
    train_ds = ConcatDataset(train_ds)
    val_ds = ConcatDataset(val_ds)
    return train_ds, val_ds

def load_test_ds():
    test_ds = [MouseDatasetSeg(file_id=args.file_id, seg_idx=i, data_split="test", vid_type="vid_shift_mean", 
                               seq_len=args.seq_len, predict_offset=1, smoothing=None, behav_prod=False) 
               for i in range(args.split_range)]
    test_ds = ConcatDataset(test_ds)
    return test_ds

def normalize_movie(movie):
    """Normalize the range of gray levels in a movie"""
    norm_movie = movie.astype(float)
    norm_movie -= norm_movie.min()
    if not np.isclose(norm_movie.max(), 0):
        norm_movie /= norm_movie.max()
    return norm_movie

def normalize_movie_neg_pos_1(movie):
    """Normalize the range of gray levels in a movie"""
    movie = normalize_movie(movie)
    movie = (movie - 0.5) * 2
    return movie

def smoothing_with_np_conv(nsp, size=int(2000/48), return_fr=False):
    np_conv_res = []
    for i in range(nsp.shape[1]):
        np_conv_res.append(np.convolve(nsp[:, i], np.ones(size)/size, mode="same"))        
    np_conv_res = np.transpose(np.array(np_conv_res))
    if return_fr:
        return np_conv_res/0.016
    return np_conv_res

# https://stackoverflow.com/questions/52201081/compute-product-of-all-combinations-of-a-list
def get_all_comb_prod(array):
    res = []
    combinations_with_r = partial(lambda r: itertools.combinations(array, r = r))
    for r in map(combinations_with_r, range(1, len(array) + 1)):
        for j in r:
            res.append(np.prod(j))       
    return res

class MouseDatasetSeg(Dataset):
    
    def __init__(self, file_id, seg_idx, data_split="train", vid_type="vid_shift_mean", 
                 seq_len=1, predict_offset=1, smoothing=None, behav_prod=False):
        
        self.seq_len = seq_len
        self.predict_offset = predict_offset
        self.behav_prod = behav_prod
        
        data_dir = "{}/{}".format(ROOT_DIR_48ms_3_segment, file_id)
        print(data_dir)
        
        # firing rate below 3 Hz
        bad_neuron_index_dict = {"070921_J553RT": [1,  20,  21,  25,  27,  30,  31,  35,  39,  45,  46,  49,  50,  
                                                   51,  52,  53,  54,  55,  57,  58,  59,  64,  65,  68,  75,  77,  
                                                   79,  81,  83,  84,  87,  88,  90,  91,  94,  95,  97, 104, 105, 107], 
                                 "101521_J559NC": [1,  4,  9, 22, 25, 27, 28, 29, 43, 45, 46, 47, 51, 53], 
                                 "110421_J569LT": [0,  1,  2,  5,  6,  8, 11, 12, 13, 16, 18, 22, 24, 26, 28, 30, 41, 
                                                   44, 47, 49]}
        
        all_nsp = np.load("{}/{}_nsp_seg_{}.npy".format(data_dir, data_split, seg_idx))
        good_nsp = []
        for i in range(all_nsp.shape[1]):
            if i not in bad_neuron_index_dict[file_id]:
                good_nsp.append(all_nsp[:, i])
        self.nsp = np.transpose(np.array(good_nsp))
        
            
        self.images = np.load("{}/{}_{}_seg_{}.npy".format(data_dir, data_split, vid_type, seg_idx))
        self.images = normalize_movie(self.images)
        self.images = np.expand_dims(self.images, axis=1)
        
#         behav_key_list = ['speed', 'gz', 'pitch', 'roll', 'phi', 'th', 'eyerad']
        behav_key_list = ['eyerad','speed',  'th', 'phi']
        behavior_var_list = [np.load("{}/{}_{}_seg_{}.npy".format(data_dir, data_split, behav_key, seg_idx)) 
                             for behav_key in behav_key_list]
        
        behavior_var_list[0] = [0 if math.isnan(x) else x for x in behavior_var_list[0]]
        temp_lst = []
        for i in range(len(behavior_var_list[0])):
            behavior_var_list[0][i] = ((behavior_var_list[0][i])**2) * np.pi
            if i == 0:
                temp_lst.append(0)
            else:
                temp_lst.append((behavior_var_list[0][i] - behavior_var_list[0][i-1])/48)

        behavior_var_list.insert(1, temp_lst)

        for i in range(len(behavior_var_list)):
            behavior_var_list[i] = [0 if math.isnan(x) else x for x in behavior_var_list[i]]
            behavior_var_list[i] = (behavior_var_list[i])/np.std(behavior_var_list[i])
   
        self.behavior_var = np.stack(behavior_var_list, axis=1)
        self.behavior_var = np.nan_to_num(self.behavior_var)

    def __len__(self):
        len_block = self.seq_len + self.predict_offset
        return self.images.shape[0] - len_block + 1

    def __getitem__(self, idx):
        current_frame = self.images[idx:(idx+self.seq_len)]

        if self.behav_prod:
            current_behavior_var = []
            for i in range(idx, idx+self.seq_len):
                behav_var_frame = np.delete(self.behavior_var[i], 1)
                current_behavior_var.append(get_all_comb_prod(behav_var_frame))
            current_behavior_var = torch.tensor(current_behavior_var, dtype=torch.float)
        else:
            current_behavior_var = torch.tensor(self.behavior_var[idx:(idx+self.seq_len)], dtype=torch.float)
        
        current_behavior_var = current_behavior_var.squeeze().cuda()
#         print(current_behavior_var)
        neural_spikes = torch.tensor(self.nsp[idx+self.seq_len-1+self.predict_offset], dtype=torch.float).cuda()

        current_frame = current_frame.squeeze()
        
        # for vision + behavior
        current_frame = np.array([current_frame, 
                                  np.repeat(current_behavior_var[0].item(), 60*80).reshape(60,-1),
                                 np.repeat(current_behavior_var[1].item(), 60*80).reshape(60,-1), 
                                 np.repeat(current_behavior_var[2].item(), 60*80).reshape(60,-1)])
        current_frame = torch.tensor(current_frame, dtype=torch.float).cuda()
        return {"images": current_frame, "responses": neural_spikes, 
                'behavior': current_behavior_var[:3], 'pupil_center':current_behavior_var[3:] }
        
        # for vision
        # current_frame = np.array([current_frame]) 
        # current_frame = torch.tensor(current_frame, dtype=torch.float).cuda()
        # return {"images": current_frame, "responses": neural_spikes, 'pupil_center':current_behavior_var[3:]}
    
    
def get_dataloaders_one_file(train_ratio=0.8, batch_size=256):
    
    ds_list = [MouseDatasetSeg(file_id=args.file_id, seg_idx=i, data_split="train", vid_type=args.vid_ty, 
                           seq_len=args.seq_len, predict_offset=1, smoothing=None, behav_prod=False) 
               for i in range(args.split_range)]
    print('current split_range is: ' + str(args.split_range))
    print('current vid_type is: ' + str(args.vid_ty))
    train_ds, val_ds = [], []
#     dl1 = TensorDataset()
    for ds in ds_list:
        train_ratio = 0.8
        train_ds_len = int(len(ds) * train_ratio)
        train_ds.append(Subset(ds, np.arange(0, train_ds_len, 1)))
        val_ds.append(Subset(ds, np.arange(train_ds_len, len(ds), 1)))
    train_ds = ConcatDataset(train_ds)
    val_ds = ConcatDataset(val_ds)

    test_ds = [MouseDatasetSeg(file_id=args.file_id, seg_idx=i, data_split="test", vid_type=args.vid_ty, 
                               seq_len=args.seq_len, predict_offset=1, smoothing=None, behav_prod=False)  
               for i in range(10)]
    test_ds = ConcatDataset(test_ds)


    dataloaders = OrderedDict()
    dataloaders['train'] = OrderedDict()
    dataloaders['train'][args.file_id] = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    dataloaders['validation'] = OrderedDict()
    dataloaders['validation'][args.file_id] = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    dataloaders['test'] = OrderedDict()
    dataloaders['test'][args.file_id] = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    return dataloaders


sensorium baseline

In [5]:
for subj in ['110421_J569LT']:
# for subj in ['110421_J569LT','101521_J559NC', '070921_J553RT']:
    for i in [10]:
        for seed in [0]:
            class Args:
                seed = 0
                file_id = subj
                epochs = 100
                batch_size = 256
                l1_weight_behav=6.2591
                l1_weight_comb=3.9773
                seq_len = 1
                split_range = 10
                vid_ty = 'vid_mean'
                best_val_path = None
                best_train_path = None

            args=Args()

            print(i)
            print(subj)
            print('seed: ' + str(args.seed))
            ROOT_DIR_48ms_3_segment = "/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms"

            dataloaders = get_dataloaders_one_file()

            model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
            model_config = {'pad_input': False,
              'stack': -1,
              'layers': 4,
              'input_kern': 9,
              'gamma_input': 6.3831,
              'gamma_readout': 0.0076,
              'hidden_kern': 7,
              'hidden_channels': 64,
              'depth_separable': True,
              'grid_mean_predictor': {'type': None,
               'input_dimensions': 2,
               'hidden_layers': 1,
               'hidden_features': 30,
               'final_tanh': True},
              'init_sigma': 0.1,
              'init_mu_range': 0.3,
              'gauss_type': 'full',
              'shifter': False,
            }

            model = get_model(model_fn=model_fn,
                              model_config=model_config,
                              dataloaders=dataloaders,
                              seed=0,)

            trainer_fn = "sensorium.training.standard_trainer"

            trainer_config = {'max_iter': 100,
                             'verbose': False,
                             'lr_decay_steps': 4,
                             'avg_loss': False,
                             'lr_init': 0.009,
                             }

            trainer = get_trainer(trainer_fn=trainer_fn, 
                                 trainer_config=trainer_config)

            validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=args.seed)
            print(validation_score)

            # weight = "/hdd/yuchen/new_data_sensorium_sota_model_" + str(subj) + "_split" + str(i) + '_'+ str(args.vid_ty) +"seed" + str(args.seed) + "head.pth"
            # torch.save(model.state_dict(), weight)

10
110421_J569LT
seed: 0
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
current split_range is: 10
current vid_type is: vid_mean
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J

KeyboardInterrupt: 

In [12]:
class Args:
                seed = 0
                file_id = subj
                epochs = 100
                batch_size = 256
                l1_weight_behav=6.2591
                l1_weight_comb=3.9773
                seq_len = 1
                split_range = 10
                vid_ty = 'vid_mean'
                best_val_path = None
                best_train_path = None

args=Args()

ROOT_DIR_48ms_3_segment = "/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms"

for subj in ['110421_J569LT','101521_J559NC', '070921_J553RT']:
    args.file_id = subj
    dataloaders = get_dataloaders_one_file()
    model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
    model_config = {'pad_input': False,
              'stack': -1,
              'layers': 4,
              'input_kern': 9,
              'gamma_input': 6.3831,
              'gamma_readout': 0.0076,
              'hidden_kern': 7,
              'hidden_channels': 64,
              'depth_separable': True,
              'grid_mean_predictor': {'type': None,
               'input_dimensions': 2,
               'hidden_layers': 1,
               'hidden_features': 30,
               'final_tanh': True},
              'init_sigma': 0.1,
              'init_mu_range': 0.3,
              'gauss_type': 'full',
              'shifter': True,
    }


    model = get_model(model_fn=model_fn,
                                  model_config=model_config,
                                  dataloaders=dataloaders,
                                  seed=0,)
    pat = "/hdd/yuchen/new_data_sensorium_sota_model_{}_split10_vid_meanseed0vision.pth".format(args.file_id)
    model.load_state_dict(torch.load(pat))

    i='test'
    model.eval()
    correlations,label,pred = get_correlations(model, dataloaders, tier=i, device="cuda:0", as_dict=False, per_neuron=False)
    cor_array = cor_in_time(pred, label)
#     print("best val model on test dataset, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
    pred = smoothing_with_np_conv(pred)
    label = smoothing_with_np_conv(label)
    print("R2", "{:.6f}".format(r2_score(label.T, pred.T)))
    print("MSE", "{:.6f}".format(mean_squared_error(label, pred)))
    cor_array = cor_in_time(pred, label)
    print("mean corr, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
    print("max corr", "{:.6f}".format(np.max(cor_array)))
    print("min corr", "{:.6f}".format(np.min(cor_array)))

/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
current split_range is: 10
current vid_type is: vid_mean
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-data-10-segment-split-70-30-48ms/110421_J569LT
/hdd/aiwenxu/mouse-

In [11]:
from neuralpredictors.training import eval_state, device_state
from neuralpredictors.measures.np_functions import corr, fev

def model_predictions(model, dataloader, data_key, device="cpu"):

    target, output = torch.empty(0), torch.empty(0)
    for batch in dataloader:
        images, responses = (
            batch[:2]
            if not isinstance(batch, dict)
#             else (batch["inputs"], batch["targets"])
            else (batch["images"], batch["responses"])
        )
        batch_kwargs = batch._asdict() if not isinstance(batch, dict) else batch

        with torch.no_grad():
            with device_state(model, device):
                output = torch.cat(
                    (
                        output,
                        (
                            model(images.to(device), data_key=data_key, **batch_kwargs)
                            .detach()
                            .cpu()
                        ),
                    ),
                    dim=0,
                )
            target = torch.cat((target, responses.detach().cpu()), dim=0)

    return target.numpy(), output.numpy()


def get_correlations(
    model, dataloaders, tier=None, device="cpu", as_dict=False, per_neuron=True, **kwargs
):

    correlations = {}
    dl = dataloaders[tier] if tier is not None else dataloaders
    target_, output_ = [], []
    
    for k, v in dl.items():
        target, output = model_predictions(
            dataloader=v, model=model, data_key=k, device=device
        )

        correlations[k] = corr(target, output, axis=0)

        if np.any(np.isnan(correlations[k])):
            warnings.warn(
                "{}% NaNs , NaNs will be set to Zero.".format(
                    np.isnan(correlations[k]).mean() * 100
                )
            )
        correlations[k][np.isnan(correlations[k])] = 0

    if not as_dict:
        correlations = (
            np.hstack([v for v in correlations.values()])
            if per_neuron
            else np.mean(np.hstack([v for v in correlations.values()]))
        )
    return correlations,target,output

from sklearn.metrics import r2_score, mean_squared_error
# for i in ['train', 'validation', 'test']:
for i in ['test']:
    model.eval()
    correlations,label,pred = get_correlations(model, dataloaders, tier=i, device="cuda:0", as_dict=False, per_neuron=False)
    cor_array = cor_in_time(pred, label)
#     print("best val model on test dataset, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
    pred = smoothing_with_np_conv(pred)
    label = smoothing_with_np_conv(label)
    print("R2", "{:.6f}".format(r2_score(label.T, pred.T)))
    print("MSE", "{:.6f}".format(mean_squared_error(label, pred)))
    cor_array = cor_in_time(pred, label)
    print("mean corr, {:.3f}+-{:.3f}".format(np.mean(cor_array), np.std(cor_array)))
    print("max corr", "{:.6f}".format(np.max(cor_array)))
    print("min corr", "{:.6f}".format(np.min(cor_array)))

R2 0.358379
MSE 0.100978
mean corr, 0.441+-0.181
max corr 0.769153
min corr 0.027994
