# Fitting Powerlaws

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pickle

from lsoc.powerlaw import fit, data, vis


# Data

Get your data into two dataframes:
* df_llc
* df_loss

For both:
* index is step
* columns are datasets


In [None]:
model_name = "pythia-410m"
# -dense is a special case with dense checkpoints, the others are sparse
df_llc, df_loss = data.load_dfs("410m-dense")

# # Save multiple DataFrames directly with pickle
# with open('example.pkl', 'wb') as f:
#     pickle.dump((df_llc, df_loss, model_name), f)

# # Load them back
# with open('example.pkl', 'rb') as f:
#     df_llc, df_loss, model_name = pickle.load(f)

# Confirm the analysis interval

In [None]:
tasks = df_llc.columns

colors = vis.assign_cols(tasks)

analysis_interval = 2000, 80000


fig = go.Figure()


for task in tasks:
    # Plot the raw data
    color = colors[task]
    full_curve = data.trim_trace(df_llc, df_loss, task)
    vis.plot_data(fig, *full_curve, color=vis.fade(color),legendgroup=task)
                  
    analysis_curve = data.trim_trace(df_llc, df_loss, task, *analysis_interval)
    vis.plot_data(fig, *analysis_curve, color=color, name=task, legendgroup=task)

fig.update_layout(width=800, height=600)
fig.update_xaxes(title_text=vis.llc_desc)
fig.update_yaxes(title_text=vis.loss_desc)
fig.show()

# Fit powerlaws

In [None]:
fig = go.Figure()
report = []  # store parameters

for task in tasks:
    color = colors[task]
    llc, loss, step = trace = data.trim_trace(df_llc, df_loss, task, *analysis_interval)
    

    result = fit.min_fit(llc, loss, fit.OffsetPowerLaw)
    shift = None  # set to result.params_dict if you want Dan's transformed style
    
    vis.plot_data(fig, *trace, color=color, shift=shift)
    vis.plot_result(fig, llc, result, color=color, name=task,
                    showlegend=True, shift=shift)

    report_row = {
        "dataset": task,
        "L*": result.params_dict["y*"],
        "r": result.params_dict["r"],
        "fit_r2": fit.r2_score(loss, result.f(llc))
    }   
    report.append(report_row)

fig.update_layout(width=800, height=600)
fig.update_xaxes(title_text=vis.llc_desc)
fig.update_yaxes(title_text=vis.loss_desc)
fig.show()
pd.DataFrame(report)