In [None]:
import plotly.graph_objects as go
import plotly.io as pio
import pandas as pd
from plotly.subplots import make_subplots
pio.templates.default = "plotly_white"

In [None]:
ex_icohp = pd.read_csv('exc_icohp/summary_results.csv', index_col='Unnamed: 0')
inc_icohp = pd.read_csv('inc_icohp/summary_results.csv', index_col='Unnamed: 0')

#### Including ICOHP stats

In [None]:
inc_icohp

#### Excluding ICOHP stats (only magpie features with sine coloumb matrix)

In [None]:
ex_icohp

## Get metrics plot for models

In [None]:
def get_summary_stats_plot(model_1,model_2):
    """
    Function to save model performace metric plot on test set
    """
    fig = make_subplots(rows=1, cols=3, shared_xaxes=False, shared_yaxes=False,
                     horizontal_spacing=0.15, x_title='Metrics')
    
    
    fig.add_trace(go.Scatter(x=[''], 
                          y= model_1.mae_test_mean,
                          marker=dict(size=10,color='#1878b6'),
                          error_y=dict(
                            type='data', # value of error bar given in data coordinates
                            array=model_1.mae_test_std,
                            visible=True),
                          mode='markers+lines',showlegend=False,
                          name='1'), row=1, col=1)
    fig.add_trace(go.Scatter(x=[''], 
                              y= model_2.mae_test_mean,
                              marker=dict(size=10,color='#f57f1f'),
                              error_y=dict(
                                type='data', # value of error bar given in data coordinates
                                array=model_2.mae_test_std,
                                visible=True),
                              mode='markers+lines',showlegend=False,
                              name='2'), row=1, col=1)
    
    
    fig.add_trace(go.Scatter(x=[''], 
                          y= model_1.rmse_test_mean,
                          marker=dict(size=10,color='#1878b6'),
                          error_y=dict(
                            type='data', # value of error bar given in data coordinates
                            array=model_1.rmse_test_std,
                            visible=True),
                          mode='markers+lines',showlegend=False,
                          name='1'), row=1, col=2)
    fig.add_trace(go.Scatter(x=[''], 
                              y= model_2.rmse_test_mean,
                              marker=dict(size=10,color='#f57f1f'),
                              error_y=dict(
                                type='data', # value of error bar given in data coordinates
                                array=model_2.rmse_test_std,
                                visible=True),
                              mode='markers+lines',showlegend=False,
                              name='2'), row=1, col=2)
    
    
    fig.add_trace(go.Scatter(x=[''], 
                          y= model_1.max_error_test_mean,
                          marker=dict(size=10,color='#1878b6'),
                          error_y=dict(
                            type='data', # value of error bar given in data coordinates
                            array=model_1.max_error_test_std,
                            visible=True),
                          mode='markers+lines',
                          name='1'), row=1, col=3)
    fig.add_trace(go.Scatter(x=[''], 
                              y= model_2.max_error_test_mean,
                              marker=dict(size=10,color='#f57f1f'),
                              error_y=dict(
                                type='data', # value of error bar given in data coordinates
                                array=model_2.max_error_test_std,
                                visible=True),
                              mode='markers+lines',
                              name='2'), row=1, col=3)
    for i in range(1, 4):
        fig.update_yaxes(title_font=dict(size=24), color='black',row=1, col=i,tickfont = dict(size=22))
        fig.update_xaxes(title_font=dict(size=24), color='black',row=1, col=i,tickfont = dict(size=22))
        fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True, row=1, col=i,autorange=True)
        fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True, row=1, col=i,autorange=True)
        fig.update_xaxes(ticks="inside", tickwidth=1, tickcolor='black', ticklen=5, row=1, col=i)
        fig.update_yaxes(ticks="inside", tickwidth=1, tickcolor='black', ticklen=5, row=1, col=i)
    
    fig.update_yaxes(title_text="$\\text{MAE } (cm^{-1})$", row=1, col=1, title_standoff=0)
    fig.update_yaxes(title_text="$\\text{RMSE } (cm^{-1})$", row=1, col=2, title_standoff=0)
    fig.update_yaxes(title_text="$\\text{Max error } (cm^{-1})$", row=1, col=3,title_standoff=0)
    
    #fig.update_layout(yaxis_title="$\\text{MAE } (cm^{-1})$", row=1, col=1)
    #fig.update_layout(yaxis_title="$\\text{RMSE } (cm^{-1})$", row=1, col=1)
    #fig.update_layout(yaxis_title="$\\text{Max error } (cm^{-1})$", row=1, col=1)
    fig.update_layout(template='simple_white', width=1000, height=600)
    
    fig.write_image("Metrics.pdf",format= 'pdf', width=1000, height=600)
    fig.write_html("Metrics.html",include_mathjax = 'cdn')
    
    return fig

In [None]:
get_summary_stats_plot(model_1=ex_icohp, model_2=inc_icohp)