# EVAL-158
## Does predicting loss of later checkpoints from LLC reliably improve over the prediction from checkpoint index?

In [None]:
%load_ext autoreload
%autoreload 2
import fitting as fit
import data_utils as dat
import vis
import plotly.graph_objects as go
import numpy as np
#from plotly.subplots import make_subplots
import pandas as pd

In [None]:
# Load some example data
# At 160m, the regular loss vs time seems better
mcode = "1b"  #"160m"  # "410m-dense"
msize = mcode.split("-")[0]
df_llc, df_loss = dat.load_dfs(mcode, data_path="data")
#tasks = ["github", "stackexchange", "arxiv", "pile-cc", "pubmed_abstracts"]  # and full?
tasks = df_llc.columns
step_start = 2000  # are we cropping too early?
step_cutoff = 20000   # How much do we get to observe? about 10% of training?
step_end = 200000 # 80000 end of reigeme, or end of training?

scale = 1000.  # rescale steps for the time-fits only
colors = vis.assign_cols(tasks)  #df_llc.columns)





## Loss vs time

In [None]:
# Model for loss
fig = go.Figure()

for task in tasks:
    llc, loss, steps = dat.trim_trace(df_llc, df_loss, task, step_start, step_end)

    # Prepare holdout data (fixed cutoff for now)
    trace = dat.Trace(steps/scale, loss, steps)
    train, test = dat.split(trace, step_cutoff)
    loss_model = fit.min_fit(train.x, train.y, fit.OffsetPowerLaw2)


    # Plot
    color = colors[task]
    color2 = vis.add_color(color)  # show faded holdout data
    vis.plot_data(fig, train.x, train.y, train.s, color=color, showlegend=False, size=5, xscale=scale) 
    vis.plot_data(fig, test.x, test.y, test.s, color=color2, showlegend=False, size=5, xscale=scale) 
    vis.plot_result(fig, trace.x, loss_model, name=task, xscale=scale, color=color, showlegend=True, res=600)
    fig.update_xaxes(title_text="Step", type="log")
    fig.update_yaxes(title_text=r"Loss L")

fig.update_layout(
    title="",
    width=800,
    height=600,
    showlegend=True,

)
fig.show()

## What's going on late training? Its almost like a new exponent
## Actually this isn't fitting that well, its just the line is very flat

In [None]:
step_start2 = 12000
step_cutoff2 = 40000
step_end2 = 200000
fig = go.Figure()

for task in tasks:
    llc, loss, steps = dat.trim_trace(df_llc, df_loss, task, step_start2, step_end2)

    # Prepare holdout data (fixed cutoff for now)
    trace = dat.Trace(steps/scale, loss, steps)
    train, test = dat.split(trace, step_cutoff2)
    loss_model = fit.min_fit(train.x, train.y, fit.OffsetLogarithm)  #fit.OffsetPowerLaw2
    
    
    # Evaluate the metrics
    y_pred = loss_model.f(test.x)
    RMSE = fit.rmse(test.y, y_pred)

    # Plot
    color = colors[task]
    color2 = vis.add_color(color)  # show faded holdout data
    vis.plot_data(fig, train.x, train.y, train.s, color=color, showlegend=False, size=5, xscale=scale) 
    vis.plot_data(fig, test.x, test.y, test.s, color=color2, showlegend=False, size=5, xscale=scale) 
    vis.plot_result(fig, trace.x, loss_model, name=task, xscale=scale, color=color, showlegend=True, res=600)
    fig.update_xaxes(title_text="Step", type="log")
    fig.update_yaxes(title_text=r"Loss L")

fig.update_layout(
    title="",
    width=800,
    height=600,
    showlegend=True,

)
fig.show()

## LLC vs time

In [None]:
# Basically a cut'n'paste job


# Candidate functions:
functions = [
    #Powerlaw needs some help because the initial conditions are very different
    (fit.OffsetPowerLaw, dict(par0=[10., -10., .1]), "Powerlaw", "red"),
    (fit.OffsetLogarithm, {}, "Logarithm", "blue"),
]

# Make a grid layout
titles = []
for f in functions:
    for s in ["Fit", "Holdout"]:
        titles.append(f"{f[2]} - {s}")

fig = go.Figure()

for task in tasks:
    # Load
    llc, loss, steps = dat.trim_trace(df_llc, df_loss, task, step_start, step_end)
    trace = dat.Trace(steps/scale, llc, steps)  # llc vs step
    train, test = dat.split(trace, step_cutoff)

    # Plot the data
    color = colors[task]
    color2 = vis.add_color(color)  # for heldout data
    vis.plot_data(fig, train, color=color, showlegend=False, size=5, xscale=scale) 
    vis.plot_data(fig, test, color=color2, showlegend=False, size=5, xscale=scale) 
        
    
    for f_ind, (function, args, fname, fcol) in enumerate(functions):   
    
        # Fit
        llc_model = fit.min_fit(train.x, train.y, function, **args)

        # Plot the fit
        vis.plot_result(fig, trace.x, llc_model, name=f"{task}-{fname}",
                        xscale=scale, color=fcol, showlegend=True)
        fig.update_xaxes(title_text="Step", type="log")
        fig.update_yaxes(title_text="Scaled LLC")



fig.update_layout(
    title="",
    width=800,
    height=600,
    showlegend=True,

)
fig.show()


Thoughts - both pretty good, I kinda like the logarithm
* the LLC curves up, and logarithm predicts higher therefore it holds out a tiny bit longer
* logarithm is way simpler
* Both **fall apart** around step 80k (the end of the analysis interval)

## LLC vs loss

we can use the function result.model.inverse(x, result.params)

In [None]:
fig = go.Figure()

for task in tasks:
    llc, loss, steps = trace = dat.trim_trace(df_llc, df_loss, task, step_start, step_end)
    train, test = dat.split(trace, step_cutoff)
    loss_model = fit.min_fit(train.x, train.y, fit.OffsetPowerLaw)
    # y_pred = loss_model.f(test.x)
    # RMSE = fit.rmse(test.y, y_pred)

    # Plot
    color = colors[task]
    color2 = vis.add_color(color)  # show faded holdout data
    vis.plot_data(fig, train, color=color, showlegend=False, size=5, xscale=scale) 
    vis.plot_data(fig, test, color=color2, showlegend=False, size=5, xscale=scale) 
    vis.plot_result(fig, trace.x, loss_model, name=task, xscale=scale, color=color, showlegend=True, res=600)
    fig.update_xaxes(title_text="Step", type="log")
    fig.update_yaxes(title_text=r"Loss L")

fig.update_layout(
    title="",
    width=800,
    height=600,
    showlegend=True,

)
fig.show()

# Comparing loss(time) vs loss(llc(time))
- Comparing L(T), L(LLC(T))
- 

In [None]:
def get_y(x, y, at):
    out = []
    for i in at:
        idx = np.searchsorted(x, i)
        if idx >= len(x):
            idx = len(x)-1
        if x[idx] == i:
            out.append(y[idx])
        else:
            out.append(np.nan)
        
    return np.array(out)



In [None]:
eval_at = np.array([50000, 70000, 143000])
eval_scaled = eval_at / scale
columns = []
row_names = ["L(T)", "L*(T)", "L(LLC(T))"]
col_names = []
n_col = len(eval_at) * len(tasks)


fig = go.Figure()
end = len(tasks) - 1
table = np.zeros((len(row_names), n_col))
col = 0

for i, task in enumerate(tasks):
    color = colors[task]
    color2 = vis.add_color(color)  # show faded holdout data
    main = dat.trim_trace(df_llc, df_loss, task, step_start, step_end)

    # Loss vs time
    loss_t = dat.Trace(main.s / scale, main.y, main.s)  # loss vs time
    train, test = dat.split(loss_t, step_cutoff)
    loss_model = fit.min_fit(train.x, train.y, fit.OffsetPowerLaw)
    
    # Fit with same functional form as the composition
    comp_model = fit.min_fit(train.x, train.y, fit.Modron)

    # Plot loss vs time
    
    vis.plot_data(fig, train, color=color, showlegend=True, size=5, xscale=scale, name=task) 
    vis.plot_data(fig, test, color=color2, showlegend=False, size=5, xscale=scale) 
    
    # Fit loss vs LLC
    loss_llc = main
    train, _ = dat.split(loss_llc, step_cutoff)
    loss_ = fit.min_fit(train.x, train.y, fit.OffsetPowerLaw)


    
    # Fit LLC vs time:
    llc_t = dat.Trace(main.s / scale, main.x, main.s)  # llc vs time
    train, _ = dat.split(llc_t, step_cutoff)
    llc_ = fit.min_fit(train.x, train.y, fit.OffsetLogarithm)  # loss vs LLC

    # Plot loss(llc(time)):

    def composed(x):
        return loss_.f(llc_.f(x))


    vis.plot_result(fig, loss_t.x, loss_model, name="L(T)", xscale=scale, color="blue", showlegend=(i==end), res=600)
    vis.plot_result(fig, loss_t.x, comp_model, name="comp(T)", xscale=scale, color="green", showlegend=(i==end), res=600)    
    vis.plot_result(fig, llc_t.x, composed, name="L(LLC(t))", xscale=scale, color="red", showlegend=(i==end), res=600)

    # Collate prediction errors at certain points
    truth = get_y(main.s, main.y, eval_at)
    table[0, col:col+3] = loss_model.f(eval_scaled) - truth
    table[1, col:col+3] = comp_model.f(eval_scaled) - truth
    table[2, col:col+3] = composed(eval_scaled) - truth
    col += 3
    for a in eval_at:
        col_names.append(f"{a//1000}k@{task}")


    
fig.update_xaxes(title_text="Step", type="log")
fig.update_yaxes(title_text=r"Loss L")
fig.update_layout(
    title="Loss vs Time",
    width=800,
    height=600,
    showlegend=True,
)
fig.show()

# Conclusion - seems better than a simple powerlaw....
## At least until ~80k when the model for LLC(t) falls apart
## but does seeing more data actually help?(and has more parameters?)

In [None]:
# Tabulate prediction error at 70k and 143k
result = pd.DataFrame(
    table,
    index=row_names,
    columns=col_names,
)
result = result.sort_index(axis=1)


# Assuming df is your DataFrame
def highlight_min_magnitude(row):
    # Find the index of the minimum absolute value in the row
    min_idx = np.abs(row).argmin()
    
    # Create a list of empty strings the same length as the row
    result = ['' for _ in range(len(row))]
    
    # Add bold HTML tag to the smallest magnitude value
    result[min_idx] = 'font-weight: bold'
    
    return result

def highlight_min_magnitude_in_columns(col):
    # Find the index of the minimum absolute value in the column
    min_idx = np.abs(col).argmin()
    
    # Create a list of empty strings the same length as the column
    result = ['' for _ in range(len(col))]
    
    # Add bold HTML tag to the smallest magnitude value
    result[min_idx] = 'font-weight: bold'
    
    return result


columns = result.columns
for a in eval_at:
    start = f"{a//1000}k"
    ex = [c for c in columns if c.startswith(start)]
    extract = result[ex]
    styled = extract.style.apply(highlight_min_magnitude_in_columns, axis=0)
    display(styled)

In [None]:
## Interestingly more parameters isn't overfitting...