# state_reconstruction_morpho
State variable reconstruction and model visualization

In [1]:
from neuron import h
from neuron.units import mV, ms
import numpy as np
import plotly
import matplotlib.pyplot as plt
from matplotlib import cm
import queue
import json
import gc
import os
import time
import pickle as pkl

import sys
sys.path.insert(1, "../utils/")
import Morpho, Stimuli

h.load_file("stdrun.hoc")
h.load_file("stdlib.hoc")
h.load_file("import3d.hoc")
h.load_file("stdrun.hoc")

1.0

In [3]:
np.random.randint(0, 250,2)

array([238,  43])

In [5]:
data_dir = '../data/state_reconstruct_morpho/'

In [19]:
ind = 6
with open(f'{data_dir}stimuli_sets/stimuli_set_{ind}.pkl', 'rb') as handle:
    stimuli = pkl.load(handle)
stimuli

[{282: <Stimuli.Poisson_Times at 0x7fe12ace92d0>,
  493: <Stimuli.Poisson_Times at 0x7fe12a9d1420>,
  143: <Stimuli.Poisson_Times at 0x7fe12aa25600>,
  623: <Stimuli.Poisson_Times at 0x7fe12aa27850>,
  361: <Stimuli.Poisson_Times at 0x7fe12aa279d0>,
  98: <Stimuli.Poisson_Times at 0x7fe12aa27160>,
  638: <Stimuli.Poisson_Times at 0x7fe12aa27220>,
  286: <Stimuli.Poisson_Times at 0x7fe12aa27d90>,
  753: <Stimuli.Poisson_Times at 0x7fe12aa27f10>,
  716: <Stimuli.Poisson_Times at 0x7fe12aa27b50>,
  680: <Stimuli.Poisson_Times at 0x7fe12a803ee0>,
  230: <Stimuli.Poisson_Times at 0x7fe12a8033a0>,
  688: <Stimuli.Poisson_Times at 0x7fe12acd56c0>,
  141: <Stimuli.Poisson_Times at 0x7fe12acd7850>},
 {754: <Stimuli.Poisson_Times at 0x7fe12a8166b0>,
  132: <Stimuli.Poisson_Times at 0x7fe12ab47bb0>,
  6: <Stimuli.Poisson_Times at 0x7fe12a846c20>,
  249: <Stimuli.Poisson_Times at 0x7fe12aa2d8a0>,
  41: <Stimuli.Poisson_Times at 0x7fe12aa2c040>,
  369: <Stimuli.Poisson_Times at 0x7fe12aa2c100>,
  1

In [20]:
with open(f'{data_dir}spiking_histories/spiking_histories_{ind}.pkl', 'rb') as handle:
    spiking_histories = pkl.load(handle)
spiking_histories

[{'v': [9.619071851730356,
   18.74706652775226,
   26.350919003211466,
   29.726364498164724,
   31.731804131579917,
   33.401255919616524,
   35.0071992271711,
   36.523522204810504,
   37.793067288232045,
   38.556658925655356,
   38.407606656763036,
   36.707479035652554,
   32.548816594785805,
   24.992381324631786,
   13.816752090595138,
   0.2861952353895729,
   -13.352352227039805,
   -25.337106587227385,
   -35.01110244121645,
   -42.46069763466118,
   -48.051772466903294,
   -52.183719059696855,
   -55.2010633242599,
   -57.37467037914697,
   -58.90707014078377,
   -59.94375388541288,
   -60.58403097086235,
   -60.88946663412421,
   8.970541715676832,
   7.882875018735711,
   7.120193951358717,
   4.739317650068679,
   1.1407050799671394,
   -7.756597132135848,
   -22.560519929534895,
   -31.5482335370834,
   -35.47113280373646,
   -38.35116579040035,
   -40.24069325147669,
   -41.17613157200416,
   -32.33322133521127,
   -37.40329918202062,
   -41.07772911148077,
   -43.4646

## create spiking stimuli sets
We randomly place stimuli along the morphology of the cell, not all randomized placements cause the cell to spike. Therefore, we test for those placements some that do and save them. 

In [3]:
duration = 5000 # any longer than 10000 is a bit fraught, there are lots of state variables to keep track of

num_stim_sets = len(os.listdir(f'{data_dir}stimuli_sets/'))

start_time = time.time()
print(f'num_stim_sets: {num_stim_sets}')
pyr = Morpho.Pyramidal("../morphology/c91662.CNG.swc", record_spiking_histories=True)
stim_params = Stimuli.MorphoStimParams(pyr)
stimuli = Stimuli.place_stims_along_morphology(stim_params.stim_scaffold, duration=duration)
for seg in stimuli:
    pyr.connect_input(stimuli[seg], pyr.all_input_segments[seg])

print('running sim')
h.finitialize(-65)
h.continuerun(duration)
print(f'sim done after {round(time.time() - start_time, 3)} seconds')

print(f'sim done, spike count: {len(pyr.spike_times)}')
if len(pyr.spike_times) > 20:
    with open(f'{data_dir}stimuli_sets/stimuli_set{num_stim_sets}.pkl', 'wb') as handle:
        pkl.dump(stimuli, handle, protocol=pkl.HIGHEST_PROTOCOL)
    num_stim_sets += 1
    print('stim set saved')


num_stim_sets: 20
running sim
sim done after 7.485 seconds
sim done, spike count: 153
stim set saved


In [5]:
len(pyr.spiking_histories)

153

In [None]:
pyr2 = Morpho.Pyramidal("../morphology/c91662.CNG.swc", record_spiking_histories=True)

with open(f'{data_dir}stimuli_sets/stimuli_set{ind}.pkl', 'rb') as handle:
    stimuli2 = pkl.load(handle)
    
for seg in stimuli2:
    pyr2.connect_input(stimuli2[seg], pyr2.all_input_segments[seg])

In [None]:
# commit spike times for reconstruction
spikes = list(pyr.spike_times)

# commit membrane voltage for reconstruction
obv_v = np.array([list(vec) for vec in pyr.v])
# this last step takes a little bit of time

In [None]:
stim_events = []
for seg_ind in stimuli:
    for event_time in stimuli[seg_ind].event_times:
        stim_events.append(
            (
                stimuli[seg_ind]._id, 
                event_time, 
                stimuli[seg_ind].tau, 
                stimuli[seg_ind].rev_potential, 
                seg_ind, 
                stimuli[seg_ind].weight
            )
        )
stim_events.sort(key = lambda x: x[1])

In [None]:
def event_sim_reconstruct(events, start_time, sim_duration, history=None):
    '''
    # used for simulating a morpho cell, detatched from other simulations, for n inputs
    :param events:
    :return:
    '''
    #h('forall delete_section()')
    #pyr = Pyramidal()
    pyr._clear_cell(False)

    # generate necessary synapses
    syns = {}
    for event in events:
        syns[event.seg_ind] = h.ExpSyn(pyr.all_input_segments[event.seg_ind])
        syns[event.seg_ind].tau = event.tau
        syns[event.seg_ind].e = event.rev_potential

    # create net connections for each stimulus
    min_event_time = min([event.t for event in events])
    max_event_time = max([event.t for event in events])
    
    # initialize cell with history
    if history:
        pyr.set_initialize_state(history)

    netstims = [h.NetStim() for event in events]
    for netstim, event in zip(netstims, events):
        netstim.number = 1
        netstim.start = event.t
        netcon = h.NetCon(netstim, syns[event.seg_ind])
        netcon.weight[0] = event.weight
        netcon.delay = 0 * ms
        pyr.netcons.append(netcon)

    # run simulation
    h.finitialize(-65)
    
    #print(len(list(h.allsec())))
    
    h.continuerun(sim_duration)
    my_result = list(pyr.v)
    return my_result

## V reconstruction

## Troubleshooting

In [None]:
# troubleshoot single instance
reconstruct_duration = 100
ind = 15
starting_spike = spikes[ind]
subsequent_stims = [stim for stim in stim_events if (stim[1] > starting_spike) and (stim[1] <= starting_spike+reconstruct_duration)]

events = [Morpho.Event(
        _id=stim[0],
        t=stim[1]-starting_spike,
        tau=stim[2],
        rev_potential=stim[3],
        seg_ind=stim[4],
        weight=stim[5]
    ) for stim in subsequent_stims]

# get comparable df
start_ind = starting_spike * 40
stop_ind = (starting_spike + reconstruct_duration) * 40
comparable_obvs_window = obv_v[:, round(start_ind):round(stop_ind+1)]

re_df = event_sim_reconstruct(events, starting_spike, reconstruct_duration, history=spiking_history)
#re_df = event_sim_reconstruct(events, reconstruct_duration)
re_df = np.array([list(vec) for vec in re_df])


print(comparable_obvs_window.shape)
print(re_df.shape)

_t = np.arange(0,100.025,1/40)

fig, axes = plt.subplots(2,1,figsize=(15,6), gridspec_kw={'height_ratios': [3, 3]}, sharex=True)
axes[0].plot(_t, comparable_obvs_window[100])
axes[0].vlines([event.t for event in events], -80, -70)

axes[1].plot(_t, re_df[100])
axes[1].vlines([event.t for event in events], -80, -70)

## running experiment

In [None]:
reconstruct_duration = 100
n = 100 # number of spikes to reconstruct
viable_spikes = [spike for spike in spikes if spike < max_time - reconstruct_duration]
spike_inds = np.random.choice(np.arange(0, len(viable_spikes)), n, replace=False)

reconstructed_dfs = []
comparable_obvs_dfs = []

print('______PROGRESS______')
progress = 0.0
for i, ind in enumerate(spike_inds):
    starting_spike = spikes[ind]
    subsequent_stims = [stim for stim in stim_events if (stim[1] > starting_spike) and (stim[1] <= starting_spike+reconstruct_duration)]

    events = [Event(
            _id=stim[0],
            t=stim[1]-starting_spike,
            tau=stim[2],
            rev_potential=stim[3],
            seg_ind=stim[4],
            weight=stim[5]
        ) for stim in subsequent_stims]

    # get comparable df
    start_ind = starting_spike * 40
    stop_ind = (starting_spike + reconstruct_duration) * 40
    comparable_obvs_window = obv_v[:, round(start_ind):round(stop_ind+1)]

    re_df = event_sim_reconstruct(events, starting_spike, reconstruct_duration, history=spiking_history)
    #re_df = event_sim_reconstruct(events, reconstruct_duration)
    re_df = np.array([list(vec) for vec in re_df])
    
    reconstructed_dfs.append(re_df)
    comparable_obvs_dfs.append(comparable_obvs_window)
    
    if i/n > progress:
        print('=',end='')
        progress+=0.05

In [None]:
reconstructed_dfs[0][0,:].shape

In [None]:
comparable_obvs_dfs[0][0,:].shape

## convert all vectors to np mats

In [None]:
ind = 10
plt.plot(_t, comparable_obvs_dfs[ind][0,:])
plt.plot(_t, reconstructed_dfs[ind][0,:])
plt.show()

## get the distances

In [None]:
difs = []
for ind in range(n):
    difs.append(comparable_obvs_dfs[ind] - reconstructed_dfs[ind])
difs[0].shape

In [None]:
# TODO: this but ordered by distance from soma
plt.plot(difs[0])

plt.show()

## average difs by time

In [None]:
all_segs = [seg for sec in pyr.all for seg in sec]
soma_segs = [(i,seg) for i, seg in enumerate(all_segs) if 'soma' in str(seg.sec) ]
axon_segs = [(i,seg) for i, seg in enumerate(all_segs) if 'axon' in str(seg.sec) ]
dend_segs = [(i,seg) for i, seg in enumerate(all_segs) if 'dend' in str(seg.sec) ]
apic_segs = [(i,seg) for i, seg in enumerate(all_segs) if 'apic' in str(seg.sec) ]

In [None]:
# calculate the errors
difs = np.stack(difs)
print(difs.shape)

# and absolute errors
abs_difs = np.abs(difs)

In [None]:
mean_difs_by_time = np.mean(abs_difs, axis=1)
mean_difs_by_time.shape

In [None]:
mean_difs_by_seg = np.mean(abs_difs, axis=2)
mean_difs_by_seg.shape

In [None]:
abs_difs

In [None]:
figures_dir = './figures/'

In [None]:
plt.figure(figsize=(6.5,4))

# all segs
plt.plot(_t, np.mean(mean_difs_by_time, axis=0))

# soma segs
soma_difs = abs_difs[:,[x[0] for x in soma_segs],:]
mean_soma_difs_by_time = np.mean(soma_difs, axis=1)
plt.plot(_t, np.mean(mean_soma_difs_by_time, axis=0))

# axon segs
axon_difs = abs_difs[:,[x[0] for x in axon_segs],:]
mean_axon_difs_by_time = np.mean(axon_difs, axis=1)
plt.plot(_t, np.mean(mean_axon_difs_by_time, axis=0))

# dend segs
dend_difs = abs_difs[:,[x[0] for x in dend_segs],:]
mean_dend_difs_by_time = np.mean(dend_difs, axis=1)
plt.plot(_t, np.mean(mean_dend_difs_by_time, axis=0))

# apic segs
apic_difs = abs_difs[:,[x[0] for x in apic_segs],:]
mean_apic_difs_by_time = np.mean(apic_difs, axis=1)
plt.plot(_t, np.mean(mean_apic_difs_by_time, axis=0))


plt.ylabel('mean absolute Vm reconstruct error (mV)')
plt.xlabel('time (ms)')
plt.xlim(0,25)

plt.legend(['all segments','soma','axon','dendrites','apicical dendrites'])
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)

'''plt.savefig(f'{figures_dir}reconstruction_morpho.svg', format='svg')'''

plt.show()

In [None]:
plt.plot(1,2)
plt.plot(1,2)
plt.plot(1,2)
plt.plot(1,2)
plt.plot(1,2)
plt.legend(['all segments','soma','axon','dendrites','apical dendrites'])

In [None]:
def seg2color(seg):
    if 'soma' in str(seg):
        return 'orange'
    elif 'axon' in str(seg):
        return 'green'
    elif 'dend' in str(seg):
        return 'red'
    elif 'apic' in str(seg):
        return 'purple'
    else:
        return 'error'


In [None]:
plt.figure(figsize=(2.5,2))

t = 10

plt.scatter(
    [h.distance(seg, pyr.soma[0](0.5)) for sec in pyr.all for seg in sec], 
    np.mean(abs_difs[:,:,t*40], axis=0), #np.mean(mean_difs_by_seg, axis=0)
    color=[seg2color(seg) for seg in all_segs],
    s=5,
    alpha=0.5
)
plt.ylim(0, 5.5)

plt.title(f'reconstruct error per\nsegment after {t} ms')
plt.ylabel('Vm reconstruct\nMAE (mV)')
plt.xlabel('path distance to soma (mm)')

'''plt.savefig(f'{figures_dir}reconstruction_morpho_inlay10.svg', format='svg')'''

plt.show()

In [None]:
plt.figure(figsize=(2.5,2))
t = 20
plt.scatter(
    [h.distance(seg, pyr.soma[0](0.5)) for sec in pyr.all for seg in sec], 
    np.mean(abs_difs[:,:,t*40], axis=0), #np.mean(mean_difs_by_seg, axis=0)
    color=[seg2color(seg) for seg in all_segs],
    s=5,
    alpha=0.5
)
plt.ylim(0, 5.5)

plt.title(f'reconstruct error per\nsegment after {t} ms')
plt.ylabel('Vm reconstruct\nMAE (mV)')
plt.xlabel('path distance to soma (mm)')

'''plt.savefig(f'{figures_dir}reconstruction_morpho_inlay20.svg', format='svg')'''

plt.show()

In [None]:
np.max(np.mean(abs_difs[:,:,800], axis=0))

In [None]:
np.mean(abs_difs[:,:,800], axis=0).shape

In [None]:
abs_difs.shape

## 3D visualization

In [None]:
from neuron import rxd
cyt = rxd.Region(pyr.all)
error = rxd.Parameter(cyt)
'''
for node in error.nodes:
    node.value = h.distance(node.segment, pyr.soma[0](0.5))
'''

i = 0
vals = []
for sec in pyr.all:
    for seg in sec:
        #error.nodes(seg).value = h.distance(seg, pyr.soma[0](0.5))
        #error.nodes(seg).value = np.mean(mean_difs_by_seg, axis=0)[i]
        error.nodes(seg).value = np.mean(abs_difs[:,:,800], axis=0)[i]
        vals.append(np.mean(abs_difs[:,:,800], axis=0)[i])
        
        #print(seg, np.mean(abs_difs[:,:,800], axis=0)[i])
        i+=1

In [None]:
np.mean(abs_difs[:,:,800], axis=0).shape

In [None]:
abs_difs.shape

In [None]:
ps = h.PlotShape(False)
ps.variable(error)
ps.scale(0, 1)

ps2 = ps.plot(plotly, cmap=cm.cool)

for seg_ind in stimuli:
    seg = pyr.all_input_segments[int(seg_ind)]
    ps2.mark(seg, marker_size=3, marker_color=event_type2color(stimuli[seg_ind]._id), marker_opacity=.9)

ps2.show()

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1)

fraction = 1  # .05

norm = mpl.colors.Normalize(vmin=0, vmax=1)
cbar = ax.figure.colorbar(
            mpl.cm.ScalarMappable(norm=norm, cmap='cool'),
            ax=ax, extend='both')

ax.axis('off')
plt.show()