In [1]:
%load_ext autoreload
%autoreload 2

import plotly.express as px
import plotly.graph_objects as go
import re
import time

from typing import Dict, List, Optional

def read_file_lines(file_path: str):
    try:
        # with open(file_path, 'r') as f:
        with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
            lines = f.readlines()

        return lines
    except Exception as e:
        print(e)
        return None
    
def get_loss_values(job_log):
    lines = read_file_lines(job_log)
    loss_values = []
    
    for line in lines:
        match = re.search(r"'loss': (\d+\.\d+)", line)
        if match:
            loss_value = float(match.group(1))
            loss_values.append(loss_value)
    
    return loss_values

def plot_train_loss(
    loss_dict: Dict[str, List[float]], 
    title: Optional[str] = None, 
    skip_steps: Optional[Dict[str, int]]= None
):
    """
    loss_dict = {
        name: [loss array],
        name2: [loss array2],
        ...
    }
    """
    if skip_steps is None:
        skip_steps = {}
        
    fig = go.Figure()

    for name, loss_array in loss_dict.items():
        skip = skip_steps.get(name, 0)
        steps = [skip + 10*i for i in range(len(loss_array))]
        fig.add_trace(go.Scatter(x=steps, y=loss_array, name=name))

    fig.update_layout(title=f'Train Loss: {name}',
                      xaxis_title='Step',
                      yaxis_title='Loss')
    fig.show()

# MM8 Tests

In [None]:
log_dict = {
    # "70b_8node_yury_6717": "/data/home/yastashonok/logs/output_6717.txt",
    # "70b_64node_harry_6733": "/fsx_0/user/tranx/output/slurm_logs/output_6733.txt",
    # "70b_128node_harry_6743": "/fsx_0/user/tranx/output/slurm_logs/output_6743.txt",
    # "70b_8node_harry_6815_w_BroadcastDataset": "/fsx_0/user/tranx/output/slurm_logs/output_6815.txt",
    # "70b_128node_harry_6839_w_BroadcastDataset": "/fsx_0/user/tranx/output/slurm_logs/output_6839.txt",
    # "70b_160node_harry_6840_w_BroadcastDataset": "/fsx_0/user/tranx/output/slurm_logs/output_6840.txt"
    "70B_160node_AWS": "/fsx_0/user/tranx/output/slurm_logs/output_6840.txt"
}

loss_values = {}
for name, log in log_dict.items():
    print(name, log)
    loss_values[name] = get_loss_values(log)

fig = go.Figure()
for name, losses in loss_values.items():
    steps = list(range(len(losses)))

    if name == "70b_128node_harry_6839_w_BroadcastDataset":
        gap = len(loss_values["70b_128node_harry_6743"])
        steps = [s + gap for s in steps]

    fig.add_trace(go.Scatter(x=steps, y=losses, name=name))

fig.show()

# 70B - fixed pipeline

In [None]:
runs = {
    "70B_160node_AWS-1": "/fsx_0/user/tranx/output/slurm_logs/output_6840.txt",
    "70B_256node_AWS-2": "/fsx_0/user/tranx/output/slurm_logs/output_6926.txt",
    "70B_256node_AWS-3": "/fsx_0/user/tranx/output/slurm_logs/output_6939.txt"
}

checkpoint_interval = 200

loss_values = {}
for name, log in runs.items():
    print(name, log)
    loss_values[name] = get_loss_values(log)

fig = go.Figure()
resume_step = 0

for name, losses in loss_values.items():
    steps = [resume_step + 10*i for i in range(len(losses))]
    resume_step = steps[-1]
    resume_step = checkpoint_interval * (resume_step // checkpoint_interval)
    fig.add_trace(go.Scatter(x=steps, y=losses, name=name))

fig.update_layout(title='Loss Values',
                  xaxis_title='Steps',
                  yaxis_title='Loss')
fig.show()

In [None]:
runs = [
    "/fsx_0/user/tranx/output/slurm_logs/output_6840.txt",
    "/fsx_0/user/tranx/output/slurm_logs/output_6926.txt",
    "/fsx_0/user/tranx/output/slurm_logs/output_6939.txt"
]

aws_loss = []
for log in runs:
    loss = get_loss_values(log)
    aws_loss.extend(loss)
aws_steps = [10*i for i in range(len(aws_loss))]

In [None]:
df = pd.read_csv("~/train_loss_f578009631.csv")
fbl_steps = df.step
fbl_loss = df.loss

In [None]:
# fbl_steps = [10, 70, 800, 1000, 2000]  # , 4000]
# fbl_loss = [9.581, 2.498, 1.348, 1.225, 1.121]  # , 1.065]

fig = go.Figure()

fig.add_trace(go.Scatter(x=aws_steps, y=aws_loss, name="AWS"))
fig.add_trace(go.Scatter(x=fbl_steps, y=fbl_loss, name="FBL"))
# fig.add_trace(go.Scatter(x=fbl_steps, y=fbl_loss, name="FBL",
#               mode="markers", marker_symbol="star", marker_size=10))

fig.update_layout(
    title='AWS vs. FBL comparison: 70B MM8 train loss',
    xaxis_title='Steps',
    yaxis_title='Loss',
    legend=dict(orientation="h", yanchor="bottom", y=1.02)
)
fig.show()

# MM9

## Initial Tests

In [None]:
log_dict = {
    "MM9_70B_2nodes": "/fsx_0/user/tranx/output/slurm_logs/output_6955.txt"
}

loss_values = {}
for name, log in log_dict.items():
    print(name, log)
    loss_values[name] = get_loss_values(log)

fig = go.Figure()
for name, losses in loss_values.items():
    steps = [10*i for i in range(len(losses))]

    fig.add_trace(go.Scatter(x=steps, y=losses, name=name))

fig.update_layout(title='Train Loss',
                  xaxis_title='Step',
                  yaxis_title='Loss')
fig.show()

In [None]:
runs = [6973, 6977]
losses = []
for j in runs:
    log = f"/fsx_0/user/tranx/output/slurm_logs/output_{j}.txt"
    print(f"reading losses from {log}")
    log_losses = get_loss_values(log)
    losses.extend(log_losses)


fig = go.Figure()

steps = [10*i for i in range(len(losses))]

name = "MM9_70B_Llama3.1"
fig.add_trace(go.Scatter(x=steps, y=losses, name="MM9_70B_Llama3.1"))

fig.update_layout(title=f'Train Loss: {name}',
                  xaxis_title='Step',
                  yaxis_title='Loss')
fig.show()

## MM9_70B_Llama3.1_336px_128nodes

In [None]:
# FBL config

runs = [7024]
losses = []
for j in runs:
    log = f"/fsx_0/user/tranx/output/slurm_logs/output_{j}.txt"
    print(f"reading losses from {log}")
    log_losses = get_loss_values(log)
    losses.extend(log_losses)


plot_train_loss(
    loss_dict={"MM9_70B_Llama3.1": losses},
    title="MM9_70B_Llama3.1_336px_128nodes"
)

## MM9 - all runs

In [2]:
import pandas as pd

# https://www.internalfb.com/mlhub/pipelines/runs/fblearner/587797729?tab=Visualizations
df = pd.read_csv(
    # "/data/home/tranx/run_516421604076231-4-train_loss.csv",
    # "/data/home/tranx/train_loss_f578009631_0806.csv",
    # "train_loss_f578009631_0806.csv"
    "/fsx_0/user/tranx/experiments/llm_mm_aligner/fb_reference_jobs/train_loss_f578009631_0809.csv",
    header=None, skiprows=1
)
df.columns = ['step', 'loss']
df.head()

Unnamed: 0,step,loss
0,1,9.6633
1,20,6.3737
2,40,2.7365
3,50,2.4793
4,60,2.3727


In [45]:
import json
import glob

In [55]:
def get_loss_from_checkpoint(checkpoint_dir):
    checkpoints = glob.glob(f"{checkpoint_dir}/checkpoint-*")
    max_checkpoint = 0
    for c in checkpoints:
        num = int(c.split('-')[-1])
        max_checkpoint = max(num, max_checkpoint)

    print("max_checkpoint:", max_checkpoint)

    trainer_state = f"{checkpoint_dir}/checkpoint-{max_checkpoint}/trainer_state.json"
    with open(trainer_state, 'r') as f:
        state = json.load(f)

    steps = [s['step'] for s in state['log_history']]
    losses = [s['loss'] for s in state['log_history']]

    return steps, losses


lm_dir = "/fsx_0/checkpoints/tranx/MM9-Pretrain-70B/Llama31_336px_128nodes_bz32_scratch"
lm31_steps, lm31_losses = get_loss_from_checkpoint(lm_dir)

max_checkpoint: 4200


In [57]:
def plot_scheme_loss(scheme, yrange=None):

    fig = go.Figure()

    for s in schemes:
        # read loss from jobs
        if "jobs" in s:
            s["loss"] = []
            for j in s["jobs"]:
                log = f"/fsx_0/user/tranx/output/slurm_logs/output_{j}.txt"
                print(f"Reading loss from {log}")
                loss_j = get_loss_values(log)
                s["loss"].extend(loss_j)

        if 'steps' in s:
            steps = s['steps']
        else:
            steps = [s["resume_step"] + i*s["step_scale"]
                     for i in range(len(s["loss"]))]

        if 'color' in s:
            fig.add_trace(go.Scatter(
                x=steps, y=s["loss"], name=s["name"], mode='lines', line=dict(color=s['color'])))
        else:
            fig.add_trace(go.Scatter(
                x=steps, y=s["loss"], name=s["name"], mode='lines'))

    fig.update_layout(
        title=f'Train Loss', xaxis_title='Step', yaxis_title='Loss',
        legend=dict(orientation="h", yanchor="bottom", y=1.02),
        # yaxis=dict(range=[0.8, 1.6])
    )

    if yrange is not None:
        fig.update_layout(yaxis=dict(range=yrange))

    fig.show()


schemes = [
    {
        "name": "f587797729_70B_Llama3.1_336px_128nodes_bz64",
        "name_comment": "#5 f587797729_70B_Llama3.1_336px_128nodes_bz64",
        "loss": df['loss'],
        "steps": df['step'],
        "resume_step": 0,
        "step_scale": 10,  # 10 for bz=32, gradient_accumulation_step=4, nodes=128
        "color": "black"
    },
    {
        "name": "MM9_70B_Llama3.1_336px_128nodes_bz32",
        "name_comment": "#1.2 MM9_70B_Llama3.1_336px_128nodes_bz32_retrain",
        # something missing between 8231 and 8390
        # "jobs": [7260, 7499, 8231, 8390, 8407],
        # "jobs": [7260, 7499],
        "loss": lm31_losses,
        "steps": lm31_steps,
        "resume_step": 0,
        "step_scale": 10,  # 10 for bz=32, gradient_accumulation_step=4, nodes=128
        "color": "blue"
    },
    {
        "name": "MM9_70B_MH19_336px",
        "name_comment": "#3.2 MM9_70B_MH19_336px_128nodes_fixed",
        "jobs": [7144],
        "resume_step": 3300,
        "step_scale": 10,
        "color": "red"
    },
    # {
    #     "name": "#3.3 MM9_70B_MH19_336px_128nodes_retrain",
    #     "jobs": [7261],
    #     "resume_step": 0,
    #     "step_scale": 10,
    #     "color": "red"
    # },
    # {
    #     "name": "#3.5 MM9_70B_MH19_336px_384nodes",
    #     "jobs": [7343],
    #     "resume_step": 2100,
    #     "step_scale": 30,
    #     "color": "red"
    # },
    {
        "name": "MM9_70B_MH19_336px",
        "name_comment": "#3.6 MM9_70B_MH19_336px_256nodes",
        "jobs": [7465],
        "resume_step": 7500,
        "step_scale": 20,
        "color": "red"
    },
    {
        "name": "MM9_70B_MH19_336px",
        "name_comment": "#3.7 MM9_70B_MH19_336px_128nodes_resume",
        "jobs": [7700, 8351],
        "resume_step": 8500,
        "step_scale": 10,
        "color": "red"
    },

]

plot_scheme_loss(schemes, yrange=[0.9, 1.4])

Reading loss from /fsx_0/user/tranx/output/slurm_logs/output_7144.txt
Reading loss from /fsx_0/user/tranx/output/slurm_logs/output_7465.txt
Reading loss from /fsx_0/user/tranx/output/slurm_logs/output_7700.txt
Reading loss from /fsx_0/user/tranx/output/slurm_logs/output_8351.txt


In [None]:
def plot_train_loss2(
    loss_dict: Dict[str, List[float]],
    title: Optional[str] = None,
    skip_steps: Optional[Dict[str, int]] = None
):
    """
    loss_dict = {
        name: [loss array],
        name2: [loss array2],
        ...
    }
    """
    if skip_steps is None:
        skip_steps = {}

    fig = go.Figure()

    for name, loss_array in loss_dict.items():
        skip = skip_steps.get(name, 0)
        log_step = 10

        if name in [
            # "f587797729_70B_Llama3.1_336px_128nodes",
            "MM9_70B_Llama3.1_336px_256nodes",
            "#3.6 MM9_70B_MH19_336px_256nodes_bz32",
        ]:
            log_step = 20
        elif name == "#3.5 MM9_70B_MH19_336px_384nodes_bz32":
            log_step = 30

        steps = [skip + log_step*i for i in range(len(loss_array))]
        fig.add_trace(go.Scatter(
            x=steps, y=loss_array, name=name, mode='lines'))

    fig.update_layout(title=f'Train Loss',
                      #   xaxis_title='Step (normalized per 128 nodes x bz 32 x grad_accm. 4)',
                      xaxis_title='Step',
                      yaxis_title='Loss',
                      #   legend=dict(orientation="h"),
                      legend=dict(
                            orientation="h",
                            yanchor="bottom", y=1.02
                            # xanchor="center"
                      ),
                      yaxis=dict(range=[0.8, 1.4])
                      #   xaxis=dict(range=[0, 7000]), yaxis=dict(range=[1.0, 1.4]),
                      #   width=800, height=600,
                      )
    fig.show()


runs = {
    # "#1 MM9_70B_Llama3.1_336px_128nodes": [7024],
    # "#1.1 MM9_70B_Llama3.1_336px_128nodes_fixed": [7204],
    "#1.2 MM9_70B_Llama3.1_336px_128nodes_bz32_retrain": [7260],
    # "#1.3 MM9_70B_Llama3.1_336px_128nodes_fixed_norm_loss": [7304],
    # "#2 MM9_70B_Llama3.1_504px_128nodes": [7047],
    # "#3 MM9_70B_MH19_336px_128nodes": [7044, 7118],
    # "#3.1 MM9_70B_MH19_336px_128nodes_bz64": [7119],
    "#3.2 MM9_70B_MH19_336px_128nodes_fixed": [7144],
    "#3.3 MM9_70B_MH19_336px_128nodes_bz32_retrain": [7261],
    # "#3.4 MM9_70B_MH19_336px_128nodes_bz48_retrain": [7339],
    "#3.5 MM9_70B_MH19_336px_384nodes_bz32": [7343],
    "#3.6 MM9_70B_MH19_336px_256nodes_bz32": [7465],

    # "#4 MM9_70B_Llama3.1_336px_256nodes": [7083]

    # "#1.3 MM9_70B_Llama3.1_336px_128nodes_fixed_norm_loss": [7304],
}

skip_steps = {
    "#1.1 MM9_70B_Llama3.1_336px_128nodes_fixed": 1200,
    "#3.1 MM9_70B_MH19_336px_128nodes_bz64": 2300,
    "#3.2 MM9_70B_MH19_336px_128nodes_fixed": 3300,  # 1000
    "#4 MM9_70B_Llama3.1_336px_256nodes": 3000,
    "#3.5 MM9_70B_MH19_336px_384nodes_bz32": 2100,
    # "#1.2 MM9_70B_Llama3.1_336px_128nodes__retrain": -100,
    # "#3.3 MM9_70B_MH19_336px_128nodes_fixed_retrain": -100,
    "#3.6 MM9_70B_MH19_336px_256nodes_bz32": 7500
}


loss_dict = {}
for name in runs:
    loss_dict[name] = []
    for j in runs[name]:
        log = f"/fsx_0/user/tranx/output/slurm_logs/output_{j}.txt"
        print(f"Reading loss from {log}")
        loss_j = get_loss_values(log)
        loss_dict[name].extend(loss_j)

# add fb reference loss
loss_dict["#5 f587797729_70B_Llama3.1_336px_128nodes_bz64"] = df['loss']

plot_train_loss2(loss_dict, skip_steps=skip_steps)

# Train Throughput

In [None]:
import os


def get_checkpoint_time(dir):
    start_ts = os.path.getctime(os.path.join(dir, "runs"))
    chk_folders = [c for c in os.listdir(dir) if c.startswith('checkpoint')]
    chk_folders

    chk_list = [[int(c.split('-')[-1]), c] for c in chk_folders]
    chk_list = sorted(chk_list)
    chk_list

    for step, folder in chk_list:
        ts = os.path.getctime(os.path.join(dir, folder))
        elasped_ts = ts - start_ts
        start_ts = ts

        print(step, elasped_ts)

In [None]:
dir160 = '/fsx_0/checkpoints/tranx/Aligner-Pretrain-70B/output_n160_retrain'
dir256 = '/fsx_0/checkpoints/tranx/Aligner-Pretrain-70B/output_n256'
get_checkpoint_time(dir160)
get_checkpoint_time(dir256)

In [None]:
import numpy as np
t160 = int(np.mean([6737.0, 6805.0, 6970.0]))
t256 = int(np.mean([8526.0, 8497, 8668]))

t160, t256

In [None]:
time_per_200_steps = t160
nodes = 160

num_images_per_200_steps = 200 * 32 * nodes * 8
num_images_per_day = num_images_per_200_steps * 86400 / time_per_200_steps
mm_images_per_day = num_images_per_day/1e6
int(mm_images_per_day)

In [None]:
time_per_200_steps = t256
nodes = 256

num_images_per_200_steps = 200 * 32 * nodes * 8
num_images_per_day = num_images_per_200_steps * 86400 / time_per_200_steps
mm_images_per_day = num_images_per_day/1e6
int(mm_images_per_day)

In [None]:
# fbl per 7 days
days = 7.768
steps = 11840
steps_per_day = steps / days
num_images_per_day = steps_per_day * 32 * 160 * 8
mm_images_per_day = num_images_per_day/1e6
int(mm_images_per_day)

In [None]:
# fbl per 1st day
days = 1
steps = 1700
steps_per_day = steps / days
num_images_per_day = steps_per_day * 32 * 160 * 8
mm_images_per_day = num_images_per_day/1e6
int(mm_images_per_day)

In [None]:
num_images = (45739 - 19500) * 32 * 64 * 8  # step x batch_size x nodes x GPUs
trained_time = 24 + 14  # hours
tput_per_day = num_images / (trained_time/24)
tput_M_per_day = tput_per_day / 1e6
print(tput_M_per_day)