In [10]:
# parse .out files
#! for register in HI ID IN IP NA OP ne dtp; do python3 ../get_losses.py ../logs/register-llama-${register}-1.8B-869*.out --tokens-per-step=2097152 > loss_files/${register}_losses_round2.tsv ; done

In [11]:
! pip -q install pandas plotly
! pip -q install nbformat>=4.2.0

In [12]:
! rm "=4.2.0"   # pip install version creates this *eyeroll*

In [13]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import glob

In [14]:
tsv_files = glob.glob("loss_files/*.tsv")


from collections import defaultdict

registers = ["HI", "ID", "IN", "IP", "NA", "OP", "ne", "dtp"]

def regroup_files(array):
    # Create a dictionary to store groups
    grouped = defaultdict(list)

    # Iterate through the array and group by register
    for string in array:
        for register in registers:
            if string.split("/")[1].startswith(register):
                grouped[register].append(string)
                break  # Stop checking other registers once a match is found

    # Convert grouped dictionary values to a list of lists
    #result = list(grouped.values())
    result = [sorted(group) for group in grouped.values()]
    return result


print(tsv_files)
ordered_files = regroup_files(tsv_files)
print(ordered_files)



['loss_files/IN_losses_round2.tsv', 'loss_files/HI_losses_round2.tsv', 'loss_files/NA_losses_round1.tsv', 'loss_files/OP_losses_round2.tsv', 'loss_files/ne_losses_round2.tsv', 'loss_files/HI_losses_round1.tsv', 'loss_files/dtp_losses_round2.tsv', 'loss_files/OP_losses_round1.tsv', 'loss_files/NA_losses_round2.tsv', 'loss_files/dtp_losses_round1.tsv', 'loss_files/ne_losses_round1.tsv', 'loss_files/IN_losses_round1.tsv', 'loss_files/IP_losses_round2.tsv', 'loss_files/IP_losses_round1.tsv', 'loss_files/ID_losses_round2.tsv', 'loss_files/ID_losses_round1.tsv']
[['loss_files/IN_losses_round1.tsv', 'loss_files/IN_losses_round2.tsv'], ['loss_files/HI_losses_round1.tsv', 'loss_files/HI_losses_round2.tsv'], ['loss_files/NA_losses_round1.tsv', 'loss_files/NA_losses_round2.tsv'], ['loss_files/OP_losses_round1.tsv', 'loss_files/OP_losses_round2.tsv'], ['loss_files/ne_losses_round1.tsv', 'loss_files/ne_losses_round2.tsv'], ['loss_files/dtp_losses_round1.tsv', 'loss_files/dtp_losses_round2.tsv'], ['

In [15]:


epoch_sizes = {"HI": 4e9, "ID":27e9, "IN":36e9, "IP":4e9, "NA":11e9, "OP":15e9, "ne":31e9, "dtp":33e9}
fig = go.Figure()
key="tokens"

for group in ordered_files:
    dfs = []
    for file in group:
        df = pd.read_csv(file, sep='\t', skipfooter=3)
        df = df.apply(pd.to_numeric)
        dfs.append(df)
        register = file.split("/")[1].split("_")[0] #file.replace('.tsv', '')  # Add a 'Parameter' column with the filename as the value
    df = pd.concat(dfs)
    fig.add_trace(go.Scatter(
                x=df[key],
                y=df["loss"],
                mode='lines',
                name=f'{register}',
                legendgroup=f'{register}'
                )
            )
    fig.add_vline(x=epoch_sizes[register], line_dash="dot", line_color="gray", annotation_text=f"epoch {register}", legendgroup=f'{register}')

fig.update_layout(
        title=f"Loss from first {len(ordered_files[0])} training runs per register",
        xaxis_title=key,
        yaxis_title="loss",
        legend_title="Legend",
        height=700
    )

#move HI annotation lower because it clashes with IP
fig.update_layout(annotations=[{**a, **{"y":.97}} if "HI" in a["text"] else {**a} for a in fig.to_dict()["layout"]["annotations"] ])
    
fig.show()




