In [None]:
#!pip install ipympl
%load_ext autoreload
%autoreload 2
import fitting as fit
import data_utils as dat
import vis
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
import pandas as pd

In [None]:
# Load Data
msize = "410m-dense"
tasks = ["github", "stackexchange", "arxiv", "pile-cc"]  # and full?
df_llc, df_loss = dat.load_dfs(msize, data_path="data")
step_start = 2000  # are we cropping too early?
step_end = 80000
step_cutoff = 20000

metrics = {
    "R2_log": fit.logspace_r2,
    "R2_lin": fit.r2_score,
    #"RMSE": fit.rmse,
}
colors = vis.assign_cols(df_llc.columns)


# Parametric loss over time
Perhaps we should be using RMSE for everything?

In [None]:
# Candidate functions:
functions = [
    (fit.ShiftedPowerLaw2, dict(rel_noise=True)),
    (fit.ShiftedExponential, dict(rel_noise=True, par0=[1., 1., -.1])),
]

 # Make a grid layout
titles = []
for f in functions:
    for s in ["Fit", "Holdout"]:
        titles.append(f"{s} loss - {f[0].name}")

fig = make_subplots(
    rows=len(functions), cols=2,
    subplot_titles=titles,
    horizontal_spacing=0.1,
    vertical_spacing=0.1,
)

report = []
fnames = {f[0].name: f[0].name.strip("Shifted ") for f in functions}
loss_preds = {f: {} for f in fnames.values()}
loss_holdouts = {f: {} for f in fnames.values()}

for task in tasks:
    llc, loss, steps = dat.trim_trace(df_llc, df_loss, task, 2000, 80000)
    
    # Permute the problem
    xscale = 1000.
    trace = dat.Trace(steps/xscale, loss, steps)    
    train, test = dat.split(trace, step_cutoff)
    color = colors[task]
    
    
    for f_ind, (function, args) in enumerate(functions):
        
        fname = function.name.strip("Shifted ")
        
        # Regular Fit
        result = fit.min_fit(trace.x, trace.y, function, **args)
        subplot = dict(row=f_ind+1, col=1)
        vis.plot_data(fig, trace.x, trace.y, color=color, showlegend=False, size=6, xscale=xscale, subplot=subplot) 
        vis.plot_result(fig, trace.x, result, name=task, xscale=xscale, color=color, showlegend=f_ind==0, subplot=subplot)
        fig.update_xaxes(title_text="Step", type="log", **subplot)
        fig.update_yaxes(title_text="Loss", **subplot)

        # Evaluate goodness-of-fit
        row = {
            "Dataset": task,
            "Function": function.name,
        }
        y_pred = result.f(trace.x)
        measures = {k: v(trace.y, y_pred) for k, v in metrics.items()}
        row.update(measures)
        
        # Holdout Fit
        subplot = dict(row=f_ind+1, col=2)
        color2 = vis.add_color(color)
        v_result = fit.min_fit(train.x, train.y, function, **args)
        vis.plot_data(fig, train.x, train.y, color=color, showlegend=False, size=6, xscale=xscale, subplot=subplot) 
        vis.plot_data(fig, test.x, test.y, color=color2, showlegend=False, size=6, xscale=xscale, subplot=subplot) 
        vis.plot_result(fig, trace.x, v_result, name=task, xscale=xscale, color=color, showlegend=False, subplot=subplot)
        fig.update_xaxes(title_text="Step", type="log", **subplot)
        fig.update_yaxes(title_text="Loss", **subplot)
        report.append(row)

        # Evaluate holdout
        y_pred_v = v_result.f(test.x)
        measures = {k + "_test": v(test.y, y_pred_v) for k, v in metrics.items()}
        row.update(measures)


        loss_preds[fname] = y_pred
        loss_holdouts[fname] = y_pred_v

        
report = pd.DataFrame(report, index=range(len(report)))


fig.update_layout(
    title="",
    width=1000, #200,
    height=450* len(functions),
    showlegend=True,
    legend=dict(
        yanchor="middle",
        y=0.5,
        xanchor="right",
        x=1.2
    ),
)

fig.show()
display(report)
# We won't need this...
del loss_preds["Exponential (3 param)"]
# Save LLC vs time

In [None]:
# Step 2: fit LLC vs time
# result_pwr = fit.min_fit(step, llc, fit.ShiftedPowerLaw, rel_noise=True, par0=[10., -10., 0.001])
# result_log = fit.min_fit(step, llc, fit.ShiftedLogarithm, rel_noise=True)

# Candidate functions:
functions = [
    (fit.ShiftedPowerLaw, dict(rel_noise=True, par0=[10., -10., .1])),
    (fit.ShiftedLogarithm, dict(rel_noise=True)),
]

 # Make a grid layout
titles = []
for f in functions:
    for s in ["Fit", "Holdout"]:
        titles.append(f"LLC - {f[0].name} ({s})")

fig = make_subplots(
    rows=len(functions), cols=2,
    subplot_titles=titles,
    horizontal_spacing=0.1,
    vertical_spacing=0.1,
)

report = []

fnames = {f[0].name: f[0].name.strip("Shifted ") for f in functions}
llc_preds = {f: {} for f in fnames.values()}
llc_holdouts = {f: {} for f in fnames.values()}

for task in tasks:
    llc, loss, steps = dat.trim_trace(df_llc, df_loss, task, 2000, 80000)
    
    # Permute the problem
    xscale = 1000.
    trace = dat.Trace(steps/xscale, llc, steps)    
    train, test = dat.split(trace, step_cutoff)
    color = colors[task]
    
    
    for f_ind, (function, args) in enumerate(functions):
        
        fname = fnames[function.name]  # allow for rename
        
        # Regular Fit
        result = fit.min_fit(trace.x, trace.y, function, **args)
        subplot = dict(row=f_ind+1, col=1)
        vis.plot_data(fig, trace.x, trace.y, color=color, showlegend=False, size=6, xscale=xscale, subplot=subplot) 
        vis.plot_result(fig, trace.x, result, name=task, xscale=xscale, color=color, showlegend=f_ind==0, subplot=subplot)
        fig.update_xaxes(title_text="Step", type="log", **subplot)
        fig.update_yaxes(title_text="Loss", **subplot)

        # Evaluate goodness-of-fit
        row = {
            "Dataset": task,
            "Function": function.name,
        }
        y_pred = result.f(trace.x)
        measures = {k: v(trace.y, y_pred) for k, v in metrics.items()}
        row.update(measures)

        
        # Holdout Fit
        subplot = dict(row=f_ind+1, col=2)
        color2 = vis.add_color(color)
        v_result = fit.min_fit(train.x, train.y, function, **args)
        vis.plot_data(fig, train.x, train.y, color=color, showlegend=False, size=6, xscale=xscale, subplot=subplot) 
        vis.plot_data(fig, test.x, test.y, color=color2, showlegend=False, size=6, xscale=xscale, subplot=subplot) 
        vis.plot_result(fig, trace.x, v_result, name=task, xscale=xscale, color=color, showlegend=False, subplot=subplot)
        fig.update_xaxes(title_text="Step", type="log", **subplot)
        fig.update_yaxes(title_text="LLC", **subplot)
        report.append(row)

        # Evaluate holdout
        y_pred_v = v_result.f(test.x)
        measures = {k + "_test": v(test.y, y_pred_v) for k, v in metrics.items()}
        row.update(measures)


        llc_preds[fname][task] = y_pred
        llc_holdouts[fname][task] = y_pred_v

        
report = pd.DataFrame(report, index=range(len(report)))


fig.update_layout(
    title="",
    width=1000, #200,
    height=450* len(functions),
    showlegend=True,
    legend=dict(
        yanchor="middle",
        y=0.5,
        xanchor="right",
        x=1.25
    ),
)

fig.show()
display(report)
# Save LLC vs time

# Step 3: LLC vs loss

In [None]:
# Candidate functions:
functions = [
    (fit.ShiftedPowerLaw, dict(rel_noise=True)),
    (fit.ShiftedExponential, dict(rel_noise=True)),
]


loss_preds
llc_preds[fname] = y_pred
llc_holdouts[fname] = y_pred_v

 # Make a grid layout
titles = []
for f in functions:
    for s in ["Fit", "Holdout"]:
        titles.append(f"Direct - {f[0].name} ({s})")

for a in loss_preds:
    a = a.replace(" ", "").lower()
    for b in llc_preds:
        b = b.replace(" ", "").lower()
        for s in ["Fit", "Holdout"]:
            titles.append(f"Parametric {a}-{b} ({s})")


n_cols = 2
n_rows = len(titles)//n_cols
fig = make_subplots(
    rows=n_rows, cols=n_cols,
    subplot_titles=titles,
    horizontal_spacing=0.1,
    vertical_spacing=0.1,
)

report = []
fnames = {f[0].name: f[0].name.strip("Shifted ") for f in functions}
llc_preds = {f: {} for f in fnames.values()}
llc_holdouts = {f: {} for f in fnames.values()}

for task in tasks:
    llc, loss, steps = trace = dat.trim_trace(df_llc, df_loss, task, 2000, 80000)
    
    # Permute the problem
    xscale = 1000.  
    train, test = dat.split(trace, step_cutoff)
    color = colors[task]
    
    # Plot the direct functions
    for f_ind, (function, args) in enumerate(functions):
        
        fname = function.name.strip("Shifted ")
        
        # Regular Fit
        result = fit.min_fit(trace.x, trace.y, function, **args)
        subplot = dict(row=f_ind+1, col=1)
        vis.plot_data(fig, trace.x, trace.y, color=color, showlegend=False, size=6, xscale=xscale, subplot=subplot) 
        vis.plot_result(fig, trace.x, result, name=task, xscale=xscale, color=color, showlegend=f_ind==0, subplot=subplot)
        fig.update_xaxes(title_text="Step", type="log", **subplot)
        fig.update_yaxes(title_text="Loss", **subplot)

        # Evaluate goodness-of-fit
        row = {
            "Dataset": task,
            "Function": function.name,
        }
        y_pred = result.f(trace.x)
        measures = {k: v(trace.y, y_pred) for k, v in metrics.items()}
        row.update(measures)

        
        # Holdout Fit
        subplot = dict(row=f_ind+1, col=2)
        color2 = vis.add_color(color)
        v_result = fit.min_fit(train.x, train.y, function, **args)
        vis.plot_data(fig, train.x, train.y, color=color, showlegend=False, size=6, xscale=xscale, subplot=subplot) 
        vis.plot_data(fig, test.x, test.y, color=color2, showlegend=False, size=6, xscale=xscale, subplot=subplot) 
        vis.plot_result(fig, trace.x, v_result, name=task, xscale=xscale, color=color, showlegend=False, subplot=subplot)
        fig.update_xaxes(title_text="Step", type="log", **subplot)
        fig.update_yaxes(title_text="LLC", **subplot)
        report.append(row)

        # Evaluate holdout
        y_pred_v = v_result.f(test.x)
        measures = {k + "_test": v(test.y, y_pred_v) for k, v in metrics.items()}
        row.update(measures)


    # Plot the indirect functions
    assert False, "up to here! plot the indirect matches (how to evaluate them? - interpolate?)"
    #cols = 
    
    for a in loss_preds:
        a = a.replace(" ", "").lower()
        for b in llc_preds:
            b = b.replace(" ", "").lower()

    
report = pd.DataFrame(report, index=range(len(report)))


fig.update_layout(
    title="",
    width=1000, #1200,
    height=400* n_rows,  # 450
    showlegend=True,
    legend=dict(
        yanchor="middle",
        y=0.5,
        xanchor="right",
        x=1.25
    ),
)

fig.show()
display(report)
# Save LLC vs time

In [None]:
# now fit a loss-llc relationship

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=loss, y=llc, mode='markers', 
    marker=dict(color='blue', size=8),
    customdata=step, hovertemplate='Step: %{customdata}<extra></extra>',
    name='Observed points'))

#e_result = fit.min_fit(loss, llc, fit.ShiftedExponential, rel_noise=True)
# p_result = fit.min_fit(loss, llc, fit.ShiftedPowerLaw, rel_noise=True)

# Note - because i'm putting loss on the x axis, only ord_fit can do this

d_result = fit.odr_fit(loss, llc, fit.DoubleShiftedPowerLaw, rel_noise=True)
vis.plot_result(fig, loss, d_result, name="Direct (double shifted)")


p_result = fit.odr_fit(loss, llc, fit.XShiftedPowerLaw2, rel_noise=True)
vis.plot_result(fig, loss, p_result, name="Direct (Power Law)")


#vis.plot_result(fig, loss, e_result, name="Direct (Exponential)")

# Plot the parametric relationship too
q = vis.x_plot(step)  # resample at higher resolution
q_llc = result_pwr.f(q)
q_loss = loss_result.f(q)
fig.add_trace(go.Scatter(
    x=q_loss, y=q_llc, mode='lines', 
    name='Parametric (pwr-pwr)'))

u_llc = result_log.f(q)
fig.add_trace(go.Scatter(
    x=q_loss, y=u_llc, mode='lines', 
    name='Parametric (pwr-log)'))


fig.update_xaxes(title_text="loss")  #, type="log")
fig.update_yaxes(title_text="LLC")
fig.update_layout(width=800, height=600, title=f"LLC vs Loss {task}@{msize}")
fig.show()
