<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [None]:
if False:
    !pip install -U pandas numpy process_improve plotly IPython

In [None]:
import sys
import os
import pathlib
cwd = pathlib
sys.path.insert(0, str(pathlib.Path.cwd().parents[1]))

from process_improve.batch.preprocessing import find_reference_batch, batch_dtw
from process_improve.batch.data_input import melted_to_dict
from process_improve.batch.plotting import plot_all_batches_per_tag, plot_multitags
import plotly.graph_objs as go
from plotly.offline import iplot, init_notebook_mode
from IPython.display import display, HTML
import pandas as pd
import ipywidgets as widgets
init_notebook_mode(connected=True)
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
# Ideally, use more than 1 tag to align on. These columns must exist in all data frames for all batches. There should be NO missing data in any columns.
columns_to_align = ["AgitatorPower", "AgitatorTorque", "JacketTemperature", "DryerTemp"]
tag_to_plot = columns_to_align[3]

In [None]:
# Import the data: a dictionary of dataframes
import process_improve.datasets.batch as batch_ds 

dryer_raw = pd.read_csv(pathlib.Path(batch_ds.__path__._recalculate()[0]) / "dryer.csv")
df_dict = melted_to_dict(dryer_raw, batch_id_col="batch_id")
df_dict.keys()

In [None]:
# Plot some data, to get an idea of what is present
iplot(plot_all_batches_per_tag(df_dict=df_dict, 
                                tag=tag_to_plot,  
                                time_column ="ClockTime", 
                                x_axis_label="Time [hours]"))

In [None]:
# Plot some data, to get an idea of what is present
iplot(plot_multitags(df_dict=df_dict, 
                      time_column ="ClockTime", 
                      settings=dict(nrows=2)
                      ))

In [None]:
# What is a good batch number to align on?
good_reference_candidate = find_reference_batch(df_dict, 
                                                columns_to_align=columns_to_align, 
                                                settings={"robust": False})
good_reference_candidate

In [None]:
# Align the data based on the reference candidate, using the specified tags/columns.
aligned_out = batch_dtw(df_dict, columns_to_align=columns_to_align,
        reference_batch=good_reference_candidate,
        settings={
            "robust": False,
            
            # High tolerance of around 1.0 will run alignment only once; 
            # Typically set to 0.1, so that at least 2 or 3 iterations of alignment take place
            "tolerance": 0.05, 
             
            "show_progress": True, # show progress: shows total "distance" of batch relative to the reference
        },  
    )
    

In [None]:
print('Weight history (the higher the weight, the greater the importance of that tag in the alignment process):')
display(pd.DataFrame(aligned_out['weight_history'], columns=columns_to_align))

In [None]:
synced = aligned_out['aligned_batch_dfdict']

In [None]:
# Plot the aligned (synced) data
iplot(plot_multitags(df_dict=synced, 
                      settings=dict(nrows=2, x_axis_label='Normalized duration'))
     )

In [None]:
settings = dict(ncols=6, 
                nrows=2
)
batch1 = df_dict[list(df_dict.keys())[0]]
tag_list = list(batch1.columns)
def base_figure():
    fig = go.Figure()
    specs = [[{"type": "scatter"}] * int(settings["ncols"])] * int(settings["nrows"])
    fig.set_subplots(
        rows=settings["nrows"],
        cols=settings["ncols"],
        shared_xaxes="all",
        shared_yaxes=False,
        start_cell="top-left",
        vertical_spacing=0.2 / settings["nrows"],
        horizontal_spacing=0.2 / settings["ncols"],
        subplot_titles=tag_list,
        specs=specs,
    )
    return fig
    

def plot_all_tags(df_dict, fig, time_column=None):   
    traces = []
    batch_list = list(df_dict.keys())
    if time_column in tag_list:
        tag_list.remove(time_column)

    for batch_id, batch_df in df_dict.items():
       
        # Time axis values
        if time_column in batch_df.columns:
            time_data = batch_df[time_column]
        else:
            time_data = list(range(batch_df.shape[0]))

        row = col = 1
        for tag in tag_list:
            trace = go.Scatter(
                x=time_data,
                y=batch_df[tag],
                name=batch_id,
                mode="lines",
                hovertemplate="Time: %{x}\ny: %{y}",
                # line=colour_assignment[batch_id],  # <---- update
                legendgroup=batch_id,
                showlegend=True if tag == tag_list[0] else False,
            )
            fig.add_trace(trace, row=row, col=col)

            col += 1
            if col > settings["ncols"]:
                row += 1
                col = 1
                
    fig.update_layout(
        title="To add still",
        margin=dict(l=10, r=10, b=5, t=80),  # Defaults: l=80, r=80, t=100, b=80,
        hovermode="closest",
        showlegend=True,
        legend=dict(
            orientation="h",
            traceorder="normal",
            font=dict(family="sans-serif", size=12, color="#000"),
            bordercolor="#DDDDDD",
            borderwidth=1,
        ),
        autosize=False,
        xaxis=dict(
            title="TO ADD STILL",
            gridwidth=1,
            mirror=True,  # ticks are mirror at the top of the frame also
            showspikes=True,
            visible=True,
        ),
        yaxis=dict(
            gridwidth=2,
            type="linear",
            autorange=True,
            showspikes=True,
            visible=True,
            showline=True,  # show a separating line
            side="left",  # show on the RHS
        ),
        width=1900,
        height=800
    )
    fig.show()
    return fig          


out = widgets.Output()

g = go.FigureWidget(plot_all_tags(df_dict, fig=base_figure()))
#scatter_points = g.data[0]
#prediction_line = g.data[1] # this is the last element drawn in the plot

box_layout = widgets.Layout(display='inline-flex', flex_flow='row', align_items='stretch', width='100%')
box_auto = widgets.Box(children=[ g], layout=box_layout)
display(widgets.VBox([box_auto, ]) );


In [None]:
output_value = widgets.FloatSlider(min=600,
                                   max=1898, 
                                   step=2, 
                                   value=600, 
                                   readout_format="d",
                                   continuous_update=True,
                                   description='Output metric')

def update_plot(change):    
    new_value = f"{int(change.new)}nm"   
    
    print(change)
    #with g.batch_update():
    #    pass
        # Update the plot values here.
        #new_x = np.array([X.min().values[0], X.max().values[0]]).reshape(-1, 1)
        # 
        #scatter_points['x'] = 1
        #prediction_line['x'] = 2
        #prediction_line['y'] = 3

output_value.observe(update_plot, names='value')
display(output_value)


#display(widgets.VBox([out]));