In [None]:
%config InlineBackend.figure_formats = ['svg']
%matplotlib inline
from wxyz.datagrid.widget_selectgrid import SelectGrid
from wxyz.lab.widget_dock import DockBox
import re, yellowbrick.features, tpot, ipywidgets as W, numpy as np, pandas as pd, traitlets as T, dask.distributed, multiprocessing, sklearn, warnings
from lime.lime_tabular import LimeTabularExplainer 
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

In [None]:
%%html
<style>
.widget-hslider, .widget-hprogress, .widget-inline-hbox .widget-label{
    width: auto;
}
.jp-RenderedSVG img {
    max-width: 100%;
    max-height: 100%;
}
</style>

In [None]:
class App(T.HasTraits):
    dataset = T.Any()
    dask = T.Instance(dask.distributed.Client, allow_none=True)
    df = T.Instance(pd.DataFrame, allow_none=True)
    train = T.Instance(np.ndarray, allow_none=True)
    test = T.Instance(np.ndarray, allow_none=True)
    labels_train = T.Instance(np.ndarray, allow_none=True) 
    labels_test = T.Instance(np.ndarray, allow_none=True)
    model = T.Instance(tpot.TPOTClassifier, allow_none=True)
    explainer = T.Instance(LimeTabularExplainer, allow_none=True)
app = App()

In [None]:
app.dask = dask.distributed.Client(n_workers=multiprocessing.cpu_count(), threads_per_worker=1)
dashboard_url = re.findall(r'http.*?(?=/status)', app.dask._repr_html_())[0]
iframe_client = W.HTML(f"""
<iframe src="{dashboard_url}/tasks" width="100%" height="300" border="0" style="border: 0"></iframe>
""")
# iframe_client

In [None]:
grid_train = SelectGrid(description="Training Data")
grid_test = SelectGrid(description="Test Data")
grid_labels_train = SelectGrid(description="Training Labels")
grid_labels_test = SelectGrid(description="Test Labels")

correlation = W.Output()
fitness = W.FloatSlider(0, min=0, max=1, description="🎯 Score", disabled=True)

train_output = W.Output(description="Training Output")

generations = W.IntSlider(10, min=1, max=100, description="👴 Generations", layout=dict(flex="1"))
population_size = W.IntSlider(7, min=1, max=100, description="👶 Population", layout=dict(flex="1"))
cv = W.IntSlider(2, min=1, max=10, description="❌ Cross Validation", layout=dict(flex="1"))
random_state = W.IntSlider(42, min=1, max=1e6, description="🎰 Random", layout=dict(flex="1"))
verbosity = W.IntSlider(2, min=1, max=11, description="📣 Verbosity", layout=dict(flex="1"))

In [None]:
@W.interact
def loader(
    dataset_name=W.SelectionSlider(options=["iris", "breast_cancer", "boston", "diabetes", "wine"], description="🗄 Dataset"),
    test_size=W.FloatSlider(0.8, min=0.000001, max=0.99999, step=0.01, description="⚖ Train/Test"),
    algorithm=W.SelectionSlider(options=['pearson', 'covariance', 'spearman'], description="🌡️ Correlation"),
):
    app.dataset = getattr(sklearn.datasets, f"load_{dataset_name}")()
    app.df = pd.DataFrame(app.dataset.data, columns=app.dataset.feature_names)
    (
        app.train, 
        app.test, 
        app.labels_train, 
        app.labels_test
    ) = sklearn.model_selection.train_test_split(
        app.dataset.data,
        app.dataset.target,
        train_size=1 - test_size,
        test_size=test_size
    )
    grid_train.value = pd.DataFrame(app.train)
    grid_test.value = pd.DataFrame(app.test)
    grid_labels_train.value = pd.DataFrame(app.labels_train)
    grid_labels_test.value = pd.DataFrame(app.labels_test)
    app.explainer = LimeTabularExplainer(
        app.train,
        feature_names=app.dataset.feature_names, 
        class_names=getattr(app.dataset, "target_names", None), 
        discretize_continuous=True
    )
    visualizer = yellowbrick.features.Rank2D(features=app.dataset.feature_names, algorithm=algorithm)
    visualizer.fit(app.train, app.labels_train)
    visualizer.transform(app.train)
    visualizer.poof()

In [None]:
btn_train = W.Button(description="🚂 Train")

def trainer(*args, **kwargs):
    train_output.clear_output()
    with train_output:
        app.model = tpot.TPOTClassifier(
            generations=generations.value,
            population_size=population_size.value,
            cv=cv.value,
            n_jobs=-1,
            random_state=random_state.value,
            verbosity=verbosity.value,
            use_dask=True
        )

        _update = app.model._check_periodic_pipeline

        def update(gen):
            _update(gen)
            fitness.value = app.model._optimized_pipeline_score

        app.model._check_periodic_pipeline = update

    train_output.clear_output()

    with train_output, warnings.catch_warnings():
        fitness.value = 0
        warnings.simplefilter("ignore")
        app.model.fit(app.train, app.labels_train)
        fitness.value = app.model.score(app.test, app.labels_test)
        

btn_train.on_click(trainer)
        
box_train = W.VBox([
    generations,
    population_size,
    cv,
    random_state,
    verbosity,
    btn_train,
    fitness,
], description="Parameters", layout=dict(display="flex"))
display(box_train)
display(train_output)

In [None]:
with W.Output():
    @W.interact
    def explain(instance=W.IntSlider(0, min=0, max=100, description="🤔 'Splain")):
        try: explain.widget.children[0].max = app.test.shape[0] - 1
        except: pass
        if app.explainer and app.model and hasattr(app.model, "predict_proba"):
            app.explainer.explain_instance(app.test[instance], app.model.predict_proba).show_in_notebook()

In [None]:
W.jslink((grid_train, "scroll_y"), (grid_labels_train, "scroll_y"))
W.jslink((grid_test, "scroll_y"), (grid_labels_test, "scroll_y"));

In [None]:
pg = DockBox([
    loader.widget,
    grid_train,
    grid_labels_train,
    grid_test,
    grid_labels_test,
    box_train,
    train_output,
    iframe_client,
    explain.widget,
], layout=dict(height="100vh"))
train_output.description = "Training Output"

In [None]:
@pg.on_displayed
def fix_layout(*args):
    pg.hide_tabs = True
    pg.dock_layout = {'type': 'split-area',
 'orientation': 'horizontal',
 'children': [{'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [8], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [0], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [1], 'currentIndex': 0}],
   'sizes': [0.3288275105446067, 0.15804757115893375, 0.5131249182964596]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [2], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [4], 'currentIndex': 0}],
   'sizes': [0.5, 0.5]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [3], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [5], 'currentIndex': 0}],
   'sizes': [0.5, 0.5]},
  {'type': 'split-area',
   'orientation': 'vertical',
   'children': [{'type': 'tab-area', 'widgets': [6], 'currentIndex': 0},
    {'type': 'tab-area', 'widgets': [7], 'currentIndex': 0}],
   'sizes': [0.30860784292804, 0.69139215707196]},
  {'type': 'tab-area', 'widgets': [9], 'currentIndex': 0}],
 'sizes': [0.25792394844255556,
  0.16167643973552365,
  0.11422049816337114,
  0.24536800382401186,
  0.2208111098345378]}

In [None]:
pg