In [27]:
import os
import numpy as np
import json

from neuron import h
from neuron.units import mV, ms
h.load_file("stdrun.hoc")
h.load_file("stdlib.hoc")
h.load_file("import3d.hoc")
h.load_file("stdrun.hoc")

1.0

## Load data

In [3]:
data_dir = './data/state_reconstruct_morpho/'

In [24]:
class Poisson_Times:
    def __init__(self, _id, tau, interval, weight, rev_potential, event_times=None, max_time=1000000, number=99999999,
                 delay=0, start=0):
        '''
        :param _id:
        :param event_times: instead of auto generating, provide a list of event_times
        :param max_time: maximum time (simulation duration)
        :param number: maximum number of stimuli (typically inconsequential if max_time is reasonable)
        '''

        self._id = _id
        self.rev_potential = rev_potential
        self.max_time = max_time
        self.interval = interval
        self.weight = weight
        self.delay = delay
        self.tau = tau
        self.start = start
        self.number = number
        self.event_times = []

        if event_times:
            # load event times
            self.event_times = event_times
        else:
            # generate event times
            event_time = 0
            for i in range(number):
                event_time += np.random.exponential(self.interval)
                if event_time < max_time:
                    self.event_times.append(event_time)
                else:
                    break

    def write2file(self, path):
        stimuli_data = {
            '_id': self._id,
            'rev_potential': self.rev_potential,
            'max_time': self.max_time,
            'interval': self.interval,
            'weight': self.weight,
            'delay': self.delay,
            'tau': self.tau,
            'start': self.start,
            'number': self.number,
            'event_times': self.event_times
        }

        with open(path, 'w') as fout:
            fout.write(json.dumps(stimuli_data))
            
class Pyramidal:
    def __init__(self, record_spiking_histories=False):
        self.load_morphology()
        # do discretization, ion channels, etc
        for sec in self.all:
            sec.nseg = int(1 + 2 * (sec.L // 40))
        h.hh.insert(self.axon)
        h.hh.insert(self.soma)
        h.pas.insert(self.dend)  # passive leak
        h.pas.insert(self.apic)  # passive leak
        self.all_input_segments = []
        for morph in [self.apic, self.dend]:
            for part in morph:
                # self.all_input_segments.append(part)
                self.all_input_segments.extend([seg for seg in part])
        self._clear_cell(record_spiking_histories)

    def _clear_cell(self, record_spiking_histories):
        # storing input mechanisms
        self.syns = []
        self.net_stims = []
        self.netcons = []
        self.stims = []
        # recording
        self.v_apic = h.Vector().record(self.apic[100](0.5)._ref_v)
        self.v_soma = h.Vector().record(self.soma[0](0.5)._ref_v)
        self.v_axon = h.Vector().record(self.axon[0](0.5)._ref_v)
        self._t = h.Vector().record(h._ref_t)

        self.v = [h.Vector().record(seg._ref_v) for sec in self.all for seg in sec]
        self.hh_secs = [sec for sec in self.all if 'hh' in sec.psection()['density_mechs']]

        '''
        self.m = [h.Vector().record(seg.hh._ref_m) for sec in hh_secs for seg in sec]
        self.h = [h.Vector().record(seg.hh._ref_h) for sec in hh_secs for seg in sec]
        self.n = [h.Vector().record(seg.hh._ref_n) for sec in hh_secs for seg in sec]
        '''

        self.spike_detector = h.NetCon(self.axon[0](0.5)._ref_v, None, sec=self.axon[0])
        self.spike_times = h.Vector()
        self.spike_detector.record(self.spike_times)

        if record_spiking_histories:
            self.spiking_histories = []
            self.spike_detector2 = h.NetCon(self.axon[0](0.5)._ref_v, None, sec=self.axon[0])
            self.spike_detector2.record(self.save)

    def __repr__(self):
        return "pyr"

    def get_state(self):
        return {
            "v": [seg.v for sec in self.all for seg in sec],
            "m": [seg.hh.m for sec in self.hh_secs for seg in sec],
            "h": [seg.hh.h for sec in self.hh_secs for seg in sec],
            "n": [seg.hh.n for sec in self.hh_secs for seg in sec]}

    def save(self):
        self.spiking_histories.append(self.get_state())

    def set_initialize_state(self, state):
        self._initial_state = state
        self.fih = h.FInitializeHandler(self._do_initial)

    def _do_initial(self):
        # state: state dict from self.get_state()
        all_segs = [seg for sec in self.all for seg in sec]
        hh_segs = [seg for sec in self.hh_secs for seg in sec]
        for seg, v in zip(all_segs, self._initial_state["v"]):
            seg.v = v
        for seg, m, h, n in zip(hh_segs, self._initial_state["m"], self._initial_state["h"], self._initial_state["n"]):
            seg.hh.m = m
            seg.hh.n = n
            seg.hh.h = h

    def load_morphology(self):
        cell = h.Import3d_SWC_read()
        cell.input("./resources/neuron_nmo/amaral/CNG version/c91662.CNG.swc")
        i3d = h.Import3d_GUI(cell, False)
        i3d.instantiate(self)

    def connect_input(self, stimuli, seg):
        '''
        :param stimuli: Poisson_Times class object
        :param seg: NEURON simulation segment
        :return:
        '''
        syn = h.ExpSyn(seg)
        syn.tau = stimuli.tau
        syn.e = stimuli.rev_potential

        vec_stim_times = h.Vector(stimuli.event_times)
        vec_stim = h.VecStim()
        vec_stim.play(vec_stim_times)

        nc = h.NetCon(vec_stim, syn)
        nc.weight[0] = 1  # stimuli.weight
        nc.delay = stimuli.delay

        self.syns.append(syn)
        self.netcons.append(nc)

        netstims = [h.NetStim() for stim_time in stimuli.event_times]
        for netstim, event_time in zip(netstims, stimuli.event_times):
            netstim.number = 1
            netstim.start = event_time
            netcon = h.NetCon(netstim, syn)
            netcon.weight[0] = stimuli.weight
            netcon.delay = 0 * ms

            self.netcons.append(netcon)
        self.stims.extend(netstims)

    def load_stimuli_from_file(self, stimuli_file):
        with open(stimuli_file, 'r') as fin:
            stimuli_json = json.load(fin)
        for seg_ind in stimuli_json:
            stimuli = Poisson_Times(
                stimuli_json[seg_ind]['stim_type'],
                stimuli_json[seg_ind]['tau'],
                stimuli_json[seg_ind]['interval'],
                stimuli_json[seg_ind]['weight'],
                stimuli_json[seg_ind]['rev_potential'],
                event_times=stimuli_json[seg_ind]['event_times']
            )
            self.connect_input(stimuli, self.all_input_segments[int(seg_ind)])

In [25]:
reconstruction_duration = 100
initial_simulation_duration = 5000

In [29]:
ind = 0

# load spikes
with open(f'{data_dir}original_simulation_data/spikes_{ind}.txt', 'r') as fin:
    spikes = [float(x.strip()) for x in fin.readlines()]

# set up reconstruction cell
reconstruction_cell = Pyramidal()
reconstruction_cell.load_morphology()

# load stimuli
stimuli_files = os.listdir(f'{data_dir}original_simulation_data/stimuli_{ind}')
for stimuli_file in stimuli_files:    
    with open(f'{data_dir}original_simulation_data/stimuli_{ind}/{stimuli_file}', 'r') as f:
        stimuli_data = json.load(f)
        seg_ind = int(stimuli_file.split('.')[1])
        stimuli = Poisson_Times(
            _id = stimuli_data['_id'],
            tau = stimuli_data['tau'], 
            interval = stimuli_data['interval'], 
            weight = stimuli_data['weight'], 
            rev_potential = stimuli_data['rev_potential'], 
            event_times = stimuli_data['event_times']
        )
    reconstruction_cell.connect_input(stimuli, reconstruction_cell.all_input_segments[seg_ind])
state_vars = np.load(f'{data_dir}original_simulation_data/state_vars_{ind}.npy')

In [30]:
spikes

[66.02500000010644,
 83.20000000011035,
 104.5000000001152,
 153.10000000009768,
 222.27500000003477,
 238.6250000000199,
 280.6249999999817,
 299.04999999996494,
 316.9249999999487,
 349.9999999999186,
 383.5499999998881,
 434.14999999984207,
 449.8249999998278,
 497.0999999997848,
 515.5499999997842,
 538.9249999998692,
 553.3499999999217,
 586.8000000000434,
 607.6750000001193,
 646.5250000002607,
 703.2750000004671,
 740.2250000006015,
 763.2000000006851,
 803.5000000008317,
 856.8000000010256,
 903.7750000011965,
 956.6000000013887,
 995.6500000015308,
 1014.0750000015978,
 1029.8750000016553,
 1077.375000001828,
 1100.175000001911,
 1164.5750000021453,
 1204.4750000022905,
 1225.800000002368,
 1278.1500000025585,
 1313.0250000026854,
 1333.3750000027594,
 1366.8500000028812,
 1390.4750000029671,
 1421.4500000030798,
 1475.625000003277,
 1495.0000000033474,
 1590.6750000036955,
 1620.0750000038024,
 1652.100000003919,
 1714.4750000041458,
 1727.5750000041935,
 1745.025000004257,
 