In [1]:
import plotly_express as px
import dash
import dash_html_components as html
import pandas as pd
from plotly import tools
import plotly
import dash_core_components as dcc
import plotly.graph_objs as go

In [2]:
# Import validation results datasets
knrm_val = pd.read_csv('data/knrm_validationResults.csv',index_col=0)
knrm_loss = pd.read_csv('data/knrm_lossResults.csv',index_col=0)
knrm_loss = knrm_loss.rename(columns={"0": "Loss"})
mp_val = pd.read_csv('data/match_pyramid_validationResults.csv', index_col=0)
mp_loss = pd.read_csv('data/match_pyramid_lossResults.csv',index_col=0)
mp_loss = mp_loss.rename(columns={"0": "Loss"})
conv_knrm_val = pd.read_csv('data/conv_knrm_validationResults.csv',index_col=0)
conv_knrm_loss = pd.read_csv('data/conv_knrm_lossResults.csv',index_col=0)
conv_knrm_loss = conv_knrm_loss.rename(columns={"0": "Loss"})

In [3]:
# Merge validation results datasets
knrm = knrm_loss.merge(knrm_val, left_index=True, right_index=True, how='inner')
mp = mp_loss.merge(mp_val, left_index=True, right_index=True, how='inner')
conv_knrm = conv_knrm_loss.merge(conv_knrm_val, left_index=True, right_index=True, how='inner')
knrm['Model']="KNRM"
mp['Model']="MatchPyramid"
conv_knrm['Model']="CONV-KNRM"
res_df=pd.concat([knrm, mp,conv_knrm])

In [6]:
#Select data for scatterplot
res_df=pd.concat([knrm, mp,conv_knrm])

#select metric to vizualize (MRR@10,Loss,Recall etc.)
metric="MRR@10"

#filter dataset using the metric
res_df=res_df[[metric,'Model']]

#create list of models
models=['KNRM','MatchPyramid','CONV-KNRM']

#draw the plot
scatter_fig = plotly.subplots.make_subplots(
        subplot_titles=models,
        rows=len(models), 
        cols=1,
        shared_xaxes=True, 
        shared_yaxes=False,
        vertical_spacing=0.1,
    )
scatter_fig.update_xaxes(row=len(models), col=1,title_text="Iterations")
scatter_fig.update_yaxes(title_text=metric)
trace = []
for i in range(len(models)):
    df = res_df[res_df["Model"]==models[i]]
    trace.append(go.Scatter(x=df.index, y=df[metric],
                        mode='lines+markers',
                        name=models[i]))
    scatter_fig.append_trace(trace[i], i+1, 1)

scatter_fig['layout'].update(
        colorway=['#fdae61', '#800080', '#2c7bb6'], 
        margin=dict(l=50, r=20, t=30, b=40), 
        showlegend=False,
        paper_bgcolor='#f2f5fa'
    )
scatter_fig.show()