# Interactive Plotting for Results


During training, data can be saved and visuals (2D and 3D) can be created using the functions presented in the demo visualizations notebook. 
This Jupyter Notebook is an example of code capable of generating an interactive plot using plotly. 

#### Training
The metrics for each epoch would have to be stored during training. For example, a Handler could be added to the validation engine in the script `run_ar_tc.py`, which would call the visualization functions at the end of each validation epoch. 

#### Visualizations
After training, the results can be visualized over time and through interactive plots and images in this code. For each epoch, a trace is created for the tropical cyclones and the atmospheric rivers. This trace will be plotted in the metrics representation per epoch and only shows metrics results up to that point in time. Likewise, the image plotting of the predicted labels vs the truthful labels is only showed for the desired time point (epoch). 

#### Running the code
In this example, the metrics results are saved for each epoch in a numpy file named `visualization_results_epoch_X.npy` where X is the epoch in question. Similarly, the images for each model prediction and the ground truths are called `image_prediction_epoch_X.png` and `image_truth_epoch_X.png`. These would have to be adapted to the name given to your files.

#### The Plot
A static rendering of what the interactive plot looks like. The slider at the bottom allows you to move in between epochs (i.e. move over time) 

![The state of learning of the model after training after 4 epochs](../images/interactiveplot_epoch4.png)

![The state of learning of the model after training after 28 epochs](../images/interactiveplot_epoch28.png)


In [None]:
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
from PIL import Image
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)

import plotly.io as pio
pio.renderers.default = 'iframe'

In [None]:
tc_values=[]
ar_values=[]
for i in range(1,31):
    values = np.load('visualization_results_epoch_{}.npy'.format(i), allow_pickle=True)
    tc_values.append(values[0][1])
    ar_values.append(values[0][2])

In [None]:
# Create figure
fig = make_subplots(rows=5, cols=2, 
                   shared_yaxes=True,
                   subplot_titles=("Tropical Cyclones Mean Average Precision", 
                                   "Atmospheric Rivers Mean Average Precision",
                                   "Prediction of Extreme Events",
                                   "Ground Truth of Extreme Events"),
                   specs=[[{"rowspan": 2}, {"rowspan": 2}],
                          [None, None],
                          [{"rowspan": 3},{"rowspan": 3}],
                          [None, None],
                          [None, None]],
                   vertical_spacing=0.1)

# Add traces, one for each slider step
for step in np.arange(0, 30):
    # TROPICAL CYCLONES
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=dict(color="#008000", width=6),
            name="TC mAP: " + str(step),
            x=np.arange(step+1),
            y=tc_values[:step+1]),
        row=1, col=1)

    #ATMOSPHERIC RIVERS
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=dict(color="#00008b", width=6),
            name="AR mAP: " + str(step),
            x=np.arange(step+1),
            y=ar_values[:step+1]),
        row=1, col=2)
    
    #2D Image
    fig.add_trace(
        go.Image(
            visible=False,
            z=np.array(Image.open("image_prediction_epoch_{}.png".format(step+1)))),
        row=3, col=1)
    
    #2D Image
    fig.add_trace(
        go.Image(
            visible=False,
            z=np.array(Image.open("image_truth_epoch_{}.png".format(step+1)))),
        row=3, col=2)


# Make 10th trace visible
fig.data[0].visible = True
fig.data[1].visible = True
fig.data[2].visible = True
fig.data[3].visible = True

# Update xaxis properties
fig.update_xaxes(title_text="Epoch", row=1, col=1)
fig.update_xaxes(title_text="Epoch", row=1, col=2)
fig.update_xaxes(showgrid=False, showticklabels=False, row=3, col=1)
fig.update_xaxes(showgrid=False, showticklabels=False, row=3, col=2)

# Update yaxis properties
fig.update_yaxes(title_text="Mean Average Precision", row=1, col=1)
fig.update_yaxes(title_text="Mean Average Precision", row=1, col=2)
fig.update_yaxes(showgrid=False, showticklabels=False, row=3, col=1)
fig.update_yaxes(showgrid=False, showticklabels=False, row=3, col=2)

# Create and add slider
steps = []
for i in range(0,len(fig.data), 4):
    step = dict(
        #method="restyle",
        args=["visible", [False] * len(fig.data)],
    )
    step["args"][1][i] = True  # Toggle i'th trace to "visible"
    step["args"][1][i+1] = True
    step["args"][1][i+2] = True
    step["args"][1][i+3] = True
    steps.append(step)

sliders = [dict(
    active=0,
    currentvalue={"prefix": "Epoch: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    height=1250, width=2000,
    title_text="Extreme Event Detection Results",
    sliders=sliders
)

pio.show(fig)

