In [1]:
import sys
sys.path.append('../code')

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from sklearn.model_selection import ShuffleSplit
import pickle

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0")
#device = torch.device('cpu')

torch.backends.cudnn.benchmark = True

import contrastive_functions





In [2]:
data_dict = contrastive_functions.get_marker_decode_dataframes()
wrist_df = data_dict['wrist_df']
task_neural_df = data_dict['task_neural_df']
notask_neural_df = data_dict['notask_neural_df']
metadata = data_dict['metadata']

In [3]:
trial_ids = task_neural_df['trial'].unique()
num_trials_filtered = len(trial_ids)

#Generate cv_dict for regular train/test/validate split
cv_split = ShuffleSplit(n_splits=5, test_size=.25, random_state=3)
val_split = ShuffleSplit(n_splits=1, test_size=.25, random_state=3)
cv_dict = {}
for fold, (train_val_idx, test_idx) in enumerate(cv_split.split(trial_ids)):
    for t_idx, v_idx in val_split.split(train_val_idx): #No looping, just used to split train/validation sets
        cv_dict[fold] = {'train_idx':trial_ids[train_val_idx[t_idx]], 
                        'test_idx':trial_ids[test_idx], 
                        'validation_idx':trial_ids[train_val_idx[v_idx]]} 

In [4]:
neural_offset = 5 # try 50-150 ms offset
window_size = 70
label_col = 'layout'
func_dict = {'wiener': contrastive_functions.run_wiener, 'rnn': contrastive_functions.run_rnn}

fpath = '../data/SPK20220308/neuron_num_results/'

num_repeats = 5
# num_neuron_list = np.arange(2,51,4)
num_neuron_list = [2,6,10]


num_neuron_results_dict = {'num_neuron_list': num_neuron_list}
for repeat_idx in range(num_repeats):
    rng = np.random.default_rng(repeat_idx) # new set of shuffled neurons seeded by repeat_idx
    random_units = rng.choice(range(85), size=85).astype(str)

    num_neuron_results_dict[f'repeat_{repeat_idx}'] = {'random_units': random_units}
    for num_neurons in num_neuron_list:

        # Filter neural_df with task info to random subset of neurons
        task_unit_mask = np.in1d(task_neural_df['unit'].values, random_units[:num_neurons])
        layout_mask = task_neural_df['unit'].str.contains(pat='layout')

        task_neural_df_filtered = task_neural_df[np.logical_or.reduce([task_unit_mask, layout_mask])].reset_index(drop=True)

        # Filter neural_df without task info to random subset of neurons
        notask_unit_mask = np.in1d(notask_neural_df['unit'].values, random_units[:num_neurons])
        notask_neural_df_filtered = notask_neural_df[np.logical_or.reduce([notask_unit_mask])].reset_index(drop=True)

        df_dict = {'task': {'df': task_neural_df_filtered, 'task_info': True, 'num_cat': 4}, # num_cat = number of categorical features
                   'notask': {'df': notask_neural_df_filtered, 'task_info': False, 'num_cat': 0}}
        

        decode_results = dict()
        for func_name, func in func_dict.items():
            decode_results[func_name] = dict()
            for df_type, pred_df in df_dict.items():
                model, res_dict = func(wrist_df, pred_df['df'], neural_offset, cv_dict, metadata, task_info=pred_df['task_info'],
                                       window_size=window_size, num_cat=pred_df['num_cat'], label_col=label_col)

                decode_results[func_name][df_type] = res_dict

                # Save results on every loop in case early stop
                num_neuron_results_dict[f'repeat_{repeat_idx}'][f'num_neuron_{num_neurons}'] = decode_results
                #Save metadata
                output = open(f'{fpath}num_neuron_results.pkl', 'wb')
                pickle.dump(num_neuron_results_dict, output)
                output.close()

                if func_name == 'rnn':
                    torch.save(model.state_dict(), f'{fpath}models/{df_type}_neurons{num_neurons}_repeat{repeat_idx}.pt')





**********
Epoch: 10/1000 ... Train Loss: 1.4042  ... Validation Loss: 1.3218
*..**..*..
Epoch: 20/1000 ... Train Loss: 1.2944  ... Validation Loss: 1.2775
...*......
Epoch: 30/1000 ... Train Loss: 1.2592  ... Validation Loss: 1.2678
 Early Stop; Min Epoch: 24
**********
Epoch: 10/1000 ... Train Loss: 0.3868  ... Validation Loss: 0.2988
*.**.*****
Epoch: 20/1000 ... Train Loss: 0.2594  ... Validation Loss: 0.2403
.**..**...
Epoch: 30/1000 ... Train Loss: 0.2077  ... Validation Loss: 0.2239
**...... Early Stop; Min Epoch: 32
**********
Epoch: 10/1000 ... Train Loss: 1.1386  ... Validation Loss: 1.1898
...... Early Stop; Min Epoch: 10
**********
Epoch: 10/1000 ... Train Loss: 0.1518  ... Validation Loss: 0.1854
*...*...*.
Epoch: 20/1000 ... Train Loss: 0.1035  ... Validation Loss: 0.1793
*.....*...
Epoch: 30/1000 ... Train Loss: 0.0804  ... Validation Loss: 0.1816
... Early Stop; Min Epoch: 27
*********.
Epoch: 10/1000 ... Train Loss: 1.1335  ... Validation Loss: 1.2239
..... Early Stop;

In [9]:
task_neural_df_filtered

Unnamed: 0,rates,rates_video,unit,trial,layout,position,count
0,"[0.37068460707699996, 0.20849023637348546, 0.0...","[1.2519973356039417e-16, 7.502509854418641e-17...",5,0,1,4,288.134765
1,"[0.02535979805162989, 0.07479052105641254, 0.0...","[0.0, 2.250752956325592e-16, 1.636289717497565...",6,0,1,4,422.700333
2,"[-0.7415819161611027, -0.28529480749897473, 0....","[0.0, 0.0, 0.0, 0.0, 1.66289418077035e-15, 2.5...",7,0,1,4,727.621281
3,"[0.3378164474986023, 0.2633298055962907, 0.154...","[0.0, 0.0, 0.0, 1.351476007222962e-15, 0.0, 0....",11,0,1,4,2738.257947
4,"[4.7903802490916245, 10.012975291719446, 14.57...","[10.67262691202719, 31.225806882942088, 26.076...",14,0,1,4,1555.370019
...,...,...,...,...,...,...,...
8597,"[-3.0800097490467784, -0.5617975973316369, 1.5...","[0.0, 0.0, 6.487920009286787, 12.0518604459026...",83,217,2,2,1140.146988
8598,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",layout_1,217,2,2,0.000000
8599,"[1.0000000000000004, 1.0000000000000004, 1.000...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",layout_2,217,2,2,83.000000
8600,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",layout_3,217,2,2,0.000000
