In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import fitting as fit
import data_utils as dat
import vis
import plotly.graph_objects as go

# Train/Test Splits

Problem - how do we extract the "interesting interval" across different models and tasks.

In [None]:
from plotly.subplots import make_subplots
import pandas as pd
fit_method = fit.min_fit  # fit.min_fit  # super robust, odr_fit also an option
mcode = '410m-dense'
step_start = 2000  # are we cropping too early?
step_end = 80000
step_cutoff = 20000  # TODO: investigate if this point is in the train or the test set

#tasks = ["arxiv", "pile-cc"]
functions = [
    fit.DoubleOffsetPowerLaw,  # The first model sets the shift parameters for the other models
    fit.OffsetPowerLaw,
    fit.OffsetExponential,
    fit.Cubic,
]
metrics = {
    # "R2_log": fit.logspace_r2,
    # "R2_lin": fit.r2_score,
    "RMSE": fit.rmse,
}

df_llc, df_loss = dat.load_dfs(mcode, data_path="data")
msize = mcode.split("-")[0]
tasks = df_llc.columns
colors = vis.assign_cols(tasks)  # deal with wikipedia

# Accumulate plots and reports ===================
# Make a 2-column layout for each function
titles = []
for f in functions:
    for s in ["linear", "log"]:
        # titles.append(f"Pythia-{msize} {f.name} ({s})")
        titles.append(f.name)  # f"Pythia-{msize} {f.name} ({s})")

fig = make_subplots(
    rows=len(functions), cols=2,
    subplot_titles=titles,
    horizontal_spacing=0.1,
    vertical_spacing=0.1,
)

report = []  # the set of reports

for task in tasks:
    color = colors[task]
    trace = dat.trim_trace(df_llc, df_loss, task, step_start, step_end)
    train, test = dat.split(trace, step_cutoff)

    # Cache fits so we can see the other methods and mark the "winner"
    scores = []  # primary score
    results = []  # fit result
    
    # Now plot all the models:
    for f_ind, function in enumerate(functions):        

        # See the future for plotting and init purposes only
        oracle = fit_method(trace.x, trace.y, function)
        if f_ind == 0:
            # Use the same plotting projection as in the set of plots - the full fit on the double shifted
            shift = oracle.params_dict
            
        # Fit the result
        result = fit_method(train.x, train.y, function, par0=oracle.params)

        # Evaluate result
        row = {
            "Dataset": task,
            "Function": function.name,
        }
        y_pred = result.f(test.x)
        measures = {k: v(test.y, y_pred) for k, v in metrics.items()}
        row.update(measures)
        report.append(row) 
        scores.append(measures["RMSE"])
        results.append(result)

    # Then plot in a second iteration
    for f_ind, function in enumerate(functions):
        result = results[f_ind]
        score = scores[f_ind]

        if score == min(scores):  # lower is better
            score_repr = f"RMSE=<b>{score:.4f}</b> {task}"
        else:
            score_repr = f"RMSE={score:.4f} {task}"
        
        # colour unseen data lighter for some visual distinction
        color2 = vis.add_color(color)

        # vis.plot_data(fig, train.x, train.y, train.s, color=color,
        #                 name="Observed", showlegend=False, subplot=subplot, size=5)
        # vis.plot_data(fig, test.x, test.y, test.s, color=color2,
        #               name="Heldout", showlegend=False, subplot=subplot, size=5)
        
        for col in range(1, 3):
            subplot = dict(row=f_ind+1, col=col)
            use_shift = shift if col==2 else None
            vis.plot_data(fig, train.x, train.y, train.s, color=color,
                          showlegend=False, subplot=subplot, size=5, shift=use_shift)
            vis.plot_data(fig, test.x, test.y, test.s, color=color2,
                          showlegend=False, subplot=subplot, size=5, shift=use_shift)
            vis.plot_result(fig, trace.x, result, name=score_repr, color=color, subplot=subplot, shift=use_shift,
                           showlegend=col==2, legendgroup=function.name) #legend=legend_id)
            fig.update_xaxes(title_text=r"$\text{Estimated and transformed LLC }\,\frac{1}{100}\hat{\lambda}$", **subplot)
            if col == 1:
                fig.update_yaxes(title_text=r"$\text{Loss }L$", **subplot)
            elif col == 2:
                fig.update_yaxes(title_text=r"$\text{Loss }L - L^*$", **subplot)     
       


# compile report etc
report = pd.DataFrame(report, index=range(len(report)))
report.index.name="Pythia" + msize

fig.update_layout(
    title="",
    width=1200,
    height=410* len(functions),
    showlegend=True,
    legend_tracegroupgap=180,  # annoying - have to eyeball this
)
fname = f"plots/holdout_{msize}.pdf"
fig.write_image(fname)
#fig.show()
print(f"Done. See {fname}")

In [None]:
def format_value(val, is_winner, is_rmse=False):
    # Format to 3 decimal places
    formatted = f"{val:.3f}"
    if is_winner:
        return f"\\textbf{{{formatted}}}"
    return formatted

def highlight_winners(df):
    # Create a copy to avoid modifying the original
    result = df.copy()
    
    # Process each unique dataset
    for dataset in df['Dataset'].unique():
        mask = df['Dataset'] == dataset
        group = df[mask]
        
        # Find winners (max for R2, min for RMSE)
        # winners_r2_log = group['R2_log'] == group['R2_log'].max()
        # winners_r2_lin = group['R2_lin'] == group['R2_lin'].max()
        winners_rmse = group['RMSE'] == group['RMSE'].min()
        
        # Format all values in the group
        # result.loc[mask, 'R2_log'] = [format_value(val, win) 
        #                              for val, win in zip(group['R2_log'], winners_r2_log)]
        # result.loc[mask, 'R2_lin'] = [format_value(val, win) 
        #                              for val, win in zip(group['R2_lin'], winners_r2_lin)]
        result.loc[mask, 'RMSE'] = [format_value(val, win, is_rmse=True) 
                                   for val, win in zip(group['RMSE'], winners_rmse)]
    
    return result

In [None]:
# bold the best fit!
report2 = highlight_winners(report)
display(report2)



In [None]:
print(report2.to_latex(index=False).replace("R2_log", "R^2 (logspace)").replace("R2_lin", "R^2").replace("_", "\\_"))

# Plot each function on the same task

In [None]:
msize = '410m-dense'
step_start = 2000  # are we cropping too early?
step_end = 80000
step_cutoff = 20000
fit_method = fit.odr_fit  # much better for 
#tasks = ["arxiv", "pile-cc"]
functions = [
    fit.ShiftedPowerLaw,
    # fit.DoubleShiftedPowerLaw,
    fit.ShiftedExponential,
    fit.Cubic,
]
metrics = {
    "R2_log": fit.logspace_r2,
    "R2_lin": fit.r2_score,
    "RMSE": fit.rmse,
}

df_llc, df_loss = dat.load_dfs(msize, data_path="data")
tasks = df_llc.columns
colors = vis.assign_cols(tasks)  # deal with wikipedia

# Accumulate plots and reports
fig = go.Figure()
reports = []  # the set of reports

for task in tasks:
    trace = dat.trim_trace(df_llc, df_loss, task, step_start, step_end)
    train, test = dat.split(trace, step_cutoff)

    
    vis.plot_split(fig, train, test)

    # Now plot all the models:
    for model_name, model in functions.items():
        
        result = fit.min_fit(train.x, train.y, model)
        vis.plot_result(fig, trace.x, result, name=model_name)

    
    fig.update_layout(
        title=f'Predictions after step {step_cutoff} - {task} @ {msize}',
        xaxis_title='LLC / 100',
        yaxis_title='Loss',
        width=900, height=600,
    )
    fig.show()

# The plots in shifted space don't look so good

While the log(y-y*) makes the fit look amazing, its using the function's parameters to distort the space so of course it looks amazing.
However, it can make the unseen data look worse.

In [None]:
# Plots in shifted space only make sense for the power law
msizes = ['160m', '410m']
tasks = ["arxiv", "wikipedia_en"]
function = fit.ShiftedPowerLaw

# TODO: we need this for every model (and maybe for every model/task combination)
step_start = 512  # are we cropping too early?
step_end = 80000
step_cutoff = 10000

for msize in msizes:
    df_llc, df_loss = dat.load_dfs(msize, data_path="data")

    for task in tasks:
        trace = dat.trim_trace(df_llc, df_loss, task, step_start, step_end)
        train, test = dat.split(trace, step_cutoff)
        result = fit.min_fit(train.x, train.y, fit.ShiftedPowerLaw2)
        
        fig = go.Figure()
        shift = result.params_dict  # contains y*
        vis.plot_split(fig, train, test, shift=result.params_dict)
        vis.plot_result(fig, trace.x, result, name="Power law fit", shift=result.params_dict)

        # As we're using shifted powerlaw, we should go back to logspace
        fig.update_xaxes(title_text="L - L*", type="log")
        fig.update_yaxes(title_text="LLC - LLC*", type="log")

        
        fig.update_layout(
            title=f'Predictions after step {step_cutoff} - {task} @ {msize}',
            xaxis_title='LLC / 100',
            yaxis_title='Loss',
            width=900, height=600,
        )
        fig.show()

# Do fit methods make a difference?

### ANSWER: a bit but not so much as to change which functional form is preferred.


In [None]:
msize="410m"
df_llc, df_loss = dat.load_dfs(msize, data_path="data")
tasks = ["arxiv", "wikipedia_en", "full"] 
functions = {
    "Power Law": fit.ShiftedPowerLaw2,
    "Exponential": fit.ShiftedExponential,
}
fit_methods = {
    "minimize": fit.min_fit,
    "ODR": fit.odr_fit,
}
hk_noise = {
   "rel_noise": True,
    "": False,
}

for task in tasks:
    trace = dat.trim_trace(df_llc, df_loss, task, step_start, step_end)
    train, test = dat.split(trace, step_cutoff)

    fig = go.Figure()
    vis.plot_split(fig, train, test)

    for model_name, model in functions.items():
        for fit_name, fit_fn in fit_methods.items():
            for rel_name, rel_noise in hk_noise.items():
                result = fit_fn(train.x, train.y, model, rel_noise=rel_noise)
                vis.plot_result(
                    fig, trace.x, result,
                    name=f"{model_name} {fit_name} {rel_name}"
                )

    
    fig.update_layout(
        title=f'Predictions after step {step_cutoff} - {task} @ {msize}',
        xaxis_title='LLC / 100',
        yaxis_title='Loss',
        width=900, height=600,
    )
    fig.show()

# Can we use pcov to estimate uncertainty?
Yes but its an approximation?

In [None]:
#Available sizes: ['14m', '31m', '70m', '160m', '410m', '1b']
msizes = ['160m', '410m']
tasks = ["arxiv", "wikipedia_en"]
functions = {
    "Power Law": fit.ShiftedPowerLaw,
    "Exponential": fit.ShiftedExponential,
    # "Power Law (4P)": cf.DoubleShiftedPowerLaw,
    # "ExpExp": cf.DoubleExponential,
}

# TODO: we need this for every model (and maybe for every model/task combination)
step_start = 512  # are we cropping too early?
step_end = 80000
step_cutoff = 10000

colors = vis.assign_cols(functions)


for msize in msizes:
    df_llc, df_loss = dat.load_dfs(msize, data_path="data")

    for task in tasks:
        trace = dat.trim_trace(df_llc, df_loss, task, step_start, step_end)
        
        train, test = dat.split(trace, step_cutoff)

        fig = go.Figure()
        vis.plot_split(fig, train, test)

        # Now plot all the models:
        for model_name, model in functions.items():
            
            result = fit.min_fit(train.x, train.y, model)
           
            _, y_test_mu, y_test_std = result.sample(test.x, 30)
            llh = fit.normal_log_likelihood(test.y, y_test_mu, y_test_std)

            vis.sample_result(fig, trace.x, result, colors[model_name],
                             name=f"{model_name}: LLH~={llh:.1f}")

        fig.update_layout(
            title=f'Predictions after step {step_cutoff} - {task} @ {msize}',
            xaxis_title='LLC / 100',
            yaxis_title='Loss',
            width=900, height=600,
        )
        fig.show()