In [1]:
%matplotlib notebook


from neuron import h
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import h5py

from cell_inference.utils.plotting.plot_results import plot_lfp_traces, plot_lfp_heatmap
from cell_inference.utils.feature_extractors.SummaryStats2D import calculate_stats, build_lfp_grid
from cell_inference.cells.stylizedcell import CellTypes
from cell_inference.cells.activecell import ActiveCell
from cell_inference.cells.passivecell import PassiveCell
from cell_inference.cells.simulation import Simulation
from cell_inference.utils.currents.recorder import Recorder
from cell_inference.config import paths, params
from cell_inference.utils.feature_extractors.parameterprediction import ClassifierTypes, ClassifierBuilder

cell_type = CellTypes.ACTIVE

h.nrn_load_dll(paths.COMPILED_LIBRARY)
geo_standard = pd.read_csv(paths.GEO_STANDARD,index_col='id')
h.tstop = params.TSTOP
h.dt = params.DT

In [2]:
rng = np.random.default_rng(12345)

inf_list = ['r_s', 'l_t', 'r_t']
SUMM_STAT_PATH = 'cell_inference/resources/geo_summ_stats.npy'
LFP_PATH = 'cell_inference/resources/geo_lfp.npy'
MEM_VOLT_PATH = 'cell_inference/resources/geo_mem_volt.npy'
run_flag = True
number_samples = 10000

xs = np.full((number_samples,1), 0.) if 'xs' not in inf_list else rng.uniform(low=-50, high=50, size=(number_samples,1))
ys = np.full((number_samples,1), 0.) if 'ys' not in inf_list else rng.uniform(low=-2000, high=5000, size=(number_samples,1))
zs = np.full((number_samples,1), 50.) if 'zs' not in inf_list else rng.uniform(low=-50, high=50, size=(number_samples,1))
alphas = np.full((number_samples,1), np.pi/4) if 'alphas' not in inf_list else rng.uniform(low=-(np.pi / 3), high=(np.pi / 3), size=(number_samples,1))
hs = np.full((number_samples,1), 1.) if 'hs' not in inf_list else rng.uniform(low=-1., high=1., size=(number_samples,1))
phis = np.full((number_samples,1), 0.) if 'phis' not in inf_list else rng.uniform(low=0, high=np.pi, size=(number_samples,1))
loc_param = np.concatenate((xs, ys, zs, alphas, hs, phis), axis=1)

r_s = np.full((number_samples,1), 8.0) if 'r_s' not in inf_list else rng.uniform(low=7, high=12, size=(number_samples,1))
l_t = np.full((number_samples,1), 600.0) if 'l_t' not in inf_list else rng.uniform(low=20., high=800., size=(number_samples,1))
r2_s = np.square(r_s)
r_t = np.full((number_samples,1), 1.25) if 'r_t' not in inf_list else rng.normal(loc=((np.log(0.6) + np.log(1.8)) / 2), scale=((np.log(1.8) - np.log(0.6)) / 4), size=(number_samples,1))
r_d = np.full((number_samples,1), .28) if 'r_d' not in inf_list else rng.normal(loc=((np.log(0.2) + np.log(1.0)) / 2), scale=((np.log(1.0) - np.log(0.2)) / 4), size=(number_samples,1))
r_tu = np.full((number_samples,1), .28) if 'r_tu' not in inf_list else rng.normal(loc=((np.log(0.2) + np.log(1.0)) / 2), scale=((np.log(1.0) - np.log(0.2)) / 4), size=(number_samples,1))
l_d = np.full((number_samples,1), 200.0) if 'l_d' not in inf_list else rng.normal(loc=((np.log(100) + np.log(300)) / 2), scale=((np.log(300) - np.log(100)) / 4), size=(number_samples,1))

clf = ClassifierBuilder()
clf.load_clf(paths.RESOURCES_ROOT + "gmax_lin_reg_classifier.joblib")
gmax = clf.predict(np.column_stack((r2_s, l_t, r_t)))

geo_param = np.concatenate((r_s, l_t, r_t, r_d, r_tu, l_d), axis=1)

# geo_param = [8,600.,1.25,.28,.28,200.]
# geo_param = np.tile(geo_param, (number_samples, 1))

labels = np.concatenate((r_s, l_t, r_t), axis=1)#np.concatenate((loc_param, geo_param), axis=1)
np.set_printoptions(suppress=True)
print(loc_param.shape)
print(geo_param.shape)
print(labels.shape)

(10000, 6)
(10000, 6)
(10000, 3)


In [None]:
from typing import Union, List, Tuple
from matplotlib.figure import Figure
from matplotlib.axes import Axes

import os.path
if not os.path.isfile(SUMM_STAT_PATH) or not os.path.isfile(LFP_PATH) or run_flag:
    hf = h5py.File(paths.INVIVO_DATA_FILE, 'r')
    groundtruth_lfp = np.array(hf.get('data'))
    hf.close()

    sim = Simulation(geometry = geo_standard, 
                     electrodes = params.ELECTRODE_POSITION, 
                     cell_type = CellTypes.ACTIVE, 
                     loc_param = loc_param, 
                     geo_param = geo_param,
                     spike_threshold = -20, 
                     gmax = gmax, 
                     scale = 1., 
                     ncell = number_samples)
    # sim = Simulation(geo_standard,params.ELECTRODE_POSITION,loc_param,geo_param=geo_param,gmax=.001,scale=1000, ncell=ncell)  # gmax 0.001 -0.012
    sec_list = sim.cells[0].all

def plot_v(sim: Simulation, cell_idx: np.ndarray = np.array([0, 1]),
           figsize: Union[List[float],Tuple[float]] = (6,2)) -> Tuple[Figure, Axes]:
    t = sim.t()
    v = sim.v('all')
    fig, axs = plt.subplots(nrows=cell_idx.size, ncols=1)
    fig.set_size_inches(figsize[0],figsize[1]*cell_idx.size)
    for i,ax in enumerate(axs):
        ax.plot(t,v[i,:])
        ax.set_ylabel('Vm (mV)')
#         ax.legend(loc=1)
    axs[0].set_title('Membrane Voltage vs Time')
    axs[-1].set_xlabel('Time (ms)')
    plt.show()
    return fig, axs    

def valid_count(sim):
    # number of valid spiking cells
    nspk = sim.get_spike_number('all')
#     print(nspk)
    nvalid = np.count_nonzero(nspk == 1)
    return nvalid

In [None]:
from cell_inference.utils.feature_extractors.SummaryStats2D import build_lfp_grid, calculate_stats
from cell_inference.utils.spike_window import first_pk_tr, get_spike_window

import os.path
if not os.path.isfile(SUMM_STAT_PATH) or not os.path.isfile(LFP_PATH) or run_flag:
    lfp_list = []
    summ_stat_list = []
    sim.run_neuron_sim()
    print("Number of valid sample: %d" % (valid_count(sim)))
    _ = plot_v(sim)

    # lfp_list = sim.get_lfp(np.arange(number_samples, dtype=int)).T

    for i in range(number_samples):
        lfp_list.append(sim.get_lfp(i).T)
        
        filt_b,filt_a = signal.butter(params.BUTTERWORTH_ORDER,
                              params.FILTER_CRITICAL_FREQUENCY,
                              params.BANDFILTER_TYPE,
                              fs=params.FILTER_SAMPLING_RATE)
        
        lfp_list[-1] = signal.lfilter(filt_b,filt_a,lfp_list[-1],axis=0) 
        
#         lfp_list[-1] /= np.max(np.abs(lfp_list[-1]))
    #     print(lfp_list[-1].shape)
        fst_idx = first_pk_tr(lfp_list[-1])
        start, end = get_spike_window(lfp_list[-1], win_size=params.WINDOW_SIZE, align_at=fst_idx-10)
        g_lfp, grid = build_lfp_grid(lfp_list[-1], params.ELECTRODE_POSITION, params.ELECTRODE_GRID)
        summ_stat_list.append(calculate_stats(g_lfp, grid))

    t = sim.t()
    lfp_list = np.reshape(np.transpose(np.stack(lfp_list, axis=-1)), (1000, 1, 384, 801))
    summ_stat_list = np.transpose(np.stack(summ_stat_list, axis=-1))
    np.savez(SUMM_STAT_PATH, x=summ_stat_list, y=labels)
    np.savez(LFP_PATH, x=lfp_list, y=labels)
    np.savez(MEM_VOLT_PATH, sim.v('all'))
else:
    lfp_list = np.load(LFP_PATH)
    summ_stat_list = np.load(SUMM_STAT_PATH)

In [None]:
data = summ_stat_list
print(data.shape)