In [1]:
# Imports
import autograd.numpy as np
from autograd import grad
import pandas as pd
from scipy.optimize import curve_fit, minimize
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score
from itertools import product
import pandas as pd
# import numpy as np
import matplotlib.pyplot as plt
from os import CLD_CONTINUED
import warnings
import numpy as np
from scipy.optimize import basinhopping
from sklearn.metrics._plot.confusion_matrix import unique_labels
import scipy.interpolate

from scipy.optimize import OptimizeWarning
np.seterr(over='ignore')
np.seterr(invalid='ignore')
warnings.filterwarnings("ignore", category=OptimizeWarning)

In [3]:
# dataset, hparams, warmup, decay, param_count, val = config
# from https://github.com/formll/resolving-scaling-law-discrepancies

FIGURE1_CONFIGS = [
    ('rw', 'base', 'long', 'kaplan', 'kaplan', 'train'),
    ('rw', 'base', 'long', 'kaplan', 'standard', 'val'),
    ('rw', 'base', 'short', 'kaplan', 'standard', 'val'),
    ('rw', 'base', 'short', 'chinchilla', 'standard', 'val'),
    ('rw', 'tuned', 'short', 'const', 'standard', 'val')
]

CONFIG_DICT_LABEL = {
    ('rw', 'base', 'long', 'kaplan', 'kaplan', 'train'): 'Reproducing Kaplan et al.',
    ('rw', 'base', 'long', 'kaplan', 'standard', 'val'): 'Counting last layer FLOPs',
    ('rw', 'base', 'short', 'kaplan', 'standard', 'val'): 'Correcting warmup',
    ('rw', 'base', 'short', 'chinchilla', 'standard', 'val'): 'Cosine decay', # original Chinchilla?
    ('rw', 'tuned', 'short', 'const', 'standard', 'val'): 'Optimizer tuning (no decay)',
    ('rw', 'tuned', 'short', 'const', 'standard', 'train'): 'Optimizer tuning (no decay) - train',
    ('rw', 'tuned', 'long', 'const', 'kaplan', 'train'): '', # kaplan tuned
    ('rw', 'base', 'long', 'kaplan', 'attention', 'train'): 'Counting last layer\nand attention FLOPs',
    ('rw', 'base', 'short', 'kaplan', 'attention', 'train'): 'Correcting warmup',
    ('rw', 'base', 'short', 'chinchilla', 'attention', 'train'): 'Cosine decay',
    ('rw', 'tuned', 'short', 'const', 'attention', 'train'): 'Optimizer tuning (no decay)',
}

ISOFLOP_ARGS = {
    ('kaplan', 'train'): dict(loss_key='train/loss_smoothed', flop_per_token_key='flops_per_token_no_att_no_embed', n_key='params_no_embed'),
    ('standard', 'val'):  dict(loss_key='val/loss', flop_per_token_key='flops_per_token', n_key='params'),
    ('standard', 'train'): dict(loss_key='train/loss_smoothed', flop_per_token_key='flops_per_token', n_key='params'),
    ('attention', 'train'): dict(loss_key='train/loss_smoothed', flop_per_token_key='flops_per_token_att', n_key='eff_params_att'),
}

# LABEL_TO_CONFIG_DICT = {
#     n: c for c, n in CONFIG_DICT_LABEL.items()
# }

def fetch_flop(df, flop, loss_key='train/loss_smoothed', warmup_remove_factor=1e-12, n_key='params', 
               seq_len=2048, bs_key='bs', keep_bs_lr_keys=False,
               flop_per_token_key='flops_per_token', flop_tolerance=0.1):
    out = []
    for _, row in df.iterrows():
        if len(row[loss_key]) == 0:
            continue
        loss_vals = row[loss_key].dropna().groupby(level=0).mean().sort_index()
        step_vals = loss_vals.index
        mask = step_vals >= ((warmup_remove_factor * row.warmup_tokens) / row.bs / row.seq_len)
        loss_vals = loss_vals[mask]
        loss_vals.index = loss_vals.index.astype(float) * seq_len * row[bs_key] * row[flop_per_token_key]
        flop_vals = loss_vals.index
        
        if len(loss_vals) == 0:
            continue        
        flop_ind = loss_vals.index.searchsorted(flop)
        if flop_ind > 0:
            flop_ind += -1 + np.abs(np.log(flop_vals[flop_ind-1:flop_ind+1]/flop)).argmin()
        rel_err = np.exp(np.abs(np.log(flop_vals[flop_ind]/flop))) - 1
        if rel_err > flop_tolerance:
            continue

        if len(flop_vals) > 1:
            flop_slice = flop_vals[max(0,flop_ind-5):flop_ind+5]
            loss_slide = loss_vals.iloc[max(0,flop_ind-5):flop_ind+5]
            loss_interp = np.exp(np.interp(np.log(flop), np.log(flop_slice), np.log(loss_slide)))
            out.append(dict(n=row[n_key], t=flop / row[flop_per_token_key], loss=loss_interp))
        else:
            out.append(dict(n=row[n_key], t=loss_vals.index[flop_ind] / row[flop_per_token_key], loss=loss_vals.iloc[flop_ind]))
        if keep_bs_lr_keys:
            out[-1].update({k: row[k] for k in [bs_key, 'lr']})

    return pd.DataFrame(out)


# def power_law_fit(df, x, y, weighted=False):
#     if isinstance(y, (list, tuple)):
#         out = {}
#         for yy in y:
#             if 'loss' not in yy:
#                 out.update(power_law_fit(df, x, yy, weighted=weighted))
#             else:
#                 df = df.copy()
#                 out.update(fit_loss_with_saturation(df, weighted=weighted))
#         return out
#     else:
#         X_data = np.log(df.dropna()[x].values).reshape(-1, 1)
#         y_data = np.log(df.dropna()[y].values)
#         std_key = f'{y}_star_std'
#         if weighted and std_key in df.columns:
#             y_data_std = df.dropna()[std_key].values
#             w = 1 / y_data_std ** 2
#         else:
#             w = None

#         clf = LinearRegression().fit(X_data, y_data, sample_weight=w)
#         return {f'{y}_exponent': clf.coef_.item(),
#                 f'{y}_coef': np.exp(clf.intercept_),
#                 f'{y}_r2': clf.score(X_data, y_data)}


# def fit_compute_optimal_power_laws(optimal_pairs, bootstrap_data, bootstrap_num=None, bootstrap_num_loss=200, fit_loss=True):
#     keys_to_fit = ['n', 't', 'multiplier']
#     if fit_loss:
#         keys_to_fit.append('loss')
#     out = {'basic': power_law_fit(optimal_pairs.reset_index(), 'flops', keys_to_fit),
#            'weighted': power_law_fit(optimal_pairs.reset_index(), 'flops', keys_to_fit, weighted=True)}
#     bootstrap_samples = bootstrap_data.dropna().set_index('flops')[
#         ['n_stars', 't_stars', 'multiplier_stars', 'loss_stars', 'n_star_std', 't_star_std', 'loss_star_std']].rename(
#         columns=lambda x: x.replace('_stars', ''))
#     if bootstrap_num is None:
#         bootstrap_num = bootstrap_samples[['n', 't', 'multiplier']].applymap(len).min().min()

#     for name, is_weighted in dict(bootstrap=False, bootstrap_weighted=True).items():
#         bs_smaples_arr = [
#             power_law_fit(bootstrap_samples.applymap(lambda x: maybe_get_item(x, i)).reset_index(),
#             'flops', ['loss'], weighted=is_weighted)
#             for i in range(bootstrap_num_loss)
#             ] if fit_loss else []
#         bs_smaples_arr.extend([power_law_fit(
#             bootstrap_samples.applymap(lambda x: maybe_get_item(x, i)).reset_index(),
#             'flops', ['n', 't', 'multiplier'], weighted=is_weighted)
#             for i in range(bootstrap_num)])
#         out[name] = bs_smaples_arr
#     bootstrap_medians = bootstrap_samples.applymap(np.median)
#     out.update({
#         'bs_median': power_law_fit(bootstrap_medians.reset_index(), 'flops', keys_to_fit),
#         'bs_median_weighted': power_law_fit(bootstrap_medians.reset_index(), 'flops', keys_to_fit, weighted=True)})
#     return out


def get_noise_for_loss(loss, bootstrap_iters, noise_low=0.005, noise_high=0.1, l_threshold_high=6, l_threshold_low=3):
    basic_noise = np.random.normal(0, 1, (bootstrap_iters, len(loss) // bootstrap_iters))
    noise_adjusted_losses = np.zeros((bootstrap_iters, len(loss) // bootstrap_iters))

    for i in range(len(loss) // bootstrap_iters):
        if np.log(loss[i]) >= l_threshold_high:
            log_noise = np.log(noise_high)
        elif np.log(loss[i]) <= l_threshold_low:
            log_noise = np.log(noise_low)
        else:
            log_noise = np.interp(np.log(loss[i]), [np.log(l_threshold_low), np.log(l_threshold_high)], [np.log(noise_low), np.log(noise_high)])
        noise_factor = np.exp(log_noise)
        noise_adjusted_losses[:, i] = loss[i] + noise_factor * basic_noise[:, i]
        
    return noise_adjusted_losses.flatten()


def vectorized_interp_with_seed_noise(df, n_interp_, bootstrap_iters, seed_noise=None,
                                      min_std_factor=0.33, tok_or_n='n'):
    if seed_noise is None:
        seed_noise = {}
    interp_num = len(n_interp_)
    stacked_df = pd.concat([df] * bootstrap_iters).reset_index(drop=True)
    stacked_df['loss'] = get_noise_for_loss(stacked_df.loss, bootstrap_iters=bootstrap_iters, **seed_noise)

    batch_ids = np.repeat(np.arange(bootstrap_iters), len(df))
    stacked_df['batch_id'] = batch_ids
    stacked_df.sort_values(by=['batch_id', tok_or_n], inplace=True)

    def batch_interp(batch):
        interp = scipy.interpolate.Akima1DInterpolator(np.log(batch[tok_or_n]), np.log(batch['loss']))
        return np.exp(interp(np.log(n_interp_)))

    interpolated_values = stacked_df.groupby('batch_id').apply(batch_interp)

    # Find the index of the minimum interpolated loss value per batch
    min_indices = interpolated_values.apply(np.argmin)
    results = [n_interp_[idx] if idx != 0 and idx != interp_num - 1 else None for idx in min_indices]
    valid_results_loss = [interpolated_values[i][idx] for i, idx in enumerate(min_indices) if idx != 0 and idx != interp_num - 1]
    # Filter None values and calculate statistics
    valid_results = [result for result in results if result is not None]
    if len(valid_results) < bootstrap_iters // 2:
        return None, 0, None, None, None
    else:
        n_star_std_ = np.std(np.log(valid_results))
        min_std = min_std_factor * np.log(n_interp_[1] / n_interp_[0])  # this assumes a roughly uniform grid
        n_star_std_ = max(n_star_std_, min_std) * (bootstrap_iters / len(valid_results))
        loss_star_std_ = np.std(np.log(valid_results_loss))
        min_std_loss = min_std_factor * min([np.log(df.loss.iloc[i+1] / df.loss.iloc[i]) for i in range(len(df) - 1)])
        loss_star_std_ = max(loss_star_std_, min_std_loss) * (bootstrap_iters / len(valid_results_loss))
        return n_star_std_, None, valid_results, valid_results_loss, loss_star_std_


def interpolation(df_, interp_num, bootstrap_iters, seed_noise, min_std_factor, interp_num_multiplier, std_method, col):
    interp_ = np.geomspace(df_[col].min(), df_[col].max(), interp_num)
    df_ = df_.sort_values(col)
    interpolator = scipy.interpolate.Akima1DInterpolator(np.log(df_[col]), np.log(df_.loss))
    loss_interp_ = np.exp(interpolator(np.log(interp_)))
    star_ind_ = loss_interp_.argmin()

    if std_method == 'add_seed_noise':
        star_std_, _, noised_stars_, noised_loss, loss_star_std = vectorized_interp_with_seed_noise(
            df_, interp_, bootstrap_iters, seed_noise, min_std_factor * interp_num_multiplier, tok_or_n=col)
    else:
        star_std_ = None
        noised_stars_ = []

    return star_ind_, star_std_, noised_stars_, interp_, loss_interp_, noised_loss, loss_star_std


def interp_flop(big_df, loss_key, flop_vals=[8e16, 3e17, 6e17, 3e18, 6e18, 1e19], groupby_action='min',
                warmup_remove_factor=1e-12,
                interp_num_multiplier=25,
                n_key='params', n_star_std_method='add_seed_noise', t_star_std_method='add_seed_noise',
                bootstrap_iters=1000,
                min_std_factor=0.33,
                seed_noise=None, flop_tolerance=0.1,
                flop_per_token_key='flops_per_token',
                bs_median_as_obs=True,
                keep_bs_lr_keys=False,
                ):
    out = []
    optimal_pairs = []
    max_loss, min_loss = 0, 1e12

    for c in flop_vals:
        df_ = fetch_flop(big_df, c, loss_key=loss_key, 
                         warmup_remove_factor=warmup_remove_factor, n_key=n_key, 
                         flop_per_token_key=flop_per_token_key,
                         flop_tolerance=flop_tolerance, keep_bs_lr_keys=keep_bs_lr_keys)

        if len(df_) < 3:
            out.append(dict(n_interp=None, loss_interp=None, t_interp=None, 
                            loss_interp_tok=None, opt_ind=None, opt_tok_ind=None, flops=c))
            continue
        if 'bs' in df_.columns and 'lr' in df_.columns:
            df_sweep_opt_eta = df_.groupby(['n','bs']).apply(minimize_with_interp).drop(['bs', 'n'], axis=1).reset_index()
            df_sweep_opt_eta_and_bs = df_sweep_opt_eta.groupby(['n']).apply(lambda x: minimize_with_interp(x, x_key='bs')).drop('n', axis=1).reset_index()
            df_ = df_sweep_opt_eta_and_bs[['n']]
            df_['t'] 
            print(df_.iloc[0].loss)
        elif groupby_action == 'min':
            df_ = df_.loc[df_.groupby(['n']).loss.idxmin()]
        elif groupby_action == 'mean':
            df_ = df_.groupby('n').mean()
        else:
            raise ValueError(f'Unknown groupby_action {groupby_action}')
        df_ = df_.reset_index()

        interp_num = (len(df_) - 1) * interp_num_multiplier

        max_loss, min_loss = max(max_loss, df_.loss.max()), min(min_loss, df_.loss.min())
        
        n_star_ind_, n_star_std_, noised_n_stars_, n_interp_, loss_interp_, noised_loss, loss_star_std = interpolation(
            df_, interp_num, bootstrap_iters, seed_noise, min_std_factor, interp_num_multiplier, n_star_std_method, 'n')

        t_star_ind_, t_star_std_, noised_t_stars_, t_interp_, loss_interp_tok_, noised_loss, _ = interpolation(
            df_, interp_num, bootstrap_iters, seed_noise, min_std_factor, interp_num_multiplier, t_star_std_method, 't')
        
        if n_star_ind_ != 0 and n_star_ind_ != interp_num -1 and noised_n_stars_ is not None:
            optimal_pairs.append(
                dict(flops=c, n=n_interp_[n_star_ind_], t=t_interp_[t_star_ind_], multiplier=c / 6 / (n_interp_[n_star_ind_]**2),
                     loss=loss_interp_.min(), loss_t=loss_interp_tok_.min(),
                     n_vals=df_.n.values, t_vals=df_.t.values, loss_vals=df_.loss
                    )
            )
        else:
            optimal_pairs.append(
                dict(flops=c, n=None, t=None, loss=None, loss_t=None,
                        n_vals=df_.n.values, t_vals=df_.t.values, loss_vals=df_.loss
                    )
            )
        out.append(
            dict(n_interp=n_interp_, loss_interp=loss_interp_, 
                 t_interp=t_interp_, loss_interp_tok=loss_interp_tok_, 
                 opt_ind=n_star_ind_, opt_tok_ind=t_star_ind_, flops=c, 
                 orig_n=df_.n, orig_t=df_.t, orig_loss=df_.loss)
            )
        if n_star_std_method == 'add_seed_noise':
            out[-1]['n_star_std'] = n_star_std_
            out[-1]['n_stars'] = noised_n_stars_
            out[-1]['multiplier_stars'] = (c / (6 * np.array(noised_n_stars_)**2)) if noised_n_stars_ is not None else None
            optimal_pairs[-1]['n_star_std'] = n_star_std_ 
            
            out[-1]['multiplier_star_std'] = 2 * n_star_std_ if n_star_std_ is not None else None
            optimal_pairs[-1]['multiplier_star_std'] = 2 * n_star_std_ if n_star_std_ is not None else None

            out[-1]['t_star_std'] = t_star_std_
            out[-1]['t_stars'] = noised_t_stars_
            optimal_pairs[-1]['t_star_std'] = t_star_std_

            out[-1]['loss_stars'] = noised_loss
            out[-1]['loss_star_std'] = loss_star_std 
            optimal_pairs[-1]['loss_star_std'] = loss_star_std

    out_df = pd.DataFrame(out)
    optimal_pairs_df = pd.DataFrame(optimal_pairs)

    if bs_median_as_obs:
        for ind, row in optimal_pairs_df.iterrows():
            if row['n'] is None or np.isnan(row['n']):
                continue
            flop = row['flops']
            data_row = out_df.set_index('flops').loc[flop]
            for key in ['n', 't', 'multiplier', 'loss']:
                optimal_pairs_df.at[ind, key] = np.median(data_row[key + '_stars']) if data_row[key + '_stars'] is not None else row[key]

    return out_df, optimal_pairs_df, max_loss, min_loss


In [4]:


def get_rsld_data(config_number=0):
    big_df = pd.read_pickle('/fsx-onellm/margaretli/code/scaling_tomer/resolving-scaling-law-discrepancies/data/experiment_results.pickle.xz', compression='xz')
    big_df = process_big_df(big_df.copy())
    config = FIGURE1_CONFIGS[config_number]
    dataset, hparams, warmup, decay, param_count, val = config
    df = big_df.query(f"dataset=='{dataset}' and hparams=='{hparams}' and warmup=='{warmup}' and decay=='{decay}'")
    
    for loss_key in ['val/loss', 'train/loss']:
        for ind, row in df.iterrows():
            row[f"{loss_key}"].index = row[f"{loss_key}"].index.astype(float) * row['seq_len'] * row['bs']
            df.loc[ind, f'last_{loss_key}_C'] = row[loss_key].index[-1] * row['flops_per_token']
            df.loc[ind, f'last_{loss_key}_D'] = row[loss_key].index[-1] 
            df.loc[ind, f'last_{loss_key}'] = row[loss_key].iloc[-1]


    return df


def get_rsld_data_interp(config_number=0,
                          flop_vals=None,
                          seed=42, seed_noise_args=None, 
                          keep_bs_lr_keys=False
                          ):
    big_df = pd.read_pickle('/fsx-onellm/margaretli/code/scaling_tomer/resolving-scaling-law-discrepancies/data/experiment_results.pickle.xz', compression='xz')
    big_df = process_big_df(big_df.copy())
    config = FIGURE1_CONFIGS[config_number]
    
    np.random.seed(seed)

    if flop_vals is None:
        flop_vals = FLOP_VALS
    if seed_noise_args is None:
        seed_noise_args = SEED_ARGS
    df = big_df.copy()
    out = []
    dataset, hparams, warmup, decay, param_count, val = config
    show_df = df.query(f"dataset=='{dataset}' and hparams=='{hparams}' and warmup=='{warmup}' and decay=='{decay}'")

    if len(show_df) == 0:
        continue
    data, optimal_pairs, max_loss, min_loss = interp_flop(
        show_df, seed_noise = seed_noise_args[config], 
        flop_vals=flop_vals, **ISOFLOP_ARGS[config[-2:]],
        keep_bs_lr_keys=keep_bs_lr_keys,
    )

    
    # if len(show_df) == 0:
    #     continue
    
    # data, optimal_pairs, max_loss, min_loss = interp_flop(
    #     show_df, seed_noise = seed_noise_args[config], 
    #     flop_vals=flop_vals, **ISOFLOP_ARGS[config[-2:]],
    #     keep_bs_lr_keys=keep_bs_lr_keys,
    # )

    # fit_results = fit_compute_optimal_power_laws(optimal_pairs, data, fit_loss=True)

#         out.append(dict(dataset=dataset, hparams=hparams, warmup=warmup, decay=decay, param_count=param_count, val=val, 
#                         optimal_pairs=optimal_pairs, fit_results=fit_results,
#                         data=data, max_loss=max_loss, min_loss=min_loss,))
# return pd.DataFrame(out)

In [5]:


def scaling_law_chinch(N, D, params):
    if len(params) == 6:
        params = params[:-1]
    a, b, e, alpha, beta = params
        
    A = np.exp(a)
    B = np.exp(b)
    E = np.exp(e)
    
    L = E + (A / (N**alpha)) + (B /(D**beta))
    
    return L

def scaling_law_kaplan(N, D, params):
    if len(params) == 6:
        params = params[:-1]
    a, b, e, alpha, beta = params
        
    A = np.exp(a)
    B = np.exp(b)
    E = np.exp(e)
    
    L = ((A/N)**(alpha/beta) + B/D)**beta
    
    return L


# def scaling_law_kaplan_chinch(N, D, params):
#     if len(params) == 6:
#         params = params[:-1]
#     a, b, e, alpha, beta = params
        
#     A = np.exp(a)
#     B = np.exp(b)
#     E = np.exp(e)
    
#     L = E + ((A/N)**(alpha/beta) + B/D)**beta
    
#     return L


def opt_N_D(C, G, opt_a, opt_b):
    opt_N = G*(C/6)**opt_a
    opt_D = (1/G)*(C/6)**opt_b
    return opt_N, opt_D

def print_opts(best_params):
    if len(best_params) == 6:
        best_params = best_params[:-1]
    opt_alpha = best_params[-2]
    opt_beta = best_params[-1]

    opt_a =  opt_beta / (opt_alpha+opt_beta)
    opt_b =  opt_alpha / (opt_alpha+opt_beta)

    A = np.exp(best_params[0])
    B = np.exp(best_params[1])
    G = ((opt_alpha*A)/(opt_beta*B))**(1/(opt_alpha+opt_beta))

    scaling = []

    for C in [1.25E+18, 5.01E+18, 1.98E+19, 1E21, 1E23]:
        N, D = opt_N_D(C, G, opt_a, opt_b)
        scaling.append(
            {"compute": f"{C:e}",
            "parameters (B)": f"{N/1e9:.2f}",
            "tokens (B)": f"{D/1e9:.2f}",
            "ratio": f"{D/N:.2f}",
            "predicted loss (Chinchilla)": f"{scaling_law_chinch(N, D, best_params):.2f}"
            "predicted loss (Kaplan)": f"{scaling_law_kaplan(N, D, best_params):.2f}"
            }
        )
    print("Scaling: \n", pd.DataFrame(scaling))
        

In [6]:
# from https://github.com/formll/resolving-scaling-law-discrepancies
def precise_flops_per_token_chinchilla(width, depth):
    seq_len = 2048
    vocab_size = 50432
    num_heads = 4
    width = width.astype(float)
    depth = depth.astype(float)

    embeddings = 2 * seq_len * width

    attention = 2 * 3 * seq_len * (width ** 2)
    kq_logits = 2 * seq_len * seq_len * width
    softmax = 3 * num_heads * seq_len * seq_len
    softmax_q_red = 2 * seq_len * seq_len * width
    final_linear = 2 * seq_len * (width ** 2)
    attention += kq_logits + softmax + softmax_q_red + final_linear

    ffw_size = 4 * width # check this, in the paper it is 4 * width
    dense_block = 4 * seq_len * width * ffw_size
    final_logits = 2 * seq_len * width * vocab_size
    forward_pass = embeddings + depth * attention + depth * dense_block + final_logits
    backward_pass = 2 * forward_pass
    return (forward_pass + backward_pass) / seq_len
    
def precise_param_count_open_lm(width, depth, vocab_size=50432):
    d_ff = 256 * (((2 * 4 * width / 3).astype(int) + 256 - 1) // 256)
    return (4 * width + 3 * d_ff) * width * depth + vocab_size * width
    
def apply_smoothing_filter(df, filter_func, compensate_for_logging_delay=True, key='train/loss', **filter_args):
    out = []
    for _, row in df.iterrows():
        if len(row[key]) == 0:
            out.append(None)
            continue
        filtered = filter_func(row[key].dropna(), **filter_args)
        if compensate_for_logging_delay:
            filtered.index = filtered.index - np.diff(filtered.index, prepend=0)/2
        out.append(filtered)
    return out

def proportional_sliding_window_filter(x, p=0.05):
    # assert that the index of x has constant increments?
    x_cumsum = x.cumsum().values
    x_cumsum_pad = np.concatenate([[0], x_cumsum])
    inds = np.arange(len(x))
    inds_up = np.minimum(inds + np.floor(p * inds).astype(int), len(x)-1)
    inds_down = np.maximum(0, inds - np.floor(p * inds).astype(int))
    inds_new = (inds_up + inds_down)/2
    index_new = np.interp(inds_new, inds, x.index)
    try:
        x_series = pd.Series((x_cumsum[inds_up] - x_cumsum_pad[inds_down]) / (inds_up - inds_down+1), 
                     index=index_new, name=x.name + '_smoothed')
    except:
        print(x)
        x_series = pd.Series((x_cumsum[inds_up] - x_cumsum_pad[inds_down]) / (inds_up - inds_down+1), 
                     index=index_new, name="" + '_smoothed')
    return x_series
    
def process_big_df(big_df):
    big_df = big_df.copy()

    # Counting parameters
    big_df['params_active'] = (12 * (big_df.width**2) * big_df.depth + big_df.vocab_size * big_df.width).astype(float)
    big_df['params_active_precise'] = precise_param_count_open_lm(big_df.width, big_df.depth)
    big_df['params_no_embed'] = precise_param_count_open_lm(big_df.width, big_df.depth, vocab_size=0)
    big_df['params_all'] = 12 * (big_df.width**2) * big_df.depth + (big_df.seq_len + 2 * big_df.vocab_size) * big_df.width
    
    # Counting FLOPs
    big_df['flops_per_token_att_no_embed'] = 6 * big_df['params_no_embed'] + 6 * big_df.seq_len * big_df.width * big_df.depth
    big_df['flops_per_token_att'] = 6 * big_df['params_active_precise']  + 6 * big_df.seq_len * big_df.width * big_df.depth
    big_df['flops_per_token_cc'] = precise_flops_per_token_chinchilla(big_df['width'], big_df['depth'])
    big_df['flops_per_token_no_att'] = 6 * big_df['params_active_precise']
    big_df['flops_per_token_no_att_no_embed'] = 6 * big_df['params_no_embed']
    big_df['flops_per_token'] = big_df['flops_per_token_no_att']

    big_df['params'] = big_df['flops_per_token'] / 6 
    big_df['eff_params_att'] = big_df['flops_per_token_att'] / 6


    big_df['train/loss_smoothed'] = apply_smoothing_filter(big_df, proportional_sliding_window_filter, compensate_for_logging_delay=True, key='train/loss')
    for k in big_df: 
        if k.startswith('train/') and k.endswith('_loss'):
            big_df[k + '_smoothed'] = apply_smoothing_filter(big_df, proportional_sliding_window_filter, compensate_for_logging_delay=False, key=k)
    return big_df

In [2]:
DEFAULT_PROJECT_DIR = "/fsx-onellm/margaretli/env_srcs/xlf/xlformers_n/scaling/data"

col_names = ['C', 'D', 'N', 'lr', 'Avg Train Loss', 'Max Train Loss', 'C4 Eval PPL', 'Wiki Eval PPL', 'C4 Eval Loss', 'Wiki Eval Loss']

def read_local_data(csv_file, loss_name='C4 Eval Loss', col_names=col_names):
    mins_only = []
    df = pd.read_csv(csv_file, usecols=col_names,)
    df.dropna(subset=[loss_name], inplace=True)

    if 'lr' in df.columns:
        df = df.loc[(df['lr'] >= 0)]
    if 'D' not in df.columns:
        df['D'] = df['C'] / (df['N'] * 6)
        
    n_vals = df['N'].unique()
    d_vals = sorted(df['D'].unique())
    for n in n_vals:
        for d in d_vals:
            cd_df = df[(df['N'] == n) & (df['D'] == d)]
            if cd_df.empty:
                continue
            min_index = cd_df[loss_name].idxmin()
            # print(min_index)
            # print(cd_df)
            # print(cd_df.loc[min_index])
            mins_only.append(cd_df.loc[min_index])

    mins_only_df = pd.DataFrame(mins_only)

    print(mins_only_df)
    # df.rename(columns={})
    return mins_only_df

In [7]:
def get_data(use_data, rsld_config_number=0):
    local_dir = os.path.dirname(os.path.realpath(__file__))
    
    n_name = "N"
    d_name = "D"
    
    if use_data == "ours":
        loss_name = "C4 Eval Loss"
        csv_file = f"{local_dir}/data/data.csv"
        training_df = read_local_data(csv_file=csv_file, loss_name=loss_name, col_names=None)
    elif use_data == "epoch_ai":
        loss_name = "loss"
        csv_file = f"{local_dir}/data/epoch_ai.csv"
        training_df = read_local_data(csv_file=csv_file, loss_name=loss_name, col_names=None)
    elif use_data == "rsld":
        n_name = "params_no_embed"
        d_name = "last_val/loss_D"
        loss_name = "last_val/loss"
        training_df = get_rsld_data(config_number=0)
    elif use_data == "rsld_isoflop":
        n_name = "params_no_embed"
        d_name = "last_val/loss_D"
        loss_name = "last_val/loss"
        training_df = get_rsld_data_isoflop(config_number=0)
    else
        raise Error("type of data to use not recognized")

    N = training_df[n_name].values
    D = training_df[d_name].values
    losses = training_df[loss_name].values
    bootstraps = 4000
    nr_of_models_excluded = 0

    sorted_losses = sorted(losses)
    if nr_of_models_excluded == 0:
        indices = list(range(len(N)))
    else:
        sorted_losses = sorted(losses)
        indices = [i for i in range(len(N)) if losses[i] < sorted_losses[-nr_of_models_excluded]]

    np.random.seed(42)
    random_indices = [np.random.choice(indices, size=len(indices), replace=True) for _ in range(bootstraps)]

    return N, D, losses, indices, random_indices

In [8]:
import autograd.numpy as np
from autograd.scipy.stats import norm
from scipy.optimize import minimize
from scipy.special import erf

true_params = np.array([np.log(406.4), np.log(410.7), np.log(1.69), 0.34, 0.28])

# Define the log-sum-exp function
def log_sum_exp_chinch(a, b, e, alpha, beta, N, D):
    return np.log(np.exp(a - alpha * np.log(N)) + np.exp(b - beta * np.log(D)) + np.exp(e))

def log_sum_exp_kaplan(a, b, e, alpha, beta, N, D):
    return np.log(np.exp((alpha / beta) * (a - np.log(N))) + np.exp(b - np.log(D))) * beta

## TODO @margaretli
def log_sum_exp_kaplan_chinch(a, b, e, alpha, beta, N, D):
    return np.log(np.exp((alpha / beta) * (a - np.log(N))) + np.exp(b - np.log(D)) + np.exp(e)) * beta

# Define the Huber loss function
def custom_huber_loss(y_true, y_pred, delta=1e-3):
    # Calculate the difference
    diff = y_true - y_pred
    # Calculate the condition for Huber loss
    cond = np.abs(diff) <= delta
    # Apply Huber loss formula
    loss = np.where(cond, 0.5 * diff**2, delta * (np.abs(diff) - 0.5 * delta))
    return np.sum(loss)

def huber_normalizing_factor(delta=1e-3):
    return np.sqrt(2*np.pi) * (1 - 2*norm.sf(delta)) + 2 * np.exp(-0.5*delta**2)/delta

def huber_logpdf(x, delta=1e-3, loc=0, scale=1):
    x = (x-loc)/scale

    cond = np.abs(x) <= delta
    loss = np.where(cond, 0.5 * x**2, delta * (np.abs(x) - 0.5 * delta))
    return -loss - np.log(huber_normalizing_factor(delta=delta)) - np.log(scale)

def huber_pdf(x, delta=1e-3, loc=0, scale=1):
    return np.exp(huber_logpdf(x, delta=delta, loc=loc, scale=scale))

# Define the objective function to be minimized
def objective(params, N, D, losses, log_sum_exp_fn=log_sum_exp_chinch):
    a, b, e, alpha, beta, sigma = params
    predictions = log_sum_exp_fn(a, b, e, alpha, beta, N, D)
    return -np.sum(huber_logpdf(np.log(losses), loc=predictions, scale=np.exp(sigma), delta=1e-3))
    # return custom_huber_loss(np.log(losses), predictions, delta=1e-3)

def scale_objective(sigma, params, N, D, losses, log_sum_exp_fn=log_sum_exp_chinch):
    a, b, e, alpha, beta = params
    predictions = log_sum_exp_fn(a, b, e, alpha, beta, N, D)
    return -np.sum(huber_logpdf(np.log(losses), loc=predictions, scale=np.exp(sigma), delta=1e-3))
    # return custom_huber_loss(np.log(losses), predictions, delta=1e-3)

def constant_term_objective(params, a, b, alpha, beta, N, D, losses, log_sum_exp_fn=log_sum_exp_chinch):
    e, sigma = params
    predictions = log_sum_exp_fn(a, b, e, alpha, beta, N, D)
    return -np.sum(huber_logpdf(np.log(losses), loc=predictions, scale=np.exp(sigma), delta=1e-3))

def huber_loss_objective(params, N, D, losses, log_sum_exp_fn=log_sum_exp_chinch):
    a, b, e, alpha, beta = params
    predictions = log_sum_exp_fn(a, b, e, alpha, beta, N, D)
    return custom_huber_loss(np.log(losses), predictions, delta=1e-3)

# Define the parameter untransform
def untransform_params(param_array):
    if len(np.shape(param_array)) == 2:
      return np.hstack((np.exp(param_array[:, :3]), param_array[:, 3:]))
    else:
      return np.hstack((np.exp(param_array[:3]), param_array[3:]))

# Define the Huber loss function on residuals
def huber_loss(residuals, delta=1e-3):
    # Calculate the difference
    diff = residuals
    # Calculate the condition for Huber loss
    cond = np.abs(diff) <= delta
    # Apply Huber loss formula
    loss = np.where(cond, 0.5 * diff**2, delta * (np.abs(diff) - 0.5 * delta))
    return loss

In [9]:
# import numpy as np
import heapq

class PQItem(object):
    def __init__(self, loss, params):
        self.loss = loss
        self.params = params

    def __lt__(self, other):
        return self.loss > other.loss # reversed because we want to retain lower loss params


def fit_from_scratch(N, D, losses, indices, obj=huber_loss_objective, method='BFGS', use_grad=True, add_sigma=False):
    # Set up the grid for initial parameter values
    # alpha_vals = np.arange(0, 2.5, 0.5)
    # beta_vals = np.arange(0, 2.5, 0.5)
    # e_vals = np.arange(-1, 1.5, 0.5)
    # a_vals = np.arange(0, 30, 5)
    # b_vals = np.arange(0, 30, 5)

    a_vals = np.arange(0, 30, 1) 
    b_vals = np.arange(0, 30, 1) 
    e_vals = np.arange(-1, 1.5, 0.1) 
    alpha_vals = np.arange(0, 2.5, 0.1)
    beta_vals = np.arange(0, 2.5, 0.1)

    # Perform the optimization using L-BFGS over the grid of initial values
    best_loss = np.inf
    best_params = None
    pq = []
    i = 0

    grid = np.array(np.meshgrid(
        a_vals, b_vals, e_vals, alpha_vals, beta_vals
    )).T.reshape(-1, 5)
    # pbar = tqdm(total=len(grid))
    i = 0
    np.random.shuffle(grid)
    
    from itertools import product
    results_dict = {}
    for init_params in grid:
    # for alpha, beta, e, a, b in product(alpha_vals, beta_vals, e_vals, a_vals, b_vals):
        # init_params = [a, b, e, alpha, beta]
        if add_sigma:
            init_params = init_params + [0]
        result = minimize(obj, init_params, args=(N[indices], D[indices], losses[indices]), method=method, jac=grad(obj) if use_grad else None)
        results_dict[tuple(init_params)] = {'params': result.x, 'loss': result.fun}
        if result.success and result.fun < best_loss:
            best_loss = result.fun
            best_params = result.x

        if len(pq) < 100:
            heapq.heappush(pq, PQItem(result.fun, result.x))
        elif result.fun < pq[0].loss:
            heapq.heappushpop(pq, PQItem(result.fun, result.x))

        # pbar.update(1)
        # if pbar.n % 1000 == 0:
        #     print("best loss: {} : best params : {}".format(best_loss, best_params))
        # if pbar.n == 100_000:
        #     break
        i += 1
        if i % 1000 == 0:
            print("best loss: {} : best params : {}".format(best_loss, best_params))
        if i == 10_000:
            break

    largest = heapq.nlargest(100, pq)
    for item in largest:
        print(round(item.loss, 6), [round(n, 2) for n in item.params])

    # Transform the fitted parameters a, b, e to A, B, E
    if best_params is not None:
        A = np.exp(best_params[0])
        B = np.exp(best_params[1])
        E = np.exp(best_params[2])
        alpha = best_params[3]
        beta = best_params[4]
        print(f"Best fit parameters: A={A}, B={B}, E={E}, alpha={alpha}, beta={beta}")
        print_opts(best_params)
        print(f"Best loss: {best_loss}")
    else:
        print("Optimization failed to converge.")

def fit_from_chinchilla_random(N, D, losses, random_indices, obj=huber_loss_objective, method='BFGS', use_grad=True, add_sigma=False):
    # Set up the grid for initial parameter values
    param_list = []

    for num, indices in enumerate(random_indices):
    # Perform the optimization using BFGS
        best_loss = np.inf
        best_params = None

        init_params = list(true_params)
            
        if add_sigma:
            init_params = init_params + [0]

        result = minimize(obj, init_params, args=(N[indices], D[indices], losses[indices]), \
                            jac=grad(obj) if use_grad else None, method=method)

        best_loss = result.fun
        best_params = result.x
        #print(f"New best loss: {best_loss}")
        #print(f"Best params: {best_params}")

        if num % 1000 == 999:
            print("Bootstrap step %d completed" % (num+1))

        param_list.append(result.x)

    param_list = np.array(param_list)
    cov_matrix = np.cov(np.transpose(param_list))

    if best_params is not None:
        A = np.exp(best_params[0])
        B = np.exp(best_params[1])
        E = np.exp(best_params[2])
        alpha = best_params[3]
        beta = best_params[4]
        print(f"Best fit parameters: A={A}, B={B}, E={E}, alpha={alpha}, beta={beta}")
        print_opts(best_params)
        print(f"Best loss: {best_loss}")
    else:
        print("Optimization failed to converge.")

def fit_from_init(N, D, losses, indices, obj=huber_loss_objective, method='BFGS', use_grad=True, add_sigma=False, init_params=None):

    if add_sigma:
        init_params = init_params + [0]
        
    result = minimize(obj, init_params, args=(N[indices], D[indices], losses[indices]), method=method, jac=grad(obj) if use_grad else None)

    print(result)
    print(result.x)
    best_loss = result.fun

    estimated_params = result.x[:5]
    # best_params = untransform_params(estimated_params)
    best_params = estimated_params
    if best_params is not None:
        A = np.exp(best_params[0])
        B = np.exp(best_params[1])
        E = np.exp(best_params[2])
        alpha = best_params[3]
        beta = best_params[4]
        print(f"Best fit parameters: A={A}, B={B}, E={E}, alpha={alpha}, beta={beta}")
        print_opts(best_params)
        print(f"Best loss: {best_loss}")
    else:
        print("Optimization failed to converge.")

In [10]:
# import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from itertools import product
import heapq

class PQItem(object):
    def __init__(self, loss, params):
        self.loss = loss
        self.params = params

    def __lt__(self, other):
        return self.loss > other.loss # reversed because we want to retain lower loss params

# custom_huber_loss(np.log(losses), predictions, delta=1e-3)

def fit_init_only_with_grid_search(
        N, D, data_losses,
        obj_func=log_sum_exp_chinch, loss_fn=(lambda x,y : np.sum((x - y)**2)),
        a=None, b=None, e=None, alpha=None, beta=None
    ):
    # a_vals = np.arange(0, 30, 0.3) if a is None else a
    # b_vals = np.arange(0, 30, 0.3) if b is None else b
    # e_vals = np.arange(-1, 1.5, 0.025) if e is None else e
    # alpha_vals = np.arange(0, 2.5, 0.025) if alpha is None else alpha
    # beta_vals = np.arange(0, 2.5, 0.025) if beta is None else beta

    a_vals = np.arange(0, 30, 1) if a is None else a
    b_vals = np.arange(0, 30, 1) if b is None else b
    e_vals = np.arange(-1, 1.5, 0.1) if e is None else e
    alpha_vals = np.arange(0, 2.5, 0.1) if alpha is None else alpha
    beta_vals = np.arange(0, 2.5, 0.1) if beta is None else beta

    best_params = None
    best_loss = np.inf
    
    grid = np.array(np.meshgrid(
        a_vals, b_vals, e_vals, alpha_vals, beta_vals
    )).T.reshape(-1, 5)
    # pbar = tqdm(total=len(grid))
    i = 0
    np.random.shuffle(grid)

    pq = []

    for params in grid:
        pred_losses = np.exp(obj_func(*params, N, D))
        loss = loss_fn(pred_losses, data_losses)

        if loss < best_loss:
            best_loss = loss
            best_params = params

        if len(pq) < 100:
            heapq.heappush(pq, PQItem(loss, params))
        elif loss < pq[0].loss:
            heapq.heappushpop(pq, PQItem(loss, params))

        i += 1
        if i % 1000 == 0:
            print("best loss: {} : best params : {}".format(best_loss, best_params))
            # pass
        if i == 1_000_000:
            break

    largest = heapq.nlargest(100, pq)
    for item in largest:
        print(round(item.loss, 4), [round(n, 2) for n in item.params])

    return pq, best_loss, best_params



In [11]:
N, D, losses, indices, random_indices = get_data("epoch_ai")
# N, D, losses, indices, random_indices = get_data("rsld", rsld_config_number=3)


             x          y    color             N             C hex_color  \
0    154.03592  140.63364  #faebdd  6.795600e+09  9.993853e+18   #faebdd   
1    151.60375  175.51709  #f8d1b8  2.979521e+09  9.227541e+18   #f8d1b8   
6    223.29648  175.51709  #931c5b  2.979521e+09  9.691017e+19   #931c5b   
2    153.29233  180.65761  #f47a54  2.638631e+09  9.753047e+18   #f47a54   
7    223.74805  180.65761  #8c1d5b  2.638631e+09  9.835628e+19   #8c1d5b   
..         ...        ...      ...           ...           ...       ...   
226  254.57913  244.27495  #781f59  5.865994e+08  2.703956e+20   #781f59   
227  256.05204  257.94760  #811e5a  4.246106e+08  2.837799e+20   #811e5a   
234  234.10686  252.01338  #871e5b  4.885467e+08  1.381549e+20   #871e5b   
236  248.18209  239.03398  #7a1f59  6.639580e+08  2.192159e+20   #7a1f59   
244  372.54603  140.63355  #34193d  6.795615e+09  1.295602e+22   #34193d   

         loss             D  
0    5.005582  2.451060e+08  
1    4.665232  5.161647e+08

In [13]:
print(D)

[4.73169920e+09 1.93173914e+10 6.97250611e+09 2.86271734e+10
 9.37112371e+09 1.22935050e+10 6.16615117e+09 6.13731533e+09
 6.08226509e+09 6.44349952e+09 5.66021325e+09 1.14677514e+10
 4.33533747e+09 7.29913754e+09 1.69413181e+10 2.16635802e+10]


In [14]:
pq, best_loss, best_params = fit_init_only_with_grid_search(N, D, losses)

best loss: 0.7975132762123402 : best params : [ 6.  15.   1.1  0.4  1.9]
best loss: 0.7975132762123402 : best params : [ 6.  15.   1.1  0.4  1.9]
best loss: 0.7660174186969138 : best params : [ 7.  18.   1.2  0.5  2.1]
best loss: 0.7660174186969138 : best params : [ 7.  18.   1.2  0.5  2.1]
best loss: 0.5926721001750532 : best params : [ 3.  11.   1.1  0.2  1.6]
best loss: 0.5926721001750532 : best params : [ 3.  11.   1.1  0.2  1.6]
best loss: 0.5926721001750532 : best params : [ 3.  11.   1.1  0.2  1.6]
best loss: 0.5926721001750532 : best params : [ 3.  11.   1.1  0.2  1.6]
best loss: 0.5926721001750532 : best params : [ 3.  11.   1.1  0.2  1.6]
best loss: 0.5926721001750532 : best params : [ 3.  11.   1.1  0.2  1.6]
best loss: 0.5926721001750532 : best params : [ 3.  11.   1.1  0.2  1.6]
best loss: 0.5926721001750532 : best params : [ 3.  11.   1.1  0.2  1.6]
best loss: 0.5926721001750532 : best params : [ 3.  11.   1.1  0.2  1.6]
best loss: 0.5926721001750532 : best params : [ 3. 

In [28]:
fit_from_scratch(N, D, losses, indices, obj=huber_loss_objective, method='L-BFGS-B', use_grad=False)


best loss: 3.787945545510745e-05 : best params : [ 2.55680649 15.01163762  0.57347052  0.11964638  0.74169355]
best loss: 3.787945545510745e-05 : best params : [ 2.55680649 15.01163762  0.57347052  0.11964638  0.74169355]
best loss: 3.786205707473094e-05 : best params : [ 2.58213038 13.97855658  0.58821181  0.12227579  0.69480374]
best loss: 3.784929045224317e-05 : best params : [ 2.60025232 13.96955753  0.60120201  0.12415799  0.69483638]
best loss: 3.784929045224317e-05 : best params : [ 2.60025232 13.96955753  0.60120201  0.12415799  0.69483638]
best loss: 3.784929045224317e-05 : best params : [ 2.60025232 13.96955753  0.60120201  0.12415799  0.69483638]
best loss: 3.784929045224317e-05 : best params : [ 2.60025232 13.96955753  0.60120201  0.12415799  0.69483638]
best loss: 3.784929045224317e-05 : best params : [ 2.60025232 13.96955753  0.60120201  0.12415799  0.69483638]
best loss: 3.784929045224317e-05 : best params : [ 2.60025232 13.96955753  0.60120201  0.12415799  0.69483638]
b

In [16]:
# np.array([np.log(406.4), np.log(410.7), np.log(1.69), 0.34, 0.28])
# array([6.0073379 , 6.01786302, 0.52472853, 0.34      , 0.28      ])
init_params = [6.0073379 , 6.01786302, 0.52472853, 0.34, 0.28]
# fit_from_init(obj=huber_loss_objective, method='L-BFGS-B', use_grad=False, init_params=list(true_params))
fit_from_init(N, D, losses, indices, obj=huber_loss_objective, method='L-BFGS-B', use_grad=False, init_params=init_params)

  message: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
  success: True
   status: 0
      fun: 3.9903721089100134e-05
        x: [ 2.687e+00  5.409e+00  5.737e-01  1.329e-01  2.900e-01]
      nit: 90
      jac: [ 6.143e-06  2.748e-07  7.069e-06 -1.003e-04 -2.327e-05]
     nfev: 858
     njev: 143
 hess_inv: <5x5 LbfgsInvHessProduct with dtype=float64>
[2.68682667 5.40857773 0.57373204 0.13289562 0.28997104]
Best fit parameters: A=14.685001525287666, B=223.31374881887865, E=1.7748786232868297, alpha=0.1328956236915945, beta=0.2899710402320098
Scaling: 
         compute parameters (B) tokens (B) ratio predicted loss
0  1.250000e+18           0.19       1.10  5.76           3.47
1  5.010000e+18           0.49       1.69  3.44           3.27
2  1.980000e+19           1.26       2.61  2.06           3.10
3  1.000000e+21          18.62       8.95  0.48           2.70
4  1.000000e+23         438.04      38.05  0.09           2.38
Best loss: 3.9903721089100134e-05


In [17]:
huber_loss_objective(true_params, N, D, losses)

0.0021415626088716454

In [18]:
fit_from_scratch(N, D, losses, indices, obj=huber_loss_objective, method='L-BFGS-B', use_grad=False)

best loss: 3.789252564543738e-05 : best params : [ 2.6334663  14.6751309   0.62756515  0.12769534  0.72791943]
best loss: 3.788544711694731e-05 : best params : [ 2.62356022 13.08068485  0.61289141  0.12651906  0.65439012]
best loss: 3.785837821408434e-05 : best params : [ 2.61180955 14.01415087  0.60943915  0.12535571  0.69711318]


best loss: 3.785837821408434e-05 : best params : [ 2.61180955 14.01415087  0.60943915  0.12535571  0.69711318]
best loss: 3.785837821408434e-05 : best params : [ 2.61180955 14.01415087  0.60943915  0.12535571  0.69711318]
best loss: 3.785837821408434e-05 : best params : [ 2.61180955 14.01415087  0.60943915  0.12535571  0.69711318]
best loss: 3.785837821408434e-05 : best params : [ 2.61180955 14.01415087  0.60943915  0.12535571  0.69711318]
best loss: 3.785262501411214e-05 : best params : [ 2.60607582 14.01419296  0.60550674  0.12475928  0.69704812]
best loss: 3.785262501411214e-05 : best params : [ 2.60607582 14.01419296  0.60550674  0.12475928  0.69704812]
best loss: 3.785262501411214e-05 : best params : [ 2.60607582 14.01419296  0.60550674  0.12475928  0.69704812]
3.8e-05 [2.61, 14.01, 0.61, 0.12, 0.7]
3.8e-05 [2.59, 14.62, 0.6, 0.12, 0.72]
3.8e-05 [2.61, 14.01, 0.61, 0.13, 0.7]
3.8e-05 [2.57, 15.01, 0.58, 0.12, 0.74]
3.8e-05 [2.62, 13.64, 0.61, 0.13, 0.68]
3.8e-05 [2.61, 13.34, 0.61

In [19]:
fit_from_scratch(N, D, losses, indices, obj=huber_loss_objective, method='L-BFGS-B', add_sigma=True, use_grad=False)

best loss: 3.787865285492144e-05 : best params : [ 2.61536334 13.17836074  0.60802206  0.12569826  0.65869289]
best loss: 3.7868732772649915e-05 : best params : [ 2.61919501 13.48844606  0.61195846  0.12609469  0.67303807]
best loss: 3.785211688583612e-05 : best params : [ 2.59429889 13.88903306  0.59651292  0.12353386  0.69097174]
best loss: 3.7851041770112685e-05 : best params : [ 2.6030215  14.00963747  0.60357835  0.12445792  0.69678976]
best loss: 3.7851041770112685e-05 : best params : [ 2.6030215  14.00963747  0.60357835  0.12445792  0.69678976]
best loss: 3.7851041770112685e-05 : best params : [ 2.6030215  14.00963747  0.60357835  0.12445792  0.69678976]
best loss: 3.7851041770112685e-05 : best params : [ 2.6030215  14.00963747  0.60357835  0.12445792  0.69678976]
best loss: 3.784998907087201e-05 : best params : [ 2.60202834 14.05801008  0.60305319  0.12436498  0.69892796]
best loss: 3.784998907087201e-05 : best params : [ 2.60202834 14.05801008  0.60305319  0.12436498  0.698927

In [20]:
fit_from_chinchilla_random(N, D, losses, random_indices, obj=huber_loss_objective, method='BFGS', use_grad=False)

Bootstrap step 1000 completed
Bootstrap step 2000 completed
Bootstrap step 3000 completed
Bootstrap step 4000 completed
Best fit parameters: A=10.80565184461938, B=303.218332151444, E=1.1907012653301834, alpha=0.09982740395511044, beta=0.2855258906813433
Scaling: 
         compute parameters (B) tokens (B)  ratio predicted loss
0  1.250000e+18           0.08       2.68  34.57           3.57
1  5.010000e+18           0.22       3.85  17.71           3.34
2  1.980000e+19           0.60       5.49   9.13           3.13
3  1.000000e+21          10.99      15.16   1.38           2.64
4  1.000000e+23         333.40      49.99   0.15           2.22
Best loss: 4.745822146499232e-05


In [21]:
fit_from_scratch(N, D, losses, indices, obj=objective, method='L-BFGS-B', add_sigma=True, use_grad=False)

ValueError: not enough values to unpack (expected 6, got 5)

In [12]:
fit_from_chinchilla_random(N, D, losses, random_indices, obj=objective, method='BFGS', add_sigma=True, use_grad=True)