# Fit LLC vs Loss, power law vs exponential

Perhaps a grid with combinations of *function* (power law, exponential) x Plot (linear, [shifted?] log space).
}
Then also a table of datasets as rows, and columns of R^2_{FF, Sp}. 
Don't use ODR yet (its unreliable for some methods).

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import fitting as fit
import data_utils as dat
import pandas as pd
import vis
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
fit_method = fit.min_fit  # odr fit doesn't work with polynomial (and not so good with 4 parameter)
m_code = '410m-dense'
msize = m_code.split("-")[0]
start_step = 2000
end_step = 80000

functions = [
    fit.DoubleOffsetPowerLaw,  # The first model sets the shift parameters for the other models
    fit.OffsetPowerLaw,
    fit.OffsetExponential,
    fit.Cubic,
]

metrics = {
    "R2_log": fit.logspace_r2,
    "R2_lin": fit.r2_score,
    # "RMSE": fit.rmse,  # RMSE and R2 are saying the same thing with different normalisation
}

# Collate reports and plots
reports = []  # the set of reports
colors = None  # Have something consistent across models


df_llc, df_loss = dat.load_dfs(m_code, data_path="data")
msize = msize.split("-")[0]  # in case it has -sparse or -dense
report = []

if colors is None:
    colors = vis.assign_cols(df_llc.columns)

# Make a 2-column layout for each function
titles = []
for f in functions:
    for s in ["linear", "log"]:
        titles.append(f.name)
        # titles.append(f"Pythia-{msize} {f.name} ({s})")
fig = make_subplots(
    rows=len(functions), cols=2,
    subplot_titles=titles,
    horizontal_spacing=0.1,
    vertical_spacing=0.1,
)
# shift = None

for task in df_llc:
    color = colors[task]
    x, y, s = dat.trim_trace(df_llc, df_loss, task, start_step, end_step)
    
    scores = []  # primary score
    results = []  # fit result
    
    for f_ind, function in enumerate(functions):

        result = fit_method(x, y, function)
        if f_ind == 0:
            # The first model sets the logspace shift for the plots
            assert function is fit.DoubleOffsetPowerLaw, "Check your axis labels"
            shift = result.params_dict
        
        # Evaluate result
        row = {
            "Dataset": task,
            "Function": function.name,
        }
        # row.update(result.params_dict)
        y_pred = result.f(x)
        measures = {k: v(y, y_pred) for k, v in metrics.items()}
        #score_repr = f"{task:-<20} R2={measures["R2_lin"]:.4f}"
        
        row.update(measures)
        # row.update(result.pcov_diagnostics())
        report.append(row)
        scores.append(measures["R2_lin"])
        results.append(result)

        
    for f_ind, function in enumerate(functions):
        result = results[f_ind]
        score = scores[f_ind]

        if score == max(scores):  # show the highest
            score_repr = f"R2=<b>{score:.4f}</b> {task}"
        else:
            score_repr = f"R2={score:.4f} {task}"
            
        # Plot the results twice
        # legend_id = f"legend{f_ind+1}"
        for col in range(1, 3):
            subplot = dict(row=f_ind+1, col=col)
            use_shift = shift if col==2 else None
            vis.plot_data(fig, x, y, s, color=color, name=task, showlegend=False, subplot=subplot, size=5, shift=use_shift)
            vis.plot_result(fig, x, result, name=score_repr, color=color, subplot=subplot, shift=use_shift,
                           showlegend=col==2, legendgroup=function.name) #legend=legend_id)
            fig.update_xaxes(title_text=r"$\text{Estimated and transformed LLC }\,\frac{1}{100}\hat{\lambda} - \lambda^*$", **subplot)
            if col == 1:
                fig.update_yaxes(title_text=r"$\text{Loss }L$", **subplot)
            elif col == 2:
                fig.update_yaxes(title_text=r"$\text{Loss }L - L^*$", **subplot)
        # Now make sure that legend displays
        # fig.update_layout(**{
        #     legend_id: dict(
        #         title=function.name,
        #         y = f_ind,
        #         yanchor="middle",
        #         xanchor="left",
        #     )
        # })
        
            
report = pd.DataFrame(report, index=range(len(report)))
report.index.name="Pythia" + msize

fig.update_layout(
    title="",
    width=1200,
    height=410* len(functions),
    showlegend=True,
    legend_tracegroupgap=180,  # annoying - have to eyeball this
)
fname = f"plots/fitting_{msize}.pdf"
fig.write_image(fname)
#fig.show()
print(f"Done. See {fname}")

In [None]:
Quality of fit metrics (R2, logspace R2 and RMSE) for candidate functional forms on different Pile data subsets for Pythia 410m.