In [1]:
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go

from icecream import ic

In [2]:
# load dataframe
path = '../reports/wandb_training_runs.csv' 
df = pd.read_csv(path)
df

Unnamed: 0.1,Unnamed: 0,name,final/%rmse_test,final/mape_train,_step,train time,final/radius,final/scaling,metrics/%rmse_val,final/%rmse_val,...,distance,batch_size,len_sequence,multiplicity,weight_decay,learning_rate,embedding_size,num_val_sequences,num_test_sequences,num_train_sequences
0,0,cnn_euclidean_128,0.879319,0.053720,143,2996.831978,,,1.013559,0.888428,...,euclidean,128,2368,11,0,0.001,128,100,150,7000
1,1,cnn_euclidean_128,0.846775,0.048216,281,5891.983670,,,0.852926,0.803828,...,euclidean,128,2368,11,0,0.001,128,100,150,7000
2,2,cnn_euclidean_128,0.856920,0.052726,289,6009.685542,,,0.940606,0.823363,...,euclidean,128,2368,11,0,0.001,128,100,150,7000
3,3,cnn_euclidean_128,0.812125,0.048806,193,3557.712142,,,0.890296,0.861829,...,euclidean,128,2368,11,0,0.001,128,100,150,7000
4,4,cnn_euclidean_64,0.782185,0.048415,239,4092.976777,,,0.860414,0.789621,...,euclidean,128,2368,11,0,0.001,64,100,150,7000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
58,58,cnn_hyperbolic_4,1.418519,0.091283,104,1523.274984,0.987652,0.024150,1.471109,1.446855,...,hyperbolic,128,2368,11,0,0.001,4,100,150,7000
59,59,cnn_hyperbolic_2,3.253555,0.354181,104,1526.040309,1.004899,0.012352,3.720729,3.366958,...,hyperbolic,128,2368,11,0,0.001,2,100,150,7000
60,60,cnn_hyperbolic_2,2.764740,0.302234,104,1836.038584,1.021302,0.012464,3.440546,2.970084,...,hyperbolic,128,2368,11,0,0.001,2,100,150,7000
61,61,cnn_hyperbolic_2,2.760559,0.350905,104,2300.138346,1.004763,0.012856,3.178160,3.000315,...,hyperbolic,128,2368,11,0,0.001,2,100,150,7000


In [3]:
# process dataframe
df = df[['distance', 'embedding_size', 'final/%rmse_test']]
df = df.groupby(['embedding_size', 'distance']).agg(['mean', 'std']).reset_index()
df.columns = [' '.join(column).strip() for column in df.columns]
df = df.rename(columns={'embedding_size': 'Embedding Dimension', 'final/%rmse_test mean': "Mean %RMSE", 'distance': 'Distance', 'final/%rmse_test std': '%RMSE std'})
df['Mean %RMSE rounded'] = df['Mean %RMSE'].round(2)

df

Unnamed: 0,Embedding Dimension,Distance,Mean %RMSE,%RMSE std,Mean %RMSE rounded
0,2,euclidean,5.408676,0.048291,5.41
1,2,hyperbolic,2.870587,0.256836,2.87
2,4,euclidean,3.148575,0.049351,3.15
3,4,hyperbolic,1.438522,0.045448,1.44
4,6,euclidean,2.297897,0.021261,2.3
5,6,hyperbolic,1.235643,0.025289,1.24
6,8,euclidean,1.854092,0.038878,1.85
7,8,hyperbolic,1.053523,0.020725,1.05
8,16,euclidean,1.206387,0.043491,1.21
9,16,hyperbolic,1.08132,0.029081,1.08


In [4]:
def line(error_y_mode=None, **kwargs):
    """
    Extension of `plotly.express.line` to use error bands.
    Soure: https://stackoverflow.com/a/69587615/14773537.
    """
    ERROR_MODES = {'bar','band','bars','bands',None}
    if error_y_mode not in ERROR_MODES:
        raise ValueError(f"'error_y_mode' must be one of {ERROR_MODES}, received {repr(error_y_mode)}.")
    if error_y_mode in {'bar','bars',None}:
        fig = px.line(**kwargs)
    elif error_y_mode in {'band','bands'}:
        if 'error_y' not in kwargs:
            raise ValueError(f"If you provide argument 'error_y_mode' you must also provide 'error_y'.")
        figure_with_error_bars = px.line(**kwargs)
        fig = px.line(**{arg: val for arg,val in kwargs.items() if arg != 'error_y'})
        for data in figure_with_error_bars.data:
            x = list(data['x'])
            y_upper = list(data['y'] + data['error_y']['array'])
            y_lower = list(data['y'] - data['error_y']['array'] if data['error_y']['arrayminus'] is None else data['y'] - data['error_y']['arrayminus'])
            color = f"rgba({tuple(int(data['line']['color'].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))},.3)".replace('((','(').replace('),',',').replace(' ','')
            fig.add_trace(
                go.Scatter(
                    x = x+x[::-1],
                    y = y_upper+y_lower[::-1],
                    fill = 'toself',
                    fillcolor = color,
                    line = dict(
                        color = 'rgba(255,255,255,0)'
                    ),
                    hoverinfo = "skip",
                    showlegend = False,
                    legendgroup = data['legendgroup'],
                    xaxis = data['xaxis'],
                    yaxis = data['yaxis'],
                )
            )
        # Reorder data as said here: https://stackoverflow.com/a/66854398/8849755
        reordered_data = []
        for i in range(int(len(fig.data)/2)):
            reordered_data.append(fig.data[i+int(len(fig.data)/2)])
            reordered_data.append(fig.data[i])
        fig.data = tuple(reordered_data)
    return fig

In [14]:
fig = line(
    data_frame=df,
    x='Embedding Dimension',
    y='Mean %RMSE',
    error_y='%RMSE std',
    error_y_mode='band',
    color='Distance',
    markers='.',
    log_y=True,
    log_x=True,
    text='Mean %RMSE rounded'
    )

dimensions = df['Embedding Dimension'].unique()
title = 'Edit Distance Approximation:<br>%RMSE on Greengenes Dataset with CNN'

fig.update_traces(textposition='bottom center')
fig.update_layout(
    yaxis_tickformat = "0.1r",
    legend=dict(y=0.95, x=0.8),
    xaxis = dict(
        tickmode = 'array',
        tickvals = dimensions,
    ),
    title={'text': title, 'xanchor': 'center', 'x':0.5}
)

fig.write_image('../reports/training_runs.png')
fig.show()