In [7]:
import ipywxyz, tpot, ipywidgets, pandas, traitlets, dask.distributed, multiprocessing, sklearn
from lime.lime_tabular import LimeTabularExplainer 
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import re

In [8]:
client = dask.distributed.Client(n_workers=multiprocessing.cpu_count(), threads_per_worker=1)
client

Port 8787 is already in use. 
Perhaps you already have a cluster running?
Hosting the diagnostics dashboard on a random port instead.


0,1
Client  Scheduler: tcp://127.0.0.1:42333  Dashboard: http://127.0.0.1:43761/status,Cluster  Workers: 8  Cores: 8  Memory: 16.72 GB


In [24]:
dashboard_url = re.findall(r'http.*?(?=/status)', client._repr_html_())[0]
iframe_client = ipywidgets.HTML(f"""
<iframe src="{dashboard_url}/tasks" width="100%" height="300" border="0" style="border: 0"></iframe>
""")
iframe_client

HTML(value='\n<iframe src="http://127.0.0.1:43761/tasks" width="100%" height="300" border="0" style="border: 0…

In [28]:
iris = load_iris()
df = pandas.DataFrame(iris["data"], columns=iris["feature_names"])

In [29]:
(
    train, 
    test, 
    labels_train, 
    labels_test
) = sklearn.model_selection.train_test_split(
    iris.data, 
    iris.target, 
    train_size=0.80
)

In [30]:
grid_train = ipywxyz.SelectGrid(value=pandas.DataFrame(train), description="Training Data")
grid_test = ipywxyz.SelectGrid(value=pandas.DataFrame(test), description="Test Data")
grid_labels_train = ipywxyz.SelectGrid(value=pandas.DataFrame(labels_train), description="Training Labels")
grid_labels_test = ipywxyz.SelectGrid(value=pandas.DataFrame(labels_test), description="Test Labels")

In [31]:
traitlets.link((grid_train, "scroll_y"), (grid_labels_train, "scroll_y"))
traitlets.link((grid_test, "scroll_y"), (grid_labels_test, "scroll_y"))

<traitlets.traitlets.link at 0x7f88c70be278>

In [32]:
model = None

In [33]:
train_output = ipywidgets.Output(description="Training Output")
generations = ipywidgets.IntSlider(2, min=1, max=1000, description="generations")
population_size = ipywidgets.IntSlider(7, min=1, max=1000, description="population size")
cv = ipywidgets.IntSlider(2, min=1, max=10, description="cross validation")
random_state = ipywidgets.IntSlider(42, min=-1e6, max=1e6, description="random state")
verbosity = ipywidgets.IntSlider(5, 1, 11, description="verbosity")

def trainer(start_training=ipywidgets.ToggleButton):
    global model
    model = tpot.TPOTClassifier(
        generations=generations.value,
        population_size=population_size.value,
        cv=cv.value,
        n_jobs=-1,
        random_state=random_state.value,
        verbosity=0,
        use_dask=True
    )
    with train_output:
        model.fit(train, labels_train)
        print(model.score(test, labels_test))

btn_train = ipywidgets.Button(description="🚂")
btn_train.on_click(trainer)
box_train = ipywidgets.VBox([
    generations,
    population_size,
    cv,
    random_state,
    verbosity,
    btn_train
], description="Parameters")
box_train

VBox(children=(IntSlider(value=2, description='generations', max=1000, min=1), IntSlider(value=7, description=…

In [34]:
explainer = LimeTabularExplainer(
    train,
    feature_names=iris.feature_names, 
    class_names=iris.target_names, 
    discretize_continuous=True
)

with ipywidgets.Output():
    @ipywidgets.interact
    def explain(instance=(0, len(test) - 1)):
        global model
        global explain
        try:
            exp = explainer.explain_instance(test[instance], model.predict_proba)
            exp.show_in_notebook()
        except Exception as err:
            print(err)

In [41]:
traitlets.dlink((grid_test, "viewport"), (explain.widget.children[0], "value"), lambda v: v[2])

<traitlets.traitlets.directional_link at 0x7f88c532c5c0>

In [38]:
pg = ipywxyz.DockBox([
    grid_train,
    grid_labels_train,
    grid_test,
    grid_labels_test,
    box_train,
    train_output,
    iframe_client,
    explain.widget,
], layout=dict(height="100vh"))
train_output.description = "Training Output"
def layout(*_):
    pg.dock_layout = {'type': 'split-area',
 'orientation': 'horizontal',
 'children': [{'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [2], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [0], 'currentIndex': 0}],
   'sizes': [1, 1]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [3], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [1], 'currentIndex': 0}],
   'sizes': [1, 1]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [6], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [4], 'currentIndex': 0}],
   'sizes': [1, 1]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [5], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [7], 'currentIndex': 0}],
   'sizes': [2, 8]}],
 'sizes': [2, 1, 2, 2]}
pg.on_displayed(layout)
pg.hide_tabs = True
pg

DockBox(children=(SelectGrid(description='Training Data', hover_row=14, selection=(0, 0, 0, 0)), SelectGrid(de…

In [36]:
pg.dock_layout

{'type': 'split-area',
 'orientation': 'horizontal',
 'children': [{'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [2], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [0], 'currentIndex': 0}],
   'sizes': [0.5, 0.5]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [3], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [1], 'currentIndex': 0}],
   'sizes': [0.5, 0.5]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [6], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [4], 'currentIndex': 0}],
   'sizes': [0.5, 0.5]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [5], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [7], 'currentIndex': 0}],
   'sizes': [0.14538806039154445, 0.8546119396084556]}],
 'sizes': [0.2718889439293394,
  0.12519024