## Refactoring Odes Function for Reduced ICs
> The goal of this NB is to reduce the number of ICs that must be passed in to the Odes function, with the overall goal of preventing sigma, v0, and tau from changing during the runs as these should be constants.  Sigma aprpoaching zero is responsible for __.
> Furthermore, a second goal of this NB is to update the necessary functions such that the synaptic couplings dictionary is always filled, and in spaces where there are no connections, is just filled with zeros.  I.e. the new number of connections for each node is now NumNodes instead of NumNodes-1.

In [2]:
from scipy.integrate import odeint
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import os.path

COLOR = 'grey'
mpl.rcParams['text.color'] = COLOR
mpl.rcParams['axes.labelcolor'] = COLOR
mpl.rcParams['xtick.color'] = COLOR
mpl.rcParams['ytick.color'] = COLOR

In [3]:
def I(t=0,vI=[0,0,0]):
    '''
    Determines the applied current, with respect to the HH model.  By default makes a square wave
    '''
    
    #Applied current
    t_start = vI[0]
    t_end = vI[1]
    amplitude = vI[2]
    
    applied_current = (t>=t_start)*(t<=t_end)*amplitude
    #Ie if t is within range [t_start,t_end], we apply some amount of current, amplitude
    #Recall: t>=C evaluates to a Boolean, which, when used numerically, is equivalent to 0 or 1
    return applied_current

In [4]:
def plot_vnmh(x,t,numNodes,gatingVars=False):
    '''
    Callable function for plotting V,n,m,h
    
    Params:
        x: Matrix, each column is one of the above variables
        t: Time vector for x axis
        numNodes:
        gatingVars:
    '''
    node_length = 21+(numNodes-1)
    V = np.zeros((numNodes, len(t)))
    n = np.zeros((numNodes, len(t)))
    m = np.zeros((numNodes, len(t)))
    h = np.zeros((numNodes, len(t)))
    
    #plot the results
    plt.figure(figsize=(15,10))
    #Font sizes
    title_font = 20
    label_font = 15
    
    #Extract variables from matrix
    for node in range(numNodes):
        V[node] = x[node_length*node,:]
        n[node] = x[node_length*node+1,:]
        m[node] = x[node_length*node+2,:]
        h[node] = x[node_length*node+3,:]

    #Now plot the voltage graph:
    plt.title("Voltage as a Function of Time", fontsize=title_font, fontweight="bold")
    plt.ylabel("Volts (mV)", fontsize=label_font)
    plt.xlabel("Time (ms)", fontsize=label_font)
    for node in range(numNodes):
        NeuronNum = 'Neuron ' + str(node+1)
        plt.plot(t,V[node],label=NeuronNum)
    leg = plt.legend(loc='upper right')
    for text in leg.get_texts():
        text.set_color('black')
    plt.show()

    if gatingVars:
        plt.figure(figsize=(15,10))
        plt.title("Gating Variables as a Function of Time", fontsize=title_font, fontweight="bold")
        plt.ylabel("Volts (mV)", fontsize=label_font)
        plt.xlabel("Time (ms)", fontsize=label_font)
        
        #Should probably just plot n's together or something... otherwise too many curves on one plot
        my_label = [0]*3
        for node in range(numNodes):
            my_label[0] = 'n' + str(node+1)
            my_label[1] = 'm' + str(node+2)
            my_label[2] = 'h' + str(node+3)
            
            plt.plot(t,n[node],label=my_label[0])
            plt.plot(t,m[node],label=my_label[1])
            plt.plot(t,h[node],label=my_label[2])

        leg = plt.legend(loc='upper right')
        for text in leg.get_texts():
            text.set_color('black')
        plt.show()

In [5]:
#Functions replacing non-functionalized calls from odes_sigmoid

def tau_func(V, mean):
    '''
    Function which returns the time constant for the associated variable
    Note that currently we are ignoring the "bump" in the true Tau functions and instead
    simply returning the mean.  The mean has been precalculated and will just be passed
    in and out for form's sake.
    
    In the future may switch to represent Tau functions as Gaussians (which they appear to be), but
    this would require extra parameters (i.e. instead of mean would need the center, width, and upwards bias)
    '''
    return mean


def inf_func(V, v0, sigma):
    '''
    Approximates the "quasi-Steady-State" ie the x_inf values, using sigmoid functions
    
    Params:
        v0: v initial value, changes x coord of center, essentially
        sigma: "rate" (technically 1/sigma is the rate) --> big sigma means gradual change (i.e. less steep slope)
    '''
    
    if v0==0 and sigma==0:
        #Avoid division by 0
        x_inf = 0
    else:
        x_inf = 1 / (1 + np.exp(-(V-v0)/sigma))
    return x_inf 


def diffEQ_func(tau, x_inf, x):
    '''
    ie dndt, dmdt, dhdt
    Uses the canonical self linear form, such that dxdt = G(x) * (E(x) - V)
    '''
    if tau==0:
        #Avoid division by 0
        dxdt = 0
    else:
        #for the base equations is just
        dxdt = (1/tau)*(x_inf - x)
    return dxdt

In [None]:
def odes_RF(x,t,I,vI):
    '''
    Defines the constants and differential equations for the base Hodgkin-Huxley equations

    Params:: 
        x: all network parameters
        t: time vector, not called in func but regardless (but used in odeint)
        I: custom function, time varying current (square wave)
        vI: custom 3x1 vector for I function, just encapsulates the parameters needed for I to function

    Returns:: Differential equations of each variable as a vector
    '''
    
    ######### DEFINITIONS #########
    
    '''
    Things to change:
    1. maxParams will no longer exist
    '''
    
    numGVs = 3 #This is "Manual" for now, could update the functionality to vary this
    p = [1,1+numGVs,-len(x)]
    roots = np.roots(p)
    numNodes = int(roots[1])
    
    param_list = get_parameters(numGVs)

    numSC = numNodes-1 
    ES = [0] * numSC #ENa #mV #Excitatory Neuron (by this definition)
    node_current = [0]*numNodes
    
    #Convert list input into a matrix
    numParams = 1+numGVs+numSC
    x = np.reshape(x,(numNodes,numParams))
    
    if len(x.shape)==1:
        #Ie if it is 1D, expand the dim to be (1,X) so that we can index as if it were 2D
        x = np.expand_dims(x,0)
        
    firstSIndex = 1+numGVs
    
    infs = np.zeros((numNodes,numGVs))
    taus = np.zeros((numNodes,numGVs))
    dxdts = np.zeros((numNodes,len(x)))

    ######### CONSTANTS #########
    CmBase = 0.01 #uF
    #Altering the current for the nodes, essentially weighting them
    #In the test trials in NB 10, we used 1x for node1, and 0x for node2
    
    #Current weighting matrix
    node_current[0] = 1
    if numNodes>1:
        #Do something more advanced later
        node_current[1] = 0
    node_current = np.array(node_current)*(1/CmBase)
    
    #For the synaptic coupling: 
    #^20 was a "good value" (ie one for which it functioned) in the previous NB
    gbars = [20] * numSC
    
    ######### LOOPING #########
    
    for nodeTemp in range(numNodes):  
        #Weighting matrix
        #WM = np.random.randint(low=998,high=1002,size=maxParams*4+1)/1000
        #Always breaks when doing it the above method
        WM = [1]*(len(x)) #get_WM(numNodes)
        
        #Source: https://www.math.mcgill.ca/gantumur/docs/reps/RyanSicilianoHH.pdf
        #reversal potentials
        ENa = 55.17*WM[0] #mV
        EK = -72.14*WM[1] #mV
        EL = -57.99*WM[2] #-49.42 #mV
        #E4
        #E5
        #membrane capacitance
        Cm = CmBase*WM[5] #uF/cm^2
        
        #conductances (S is Siemens)
        gbarK = 0.36*WM[6] #mS/cm2
        gbarNa = 1.2*WM[7] #mS/cm2
        gbarL = 0.003*WM[8] #mS/cm2
        #So really there should be other values here for the 4th and 5th gating variables should those get used
        #gbar?(4) = ___
        #gbar?(5) = ___

        #For now, defining every neuron the same, but can change the "n,m,h" values...
        gK = (1/Cm)*gbarK*(param_list[nodeTemp,1]**4) #Why did I have a +maxParams after the 1?
        gNa = (1/Cm)*gbarNa*(param_list[nodeTemp,2]**3)*param_list[nodeTemp,3]
        gL = (1/Cm)*gbarL
        #Again, would need to add something for the 4th and 5th gating variables should they exist
        #...
        
        #Simplification: check connection in external function
        vnode_couplings = []
        svars = []   
        if numNodes==1:
            svars = [0]*numSC #1 neuron, so no connections.  Just return 0s
        else:
            SC_repo = read_SC(nodeTemp,numNodes) #RETURNS A DICTIONARY

            if SC_repo is not None and len(SC_repo)>0:
                for idx in range(numNodes):
                    idx += 1 #Adjust for offset, first node is 1 not 0
                    if SC_repo.get(idx) is not None:
                        adj_idx = int(idx)-1
                        vnode_couplings.append(adj_idx) #This is the node that it is coupled to
                        sval = float(SC_repo.get(idx))
                        svars.append(sval) #This is the s val, to be used in gs

                my_zeros = [0]*(firstSIndex)
                svars_full = my_zeros+svars
                vnode_couplings_full = my_zeros+vnode_couplings
            else:
                print("Error: SC_repo returned None")
        gs = np.array(gbars) * np.array(svars) #* (1/Cm)

        #Define the steady state (inf) values, and the time constants (tau)
        #Note that this must be completed first, so that all the taus, infs are defined before we can make the ODEs
        for param in range(numGVs):
            #inf(V, v0, sigma)
            infs[nodeTemp,param] = inf_func(x[nodeTemp,0], x[nodeTemp,param*numGVs+1], x[nodeTemp,param*numGVs+2])
            #tau(V, mean)
            taus[nodeTemp,param] = tau_func(x[nodeTemp,0], x[nodeTemp,param*numGVs+3])

        #Define each ODE    
        for param in range(numParams): 
            if param==0:
                #dVdt = (gNa*(ENa-V) + gK*(EK-V) + gL*(EL-V) + gs1*(ES1-Va) + ... + gsn*(ESn-Vz) + I(t,vI))
                
                #TERM 1
                ionChannels = gNa*(ENa-x[nodeTemp,0]) + gK*(EK-x[nodeTemp,0]) + gL*(EL-x[nodeTemp,0])
                
                #TERM 2
                synCoups = 0
                for idx in range(numSC):
                    synCoups += gs[idx]*(ES[idx]-x[nodeTemp,0])
                    #^ Terms: (conducance gs_x) * (ES - V)

                #TERM 3
                appliedCurrent = I(t,vI)*node_current[nodeTemp]

                dxdts[nodeTemp,param] = ionChannels + synCoups + appliedCurrent
            elif param<(numGVs+1): #ie the gating variables
                #Note we use [nodeTemp,param-1] because there is no tau/inf for V so the matrix tau starts with n @ index 0
                dxdts[nodeTemp,param] = diffEQ_func(taus[nodeTemp,param-1], infs[nodeTemp,param-1], x[nodeTemp,param]) 
            elif param>=firstSIndex and sum(svars_full)==0: #ie if all the s vars are equal to zero
                break #ie just leave them as zero, and we are done with the loop so we can just break
            elif param>=firstSIndex:
                
                #STILL MANUAL
                tau_s = 30 #ms
                s_inf = 10.0/11.0 
                beta_s = 1/tau_s
                
                if numNodes==1:
                    pass #ie dsdts remain zeros
                else:
                    s = svars_full[param]
                    if s==0:
                        #Not sure if this will ever happen
                        #Goal is to get around cases where no s is passed in so s is 0
                        #^Maybe initilize that array to something biophysically impossible (e.g. can s be neagtive?)
                        dxdts[nodeTemp,param] = 0 
                    else:
                        coupled_neuron = vnode_couplings_full[param]
                        #coupled_neuron-1 because Neuron 1 corresponds to row 0
                        alpha_s = (x[coupled_neuron-1,0]>0)*(1/tau_s)

                        dxdts[nodeTemp,param] = alpha_s*(1-s)-beta_s*s 
                
                #This should be the last case, it will run once the way it currently configured
                #break
            else:
                pass
                #Should just be equal to zero, so leave as is (matrix is initialized as zeros)
    
    #Rearrange dxdts into a list so that we can unpack it
    flat_array = dxdts.flatten()
    ODE_list = flat_array.tolist()
    
    return ODE_list

In [31]:
def set_parameters(numGVs,*args):
    #numv0s = numTaus = num sigmas = numGVs
    
    my_path = os.getcwd() + r"\\NetworkCouplingParams"
    filename = str(numGVs) + "GVs.txt"
    completeName = my_path + filename
    with open(completeName, "w+") as file_object:
        for GV in range(numGVs):
            v0 = args[0][GV*numGVs]
            sigma = args[0][GV*numGVs+1]
            tau = args[0][GV*numGVs+2]
            my_string = str(v0) + ", " + str(sigma) + ", " + str(tau) + "\n"
            file_object.write(my_string)
            

def get_parameters(numGVs):
    my_path = os.getcwd() + r"\\NetworkCouplingParams"
    filename = str(numGVs) + "GVs.txt"
    completeName = my_path + filename
    with open(completeName, "r") as file_object:
        param_list = [line.split(',') for line in file_object.readlines()]
    
    one_list = []
    for sublist in param_list:
        one_list = one_list + sublist
    one_list = ([item.strip('\n') and item.strip() for item in one_list])

    return one_list
    
    
def read_SC(node, numNodes):
    my_path = os.getcwd() + r"\\NetworkCouplingParams\\" + "x10_0_5.npy"
    my_array = np.load(my_path)
    
    my_string = str(my_array[node-1,:])
    my_string = my_string.replace('\n','')
    for char in my_string:
        if char in "[],\n/;:":
            my_string = my_string.replace(char,'')
    
    my_keys = [val+1 for val in range(numNodes)]
    my_vals = my_string.split()
    my_dict = {}
    count = 0
    for key in my_keys:
        my_dict[key] = my_vals[count]
        count += 1
    
    return my_dict

In [32]:
numGVs = 3
input_list = [-49,-36,-55,18,10,-8,2.785,0.258,2.810]
set_parameters(numGVs,input_list)

output_list = get_parameters(numGVs)
print(output_list)

['-49', '-36', '-55', '18', '10', '-8', '2.785', '0.258', '2.81']


In [None]:
node1Base = [V1, n1, m1, h1, 0, 0]

In [None]:
print("Burn in run, looking for NO oscillations")
#time vector
t = np.linspace(0,50,2000) #600
#vector needed 
vI = [0,0,0]

#How many nodes are you using?
numNodes = 1
node1 = node1Base + ([0]*(numNodes-1))

node1 = np.array(node1)
network_params = node1

#####################################################################################

x = odeint(odes_progen,network_params,t,args=(I,vI))
x = np.transpose(x)

plot_vnmh(x,t,numNodes,gatingVars=True)

#Now grab the burned in values from the previous run
network_params_BI = x[:,-1]

print("Actual run, using a current of 0.1 A @ 5 ms.  No oscillations expected")
#BURNED IN initial condition
vI = [5,7,0.1]
x = odeint(odes_progen,network_params_BI,t,args=(I,vI))
x = np.transpose(x)

plot_vnmh(x,t,numNodes,gatingVars=True)