In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from configs import *
import pandas as pd
from data import *
from analysis import *
from paper_figures import *
from paper_tables import *

In [None]:
save_figs = True

In [None]:
df = pd.read_pickle('data/experiment_results.pickle.xz', compression='xz')
df = process_big_df(df.copy())

In [None]:
df_sweep = process_sweep_df(df.query("hparams=='sweep'"))

In [None]:
multipliers = np.linspace(2,20,90)
df_sweep_mults = process_sweep_df(df.query("hparams=='sweep'"), trunc=multipliers)

## Processing dfs - main analysis

In [None]:
# takes about ~28 minutes, mainly due to loss curve fitting. skip this if you want to save time and load the precomputed results
summary_df = perform_main_analysis(df, FIGURE1_CONFIGS)
summary_df_att = perform_main_analysis(df, ATTENTION_ACCOUNTING_CONFIGS)
summary_df_kaplan_tuned_hparams = perform_main_analysis(df, [('rw', 'tuned', 'long', 'const', 'kaplan', 'train')])
summary_df_owt2 = perform_main_analysis(df, FIGURE1_CONFIGS_OWT2)

# saving the dataframes
summary_df.to_pickle('data/summary_df.pickle.xz', compression='xz')
summary_df_att.to_pickle('data/summary_df_att.pickle.xz', compression='xz')
summary_df_kaplan_tuned_hparams.to_pickle('data/summary_df_kaplan_tuned_hparams.pickle.xz', compression='xz')
summary_df_owt2.to_pickle('data/summary_df_owt2.pickle.xz', compression='xz')


In [None]:
# loading the dataframes (instead of running the above cells)
summary_df = pd.read_pickle('data/summary_df.pickle.xz', compression='xz')
summary_df_att = pd.read_pickle('data/summary_df_att.pickle.xz', compression='xz')
summary_df_kaplan_tuned_hparams = pd.read_pickle('data/summary_df_kaplan_tuned_hparams.pickle.xz', compression='xz')
summary_df_owt2 = pd.read_pickle('data/summary_df_owt2.pickle.xz', compression='xz')

# Figures

## Figure 1

In [None]:
figure1(summary_df, save=save_figs)
plt.show()

## Warmup evidence

In [None]:
warm_evidence_figure(summary_df, save=save_figs)

## IsoFLOP curves

In [None]:
isoflop_loss_figure(summary_df, save=save_figs, configs_to_show=None)
plt.show()

In [None]:
isoflop_loss_figure(summary_df_owt2, save=save_figs, configs_to_show=FIGURE1_CONFIGS_OWT2, save_path='../paper/figures/IsoFLOP-curves-owt2.pdf', ylim=[2.6,5.7])
plt.show()

## Different datasets and FLOP counts 

In [None]:
full_results_figure(summary_df_owt2, save=save_figs, configs_to_show=FIGURE1_CONFIGS_OWT2)
plt.show()

In [None]:
full_results_figure(summary_df, save=save_figs)
plt.show()

In [None]:
opt_N_with_attention_figure(summary_df_att,save=save_figs)
plt.show()

In [None]:
full_results_figure(summary_df_kaplan_tuned_hparams,
    configs_to_show=[('rw', 'tuned', 'long', 'const', 'kaplan', 'train')], save=save_figs, kaplan_adjusted=True)
plt.show()

## Accuracy vs. compute

In [None]:
# takes about 2 minutes
config_compute = ('rw', 'tuned', 'short', 'const', 'standard', 'val')
summary_compute = perform_varying_compute_analysis(df, [2.56e19, 5.76e23], config_compute)

In [None]:
accuracy_vs_compute_figure(summary_compute, save=False)
plt.show()

## Power laws for loss 

In [None]:
opt_loss_figure(summary_df, save=save_figs, bootstrap_num=0)
plt.show()

In [None]:
opt_loss_extended_figure(summary_df, save=save_figs, bootstrap_num=200)
plt.show()

## Hyperparameters sweep results

In [None]:
df_sweep_opt_eta_and_bs, fit = get_interpolated_hparams_dfs(df_sweep)

hparams_fit_figure(df_sweep, df_sweep_opt_eta_and_bs, fit, save=save_figs)
plt.show()

In [None]:
df_sweep_beta2_095 = process_sweep_df(df.query("hparams=='sweep'").query('beta2==0.95').copy())

df_sweep_beta2_095_opt_eta_and_bs, fit_beta095 = get_interpolated_hparams_dfs(df_sweep_beta2_095)

hparams_fit_figure(df_sweep_beta2_095, df_sweep_beta2_095_opt_eta_and_bs, fit_beta095, save=save_figs, save_path='../paper/figures/hparams_fit_0.95.pdf')
plt.show()


In [None]:
full_sweep_figure(create_pivot_df(df_sweep), save=save_figs)

## Estimation of ideal tuning

In [None]:
params_lr_map = dict(zip(df_sweep_opt_eta_and_bs['params'], df_sweep_opt_eta_and_bs['lr']))
params_bs_map = dict(zip(df_sweep_opt_eta_and_bs['params'], df_sweep_opt_eta_and_bs['bs']))

In [None]:
params_bs_map = {5173248: 20, 7503872: 28, 9809920: 32, 15597568: 44, 22487040: 56, 28672000: 64, 37060608: 80, 57384960: 104, 84787200: 128, 108462080: 160, 149045248: 192, 220872704: 256, 347078656: 320, 455311360: 448, 611958784: 512, 901726208: 640}
params_lr_map = {5173248: 1.30e-02, 7503872: 1.15e-02, 9809920: 1.05e-02, 15597568: 9.00e-03, 22487040: 8.00e-03, 28672000: 5.10e-03, 37060608: 5.10e-03, 57384960: 4.70e-03, 84787200: 5.10e-03, 108462080: 4.70e-03, 149045248: 4.30e-03, 220872704: 3.80e-03, 347078656: 3.20e-03, 455311360: 3.00e-03, 611958784: 2.70e-03, 901726208: 2.40e-03}
df_sweep_mults['lr_star'] = df_sweep_mults['params'].map(params_lr_map)
df_sweep_mults['bs_star'] = df_sweep_mults['params'].map(params_bs_map)
df_sweep_extended, fits_extended = hparams_other_multipliers(df_sweep_mults, multipliers)

In [None]:
fits_N_vs_L_diff = fit_l_star_vs_N_for_M(df_sweep_extended, smoothed=True)

In [None]:
tuning_excess_df, optimal_pairs, fit_results = preform_analysis_with_sweep_data(summary_df.iloc[-1], fits_N_vs_L_diff)

In [None]:
ideal_tuning_figure(summary_df.iloc[-1], summary_compute, df_sweep_extended, tuning_excess_df, optimal_pairs, fit_results, save=save_figs, save_path='../paper/figures/ideal_tuning.pdf', flop_vals_tuning=[1.25e16, 1.6e18])

## Seed variance plot

In [None]:
seed_df = df.query("hparams=='seed'")

In [None]:
seed_noise_figure(perform_seed_var_analysis(seed_df), save=save_figs, save_path='../paper/figures/seed_noise.pdf')

## Train loss curves

In [None]:
loss_curves_figure(df, save=save_figs, save_path='../paper/figures/loss-curves-rw.pdf')

In [None]:
loss_curves_figure(df, save=save_figs, save_path='../paper/figures/loss-curves-owt2.pdf', configs_to_show=FIGURE1_CONFIGS_OWT2, ylim=[2.64,5])

# Tables

In [None]:
results_table_df = results_table(pd.concat([summary_df, summary_df_owt2]), flop_vals=FLOP_VALS, validation='all')

In [None]:
results_table_df.loc[len(results_table_df)] = results_table(summary_df_kaplan_tuned_hparams, flop_vals=FLOP_VALS, validation='all').iloc[0]
results_table_df.loc[-1] = ["Kaplan Law", "WebText2", "0.88", "", ""]
results_table_df.loc[-2] = ["Hoffmann Law", "MassiveText", "0.5", "", ""]
results_table_df.index = results_table_df.index + 2
results_table_df = results_table_df.sort_index()


In [None]:
results_table_df

In [None]:
tuned_hparams(df.query('dataset=="rw" and hparams=="tuned" and warmup=="short"')).sort_values('Batch size')

In [None]:
archs_table_df = archs_table(df.query('dataset=="rw" and hparams=="tuned" and warmup=="short"'))

In [None]:
archs_table_df