In [13]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False
from incremental_learning.config import es_cloud_id, es_user, es_password
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots

from elasticsearch import Elasticsearch


import eland as ed

import ipywidgets as widgets
from ipywidgets import interact, Layout, interactive_output

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
def plot_errors(sampling_mode, update_fraction, training_fraction, dataset_name):
    if sampling_mode:
        dataset = df1[(df1['config.sampling_mode'] == sampling_mode)]
    else:
        dataset = df1
    if update_fraction:
        dataset =  dataset[(dataset['config.update_fraction'] == update_fraction)]
    if training_fraction:
        dataset =  dataset[(dataset['config.training_fraction'] == training_fraction)]
    if dataset_name:
        dataset =  dataset[(dataset['config.dataset_name'] == dataset_name)]
        
    x_base = np.arange(10)+1
    
    fig = go.Figure()
    for row in dataset.iterrows():
        text = ["{}: {}".format(k, v) for k,v in row[1][['config.sampling_mode', 'config.update_fraction', 
                          'config.training_fraction', 'config.seed']].items()]
        text = '<br>'.join(text)
        x = x_base*row[1]['config.update_fraction'] + row[1]['config.training_fraction']
        if len(row[1]['run.result.updated_model.train_error.value']) == 10:
            fig.add_trace(go.Scatter(x=x, y=row[1]['run.result.updated_model.train_error.value'], text=text))
    fig.update_layout(title="Train error", xaxis_title='Fraction observed', 
            yaxis_title="MSE")
    fig.show()

    fig = go.Figure()
    for row in dataset.iterrows():
        text = ["{}: {}".format(k, v) for k,v in row[1][['config.sampling_mode', 'config.update_fraction', 
                          'config.training_fraction', 'config.seed']].items()]
        text = '<br>'.join(text)
        x = x_base*row[1]['config.update_fraction'] + row[1]['config.training_fraction']
        if len(row[1]['run.result.updated_model.test_error.value']) == 10:
            fig.add_trace(go.Scatter(x=x, y=row[1]['run.result.updated_model.test_error.value'], text=text))
    fig.update_layout(title="Test error", xaxis_title='Fraction observed', 
            yaxis_title="MSE")
    fig.show()

In [15]:
es = Elasticsearch(cloud_id=es_cloud_id,
                       http_auth=(es_user, es_password))

df = ed.DataFrame(es_client=es, es_index_pattern='experiment-multi-step-sampling', columns=['config.dataset_name', 'run.result.updated_model.train_error.value', 
                      'run.result.updated_model.test_error.value', 'config.sampling_mode', 'config.update_fraction', 
                      'config.training_fraction', 'config.seed'])

df1 = df.to_pandas().dropna()

In [18]:
# layout = Layout(display='flex', flex_flow='row', justify_content='space-between')
w_dataset_name = widgets.Dropdown(options=df1['config.dataset_name'].unique(), value='house', 
                                   description="Dataset_name", disabled=False)
w_sampling_mode = widgets.Dropdown(options=df1['config.sampling_mode'].unique(), value=None, 
                                   description="Sampling mode", disabled=False)
w_update_fraction = widgets.Dropdown(options=df1['config.update_fraction'].unique(), value=None, 
                                     description="Update fraction", disabled=False)
w_training_fraction = widgets.Dropdown(options=df1['config.training_fraction'].unique(), value=None, 
                                       description="Training fraction", disabled=False)
ui = widgets.HBox([w_dataset_name, w_sampling_mode, w_update_fraction, w_training_fraction])
out = interact(plot_errors, dataset_name = w_dataset_name, sampling_mode=w_sampling_mode, 
             update_fraction = w_update_fraction, training_fraction=w_training_fraction)


interactive(children=(Dropdown(description='Sampling mode', options=('nlargest', 'random'), value=None), Dropd…

In [17]:
df1[(df1["config.dataset_name"] == 'facebook')].dropna()

Unnamed: 0,config.dataset_name,run.result.updated_model.train_error.value,run.result.updated_model.test_error.value,config.sampling_mode,config.update_fraction,config.training_fraction,config.seed
NndFnXwBUYUoev8FgBGq,facebook,"[224.234803423985, 134.8236830908232, 134.8236...","[270.1305986033704, 316.8556287640425, 316.855...",nlargest,0.05,0.5,0
63cDpXwBUYUoev8FYhE3,facebook,"[246.8351697587024, 246.8351697587024, 156.772...","[290.4893250714229, 290.4893250714229, 269.521...",nlargest,0.2,0.1,0
83dipXwBUYUoev8FHRG_,facebook,"[230.6992257018419, 128.6263632156568, 128.626...","[270.1742478184926, 306.26815448443847, 306.26...",nlargest,0.1,0.5,0
TXcAqnwBUYUoev8FlxI9,facebook,"[304.6568239873566, 304.6568239873566, 304.656...","[372.6634242311355, 372.6634242311355, 372.663...",nlargest,0.1,0.25,1
S3fsqXwBUYUoev8FJBLE,facebook,"[324.5287687359223, 324.5287687359223, 324.528...","[315.5524872558988, 315.5524872558988, 315.552...",nlargest,0.1,0.25,0
Lnd3qHwBUYUoev8F9BIm,facebook,"[390.1605412852457, 390.1605412852457, 323.275...","[370.2498139849785, 370.2498139849785, 332.081...",random,0.1,0.1,0
-3fBpXwBUYUoev8FzhEs,facebook,"[395.8675035308161, 395.8675035308161, 351.868...","[376.6148687974317, 376.6148687974317, 378.441...",random,0.05,0.1,0
z3dPsHwBUYUoev8FDBK6,facebook,"[304.6568239873566, 304.6568239873566, 304.656...","[372.6634242311355, 372.6634242311355, 372.663...",nlargest,0.1,0.25,1
0HdPsHwBUYUoev8FfBKJ,facebook,"[304.6568239873566, 304.6568239873566, 304.656...","[372.6634242311355, 372.6634242311355, 372.663...",nlargest,0.1,0.25,1
YnegqnwBUYUoev8FNRK1,facebook,"[276.6931256332166, 144.67571911774428, 144.67...","[289.4603473715156, 277.9086951461986, 277.908...",nlargest,0.2,0.25,0
