# Milestone Plan

## Here:
* Raw data plots L(t), lambda(t), L(lambda) as three subplots (below)

## fit_stats.ipynb
* tables of fits for each dataset for each model with fit parameters, R^2 in log space and other relevant info
* plots of fits (linear and log) for both models
* (maybe) visualisations of loss landscape

## validation.ipynb
* held-out cross-validation
* Dan's shifted log plots

## investigate_powerlaw.ipynb
* Make a plot and table comparing power law to exp fit on loss(t), show power law is clear winner
* Make a plot and table comparing power law to (shifted?) log for llc(t), show they're both pretty damn good.
* Using those estimates, plot the resulting algebraic equation for loss(llc), for both options of llc(t) curve (but just fixing loss(t) curve). Do they look pretty damn good? (And then, are they just as good as the original power law fits we were getting?)

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import fitting as fit
import data_utils as dat
import vis
import plotly.graph_objects as go
from plotly.subplots import make_subplots

## Raw Data Plots

In [None]:
# Config
msizes = ['14m', '31m', '70m', '160m', '410m', '1b']
tasks = None
cols = None
normalise = False
normalised = ""
if normalise:
    normalised = "Scaled "

# So I don't flub these
step_loss = dict(row=1, col=1)
step_llc = dict(row=1, col=2)

def rescale(v):
    v -= v.min()
    v /= v.max()
    

for msize in msizes:
    desc = 'Pythia-'+msize
    # Multiple figures or a big array (its REALLY BIG)
    
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=(f'{desc} {normalised}Loss vs Step', f'{desc} {normalised}LLC vs Step'),
        horizontal_spacing=0.1,
    )
    fig.update_xaxes(title_text="Step", type="log", **step_loss)
    fig.update_yaxes(title_text="Loss", **step_loss)
    fig.update_xaxes(title_text="Step", type="log", **step_llc)
    fig.update_yaxes(title_text="LLC", **step_llc)

    
    df_llc, df_loss = dat.load_dfs(msize, data_path="data")
    if tasks is None:
        tasks = df_llc.columns
        cols = vis.assign_cols(tasks)
        
    for task, col in cols.items():
        llc, loss, step = dat.trim_trace(df_llc, df_loss, task, 0, 150000)
        if normalise:
            llc = rescale(llc)
            loss = rescale(loss)
        
        lineprop = dict(color=cols[task])

        fig.add_trace(
            go.Scatter(x=step, y=loss, mode='lines+markers', line=lineprop, name=f"{task}", marker=dict(size=4)),
            **step_loss
        )
        fig.add_trace(
            go.Scatter(x=step, y=llc, mode='lines+markers', line=lineprop, name=f"{task}", showlegend=False, marker=dict(size=4)),
            **step_llc
        )

    fig.update_layout(
        title="",
        width=1000,
        height=500,
        showlegend=True,
        legend=dict(
            yanchor="middle",
            y=0.5,
            xanchor="right",
            x=1.3
        ),
    )
   
    fig.show()




In [None]:
# Also show loss vs llc (in a grid)
n_rows = (len(msizes)+1)//2

names = ["Pythia-"+s for s in msizes]

fig = make_subplots(
    rows=n_rows, cols=2,
    horizontal_spacing=0.1,
    subplot_titles=names
)
axtype = "log"  # or "linear"
row = 1
col = 0
tasks = None  # autofill
cols = None
showlegend=True
for i, msize in enumerate(msizes):

    col += 1
    if col == 3:
        col = 1
        row += 1
    
    desc = 'Pythia-'+msize
    fig.update_xaxes(title_text="LLC", type=axtype, row=row, col=col)
    fig.update_yaxes(title_text="Loss", type=axtype, row=row, col=col)
    fig.update_annotations({"text": desc}, row=row, col=col)
    
    df_llc, df_loss = dat.load_dfs(msize, data_path="data")
    
    if i==0:
        # Make a consistent colouring across tasks
        cols = vis.assign_cols(df_llc.columns)
        
    for task in df_llc.columns:
        
        llc, loss, step = dat.trim_trace(df_llc, df_loss, task, 0, 150000)
        if normalise:
            llc = rescale(llc)
            loss = rescale(loss)
        
        lineprop = dict(color=cols[task])

        fig.add_trace(
            go.Scatter(x=llc, y=loss, mode='lines+markers', line=lineprop, name=f"{task}",
                       marker=dict(size=4),
                      showlegend=showlegend),
            row=row, col=col,
        )

    showlegend=False
    
fig.update_layout(
    title="",
    width=1000,
    height=500*n_rows,
    showlegend=True,
    legend=dict(
        yanchor="middle",
        y=0.5,
        xanchor="right",
        x=1.3
    ),
)

fig.show()