# Comparison notebook
Compares the loss of several Neural Networks Transformers.

# Imports

In [None]:
import numpy as np
import plotly.express as px
import pandas as pd
import matplotlib.colors as mplc

# Load losses

In [None]:
path = ""
losses_files = []
losses_df = []
for loss_file in losses_files :
    losses_df.append(pd.from_csv(path + loss_file))

# Plot

### Global settings

In [None]:
# define custom colors
colors_distinct_named = ["darkmagenta", "firebrick", "tomato", 
            "orange", "gold", "yellowgreen", "darkseagreen", 
            "limegreen", "mediumaquamarine", "turquoise", 
            "deepskyblue", "royalblue", "navy", "mediumpurple", 
            "orchid", "crimson"]
# convert colors to hex
colors_distinct_hex = {c: mplc.cnames[c] for c in colors1}

graph_distinct_summer = ["#27187E","#758BFD","#AEB8FE","#CC2936", "#FF8600"]
color_gradient_summer = ["#057bb0","#390099", "#750090", "#9e0059","#880d1e","#ff5400","#ff7d10","#ffbd00"]

### Utilities

In [None]:
# function for plotting a line plot of an array x depending on y
# Args : 
#   y : numerical values to plot
#   ylabel : name of y value or y unit
#   x can be none or X axis
# Returns the corresponding html file

def plot_line(x=None, xlabel="x", y=None, ylabel="y", color="blue", width=800, height=500, x_int_only=True, plot=True, full_plotly_offline=False):
    n = len(y)
    if y is None :
        print("y is None")
        return None
    if x is None :
        x = np.arange(n)
    # Create df
    df = pd.DataFrame({
                    xlabel: x,
                    ylabel: y,
                
    })
    # Plot
    fig = px.line(
            df,
            x=xlabel,
            y=ylabel,
            template="plotly_white",
            width=width,
            height=height,
            color=[color]*n, 
            color_discrete_map="identity"
    )
    fig.update_xaxes(title=xlabel)
    if x_int_only : fig.update_xaxes(tickformat='d')
    fig.update_yaxes(title=ylabel)

    if plot : fig.show() 
    if full_plotly_offline:
        return fig.to_html(full_html=False)
    else :
        return fig.to_html(full_html=False, include_plotlyjs=False)




In [None]:

# function for plotting multiple lines from a dataframe. All columns will be plotted except the one specified as x_col.
# Args : 
#   df : input dataframe
#   x_col : column name for x axis
#   startstop : list of 2 integers, start and end index of columns to be plotted. If None, the whole column is plotted.
#   width, height : plot dimensions
#   x_int_only : format x axis as integer
# Returns the corresponding html file

def multiline_plot(df, x_col, startstop=None, colors=colors1_hex, width=800, height=500, x_int_only=True, plot=True, full_plotly_offline=False):
    if x_col not in df.columns:
        print(f"{x_col} not in dataframe columns")
        return None
    
    y_cols = [col for col in df.columns if col != x_col]
    if not y_cols:
        print("No columns to plot on y axis")
        return None
    if startstop is None :
        startstop = [0, len(df)]

    fig = px.line(
        df.iloc[startstop[0]:startstop[1]],
        x=x_col,
        y=y_cols,
        template="plotly_white",
        width=width,
        height=height,
        color_discrete_map=dict(zip(y_cols, colors))
    )
    #fig.update_xaxes(title=xlabel)
    if x_int_only:
        fig.update_xaxes(tickformat='d')
    fig.update_yaxes(title="Value")
    if plot:
        fig.show()
    if full_plotly_offline:
        return fig.to_html(full_html=False)
    else:
        return fig.to_html(full_html=False, include_plotlyjs=False)





### Compare test losses

### Compare train losses

### Compare test & train losses