# Fit LLC vs Loss, power law vs exponential

* Use our "preferred model" (which is currently 3 parameter power law).
* row = model
* columns linear, log, r by l* (parameter introspection) with text labels (why not)


In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import fitting as fit
import data_utils as dat
import pandas as pd
import vis
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
# be conservative initially
m_codes = [
    ('14m', 1000, 20000),  # inclusive
    ('31m',1000, 20000),
    ('70m', 1000, 20000),
    ('160m', 1000, 30000),
    ('410m-dense', 2000, 100000),
    ('1b', 1000, 30000),
]


function = fit.OffsetPowerLaw
colors = None


for m_code, start_step, end_step in m_codes:
    msize = m_code.split("-")[0]
    df_llc, df_loss = dat.load_dfs(m_code, data_path="data")
    tasks = df_llc.columns
    if colors is None:
        # Assume the first model has all tasks
        colors = vis.assign_cols(df_llc.columns)

    fig = go.Figure()
    
    for task in tasks:
        color = colors[task]
        full = dat.trim_trace(df_llc, df_loss, task, start_step // 5, end_step * 5)
        trace = dat.trim_trace(df_llc, df_loss, task, start_step, end_step)
        
        #color2 = vis.add_color(color, 0.5, 128)  # pull towards gray
        
        # Fit the task
        result = fit.min_fit(trace.x, trace.y, function)
        shift = result.params_dict
        vis.plot_data(fig, *full, color='lightgray', shift=shift, mode='lines+markers')
        vis.plot_data(fig, *trace, color=color, showlegend=True, name=task, shift=shift)
        vis.plot_result(fig, trace.x, result, shift=shift, color=color)
    
    fig.update_layout(title=msize, width=800, height=600)
    fig.show()
# fig.update_layout(
#     title="",
#     width=1200,
#     height=410* len(functions),
#     showlegend=True,
#     legend_tracegroupgap=180,  # annoying - have to eyeball this
# )
# fname = f"plots/fitting_{msize}.pdf"
# fig.write_image(fname)
# #fig.show()
# print(f"Done. See {fname}")

# Make just an input data plot for pythia 410m


In [None]:
m_code, start_step, end_step = m_codes[4]  # 4 should be 410-dense
msize = m_code.split("-")[0]
df_llc, df_loss = dat.load_dfs(m_code, data_path="data")
tasks = df_llc.columns

fig = go.Figure()


for task in tasks:    
    full = dat.trim_trace(df_llc, df_loss, task, 32, 150000)
    vis.plot_data(fig, *full, color='lightgray', mode='lines+markers')

for task in tasks:
    color = colors[task]
    trace = dat.trim_trace(df_llc, df_loss, task, start_step, end_step)    
    vis.plot_data(fig, *trace, color=color, mode='lines+markers', name=task, showlegend=True)

fig.update_layout(title=f"Pythia-{msize} interval of interest", width=800, height=600)
fig.update_xaxes(title_text=r"$\text{Estimated and transformed LLC }\,\frac{1}{100}\hat{\lambda}$")
fig.update_yaxes(title_text=r"$\text{Loss }L$")
fname = f"plots/interval_{msize}.pdf"
fig.write_image(fname)
#fig.show()
print(f"Done. See {fname}")

# What does the parameter distribution look like?

In [None]:

titles = []
for m_code, start_step, end_step in m_codes:
    msize = m_code.split("-")[0]
    titles.append(f"Parameters for Pythia-{msize} on [{start_step}, {end_step}]")
fig = make_subplots(
    rows=3, cols=2,
    subplot_titles=titles,
    horizontal_spacing=0.1,
    vertical_spacing=0.1,
)
fig.update_layout(width=1200, height=410*3)

row = 1
col = 0

for m_code, start_step, end_step in m_codes:
    msize = m_code.split("-")[0]
    df_llc, df_loss = dat.load_dfs(m_code, data_path="data")
    tasks = df_llc.columns

    col += 1
    if col == 3:
        col = 1
        row += 1
    subplot = dict(row=row, col=col)
        
    for task in tasks:
        trace = dat.trim_trace(df_llc, df_loss, task, start_step, end_step)
        result = fit.min_fit(trace.x, trace.y, fit.OffsetPowerLaw)
        pars = result.params_dict
        desc = f"{msize}-{task}"
        fig.add_trace(go.Scatter(
            x=[pars["r"]],
            y=[pars["y*"]],
            name=task,
            marker=dict(
                color=colors[task],
                size=10,
            ),
            showlegend=(row+col==2),
            mode='markers+text',
            text=task,
            textposition='top center',
            textfont=dict(size=5),
        ), **subplot)
    fig.update_xaxes(title_text=r"Exponent r", **subplot)
    fig.update_yaxes(title_text=r"Loss Offset L*", **subplot)


fname = f"plots/parameters_grid.pdf"
fig.write_image(fname)
#fig.show()
print(f"Done. See {fname}")

In [None]:
# Is there any pattern over model size?

task_xs = {t:[] for t in tasks}
task_ys = {t:[] for t in tasks}
task_siz = {t:[] for t in tasks}


for m_code, start_step, end_step in m_codes[1:]:
    msize = m_code.split("-")[0]
    df_llc, df_loss = dat.load_dfs(m_code, data_path="data")
    tasks = df_llc.columns
    
    for task in tasks:
        trace = dat.trim_trace(df_llc, df_loss, task, start_step, end_step)
        result = fit.min_fit(trace.x, trace.y, fit.OffsetPowerLaw)
        pars = result.params_dict
        desc = f"{msize}-{task}"
        x = pars["r"]
        y = pars["y*"]
        task_xs[task].append(x)
        task_ys[task].append(y)
        task_siz[task].append(msize)


fig = go.Figure()

for task in tasks:
    if task == "dm_mathematics":
        continue  # its messed up
    fig.add_trace(go.Scatter(
        x=task_xs[task],
        y=task_ys[task],
        customdata=task_siz[task],
        marker=dict(
            color=colors[task],
            size=6,
        ),
        mode='markers+lines',
        name=task,
        # text=msize,
        # textposition='top right'
        hovertemplate="Model: %{customdata}<br><extra></extra>",
    ))
fig.update_xaxes(title_text="r")
fig.update_yaxes(title_text="y*")
fig.update_layout(title="Fit parameters - Trend 31m to 1B", width=800, height=600)

fname = f"plots/interesting_{msize}.pdf"
fig.write_image(fname)
print(f"Done. See {fname}")
fig.show()

# Final array - ALL THE FITS

In [None]:
# Now make the final figure for each of the datasets:
# Row: model size (6 rows)
# Column: fit: linear, fit_logspace, parameters
function = fit.OffsetPowerLaw  # our chosen model

batches = [
    ("plots/models1.pdf", m_codes[:3]),
    ("plots/models2.pdf", m_codes[3:]),
]
    
for fname, codes in batches:

    # Make a figure
    titles = []
    for m_code, start_step, end_step in codes:
        for view in ["Linear", "Log"]:
            msize = m_code.split("-")[0]
            titles.append(f"Pythia-{msize} on [{start_step}, {end_step}]")
    fig = make_subplots(
        rows=3, cols=2,
        subplot_titles=titles,
        horizontal_spacing=0.1,
        vertical_spacing=0.1,
    )
    fig.update_layout(
        width=1200,
        height=410*3,
        legend_tracegroupgap=120,  # annoying - have to eyeball this
    )
    row = 0
    for m_code, start_step, end_step in codes:
        row += 1
        msize = m_code.split("-")[0]
        df_llc, df_loss = dat.load_dfs(m_code, data_path="data")
        tasks = df_llc.columns
        
        for task in tasks:
            color = colors[task]
            trace = dat.trim_trace(df_llc, df_loss, task, start_step, end_step)
            result = fit.min_fit(trace.x, trace.y, function)
            func_desc = vis.dict2txt(result.params_dict, ",").replace("y*", "L*").replace(": ", ":")

            y_fit = result.f(trace.x)
            r2_score = fit.r2_score(trace.y, y_fit)
            
            # Plot the linear and logspace fits
            for col, shift in ((1, None), (2, result.params_dict)):
                subplot = dict(row=row, col=col)
                vis.plot_data(fig, *trace, color=color, shift=shift, subplot=subplot, showlegend=False)
                vis.plot_result(fig, trace.x, result, color=color, shift=shift, subplot=subplot, showlegend=col==1,
                               legendgroup=msize, name=f"{task} {func_desc}: R2={r2_score:.4f}")
                fig.update_xaxes(title_text=r"$\text{Estimated and transformed LLC }\,\frac{1}{100}\hat{\lambda}$", **subplot)
                if col == 1:
                    fig.update_yaxes(title_text=r"$\text{Loss }L$", **subplot)
                elif col == 2:
                    fig.update_yaxes(title_text=r"$\text{Loss }L - L^*$", **subplot)

    fig.write_image(fname)
    print(f"Saved {fname}")
    #fig.show()