## Setup

In [1]:
from neuron import h
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import h5py
import json
import os
from scipy import signal
from typing import Union, List, Tuple, Dict
from tqdm import tqdm

from cell_inference.config import paths, params
from cell_inference.cells.simulation import Simulation
from cell_inference.cells.stylizedcell import CellTypes
from cell_inference.utils.random_parameter_generator import Random_Parameter_Generator
from cell_inference.utils.transform.geometry_transformation import pol2cart, cart2pol
from cell_inference.utils.spike_window import first_pk_tr, get_spike_window
from cell_inference.utils.plotting.plot_results import plot_lfp_traces, plot_lfp_heatmap
from cell_inference.utils.feature_extractors.SummaryStats2D import calculate_stats, build_lfp_grid

h.load_file('stdrun.hoc')
h.nrn_load_dll(paths.COMPILED_LIBRARY_REDUCED_ORDER)
h.tstop = params.TSTOP
h.dt = params.DT
h.steps_per_ms = 1/h.dt

geo_standard = pd.read_csv(paths.GEO_REDUCED_ORDER, index_col='id')

## Set up configuration

#### Set batch

In [2]:
batchid = None
if batchid is None:
    batchid = 0
    batch_suf = ''
else:
    batch_suf = '_%d' % batchid

#### Trial configurations

In [3]:
number_cells = 3
number_locs = 2
number_samples = number_cells * number_locs
rand_seed = 12345

inference_list = ['y', 'd', 'theta', 'h', 'phi', 'l_t']  # can use d, theta instead of x, z to represent location
randomized_list = ['alpha']  # randomized parameters not to inferred
randomized_list += inference_list
# parameters not in the two lists above are fixed at default.

#### Simulation configurations

In [4]:
loc_param_list = ['x','y','z','alpha','h','phi']
geo_param_list = ['l_t']

loc_param_default = {'x': 0., 'y': 0., 'z': 50., 
                     'alpha': np.pi/4, 'h': 1., 'phi': 0.}
loc_param_default['d'], loc_param_default['theta'] = cart2pol(loc_param_default['x'], loc_param_default['z'])
geo_param_default = {'l_t': 814.}

loc_param_range = {'x': (-50, 50), 'y': (-1200, 1200), 'z': (50., 200.), 
                   'alpha': (0, np.pi), 'h': (.7071, 1.) ,'phi': (-np.pi, np.pi), 
                   'd': (50., 200.), 'theta': (-np.pi/3, np.pi/3)}
geo_param_range = {'l_t': (15., 865.)}

loc_param_dist = {'x': 'unif', 'y': 'unif', 'z': 'unif', 
                  'alpha': 'unif', 'h': 'unif','phi': 'unif', 'd': 'unif', 'theta': 'norm'}
geo_param_dist = {'l_t': 'unif'}


#### Synapse parameters

In [5]:
# Fixed gmax, not using gmax mapping file
gmax = 0.013

#### Fixed biophysical parameters

In [6]:
# Biophysical parameters
filepath = './cell_inference/resources/biophys_parameters/ReducedOrderL5.json'
with open(filepath) as f:
    full_biophys = json.load(f)

# common parameters
biophys_param = []
biophys_comm = {}

### Create configuration dictionary

In [7]:
config_dict = { 
    'Trial_Parameters': {'number_cells': number_cells, 'number_locs': number_locs, 
                         'number_samples': number_samples, 'rand_seed': rand_seed, 
                         'inference_list': inference_list, 'randomized_list': randomized_list}, 
    'Simulation_Parameters': {'loc_param_list': loc_param_list, 'geo_param_list': geo_param_list, 
                              'loc_param_default': loc_param_default, 'geo_param_default': geo_param_default, 
                              'loc_param_range': loc_param_range, 'geo_param_range': geo_param_range, 
                              'loc_param_dist': loc_param_dist, 'geo_param_dist': geo_param_dist,
                              'stim_param': params.STIM_PARAM, 'gmax': gmax, 'gmax_mapping': None,
                              'full_biophys': filepath, 'biophys_param': biophys_param, 'biophys_comm': biophys_comm }
}

## Generate random samples

In [8]:
rpg = Random_Parameter_Generator(seed=rand_seed, n_sigma=3)

#### Location parameters

In [9]:
loc_param_gen = loc_param_list.copy()
if 'd' in randomized_list and 'theta' in randomized_list:
    loc_param_gen[loc_param_gen.index('x')] = 'd'
    loc_param_gen[loc_param_gen.index('z')] = 'theta'

loc_param_samples = rpg.generate_parameters(number_samples, loc_param_gen, randomized_list, loc_param_default, loc_param_range, loc_param_dist)

if 'd' in randomized_list and 'theta' in randomized_list:
    loc_param_samples['x'], loc_param_samples['z'] = pol2cart(loc_param_samples['d'], loc_param_samples['theta'])

loc_param = np.column_stack([loc_param_samples[key] for key in loc_param_list])

# reshape into ncell-by-nloc-by-nparam
loc_param = loc_param.reshape(number_cells,number_locs,-1)

#### Geometery parameters

In [10]:
geo_param_samples = rpg.generate_parameters(number_cells, geo_param_list, randomized_list, geo_param_default, geo_param_range, geo_param_dist)

geo_param = np.column_stack([geo_param_samples[key] for key in geo_param_list])

# repeat to match number_samples
for key, value in geo_param_samples.items():
    geo_param_samples[key] = np.repeat(value,number_locs)

### Get parameters to be inferred as labels

In [11]:
samples = {**geo_param_samples, **loc_param_samples}
labels = np.column_stack([ samples[key] for key in inference_list ])
rand_param = np.column_stack([ samples[key] for key in randomized_list[:-len(inference_list)] ])

np.set_printoptions(suppress=True)
print(loc_param.shape)
print(geo_param.shape)
print(labels.shape)

(3, 2, 6)
(3, 1)
(6, 6)


## Create simulation and run

In [12]:
sim = Simulation(geometry = geo_standard,
                 electrodes = params.ELECTRODE_POSITION,
                 full_biophys = full_biophys,
                 cell_type = CellTypes.REDUCED_ORDER,
                 biophys = biophys_param,
                 biophys_comm = biophys_comm,
                 loc_param = loc_param,
                 geo_param = geo_param,
                 spike_threshold = params.SPIKE_THRESHOLD,
                 gmax = gmax,
                 stim_param = params.STIM_PARAM,
                 min_distance = params.MIN_DISTANCE,
                 ncell = number_cells)

sim.run_neuron_sim()

In [13]:
ecp = sim.lfp[-1]
im = ecp.calc_im()
tr = ecp.calc_transfer_resistance(move_cell=ecp.move_cell, scale=ecp.scale, min_distance=ecp.min_distance)

In [None]:
tr @ im

## Get LFPs
#### Save what

In [13]:
save_lfp = True
save_stats = True

#### Reshape LFP array. Filter each channel. Get window of spike for each sample.

In [None]:
lfp = sim.get_lfp('all', multiple_position=True)  # (cells x locs x channels x time)
lfp = lfp.reshape((-1,)+lfp.shape[-2:]).transpose((0,2,1))  # -> (samples x channels x time) -> (samples x time x channels)

filt_b, filt_a = signal.butter(params.BUTTERWORTH_ORDER,
                               params.FILTER_CRITICAL_FREQUENCY,
                               params.BANDFILTER_TYPE,
                               fs=params.FILTER_SAMPLING_RATE)

start_idx = int(np.ceil(params.STIM_PARAM['start']/h.dt)) # ignore signal before
filtered_lfp = signal.lfilter(filt_b, filt_a, lfp[:,start_idx:,:], axis=1)  # filter along time axis

lfp_list = []
bad_indices = []
for i in tqdm(range(number_samples)):
#     filtered_lfp[i] /= np.max(np.abs(filtered_lfp[i]))
    try:
        start, end = get_spike_window(filtered_lfp[i], win_size=params.WINDOW_SIZE, align_at=params.PK_TR_IDX_IN_WINDOW) # 24*0.025=0.6 ms
        lfp_list.append(filtered_lfp[i,start:end,:])
    except ValueError:
        bad_indices.append(i)

t = sim.t()[:params.WINDOW_SIZE]
windowed_lfp = np.stack(lfp_list, axis=0)  # (samples x time window x channels)
labels = np.delete(labels, bad_indices, axis=0)
print('%d bad samples.' % len(bad_indices))

In [None]:
if save_stats:
    y_idx = inference_list.index('y') if 'y' in inference_list else None
    summ_stats = []
    bad_indices = []
    yshift = []
    for i in tqdm(range(windowed_lfp.shape[0])):
        try:
            g_lfp, _, y_c = build_lfp_grid(windowed_lfp[i], params.ELECTRODE_POSITION[:, :2], y_window_size=960.0)
        except ValueError:
            bad_indices.append(i)
            continue
        summ_stats.append(calculate_stats(g_lfp))
        if y_idx is not None:
            yshift.append(y_c - labels[i, y_idx])

    summ_stats = np.array(summ_stats)
    yshift = np.array(yshift)
    print('%d bad samples.' % len(bad_indices))

## Save configurations and simulation data

In [None]:
DATA_PATH = 'cell_inference/resources/simulation_data'
TRIAL_PATH = os.path.join(DATA_PATH, 'Reduced_Order_trunklength_Loc5_restrict_h')

CONFIG_PATH = os.path.join(TRIAL_PATH, 'config.json')  # trial configuration
LFP_PATH = os.path.join(TRIAL_PATH, 'lfp' + batch_suf)  # LFP and labels
STATS_PATH = os.path.join(TRIAL_PATH, 'summ_stats' + batch_suf)  # summary statistics
MEM_VOLT_PATH = os.path.join(TRIAL_PATH, 'mem_volt' + batch_suf)  # membrane voltage and spike times

if not os.path.exists(DATA_PATH):
    os.makedirs(DATA_PATH)
    print("The new data directory is created!")

if not os.path.exists(TRIAL_PATH):
    os.makedirs(TRIAL_PATH)
    print("The new trial directory is created!")

In [None]:
if save_lfp:
    np.savez(LFP_PATH, t=t, x=windowed_lfp, y=labels, rand_param=rand_param, gmax=gmax)
if save_stats:
    np.savez(STATS_PATH, t=t, x=summ_stats, y=np.delete(labels, bad_indices, axis=0),
         rand_param=rand_param, gmax=gmax, ys=yshift, bad_indices=bad_indices)
# np.savez(MEM_VOLT_PATH, v=mem_volt, spk=tspk)
with open(CONFIG_PATH, 'w') as fout:
    json.dump(config_dict, fout, indent=2)

## Verify LFPs
We hand compare the LFP plots generated by our data

In [None]:
%matplotlib inline

cell_idx = 0

ix = 1
ylim = [-1900,1900]
x_dist = np.unique(params.ELECTRODE_POSITION[:,0])
e_idx = ((params.ELECTRODE_POSITION[:,0]==x_dist[ix]) & 
         (params.ELECTRODE_POSITION[:,1]>=ylim[0]) & 
         (params.ELECTRODE_POSITION[:,1]<=ylim[1]))

_ = plot_lfp_heatmap(t, params.ELECTRODE_POSITION[e_idx, 1],
                              windowed_lfp[cell_idx][:,e_idx], vlim='auto',
                              fontsize=15, labelpad=-10, ticksize=12, tick_length=5, nbins=5)

_ = plot_lfp_traces(t, windowed_lfp[cell_idx], fontsize=15, labelpad=-10, ticksize=12, tick_length=5, nbins=5)