In [2]:
%load_ext autoreload
%autoreload 2

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import glob
import json
from typing import Dict, List, Optional

def get_loss_from_checkpoint(checkpoint_dir):
    checkpoints = glob.glob(f"{checkpoint_dir}/checkpoint-*")
    max_checkpoint = 0
    for c in checkpoints:
        num = int(c.split('-')[-1])
        max_checkpoint = max(num, max_checkpoint)

    print("max_checkpoint:", max_checkpoint)

    trainer_state = f"{checkpoint_dir}/checkpoint-{max_checkpoint}/trainer_state.json"
    with open(trainer_state, 'r') as f:
        state = json.load(f)

    steps = [s['step'] for s in state['log_history']]
    losses = [s['loss'] for s in state['log_history']]

    return steps, losses

def plot_compare_loss(schemas, yrange=None):

    fig = go.Figure()

    for s in schemas:
        steps = s['steps']

        if 'color' in s:
            fig.add_trace(go.Scatter(
                x=steps, y=s["loss"], name=s["name"], mode='lines', line=dict(color=s['color'])))
        else:
            fig.add_trace(go.Scatter(
                x=steps, y=s["loss"], name=s["name"], mode='lines'))

    fig.update_layout(
        title=f'Train Loss', xaxis_title='Step (normalized to nodes=128 x bz=32 x gradient_accumulation_step=4)', yaxis_title='Loss',
        legend=dict(orientation="h", yanchor="bottom", y=1.02),
        # yaxis=dict(range=[0.8, 1.6])
    )

    if yrange is not None:
        fig.update_layout(yaxis=dict(range=yrange))

    return fig

# Losses

In [22]:
OUTPUT_MAIN = "/fsx_0/checkpoints/tranx/MM10-Pretrain-70B/MH21_70B_224px_0916"
OUTPUT_NORM_32 = "/fsx_0/checkpoints/kapilk/MM10-Pretrain-70B/MH21_70B_224px_0918_normUnifVariance"
OUTPUT_NORM_128 = "/fsx_0/checkpoints/kapilk/MM10-Pretrain-70B/MH21_70B_224px_0920_128n_normUnifVariance"
OUTPUT_NORM_MAIN_128 = "/fsx_0/checkpoints/mm10/MM10-Stage1-70B/MH21_70B_224px_norm_R2"

# bz=32, gradient_accumulation: 4
main_steps, main_losses = get_loss_from_checkpoint(OUTPUT_MAIN)

# bz=32, gradient_accumulation: 8. With 32 nodes to step size = 1/2 that of 128-node job
norm_32_steps, norm_32_losses = get_loss_from_checkpoint(OUTPUT_NORM_32)
norm_32_steps = [int(x/2) for x in norm_32_steps]
max_norm_32_steps = norm_32_steps[-1]

# bz=32, gradient_accumulation: 4
norm_128_steps, norm_128_losses = get_loss_from_checkpoint(OUTPUT_NORM_128)
norm_128_steps = [x + max_norm_32_steps for x in norm_128_steps]

norm_main_128_steps, norm_main_128_losses = get_loss_from_checkpoint(OUTPUT_NORM_MAIN_128)
norm_main_128_steps = [x + norm_128_steps[-1] for x in norm_main_128_steps]

schemes = [
    {
        "name": "main_128",
        "loss": main_losses,
        "steps": main_steps
    },
    
    {
        "name": "norm_32",
        "loss": norm_32_losses,
        "steps": norm_32_steps
    },
        
    {
        "name": "norm_128",
        "loss": norm_128_losses,
        "steps": norm_128_steps
    },
    
    {
        "name": "norm_main_128 (+m2c2)",
        "loss": norm_main_128_losses,
        "steps": norm_main_128_steps
    }
]

plot_compare_loss(schemes, yrange=[2,3])

max_checkpoint: 3000
max_checkpoint: 1400
max_checkpoint: 1600
max_checkpoint: 3900
