In [1]:
import os.path as op
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import hnn_core
from hnn_core import calcium_model, simulate_dipole, read_params
from hnn_core.network_models import add_erp_drives_to_jones_model
from sklearn.model_selection import ShuffleSplit
from sklearn.linear_model import LinearRegression
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import utils
device = 'cpu'

--No graphics will be displayed.


In [22]:
#LSTM/GRU architecture for decoding
class model_lstm(nn.Module):
    def __init__(self, input_size, output_size, hidden_dim=10, n_layers=2, dropout=0.1, device='cpu', bidirectional=False):
        super(model_lstm, self).__init__()

        #multiplier based on bidirectional parameter
        if bidirectional:
            num_directions = 2
        else:
            num_directions = 1

        # Defining some parameters
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers * num_directions
        self.device = device
        self.dropout = dropout
        self.bidirectional = bidirectional

        #Defining the layers
        # LSTM Layer
        self.lstm = nn.LSTM(input_size, hidden_dim, n_layers, batch_first=True, dropout=dropout)   

        # Fully connected layer
        self.fc = nn.Linear(hidden_dim*num_directions, output_size)
    
    def forward(self, x):
        batch_size = x.size(0)
        # Initializing hidden state for first input using method defined below
        hidden = self.init_hidden(batch_size)
        # Passing in the input and hidden state into the model and obtaining outputs
        out, hidden = self.lstm(x, hidden)
        
        # Reshaping the outputs such that it can be fit into the fully connected layer
        out = out.contiguous()
        out = self.fc(out)
        return out
    
    def init_hidden(self, batch_size):
        # This method generates the first hidden state of zeros which we'll use in the forward pass
        weight = next(self.parameters()).data.to(self.device)

        # LSTM cell initialization
        hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(self.device),
                      weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(self.device))
    

        return hidden


def add_noise(net):
    rate = 10
    weight = 0.01
    # Add Poisson drives
    weights_ampa_d1 = {'L2_pyramidal': weight, 'L5_pyramidal': weight,
                       'L2_basket': weight}
    rates_d1 = {'L2_pyramidal': rate, 'L5_pyramidal': rate, 'L2_basket': rate}

    net.add_poisson_drive(
        name='distal', tstart=0, tstop=None, rate_constant=rates_d1, location='distal', n_drive_cells='n_cells',
        cell_specific=True, weights_ampa=weights_ampa_d1, weights_nmda=None, space_constant=1e50,
        synaptic_delays=0.0, probability=1.0, event_seed=1, conn_seed=2)

    weights_ampa_p1 = {'L2_pyramidal': weight, 'L5_pyramidal': weight,
                       'L2_basket': weight, 'L5_basket': weight}
    rates_p1 = {'L2_pyramidal': rate, 'L5_pyramidal': rate, 'L2_basket': rate, 'L5_basket': rate}

    net.add_poisson_drive(
        name='proximal', tstart=0, tstop=None, rate_constant=rates_p1, location='proximal', n_drive_cells='n_cells',
        cell_specific=True, weights_ampa=weights_ampa_p1, weights_nmda=None, space_constant=1e50,
        synaptic_delays=0.0, probability=1.0, event_seed=3, conn_seed=4)

In [23]:
%%capture
hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
params.update({'N_pyr_x': 1, 'N_pyr_y': 1})
net = calcium_model(params)
add_noise(net)
dpl = simulate_dipole(net, dt=0.05, tstop=1000, record_vsec='all', record_isec='all')

In [24]:
cell_type = 'L5_pyramidal'
cell_gids = net.gid_ranges[cell_type]


cell_voltage_list = list()
cell_current_list = list()
for gid in cell_gids:
    voltage_list = list()
    # Get voltages
    voltage_list = [vsec for vsec in  net.cell_response.vsec[0][gid].values()]
    cell_voltage_list.append(voltage_list)

    current_list = list()
    # Get currents
    for sec_name in net.cell_response.isec[0][gid].keys():
        current_list.extend([isec for isec in net.cell_response.isec[0][gid][sec_name].values()])

    cell_current_list.append(current_list)

# Recordings stored in shape (num_cells x num_rec_sites x time)
cell_voltages = np.stack(cell_voltage_list)
cell_currents = np.stack(cell_current_list)

n_voltages = cell_voltages.shape[1]  # first n_volumns used as prediction target

sim_data = np.concatenate([cell_voltages, cell_currents], axis=1)


In [25]:
class MultiComp_Dataset(torch.utils.data.Dataset):
    #'Characterizes a dataset for PyTorch'
    def __init__(self, sim_data, n_voltages, data_step_size=1, window_size=10, scaler=None, device='cpu'):
        self.sim_data = sim_data
        self.n_voltages = n_voltages
        self.data_step_size = data_step_size
        self.window_size = window_size
        self.device = device
        
        self.n_cells, self.n_rec_sites, self.n_times = sim_data.shape
        self.sim_data_unfolded = self.process_data(sim_data)
        
        if self.scaler is None:
            self.scaler = StandardScaler()
            self.scaler.fit(np.vstack(self.neuralData_list))
        
        # X is one step behind y
        self.X_tensor = self.sim_data_unfolded[:, :-1, :]
        self.y_tensor = self.sim_data_unfolded[:, 1:, :n_voltages]
        assert self.X_tensor.shape[0] == self.y_tensor.shape[0]
        self.num_samples = self.X_tensor.shape[0]

    
    def __len__(self):
        #'Denotes the total number of samples'
        return self.num_samples

    def __getitem__(self, slice_index):
        return self.X_tensor[slice_index,:,:], self.y_tensor[slice_index,:,:]
    
    def process_data(self, sim_data):
        sim_data_tensor = torch.from_numpy(sim_data[0]).transpose(0, 1)
        sim_data_unfolded = sim_data_tensor.unfold(0, self.window_size + 1, self.data_step_size).transpose(1, 2)
        return sim_data_unfolded
        
        

In [26]:
train_size = int(sim_data.shape[2] * 0.6)

In [27]:
training_set = MultiComp_Dataset(sim_data[:,:,:train_size], n_voltages)
validation_set = MultiComp_Dataset(sim_data[:,:,train_size:], n_voltages)

_, _, input_size = training_set[:][0].numpy().shape
_, _, output_size = training_set[:][1].numpy().shape

In [28]:
model = model_lstm(input_size=input_size, output_size=output_size)

In [29]:
lr = 1e-2
weight_decay = 1e-5
max_epochs = 1000
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [32]:
batch_size = 10
num_cores = 32
train_params = {'batch_size': batch_size, 'shuffle': True, 'num_workers': num_cores, 'pin_memory':False}
training_generator = torch.utils.data.DataLoader(training_set, **train_params)

validation_params = {'batch_size': batch_size, 'shuffle': True, 'num_workers': num_cores, 'pin_memory':False}
validation_generator = torch.utils.data.DataLoader(validation_set, **validation_params)

In [33]:
#Train model
loss_dict = utils.train_validate_model(model, optimizer, criterion, max_epochs, training_generator, validation_generator, device, 10, 5)


****......
Epoch: 10/1000 ... Train Loss: 79.3591  ... Validation Loss: 76.7248
 Early Stop; Min Epoch: 4


In [None]:
train_pred = utils.evaluate_model(model_ann, training_eval_generator, device)

In [None]:
#Evaluate trained model
ann_train_pred = mocap_functions.evaluate_model(model_ann, training_eval_generator, device)
ann_test_pred = mocap_functions.evaluate_model(model_ann, testing_generator, device)

#Compute decoding performance
ann_train_corr = mocap_functions.matrix_corr(ann_train_pred,y_train_data)
ann_test_corr = mocap_functions.matrix_corr(ann_test_pred,y_test_data)



In [46]:
# #Generate cv_dict for regular train/test/validate split (no rolling window)
# cv_split = ShuffleSplit(n_splits=5, test_size=.25, random_state=0)
# val_split = ShuffleSplit(n_splits=1, test_size=.25, random_state=0)
# cv_dict = {}
# for fold, (train_val_idx, test_idx) in enumerate(cv_split.split(np.arange(num_trials))):
#     for t_idx, v_idx in val_split.split(train_val_idx): #No looping, just used to split train/validation sets
#         cv_dict[fold] = {'train_idx':train_val_idx[t_idx], 
#                          'test_idx':test_idx, 
#                          'validation_idx':train_val_idx[v_idx]} 