In [2]:
%load_ext autoreload
%autoreload 2

import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import glob
import json
from typing import Dict, List, Optional

def get_loss_from_checkpoint(checkpoint_dir):
    checkpoints = glob.glob(f"{checkpoint_dir}/checkpoint-*")
    max_checkpoint = 0
    for c in checkpoints:
        num = int(c.split('-')[-1])
        max_checkpoint = max(num, max_checkpoint)

    print("max_checkpoint:", max_checkpoint)

    trainer_state = f"{checkpoint_dir}/checkpoint-{max_checkpoint}/trainer_state.json"
    with open(trainer_state, 'r') as f:
        state = json.load(f)

    steps = [s['step'] for s in state['log_history']]
    losses = [s['loss'] for s in state['log_history']]

    return steps, losses

def plot_compare_loss(schemas, yrange=None):

    fig = go.Figure()

    for s in schemas:
        steps = s['steps']

        if 'color' in s:
            fig.add_trace(go.Scatter(
                x=steps, y=s["loss"], name=s["name"], mode='lines', line=dict(color=s['color'])))
        else:
            fig.add_trace(go.Scatter(
                x=steps, y=s["loss"], name=s["name"], mode='lines'))

    fig.update_layout(
        title=f'Train Loss', xaxis_title='Step (normalized to nodes=128 x bz=32 x gradient_accumulation_step=4)', yaxis_title='Loss',
        legend=dict(orientation="h", yanchor="bottom", y=1.02),
        # yaxis=dict(range=[0.8, 1.6])
    )

    if yrange is not None:
        fig.update_layout(yaxis=dict(range=yrange))

    return fig

# Losses

In [66]:
OUTPUT_MAIN = "/fsx_0/checkpoints/tranx/MM10-Pretrain-70B/MH21_70B_224px_0916"
OUTPUT_NORM_32 = "/fsx_0/checkpoints/kapilk/MM10-Pretrain-70B/MH21_70B_224px_0918_normUnifVariance"
OUTPUT_NORM_128 = "/fsx_0/checkpoints/kapilk/MM10-Pretrain-70B/MH21_70B_224px_0920_128n_normUnifVariance"
OUTPUT_NORM_MAIN_128 = "/fsx_0/checkpoints/mm10/MM10-Stage1-70B/MH21_70B_224px_norm_R2"

# bz=32, gradient_accumulation: 4
main_steps, main_losses = get_loss_from_checkpoint(OUTPUT_MAIN)

# bz=32, gradient_accumulation: 8. With 32 nodes to step size = 1/2 that of 128-node job
norm_32_steps, norm_32_losses = get_loss_from_checkpoint(OUTPUT_NORM_32)
norm_32_steps = [int(x/2) for x in norm_32_steps]
max_norm_32_steps = norm_32_steps[-1]

# bz=32, gradient_accumulation: 4
norm_128_steps, norm_128_losses = get_loss_from_checkpoint(OUTPUT_NORM_128)
norm_128_steps = [x + max_norm_32_steps for x in norm_128_steps]

norm_main_128_steps, norm_main_128_losses = get_loss_from_checkpoint(OUTPUT_NORM_MAIN_128)
norm_main_128_steps = [x + norm_128_steps[-1] for x in norm_main_128_steps]

schemes = [
    {
        "name": "main_128 (m2c2 1B)",
        "loss": main_losses,
        "steps": main_steps
    },
    
    {
        "name": "norm_32 (m2c2 1B)",
        "loss": norm_32_losses,
        "steps": norm_32_steps
    },
        
    {
        "name": "norm_128 (m2c2 1B)",
        "loss": norm_128_losses,
        "steps": norm_128_steps
    },
    
    {
        "name": "norm_main_128 (m2c2 1B + metaclip 2B)",
        "loss": norm_main_128_losses,
        "steps": norm_main_128_steps
    }
]

plot_compare_loss(schemes, yrange=[2,3])

max_checkpoint: 3000
max_checkpoint: 1400
max_checkpoint: 1600
max_checkpoint: 5600


# Evals

In [41]:
def read_mmmu(file):
    df = pd.read_csv(file, delimiter='\t', header=None, names=['Step', 'MMMU'])
    
    return df

In [42]:
df1 = read_mmmu("/fsx_0/user/tranx/experiments/metrics/mm10/stage1_mmmu_main.csv")
df1

Unnamed: 0,Step,MMMU
0,100,0.4078
1,200,0.5167
2,300,0.5011
3,400,0.4956
4,500,0.5122
5,600,0.5089
6,700,0.5333
7,800,0.5378
8,900,0.5389
9,1000,0.5244


In [45]:
df2a = read_mmmu("/fsx_0/user/tranx/experiments/metrics/mm10/stage1_mmmu_norm1.csv")
df2a['Step'] = df2a['Step'] + 700
df2a

Unnamed: 0,Step,MMMU
0,800,0.5444
1,900,0.5456
2,1000,0.54
3,1100,0.5422
4,1200,0.5189
5,1300,0.5533
6,1400,0.5589
7,1500,0.5156
8,1600,0.5589
9,1700,0.5356


In [47]:
df2b = read_mmmu("/fsx_0/user/tranx/experiments/metrics/mm10/stage1_mmmu_norm2.csv")
df2b['Step'] = df2b['Step'] + 2300
df2b

Unnamed: 0,Step,MMMU
0,2500,0.5489
1,2600,0.5611
2,2700,0.5444
3,2800,0.5478
4,2900,0.5533
5,3000,0.5478
6,3100,0.5411
7,3200,0.5411
8,3300,0.5489
9,3400,0.5367


In [55]:
fig = go.Figure()


fig.add_trace(go.Scatter(
    x=df1.Step, y=df1.MMMU, name="main (m2c2 1B)"))

fig.add_trace(go.Scatter(
    x=df2a.Step, y=df2a.MMMU, name="norm (m2c2 1B)"))

fig.add_trace(go.Scatter(
    x=df2b.Step, y=df2b.MMMU, name="norm (m2c2 1B + metaclip 2B)"))

fig.update_layout(
    title=f'MMMU', xaxis_title='Step (normalized to nodes=128 x bz=32 x gradient_accumulation_step=4)', yaxis_title='Accuracy',
    legend=dict(orientation="h", yanchor="bottom", y=1.02),
)

fig.show()