In [1]:
import pandas as pd
import json
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path
from functools import partial
import numpy as np


In [2]:
from neptune.internal.utils.logger import get_logger
get_logger().setLevel('ERROR')

In [3]:
from typing import Optional, Any, Tuple, Dict, List, Iterable, Union
from ast import literal_eval as make_tuple
from functools import reduce

import neptune.api
import neptune.attributes


NEPTUNE_PROJECT = "pmtest/llm-random"
STD_TAIL_QUANTILE = 0.9
STD_TAIL_KEY = f"std_q{STD_TAIL_QUANTILE}"


project = neptune.init_project(
  project=NEPTUNE_PROJECT,
  mode="read-only"
)
columns = [
    "sys/tags",
    "loss_interval/100",
    "args/learning_rate",
    "sys/name",
    "args/grad_modif_params",
    "step",
    "sys/id",
    "sys/failed"
]

placements = set([
    "post_attn_and_ff",
    "post_norm",
    "post_add",
    "all",
])


class SeriesCache:
    def __init__(self, index_path: Path, data_dir: Path):
        self.index_path = index_path
        self.data_dir = data_dir
        self.index = {}
        if not self.index_path.exists():
            with open(self.index_path, 'w') as f:
                json.dump({}, f)
        else:
            with open(self.index_path, 'r') as f:
                self.index = json.load(f)

        if not self.data_dir.exists():
            data_dir.mkdir(parents=True)
    
    def add_entry(self, path: List[str], series: pd.DataFrame, overwrite: bool = False):
        sub_data = self.index
        for p in path[:-1]:
            if p not in sub_data:
                sub_data[p] = {}
            sub_data = sub_data[p]
        
        if path[-1] in sub_data and not overwrite:
            return

        fname = '-'.join(path) + ".pkl"
        series.to_pickle(self.data_dir / fname)
        sub_data[path[-1]] = fname
    
    def get_entry(self, path: List[str]) -> pd.DataFrame:
        sub_data = self.index
        for p in path[:-1]:
            if p not in sub_data:
                return None
            sub_data = sub_data[p]
        
        if path[-1] not in sub_data:
            return None
        
        fname = sub_data[path[-1]]
        return pd.read_pickle(self.data_dir / fname)
    
    def save(self):
        with open(self.index_path, 'w') as f:
            json.dump(self.index, f)

series_cache = SeriesCache(Path("./grad_series_index.json"), Path("./grad_series_data"))

def rename_to_common(df: pd.DataFrame):
    return df.rename(columns={"loss_interval/100": "loss", "args/learning_rate": "lr", "args/grad_modif_params": "grad_modif_params"}, inplace=False)


def infere_layer_type(tags: str):
    tags = tags.split(',')
    if "true_baseline" in tags:
        return "baseline"
    else:
        return "regular"

def infere_k(grad_modif_params: str):
    params = grad_modif_params.split(',')
    for p in params:
        key, val = p.split('=')
        if key == "k":
            if isinstance(val, str):
                return val
            return float(val)
    return None

def infere_placement_type(tags: str):
    tags = tags.split(',')
    for t in tags:
        if t in placements:
            return t


def infere_c(grad_modif_params: str):
    params = grad_modif_params.split(',')
    for p in params:
        key, val = p.split('=')
        if key == "c":
            return float(val)
    return None

def infere_eps(grad_modif_params: str):
    params = grad_modif_params.split(',')
    for p in params:
        key, val = p.split('=')
        if key == "eps":
            return float(val)
    return None

def infere_norm_dims(grad_modif_params: str) -> Tuple[int, ...]:
    start_idx = grad_modif_params.find("norm_dims")
    if start_idx == -1:
        return None

    next_opening_bracket_idx = grad_modif_params.find("(", start_idx)
    next_closing_bracket_idx = grad_modif_params.find(")", start_idx)

    return tuple(make_tuple(grad_modif_params[next_opening_bracket_idx:next_closing_bracket_idx+1]))

def extract_structure_metrics(structure: Dict[str, Any], key: str, path: List[str] = []) -> Iterable[Tuple[str, str, Any]]:
    for k, v in structure.items():
        if k == key:
            yield '.'.join(path), k, v
        elif isinstance(v, dict):
            yield from extract_structure_metrics(v, key, path + [k])


def fetch_run_mean_grad_stds_series(run_id: str) -> pd.DataFrame:
    with neptune.init_run(project=NEPTUNE_PROJECT, with_id=run_id, mode="read-only") as run:
        structure = run.get_structure()
        grad_norms = extract_structure_metrics(structure, "raw_grad_norms")
        dfs = []
        for path, _, series in grad_norms:
            series_df = series['std'].fetch_values(include_timestamp=False)
            series_df.rename(columns={'value': path}, inplace=True)
            dfs.append(series_df)
        
        df_merged = reduce(lambda left, right: pd.merge(left, right, on=['step'], how='outer'), dfs)
        df_merged['mean_stds'] = df_merged.drop(columns=['step']).mean(axis=1)
        return df_merged[['step', 'mean_stds']]


def fetch_and_save_run_grad_series(run_id: str, series_cache: SeriesCache, overwrite: bool = False):
    if run_id in series_cache.index and not overwrite:
        return
    with neptune.init_run(project=NEPTUNE_PROJECT, with_id=run_id, mode="read-only") as run:
        structure = run.get_structure()
        grad_norms = extract_structure_metrics(structure, "raw_grad_norms")
        statistics = ['std', 'mean']
        dfs = dict(zip(statistics, [list() for _ in statistics]))
        for path, _, series in grad_norms:
            for statistic in statistics:
                series_df = series[statistic].fetch_values(include_timestamp=False)
                series_df.rename(columns={'value': path}, inplace=True)
                dfs[statistic].append(series_df)
        
        for statistic in statistics:
            df_merged = reduce(lambda left, right: pd.merge(left, right, on=['step'], how='outer'), dfs[statistic])
            series_cache.add_entry([run_id, statistic], df_merged, overwrite=overwrite)
        
        series_cache.save()
    
def get_run_mean_grad_std_quantile_diff(run_id: str, series_cache: SeriesCache, qbase: float, qref: float) -> pd.DataFrame:
    if run_id not in series_cache.index:
        fetch_and_save_run_grad_series(run_id, series_cache)
    grad_std_series = series_cache.get_entry([run_id, 'std'])
    std_norm_qbase = grad_std_series.drop(columns=['step']).quantile(q=qbase, axis=1)
    std_norm_qref = grad_std_series.drop(columns=['step']).quantile(q=qref, axis=1)
    std_norm_quantile = (std_norm_qref - std_norm_qbase) / std_norm_qbase
    return std_norm_quantile.to_list()

def get_run_std_quantile_series(run_id: str, series_cache: SeriesCache, q: float) -> pd.DataFrame:
    if run_id not in series_cache.index:
        fetch_and_save_run_grad_series(run_id, series_cache)
    grad_std_series = series_cache.get_entry([run_id, 'std'])
    std_norm_q = grad_std_series.drop(columns=['step']).quantile(q=q, axis=1)
    return std_norm_q.to_list()

def get_run_mean_grad_std_steps(run_id: str, series_cache: SeriesCache) -> pd.DataFrame:
    if run_id not in series_cache.index:
        fetch_and_save_run_grad_series(run_id, series_cache)
    grad_std_series = series_cache.get_entry([run_id, 'std'])
    return grad_std_series['step'].to_list()

def fetch_run_mean_grad_last_std_series(run_id: str) -> pd.DataFrame:
    with neptune.init_run(project=NEPTUNE_PROJECT, with_id=run_id, mode="read-only") as run:
        structure = run.get_structure()
        grad_norms = extract_structure_metrics(structure, "raw_grad_norms")
        last_stds = []
        for _, _, series in grad_norms:
            last_std = series['std'].fetch_last()
            last_stds.append(last_std)
        
        return sum(last_stds) / len(last_stds)

def common_line_plot(df: pd.DataFrame, x: str, y: str, color: str, title: str, trace: Optional[Any] = None) -> go.Figure:
    fig = px.line(df, x=x, y=y, color=color, markers=True, log_x=True, log_y=True)

    if trace:
        fig.add_trace(trace)

    fig.update_layout(
        title=title,
        yaxis = dict(
            showexponent = 'all',
            exponentformat = 'power'
        ),
        xaxis = dict(
            showexponent = 'all',
            exponentformat = 'power'
        )
    )

    return fig

def common_plot_nested_list(df: pd.DataFrame, color: str, x: Union[str, List[float]], y: Union[str, List[float]], title: str, trace: Optional[Any] = None) -> go.Figure:
    fig = go.Figure()
    runs = df[color].unique()
    for run in runs:
        run_df = df[df[color] == run]
        assert len(run_df) == 1
        x_vals = run_df[x].values[0] if isinstance(x, str) else x
        y_vals = run_df[y].values[0] if isinstance(y, str) else y
        fig.add_trace(go.Scatter(x=x_vals, y=y_vals, mode='lines', name=run, hovertext=run))
    
    if trace:
        fig.add_trace(trace)

    fig.update_layout(
        title=title,
        yaxis = dict(
            showexponent = 'all',
            exponentformat = 'power',
            type='log'
        ),
        #xaxis = dict(
        #    showexponent = 'all',
        #    exponentformat = 'power'
        #)
    )

    return fig

#### Baseline loss vs lr categorised by eps, total_steps

In [4]:
baseline_df = project.fetch_runs_table(tag="true_baseline", columns=columns).to_pandas()
baseline_df = rename_to_common(baseline_df)
baseline_df = baseline_df[baseline_df['grad_modif_params'].notna()]
baseline_df['eps'] = baseline_df['grad_modif_params'].apply(infere_eps)
baseline_df['sys/id'].apply(partial(fetch_and_save_run_grad_series, series_cache=series_cache, overwrite=False))

baseline_df.sort_values('lr', inplace=True)

In [5]:
fig = common_line_plot(baseline_df, x="lr", y="loss", color="step", title="Baseline Loss vs LR. Categorised by eps, total_steps")
fig.show()

In [6]:
baseline_df_optimal_lr = baseline_df[baseline_df['lr'] == 1e-3].copy()
baseline_df_optimal_lr['steps'] = baseline_df_optimal_lr['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
baseline_df_optimal_lr[STD_TAIL_KEY] = baseline_df_optimal_lr['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
fig = common_plot_nested_list(baseline_df_optimal_lr, color='sys/id', x="steps", y=STD_TAIL_KEY, title="Baseline std_tail_change vs steps")
fig.show()

baseline_grad_norm_std_long = go.Scatter(x=baseline_df_optimal_lr['steps'].values[0], y=baseline_df_optimal_lr[STD_TAIL_KEY].values[0], mode='lines', name='baseline', line=dict(color='black', width=2, dash='dash'))
baseline_grad_norm_std_short = go.Scatter(x=baseline_df_optimal_lr['steps'].values[1], y=baseline_df_optimal_lr[STD_TAIL_KEY].values[1], mode='lines', name='baseline', line=dict(color='black', width=2, dash='dash'))

# Short Experiments (2k steps)

#### Loss vs LR categorised by baseline and sanity_checks for v1, v2 std norm

In [59]:
grad_std_norm_type_tag = {
    "std_v1": "std_v1_c_lr_grid_placement_short",
    "std_v2": "std_v2_c_lr_grid_placement_short",
    "std_v4": "std_v4_c_lr_grid_placement_short"
}

baseline_step_2k = baseline_df[baseline_df['step'] == 2000]

for gn_name, gn_tag in grad_std_norm_type_tag.items():
    sanity_check_df = project.fetch_runs_table(owner="szysad", tag=["c_0", gn_tag], columns=columns).to_pandas()
    sanity_check_df = rename_to_common(sanity_check_df)
    sanity_check_df['eps'] = sanity_check_df['grad_modif_params'].apply(infere_eps)
    sanity_check_df['placement'] = sanity_check_df['sys/tags'].apply(infere_placement_type)
    sanity_check_df.sort_values('lr', inplace=True)

    baseline_trace = go.Scatter(x=baseline_step_2k['lr'], y=baseline_step_2k['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))
    fig = common_line_plot(sanity_check_df, x="lr", y="loss", color="placement", title=f"Sanity check Loss vs LR categorised by placement with baseline for grad norm '{gn_name}'", trace=baseline_trace)
    fig.show()

    optimal_lr_df = sanity_check_df[sanity_check_df['lr'] == 1e-3].copy()
    optimal_lr_df['steps'] = optimal_lr_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
    optimal_lr_df[STD_TAIL_KEY] = optimal_lr_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
    fig = common_plot_nested_list(optimal_lr_df, color='placement', x="steps", y=STD_TAIL_KEY, title=f"Sanity check gradient norm std 90th quantile vs steps for grad norm '{gn_name}'", trace=baseline_grad_norm_std_short)
    fig.show()

#### Loss vs LR categorised by c for each placement (short run, v1 std norm)

In [60]:
plot_cnt = len(grad_std_norm_type_tag) * len(placements)

for i, (gn_name, gn_tag) in enumerate(grad_std_norm_type_tag.items()):
    std_v1_df = project.fetch_runs_table(tag=gn_tag, columns=columns).to_pandas()
    std_v1_df = rename_to_common(std_v1_df)
    std_v1_df['layer_type'] = std_v1_df['sys/tags'].apply(infere_layer_type)
    std_v1_df['c'] = std_v1_df['grad_modif_params'].apply(infere_c)
    std_v1_df = std_v1_df[std_v1_df['c'] != 0]
    std_v1_df['placement'] = std_v1_df['sys/tags'].apply(infere_placement_type)
    std_v1_df.sort_values('lr', inplace=True)

    offset = i * len(placements)
    for j, p in enumerate(placements):
        df = std_v1_df[std_v1_df['placement'] == p]
        title = f"({offset + j + 1}/{plot_cnt}) Loss vs Learning Rate categorized by c, for '{p}' placement for grad norm '{gn_name}'"
        baseline_trace = go.Scatter(x=baseline_step_2k['lr'], y=baseline_step_2k['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))
        fig = common_line_plot(df, x="lr", y="loss", color="c", title=title, trace=baseline_trace)
        fig.show()
    
        optimal_lr_df = std_v1_df[(std_v1_df['lr'] == 1e-3) & (std_v1_df['sys/failed'] == False) & (std_v1_df['placement'] == p)].copy()
        optimal_lr_df[STD_TAIL_KEY] = optimal_lr_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
        optimal_lr_df['steps'] = optimal_lr_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
        optimal_lr_df['color'] = optimal_lr_df['c'].astype(str)
        optimal_lr_df.sort_values('color', inplace=True)
        #title = f"({i+1}/{len(grad_std_norm_type_tag)}) gradient norm std 90th quantile vs steps for '{gn_name}'" # GNQR - Grad Norm Quantile Ratio
        title = f"{offset + j + 1}/{plot_cnt} gradient norm std 90th quantile vs steps for '{p}' placement for grad norm '{gn_name}'"
        fig = common_plot_nested_list(optimal_lr_df, color='color', x="steps", y=STD_TAIL_KEY, title=title, trace=baseline_grad_norm_std_short)
        fig.show() 
    

Fetching table...: 0 [00:00, ?/s]

#### Loss vs LR categorised by c and norm_dims for each placement in 'norm_scale' grad norm

In [61]:
sanity_check_df = project.fetch_runs_table(owner="szysad", tag=["c_0", "scale_norm_c_lr_grid_placement_short"], columns=columns).to_pandas()
sanity_check_df = rename_to_common(sanity_check_df)
sanity_check_df['k'] = sanity_check_df['grad_modif_params'].apply(infere_k)
#sanity_check_df = sanity_check_df[sanity_check_df['k'] == "auto"]
sanity_check_df['c'] = sanity_check_df['grad_modif_params'].apply(infere_c)
sanity_check_df['placement'] = sanity_check_df['sys/tags'].apply(infere_placement_type)
sanity_check_df['norm_dims'] = sanity_check_df['grad_modif_params'].apply(infere_norm_dims)
sanity_check_df['category'] = sanity_check_df.apply(lambda x: f"placement={x.placement}, c={x.c}, k={x.k}, norm_dims={x.norm_dims}", axis=1)
sanity_check_df.sort_values('lr', inplace=True)

baseline_trace = go.Scatter(x=baseline_step_2k['lr'], y=baseline_step_2k['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))
fig = common_line_plot(sanity_check_df, x="lr", y="loss", color="category", title=f"Sanity check Loss vs LR categorised by placement, c and norm_dims for grad norm scale_norm", trace=baseline_trace)
fig.show()

In [62]:
plot_cnt = len(placements)

for i, place in enumerate(placements):
    scale_norm_df = project.fetch_runs_table(tag="scale_norm_c_lr_grid_placement_short", columns=columns).to_pandas()
    scale_norm_df = rename_to_common(scale_norm_df)
    scale_norm_df['k'] = scale_norm_df['grad_modif_params'].apply(infere_k)
    scale_norm_df = scale_norm_df[scale_norm_df['k'] == "auto"]
    scale_norm_df['c'] = scale_norm_df['grad_modif_params'].apply(infere_c)
    scale_norm_df = scale_norm_df[scale_norm_df['c'] != 0]
    scale_norm_df['placement'] = scale_norm_df['sys/tags'].apply(infere_placement_type)
    scale_norm_df['norm_dims'] = scale_norm_df['grad_modif_params'].apply(infere_norm_dims)
    scale_norm_df['category'] = scale_norm_df.apply(lambda x: f"c={x.c}, k={x.k}, norm_dims={x.norm_dims}", axis=1)
    scale_norm_df.sort_values('lr', inplace=True)


    df = scale_norm_df[scale_norm_df['placement'] == place]
    title = f"({i}/{plot_cnt}) Loss vs Learning Rate categorized by c and norm_dims for '{place}' placement."
    baseline_trace = go.Scatter(x=baseline_step_2k['lr'], y=baseline_step_2k['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))
    fig = common_line_plot(df, x="lr", y="loss", color="category", title=title, trace=baseline_trace)
    fig.show()

    optimal_lr_df = df[scale_norm_df['lr'] == 1e-3].copy()
    optimal_lr_df = optimal_lr_df[optimal_lr_df['sys/failed'] == False]
    optimal_lr_df[STD_TAIL_KEY] = optimal_lr_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
    optimal_lr_df['steps'] = optimal_lr_df['sys/id'].apply(partial(get_run_mean_grad_std_steps, series_cache=series_cache))
    optimal_lr_df['color'] = optimal_lr_df['placement'].astype(str) + "-" + optimal_lr_df['k'].astype(str) + "-" + optimal_lr_df['c'].astype(str) + "-" + optimal_lr_df['norm_dims'].astype(str)
    optimal_lr_df.sort_values('color', inplace=True)
    title = f"({i+1}/{plot_cnt}) gradient norm std 90th quantile vs steps for scale_norm for '{place}' placement"
    fig = common_plot_nested_list(optimal_lr_df, color='color', x="steps", y=STD_TAIL_KEY, title=title, trace=baseline_grad_norm_std_short)
    fig.show()

Fetching table...: 0 [00:00, ?/s]


Boolean Series key will be reindexed to match DataFrame index.



Fetching table...: 0 [00:00, ?/s]


Boolean Series key will be reindexed to match DataFrame index.



Fetching table...: 0 [00:00, ?/s]


Boolean Series key will be reindexed to match DataFrame index.



Fetching table...: 0 [00:00, ?/s]


Boolean Series key will be reindexed to match DataFrame index.



## Relation of loss to gradient std

In [89]:
dfs = []
std_norm = {
    "std_v1": "(0, 1, 2)",
    "std_v2": "(2)",
    "std_v4": "(1, 2)"
}
for gn_name, gn_tag in grad_std_norm_type_tag.items():
    _df = project.fetch_runs_table(tag=gn_tag, columns=columns).to_pandas()
    _df['type'] = "std_norm"
    _df['std_v'] = gn_name
    _df['norm_dims'] = std_norm[gn_name]
    dfs.append(_df)

scale_norm_df = project.fetch_runs_table(tag="scale_norm_c_lr_grid_placement_short", columns=columns).to_pandas()
scale_norm_df['type'] = "scale_norm"
dfs.append(scale_norm_df)

all_df = pd.concat(dfs, ignore_index=True)
all_df = rename_to_common(all_df)
all_df['k'] = all_df['grad_modif_params'].apply(infere_k)
all_df['c'] = all_df['grad_modif_params'].apply(infere_c)
all_df = all_df[(all_df['c'] != 0) & (all_df['lr'] == 1e-3) & (all_df['sys/failed'] == False)]

all_df['placement'] = all_df['sys/tags'].apply(infere_placement_type)
all_df['grad_std_q_series'] = all_df['sys/id'].apply(partial(get_run_std_quantile_series, series_cache=series_cache, q=STD_TAIL_QUANTILE))
all_df['grad_std_q_mean'] = all_df['grad_std_q_series'].apply(np.mean)
all_df['color'] = all_df['placement'].astype(str) + "-" + all_df['norm_dims'].astype(str)



Fetching table...: 0 [00:00, ?/s]

Fetching table...: 0 [00:00, ?/s]

In [100]:
log_axis_layout = dict(
    showexponent = 'all',
    exponentformat = 'power',
    type='log'
)
corr_df = all_df[all_df['loss'] < 5]
corr_df_std = corr_df[corr_df['type'] == "std_norm"]
corr_df_scale_norm = corr_df[corr_df['type'] == "scale_norm"]
corr_df_scale_norm['norm_dims'] = corr_df_scale_norm['grad_modif_params'].apply(infere_norm_dims)
baseline_trace = go.Scatter(x=[np.mean(baseline_df_optimal_lr[STD_TAIL_KEY].values[1])], y=[baseline_df_optimal_lr['loss'].values[1]], mode='markers', name='baseline', marker=dict(symbol='x', color='black', size=8))



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



In [105]:
fig = px.scatter(corr_df_std, x='grad_std_q_mean', y='loss', color='color', hover_name='sys/id', hover_data=['c', 'placement', 'norm_dims', 'std_v'], title="Loss vs mean of 90th quantile of gradient norm std for std normalizations", log_x=True, log_y=True)
fig.add_trace(baseline_trace)
fig.update_layout(yaxis = log_axis_layout, xaxis = log_axis_layout)
fig.show()

In [103]:
fig = px.scatter(corr_df_scale_norm, x='grad_std_q_mean', y='loss', color='placement', hover_name='sys/id', hover_data=['c', 'k', 'placement', 'norm_dims'], title="Loss vs mean of 90th quantile of gradient norm std for scale_norm", log_x=True, log_y=True)
fig.add_trace(baseline_trace)
fig.update_layout(yaxis = log_axis_layout, xaxis = log_axis_layout)
fig.show()

# Long experiments

### std_v1 params
|     placement    	| c    	|
|:----------------:	|------	|
| post_add         	| 1e-4 	|
| post_attn_and_ff 	| 1e-4 	|
| post_norm        	| 1e-3 	|
| all              	| 1e-5 	|

### std_v2 params
|     placement    	| c    	|
|:----------------:	|------	|
| post_attn_and_ff 	| 1e-5 	|
| post_add         	| 1e-4 	|
| post_add         	| 1e-5 	|
| post_attn_and_ff 	| 1e-4 	|

### std_v4 params
|     placement    	|   c  	|
|:----------------:	|:----:	|
|     post_add     	| 1e-4 	|
|     post_norm    	| 1e-3 	|
| post_attn_and_ff 	| 1e-4 	|
|        all       	| 1e-5 	|

### scale_norm params
|     placement    	| c    	| k    	| norm_dims 	|
|:----------------:	|------	|------	|-----------	|
| post_add         	| 1e-4 	| 1    	| 0,1,2     	|
| post_add         	| 1e-4 	| 1    	| 1,2       	|
| post_add         	| 1    	| auto 	| 0,1,2     	|
| post_add         	| 1    	| auto 	| 1,2       	|
| all              	| 1e-4 	| 1    	| 0,1,2     	|
| all              	| 1e-5 	| 1    	| 0,1,2     	|
| post_norm        	| 1e-4 	| 1    	| 1,2       	|
| post_norm        	| 1e-4 	| 1    	| 0,1,2     	|
| post_attn_and_ff 	| 1e-4 	| 1    	| 1,2       	|
| post_attn_and_ff 	| 1e-3 	| 1    	| 1,2       	|


# Long Experiments (16k steps)

In [94]:
table_long = project.fetch_runs_table(tag="post_add_c_lr_grid_long", columns=columns).to_pandas()
table_long['layer_type'] = table_long['sys/tags'].apply(infere_layer_type)
table_long['c'] = table_long['args/grad_modif_params'].apply(infere_c)
table_long.rename(columns={'loss_interval/100': 'loss', 'args/learning_rate': 'lr', 'sys/id': 'id'}, inplace=True)

baseline_long = table_long[table_long['layer_type'] == 'baseline'].sort_values(by='lr')
rest_long = table_long[table_long['layer_type'] != 'baseline'].sort_values(by=['lr', 'c'])


In [95]:
baseline_trace = go.Scatter(x=baseline_long['lr'], y=baseline_long['loss'], name='baseline', line=dict(color='black', width=2, dash='dash'))
fig = common_line_plot(rest_long, x="lr", y="loss", color="c", title="Loss vs Learning Rate for All Layer Typles and Placements", trace=baseline_trace)
fig.show()