# Fitting Powerlaws

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import fitting as fit
import data_utils as dat
import pandas as pd
import vis
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pickle

# Data

Get your data into two dataframes:
* df_llc
* df_loss

For both:
* index is step
* columns are datasets


In [None]:
model_name = "pythia-410m"
df_llc, df_loss = dat.load_dfs("410m", data_path="data")



# Confirm the analysis interval

In [None]:
# Plot the full curves for each model size
models = {
    '160m': [(1000, 100000)],
    '1b': [(1000, 20000)], #, (70000, 110000)],
    '1.4b': [(4000, 50000)],
    '2.8b': [(1000, 60000)],
    '6.9b': [(1000, 60000)],
}
colors = vis.assign_cols(models)
task = "full"

fig = go.Figure()

last_model = list(models)[-1]

for model in models:
    analysis_intervals = models[model]
    df_llc, df_loss = dat.load_dfs(model, data_path="data")

    visible = True if model==last_model else "legendonly"
    faded = vis.fade(color)
    # Plot the raw data
    color = colors[model]
    full_curve = dat.trim_trace(df_llc, df_loss, task)
    full_llc = full_curve[0]
    vis.plot_data(fig, *full_curve, color=faded, mode="lines+markers",
                  showlegend=False, legendgroup=model, visible=visible)

    for i, interval in enumerate(analysis_intervals):
        analysis_curve = llc, loss, step = dat.trim_trace(df_llc, df_loss, task, *interval)
        vis.plot_data(fig, *analysis_curve, color=color, 
                      showlegend=i==0, name=model, legendgroup=model, visible=visible)

        # Try fitting some powerlaws?
        result = fit.min_fit(llc, loss, fit.OffsetPowerLaw)
        desc = model+" powerlaw" + (str(i)+"?" if i>0 else "")
        # expand_llc = dat.trim_trace(df_llc, df_loss, task, interval[0]/3., interval[1]*2.)[0]
        # vis.plot_result(fig, expand_llc, result, color=faded,  showlegend=False, visible="legendonly", legendgroup=desc)
        vis.plot_result(fig, llc, result, color=color, name=desc, showlegend=True, visible="legendonly", legendgroup=desc)
        

llc_desc = r"$\text{Estimated and transformed LLC }\,\frac{1}{100}\hat{\lambda}$"
loss_desc = r"$\text{Loss }L$"


fig.update_layout(width=800, height=600, title='Loss vs LLC on Full-Pile subset')
fig.update_xaxes(title_text=llc_desc)
fig.update_yaxes(title_text=loss_desc)
fig.show()

fig.write_html("explore_full.html", include_mathjax="cdn")

In [None]:
full_curve

In [None]:
tasks = df_llc.columns

colors = vis.assign_cols(tasks)

analysis_interval = 2000, 80000

llc_desc = r"$\text{Estimated and transformed LLC }\,\frac{1}{100}\hat{\lambda}$"
loss_desc = r"$\text{Loss }L$"

fig = go.Figure()


for task in tasks:
    # Plot the raw data
    color = colors[task]
    full_curve = dat.trim_trace(df_llc, df_loss, task)
    vis.plot_data(fig, *full_curve, color=vis.fade(color))
                  
    analysis_curve = dat.trim_trace(df_llc, df_loss, task, *analysis_interval)
    vis.plot_data(fig, *analysis_curve, color=color, showlegend=True)

fig.update_layout(width=800, height=600)
fig.update_xaxes(title_text=llc_desc)
fig.update_yaxes(title_text=loss_desc)
fig.show()

# Fit powerlaws

In [None]:
fig = go.Figure()
report = []  # store parameters

for task in tasks:
    color = colors[task]
    llc, loss, step = trace = dat.trim_trace(df_llc, df_loss, task, *analysis_interval)
    

    result = fit.min_fit(llc, loss, fit.OffsetPowerLaw)
    shift = None  # set to result.params_dict if you want Dan's transformed style
    
    vis.plot_data(fig, *trace, color=color, shift=shift)
    vis.plot_result(fig, llc, result, color=color, name=task,
                    showlegend=True, shift=shift)

    report_row = {
        "dataset": task,
        "L*": result.params_dict["y*"],
        "r": result.params_dict["r"],
        "fit_r2": fit.r2_score(loss, result.f(llc))
    }   
    report.append(report_row)

fig.update_layout(width=800, height=600)
fig.update_xaxes(title_text=llc_desc)
fig.update_yaxes(title_text=loss_desc)
fig.show()
pd.DataFrame(report)