In [None]:
import pytc
import plotly.plotly as py
import ipywidgets as widgets

from IPython.display import display, clear_output
from IPython.html.widgets import interact, interactive
from plotly.widgets import GraphWidget
import seaborn as sns

In [None]:
class Sliders():
    def __init__(self, exp, fitter, param_name, global_list):
        """
        """
        self._var_opt = {'link': ['blank'], 'unlink': ['unlink']}
        
        self._loc_link = widgets.Dropdown(value = self._var_opt['unlink'], options = self._var_opt)
        self._glob_link = widgets.Dropdown(value = self._loc_link.value[0], options = self._loc_link.value)
        self._fixed_check = widgets.Checkbox(value = False)
        self._slider = widgets.FloatSlider()
        self._fixed_int = widgets.FloatText(display = 'none')
        self._s_min = widgets.FloatText()
        self._s_max = widgets.FloatText()
        self._exp = exp
        self._fitter = fitter
        self._param_name = param_name
        self._global_list = global_list
        
        self._s_min.layout.width = '100px'
        self._s_max.layout.width = '100px'
        self._fixed_int.layout.width = '50px'
        self._fixed_int.layout.display = 'none'
    
    def logic(self):
    
        self._fixed_check.observe(self.check_change, 'value')
        self._loc_link.observe(self.link_change, 'value')
        self._s_min.observe(self.min_change, 'value')
        self._s_max.observe(self.max_change, 'value')
        self._slider.observe(self.param_change, 'value')

    #bound not range, but update range
    def min_change(self, min_val):

        self._slider.min = min_val['new']
        self.update_bounds(self._slider.min, self._slider.max)

    def max_change(self, max_val):

        self._slider.max = max_val['new']
        self.update_bounds(self._slider.min, self._slider.max)
        
    def update_bounds(self, s_min, s_max):
        
        pass

    def check_change(self, val):

        if val['new']:
            self._slider.layout.display = 'none'
            self._fixed_int.layout.display = ''
            self._fitter.update_fixed(self._param_name, self._fixed_int.value, self._exp)
        elif val['new'] == False and self._loc_link.value[0] == 'link':
            self._slider.layout.display = 'none'
            self._fixed_int.layout.display = 'none'
        elif val['new'] == False and self._loc_link.value[0] == 'unlink': 
            self._slider.layout.display = ''
            self._fixed_int.layout.display = 'none'

        #fixed and update fixed

    def link_change(self, select):

        if select['new'][0] == 'unlink' and self._fixed_check.value == False:
            self._slider.layout.display = ''
            self._fixed_int.layout.display = 'none'
        elif select['new'][0] == 'unlink' and self._fixed_check.value:
            self._slider.layout.display = 'none'
            self._fixed_int.layout.display = ''
        else:
            self._slider.layout.display = 'none'
            self._fixed_int.layout.display = 'none'

    def create_global(self, glob_var):

        if glob_var != 'unlink' and glob_var != 'blank':
            try:
                self._fitter.link_to_global(self._exp, self._param_name, glob_var)
            except:
                pass

    def create_local(self, loc_var):

        self._var_opt['link'] = self._global_list
        self._loc_link.options = self._var_opt
        
        if loc_var[0] == 'unlink':
            try:
                self._fitter.unlink_from_global(self._exp, self._param_name)
            except: 
                pass
        else:
             self._glob_link.options = loc_var
                
    def param_change(self, param_val):
        
        guess = param_val['new']
        
        self._fitter.update_guess(self._param_name, guess, self._exp)
    
    def build_sliders(self):
        """
        build sliders!
        """
        
        self.logic()
        
        bounds_label = widgets.Label(value = "bounds: ")
        name_label = widgets.Label(value = "{}: ".format(self._param_name))
        loc_inter = widgets.interactive(self.create_local, loc_var = self._loc_link)
        glob_inter = widgets.interactive(self.create_global, glob_var =  self._glob_link)

        box = widgets.HBox(children = [name_label, self._fixed_check, self._slider, loc_inter, glob_inter, self._fixed_int])
        box.layout.width = '90%'

        min_max = widgets.HBox(children = [bounds_label, self._s_min, self._s_max])

        return min_max, box
        
class LocalSliders(Sliders):
    def __init__(self, exp, fitter, param_name, global_list):
        super().__init__(exp, fitter, param_name, global_list)
        
    def update_bounds(self, s_min, s_max):
    
        bounds = [s_min, s_max]
        self._fitter.update_bounds(self._param_name, bounds, self._exp)
        
        # check if bounds are smaller than range, then update.
        curr_range = self._exp.model.param_guess_ranges[self._param_name]
        curr_bounds = self._exp.model.bounds[self._param_name]
        
        if curr_range[0] < curr_bounds[0] or curr_range[1] > curr_bounds[1]:
            self._fitter.update_range(self._param_name, bounds, self._exp)
                
    def build_sliders(self):
        
        exp_range = self._exp.model.param_guess_ranges[self._param_name]
        self._slider.min = self._exp_range[0]
        self._slider.max = self._exp_range[1]
        
        super().build_sliders()
        
class GlobalSliders(Sliders):
    def __init__(self, exp, fitter, param_name, global_list):
        super().__init__(exp, fitter, param_name, global_list)
        
        exp_range = self.fitter.param_ranges[0][self._param_name]
        
    def update_bounds(self, s_min, s_max):
    
        bounds = [s_min, s_max]
        self._fitter.update_bounds(self._param_name, bounds, self._exp)
        
        # check if bounds are smaller than range, then update.
        curr_range = self.fitter.param_ranges[0][self._param_name]
        curr_bounds = self.fitter.param_bounds[0][self._param_name]
        
        if curr_range[0] < curr_bounds[0] or curr_range[1] > curr_bounds[1]:
            self._fitter.update_range(self._param_name, bounds, self._exp)
               
    def build_sliders(self):

        exp_range = self.fitter.param_ranges[0][self._param_name]
        self._slider.min = self._exp_range[0]
        self._slider.max = self._exp_range[1]
        
        super().build_sliders()

In [None]:
class ParamCollect():
    def __init__(self, gui, container, fitter):
        
        self._gui = gui
        self._exp_id = ''
        self._exp_val = ''
        self._widgets = []
        self._fitter = fitter
        self._container = container
        self._parameters = {}
        self._models = {"blank" : pytc.models.Blank,
          "single site" : pytc.models.SingleSite,  
          "single site competitor" : pytc.models.SingleSiteCompetitor, 
          "binding polynomial" : pytc.models.BindingPolynomial}
        self._current_model = ''
        
    def remove_exp(self, b):
        """
        remove experiment from analysis and close widgets
        """
        try:
            self._gui.remove_experiment(self._exp_id)
        except:
            clear_output()
            print("no experiment linked")
        
        self._widgets[3].close()
        
        for i in self._container:
            if self._exp_id == i[0]._exp_id:
                self._container.remove(i)
    
    def parameters(self):
        """
        get parameters for experiment
        """
        pass
        
    def create_exp(self):
        """
        create a new pytc experiment
        """
        self._exp_val = self._widgets[0].value
        if self._exp_val:
            self._current_model = self._widgets[1].value
            self._exp_id = pytc.ITCExperiment(self._exp_val, self._current_model)
            #self.parameters()
        else:
            clear_output()
            print("no exp data given")
        
    @property        
    def exp_id(self):
        """
        return experiment id
        """
        
        return self._exp_id
    
    
    def gen_sliders(self):
        """
        generate sliders for each experiment, give option to link to global.
        """
        
        pass
    
    def gen_exp(self):
        """
        generate widgets for experiment.
        """
        pass
    
class Experiments(ParamCollect):
    """
    create experiment object and generate widgets
    """
    def __init__(self, gui, container, fitter):
        super().__init__(gui, container, fitter)
        self._sliders = []
    
    @property
    def parameters(self):
        """
        generate local parameters for experiment.
        """
        
        global_param, local_param = self._fitter.param_names
        global_guesses, local_guesses = self._fitter.param_guesses
        global_ranges, local_ranges = self._fitter.param_ranges
        global_fixed, local_fixed = self._fitter.fixed_param 
        
        self._parameters = {"name": local_param, 
                            "value": local_guesses, 
                            "ranges": local_ranges, 
                            "fixed": local_fixed}
        
        return self._parameters
    
    def gen_sliders(self):
        """
        generate sliders for each experiment, give option to link to global.
        """
        parameters = self.exp_id.param_values
        
        for p in parameters.keys():
            s = LocalSliders(self.exp_id, self._fitter, p, global_var)
            self._sliders.append(s)
            
        return self._sliders
        
    
    def gen_exp(self):
        """
        generate widgets for experiment.
        """
        exp_field = widgets.Text(description = "exp: ")
        model_drop = widgets.Dropdown(options = self._models, value = self._models["blank"])

        rm_exp = widgets.Button(description = "remove experiment")
        rm_exp.on_click(self.remove_exp)

        exp_box = widgets.HBox(children = [exp_field, model_drop, rm_exp])
        self._widgets.extend([exp_field, model_drop, rm_exp, exp_box])
        
        return exp_box
    
class GlobalExp(ParamCollect):
    """
    create experiment object and generate widgets
    """
    def __init__(self, gui, container, fitter, v_name):
        
        super().__init__(gui, container, fitter)
        self._v_name = v_name
    
    def parameters(self):
        """
        generate local parameters for experiment.
        """
        
        global_param, local_param = self._fitter.param_names
        global_guesses, local_guesses = self._fitter.param_guesses
        global_ranges, local_ranges = self._fitter.param_ranges
        global_fixed, local_fixed = self._fitter.fixed_param 
        
        self._parameters = {"name": self._v_name, 
                            "value": global_guesses, 
                            "ranges": global_ranges, 
                            "fixed": global_fixed}
        
        return self._parameters
    
    def gen_sliders(self):
        """
        generate sliders for each experiment, give option to link to global.
        """
        sliders = GlobalSliders(self._parameters, self._fitter, self)
        
        return sliders.gen_sliders()
    
    def gen_exp(self):
        """
        generate widgets for experiment.
        """
        exp_field = widgets.Text(description = "exp: ")
        model_drop = widgets.Dropdown(options = self._models, value = self._models["blank"])

        rm_exp = widgets.Button(description = "remove experiment")
        rm_exp.on_click(self.remove_exp)

        exp_box = widgets.HBox(children = [exp_field, model_drop, rm_exp])
        self._widgets.extend([exp_field, model_drop, rm_exp, exp_box])
        
        return exp_box

In [None]:
%matplotlib inline

class Interface:
    
    def __init__(self,fitter):
        """
        """
        
        self._global_sliders = {}
        self._local_sliders = []
        self._fitter = fitter
        self._experiments = []
        self._param = []
        
    def view_exp(self):
        
        return self._experiments
    
    def fit(self):
        
        self._fitter.fit()
        
    def name(self):
        
        return self._fitter.param_names
    
    def add_experiment(self,expt):
        
        self._fitter.add_experiment(expt)
        self._experiments.append(expt)
        expt.initialize_param()
        
    def remove_experiment(self, expt):
        
        self._fitter.remove_experiment(expt)
        self._experiments.remove(expt)
            
    def reset_sliders(self):
        
        for s in self._local_sliders:
            for slider in s:
                s[slider].close()
            s.clear()
            
        for p in self._param:
            p.close()
            self._param = []
            
    def build(self):
        
        for e in self._experiments:
            sliders = e.gen_sliders()
            for s in sliders:
                s.build_sliders()
        
    
    def build_interface(self):
        """
        """

        global_param, local_param = self._fitter.param_names
        global_guesses, local_guesses = self._fitter.param_guesses
        global_ranges, local_ranges = self._fitter.param_ranges
        global_fixed, local_fixed = self._fitter.fixed_param 
        
        all_widgets = {}
        
        for p in global_param:
        
            g_min = global_ranges[p][0]
            g_max = global_ranges[p][1]
            g_val = global_guesses[p]
            
            self._global_sliders[p] = widgets.FloatSlider(min=g_min,max=g_max,value=g_val)
            
            all_widgets["{}".format(p)] = self._global_sliders[p]
            
            #for k, v in global_param.items():
            #    temp = widgets.FloatText(value = v, description = "{}: ".format(k))
            #    self._param.append(temp)
            #    display(temp)
    
        for i in range(len(self._experiments)):            
            
            self._local_sliders.append({})
        
            for p in local_param[i]:
            
                g_min = local_ranges[i][p][0]
                g_max = local_ranges[i][p][1]
                g_val = local_guesses[i][p]
                
                self._local_sliders[-1][p] = widgets.FloatSlider(min=g_min,max=g_max,value=g_val)
        
                all_widgets["{},{}".format(p,i)] = self._local_sliders[-1][p]
            
            global_param, local_param = self._fitter.fit_param
            #global_error, local_error = self._fitter.fit_error

            for p in local_param:
                for k, v in p.items():
                    temp = widgets.FloatText(value = v, description = "{}: ".format(k))
                    self._param.append(temp)
                    #all_widgets["{}".format(p)] = temp
                

        w = widgets.interactive(self._update,**all_widgets)
                        
        display(w)


    def _update(self,**kwargs):
        """
        """
        
        for k in kwargs.keys():
            if len(k.split(",")) == 1:
                self._fitter.update_guess(k,kwargs[k])
            else:
                param_name = k.split(",")[0]
                expt = self._experiments[int(k.split(",")[1])]
                self._fitter.update_guess(param_name,kwargs[k],expt)
                
        global_param, local_param = self._fitter.fit_param
        global_error, local_error = self._fitter.fit_error

        for param, error in zip(local_param, local_error):
            for p, e in zip(param, error):
                print(p, ': ', param[p], error[e])
        
        for param, error in zip(global_param, global_error):
            for p, e in zip(param, error):
                print(p, ': ', param[p], error[e])
       
        self._fitter.fit()
        self._fitter.plot()
        
    def get_param(self):
        
        return self._fitter.fit_param, self._fitter.fit_error

In [None]:
exp_w = []
global_var = ['blank']
f = pytc.GlobalFit()
gui = Interface(f)

def rm_last(b):

    if exp_w:
        last_exp = exp_w[-1]
        last_exp[1].close()
        try:
            gui.remove_experiment(last_exp[0].exp_id)
        except:
            pass

        exp_w.remove(last_exp)

def gen_exp(b):
    
    gui.reset_sliders()
    clear_output()
    
    # update information based on slider widget data. too much.
    
    for i in exp_w:
        try:
            exp = i[0]
            
            if not exp.exp_id:
            
                exp.create_exp()
                exp_add = exp.exp_id

                gui.add_experiment(exp_add)
                #exp.build_sliders
                
            else:
                pass
        except:
            print("no data added.")

    gui.build()
    #f.fit()
    #f.plot()

def clear_exp(b):

    for i in exp_w:
        try:
            i[1].close()
            gui.remove_experiment(i[0].exp_id)
        except:
            pass
        
        exp_w.remove(i)
        
    gui.reset_sliders()
    clear_output()

    
def gen_exp_list(exp_list):
    
    for i in range(len(exp_w)):
        exp = exp_w[i][0]
        exp_list['exp {}'.format(i)] = exp.exp_id
        
    return exp_list


def add_field(b):
    
    clear_output()
    gui.reset_sliders()
    exp_object = Experiments(gui, exp_w, f)
    show = exp_object.gen_exp()

    exp_w.append([exp_object, show])

    display(show)
    
def create_global(b):
    
    glob_var = global_field.value
    
    if glob_var not in global_var and glob_var:
        global_var.append(glob_var)
        global_field.value = ''
    else:
        pass

def remove_global(b):
    
    f.remove_global(global_field.value)
    global_var.remove(global_field.value)
        
ENTRY_W = '200px'

global_field = widgets.Text(description = "Global :")
#global_field.layout.width = '100px'

global_add = widgets.Button(description = "Add Global Variable")
global_add.layout.width = ENTRY_W
global_add.on_click(create_global)

global_remove = widgets.Button(description = "Remove Global Variable")
global_remove.layout.width = ENTRY_W
global_remove.on_click(remove_global)

add_exp_field_b = widgets.Button(description = "Add an Experiment")
add_exp_field_b.layout.width = ENTRY_W
add_exp_field_b.on_click(add_field)

rmv_last_field = widgets.Button(description = "Remove Last Experiment")
rmv_last_field.layout.width = ENTRY_W
rmv_last_field.on_click(rm_last)

exp_object = Experiments(gui, exp_w, f)
show = exp_object.gen_exp()

exp_w.append([exp_object, show])

analyze_widget = widgets.Button(description = "Analyze")
analyze_widget.on_click(gen_exp)

clear_widget = widgets.Button(description = "Clear", value = False)
clear_widget.on_click(clear_exp)

experiments_layout = widgets.Layout(display = "flex", 
                      flex_flow = "row", 
                      align_items = "stretch")

glob_box = widgets.Box(children = [global_field, global_add, global_remove],
                       layout = experiments_layout)

experiments = widgets.Box(children = [add_exp_field_b, rmv_last_field], 
                                      layout = experiments_layout)
parent = widgets.Box(children = [analyze_widget, clear_widget, experiments, glob_box, show])

display(parent)


# inspect.argspec for experiment class and widgets generated