# Setup (run this, no need to read)

## General imports

In [None]:
import pandas as pd
import numpy as np
from copy import deepcopy
import random
import seaborn as sns
import re

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
!pip install matplotlib-label-lines
import labellines
import plotly.express as px

Collecting matplotlib-label-lines
  Downloading matplotlib_label_lines-0.7.0-py3-none-any.whl (12 kB)
Installing collected packages: matplotlib-label-lines
Successfully installed matplotlib-label-lines-0.7.0


In [None]:
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')

In [None]:
from scipy.optimize import minimize, minimize_scalar
from scipy.stats import norm
from sklearn.metrics import r2_score

import json
from google.colab import files

In [None]:
np.seterr(over='ignore') # models 14 and 15

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}

## Data

In [None]:
# load primary algorithmic progress dataset
# path = "https://docs.google.com/spreadsheets/d/11m8O_mU0cUkOB_5wluPne4PNsuvsKNbbVAzbYNy-NXY/edit#gid=91564213"
# path = "https://docs.google.com/spreadsheets/d/1NQh5XJdMjuoZ7Brb39rORCaBveWTRGT0TVcTYWUo6tU/edit#gid=91564213"
path = "https://docs.google.com/spreadsheets/d/11m8O_mU0cUkOB_5wluPne4PNsuvsKNbbVAzbYNy-NXY/edit#gid=2087221150"
path = path.replace("edit#", "export?") + "&format=csv"
df = pd.read_csv(path, parse_dates=True)

# remove unnecessary columns
columns_to_remove = ['Author(s)', 'Link', 'Hardware', 'Base Model', 'GitHub', 'Comments', 'Organizations', 'Organization Categorization', 'Comments']
df = df.drop(columns=columns_to_remove)

# rename columns
column_renames = {
    'Publication date': 'publication_date',
    'Parameters': 'param',
    'Perplexity (WT103)': 'ppl_wt103',
    'Perplexity (WT2)': 'ppl_wt2',
    'Perplexity (PTB)': 'ppl_ptb',
    'Dataset Size': 'dataset',
    'System': 'system',
    'Epoch': 'epoch',
    'Include?': 'include',
    'Zero-shot?': 'zero_shot',
    'Citations': 'cites',
    'Peer reviewed?': 'peer_reviewed',
    'Outlier?': 'outlier'
}
df = df.rename(columns=column_renames)

# convert publication date to datetime format
def convert_to_fraction_of_year(date):
    return date.year + (date.dayofyear - 1) / 365.0
df['publication_date'] = pd.to_datetime(df['publication_date'], format='%Y/%m/%d', errors='coerce')
df['publication_date'] = df['publication_date'].apply(convert_to_fraction_of_year)

# convert param and data to numeric values
df['param'] = pd.to_numeric(df['param'], errors='coerce')
df['dataset'] = pd.to_numeric(df['dataset'], errors='coerce')

# drop rows when 0 or NaN
df.dropna(subset=['param', 'dataset'], inplace=True)
df = df.loc[(df['param'] > 0) & (df['dataset'] > 0)]
df = df.loc[(df['include'] != 0)]
df = df.loc[(df['outlier'] != 1)]

df.dropna(subset=['ppl_wt103', 'ppl_wt2', 'ppl_ptb'], how='all', inplace=True)
df.reset_index(drop=True, inplace=True) # reset the index of the dataframe after dropping rows

# perplexity data
df['ppl'] = df['ppl_wt103'].fillna(df['ppl_wt2']).fillna(df['ppl_ptb']) # single ppl column

def safe_log(column): # modified logarithm for handling possible zero values
    return np.where(column != 0, np.log(column), np.nan)
for col in ['param', 'dataset', 'ppl_wt103', 'ppl_wt2', 'ppl_ptb', 'ppl']:
    df[f'log_{col}'] = safe_log(df[col])

Prepare data for algorithmic progress analysis. This is where we exclude models where the data is particularly uncertain or are clear outliers.

In [None]:
# filter models and form new df
df_wt103 = df.dropna(subset=['log_ppl_wt103'])
df_wt2 = df.dropna(subset=['log_ppl_wt2'])
df_ptb = df.dropna(subset=['log_ppl_ptb'])
df1 = pd.concat([
    df_wt103.assign(dataset_name='wt103'),
    df_wt2.assign(dataset_name='wt2'),
    df_ptb.assign(dataset_name='ptb')
])
len_before_filter = len(df1)

# name systems based on dataset
df1['system_dataset'] = df1['system'].str.cat(df1['dataset_name'], sep='_')
df1['dataset_name'] = df1['dataset_name'].astype('category') # convert to categorical
dataset_idx = df1['dataset_name'].cat.codes.values

# create dummies for benchmarks
df1['ptb_dummy'] = (df1['dataset_name'] == 'ptb').astype(int)
df1['wt2_dummy'] = (df1['dataset_name'] == 'wt2').astype(int)

# drop models
df1= df1[(df1['uncertain']==0)]
columns_to_check = ['log_param', 'log_dataset', 'publication_date', 'ppl']
df1 = df1[~df1[columns_to_check].replace([np.inf, -np.inf], np.nan).isnull().any(axis=1)]

# sort by reference and ppl
df1 = df1.sort_values(['Reference', 'ppl'], ascending=[True, True])

In [None]:
mask = df1['system'] == "GPT3-6.7B + muP"
index_to_drop = df1[mask].index
df1 = df1.drop(index_to_drop)

mask = df1['system'] == "LLaMA-65B (LoRA finetuned)"
index_to_drop = df1[mask].index
df1 = df1.drop(index_to_drop)

mask = df1['system'] == "LLaMA-13B (LoRA finetuned)"
index_to_drop = df1[mask].index
df1 = df1.drop(index_to_drop)

mask = df1['system'] == "LLaMA-7B (LoRA finetuned)"
index_to_drop = df1[mask].index
df1 = df1.drop(index_to_drop)

# Print original 'param' values for 'Gopher (280B)' and 'Gopher (7.1B)'
print("Original 'param' values:")
print("Gopher (280B):", df1.loc[df1['system'] == 'Gopher (280B)', 'param'].values)
print("Gopher (7.1B):", df1.loc[df1['system'] == 'Gopher (7.1B)', 'param'].values)

# Apply the changes
df1.loc[df1['system'] == 'Gopher (280B)', 'param'] = 280e9
df1.loc[df1['system'] == 'Gopher (7.1B)', 'param'] = 7.1e9

# Print updated 'param' values for 'Gopher (280B)' and 'Gopher (7.1B)'
print("\nUpdated 'param' values:")
print("Gopher (280B):", df1.loc[df1['system'] == 'Gopher (280B)', 'param'].values)
print("Gopher (7.1B):", df1.loc[df1['system'] == 'Gopher (7.1B)', 'param'].values)

Original 'param' values:
Gopher (280B): [2.8e+11]
Gopher (7.1B): [7.1e+09]

Updated 'param' values:
Gopher (280B): [2.8e+11]
Gopher (7.1B): [7.1e+09]


One problem that we encounter is managing correlation between the models, due to multiple models coming from the same paper. To try and fix this issue we consider two possible approaches:
1. Keeping only the top 3 models per paper
2. Keeping track of the clusters and adjusting for this

The code below collects the data for this. We primarily use the first approach, and we show that our results are unchanged when we use the second approach in the appendix.

In [None]:
# keep only the top 3 models per paper
df_head = df1.copy(deep=True)
df_head = df_head.groupby('Reference').head(3)
df_head.reset_index(drop=True, inplace=True)
print("Total dropped rows (head):", len_before_filter - len(df_head))

# cluster the models by paper and number the clusters
df_cluster = df1.copy(deep=True)
df_cluster['cluster'] = df_cluster.groupby('Reference').ngroup()
df_cluster.reset_index(drop=True, inplace=True)
print("Total dropped rows (cluster):", len_before_filter - len(df_cluster))

Total dropped rows (head): 50
Total dropped rows (cluster): 20


## Useful functions

In [None]:
def doubling_to_x_per_year(doubling):
    """
    convert doubling times to Nx/year,
    e.g. if doubling time = 1 year, this returns 2x/year

    doubling time argument should be in years
    """
    if doubling is None: return None
    return 2 ** np.reciprocal(doubling)

def doubling_to_oom(doubling):
    """
    convert doubling times to OOM/year
    """
    if doubling is None: return None
    return np.reciprocal(doubling) * np.log10(2)

def oom_to_doubling(oom_per_year):
    """
    convert from OOM/year to doubling time in years
    """
    return np.log10(2) / oom_per_year

def data_filter(values, confidence_interval):
    """
    filter values to the given confidence interval
    """
    values = np.array(values)
    lower, upper = np.percentile(values, [50 - confidence_interval/2, 50 + confidence_interval/2])
    mask = (values > lower) & (values < upper)
    return values[mask]

def variance(weights, vector1, vector2, vector3):
    """
    function for aggregating doubling times from three different models (one array per model)
    chooses weights in a convex combination to obtain a pooled score vector,
    then calculate the variance of this vector
    """
    pool = weights[0] * vector1 + weights[1] * vector2 + weights[2] * vector3
    return np.var(pool)

def prime(param, param_ptb, param_wt2, category_ptb, category_wt2):
    return param + param_ptb * category_ptb + param_wt2 * category_wt2

def log_diff(value, constant):
    return np.log(value) - np.log(constant)

def build_model(alpha_terms, beta_terms):
    return np.exp(alpha_terms) + np.exp(beta_terms)

def print_stats(data, ci):

  if len(data) == 2:
    nt, t = data
    print("Non-transformer", np.percentile(nt, [50-ci/2, 50, 50+ci/2]))
    print("Transformer", np.percentile(t, [50-ci/2, 50, 50+ci/2]))

  elif len(data) == 3:
    wt103, ptb, wt2 = data
    print("WT103", np.percentile(wt103, [50-ci/2, 50, 50+ci/2]))
    print("PTB", np.percentile(ptb, [50-ci/2, 50, 50+ci/2]))
    print("WT2", np.percentile(wt2, [50-ci/2, 50, 50+ci/2]))

  else:
    print(np.percentile(data, [50-ci/2, 50, 50+ci/2]))

def doubling_times_plot(model_doubling_times, ci, ax=None, num_bootstraps=100, fontsize=12, legend=True, **kwargs):

  title = kwargs.get('title',"")
  xlabel = kwargs.get('xlabel', 'Doubling Time (Years)')
  ylabel = kwargs.get('ylabel', 'Density')
  label = kwargs.get('label', None)
  xlim = kwargs.get('xlim', None)
  xscale = kwargs.get('xscale', 'linear')

  if model_doubling_times is None:
    # sns.kdeplot(np.zeros(num_bootstraps))
    pass

  elif len(model_doubling_times) == 2:
    model_doubling_nt, model_doubling_t = model_doubling_times
    sns.kdeplot(data_filter(model_doubling_nt, ci), label=f'nt', ax=ax)
    sns.kdeplot(data_filter(model_doubling_t, ci), label=f't', ax=ax)
    if legend:
      if ax != None: ax.legend()

  elif len(model_doubling_times) == 3:
    model_doubling_wt103, model_doubling_ptb, model_doubling_wt2 = model_doubling_times
    sns.kdeplot(data_filter(model_doubling_wt103, ci), label=f'wt103', ax=ax)
    sns.kdeplot(data_filter(model_doubling_ptb, ci), label=f'ptb', ax=ax)
    sns.kdeplot(data_filter(model_doubling_wt2, ci), label=f'wt2', ax=ax)
    if legend:
      if ax != None: ax.legend()

  else:
    sns.kdeplot(data_filter(model_doubling_times, ci), ax=ax, label=label)

  if ax != None:
    ax.set_title(title, fontsize=fontsize)
    ax.set_xlabel(xlabel, fontsize=fontsize)
    ax.set_ylabel(ylabel, fontsize=fontsize)
    ax.set_xscale(xscale)
    ax.legend()
    if xlim is not None:
      x_min, x_max = xlim
      ax.set_xlim(x_min, x_max)

  else:
    plt.title(title, fontsize=fontsize)
    plt.xlabel(xlabel, fontsize=fontsize)
    plt.ylabel(ylabel, fontsize=fontsize)
    plt.xscale(xscale)
    plt.legend()
    if xlim is not None:
      x_min, x_max = xlim
      plt.xlim(x_min, x_max)

In [None]:
def estimate_doubling_times(model_num, bootstrap_results):
    param_names = PARAMS_MAPPING[model_num]
    bootstrap_array = np.array(bootstrap_results).T
    pred_params = {name: bootstrap_array[i] for i, name in enumerate(param_names)}

    def doubling_times(scale_exponent, year_exponent):
      return scale_exponent / year_exponent * np.log(2)

    param_doubling, data_doubling, compute_doubling = None, None, None

    if model_num in {1, 16}:
      param_doubling = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"])
      data_doubling = doubling_times(pred_params["beta_data"], pred_params["beta_year"])
      compute_doubling = (1/param_doubling + 1/data_doubling) ** (-1)

    elif model_num in {2, 17}:
      data_doubling = doubling_times(pred_params["beta_data"], pred_params["beta_year"])
      compute_doubling = data_doubling

    elif model_num in {3, 18}:
      param_doubling = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"])
      compute_doubling = param_doubling

    elif model_num in {4, 19}:
      param_doubling_wt103 = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"])
      param_doubling_ptb = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"] + pred_params["alpha_year_ptb"])
      param_doubling_wt2 = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"] + pred_params["alpha_year_wt2"])
      param_doubling = (param_doubling_wt103, param_doubling_ptb, param_doubling_wt2)

      data_doubling = doubling_times(pred_params["beta_data"], pred_params["beta_year"])

      compute_doubling_wt103 = (1/param_doubling_wt103 + 1/data_doubling) ** (-1)
      compute_doubling_ptb = (1/param_doubling_ptb + 1/data_doubling) ** (-1)
      compute_doubling_wt2 = (1/param_doubling_wt2 + 1/data_doubling) ** (-1)
      compute_doubling = (compute_doubling_wt103, compute_doubling_ptb, compute_doubling_wt2)

    elif model_num in {5, 20}:
      param_doubling = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"])

      data_doubling_wt103 = doubling_times(pred_params["beta_data"], pred_params["beta_year"])
      data_doubling_ptb = doubling_times(pred_params["beta_data"], pred_params["beta_year"] + pred_params["beta_year_ptb"])
      data_doubling_wt2 = doubling_times(pred_params["beta_data"], pred_params["beta_year"] + pred_params["beta_year_wt2"])
      data_doubling = (data_doubling_wt103, data_doubling_ptb, data_doubling_wt2)

      compute_doubling_wt103 = (1/param_doubling + 1/data_doubling_wt103) ** (-1)
      compute_doubling_ptb = (1/param_doubling + 1/data_doubling_ptb) ** (-1)
      compute_doubling_wt2 = (1/param_doubling + 1/data_doubling_wt2) ** (-1)
      compute_doubling = (compute_doubling_wt103, compute_doubling_ptb, compute_doubling_wt2)

    elif model_num in {6, 21}:
      param_doubling_wt103 = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"])
      param_doubling_ptb = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"] + pred_params["alpha_year_ptb"])
      param_doubling_wt2 = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"] + pred_params["alpha_year_wt2"])
      param_doubling = (param_doubling_wt103, param_doubling_ptb, param_doubling_wt2)

      data_doubling_wt103 = doubling_times(pred_params["beta_data"], pred_params["beta_year"])
      data_doubling_ptb = doubling_times(pred_params["beta_data"], pred_params["beta_year"] + pred_params["beta_year_ptb"])
      data_doubling_wt2 = doubling_times(pred_params["beta_data"], pred_params["beta_year"] + pred_params["beta_year_wt2"])
      data_doubling = (data_doubling_wt103, data_doubling_ptb, data_doubling_wt2)

      compute_doubling_wt103 = (1/param_doubling_wt103 + 1/data_doubling_wt103) ** (-1)
      compute_doubling_ptb = (1/param_doubling_ptb + 1/data_doubling_ptb) ** (-1)
      compute_doubling_wt2 = (1/param_doubling_wt2 + 1/data_doubling_wt2) ** (-1)
      compute_doubling = (compute_doubling_wt103, compute_doubling_ptb, compute_doubling_wt2)

    elif model_num in {7, 22}:
      param_doubling = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"])
      data_doubling = doubling_times(pred_params["beta_data"], pred_params["beta_year"])
      compute_doubling = (1/param_doubling + 1/data_doubling) ** (-1)

    elif model_num in {8, 23}:
      data_doubling = doubling_times(pred_params["beta_data"], pred_params["beta_year"])
      compute_doubling = data_doubling

    elif model_num in {9, 24}:
      param_doubling = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"])
      compute_doubling = param_doubling

    elif model_num in {10, 25}:
      param_doubling_wt103 = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"])
      param_doubling_ptb = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"] + pred_params["alpha_year_ptb"])
      param_doubling_wt2 = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"] + pred_params["alpha_year_wt2"])
      param_doubling = (param_doubling_wt103, param_doubling_ptb, param_doubling_wt2)

      data_doubling_wt103 = doubling_times(pred_params["beta_data"], pred_params["beta_year"])
      data_doubling_ptb = doubling_times(pred_params["beta_data"], pred_params["beta_year"] + pred_params["beta_year_ptb"])
      data_doubling_wt2 = doubling_times(pred_params["beta_data"], pred_params["beta_year"] + pred_params["beta_year_wt2"])
      data_doubling = (data_doubling_wt103, data_doubling_ptb, data_doubling_wt2)

      compute_doubling_wt103 = (1/param_doubling_wt103 + 1/data_doubling_wt103) ** (-1)
      compute_doubling_ptb = (1/param_doubling_ptb + 1/data_doubling_ptb) ** (-1)
      compute_doubling_wt2 = (1/param_doubling_wt2 + 1/data_doubling_wt2) ** (-1)
      compute_doubling = (compute_doubling_wt103, compute_doubling_ptb, compute_doubling_wt2)

    elif model_num in {11, 26}:
      param_doubling_wt103 = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"])
      param_doubling_ptb = doubling_times(pred_params["alpha_param"] + pred_params["alpha_param_ptb"], pred_params["alpha_year"] + pred_params["alpha_year_ptb"])
      param_doubling_wt2 = doubling_times(pred_params["alpha_param"] + pred_params["alpha_param_wt2"], pred_params["alpha_year"] + pred_params["alpha_year_wt2"])
      param_doubling = (param_doubling_wt103, param_doubling_ptb, param_doubling_wt2)

      data_doubling_wt103 = doubling_times(pred_params["beta_data"], pred_params["beta_year"])
      data_doubling_ptb = doubling_times(pred_params["beta_data"] + pred_params["beta_data_ptb"], pred_params["beta_year"] + pred_params["beta_year_ptb"])
      data_doubling_wt2 = doubling_times(pred_params["beta_data"] + pred_params["beta_data_wt2"], pred_params["beta_year"] + pred_params["beta_year_wt2"])
      data_doubling = (data_doubling_wt103, data_doubling_ptb, data_doubling_wt2)

      compute_doubling_wt103 = (1/param_doubling_wt103 + 1/data_doubling_wt103) ** (-1)
      compute_doubling_ptb = (1/param_doubling_ptb + 1/data_doubling_ptb) ** (-1)
      compute_doubling_wt2 = (1/param_doubling_wt2 + 1/data_doubling_wt2) ** (-1)
      compute_doubling = (compute_doubling_wt103, compute_doubling_ptb, compute_doubling_wt2)

    elif model_num in {12, 27}:
      param_doubling = doubling_times(pred_params["alpha_param"], pred_params["alpha_year"])
      data_doubling = doubling_times(pred_params["beta_data"], pred_params["alpha_year"])
      compute_doubling = (1/param_doubling + 1/data_doubling) ** (-1)

    elif model_num in {13, 28}:
      param_doubling_nt = pred_params['alpha_param'] / pred_params['alpha_year'] * np.log(2)
      param_doubling_t = pred_params['alpha_param_t'] / pred_params['alpha_year'] * np.log(2)
      param_doubling = (param_doubling_nt, param_doubling_t)

      data_doubling_nt = pred_params['beta_data'] / pred_params['beta_year'] * np.log(2)
      data_doubling_t = pred_params['beta_data_t'] / pred_params['beta_year'] * np.log(2)
      data_doubling = (data_doubling_nt, data_doubling_t)

      compute_doubling_nt = (1/param_doubling_nt + 1/data_doubling_nt) ** (-1)
      compute_doubling_t = (1/param_doubling_t + 1/data_doubling_t) ** (-1)
      compute_doubling = (compute_doubling_nt, compute_doubling_t)

    return param_doubling, data_doubling, compute_doubling

In [None]:
def compute_doubling_numerical(params, model_num, year, category_ptb=0, category_wt2=0, category_transformer=1, compute=1e25):
    """
    estimates effective compute doubling times for all models using a numerical approach
    """
    # param_names = PARAMS_MAPPING[model_num]
    # bootstrap_array = np.array(bootstrap_results).T
    # pred_params = {name: bootstrap_array[i] for i, name in enumerate(param_names)}

    compute_doubling = None
    C1, C2 = compute, 2 * compute
    # params = params_optimized.values

    def CEL(x, params, year):
        log_param, log_data = x
        if model_num in {13, 17}:
            # print(model_name(params, year, np.exp(log_param), np.exp(log_data), category_ptb, category_wt2, category_transformer))
            return model_name(params, year, np.exp(log_param), np.exp(log_data), category_ptb, category_wt2, category_transformer)
        else:
            return model_name(params, year, np.exp(log_param), np.exp(log_data), category_ptb, category_wt2)

    def constraint_eq(x, C):
        """
        physical compute constraint C = 6ND
        """
        log_param, log_data = x
        return log_param + log_data + np.log(6) - np.log(float(C))

    def optimize_and_evaluate(C, params, year=2023):
        """
        minimize cross entropy loss given constraint
        """
        initial_guess = [np.log(C**0.5), np.log(C**0.5)]
        con = {'type': 'eq', 'fun': lambda x: constraint_eq(x, C)}
        result = minimize(CEL, initial_guess, args=(params, year), constraints=con)
        return np.exp(result.x), CEL(result.x, params, year)

    # Estimate optimal param, data and loss after doubling of compute budget
    optimized_nd1, loss_C1 = optimize_and_evaluate(C1, params, year)
    optimized_nd2, loss_C2 = optimize_and_evaluate(C2, params, year)
    # print(optimized_nd1, loss_C1)
    # print(optimized_nd2, loss_C2)

    def optimal_scaling_doubling_time(opt_params, params, target_loss):
        """
        determine years of algorithmic progress required to obtain equivalent loss reduction
        """
        def to_minimize(delta):
            new_year = year + delta
            return np.abs(CEL(opt_params, params, year=new_year) - target_loss)

        result = minimize_scalar(to_minimize)
        return result.x

    # print(np.log(optimized_nd1), params)
    doubling_time = optimal_scaling_doubling_time(np.log(optimized_nd1), params, loss_C2)
    # print(f"Doubling time is: {doubling_time} years")
    return doubling_time

In [None]:
def bootstrap_to_latex(text, param_names, pval_df):
    # match bootstrap output text
    param_pattern = r"([\w_]+):"
    num_pattern = r"([-+]?\d*\.\d+e*[-+]?\d*)"
    pattern = f"{param_pattern}\s*{num_pattern}\s*CI:\s*\[{num_pattern}\s+{num_pattern}\s*\]\s*SE:\s*{num_pattern}"
    matches = re.findall(pattern, text)

    if len(matches) != len(param_names):
        raise ValueError("The number of matches does not match the number of parameter names.")

    formatted_rows = []
    for match_item in matches:
        param, estimate, ci_lower, ci_upper, se = match_item
        pattern = r"([\w_]+)_opt"
        param_name = re.findall(pattern, param)[0]

        num_stars = sum(pval_df.loc[param_name, ["*", "**", "***"]])
        if num_stars == 0:
          star_sign = r"\nosign"
        elif num_stars == 1:
          star_sign = r"\sign"
        elif num_stars == 2:
          star_sign = r"\signn"
        elif num_stars == 3:
          star_sign = r"\signnn"

        formatted_row = f"    ${param_names.pop(0)}$ & \\begin{{tabular}}[c]{{@{{}}c@{{}}}}$\\underset{{({float(se):.3f})}}{{{float(estimate):.3f}}}$ {star_sign} \\end{{tabular}} & ${float(ci_lower):.3f}, {float(ci_upper):.3f}$ \\\\"
        formatted_rows.append(formatted_row)

    # create LaTeX table
    formatted_table = "\\begin{tabular}{@{}lcc@{}}\n    \\toprule\n     & Estimate & 95\\% CI \\\\ \\midrule\n" + "\n".join(formatted_rows) + "\n    \\bottomrule\n    \\end{tabular}"

    return formatted_table

## Model definitions

In [None]:
PARAMS_MAPPING = {
    1: ('alpha_const', 'alpha_year', 'alpha_param', 'beta_const', 'beta_year', 'beta_data'),
    2: ('alpha_const', 'alpha_param', 'beta_const', 'beta_year', 'beta_data'),
    3: ('alpha_const', 'alpha_year', 'alpha_param', 'beta_const', 'beta_data'),
    4: ('alpha_const', 'alpha_year', 'alpha_year_ptb', 'alpha_year_wt2', 'alpha_param', \
        'beta_const', 'beta_year', 'beta_data'),
    5: ('alpha_const', 'alpha_year', 'alpha_param', 'beta_const', 'beta_year', 'beta_year_ptb', \
        'beta_year_wt2', 'beta_data'),
    6: ('alpha_const', 'alpha_year', 'alpha_year_ptb', 'alpha_year_wt2', 'alpha_param', \
        'beta_const', 'beta_year', 'beta_year_ptb', 'beta_year_wt2', 'beta_data'),
    7: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_year', 'alpha_param', \
        'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_year', 'beta_data'),
    8: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_param', \
        'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_year', 'beta_data'),
    9: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_year', 'alpha_param', \
        'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_data'),
    10: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_year', 'alpha_year_ptb', 'alpha_year_wt2', 'alpha_param', \
          'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_year', 'beta_year_ptb', 'beta_year_wt2', 'beta_data'),
    11: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_year', 'alpha_year_ptb', 'alpha_year_wt2', 'alpha_param', 'alpha_param_ptb', 'alpha_param_wt2', \
          'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_year', 'beta_year_ptb', 'beta_year_wt2', 'beta_data', 'beta_data_ptb', 'beta_data_wt2'),
    12: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_year', 'alpha_param', \
          'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_data'),
    13: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_year', 'alpha_param', 'alpha_param_t', \
          'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_year', 'beta_data', 'beta_data_t'),
    14: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_param', \
          'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_data', 'alpha_rate', 'beta_rate'),
    15: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_year', 'alpha_param', \
          'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_year', 'beta_data', 'alpha_rate', 'beta_rate'),
    16: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_param', \
        'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_data'),
    17: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_param', 'alpha_param_t', \
          'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_data', 'beta_data_t'),
    18: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_year', 'alpha_compute'),
    19: ('gamma', 'alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_year', 'alpha_param', \
        'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_year', 'beta_data'),
    20: ('alpha_const', 'alpha_const_ptb', 'alpha_const_wt2', 'alpha_year', 'alpha_param', \
        'beta_const', 'beta_const_ptb', 'beta_const_wt2', 'beta_year', 'beta_data'),
}

# Extracts the parameters based on model number
def extract_params(model_num, params):
    param_names = PARAMS_MAPPING[model_num]
    return [params[i] for i, name in enumerate(param_names)]

def model_1(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_year, alpha_param, beta_const, beta_year, beta_data = extract_params(1, params)
    alpha_terms = (alpha_const - alpha_year * (year - year_const) - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const - beta_year * (year - year_const) - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_2(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_param, beta_const, beta_year, beta_data = extract_params(2, params)
    alpha_terms = (alpha_const - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const - beta_year * (year - year_const) - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_3(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_year, alpha_param, beta_const, beta_data = extract_params(3, params)
    alpha_terms = (alpha_const - alpha_year * (year - year_const) - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_4(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_year, alpha_year_ptb, alpha_year_wt2, alpha_param, beta_const, beta_year, beta_data = extract_params(4, params)
    alpha_year_prime = prime(alpha_year, alpha_year_ptb, alpha_year_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const - alpha_year_prime * (year - year_const) - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const - beta_year * (year - year_const) - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_5(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_year, alpha_param, beta_const, beta_year, beta_year_ptb, beta_year_wt2, beta_data = extract_params(5, params)
    beta_year_prime = prime(beta_year, beta_year_ptb, beta_year_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const - alpha_year * (year - year_const) - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const - beta_year_prime * (year - year_const) - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_6(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_year, alpha_year_ptb, alpha_year_wt2, alpha_param, beta_const, beta_year, beta_year_ptb, beta_year_wt2, beta_data = extract_params(6, params)
    alpha_year_prime = prime(alpha_year, alpha_year_ptb, alpha_year_wt2, category_ptb, category_wt2)
    beta_year_prime = prime(beta_year, beta_year_ptb, beta_year_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const - alpha_year_prime * (year - year_const) - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const - beta_year_prime * (year - year_const) - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_7(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_param, beta_const, beta_const_ptb, beta_const_wt2, beta_year, beta_data = extract_params(7, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const_prime - alpha_year * (year - year_const) - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_year * (year - year_const) - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_8(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_param, beta_const, beta_const_ptb, beta_const_wt2, beta_year, beta_data = extract_params(8, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const_prime - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_year * (year - year_const) - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_9(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_param, beta_const, beta_const_ptb, beta_const_wt2, beta_data = extract_params(9, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const_prime - alpha_year * (year - year_const) - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_10(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_year_ptb, alpha_year_wt2, alpha_param, beta_const, beta_const_ptb, beta_const_wt2, beta_year, beta_year_ptb, beta_year_wt2, beta_data = extract_params(10, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    alpha_year_prime = prime(alpha_year, alpha_year_ptb, alpha_year_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    beta_year_prime = prime(beta_year, beta_year_ptb, beta_year_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const_prime - alpha_year_prime * (year - year_const) - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_year_prime * (year - year_const) - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_11(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_year_ptb, alpha_year_wt2, alpha_param, alpha_param_ptb, alpha_param_wt2, beta_const, beta_const_ptb, beta_const_wt2, beta_year, beta_year_ptb, beta_year_wt2, beta_data, beta_data_ptb, beta_data_wt2 = extract_params(11, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    alpha_year_prime = prime(alpha_year, alpha_year_ptb, alpha_year_wt2, category_ptb, category_wt2)
    alpha_param_prime = prime(alpha_param, alpha_param_ptb, alpha_param_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    beta_year_prime = prime(beta_year, beta_year_ptb, beta_year_wt2, category_ptb, category_wt2)
    beta_data_prime = prime(beta_data, beta_data_ptb, beta_data_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const_prime - alpha_year_prime * (year - year_const) - alpha_param_prime * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_year_prime * (year - year_const) - beta_data_prime * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_12(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_param, beta_const, beta_const_ptb, beta_const_wt2, beta_data = extract_params(12, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const_prime - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_data * log_diff(dataset, dataset_const))
    combined_exp = np.exp(-alpha_year * (year - year_const))
    return build_model(alpha_terms, beta_terms) * combined_exp

def model_13(params, year, param, dataset, category_ptb, category_wt2, category_transformer):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_param, alpha_param_t, beta_const, beta_const_ptb, beta_const_wt2, beta_year, beta_data, beta_data_t = extract_params(13, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    alpha_param_prime = alpha_param * (1 - category_transformer) + alpha_param_t * category_transformer
    beta_data_prime = beta_data * (1 - category_transformer) + beta_data_t * category_transformer

    alpha_terms = (alpha_const_prime - alpha_year * (year - year_const) - alpha_param_prime * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_year * (year - year_const) - beta_data_prime * log_diff(dataset, dataset_const))

    return build_model(alpha_terms, beta_terms)

def model_14(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_param, \
    beta_const, beta_const_ptb, beta_const_wt2, beta_data, alpha_rate, beta_rate = extract_params(14, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)

    alpha_param_prime = alpha_param + alpha_rate * np.log(year) #np.log(year - year_const + 1)
    beta_data_prime = beta_data + beta_rate * np.log(year) #np.log(year - year_const + 1)

    param_terms = alpha_const_prime - alpha_param_prime * (np.log(param) - np.log(param_const))
    data_terms = beta_const_prime - beta_data_prime * (np.log(dataset) - np.log(dataset_const))

    return build_model(param_terms, data_terms)

def model_15(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_param, \
    beta_const, beta_const_ptb, beta_const_wt2, beta_year, beta_data, alpha_rate, beta_rate = extract_params(15, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)

    alpha_param_prime = alpha_param + alpha_rate * np.log(year)
    beta_data_prime = beta_data + beta_rate * np.log(year)

    param_terms = alpha_const_prime - alpha_year * (year - year_const) - alpha_param_prime * (np.log(param) - np.log(param_const))
    data_terms = beta_const_prime - beta_year * (year - year_const) - beta_data_prime * (np.log(dataset) - np.log(dataset_const))

    return build_model(param_terms, data_terms)

def model_16(params, year, param, dataset, category_ptb, category_wt2):
    """
    chinchilla with benchmark-specific coefficients. no algorithmic progress
    """
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_param, beta_const, beta_const_ptb, beta_const_wt2, beta_data = extract_params(16, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const_prime - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def model_17(params, year, param, dataset, category_ptb, category_wt2, category_transformer):
    """
    different scaling exponents for transformers vs non-transformer, and no other algorithmimc progress
    """
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_param, alpha_param_t, beta_const, beta_const_ptb, beta_const_wt2, beta_data, beta_data_t = extract_params(17, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    alpha_param_prime = alpha_param * (1 - category_transformer) + alpha_param_t * category_transformer
    beta_data_prime = beta_data * (1 - category_transformer) + beta_data_t * category_transformer

    alpha_terms = (alpha_const_prime - alpha_param_prime * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_data_prime * log_diff(dataset, dataset_const))

    return build_model(alpha_terms, beta_terms)

def model_18(params, year, param, dataset, category_ptb, category_wt2):
    """
    compute only
    """
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_compute = extract_params(18, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)

    compute = 6 * param * dataset
    compute_const = np.min(compute)
    alpha_terms = (alpha_const_prime - alpha_year * (year - year_const) - alpha_compute * log_diff(compute, compute_const))
    return np.exp(alpha_terms)

def model_19(params, year, param, dataset, category_ptb, category_wt2, vocab):
    """
    vocab size
    """
    gamma, alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_param, beta_const, beta_const_ptb, beta_const_wt2, beta_year, beta_data = extract_params(19, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const_prime - alpha_year * (year - year_const) - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_year * (year - year_const) - beta_data * log_diff(dataset, dataset_const))
    return gamma * np.log(vocab) + build_model(alpha_terms, beta_terms)

def model_20(params, year, param, dataset, category_ptb, category_wt2):
    """
    same as model 7, but with imputed epochs
    """
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_param, beta_const, beta_const_ptb, beta_const_wt2, beta_year, beta_data = extract_params(20, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)
    beta_const_prime = prime(beta_const, beta_const_ptb, beta_const_wt2, category_ptb, category_wt2)
    alpha_terms = (alpha_const_prime - alpha_year * (year - year_const) - alpha_param * log_diff(param, param_const))
    beta_terms = (beta_const_prime - beta_year * (year - year_const) - beta_data * log_diff(dataset, dataset_const))
    return build_model(alpha_terms, beta_terms)

def residuals(params, model_func, year, param, dataset, category_ptb, category_wt2, log_ppl, delta, category_transformer=None, vocab=None):
    if model_func in {model_13, model_17}:
        residuals_val = log_ppl - model_func(params, year, param, dataset, category_ptb, category_wt2, category_transformer)
    elif model_func == model_19:
        residuals_val = log_ppl - model_func(params, year, param, dataset, category_ptb, category_wt2, vocab)
    else:
        residuals_val = log_ppl - model_func(params, year, param, dataset, category_ptb, category_wt2)

    l1_reg = delta * np.sum(np.abs(params))
    # print("residual start")
    val = np.mean(np.square(residuals_val)) + l1_reg
    # print("residual end")
    return val

def num_params_in_model(model_num):
    return len(PARAMS_MAPPING[model_num])

model_numbers = np.arange(1, len(PARAMS_MAPPING) + 1, 1)
for model_num in model_numbers:
    length = num_params_in_model(model_num)
    print(f"Model {model_num}: {length} parameters")

Model 1: 6 parameters
Model 2: 5 parameters
Model 3: 5 parameters
Model 4: 8 parameters
Model 5: 8 parameters
Model 6: 10 parameters
Model 7: 10 parameters
Model 8: 9 parameters
Model 9: 9 parameters
Model 10: 14 parameters
Model 11: 18 parameters
Model 12: 9 parameters
Model 13: 12 parameters
Model 14: 10 parameters
Model 15: 12 parameters
Model 16: 8 parameters
Model 17: 10 parameters
Model 18: 5 parameters
Model 19: 11 parameters
Model 20: 10 parameters


# Cross validation

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import re
from sklearn.metrics import r2_score

# random.seed(123)
random.seed(42)

# Initialize dictionaries to store all predicted and actual values
all_preds = {}
all_actuals = {}

# Initialize a dictionary to store the last optimized parameters for each model and delta
warm_start_params = {}

df4 = df_head.copy(deep=True)
df4['transformer'] = df4['Architecture'].apply(lambda x: 1 if x == 'Transformer' else 0).astype(int)

epoch_num = 1 # impute this number of epochs in lieu of NaN
df4['epoch'] = pd.to_numeric(df4['epoch'], errors='coerce')
df4['epoch'].fillna(epoch_num, inplace=True)

def effective_epochs(num_epochs, rd_prime=2.9157):
  """
  estimate of "effective data" based on https://arxiv.org/pdf/2305.16264.pdf
  2.9157 from table 1, only decay D
  """
  return 1 + rd_prime * (1 - np.exp(-num_epochs / rd_prime))

def impute_vocabulary(row):
  if not pd.isna(row['Vocabulary']):
      return row['Vocabulary']
  elif row['ptb_dummy'] == 0 and row['wt2_dummy'] == 0:
      return 268000
  elif row['wt2_dummy'] == 1:
      return 33278
  elif row['wt2_dummy'] == 0 and row['ptb_dummy'] == 1:
      return 10000

df4["Vocabulary"].replace("?", np.nan, inplace=True)
df4['Vocabulary'] = df4.apply(impute_vocabulary, axis=1)
df4['Vocabulary'] = pd.to_numeric(df4['Vocabulary'], errors='coerce')

# Data preprocessing: Train-test split
df_cv, val_data = train_test_split(df4, test_size=0.2, random_state=1)
df_cv = df_cv.assign(dataset_ptb=(df_cv['dataset_name'] == 'ptb').astype(int),
                 dataset_wt2=(df_cv['dataset_name'] == 'wt2').astype(int))
val_data = val_data.assign(dataset_ptb=(val_data['dataset_name'] == 'ptb').astype(int),
                           dataset_wt2=(val_data['dataset_name'] == 'wt2').astype(int))
df_cv.reset_index(drop=True, inplace=True)
val_data.reset_index(drop=True, inplace=True)

# Hyperparameters for LOOCV
selected_models = [model_1, model_2, model_3, model_4, model_5, model_6, model_7, model_8, model_9, model_10, \
                   model_11, model_12, model_13, model_14, model_15, model_16, model_17, model_18, model_19, model_20]
# selected_models = [model_20]

delta_range = [0, 0.001, 0.0025, 0.005, 0.01, 0.02]
# delta_range = [0]

param_const = np.min(df_cv['param'])
dataset_const = np.min(df_cv['dataset'])
year_const = np.min(df_cv['publication_date'])
n = len(df_cv)

# Initialize mean_mse based on actual selected_models list
mean_mse = {'MSE_model{}_delta{}'.format(re.search(r"\d+", model.__name__).group(), delta): [] for model in selected_models for delta in delta_range}
val_mean_mse = {'Val_MSE_model{}_delta{}'.format(re.search(r"\d+", model.__name__).group(), delta): [] for model in selected_models for delta in delta_range}

# Initialize a set to store worst-performing model-delta combinations
excluded_model_deltas = set()

all_param_estimates = []

for i in range(n):
    train_data = df_cv.drop(i).copy()
    test_data = df_cv.iloc[i].copy()
    param_estimates = []

    for model_idx, model in enumerate(selected_models):

        model_num_str = re.search(r'\d+', model.__name__).group()
        model_num = int(model_num_str)

        if model_num == 20:
            train_data.loc[:,"dataset"] = train_data["dataset"] * effective_epochs(train_data["epoch"])
            test_data.loc["dataset"] = test_data["dataset"] * effective_epochs(test_data["epoch"])

        for delta in delta_range:
            model_delta_key = f'model{model_num_str}_delta{delta}'
            num_params = num_params_in_model(model_num)

            """
            Optimization
            """
            # Check if warm-start parameters exist for this model and delta
            p0 = warm_start_params.get(f'model{model_num_str}_delta{delta}', np.zeros(num_params))
            # p0 = np.zeros(num_params)


            res = minimize(residuals, p0, args=(model, train_data["publication_date"], train_data["param"], train_data["dataset"], \
                      train_data["dataset_ptb"], train_data["dataset_wt2"], train_data["log_ppl"], delta, train_data["transformer"], train_data["Vocabulary"]), method='SLSQP')
            param_estimates.append(res.x)

            # Store the optimized parameters for warm-starting the next iteration
            warm_start_params[f'model{model_num_str}_delta{delta}'] = res.x

            if any(np.isnan(res.x)):
                print(f"Optimization failed for model_{model_num} at iteration {i}")
                continue


            """
            Predictions and evaluation
            """
            # Test prediction
            if model_num in {13, 17}: # separates transformer vs non-transformer
                pred_value = model(res.x, test_data["publication_date"], test_data["param"], \
                               test_data["dataset"], test_data["dataset_ptb"], test_data["dataset_wt2"], test_data["transformer"])
            elif model_num == 19:
                pred_value = model(res.x, test_data["publication_date"], test_data["param"], \
                               test_data["dataset"], test_data["dataset_ptb"], test_data["dataset_wt2"], test_data["Vocabulary"])
            else:
                pred_value = model(res.x, test_data["publication_date"], test_data["param"], \
                               test_data["dataset"], test_data["dataset_ptb"], test_data["dataset_wt2"])

            if np.isnan(pred_value):
                print(f"Prediction failed for model_{model_num} at iteration {i}")
                continue

            mse_value = mean_squared_error([test_data["log_ppl"]], [pred_value])
            if not np.isnan(mse_value):
                mean_mse[f'MSE_model{model_num_str}_delta{delta}'].append(mse_value)
            else:
                print(f"Computed MSE is NaN for model_{model_idx+1} at iteration {i}")

            # Validation prediction
            if model_num in {13, 17}:
                val_pred = model(res.x, val_data["publication_date"], val_data["param"], \
                                val_data["dataset"], val_data["dataset_ptb"], val_data["dataset_wt2"], val_data["transformer"])
            elif model_num == 19:
                val_pred = model(res.x, val_data["publication_date"], val_data["param"], \
                               val_data["dataset"], val_data["dataset_ptb"], val_data["dataset_wt2"], val_data["Vocabulary"])
            else:
                val_pred = model(res.x, val_data["publication_date"], val_data["param"], \
                                val_data["dataset"], val_data["dataset_ptb"], val_data["dataset_wt2"])

            val_mse = mean_squared_error(val_data["log_ppl"], val_pred)

            if not np.isnan(val_mse):
                val_mean_mse[f'Val_MSE_model{model_num_str}_delta{delta}'].append(val_mse)
            else:
                print(f"Validation MSE is NaN for model_{model_idx+1} at iteration {i}")

            # Store the predicted and actual values
            key = f'model{model_num_str}_delta{delta}'
            all_preds.setdefault(key, []).append(pred_value)
            all_actuals.setdefault(key, []).append(test_data["log_ppl"])

    all_param_estimates.append(param_estimates)
    if i % 5 == 0:
        temp_mean_mse = {key: np.mean(val) for key, val in mean_mse.items() if len(val) > 0}
        sorted_mse = sorted(temp_mean_mse.items(), key=lambda x: x[1])
        print(f"Intermediate rankings after iteration {i}: {sorted_mse}")

mean_mse = {key: np.mean(val) for key, val in mean_mse.items() if len(val) > 0}

# Compute the mean validation MSE for each model, delta pair
val_mean_mse = {key: np.mean(val) for key, val in val_mean_mse.items() if len(val) > 0}

# Initialize a dictionary to store R2 scores
r2_scores = {}

# Loop through all the keys in all_preds and all_actuals
for key in all_preds.keys():
    # Extract the predicted and actual values
    pred_values = np.array(all_preds[key])
    actual_values = np.array(all_actuals[key])

    # Compute R2 score using scikit-learn's r2_score function
    r2 = r2_score(actual_values, pred_values)

    # Store the R2 score
    r2_scores[key] = r2

# You can sort the models based on their R2 scores if needed
sorted_r2_scores = {k: v for k, v in sorted(r2_scores.items(), key=lambda item: item[1], reverse=True)}

# Print or return the R2 scores
print(sorted_r2_scores)

Intermediate rankings after iteration 0: [('MSE_model18_delta0.02', 0.005892928407613434), ('MSE_model18_delta0.01', 0.00826688999196741), ('MSE_model18_delta0.005', 0.009595673512236008), ('MSE_model18_delta0.0025', 0.010022540296258254), ('MSE_model18_delta0.001', 0.010484490486921348), ('MSE_model18_delta0', 0.010799664654454989), ('MSE_model10_delta0.001', 0.07258591546386357), ('MSE_model10_delta0', 0.07292631502555823), ('MSE_model19_delta0.02', 0.0729948207597783), ('MSE_model6_delta0.001', 0.08844367092467867), ('MSE_model2_delta0.02', 0.08851355832977954), ('MSE_model6_delta0.0025', 0.08998004653241032), ('MSE_model5_delta0', 0.09144996417959148), ('MSE_model6_delta0.005', 0.0928014919657351), ('MSE_model5_delta0.001', 0.09445996098304352), ('MSE_model6_delta0', 0.09819018129465867), ('MSE_model8_delta0', 0.09935251083484095), ('MSE_model5_delta0.0025', 0.09996775670772981), ('MSE_model13_delta0.001', 0.10023879809848275), ('MSE_model7_delta0', 0.10118107703229398), ('MSE_mode

In [None]:
# Sort the val_mean_mse dictionary by the MSE values in ascending order
sorted__mse = {k: v for k, v in sorted(mean_mse.items(), key=lambda item: item[1])}

# Extract the top 5 model, percentile pairs
top_5_models = list(sorted__mse.keys())[:5]

# Print out the top 5 models
print("Top 5 models based on mean squared error:")
for model in top_5_models:
    print(f"{model} with MSE: {sorted__mse[model]}")

Top 5 models based on mean squared error:
MSE_model10_delta0.001 with MSE: 0.04847642103308974
MSE_model7_delta0.0025 with MSE: 0.048563819600523625
MSE_model19_delta0.02 with MSE: 0.04859473435181232
MSE_model15_delta0.01 with MSE: 0.04862315855821884
MSE_model8_delta0.005 with MSE: 0.04869930867235716


In [None]:
for model in sorted__mse.keys():
    print(f"{model} with MSE: {sorted__mse[model]}")

MSE_model10_delta0.001 with MSE: 0.04847642103308974
MSE_model7_delta0.0025 with MSE: 0.048563819600523625
MSE_model19_delta0.02 with MSE: 0.04859473435181232
MSE_model15_delta0.01 with MSE: 0.04862315855821884
MSE_model8_delta0.005 with MSE: 0.04869930867235716
MSE_model15_delta0.005 with MSE: 0.048837041442145145
MSE_model7_delta0.01 with MSE: 0.04891918602256304
MSE_model12_delta0.02 with MSE: 0.048973521894055565
MSE_model10_delta0.0025 with MSE: 0.04900021080864785
MSE_model19_delta0.01 with MSE: 0.04901942297035777
MSE_model10_delta0.005 with MSE: 0.04912598962168982
MSE_model12_delta0.005 with MSE: 0.04919516440616078
MSE_model7_delta0.02 with MSE: 0.049196912806226865
MSE_model15_delta0.02 with MSE: 0.049208871605081556
MSE_model9_delta0.02 with MSE: 0.0492496779731775
MSE_model15_delta0.0025 with MSE: 0.0493659178353991
MSE_model12_delta0.0025 with MSE: 0.04939151143709678
MSE_model20_delta0.01 with MSE: 0.049393034364751905
MSE_model12_delta0 with MSE: 0.04945719102884973
MSE

In [None]:
r2_scores = {}

# Loop through all the keys in all_preds and all_actuals
for key in all_preds.keys():
    # Extract the predicted and actual values
    pred_values = np.array(all_preds[key])
    actual_values = np.array(all_actuals[key])

    # Compute R2 score using scikit-learn's r2_score function
    r2 = r2_score(actual_values, pred_values)

    # Store the R2 score
    r2_scores[key] = r2

# You can sort the models based on their R2 scores if needed
sorted_r2_scores = {k: v for k, v in sorted(r2_scores.items(), key=lambda item: item[1], reverse=True)}

# Print or return the R2 scores
print(sorted_r2_scores)

best_model = next(iter(sorted_r2_scores))

def extract_model_and_delta(s):
    # Regular expression pattern to capture the model name and number separately, and then the delta value
    pattern = r'(model)(\d+)_delta([\d.]+)'

    # Matching the pattern with the input string
    match = re.match(pattern, s)
    if match:
        model_prefix = match.group(1)  # 'model'
        model_number = match.group(2)  # e.g., '13'
        delta_value = match.group(3)   # e.g., '0.005'

        # Concatenating model prefix and number with an underscore
        model_name = f"{model_prefix}_{model_number}"

        return model_name, float(delta_value)
    else:
        raise ValueError("String does not match the expected pattern")

model_name, delta_value = extract_model_and_delta(best_model)
model_num = int(re.search("\d+", model_name)[0])
model_name = globals()[model_name]

{'model10_delta0.001': 0.9108407867274986, 'model7_delta0.0025': 0.9106800407947857, 'model19_delta0.02': 0.9106231815043373, 'model15_delta0.01': 0.9105709028949208, 'model8_delta0.005': 0.9104308454376562, 'model15_delta0.005': 0.91017752340739, 'model7_delta0.01': 0.9100264407571329, 'model12_delta0.02': 0.909926504675807, 'model10_delta0.0025': 0.9098774176644842, 'model19_delta0.01': 0.9098420821098616, 'model10_delta0.005': 0.909646081691692, 'model12_delta0.005': 0.9095188534592795, 'model7_delta0.02': 0.909515637752116, 'model15_delta0.02': 0.9094936427888979, 'model9_delta0.02': 0.9094185905552881, 'model15_delta0.0025': 0.909204798892337, 'model12_delta0.0025': 0.9091577264116639, 'model20_delta0.01': 0.9091549253997759, 'model12_delta0': 0.9090369266371766, 'model7_delta0.005': 0.9089219669152078, 'model8_delta0.001': 0.908894218746116, 'model9_delta0.01': 0.9087123408934826, 'model10_delta0': 0.908703121599614, 'model19_delta0.005': 0.9086110134911495, 'model15_delta0.001':

In [None]:
all_param_estimates_to_save = [
    [param_estimates.tolist() if isinstance(param_estimates, np.ndarray) else param_estimates for param_estimates in loocv_loop]
    for loocv_loop in all_param_estimates
]

with open("all_param_estimates.json", 'w') as f:
      json.dump(all_param_estimates_to_save, f)

files.download("all_param_estimates.json")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
text = input()

MSE_model10_delta0.001 with MSE: 0.04847642103308974 MSE_model7_delta0.0025 with MSE: 0.048563819600523625 MSE_model19_delta0.02 with MSE: 0.04859473435181232 MSE_model15_delta0.01 with MSE: 0.04862315855821884 MSE_model8_delta0.005 with MSE: 0.04869930867235716 MSE_model15_delta0.005 with MSE: 0.048837041442145145 MSE_model7_delta0.01 with MSE: 0.04891918602256304 MSE_model12_delta0.02 with MSE: 0.048973521894055565 MSE_model10_delta0.0025 with MSE: 0.04900021080864785 MSE_model19_delta0.01 with MSE: 0.04901942297035777 MSE_model10_delta0.005 with MSE: 0.04912598962168982 MSE_model12_delta0.005 with MSE: 0.04919516440616078 MSE_model7_delta0.02 with MSE: 0.049196912806226865 MSE_model15_delta0.02 with MSE: 0.049208871605081556 MSE_model9_delta0.02 with MSE: 0.0492496779731775 MSE_model15_delta0.0025 with MSE: 0.0493659178353991 MSE_model12_delta0.0025 with MSE: 0.04939151143709678 MSE_model20_delta0.01 with MSE: 0.049393034364751905 MSE_model12_delta0 with MSE: 0.04945719102884973 MSE

In [None]:
import re

# Parse the data to extract model numbers, delta values, and MSE values
parsed_data = {}
data = re.findall("model\d+_delta\d\.*\d*\swith\sMSE:\s\d.\d+", text)

for line in data:
    match = re.search(r'model(\d+)_delta([0-9.]+) with MSE: ([0-9.]+)', line)
    if match:
        model_number = int(match.group(1))
        delta_value = float(match.group(2))
        mse_value = round(float(match.group(3)), 5)  # Rounding to 3 significant figures

        # Initialize dictionary structure if needed
        if model_number not in parsed_data:
            parsed_data[model_number] = {}

        parsed_data[model_number][delta_value] = mse_value

# Initialize LaTeX table
latex_table = "\\begin{tabular}{lrrrrrr}\n"
latex_table += "Model/Delta & 0 & 0.001 & 0.0025 & 0.005 & 0.01 & 0.02 \\\\\n"

# Fill in the table with the parsed data
for model in range(1, 21):
    row = [f"{model}"]
    for delta in [0, 0.001, 0.0025, 0.005, 0.01, 0.02]:
        mse_value = parsed_data.get(model, {}).get(delta, '-')
        row.append(str(mse_value))
    latex_table += " & ".join(row) + " \\\\\n"

# Close the table environment
latex_table += "\\end{tabular}"

# Print the LaTeX table code
print(latex_table)


\begin{tabular}{lrrrrrr}
Model/Delta & 0 & 0.001 & 0.0025 & 0.005 & 0.01 & 0.02 \\
1 & 0.05304 & 0.05309 & 0.05308 & 0.05295 & 0.05281 & 0.05231 \\
2 & 0.05371 & 0.0537 & 0.05393 & 0.05852 & 0.0586 & 0.05921 \\
3 & 0.05294 & 0.0529 & 0.05262 & 0.05249 & 0.05235 & 0.05222 \\
4 & 0.05118 & 0.0512 & 0.05114 & 0.05134 & 0.05227 & 0.05199 \\
5 & 0.05011 & 0.0501 & 0.05004 & 0.05015 & 0.05007 & 0.05113 \\
6 & 0.05034 & 0.05021 & 0.05018 & 0.0503 & 0.05049 & 0.05162 \\
7 & 0.05049 & 0.05028 & 0.04856 & 0.04952 & 0.04892 & 0.0492 \\
8 & 0.05006 & 0.04953 & 0.04975 & 0.0487 & 0.05208 & 0.0525 \\
9 & 0.0505 & 0.04996 & 0.05097 & 0.05042 & 0.04963 & 0.04925 \\
10 & 0.04964 & 0.04848 & 0.049 & 0.04913 & 0.05009 & 0.05105 \\
11 & 0.05281 & 0.05186 & 0.05196 & 0.0523 & 0.05101 & 0.05127 \\
12 & 0.04946 & 0.04997 & 0.04939 & 0.0492 & 0.05012 & 0.04897 \\
13 & 0.05227 & 0.05267 & 0.0507 & 0.05028 & 0.05005 & 0.05018 \\
14 & 0.06314 & 0.06377 & 0.0652 & 0.06359 & 0.06302 & 0.06336 \\
15 & 0.05068 & 0.0

## Double checking the compute model result

In [None]:
def compute_model(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_compute = extract_params(18, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)

    compute = 6 * param * dataset
    compute_const = np.min(compute)
    alpha_terms = (alpha_const_prime - alpha_year * (year - year_const) - alpha_compute * log_diff(compute, compute_const))
    return np.exp(alpha_terms)

def compute_residuals(params, year, param, dataset, category_ptb, category_wt2, log_ppl, delta):
    residuals_val = log_ppl - compute_model(params, year, param, dataset, category_ptb, category_wt2)
    l1_reg = delta * np.sum(np.abs(params))
    val = np.mean(np.square(residuals_val)) + l1_reg
    return val

df_main = df_head.copy(deep=True)
np.random.seed(0)

df_main['ptb_dummy'] = (df_main['dataset_name'] == 'ptb').astype(int)
df_main['wt2_dummy'] = (df_main['dataset_name'] == 'wt2').astype(int)
df_main['transformer'] = df_main['Architecture'].apply(lambda x: 1 if x == 'Transformer' else 0).astype(int)

# Initialize variables for LOOCV
mse_values = []

for i in range(len(df_main)):
    # Splitting the dataset into training and test set for LOOCV
    test_data = df_main.iloc[i]
    train_data = df_main.drop(i)

    # Extracting training data
    year_train = train_data["publication_date"]
    param_train = train_data["param"]
    dataset_train = train_data["dataset"]
    log_ppl_train = train_data["log_ppl"]
    ptb_dummy_train = train_data['ptb_dummy']
    wt2_dummy_train = train_data['wt2_dummy']
    transformer_dummy_train = train_data['transformer']

    # Constants for the training data
    param_const = np.min(param_train)
    dataset_const = np.min(dataset_train)
    year_const = np.min(year_train)

    # Number of parameters and initial guess
    num_params = num_params_in_model(model_num)
    p0 = np.zeros(num_params)

    # Optimization
    res = minimize(compute_residuals, p0, args=(year_train, param_train, dataset_train, ptb_dummy_train, wt2_dummy_train, log_ppl_train, delta), method='SLSQP')

    # Predicting for the test data
    predicted_log_ppl = compute_model(res.x, test_data["publication_date"], test_data["param"], test_data["dataset"], test_data['ptb_dummy'], test_data['wt2_dummy'])

    # Calculating MSE for the prediction
    mse = mean_squared_error([test_data["log_ppl"]], [predicted_log_ppl])
    mse_values.append(mse)

# Calculating the average MSE over all LOOCV iterations
average_mse = np.mean(mse_values)
print(f"Average Mean Squared Error (LOOCV): {average_mse}")


Average Mean Squared Error (LOOCV): 0.6686916447247993


Try accounting for epochs to see if it makes much of a difference

In [None]:
df_main = df_head.copy(deep=True)

np.random.seed(0)

num_params = num_params_in_model(model_num)
p0 = np.zeros(num_params)
epoch_num = 1 # impute this number of epochs is NaN

df_main['ptb_dummy'] = (df_main['dataset_name'] == 'ptb').astype(int)
df_main['wt2_dummy'] = (df_main['dataset_name'] == 'wt2').astype(int)
df_main['transformer'] = df_main['Architecture'].apply(lambda x: 1 if x == 'Transformer' else 0).astype(int)
df_main['epoch'] = pd.to_numeric(df_main['epoch'], errors='coerce')
df_main['epoch'].fillna(epoch_num, inplace=True)

def effective_epochs(num_epochs, rd_prime=2.9157):
  """
  estimate of "effective data" based on https://arxiv.org/pdf/2305.16264.pdf
  2.9157 from table 1, only decay D
  """
  return 1 + rd_prime * (1 - np.exp(-num_epochs / rd_prime))

year = df_main["publication_date"]
param = df_main["param"]
dataset = df_main["dataset"] * effective_epochs(df_main["epoch"])
log_ppl = df_main["log_ppl"]
ptb_dummy = df_main['ptb_dummy']
wt2_dummy = df_main['wt2_dummy']
transformer_dummy = df_main['transformer']

param_const = np.min(param)
dataset_const = np.min(dataset)
year_const = np.min(year)

res = minimize(residuals, p0, args=(model_name, year, param, dataset, ptb_dummy, wt2_dummy, log_ppl, delta, transformer_dummy), method='SLSQP')

print(res)

print("\n")
params_optimized = pd.Series(res.x, index=PARAMS_MAPPING[model_num])
print("Optimized parameters:")
print(params_optimized)

In [None]:
def compute_model(params, year, param, dataset, category_ptb, category_wt2):
    alpha_const, alpha_const_ptb, alpha_const_wt2, alpha_year, alpha_compute = extract_params(18, params)
    alpha_const_prime = prime(alpha_const, alpha_const_ptb, alpha_const_wt2, category_ptb, category_wt2)

    compute = 6 * param * dataset
    compute_const = np.min(compute)
    alpha_terms = (alpha_const_prime - alpha_year * (year - year_const) - alpha_compute * log_diff(compute, compute_const))
    return np.exp(alpha_terms)

def compute_residuals(params, year, param, dataset, category_ptb, category_wt2, log_ppl, delta):
    residuals_val = log_ppl - compute_model(params, year, param, dataset, category_ptb, category_wt2)
    l1_reg = delta * np.sum(np.abs(params))
    val = np.mean(np.square(residuals_val)) + l1_reg
    return val

df_main = df_head.copy(deep=True)
np.random.seed(0)

df_main['ptb_dummy'] = (df_main['dataset_name'] == 'ptb').astype(int)
df_main['wt2_dummy'] = (df_main['dataset_name'] == 'wt2').astype(int)
df_main['transformer'] = df_main['Architecture'].apply(lambda x: 1 if x == 'Transformer' else 0).astype(int)
df_main['epoch'] = pd.to_numeric(df_main['epoch'], errors='coerce')
df_main['epoch'].fillna(epoch_num, inplace=True)
df_main['effective_data'] = df_main["dataset"] * effective_epochs(df_main["epoch"])

def effective_epochs(num_epochs, rd_prime=2.9157):
  """
  estimate of "effective data" based on https://arxiv.org/pdf/2305.16264.pdf
  2.9157 from table 1, only decay D
  """
  return 1 + rd_prime * (1 - np.exp(-num_epochs / rd_prime))

# Initialize variables for LOOCV
mse_values = []

for i in range(len(df_main)):
    # Splitting the dataset into training and test set for LOOCV
    test_data = df_main.iloc[i]
    train_data = df_main.drop(i)

    # Extracting training data
    year_train = train_data["publication_date"]
    param_train = train_data["param"]
    dataset_train = train_data["effective_data"]
    log_ppl_train = train_data["log_ppl"]
    ptb_dummy_train = train_data['ptb_dummy']
    wt2_dummy_train = train_data['wt2_dummy']
    transformer_dummy_train = train_data['transformer']

    # Constants for the training data
    param_const = np.min(param_train)
    dataset_const = np.min(dataset_train)
    year_const = np.min(year_train)

    # Number of parameters and initial guess
    num_params = num_params_in_model(model_num)
    p0 = np.zeros(num_params)

    # Optimization
    res = minimize(compute_residuals, p0, args=(year_train, param_train, dataset_train, ptb_dummy_train, wt2_dummy_train, log_ppl_train, delta), method='SLSQP')

    # Predicting for the test data
    predicted_log_ppl = compute_model(res.x, test_data["publication_date"], test_data["param"], test_data["effective_data"], test_data['ptb_dummy'], test_data['wt2_dummy'])

    # Calculating MSE for the prediction
    mse = mean_squared_error([test_data["log_ppl"]], [predicted_log_ppl])
    mse_values.append(mse)

# Calculating the average MSE over all LOOCV iterations
average_mse = np.mean(mse_values)
print(f"Average Mean Squared Error (LOOCV): {average_mse}")


Average Mean Squared Error (LOOCV): 0.7397449268180392
