# Comparison notebook
Compares the loss of several Neural Networks Transformers.

# Imports

In [39]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import matplotlib.colors as mplc

# Load losses

In [2]:
# number of tokens in a block
n_block = 64
# number of block in a batch
n_batch = 32
# number of tokens in the whole tinystories dataset
n_token_full_ds = 3625009 * 128

In [None]:
# model names and description
plotted_models = {"6262": {"description" : "24M params, 100% TinyStories DS",
                            "n_block" : 64,
                            "n_batch" : 32,
                            "n_iter" : 100000,
                            "dataset_fraction" : 1.0}, 
                    "9174": {"description" : "24M params, 10% TinyStories DS",
                            "n_block" : 64,
                            "n_batch" : 32,
                            "n_iter" : 100000,
                            "dataset_fraction" : 0.1}, 
                    "8767": {"description" : "24M params, 1% TinyStories DS",
                            "n_block" : 64,
                            "n_batch" : 32,
                            "n_iter" : 100000,
                            "dataset_fraction" : 0.01}}

n_plotted_models = len(plotted_models)

In [55]:
for model in plotted_models.keys():
    # how many token the model has seen
    tokens_seen = plotted_models[model]["n_block"] * plotted_models[model]["n_batch"] * plotted_models[model]["n_iter"]
    
    # total available tokens for this training
    tokens_available = n_token_full_ds * plotted_models[model]["dataset_fraction"]
    
    # how many times a token was seen on average
    token_times_seen = max(tokens_seen / tokens_available, 1) # a token cant be seen less than once
    print(f"model {model} has seen {tokens_seen}, in average {token_times_seen} times")

    # append result to dict
    plotted_models[model]["tokens_seen"] = tokens_seen
    plotted_models[model]["token_times_seen"] = token_times_seen





model 6262 has seen 204800000, in average 1 times
model 9174 has seen 204800000, in average 2.2068910725463025 times
model 8767 has seen 204800000, in average 44.13782145092605 times


In [5]:
# load model training csv output to pandas df
path = "./Training_Results/"
losses_files_dict = dict(zip(plotted_models.keys(), [m+"_losses.csv" for m in plotted_models.keys()]))
print(losses_files_dict)
losses_df = {}
for model, file in losses_files_dict.items() :
    losses_df[model] = pd.read_csv(path + file)

{'6262': '6262_losses.csv', '9174': '9174_losses.csv', '8767': '8767_losses.csv'}


# Pre-plot

### Global settings

In [28]:
# define custom colors
colors_distinct_named = ["mediumpurple", "royalblue", "firebrick", 
                        "gold", "darkseagreen", "tomato", 
            "orange", "gold", "yellowgreen", "darkseagreen", 
            "limegreen", "mediumaquamarine", "turquoise", 
            "deepskyblue", "royalblue", "navy", "darkmagenta", 
            "orchid", "crimson"]
# convert colors to hex
colors_distinct_hex = [mplc.cnames[c] for c in colors_distinct_named]
# some other nice colors
graph_distinct_summer = ["#27187E","#758BFD","#AEB8FE","#CC2936", "#FF8600"]
color_gradient_summer = ["#057bb0","#390099", "#750090", "#9e0059","#880d1e","#ff5400","#ff7d10","#ffbd00"]

### Utilities

In [7]:
# function for plotting a line plot of an array x depending on y
# Args : 
#   y : numerical values to plot
#   ylabel : name of y value or y unit
#   x can be none or X axis
# Returns the corresponding html file

def plot_line(x=None, xlabel="x", y=None, ylabel="y", color="blue", width=800, height=500, x_int_only=True, plot=True, full_plotly_offline=False):
    n = len(y)
    if y is None :
        print("y is None")
        return None
    if x is None :
        x = np.arange(n)
    # Create df
    df = pd.DataFrame({
                    xlabel: x,
                    ylabel: y,
                
    })
    # Plot
    fig = px.line(
            df,
            x=xlabel,
            y=ylabel,
            template="plotly_white",
            width=width,
            height=height,
            color=[color]*n, 
            color_discrete_map="identity"
    )
    fig.update_xaxes(title=xlabel)
    if x_int_only : fig.update_xaxes(tickformat='d')
    fig.update_yaxes(title=ylabel)

    if plot : fig.show() 
    if full_plotly_offline:
        return fig.to_html(full_html=False)
    else :
        return fig.to_html(full_html=False, include_plotlyjs=False)




In [8]:

# function for plotting multiple lines from a dataframe. All columns will be plotted except the one specified as x_col.
# Args : 
#   df : input dataframe
#   x_col : column name for x axis
#   startstop : list of 2 integers, start and end index of columns to be plotted. If None, the whole column is plotted.
#   width, height : plot dimensions
#   x_int_only : format x axis as integer
# Returns the corresponding html file

def multiline_plot(df, x_col, startstop=None, colors=colors_distinct_hex, width=800, height=500, x_int_only=True, plot=True, full_plotly_offline=False):
    if x_col not in df.columns:
        print(f"{x_col} not in dataframe columns")
        return None
    
    y_cols = [col for col in df.columns if col != x_col]
    if not y_cols:
        print("No columns to plot on y axis")
        return None
    if startstop is None :
        startstop = [0, len(df)]

    fig = px.line(
        df.iloc[startstop[0]:startstop[1]],
        x=x_col,
        y=y_cols,
        template="plotly_white",
        width=width,
        height=height,
        color_discrete_map=dict(zip(y_cols, colors))
    )
    #fig.update_xaxes(title=xlabel)
    if x_int_only:
        fig.update_xaxes(tickformat='d')
    fig.update_yaxes(title="Value")
    if plot:
        fig.show()
    if full_plotly_offline:
        return fig.to_html(full_html=False)
    else:
        return fig.to_html(full_html=False, include_plotlyjs=False)





In [9]:

def multiline_plot_dashoption(
    df,
    x_col,
    dot_pattern=None,
    startstop=None,
    colors=None,
    width=800,
    height=500,
    x_int_only=True,
    plot=True,
    full_plotly_offline=False):
    if x_col not in df.columns:
        print(f"{x_col} not in dataframe columns")
        return None

    y_cols = [col for col in df.columns if col != x_col]
    if not y_cols:
        print("No columns to plot on y axis")
        return None

    if startstop is None:
        startstop = [0, len(df)]

    # Determine line dash style for each column
    line_dash = {}
    for col in y_cols:
        if dot_pattern and dot_pattern in col:
            line_dash[col] = 'dot'
        else:
            line_dash[col] = 'solid'

    fig = px.line(
        df.iloc[startstop[0]:startstop[1]],
        x=x_col,
        y=y_cols,
        template="plotly_white",
        width=width,
        height=height,
        color_discrete_map=dict(zip(y_cols, colors)) if colors else None
    )

    # Update line dash for each trace
    for i, col in enumerate(y_cols):
        fig.update_traces(
            selector={'name': col},
            line_dash=line_dash[col]
        )

    if x_int_only:
        fig.update_xaxes(tickformat='d')
    fig.update_yaxes(title="Value")

    if plot:
        fig.show()
    if full_plotly_offline:
        return fig.to_html(full_html=True)
    else:
        return fig.to_html(full_html=False, include_plotlyjs=False)

### Compute dataframes

In [10]:
dict_test_losses = {"iteration" : losses_df[list(plotted_models.keys())[0]]["iteration"]}
dict_train_losses = {"iteration" : losses_df[list(plotted_models.keys())[0]]["iteration"]}



for m in plotted_models.keys():
    print(f"Working for model {m}...")
    dict_test_losses[m+"_test_loss"] = losses_df[m]["test_loss"].to_numpy()
    dict_train_losses[m+"_train_loss"] = losses_df[m]["train_loss"].to_numpy()

df_test_losses = pd.DataFrame(dict_test_losses)
df_train_losses = pd.DataFrame(dict_train_losses)


df_test_losses.head()
df_train_losses.head()

Working for model 6262...
Working for model 9174...
Working for model 8767...


Unnamed: 0,iteration,6262_train_loss,9174_train_loss,8767_train_loss
0,0.0,10.86224,10.894934,10.882893
1,500.0,6.317381,6.287441,6.294331
2,1000.0,4.543505,4.568794,4.560696
3,1500.0,3.821358,3.845598,3.777266
4,2000.0,3.421336,3.443779,3.428401


In [16]:
# convert iter to tokens
iter_to_token_seen = n_block * n_batch

df_test_losses["training tokens"] = iter_to_token_seen * df_test_losses["iteration"]
df_train_losses["training tokens"] = iter_to_token_seen * df_train_losses["iteration"]

# remove iteration column
df_test_losses = df_test_losses.drop(columns=["iteration"])
df_train_losses = df_train_losses.drop(columns=["iteration"])

# Plot

### Compare test losses

In [29]:

multiline_plot(df=df_test_losses, x_col="training tokens", 
                startstop=None, colors=colors_distinct_hex, 
                width=800, height=500, 
                x_int_only=False, plot=True, full_plotly_offline=False);

### Compare train losses

In [19]:
multiline_plot(df=df_train_losses, x_col="training tokens", 
                startstop=None, colors=colors_distinct_hex, 
                width=800, height=500, 
                x_int_only=False, plot=True, full_plotly_offline=False);

### Compare test & train losses

In [21]:
df_both_losses = pd.concat((df_train_losses, df_test_losses[[m+"_test_loss" for m in plotted_models.keys()]] ), axis=1)
df_both_losses.head()

Unnamed: 0,6262_train_loss,9174_train_loss,8767_train_loss,training tokens,6262_test_loss,9174_test_loss,8767_test_loss
0,10.86224,10.894934,10.882893,0.0,10.862324,10.891705,10.880738
1,6.317381,6.287441,6.294331,1024000.0,6.328369,6.296029,6.321471
2,4.543505,4.568794,4.560696,2048000.0,4.546401,4.56493,4.595034
3,3.821358,3.845598,3.777266,3072000.0,3.801347,3.807614,3.815633
4,3.421336,3.443779,3.428401,4096000.0,3.447,3.453033,3.460869


In [22]:
colors_compare = [colors_distinct_hex[k] for k in range(n_plotted_models)]* 2
multiline_plot_dashoption( df=df_both_losses, x_col="training tokens", 
                            dot_pattern="train", 
                            startstop=None, colors=colors_compare, 
                            width=800, height=500, x_int_only=False, 
                            plot=True, full_plotly_offline=False);

### Plot final loss unique tokens

In [37]:
final_losses = []
models = []
unicityfract = []
for model in plotted_models.keys():
    models.append(model)
    final_losses.append(df_test_losses[f"{model}_test_loss"].to_numpy()[-1])
    unicityfract.append( 1/plotted_models[model]["token_times_seen"] )

df_loss_vs_unicityfract = pd.DataFrame({"model":models, 
                                        "final_losses":final_losses,
                                        "fraction of unique tokens":unicityfract})
                                    

# add reference "ideal" loss as being the one with the most unique tokens
most_unique_token_line = df_loss_vs_unicityfract["fraction of unique tokens"].idxmax()
reference_loss = df_loss_vs_unicityfract.loc[most_unique_token_line, "final_losses"]
df_loss_vs_unicityfract["reference_loss"] = reference_loss

# show head
df_loss_vs_unicityfract.head()

Unnamed: 0,model,final_losses,fraction of unique tokens,reference_loss
0,6262,1.915381,1.0,1.915381
1,9174,1.913745,0.226563,1.915381
2,8767,2.746964,0.022656,1.915381


In [47]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=df_loss_vs_unicityfract["fraction of unique tokens"], 
                         y=df_loss_vs_unicityfract["reference_loss"],
                    mode='lines',
                    line=dict(width=2, color='red', dash="dot"),
                    name='final loss for 100% unique tokens'))

fig.add_trace(go.Scatter(x=df_loss_vs_unicityfract["fraction of unique tokens"], 
                         y=df_loss_vs_unicityfract["final_losses"],
                    mode='markers',
                    marker=dict(size=16, color='blue'),
                    name='final_losses'))


fig.update_layout(
    title='Final losses vs fraction of unique tokens',
    xaxis_title='Fraction of unique tokens',
    yaxis_title='Final Losses',
    template="plotly_white",
    showlegend=True,
    yaxis_range=[0,4]
)

fig.show()