In [None]:
import ipywxyz, tpot, ipywidgets, pandas, traitlets, dask.distributed, multiprocessing, sklearn
from lime.lime_tabular import LimeTabularExplainer 
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import re
import yellowbrick.features

In [None]:
%matplotlib inline

In [None]:
client = dask.distributed.Client(n_workers=multiprocessing.cpu_count(), threads_per_worker=1)
client

In [None]:
dashboard_url = re.findall(r'http.*?(?=/status)', client._repr_html_())[0]
iframe_client = ipywidgets.HTML(f"""
<iframe src="{dashboard_url}/tasks" width="100%" height="300" border="0" style="border: 0"></iframe>
""")
# iframe_client

In [None]:
dataset = None
df: pandas.DataFrame = None
train: pandas.np.ndarray = None
test: pandas.np.ndarray = None
labels_train: pandas.np.array = None 
labels_test: pandas.np.array = None
model: sklearn.pipeline.Pipeline = None

In [None]:
grid_train = ipywxyz.SelectGrid(description="Training Data")
grid_test = ipywxyz.SelectGrid(description="Test Data")
grid_labels_train = ipywxyz.SelectGrid(description="Training Labels")
grid_labels_test = ipywxyz.SelectGrid(description="Test Labels")

correlation = ipywidgets.Output()

train_output = ipywidgets.Output(description="Training Output")
generations = ipywidgets.IntSlider(2, min=1, max=1000, description="generations")
population_size = ipywidgets.IntSlider(7, min=1, max=1000, description="population size")
cv = ipywidgets.IntSlider(2, min=1, max=10, description="cross validation")
random_state = ipywidgets.IntSlider(42, min=-1e6, max=1e6, description="random state")
verbosity = ipywidgets.IntSlider(5, 1, 11, description="verbosity")

In [None]:
with ipywidgets.Output():
    @ipywidgets.interact
    def loader(dataset_name=ipywidgets.SelectionSlider(options=["iris", "breast_cancer", "boston", "diabetes", "wine"], description="Dataset")):
        global dataset, df, train, test, labels_train, labels_test, explainer
        dataset = getattr(sklearn.datasets, f"load_{dataset_name}")()
        df = pandas.DataFrame(dataset.data, columns=dataset.feature_names)
        (
            train, 
            test, 
            labels_train, 
            labels_test
        ) = sklearn.model_selection.train_test_split(
            dataset.data, 
            dataset.target, 
            train_size=0.80
        )
        grid_train.value = pandas.DataFrame(train)
        grid_test.value = pandas.DataFrame(test)
        grid_labels_train.value = pandas.DataFrame(labels_train)
        grid_labels_test.value = pandas.DataFrame(labels_test)
        explainer = LimeTabularExplainer(
            train,
            feature_names=dataset.feature_names, 
            class_names=getattr(dataset, "target_names", None), 
            discretize_continuous=True
        )

In [None]:
with ipywidgets.Output():
    @ipywidgets.interact
    def visualize(algorithm=ipywidgets.SelectionSlider(options=['pearson', 'covariance', 'spearman'])):
        global visualizer
        visualizer = yellowbrick.features.Rank2D(features=dataset.feature_names, algorithm=algorithm)

        correlation.clear_output()
        with correlation:
            visualizer.fit(train, labels_train)
            visualizer.transform(train)
            visualizer.poof()

In [None]:
def trainer(start_training=ipywidgets.ToggleButton):
    global model
    model = tpot.TPOTClassifier(
        generations=generations.value,
        population_size=population_size.value,
        cv=cv.value,
        n_jobs=-1,
        random_state=random_state.value,
        verbosity=0,
        use_dask=True
    )
    with train_output:
        model.fit(train, labels_train)
        print(model.score(test, labels_test))

btn_train = ipywidgets.Button(description="🚂")
btn_train.on_click(trainer)
box_train = ipywidgets.VBox([
    generations,
    population_size,
    cv,
    random_state,
    verbosity,
    btn_train
], description="Parameters")
# box_train

In [None]:
with ipywidgets.Output():
    @ipywidgets.interact
    def explain(instance=(0, len(test) - 1)):
        try:
            exp = explainer.explain_instance(test[instance], model.predict_proba)
            exp.show_in_notebook()
        except Exception as err:
            print(err)

In [None]:
traitlets.dlink((grid_test, "viewport"), (explain.widget.children[0], "value"), lambda v: v[2])
traitlets.link((grid_train, "scroll_y"), (grid_labels_train, "scroll_y"))
traitlets.link((grid_test, "scroll_y"), (grid_labels_test, "scroll_y"));

In [None]:
pg = ipywxyz.DockBox([
    loader.widget,
    visualize.widget,
    correlation,
    grid_train,
    grid_labels_train,
    grid_test,
    grid_labels_test,
    box_train,
    train_output,
    iframe_client,
    explain.widget,
], layout=dict(height="100vh"))
train_output.description = "Training Output"
def layout(*_):
    pg.dock_layout = {'type': 'split-area',
 'orientation': 'horizontal',
 'children': [{'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [5], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [3], 'currentIndex': 0}],
   'sizes': [0.5, 0.5]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [6], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [4], 'currentIndex': 0}],
   'sizes': [0.5, 0.5]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [0], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [2], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [1], 'currentIndex': 0}],
   'sizes': [0.3030839281152105, 0.41608264801846156, 0.2808334238663279]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [9], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [7], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [8], 'currentIndex': 0}],
   'sizes': [0.3703284415172773, 0.31483577924136136, 0.31483577924136136]},
  {'type': 'tab-area', 'widgets': [10], 'currentIndex': 0}],
 'sizes': [0.08564685481230452,
  0.09288380506162426,
  0.25631036671974694,
  0.30134697916301306,
  0.26381199424331125]}

pg.on_displayed(layout)
pg.hide_tabs = True
pg