In [1]:
import sys
sys.path.append('../code')
# sys.path.append('../externals/mamba/')

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from sklearn.model_selection import ShuffleSplit
import pickle
import model_utils
from torch import nn
from mamba_ssm import Mamba


# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0")
#device = torch.device('cpu')

torch.backends.cudnn.benchmark = True


In [2]:
noise_fold = 0
data_dict = model_utils.get_marker_decode_dataframes(noise_fold=noise_fold)
wrist_df = data_dict['wrist_df']
task_neural_df = data_dict['task_neural_df']
notask_neural_df = data_dict['notask_neural_df']
metadata = data_dict['metadata']
cv_dict = data_dict['cv_dict']

neuron_list = notask_neural_df['unit'].unique()

notask_time_neural_mask = notask_neural_df['unit'] != 'time'
notask_neural_df = notask_neural_df[notask_time_neural_mask]

task_time_neural_mask = task_neural_df['unit'] != 'time'
task_neural_df = task_neural_df[task_time_neural_mask]

wrist_mask = wrist_df['name'] != 'time'
wrist_df = wrist_df[wrist_mask]

In [3]:
neural_offset = 10 # try 50-150 ms offset
window_size = 70
label_col = 'layout'

In [6]:
#LSTM/GRU architecture for decoding
#RNN architecture for decoding kinematics
class model_mamba(nn.Module):
    def __init__(self, input_size, output_size, d_model, d_state=16, d_conv=4, expand=2, dropout=0.2, device=device,
                 cat_features=None):
        super(model_mamba, self).__init__()

        # Defining some parameters
        self.device = device
        self.dropout = dropout
        self.cat_features = cat_features
        self.input_size = input_size

        if self.cat_features is not None:
            self.num_cat_features = np.sum(self.cat_features).astype(int)
            self.input_size = self.input_size - self.num_cat_features

            
        else:
            self.fc = nn.Linear(in_features=d_model, out_features=output_size).to(device)

        # self.fc = nn.Linear((input_), output_size)
        self.mamba = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand).to(device)
    
    def forward(self, x):
        batch_size = x.size(0)

        out = self.mamba(x)
        out = self.fc(out)
        return out, None, None


In [13]:
def run_rnn(pred_df, neural_df, neural_offset, cv_dict, metadata, task_info=True,
            window_size=50, num_cat=0, label_col=None, flip_outputs=False, temperature=0.1, dropout=0.5):
    exclude_processing = None
    criterion = model_utils.mse
    # if task_info:
    #     criterion = partial(contrast_mse, temperature=temperature)
    #     if num_cat > 0:
    #         exclude_processing = np.zeros(len(neural_df['unit'].unique()))
    #         exclude_processing[-num_cat:] = np.ones(num_cat)
    #         exclude_processing = exclude_processing.astype(bool)

    # else:
    #     criterion = mse

    data_arrays, generators = model_utils.make_generators(
    pred_df, neural_df, neural_offset, cv_dict, metadata, exclude_neural=exclude_processing,
    window_size=window_size, flip_outputs=flip_outputs, batch_size=1000, label_col=label_col)

    # Unpack tuple into variables
    training_set, validation_set, testing_set = data_arrays
    training_generator, training_eval_generator, validation_generator, testing_generator = generators

    X_train_data = training_set[:][0][:,-1,:].detach().cpu().numpy()
    y_train_data = training_set[:][1][:,-1,:].detach().cpu().numpy()

    X_test_data = testing_set[:][0][:,-1,:].detach().cpu().numpy()
    y_test_data = testing_set[:][1][:,-1,:].detach().cpu().numpy()

    #Define hyperparameters
    lr = 1e-3
    # weight_decay = 1e-4
    weight_decay = 1e-4
    hidden_dim = 600
    n_layers = 2
    max_epochs = 1000
    input_size = X_train_data.shape[1] 
    output_size = y_train_data.shape[1] 

    # model_rnn = model_lstm(input_size, output_size, hidden_dim, n_layers, dropout, device, cat_features=exclude_processing).to(device)
    model_rnn = model_utils.model_lstm(input_size, output_size, hidden_dim, n_layers, dropout, device, cat_features=exclude_processing).to(device)

    # model_rnn = model_mamba(input_size, output_size, d_model=input_size, d_state=128, d_conv=4, expand=2)


    # Define Loss, Optimizerints h
    optimizer = torch.optim.Adam(model_rnn.parameters(), lr=lr, weight_decay=weight_decay)

    #Train model
    loss_dict = model_utils.train_validate_model(model_rnn, optimizer, criterion, max_epochs, training_generator, validation_generator, device, 10, 5)

    #Evaluate trained model
    rnn_train_pred = model_utils.evaluate_model(model_rnn, training_eval_generator, device)
    rnn_test_pred = model_utils.evaluate_model(model_rnn, testing_generator, device)

    rnn_train_corr = model_utils.matrix_corr(rnn_train_pred, y_train_data)
    rnn_test_corr = model_utils.matrix_corr(rnn_test_pred, y_test_data)

    res_dict = {'loss_dict': loss_dict,
                'train_pred': rnn_train_pred, 'test_pred': rnn_test_pred,
                'train_corr': rnn_train_corr, 'test_corr': rnn_test_corr}

    return model_rnn, res_dict

In [14]:
# func_dict = {'wiener': contrastive_functions.run_wiener, 'rnn': contrastive_functions.run_rnn}
func_dict = {'rnn': run_rnn}

fpath = '../data/neuron_num_results/'

num_repeats = 1

num_neuron_results_dict = {'noise_fold': noise_fold}

# Filter neural_df with task info to random subset of neurons
layout_mask = task_neural_df['unit'].str.contains(pat='layout')
task_neural_df_filtered = task_neural_df[np.logical_or.reduce([layout_mask])].reset_index(drop=True)

# df_dict = {'task': {'df': task_neural_df_filtered, 'task_info': True, 'num_cat': 4, 'flip_outputs': True},
#             'notask': {'df': notask_neural_df_filtered, 'task_info': False, 'num_cat': 0, 'flip_outputs': True}}

df_dict = {'notask': {'df': notask_neural_df, 'task_info': False, 'num_cat': 0, 'flip_outputs': True}}

decode_results = dict()
for func_name, func in func_dict.items():
    decode_results[func_name] = dict()
    for df_type, pred_df in df_dict.items():
        # print(f'{func_name}_{df_type} num_neurons: {num_neurons}; repeat {repeat_idx}')

        model, res_dict = func(wrist_df, pred_df['df'], neural_offset, cv_dict, metadata, task_info=pred_df['task_info'],
                                window_size=window_size, num_cat=pred_df['num_cat'], label_col=label_col, flip_outputs=pred_df['flip_outputs'])

        decode_results[func_name][df_type] = res_dict

        # # Save results on every loop in case early stop
        # num_neuron_results_dict[f'repeat_{repeat_idx}'][f'num_neuron_{num_neurons}'] = decode_results
        # # #Save metadata
        # output = open(f'{fpath}num_neuron_results.pkl', 'wb')
        # pickle.dump(num_neuron_results_dict, output)
        # output.close()

        # if func_name == 'rnn':
        #     torch.save(model.state_dict(), f'{fpath}models/{df_type}_neurons{num_neurons}_repeat{repeat_idx}.pt')





**

Exception ignored in: <function _releaseLock at 0x7f5074a09e10>
Traceback (most recent call last):
  File "/users/ntolley/.conda/envs/see/lib/python3.10/logging/__init__.py", line 228, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 


****

..

In [None]:
res_dict['test_corr']

array([0.78619208, 0.85263139, 0.78664979, 0.77634815, 0.85054791,
       0.80889051, 0.75773891, 0.84999737, 0.81096381, 0.7477906 ,
       0.84765538, 0.81446937, 0.32611803, 0.60442949, 0.5691553 ,
       0.70905636, 0.79873693, 0.65244833, 0.82744105, 0.87781194,
       0.78477901, 0.84170869, 0.88496196, 0.80491002, 0.79456252,
       0.8728516 , 0.83119905, 0.83054333, 0.87969224, 0.84003424,
       0.8473001 , 0.88312128, 0.84975987, 0.85022376, 0.88566084,
       0.85071061, 0.79987327, 0.86219309, 0.79723843, 0.81339499,
       0.86437358, 0.682533  , 0.8272171 , 0.8683661 , 0.52161455,
       0.82627647, 0.82650184, 0.35144037, 0.8014648 , 0.8519737 ,
       0.78406712, 0.82789316, 0.85126832, 0.80678569, 0.83949846,
       0.85967821, 0.81182206, 0.84359117, 0.86252679, 0.81172522,
       0.7609846 , 0.82993026, 0.71672999, 0.82295128, 0.880294  ,
       0.69096293])