# Readme

Computes the intial states $y_0$ of all single STG neurons.

If the neurons have a steady state in the absence of input current, it will find this steady state using a grid search and gradient minimization.

If there is no such steady state, it will simulate a longer trace and find the lowest membrane potential between the two spikes with the largest distance in time.
The state at this point will be used as the initial state.

In [None]:
import numpy as np
from matplotlib import pyplot as plt

In [None]:
from sys import path as sys_path
from os.path import abspath as os_path_abspath
sys_path.append(os_path_abspath('..'))
import addpaths

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import stg_model
import stg_parameters
import ode_solver
import data_utils
import stim_utils
import math_utils

# Create neuron

In [None]:
def get_ydot_eq(v_clamp, plot=False, return_y_eq=False, tmax=10):
    """Get vdot at clamped voltage"""
    neuron.set_v_clamped(1)
    
    y0, _ = neuron.cmodel.eval_yinf_and_yf(
        t=0.0, y=np.array([v_clamp, 0.01, 0.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.8, 0.0, 0.0])
    )
    
    y0[1] = 0.01
    
    solver = ode_solver.get_solver(
        odefun=neuron.eval_yinf_and_yf, t0=0.0, y0=y0, h0=1e-6, method='EE', adaptive=True
    )
    
    sol = solver.solve(tmax=tmax*1e3, show_progress=plot)
    if plot: sol.plot(max_nx_sb=5)
        
    y_eq = sol.ys[-1,:]
        
    neuron.set_v_clamped(0)
    vdot_eq = neuron.eval_ydot(t=0, y=y_eq)[0]
    
    if plot: print(vdot_eq)
    
    if return_y_eq:
        return np.abs(vdot_eq), y_eq
    else:
        return np.abs(vdot_eq)

In [None]:
def get_v0_and_bounds(plot=True):
    """Test different voltage clamps"""
    
    vclamps = np.linspace(-80, -30, 101)
    ydot_eqs = [get_ydot_eq(vclamp) for vclamp in vclamps]
    best_idx = np.argmin(ydot_eqs)
    
    x0 = vclamps[best_idx]
    
    if x0 == vclamps[0]:
        lb = vclamps[0]
        ub = vclamps[1]
    elif x0 == vclamps[-1]:
        lb = vclamps[-2]
        ub = vclamps[-1]
    else:
        lb = vclamps[np.argmin(ydot_eqs[:best_idx])]
        ub = vclamps[best_idx+1+np.argmin(ydot_eqs[best_idx+1:])]
    
    if plot:
        fig, ax = plt.subplots(figsize=(7,1.5), subplot_kw=dict(xlabel='Voltage', ylabel='|dvdt|'))
        ax.semilogy(vclamps, ydot_eqs, 'x-')
        ax.axvline(x0, c='k', alpha=0.6)
        ax.axvline(lb, c='r', alpha=0.6, ls='--', zorder=-10)
        ax.axvline(ub, c='c', alpha=0.6, ls='--', zorder=-10)
        ax.set_title('Search between red and cyan line. Black is lowest so far.')
        plt.show()
    
    assert x0 >= lb
    assert x0 <= ub
    
    return x0, (lb, ub)

In [None]:
from scipy.optimize import minimize
from scipy.ndimage.filters import gaussian_filter
import metric_utils

def find_y_eq(x0, bounds, plot=False, tmax=2):
    """Simulate model and look if there will be spikes."""
    eq_sol = minimize(get_ydot_eq, x0=x0, bounds=[bounds], tol=0.01)

    best_v_clamp = eq_sol.x
    vdot_eq, y_eq = get_ydot_eq(v_clamp=best_v_clamp, plot=False, return_y_eq=True)
    
    neuron.set_v_clamped(0)
    
    solver = ode_solver.get_solver(
        odefun=neuron.eval_ydot, t0=0.0, y0=y_eq, h0=0.01, method='RKDP',
        adaptive=1, adaptive_params={'max_step': 0.1, 'rtol': 1e-5, 'rtol': 1e-5},
    )
    sol = solver.solve(
        tmax=tmax*1e3, return_vars=['ys', 'ydots', 'failed_steps'],
        t_eval=math_utils.t_arange(0.0, tmax*1e3, 0.1)
    )
    
    has_eq = np.all(sol.ys[:,0] < -20)
    
    if not has_eq:
        spike_times = metric_utils.find_spike_times_in_trace(
            ts=sol.ts, vs=sol.get_ys(yidx=0), thresh=20
        )
        spike_times = np.append(0, spike_times)
        
        idxspike = np.argmax(np.diff(spike_times))
        idx0 = np.argmin(np.abs(sol.ts - spike_times[idxspike]))
        idx1 = np.argmin(np.abs(sol.ts - spike_times[idxspike+1]))
    
        t_eq_idx = idx0 + np.argmin(gaussian_filter(sol.ys[idx0:idx1,0], 20))
        
    else:
        t_eq_idx = 0
        
    y_eq = sol.ys[t_eq_idx,:]
    
    ### Plot ###
    fig, ax = plt.subplots(figsize=(7,1.5))
    ax.set_title(f"{'EQ. found' if has_eq else 'no EQ. found'}, v0={sol.ys[t_eq_idx,0]:.1f}")
    ax.plot(sol.ts, sol.ys[:,0], label='voltage')
    ax.set_ylabel('v(t)')
    ax.set_xlabel('Time (s)')
    
    ax.axhline(y_eq[0], c='r', alpha=0.6, label='v_EQ')
    ax.axvline(sol.ts[t_eq_idx], c='darkred', alpha=0.6, label='t_EQ')
    
    if not has_eq:
        ax.axvline(sol.ts[idx0], c='c', alpha=0.6, zorder=-10, ls=':')
        ax.axvline(sol.ts[idx1], c='c', alpha=0.6, zorder=-10, ls=':')

    ax.legend(loc='best')
    plt.show()
        
    return y_eq

# Get y_eq for all cells

In [None]:
neuron2y0 = {}

In [None]:
zero_stim = stim_utils.Istim(Iamp=0, onset=0, offset=0)

In [None]:
np.random.seed(42)

for name in np.flip(list(stg_parameters.neuron2gs.keys())):
    neuron = stg_model.stg_model(n_neurons=1, g_params=name, stim=zero_stim)
    print(name)
    x0, bounds = get_v0_and_bounds(plot=True)
    y_eq = find_y_eq(x0, bounds, plot=True)
    
    neuron2y0[name] = y_eq
    print('------------------------')

In [None]:
data_utils.save_var(neuron2y0, '../../models/stg_neuron2y0.pkl')