In [1]:
import pickle
import numpy as np
import os
import sys
import plotly.graph_objs as go
import matplotlib.pyplot as plt

path = os.getcwd().split(os.sep +'GUI')[0]
if path not in sys.path:
    sys.path.append(path)

import neurolib.dashboard.layout as layout
import neurolib.dashboard.functions as functions
import neurolib.dashboard.data as data

In [2]:
# plot and save for all data points
from neurolib.models.aln import ALNModel
from neurolib.utils import plotFunctions as plotFunc
from neurolib.utils import costFunctions as cost

path = os.path.join(os.getcwd(), "plots_")

aln = ALNModel()
data.set_parameters(aln)

state_vars = aln.state_vars

In [3]:
##### LOAD BOUNDARIES
with open('boundary_bi.pickle','rb') as file:
    load_array= pickle.load(file)
boundary_bi_exc = load_array[0]
boundary_bi_inh = load_array[1]

with open('boundary_LC.pickle','rb') as file:
    load_array= pickle.load(file)
boundary_LC_exc = load_array[0]
boundary_LC_inh = load_array[1]

with open('boundary_LCbi.pickle','rb') as file:
    load_array= pickle.load(file)
boundary_LC_up_exc = load_array[0]
boundary_LC_up_inh = load_array[1]

In [4]:
def setcase(obj, active):
    set_case(case)

def set_case(old_case):
    
    b0 = button0.active
    b1 = button1.active
    b2 = button2.active
    b3 = button3.active
    b4 = button4.active
    
    if b4 == 0 and old_case[4] == '1':
        button3.buttons[3].visible = False
        button3.buttons[4].visible = False
        button3.buttons[5].visible = False
        if old_case[3] in ['3', '4', '5']:
            button3.active = 0
            b3 = 0
        
    elif b4 == 1 and old_case[4] == '0':
        button3.buttons[3].visible = True
        button3.buttons[4].visible = True
        button3.buttons[5].visible = True
    
    case_ = str(b0) + str(b1) + str(b2) + str(b3) + str(b4)        
    
    functions.setdefaultmarkersize(layout.markersize, scatter1)
    functions.setdefaultmarkersize(layout.markersize, scatter2)
    functions.setdefaultmarkersize(layout.markersize, scatter3)
    functions.setdefaultmarkersize(layout.markersize, scatter4)
    
    global exc_, inh_, no_c_, both_c_
    global exc_1, inh_1, lenx_1, leny_1, exc_2, inh_2, lenx_2, leny_2
    global exc_3, inh_3, lenx_3, leny_3, exc_4, inh_4, lenx_4, leny_4, cost1, cost2, cost3, cost4
    global img_path
    
    
    #print("switch case to ", case_)
    fig.layout.annotations = fig.layout.annotations[:9]
    fig.layout.images = []
    
    global case
    case = case_
    casepath = case[0] + case[1] + '1' + case[3] + case[4]
    
    readpath = '.' + os.sep + 'data' + os.sep + casepath + os.sep
    
    data_array = data.read_data(aln, readpath, case_)
    exc_, inh_, both_c_, no_c_ = data_array[0:4]
    exc_1, inh_1, lenx_1, leny_1, exc_2, inh_2, lenx_2, leny_2 = data_array[4:12]
    exc_3, inh_3, lenx_3, leny_3, exc_4, inh_4, lenx_4, leny_4, cost1, cost2, cost3, cost4 = data_array[12:]
    data.update_data(fig, exc_1, inh_1, exc_2, inh_2, exc_3, inh_3, exc_4, inh_4)
        
    img_path = set_image_path(case_)
        
    return [case_, img_path, exc_, inh_, no_c_, both_c_,
                exc_1, inh_1, lenx_1, leny_1, exc_2, inh_2, lenx_2, leny_2,
                exc_3, inh_3, lenx_3, leny_3, exc_4, inh_4, lenx_4, leny_4, cost1, cost2, cost3, cost4]
        
def set_image_path(str_case):
    img_path_ = 'plots_' + os.sep + str_case + os.sep
    return img_path_
    
def show_trace(trace, points, state):
    ind = points.point_inds
    
    if len(ind) == 0:
        return
        
    functions.setdefaultmarkersize(0, trace)
    functions.setmarkersize(ind[-1], layout.background_markersize, trace)
    functions.setdefaultmarkersize(layout.markersize, scatter1)
    functions.setdefaultmarkersize(layout.markersize, scatter2)
    functions.setdefaultmarkersize(layout.markersize, scatter3)
    functions.setdefaultmarkersize(layout.markersize, scatter4)

    data.plot_trace(aln, data_background.x[ind[-1]], data_background.y[ind[-1]], fig.data[7], fig.data[8])
    
def show_arrow_and_traces(trace, points, state):
    ind = points.point_inds
    
    if len(ind) == 0:
        return
        
    functions.setdefaultmarkersize(0, scatter_background)
    for tr in fig.data[1:5]:
        functions.setdefaultmarkersize(layout.markersize, tr)
    functions.setmarkersize(ind[-1], 2. * layout.markersize, trace)
    
    fig.layout.annotations = fig.layout.annotations[:9]
    
    if trace.uid == '1':
        fig.layout.annotations[8].text = layout.change_cost_layout(cost1[ind[-1]], case)
        e_, i_, lx_, ly_ = exc_1[ind[-1]], inh_1[ind[-1]], lenx_1[ind[-1]], leny_1[ind[-1]]
    elif trace.uid == '2':
        fig.layout.annotations[8].text = layout.change_cost_layout(cost2[ind[-1]], case)
        e_, i_, lx_, ly_ = exc_2[ind[-1]], inh_2[ind[-1]], lenx_2[ind[-1]], leny_2[ind[-1]]
    elif trace.uid == '3':
        fig.layout.annotations[8].text = layout.change_cost_layout(cost3[ind[-1]], case)
        e_, i_, lx_, ly_ = exc_3[ind[-1]], inh_3[ind[-1]], lenx_3[ind[-1]], leny_3[ind[-1]]
    elif trace.uid == '4':
        fig.layout.annotations[8].text = layout.change_cost_layout(cost4[ind[-1]], case)
        e_, i_, lx_, ly_ = exc_4[ind[-1]], inh_4[ind[-1]], lenx_4[ind[-1]], leny_4[ind[-1]]                                                                                
    
    data.plot_trace(aln, e_, i_, fig.data[7], fig.data[8])
    
    if trace.uid in ['1', '3']:
        xarrow, reshapex = functions.get_x_arrow(e_,i_,lx_)
        fig.add_annotation(xarrow)

        if reshapex != 1.:
            rescale_ann = functions.get_x_rescale_annotation(reshapex,e_,i_,lx_)
            fig.add_annotation(rescale_ann)
    
    if trace.uid in ['2', '3']:
        yarrow, reshapey = functions.get_y_arrow(e_,i_,ly_)
        fig.add_annotation(yarrow)

        if reshapey != 1.:
            rescale_ann = functions.get_y_rescale_annotation(reshapey,e_,i_,ly_)
            fig.add_annotation(rescale_ann)
            
    folder = int(trace.uid)
    
    img = layout.get_img(img_path, folder, ind[-1])
    fig.layout.images = []
    fig.add_layout_image(img)

In [5]:
global img_path
global case
global exc_, inh_, no_c_, both_c_
global exc_1, inh_1, lenx_1, leny_1, exc_2, inh_2, lenx_2, leny_2
global exc_3, inh_3, lenx_3, leny_3, exc_4, inh_4, lenx_4, leny_4, cost1, cost2, cost3, cost4

cmap = layout.getcolormap()
darkgrey, midgrey, lightgrey, color_bi_updown, color_LC, color_bi_uposc = layout.getcolors()

img_path = 'plots_' + os.sep + '00100' + os.sep
case = '00000'
pathcase = '00100'
readpath = '.' + os.sep + 'data' + os.sep + pathcase + os.sep

with open(readpath + os.sep + 'bi.pickle','rb') as file:
    load_array= pickle.load(file)
ext_exc = load_array[0]
ext_inh = load_array[1]

data_array = data.read_data(aln, readpath, case)
exc_, inh_, both_c_, no_c_ = data_array[0:4]
exc_1, inh_1, lenx_1, leny_1, exc_2, inh_2, lenx_2, leny_2 = data_array[4:12]
exc_3, inh_3, lenx_3, leny_3, exc_4, inh_4, lenx_4, leny_4, cost1, cost2, cost3, cost4 = data_array[12:]

data1, data2, data3, data4 = data.get_scatter_data(exc_1, inh_1, exc_2, inh_2, exc_3, inh_3, exc_4, inh_4)

data_background = data.get_data_background(exc_1, inh_1, exc_2, inh_2, exc_3, inh_3, exc_4, inh_4)

trace00, trace01 = data.get_step_current_traces(aln)
trace10, trace11 = layout.get_empty_traces()

bistable_regime = layout.get_bistable_paths(boundary_bi_exc, boundary_bi_inh)
oscillatory_regime = layout.get_osc_path(boundary_LC_exc, boundary_LC_inh)
LC_up_regime = layout.get_LC_up_path(boundary_LC_up_exc, boundary_LC_up_inh)


fig = go.FigureWidget([data_background, data1, data2, data3, data4, trace00, trace01, trace10, trace11])
fig.update_layout(layout.get_layout())
fig.update_layout(updatemenus=layout.get_updatemenus())
fig.update_layout(shapes=[bistable_regime, oscillatory_regime, LC_up_regime])

fig.add_annotation(layout.get_label_bistable())
fig.add_annotation(layout.get_label_osc())
fig.add_annotation(layout.get_label_osc_up())
fig.add_annotation(layout.get_label_down())
fig.add_annotation(layout.get_label_up())
fig.add_annotation(layout.get_info_text())
fig.add_annotation(layout.get_label_exc())
fig.add_annotation(layout.get_label_inh())
fig.add_annotation(layout.get_label_cost())

fig.update_annotations()

scatter_background = fig.data[0]
scatter1 = fig.data[1]
scatter2 = fig.data[2]
scatter3 = fig.data[3]
scatter4 = fig.data[4]

scatter1.uid = '1'
scatter2.uid = '2'
scatter3.uid = '3'
scatter4.uid = '4'

button0 = fig.layout.updatemenus[0]
button1 = fig.layout.updatemenus[1]
button2 = fig.layout.updatemenus[2]
button3 = fig.layout.updatemenus[3]
button4 = fig.layout.updatemenus[4]
#button5 = fig.layout.updatemenus[5]


scatter1.on_click(show_arrow_and_traces)
scatter2.on_click(show_arrow_and_traces)
scatter3.on_click(show_arrow_and_traces)
scatter4.on_click(show_arrow_and_traces)
scatter_background.on_click(show_trace)

button0.on_change(setcase, 'active')
button1.on_change(setcase, 'active')
button2.on_change(setcase, 'active')
button3.on_change(setcase, 'active')
button4.on_change(setcase, 'active')
#button5.on_change(setcase, 'active')

display(fig)


plotly.graph_objs.Annotation is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.layout.Annotation
  - plotly.graph_objs.layout.scene.Annotation




FigureWidget({
    'data': [{'hoverinfo': 'x+y',
              'marker': {'color': 'rgb(100,100,100)',
       …

In [6]:
print('Current raw data: ',  3. * (17.7 + 53.4) * 18 * 2, ' MB')

Current raw data:  7678.799999999999  MB
