In [5]:
%matplotlib inline
# from ipywidgets import interactive
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import time
from scipy import special
from scipy import optimize as opt
from scipy import integrate as scint
from scipy.stats import norm
from scipy.stats import binom
import scipy
import tqdm as tqdm
trange = lambda *x: tqdm.tqdm(range(*x))
import ipywidgets as widgets
from IPython.display import display


font = {'size'   : 10}
plt.rc('font', **font)


#hill fucntion for dose-activity relation
def drug(EC50,h, dose):
    if dose == 0:
        return 1
    else:
       return 1/(1+pow(dose/EC50,h))

#neigbourhood matrix, k is average connection amount per neuron
def make_matrix(N,M,muee,muei,muie,muii,k):
    #initialize variables: J, \phi, x
    Jmask = np.random.binomial(1,k,(N+M,N+M))/N
    J = np.ones((N+M,N+M))*muee
    J[N:,N:] = np.ones((M,M))*muii
    J[N:,:N] = np.ones((M,N))*muie
    J[:N,N:] = np.ones((N,M))*muei
    return J*Jmask, (Jmask*N).sum(axis=1)



    
def grad(c1,c2,mix=0): #fade (linear interpolate) from color c1 (at mix=0) to c2 (mix=1)
    c1=np.array(mpl.colors.to_rgb(c1))
    c2=np.array(mpl.colors.to_rgb(c2))
    return mpl.colors.to_hex((1-mix)*c1 + mix*c2)
#run simulation - analyse again for redundant vars and analyze again
def interpolate(a,b,p=0.5):
    return (b-a)*p + a
    

def run(b):
    k_array = [abs(int(ele)) for ele in k_list.value.split()]
    
    with output:
        if isauto.value:
            display_once = False
            
            for current_k in k_array:
                print("calculating for k = " + str(current_k))
                
                drug_concentrations = []
                activity_before_drug_i = []
                activity_before_drug_e = []
                activity_after_drug_i = []
                activity_after_drug_e = []
                activity_decrease_e = []
                activity_decrease_i = []
                
                for concentration_step in range(0,concentration_steps.value+1):
                    
                    p = 0
                    if concentration_steps.value != 0:
                       p = concentration_step/concentration_steps.value
                    current_concentration = interpolate(concentration_start.value,concentration_end.value,p)
                    drug_concentrations.append( current_concentration)
                    
                    print("k: ",current_k," from: ",concentration_start.value," to ", concentration_end.value, " current: ",current_concentration)
                    
                    phiemeans = []
                    phiimeans = []
                    workingemean = []
                    workingimean = []
                    drugemean = []
                    drugimean = []
                    con_counts = []
                   
                    for i in range(0,k_reps.value):
                        times,pem,pim,wem,wim,dem,dim,ccount = sim(display_once,exc_amount.value,inh_amount.value,current_k,dec50.value,dhs.value ,current_concentration,experimenttime.value,cutouttime.value,drugtime_start.value,drugtime_end.value)
                        phiemeans.append( pem)
                        phiimeans.append( pim)
                        workingemean.append( wem)
                        workingimean.append( wim)
                        drugemean.append( dem)
                        drugimean.append( dim)
                        con_counts.append( ccount)
    
                    
                        display_once = False
                    if display_each_plot.value:
                        for i in range(0,k_reps.value):
                            print(workingemean[i],workingimean[i],drugemean[i],drugimean[i],np.average(con_counts[i]))
                        for i in range(0,k_reps.value):
                            plt.plot(times,phiemeans[i],color = grad("#0000FF", "#00FFFF", i/k_reps.value))
                        for i in range(0,k_reps.value):
                            plt.plot(times,phiimeans[i],color = grad("#FF0000", "#FFFF00", i/k_reps.value))
                        legends = []
                        
                        
                        for i in range(0,k_reps.value):
                            legends.append('Excitatory'+ str(i+1))
                        for i in range(0,k_reps.value):
                            legends.append('Inhibitory'+ str(i+1))
                        plt.legend(legends,bbox_to_anchor=(1.7, 1.00))
                        plt.show()
    
                    
                    activity_before_drug_i.append(np.average(workingimean))
                    activity_before_drug_e.append(np.average(workingemean))
                    activity_after_drug_i.append(np.average(drugimean))
                    activity_after_drug_e.append(np.average(drugemean))
                    activity_decrease_i.append(np.average(drugimean)/np.average(workingimean))
                    activity_decrease_e.append(np.average(drugemean/np.average(workingemean)))

                print("concentrations:")
                print(drug_concentrations)

                print("inhibitory activity before drug introduction:")
                print(activity_before_drug_i)
                print("inhibitory activity after drug introduction:")
                print(activity_after_drug_i)
                print("inhibitory activity decrease:")
                print(activity_decrease_i)
                
                print("excitatory activity before drug introduction:")
                print(activity_before_drug_e)
                print("excitatory activity after drug introduction:")
                print(activity_after_drug_e)
                print("excitatory activity decrease:")
                print(activity_decrease_e)      
                plt.title("numbers of connections")
                plt.hist(con_counts) 
                plt.show()
                plt.title("k = " + str(current_k))
                plt.plot(drug_concentrations,activity_decrease_e,color="#0000ff")
                plt.plot(drug_concentrations,activity_decrease_i,color="#FF0000")
                plt.ylabel('activity %')
                plt.xlabel('drug concentration (nM)')
                plt.ylim((0, 1))
                plt.legend(['Excitatory','Inhibitory'],bbox_to_anchor=(1.7, 1.00))
                locs,labels = plt.xticks()
                #plt.xticks(locs, map(lambda drug_concentrations: "%g" % drug_concentrations, locs))
                plt.ticklabel_format(style='plain', axis='x')
                plt.show()

        else:
            times,phiemeans,phiimeans,workingemean,workingimean,drugemean,drugimean,con_count = sim(True,exc_amount.value,inh_amount.value,k_array[0],dec50.value,dhs.value ,concentration_start.value,experimenttime.value,cutouttime.value,drugtime_start.value,drugtime_end.value)
            
            print("E: ", exc_amount.value," I: ", inh_amount.value," k approx: ",np.average(con_count))
            plt.hist(con_count)
            plt.show()
            print("healthy excitatory activity: ",workingemean)
            print("healthy inhibitory activity: ",workingimean)
            print("drugged excitatory activity: ",drugemean)
            print("drugged inhibitory activity: ",drugimean)
            
            echange = (drugemean/workingemean)
            ichange = (drugimean/workingimean)
            
            print("drugged excitatory activity:",echange*100,"%")
            print("drugged inhibitory activity:",ichange*100,"%")
            plt.plot(times,phiemeans,'b')
            plt.plot(times,phiimeans,'r')
            plt.ylim((0, 1))
            plt.ylabel('activity')
            plt.xlabel('time')
            plt.legend(['Excitatory','Inhibitory'],bbox_to_anchor=(1.7, 1.00))
        
            locs,labels = plt.xticks()
           # plt.xticks(locs, map(lambda drug_concentrations: "%g" % drug_concentrations, locs))
            plt.ticklabel_format(style='plain', axis='x')
            plt.show()


def sim(shownonlin,N,M,k,drugec50,drughill,c,simexperimenttime,simstabletime,simdrugtime_start,simdrugtime_end):

        p = k/(N+M)
        
        m_0 = 0
        w_0E = 1.0
        w_0I = 0.8
        
        I_E = w_0E*p*m_0
        I_I = w_0I*p*m_0
        
        sigma_0E = .75 #.5 for both #low for high ignition
        sigma_0I = .75
        
        theta_E = 1.0 #.8 for both
        theta_I = 1.0
        taui = 3.3 # integration time constant
        g = 10 # non linearity gain
        
        
        nonlin = lambda x,thresh,g : 1./(1+np.exp(g*(thresh-x)))
    
        if shownonlin:
            sampXs = np.linspace(-1,2,301)
            sampPhis = nonlin(sampXs,theta_E,g)
           # plt.figure(figsize=(8,2))
            plt.title('Nonlinearity, Threshold 1, Gain 10')
            plt.plot(sampXs,sampPhis)
            plt.show()
            
        wBarAB = np.zeros((2,2))
        const = 1.0
        
        wBarAB[0,0], wBarAB[0,1] = 10*const, -7.5*const
        wBarAB[1,0], wBarAB[1,1] = 7.5*const, -0.5*const
        
     
        dt = .03
        taui=3.3 #high cycles
        
        pe = np.random.rand()
        pi = np.random.rand()
     
        J, connection_count = make_matrix(N,M,wBarAB[0,0],wBarAB[0,1],wBarAB[1,0],wBarAB[1,1],p)
      
        
        x = np.random.rand(N+M,1)
        s = np.zeros((N+M,1))
        s[:int(pe*N),0] = np.ones(int(pe*N))
        s[N:N+int(pi*M),0] = np.ones(int(pi*M))
        s[:N,0] = np.random.permutation(s[:N,0])
        s[N:,0] = np.random.permutation(s[N:,0])
        
        steps = int(simexperimenttime/dt)
        times = np.arange(steps)*dt
        
        phiemeans = np.zeros(steps)
        phiimeans = np.zeros(steps)
        phistds = np.zeros(steps)
        xmeans = np.zeros(steps)
        xs = np.zeros((steps,N+M))
        
        for i in trange(steps):
            NoiseInputs = np.random.normal(0,sigma_0E*np.sqrt(dt),size=(N+M,1))
            NoiseInputs[N:] = np.random.normal(0,sigma_0I*np.sqrt(dt),size=(M,1))
            
            s[:N,0] = nonlin(x[:N,0],theta_E,g)
            s[N:,0] = nonlin(x[N:,0],theta_I,g)

            if i*dt > simdrugtime_start and  i*dt < simdrugtime_end:
                s[:] *= drug(drugec50,drughill,c)

            u = np.matmul(J,s)
            u[:N] = u[:N] + I_E
            u[N:] = u[N:] + I_I
            
            x[:N] = (1-dt)*x[:N] + dt*u[:N] + NoiseInputs[:N] # noise scales with sigma_0 * sqrt(dt)
            x[N:] = (1- dt/taui)*x[N:] + (dt/taui)*u[N:] + (1/taui)*NoiseInputs[N:] # noise scales with sigma_0 * sqrt(dt)
            
            phiemeans[i] = np.mean(s[:N])
            phiimeans[i] = np.mean(s[N:])
            phistds[i] = np.std(s)
            xmeans[i] = np.mean(x)
            xs[i,:] = x[:,0]
        
        workingemean = np.mean(phiemeans[int(simstabletime/dt):int(simdrugtime_start/dt)])
        workingimean = np.mean(phiimeans[int(simstabletime/dt):int(simdrugtime_start/dt)])
        drugemean = np.mean(phiemeans[int(simdrugtime_start/dt):int(simdrugtime_end/dt)])
        drugimean = np.mean(phiimeans[int(simdrugtime_start/dt):int(simdrugtime_end/dt)])
    
       

        return times,phiemeans,phiimeans,workingemean,workingimean,drugemean,drugimean,connection_count
    
        

# GUI
style = {'description_width': '200px'}
layout = widgets.Layout(width='400px') #set width and height

isauto = widgets.Checkbox(
    value=False,
    description='Run many runs?',
    disabled=False,
    indent=False,
    style = style,
    layout=layout,
)
display_each_plot = widgets.Checkbox(
    value=True,
    description='Display each plot?',
    disabled=False,
    indent=False,
    style = style,
    layout=layout,
)
exc_amount = widgets.IntText(
    value=1500,
    min = 0,
    max = 10e20,
    description='Excitatory neurons',tooltip='Excitatory neurons:',
    disabled=False,
    style = style,
    layout=layout
)
inh_amount = widgets.IntText(
    value=1500,
    min = 0,
    max = 10e20,
    description='Inhibitory neurons:',tooltip='Inhibitory neurons:',
    disabled=False,
    style = style,
    layout=layout

)

k_list = widgets.Text(
    value='3 6 10 100 400',
    placeholder='1',
    description='Average connection counts:',
    style = style,
    layout=layout
)

k_reps = widgets.IntText(
    value=10,
    min = 1,
    max = 10e20,
    step=1,
    description='Simulation repetitions:',tooltip='Simulation repetitions:',
    disabled=False,
    style = style,
    layout=layout
)

experimenttime = widgets.BoundedFloatText(
    value=60.0,
    step=0.1,
    min=0.1,
    max=3600,
    description='Experiment duration (s):', tooltip='Experiment duration:',
    disabled=False,
    style = style,
    layout=layout
)

drugtime_start = widgets.BoundedFloatText(
    value=30.0,
    step=0.1,
    min=0,
    max = experimenttime.value,
    description='Drug introduction time (s):', tooltip='Drug introduction time:',
    disabled=False,
    style = style,
    layout=layout
)
drugtime_end = widgets.BoundedFloatText(
    value=experimenttime.value,
    step=0.1,
    min=drugtime_start.value,
    max = experimenttime.value,
    description='Drug washout time (s): ', tooltip='Drug washout time:',
    disabled=False,
    style = style,
    layout=layout
)

cutouttime = widgets.BoundedFloatText(
    value=10.0,
    step=0.1,
    min=0,
    max = drugtime_start.value,
    description='Activity cutout time (s):', tooltip='Activity cutout time:',
    disabled=False,
    style = style,
    layout=layout
)



concentration_start = widgets.BoundedFloatText(
    value=0.0,
    min=0,
    max=10e20,
    step=0.01,
    description='Drug concentration start (nM) :',    tooltip='Drug concentration start (nM) :',
    disabled=isauto.value,
    style = style,
    layout=layout
)
concentration_end = widgets.BoundedFloatText(
    value=0.0,
    min=0,
    max=10e20,
    step=0.01,
    description='Drug concentration end (nM) :',    tooltip='Drug concentration end (nM) :',
    disabled=isauto.value,
    style = style,
    layout=layout
)
concentration_steps = widgets.IntText(
    value=1,
    min=1,
    max=10e20,
    step=1,
    description='Drug concentration steps:',    tooltip='Drug concentration steps :',
    disabled=isauto.value,
    style = style,
    layout=layout
)
dhs = widgets.BoundedFloatText(
    value=1,
    step=0.001,
    max=100,
    min=0,
    description='Drug Hill slope :',  tooltip='Drug Hill slope :',
    style = style,
    layout=layout
)
dec50 = widgets.BoundedFloatText(
    value=1,
    step=0.1,
    min=0,
    max=10e20,
    description='Drug EC50 (nM):', tooltip='Drug EC50 :',
    style = style,
    layout=layout
)


enterbutton = widgets.Button(
    description='Run',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''

    icon='check' # (FontAwesome names without the `fa-` prefix)
)
clearbutton = widgets.Button(
    description='Clear',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click me',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)




output = widgets.Output()

def clear(b):
    output.clear_output()



print("Tetradoxin:          EC50 = 10.400111, Hill Slope = 0.9690177" )
print("Tetraethyloammonium: EC50 =  73139000, Hill Slope = 0.7834354" )
print("Cocaine:             EC50 = 27027.600, Hill Slope = 0.9484492" )


tab = widgets.Tab()
drug_tab = widgets.VBox([experimenttime,cutouttime,drugtime_start,drugtime_end,dec50,dhs,concentration_start,concentration_end,concentration_steps])
run_tab = widgets.VBox([isauto,display_each_plot,exc_amount,inh_amount,k_list,k_reps])
tab.children =[ drug_tab,run_tab]
tab.titles = ['Drug Parameters', 'Simulation Parameters']

display(tab)
display(widgets.HBox([enterbutton, clearbutton]))
display(output)


enterbutton.on_click(run)
clearbutton.on_click(clear)


Tetradoxin:          EC50 = 10.400111, Hill Slope = 0.9690177
Tetraethyloammonium: EC50 =  73139000, Hill Slope = 0.7834354
Cocaine:             EC50 = 27027.600, Hill Slope = 0.9484492


Tab(children=(VBox(children=(BoundedFloatText(value=60.0, description='Experiment duration (s):', layout=Layou…

HBox(children=(Button(description='Run', icon='check', style=ButtonStyle()), Button(description='Clear', icon=…

Output()