In [1]:
import pandas as pd
import neptune
import plotly.express as px
import plotly.graph_objects as go

In [2]:
from typing import Optional, Any, Tuple
from ast import literal_eval as make_tuple


project = neptune.init_project(
  project="pmtest/llm-random",
  mode="read-only",
  api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyMDY0ZDI5Ni05YWU3LTQyNGYtYmY4My1hZTFkY2EzYmUwMjgifQ==",
)
columns = [
    "sys/tags",
    "loss_interval/100",
    "args/learning_rate",
    "sys/name",
    "args/grad_modif_params",
    "step",
    "sys/id"
]

placements = set([
    "post_attn_and_ff",
    "post_norm",
    "post_add",
    "all",
])

def rename_to_common(df: pd.DataFrame):
    return df.rename(columns={"loss_interval/100": "loss", "args/learning_rate": "lr", "args/grad_modif_params": "grad_modif_params"}, inplace=False)


def infere_layer_type(tags: str):
    tags = tags.split(',')
    if "true_baseline" in tags:
        return "baseline"
    else:
        return "regular"

def infere_placement_type(tags: str):
    tags = tags.split(',')
    for t in tags:
        if t in placements:
            return t


def infere_c(grad_modif_params: str):
    params = grad_modif_params.split(',')
    for p in params:
        key, val = p.split('=')
        if key == "c":
            return float(val)
    return None

def infere_eps(grad_modif_params: str):
    params = grad_modif_params.split(',')
    for p in params:
        key, val = p.split('=')
        if key == "eps":
            return float(val)
    return None

def infere_norm_dims(grad_modif_params: str) -> Tuple[int, ...]:
    start_idx = grad_modif_params.find("norm_dims")
    if start_idx == -1:
        return None

    next_opening_bracket_idx = grad_modif_params.find("(", start_idx)
    next_closing_bracket_idx = grad_modif_params.find(")", start_idx)

    return tuple(make_tuple(grad_modif_params[next_opening_bracket_idx:next_closing_bracket_idx+1]))

def common_line_plot(df: pd.DataFrame, x: str, y: str, color: str, title: str, trace: Optional[Any] = None) -> go.Figure:
    fig = px.line(df, x=x, y=y, color=color, markers=True, log_x=True, log_y=True)

    if trace:
        fig.add_trace(trace)

    fig.update_layout(
        title=title,
        yaxis = dict(
            showexponent = 'all',
            exponentformat = 'power'
        ),
        xaxis = dict(
            showexponent = 'all',
            exponentformat = 'power'
        )
    )

    return fig

https://app.neptune.ai/pmtest/llm-random/


#### Baseline loss vs lr categorised by eps, total_steps

In [3]:
baseline_df = project.fetch_runs_table(tag="true_baseline", columns=columns).to_pandas()
baseline_df = rename_to_common(baseline_df)
baseline_df = baseline_df[baseline_df['grad_modif_params'].notna()]
baseline_df['eps'] = baseline_df['grad_modif_params'].apply(infere_eps)
baseline_df.sort_values('lr', inplace=True)

In [4]:
fig = common_line_plot(baseline_df, x="lr", y="loss", color="step", title="Baseline Loss vs LR. Categorised by eps, total_steps")
fig.show()

# Short Experiments (2k steps)

#### Loss vs LR categorised by baseline and sanity_checks for v1, v2 std norm

In [5]:
grad_std_norm_type_tag = {
    "std_v1": "std_v1_c_lr_grid_placement_short",
    "std_v2": "std_v2_c_lr_grid_placement_short"
}

baseline_step_2k = baseline_df[baseline_df['step'] == 2000]

for gn_name, gn_tag in grad_std_norm_type_tag.items():
    sanity_check_df = project.fetch_runs_table(owner="szysad", tag=["c_0", gn_tag], columns=columns).to_pandas()
    sanity_check_df = rename_to_common(sanity_check_df)
    sanity_check_df['eps'] = sanity_check_df['grad_modif_params'].apply(infere_eps)
    sanity_check_df['placement'] = sanity_check_df['sys/tags'].apply(infere_placement_type)
    sanity_check_df.sort_values('lr', inplace=True)

    baseline_trace = go.Scatter(x=baseline_step_2k['lr'], y=baseline_step_2k['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))
    fig = common_line_plot(sanity_check_df, x="lr", y="loss", color="placement", title=f"Sanity check Loss vs LR categorised by placement with baseline for grad norm '{gn_name}'", trace=baseline_trace)
    fig.show()

#### Loss vs LR categorised by c for each placement (short run, v1 std norm)

In [6]:
plot_cnt = len(grad_std_norm_type_tag) * len(placements)

for i, (gn_name, gn_tag) in enumerate(grad_std_norm_type_tag.items()):
    std_v1_df = project.fetch_runs_table(tag=gn_tag, columns=columns).to_pandas()
    std_v1_df = rename_to_common(std_v1_df)
    std_v1_df['layer_type'] = std_v1_df['sys/tags'].apply(infere_layer_type)
    std_v1_df['c'] = std_v1_df['grad_modif_params'].apply(infere_c)
    std_v1_df['placement'] = std_v1_df['sys/tags'].apply(infere_placement_type)
    std_v1_df.sort_values('lr', inplace=True)

    offset = i * len(placements)
    for j, p in enumerate(placements):
        df = std_v1_df[std_v1_df['placement'] == p]
        title = f"({offset + j + 1}/{plot_cnt}) Loss vs Learning Rate categorized by c, for '{p}' placement for grad norm '{gn_name}'"
        baseline_trace = go.Scatter(x=baseline_step_2k['lr'], y=baseline_step_2k['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))
        fig = common_line_plot(df, x="lr", y="loss", color="c", title=title, trace=baseline_trace)
        fig.show()

#### Loss vs LR categorised by c and norm_dims for each placement in 'norm_scale' grad norm

In [7]:
sanity_check_df = project.fetch_runs_table(owner="szysad", tag=["c_0", "scale_norm_c_lr_grid_placement_short"], columns=columns).to_pandas()
sanity_check_df = rename_to_common(sanity_check_df)
sanity_check_df['c'] = sanity_check_df['grad_modif_params'].apply(infere_c)
sanity_check_df['placement'] = sanity_check_df['sys/tags'].apply(infere_placement_type)
sanity_check_df['norm_dims'] = sanity_check_df['grad_modif_params'].apply(infere_norm_dims)
sanity_check_df['category'] = sanity_check_df.apply(lambda x: f"placement={x.placement}, c={x.c}, norm_dims={x.norm_dims}", axis=1)
sanity_check_df.sort_values('lr', inplace=True)

baseline_trace = go.Scatter(x=baseline_step_2k['lr'], y=baseline_step_2k['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))
fig = common_line_plot(sanity_check_df, x="lr", y="loss", color="category", title=f"Sanity check Loss vs LR categorised by placement, c and norm_dims for grad norm scale_norm", trace=baseline_trace)
fig.show()

In [8]:
plot_cnt = len(placements)

for i, place in enumerate(placements):
    scale_norm_df = project.fetch_runs_table(tag="scale_norm_c_lr_grid_placement_short", columns=columns).to_pandas()
    scale_norm_df = rename_to_common(scale_norm_df)
    scale_norm_df['c'] = scale_norm_df['grad_modif_params'].apply(infere_c)
    scale_norm_df['placement'] = scale_norm_df['sys/tags'].apply(infere_placement_type)
    scale_norm_df['norm_dims'] = scale_norm_df['grad_modif_params'].apply(infere_norm_dims)
    scale_norm_df['category'] = scale_norm_df.apply(lambda x: f"c={x.c}, norm_dims={x.norm_dims}", axis=1)
    scale_norm_df.sort_values('lr', inplace=True)


    df = scale_norm_df[scale_norm_df['placement'] == place]
    title = f"({i}/{plot_cnt}) Loss vs Learning Rate categorized by c and norm_dims for '{place}' placement."
    baseline_trace = go.Scatter(x=baseline_step_2k['lr'], y=baseline_step_2k['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))
    fig = common_line_plot(df, x="lr", y="loss", color="category", title=title, trace=baseline_trace)
    fig.show()

# Impact of grad normalization on gradients statistics

In [9]:
# y axis is mean of std of norms of raw gradients
# x axis is learning rate
# color is c
# we do this for each placement

# Long Experiments (16k steps)

In [10]:
table_long = project.fetch_runs_table(tag="post_add_c_lr_grid_long", columns=columns).to_pandas()
table_long['layer_type'] = table_long['sys/tags'].apply(infere_layer_type)
table_long['c'] = table_long['args/grad_modif_params'].apply(infere_c)
table_long.rename(columns={'loss_interval/100': 'loss', 'args/learning_rate': 'lr', 'sys/id': 'id'}, inplace=True)

baseline_long = table_long[table_long['layer_type'] == 'baseline'].sort_values(by='lr')
rest_long = table_long[table_long['layer_type'] != 'baseline'].sort_values(by=['lr', 'c'])


In [11]:
baseline_trace = go.Scatter(x=baseline_long['lr'], y=baseline_long['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))
fig = common_line_plot(rest_long, x="lr", y="loss", color="c", title="Loss vs Learning Rate for All Layer Typles and Placements", trace=baseline_trace)
fig.show()