In [17]:
# README: Need to run on a jupyter kernel with `plotly` installed
import plotly.express as px
import plotly.graph_objects as go
import json
import glob
import pandas as pd

In [18]:
def save_plotly_to_html(fig, output_file, width='100%', height=700):

    html_string = fig.to_html(
        include_plotlyjs=True,
        default_width=width,
        default_height=height
    )

    # save the HTML string to a file
    if not output_file.endswith('.html'):
        output_file += '.html'

    with open(output_file, 'w') as f:
        f.write(html_string)

    print(f"Saved figure to {output_file}")


def get_loss_from_checkpoint(checkpoint_dir):
    checkpoints = glob.glob(f"{checkpoint_dir}/checkpoint-*")
    max_checkpoint = 0
    for c in checkpoints:
        num = int(c.split('-')[-1])
        max_checkpoint = max(num, max_checkpoint)

    print("max_checkpoint:", max_checkpoint)

    trainer_state = f"{checkpoint_dir}/checkpoint-{max_checkpoint}/trainer_state.json"
    with open(trainer_state, 'r') as f:
        state = json.load(f)

    steps = [s['step'] for s in state['log_history']]
    losses = [s['loss'] for s in state['log_history']]

    return steps, losses


def plot_compare_loss(schemas, yrange=None):

    fig = go.Figure()

    for s in schemas:
        steps = s['steps']
        if 'color' in s:
            fig.add_trace(go.Scatter(
                x=steps, y=s["loss"], name=s["name"], mode='lines', line=dict(color=s['color'])))
        else:
            fig.add_trace(go.Scatter(
                x=steps, y=s["loss"], name=s["name"], mode='lines'))

    fig.update_layout(
        title=f'Train Loss', xaxis_title='Step', yaxis_title='Loss',
        legend=dict(orientation="h", yanchor="bottom", y=1.02),
    )

    if yrange is not None:
        fig.update_layout(yaxis=dict(range=yrange))

    return fig

In [31]:
# https://www.internalfb.com/mlhub/pipelines/runs/fblearner/587797729?tab=Visualizations
df = pd.read_csv(
    "/fsx_0/user/tranx/experiments/llm_mm_aligner/fb_reference_jobs/train_loss_f578009631_0809.csv",
    header=None, skiprows=1
)
df.columns = ['step', 'loss']
df.head()

Unnamed: 0,step,loss
0,1,9.6633
1,20,6.3737
2,40,2.7365
3,50,2.4793
4,60,2.3727


In [33]:
lm_dir = "/fsx_0/checkpoints/tranx/MM9-Pretrain-70B/Llama31_336px_128nodes_bz32_scratch"
lm31_steps, lm31_losses = get_loss_from_checkpoint(lm_dir)

schemes = [
    {
        "name": "f587797729_70B_Llama3.1_336px_128nodes_bz64",
        "loss": df['loss'],
        "steps": df['step'],
        "resume_step": 0,
        # "step_scale": 10,  # 10 for bz=32, gradient_accumulation_step=4, nodes=128
        "color": "black"
    },

    {
        "name": "MM9_70B_Llama3.1_336px_128nodes_bz32",
        "loss": lm31_losses,
        "steps": lm31_steps,
        "color": "blue"
    },
]

fig = plot_compare_loss(schemes, yrange=[.8, 1.4])
fig.show()

max_checkpoint: 8600


In [21]:
# publish
PUBLISH_DIR = "/fsx_0/user/tranx/aws_dropbox"
save_plotly_to_html(fig, f"{PUBLISH_DIR}/test.html")

Saved figure to /fsx_0/user/tranx/aws_dropbox/test.html
