# Setup

In [None]:
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
import json
import datetime
from itertools import combinations_with_replacement
import numpy as np
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy.stats import chi2
from sklearn.model_selection import KFold
import statsmodels.api as sm
from tqdm import tqdm

from data import *
from plotting import *
from regression import *
from utils import *

In [53]:
top_n = 4  # Filter to the top n models by training compute at time of release. Default: 4.
cutoff_date = '2019-01-01'  # When to start the regressions from. Default: '2018-01-01'.
top_n_cutoff_date = '2019-01-01'  # When to start the regressions from. Default: '2018-01-01'.
save = True  # Whether to save the plots. Default: True.

In [54]:
results_dir = 'results/compute/20250210_slowdown/'
os.makedirs(results_dir, exist_ok=True)
os.makedirs(results_dir + 'plot_data', exist_ok=True)

In [55]:
colors = {'Protein language model': 'blue', 'Specialized model': 'red', 'All': 'blue'}

# Data preparation

In [56]:
# Load data

def load_pcd_df():
    return pd.read_csv('data/biological_ai_models.csv')

pcd_df = load_pcd_df()
pcd_df.Task[pcd_df.Model == 'ESM3 (98B)'] = 'Protein language model' # ESM 3 is not correctly tagged
pcd_df


ChainedAssignmentError: behaviour will change in pandas 3.0!
You are setting values through chained assignment. Currently this works in certain cases, but when using Copy-on-Write (which will become the default behaviour in pandas 3.0) this will never work to update the original DataFrame or Series, because the intermediate object on which we are setting values will behave as a copy.
A typical example is when you are setting values in a column of a DataFrame, like:

df["col"][row_indexer] = value

Use `df.loc[row_indexer, "col"] = values` instead, to perform the assignment in a single step and ensure this keeps updating the original `df`.

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy




A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



Unnamed: 0,Model,Domain,Task,Organization,Authors,Publication date,Reference,Link,Citations,Notability criteria,...,Batch size,Batch size notes,Organization categorization,Training code accessibility,Accessibility notes,Organization categorization (from Organization),Possibly over 1e23 FLOP,Frontier model,Training power draw (W),Training compute estimation method
0,DNCON2,Biology,"Proteins,Protein folding prediction",University of Missouri,"Badri Adhikari, Jie Hou, Jianlin Cheng",2018-05-01,DNCON2: improved protein contact prediction us...,https://academic.oup.com/bioinformatics/articl...,173.0,,...,,,Academia,Open source,license: https://github.com/multicom-toolbox/D...,Academia,,,,Hardware
1,DNABERT,Biology,Protein or nucleotide language model (pLM/nLM),Northeastern University,"Yanrong Ji, Zhihan Zhou, Han Liu, Ramana V Dav...",2021-08-15,DNABERT: pre-trained Bidirectional Encoder Rep...,https://academic.oup.com/bioinformatics/articl...,479.0,SOTA improvement,...,,,Academia,Open source,"Apache 2.0, code and weights: https://github.c...",Academia,,,,"Hardware,Operation counting"
2,ProteinBERT,Biology,"Proteins,Protein generation","Hebrew University of Jerusalem,Ben-Gurion Univ...","Nadav Brandes, Dan Ofer, Yam Peleg, Nadav Rapp...",2022-02-10,ProteinBERT: a universal deep-learning model o...,https://academic.oup.com/bioinformatics/articl...,386.0,SOTA improvement,...,26008.0,"Supplementary materials: ""During pretraining w...","Academia,Academia,Industry",,,"Academia,Academia,Industry",,,511.396755,Hardware
3,DistilProtBert,Biology,"Proteins,Protein folding prediction",Bar-Ilan University,"Yaron Geffen, Yanay Ofran, Ron Unger",2022-09-18,DistilProtBert: a distilled protein language m...,https://academic.oup.com/bioinformatics/articl...,23.0,,...,,,Academia,,,Academia,,,,Hardware
4,BERT-RBP,Biology,"Proteins,Protein interaction prediction",Waseda University,"Keisuke Yamada, Michiaki Hamada",2022-04-07,Prediction of RNA–protein interactions using a...,https://academic.oup.com/bioinformaticsadvance...,36.0,SOTA improvement,...,,,Academia,Open (non-commercial),No clear license: https://github.com/kkyamada/...,Academia,,,,Hardware
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
345,CryoChains,Biology,Cryo-EM image reconstruction,"University of California Santa Barbara (UCSB),...","Bongjin Koo, Julien Martel, Ariana Peck, Axel ...",2023-07-15,CryoChains: Heterogeneous Reconstruction of Mo...,https://arxiv.org/abs/2306.07274,,,...,,,"Academia,Academia",,,"Academia,Academia",,,,
346,CryoDRGN,Biology,Cryo-EM image reconstruction,Massachusetts Institute of Technology (MIT),"Ellen D. Zhong, Tristan Bepler, Bonnie Berger,...",2021-02-04,CryoDRGN: reconstruction of heterogeneous cryo...,https://www.nature.com/articles/s41592-020-010...,,,...,,,Academia,,,Academia,,,,
347,AMPLIFY,Biology,Protein or nucleotide language model (pLM/nLM),"Chandar Research Lab,Mila - Quebec AI (origina...","Quentin Fournier, Robert M. Vernon, Almer van ...",2024-09-23,Protein Language Models: Is Scaling Necessary?,https://www.biorxiv.org/content/10.1101/2024.0...,3.0,,...,,,"Academia,Industry,Academia,Research collective",,,"Academia,Industry,Academia,Research collective",,,,Hardware
348,Evo 2 40B,Biology,Protein or nucleotide language model (pLM/nLM),"Arc Institute,Stanford University,NVIDIA,Liqui...","Garyk Brixi, Matthew G. Durrant, Jerome Ku, Mi...",2025-02-19,Genome modeling and design across all domains ...,https://arcinstitute.org/manuscripts/Evo2,,,...,,,"Academia,Industry,Industry,Academia,Academia,A...",,,"Academia,Industry,Industry,Academia,Academia,A...",checked,,,


In [57]:
def find_rolling_top_models(df, n):
    """Find the models which were in the top n FLOP values when they were released."""
    # This set will keep track of models that were ever in the top n at their release
    ever_in_top_n = set()

    # Iterate over each date in the DataFrame
    for current_date in df['date'].unique():
        # Get all entries up to the current date
        historical_data = df[df['date'] <= current_date]
        # Find top 10 models by flop count in this subset
        top_n_models = historical_data.nlargest(n, 'flop')['Model']
        # Update the set of models that were ever in top n
        ever_in_top_n.update(top_n_models)

    # Return DataFrame filtered to only include models that were ever in the top 10
    return df[df['Model'].isin(ever_in_top_n)]


def filter_top_models_within_category(df, top_n, cutoff_date, category):
    """Find the models which were in the top-n by compute when they were released,
    among models in the specified category. The top-n models in the specified category
    are seeded with the overall top-n models before the cutoff date.
    """
    # Filter top-n models within the category, but seeded with overall top-n models
    top_models_df = find_rolling_top_models(df, top_n)
    top_n_models_at_cutoff_date_df = top_models_df[top_models_df['date'] <= cutoff_date].nlargest(top_n, 'flop')
    category_df = df[df['category'] == category]

    # This set will keep track of models that were ever in the top 10 at their release
    ever_in_top_n = set()

    # Iterate over each date in the DataFrame
    for current_date in category_df['date'].unique():
        # Get all entries up to the current date
        category_since_cutoff = category_df[(category_df['date'] <= current_date) & (category_df['date'] > cutoff_date)]
        historical_data = pd.concat([category_since_cutoff, top_n_models_at_cutoff_date_df])
        # Find top 10 models by flop count in this subset
        top_n_models_df = historical_data.nlargest(top_n, 'flop')
        # Update the set of models that were ever in top n
        # Filter out the models that aren't in the category
        ever_in_top_n.update(top_n_models_df[top_n_models_df['category'].str.contains(category)]['Model'])

    # Return DataFrame filtered to only include models that were ever in the top 10
    new_df = df[df['Model'].isin(ever_in_top_n)]
    # Assign the category to the new DataFrame (overwrites cases with both US and China)
    # E.g. if a "USA,China" model is top-10 among models affiliated with China, then it's just "China"
    new_df.loc[:, 'category'] = category

    return new_df


def filter_top_models_in_both_categories(df, top_n, cutoff_date):
    # Get top models for Open and Closed categories
    top_plm = filter_top_models_within_category(df, top_n, cutoff_date, category='Protein language model')
    top_spec = filter_top_models_within_category(df, top_n, cutoff_date, category='Specialized model')
    # Combine the results
    df_filtered = pd.concat([top_plm, top_spec])
    # Sort the combined DataFrame by date
    df_filtered = df_filtered.sort_values('date')
    return df_filtered

In [58]:
df_filtered = (pcd_df[['Model', 'Training compute (FLOP)', 'Publication date', 'Task']]
    .rename(columns={'Training compute (FLOP)': 'flop', 'Publication date': 'date', 'Task': 'category'})
    .assign(date=lambda x: pd.to_datetime(x['date']), log_flop=lambda x: np.log10(x['flop']), category=lambda x: x['category'].apply(lambda y: 'Protein language model' if 'language model' in y.lower() else 'Specialized model'))
    .sort_values('date'))

#dff2 = df_filtered.copy()
#dff2.category = 'All'

#df_filtered = pd.concat([df_filtered, dff2])

In [59]:
df_filtered.dropna(subset=['flop'], inplace=True)
df_filtered = filter_top_models_in_both_categories(df_filtered, top_n, top_n_cutoff_date)
df_filtered.category.value_counts()
# df_filtered = find_rolling_top_models(df_filtered, 8)#top_n)

Unnamed: 0_level_0,count
category,Unnamed: 1_level_1
Specialized model,17
Protein language model,16


In [60]:
plm_df = df_filtered[df_filtered['category'] == 'Protein language model']
spec_df = df_filtered[df_filtered['category'] == 'Specialized model']

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=plm_df['date'],
    y=plm_df['log_flop'],
    mode='markers',
    marker=dict(color=colors['Protein language model'], opacity=0.5),
    text=plm_df['Model'],
    hoverinfo='text',
    name=f'Top-{top_n} Protein language model'
))

fig.add_trace(go.Scatter(
    x=spec_df['date'],
    y=spec_df['log_flop'],
    mode='markers',
    marker=dict(color=colors['Specialized model'], opacity=0.5),
    text=spec_df['Model'],
    hoverinfo='text',
    name=f'Top-{top_n} Specialized model'
))

fig.update_layout(
    width=800,
    height=400,
    xaxis_title='Date',
    yaxis_title='Log FLOP',
    title=f'Top-{top_n} models',
    margin=dict(t=50, l=60, r=60, b=50),
)

save_plot(fig, results_dir, f'top_{top_n}_models')

fig.show()

# Regression analysis

In [61]:
dep_var = 'log_flop'

In [62]:
@dataclass
class FitResult:
    df: pd.DataFrame
    p: int = None
    bic: float = None
    rss: float = None
    mse: float = None
    predict: Callable = None

@dataclass
class HyperbolicFitResult(FitResult):
    params: tuple[float] = None

@dataclass
class KinkedFitResult(FitResult):
    break_points: tuple[float] = None
    break_points_dt: float = None
    oom_year_slopes: tuple[float] = None

    # Model properties for each breakpoint combination
    # (for debugging)
    bics: tuple[float] = None
    rsss: tuple[float] = None
    mses: tuple[float] = None
    break_points_list: tuple[tuple[float]] = None
    break_points_dt_list: tuple[tuple[float]] = None

def fit_hyperbolic(df):
    def hyperbolic_model(t, A, B, k):
        return A / (1 + B * np.exp(-k * t))

    # Prepare data for curve fitting
    timestamp = pd.to_datetime(df['date']).apply(lambda date: date.toordinal()).values

    # Initial guess for the parameters
    # initial_guess = [0, 0, 0]
    initial_guess = [1.72373207e-02, -9.45447534e-01, -7.50101861e-08]  # Updated initial guess

    # Fit the model to the data
    try:
      params, covariance = curve_fit(hyperbolic_model, timestamp, df[dep_var], p0=initial_guess, maxfev=100000, ftol=1e-10)
    except RuntimeError as e:
      print("FATAL ERROR WHEN FITTING HYPERBOLIC")
      return None

    # Extracting parameters
    A, B, k = params

    # Compute predictions to calculate residuals
    predicted_log_y = hyperbolic_model(timestamp, *params)

    # Compute the Residual Sum of Squares (RSS)
    rss = np.sum((df[dep_var] - predicted_log_y) ** 2)

    # Number of observations (n)
    n = len(df[dep_var])

    # Number of parameters (p)
    p = len(params) + 1

    # Calculate log-likelihood under the assumption of normally distributed errors
    # log_likelihood = -0.5 * rss
    log_likelihood = -0.5 * n * (np.log(2 * np.pi * rss/n) + 1)

    # Compute bic_hyperbolic using the provided formula
    bic = p * np.log(n) - 2 * log_likelihood

    # Compute MSE
    mse = rss / n

    fit_result = HyperbolicFitResult(
        df=df,
        p=p,
        bic=bic,
        rss=rss,
        mse=mse,
        params=params,
        predict=lambda date: hyperbolic_model(date.apply(lambda d: d.toordinal()), *params)
    )

    return fit_result

def fit_n_phase_exponential(df, kink_count=0, allow_discontinuities=False, min_n_segment=None):
    if min_n_segment is None:
        if top_n == 1:
            # Top-1 has few points, so larger segments don't work
            # We've found that 4 is the highest number that works
            min_n_segment = 4
        else:
            min_n_segment = 10

    # Generate monthly breakpoints between 2010 and 2024
    one_month = pd.DateOffset(months=1)
    break_point_grid = pd.date_range(start=df['date'].min() - one_month, end=df['date'].max() - 4*one_month, freq='MS')
    break_point_grid = [x.toordinal() for x in break_point_grid]

    x = pd.to_datetime(df['date']).apply(lambda date: date.toordinal()).values
    y = df[dep_var].values

    break_points_list = []
    bics = []
    rsss = []
    mses = []
    models = []

    for break_points in combinations_with_replacement(break_point_grid, kink_count):
        # Model predictors

        intercept_change_points = (0,)
        if allow_discontinuities:
            intercept_change_points += break_points
        slope_change_points = (0,) + break_points

        predictors = np.zeros((len(x), len(intercept_change_points) + len(slope_change_points)))

        for i, intercept_point in enumerate(intercept_change_points):
            predictors[:, i] = (x >= intercept_point).astype(int)

        for i, break_point in enumerate(slope_change_points):
            predictors[:, len(intercept_change_points) + i] = np.maximum(x - break_point, 0)

        # Fit the model
        model = sm.OLS(y, predictors).fit()

        # Check for negative discontinuities if discontinuities are allowed
        invalid_discontinuity = False
        if allow_discontinuities and break_points:
            # For each breakpoint, compare the predicted value just before and after
            for break_point in break_points:
                # Create predictor matrices for points just before and after the breakpoint
                before_predictors = np.zeros((1, len(intercept_change_points) + len(slope_change_points)))
                after_predictors = np.zeros((1, len(intercept_change_points) + len(slope_change_points)))

                # Fill in the predictor matrices
                for i, intercept_point in enumerate(intercept_change_points):
                    before_predictors[0, i] = (break_point - 1 >= intercept_point)
                    after_predictors[0, i] = (break_point >= intercept_point)

                for i, slope_point in enumerate(slope_change_points):
                    before_predictors[0, len(intercept_change_points) + i] = max(0, break_point - 1 - slope_point)
                    after_predictors[0, len(intercept_change_points) + i] = max(0, break_point - slope_point)

                # Get predictions
                before_value = model.predict(before_predictors)[0]
                after_value = model.predict(after_predictors)[0]

                # Check if there's a negative discontinuity
                if after_value < before_value:
                    invalid_discontinuity = True
                    break

        if invalid_discontinuity:
            continue

        # Calculate BIC manually based on log-likelihood
        n = len(x) # Number of observations
        p = len(model.params) + 2*kink_count + 1 # Number of parameters

        # Calculate log-likelihood under the assumption of normally distributed errors
        # We have to iterate over all points to get their individual log-likelihoods
        log_likelihood = 0
        rss = 0
        invalid_model = False # Discard models with segments with less than 2 points
        for i, break_point in enumerate(slope_change_points):
            left_x = break_point
            right_x = slope_change_points[i + 1] if i + 1 < len(slope_change_points) else np.inf

            segment_predictors = predictors[(left_x <= x) & (x < right_x), :]
            segment_y = y[(left_x <= x) & (x < right_x)]
            segment_n = len(segment_y)

            assert min_n_segment > 2

            if segment_n < min_n_segment:
                invalid_model = True
                break

            y_pred = model.predict(segment_predictors)

            segment_rss = np.sum((y_pred - segment_y)**2)
            if segment_rss == 0:
                print(f"segment_rss={segment_rss}")
                print(f"y_pred={y_pred}")
                print(f"segment_y={segment_y}")
                invalid_model = True
                break
            segment_mse = segment_rss / segment_n

            segment_log_likelihood = -segment_n/2 * (np.log(2*np.pi) + np.log(segment_rss/segment_n) + 1)
            log_likelihood += segment_log_likelihood
            rss += segment_rss

        if invalid_model:
            continue

        # Compute BIC using the manual method based on the log-likelihood
        bic = p * np.log(n) - 2 * log_likelihood
        # bic = n*np.log(rss/n) + p*np.log(n)

        bics.append(bic)
        rsss.append(rss)
        mses.append(rss/len(df))
        models.append(model)
        break_points_list.append(break_points)

    if len(bics) == 0:
        return None

    # Prepare the result object
    best_bic = min(bics)
    best_idx = bics.index(best_bic)
    best_rss = rsss[best_idx]
    best_mse = mses[best_idx]
    best_model = models[best_idx]
    best_break_points = break_points_list[best_idx]

    p = len(best_model.params) + 2*kink_count + 1 # Number of parameters

    intercept_change_points = (0,)
    if allow_discontinuities:
        intercept_change_points += best_break_points
    slope_change_points = (0,) + best_break_points

    intercepts = best_model.params[:len(intercept_change_points)]
    oom_year_slopes = 365 * np.cumsum(best_model.params[len(intercepts):])

    def predict(date):
        if not isinstance(date, pd.Series):
            date = pd.Series(date)
        x = pd.to_datetime(date).apply(lambda date: date.toordinal()).values

        predictors = np.zeros((len(x), len(intercept_change_points) + len(slope_change_points)))

        for i, intercept_point in enumerate(intercept_change_points):
            predictors[:, i] = (x >= intercept_point).astype(int)

        for i, break_point in enumerate(slope_change_points):
            predictors[:, len(intercept_change_points) + i] = np.maximum(x - break_point, 0)

        return best_model.predict(predictors)

    fit_result = KinkedFitResult(
        df=df,
        p=p,
        bic=best_bic,
        rss=best_rss,
        mse=best_mse,
        break_points=best_break_points,
        predict=predict,
        break_points_dt=[pd.Timestamp.fromordinal(bp) for bp in best_break_points],
        bics=bics,
        rsss=rsss,
        mses=mses,
        oom_year_slopes=oom_year_slopes,
        break_points_list=break_points_list,
        break_points_dt_list=[[pd.Timestamp.fromordinal(bp) for bp in break_points] for break_points in break_points_list],
    )

    return fit_result

In [63]:
fit_em_all = lambda df_fit : {
    "Simple" : fit_n_phase_exponential(df_fit, kink_count=0, min_n_segment=3),
    "One kink" : fit_n_phase_exponential(df_fit, kink_count=1, min_n_segment=3),
    #"Discontinuity" : fit_n_phase_exponential(df_fit, kink_count=1, allow_discontinuities=True, min_n_segment=5),
    # "Hyperbolic": fit_hyperbolic(df_fit),
}

# Best model fits
print(f"Fitting PLM and specialized models")
regression_data = {
   'All': {},
   'Specialized model': {},
   'Protein language model': {},
}
regression_data['Specialized model']['models'] = fit_em_all(df_filtered[df_filtered['category'] == 'Specialized model'])
regression_data['Protein language model']['models'] = fit_em_all(df_filtered[df_filtered['category'] == 'Protein language model'])
regression_data['All']['models'] = fit_em_all(df_filtered)

Fitting PLM and specialized models


In [64]:
# K-Fold Cross Validation
def perform_cross_validation(df, k=10, random_state=42):
    kf = KFold(n_splits=k, shuffle=True, random_state=random_state)
    folds_mses = defaultdict(lambda : [])
    for train_index, test_index in kf.split(df):
        train_df, test_df = df.iloc[train_index], df.iloc[test_index]

        # Fit the models on the training set
        fold_models = fit_em_all(train_df)

        # Predict on the test set
        for name,model in fold_models.items():
            try:
                predicted_log_y = model.predict(test_df["date"])
            except AttributeError:
                continue
            test_rss = np.sum((predicted_log_y - test_df[dep_var])**2)
            test_mse = test_rss / len(test_df)
            folds_mses[name].append(test_mse)

    # Compute mean MSE
    folds_mses = {name: np.mean(folds_mses[name]) for name in folds_mses}

    return folds_mses

if top_n > 1:
    regression_data['Specialized model']['folds_mses'] = perform_cross_validation(df_filtered[df_filtered['category'] == 'Specialized model'])
    regression_data['Protein language model']['folds_mses'] = perform_cross_validation(df_filtered[df_filtered['category'] == 'Protein language model'])
    regression_data['All']['folds_mses'] = perform_cross_validation(df_filtered)
else:
    regression_data['Specialized model']['folds_mses'] = {}
    regression_data['Protein language model']['folds_mses'] = {}
    regression_data['All']['folds_mses'] = {}

In [65]:
# Bootstrap
bootstrap_sample_size = 1000

pred_start_date = df_filtered['date'].min()
pred_end_date = df_filtered['date'].max()

regression_data['Specialized model']['bootstrap_predictions'] = defaultdict(lambda : [])
regression_data['Protein language model']['bootstrap_predictions'] = defaultdict(lambda : [])
regression_data['All']['bootstrap_predictions'] = defaultdict(lambda : [])

regression_data['Specialized model']['bootstrap_bics'] = defaultdict(lambda : [])
regression_data['Protein language model']['bootstrap_bics'] = defaultdict(lambda : [])
regression_data['All']['bootstrap_bics'] = defaultdict(lambda : [])

regression_data['Specialized model']['bootstrap_mses'] = defaultdict(lambda : [])
regression_data['Protein language model']['bootstrap_mses'] = defaultdict(lambda : [])
regression_data['All']['bootstrap_mses'] = defaultdict(lambda : [])

regression_data['Specialized model']['bootstrap_bic_score_diff'] = defaultdict(lambda : [])
regression_data['Protein language model']['bootstrap_bic_score_diff'] = defaultdict(lambda : [])
regression_data['All']['bootstrap_bic_score_diff'] = defaultdict(lambda : [])

regression_data['Specialized model']['bootstrap_slopes'] = defaultdict(lambda : [])
regression_data['Protein language model']['bootstrap_slopes'] = defaultdict(lambda : [])
regression_data['All']['bootstrap_slopes'] = defaultdict(lambda : [])

regression_data['Specialized model']['bootstrap_breaks'] = defaultdict(lambda : [])
regression_data['Protein language model']['bootstrap_breaks'] = defaultdict(lambda : [])
regression_data['All']['bootstrap_breaks'] = defaultdict(lambda : [])

rng = np.random.default_rng(20250103)

from joblib import Parallel, delayed
from tqdm.notebook import tqdm

def bootstrap_iteration(bootstrap_index, category, df_filtered, pred_start_date, pred_end_date, rng_seed):
    if bootstrap_index == 0:
        # Use the original data as the first bootstrap sample
        sample = df_filtered.copy()
    else:
        sample = df_filtered.sample(len(df_filtered), replace=True, random_state=rng_seed)

    if category != 'All':
        sample = sample[sample['category'] == category]
    sample = sample.sort_values('date')

    # Compute BICs
    boot_models = fit_em_all(sample)
    if any(model is None for model in boot_models.values()):
        return None

    # Compute K fold validation
    if top_n > 1:
        boot_folds_mses = perform_cross_validation(sample)
    else:
        boot_folds_mses = {}

    # Initialize local storage
    local_bics = {}
    local_mses = {}
    local_bic_diff = {}
    local_slopes = {}
    local_breaks = {}
    local_predictions = {}

    # Store results
    for name, model in boot_models.items():
        local_bics[name] = model.bic
        local_mses[name] = boot_folds_mses.get(name, np.nan)
        local_bic_diff[name] = model.bic - boot_models.get("Simple", model.bic).bic

        if isinstance(model, KinkedFitResult):
            if len(model.oom_year_slopes) > 0:
                local_slopes[name] = 10**model.oom_year_slopes[-1]
            if len(model.break_points_dt) > 0:
                local_breaks[name] = model.break_points_dt[-1]

    # Store predictions for confidence intervals
    for name, model in boot_models.items():
        try:
            date_grid = pd.date_range(start=pred_start_date, end=pred_end_date, freq='D')
            pred = model.predict(pd.Series(date_grid))
            local_predictions[name] = pred
        except AttributeError:
            continue

    return (local_bics, local_mses, local_bic_diff, local_slopes, local_breaks, local_predictions)


def bootstrap_with_retry(bootstrap_index, category, df_filtered, pred_start_date, pred_end_date, max_retries=1):
    rng = np.random.default_rng(bootstrap_index)  # Deterministic seed per worker

    for retry in range(max_retries):
        try:
            result = bootstrap_iteration(
                bootstrap_index,
                category,
                df_filtered,
                pred_start_date,
                pred_end_date,
                rng.integers(0, 1e9)
            )
            if result is not None:
                return {
                    'success': True,
                    'result': result,
                    'retries': retry
                }
        except Exception as e:
            if retry == max_retries - 1:
                return {
                    'success': False,
                    'error': str(e),
                    'retries': retry + 1
                }
            continue

    return {
        'success': False,
        'error': 'Max retries exceeded',
        'retries': max_retries
    }


for category in ['All', 'Specialized model', 'Protein language model']:
    print(f"Bootstrapping {category} data")

    # Run parallel bootstrap with retries
    bootstrap_results = Parallel(n_jobs=-1)(
        delayed(bootstrap_with_retry)(
            i,
            category,
            df_filtered,
            pred_start_date,
            pred_end_date
        )
        for i in range(bootstrap_sample_size)
    )

    # Analyze results and retry statistics
    successful_results = [r['result'] for r in bootstrap_results if r['success']]
    total_retries = sum(r['retries'] for r in bootstrap_results)
    failed_bootstraps = sum(1 for r in bootstrap_results if not r['success'])

    print(f"Bootstrap statistics for {category}:")
    print(f"- Success rate: {(len(successful_results)/bootstrap_sample_size):.1%}")
    print(f"- Average retries: {total_retries/bootstrap_sample_size}")
    print(f"- Failed bootstraps: {failed_bootstraps}")

    # Process successful results
    for res in successful_results:
        local_bics, local_mses, local_bic_diff, local_slopes, local_breaks, local_predictions = res

        # Update storage as before
        for name, bic in local_bics.items():
            regression_data[category]['bootstrap_bics'][name].append(bic)
        # Update MSEs
        for name, mse in local_mses.items():
            regression_data[category]['bootstrap_mses'][name].append(mse)

        # Update BIC score differences
        for name, diff in local_bic_diff.items():
            regression_data[category]['bootstrap_bic_score_diff'][name].append(diff)

        # Update slopes
        for name, slope in local_slopes.items():
            regression_data[category]['bootstrap_slopes'][name].append(slope)

        # Update break points
        for name, break_pt in local_breaks.items():
            regression_data[category]['bootstrap_breaks'][name].append(break_pt)

        # Update predictions
        for name, pred in local_predictions.items():
            regression_data[category]['bootstrap_predictions'][name].append(pred)


Bootstrapping All data
Bootstrap statistics for All:
- Success rate: 100.0%
- Average retries: 0.0
- Failed bootstraps: 0
Bootstrapping Specialized model data
Bootstrap statistics for Specialized model:
- Success rate: 99.5%
- Average retries: 0.005
- Failed bootstraps: 5
Bootstrapping Protein language model data
Bootstrap statistics for Protein language model:
- Success rate: 99.1%
- Average retries: 0.009
- Failed bootstraps: 9


In [66]:
ci_width = 0.90
qs = [(1 - ci_width)/2, (1 + ci_width)/2]
bootstrap_preferred_percent = {}
bootstrap_summary_data = {
    'Specialized model': defaultdict(lambda: {}),
    'Protein language model': defaultdict(lambda: {}),
    'All': defaultdict(lambda: {})
}

for category in ['All', 'Specialized model', 'Protein language model']:
    for name in regression_data[category]['models']:
        bootstrap_summary_data[category]['bootstrap_preferred_percent'][name] = np.mean(np.array(regression_data[category]['bootstrap_bic_score_diff'][name])<0)
        bootstrap_summary_data[category]['bootstrap_bics'][name] = np.quantile(np.array(regression_data[category]['bootstrap_bics'][name]), qs)
        bootstrap_summary_data[category]['bootstrap_mses'][name] = np.quantile(np.array(regression_data[category]['bootstrap_mses'][name]), qs)
        bootstrap_summary_data[category]['bootstrap_bic_score_diff'][name] = np.quantile(np.array(regression_data[category]['bootstrap_bic_score_diff'][name]), qs)
        try:
            bootstrap_summary_data[category]['bootstrap_slopes'][name] = np.quantile(np.array(regression_data[category]['bootstrap_slopes'][name]), qs)
            bootstrap_summary_data[category]['bootstrap_breaks'][name] = np.quantile(np.array(regression_data[category]['bootstrap_breaks'][name]), qs)
        except IndexError:
            pass

# Models with lower BIC score / MSE are preferred.

results = {
    'Specialized model': [],
    'Protein language model': [],
    'All': []
}

for category in ['All', 'Specialized model', 'Protein language model']:
    for name, model in regression_data[category]['models'].items():
        param_count = model.p
        log_likelihood = (np.log(len(df_filtered))*param_count - model.bic)/2

        param_count_simple = regression_data[category]['models']['Simple'].p
        log_likelihood_simple = (np.log(len(df_filtered))*param_count_simple - regression_data[category]['models']['Simple'].bic)/2

        c2 = chi2.sf(2*(log_likelihood - log_likelihood_simple), df=(param_count - param_count_simple))

        result = {
            "Model": name,
            "BIC" : np.round(model.bic, 2),
            "BIC 90% CI" : np.round(bootstrap_summary_data[category]['bootstrap_bics'][name], 2),
            #"Parameter count": param_count,
            #"Log likelihood": np.round((np.log(len(df_filtered))*param_count - model.bic)/2),
            # "MSE" : model.mse,
            "BIC score diff": np.round(model.bic - regression_data[category]['models']['Simple'].bic, 2),
            "BIC score diff 90% CI": np.round(bootstrap_summary_data[category]['bootstrap_bic_score_diff'][name], 2),
            "Xi²": c2,
            "% times preferred over simple": f"{bootstrap_summary_data[category]['bootstrap_preferred_percent'][name]:.0%}",
            # "bayes factor over simple" : np.exp(-0.5 * (model.bic - models["simple"].bic)),
            "K-fold mean MSE" : np.round(regression_data[category]['folds_mses'].get(name, np.nan), 2),
            "K-fold mean MSE 90% CI" : np.round(bootstrap_summary_data[category]['bootstrap_mses'][name], 2),
        }

        try:
            result["Recent slope (Nx/year)"] = np.round(10**model.oom_year_slopes[-1], 2)
            result["Recent slope 90% CI"] = np.round(bootstrap_summary_data[category]['bootstrap_slopes'][name], 2)
            result["Break point"] = model.break_points_dt[-1].strftime('%Y-%m')
            result["Break point 90% CI"] = [date.strftime('%Y-%m') for date in bootstrap_summary_data[category]['bootstrap_breaks'][name]]
        except (AttributeError, IndexError):
            pass
        results[category].append(result)

results = {category: pd.DataFrame(results[category]) for category in results.keys()}

print("Bootstrapped regression results")
for category in ['All', 'Specialized model', 'Protein language model']:
    print(category)
    display(results[category])

Bootstrapped regression results
All


Unnamed: 0,Model,BIC,BIC 90% CI,BIC score diff,BIC score diff 90% CI,Xi²,% times preferred over simple,K-fold mean MSE,K-fold mean MSE 90% CI,Recent slope (Nx/year),Recent slope 90% CI,Break point,Break point 90% CI
0,Simple,97.47,"[83.09, 105.42]",0.0,"[0.0, 0.0]",,0%,0.89,"[0.59, 1.17]",8.78,"[6.21, 11.97]",,
1,One kink,98.56,"[-101.37, 101.98]",1.09,"[-197.28, 3.75]",0.024391,80%,0.83,"[0.51, 1.61]",3.63,"[0.95, 9.77]",2021-05,"[2018-03, 2024-07]"


Specialized model


Unnamed: 0,Model,BIC,BIC 90% CI,BIC score diff,BIC score diff 90% CI,Xi²,% times preferred over simple,K-fold mean MSE,K-fold mean MSE 90% CI,Recent slope (Nx/year),Recent slope 90% CI,Break point,Break point 90% CI
0,Simple,39.74,"[23.97, 50.9]",0.0,"[0.0, 0.0]",,0%,0.59,"[0.23, 0.76]",7.86,"[5.35, 10.41]",,
1,One kink,37.65,"[-165.72, 43.75]",-2.09,"[-208.83, 2.15]",0.005636,87%,0.41,"[0.16, 1.24]",2.24,"[0.28, 8.81]",2022-05,"[2018-03, 2023-11]"


Protein language model


Unnamed: 0,Model,BIC,BIC 90% CI,BIC score diff,BIC score diff 90% CI,Xi²,% times preferred over simple,K-fold mean MSE,K-fold mean MSE 90% CI,Recent slope (Nx/year),Recent slope 90% CI,Break point,Break point 90% CI
0,Simple,47.08,"[23.8, 61.52]",0.0,"[0.0, 0.0]",,0%,0.69,"[0.27, 1.65]",9.31,"[5.76, 20.67]",,
1,One kink,35.51,"[-145.42, 38.97]",-11.57,"[-190.95, -6.53]",6.3e-05,99%,0.79,"[0.1, 1.93]",3.7,"[0.09, 13.25]",2021-05,"[2019-06, 2024-02]"


In [None]:
# Find the best model for each category
simplicity_order = ['Simple', 'One kink', 'Discontinuity']
selected_model = {}
for category in ['All', 'Specialized model', 'Protein language model']:
    df = results[category]
    argmin_bic = df['BIC'].argmin()
    min_bic = df['BIC'].iloc[argmin_bic]
    min_bic_model = df['Model'].iloc[argmin_bic]
    min_bic_mse = df['K-fold mean MSE'].iloc[argmin_bic]
    # Iterate over the models in simplicity order
    # If a simpler model is not preferred, the selected model is the one with the lowest BIC
    for model in simplicity_order:
        # Check if the BICs are close
        if df[df['Model'] == model]['BIC'].iloc[0] - min_bic < 2:
            # Check if the MSEs are close
            if np.isnan(min_bic_mse):
                selected_model[category] = model
                break
            elif (df[df['Model'] == model]['K-fold mean MSE'].iloc[0] - min_bic_mse < 0.01):
                selected_model[category] = model
                break
    print(f"Best model for {category}: {selected_model[category]}")

Best model for All: Simple
Best model for Specialized model: One kink
Best model for Protein language model: One kink


# Plot predictions with bootstrapped CIs

In [None]:
def calculate_confidence_intervals(bootstrap_preds, percentile=90):
    lower_percentile = (100 - percentile) / 2
    upper_percentile = 100 - lower_percentile
    ci = {}
    for model, preds in bootstrap_preds.items():
        preds_array = np.array(preds)  # Shape: (bootstrap_samples, n_dates)
        lower = np.percentile(preds_array, lower_percentile, axis=0)
        upper = np.percentile(preds_array, upper_percentile, axis=0)
        ci[model] = {'lower': lower, 'upper': upper}
    return ci

In [None]:
# Calculate 90% Confidence Intervals
confidence_intervals = calculate_confidence_intervals(regression_data['Specialized model']['bootstrap_predictions'], percentile=90)
confidence_intervals

{'Simple': {'lower': array([15.9383474 , 15.94106474, 15.94378208, ..., 23.43340343,
         23.43558355, 23.43776366]),
  'upper': array([17.44017292, 17.44218831, 17.4442037 , ..., 24.46448843,
         24.46727753, 24.47006664])},
 'One kink': {'lower': array([14.49349355, 14.49811997, 14.50274639, ..., 22.13162282,
         22.13187989, 22.13213695]),
  'upper': array([16.59319816, 16.5956484 , 16.59809864, ..., 23.89171912,
         23.89419454, 23.89666996])}}

In [None]:
# Graph of the different model fits using plotly

# Use the selected models
#model_types = selected_model
# Or custom:
model_types = {
    'Protein language model': 'One kink',
    'Specialized model': 'One kink',
}
# Parameters for each model selection
model_params = {
    'Simple': {
        'kink_count': 0,
        'allow_discontinuities': False,
        'min_n_segment': 5,
    },
    'One kink': {
        'kink_count': 1,
        'allow_discontinuities': False,
        'min_n_segment': 5,
    },
    'Discontinuity': {
        'kink_count': 1,
        'allow_discontinuities': True,
        'min_n_segment': 5,
    }
}

def plot_model(df, model_types, model_params):
    fig = go.Figure(layout_xaxis_range=[datetime.date(2019,1,1), datetime.date(2025,3,1)],
                    layout_yaxis_range=[16,25.5])

    # Plot the original data points
    df_plm = df[df['category'] == 'Protein language model']
    df_spec = df[df['category'] == 'Specialized model']

    fig.add_trace(go.Scatter(
        x=df_plm['date'], y=df_plm['log_flop'],
        mode='markers', name='Protein language model', text=df_plm['Model'],
        marker=dict(color=colors['Protein language model'], opacity=0.1, size=10)
    ))
    fig.add_trace(go.Scatter(
        x=df_spec['date'], y=df_spec['log_flop'],
        mode='markers', name='Specialized model', text=df_spec['Model'],
        marker=dict(color=colors['Specialized model'], opacity=0.1, size=10)
    ))

    # fig.add_trace(go.Scatter(
    #     x=df_spec['date'], y=df_spec['log_flop'],
    #     mode='markers', name='Biological models', text=df_spec['Model'],
    #     marker=dict(color=colors['All'], opacity=0.1, size=10)
    # ))

    # Annotate China models (with invisible markers)
    # china_model_annotations = []#['Doubao-pro', 'ERNIE 3.0 Titan']
    # non_china_model_annotations = []#['Grok-2', 'GPT-4', 'GPT-3 175B (davinci)']
    # for model in china_model_annotations:
    #     fig.add_trace(go.Scatter(
    #         x=df_spec[df_spec['Model'] == model]['date'], y=df_spec[df_spec['Model'] == model]['log_flop'],
    #         mode='text', name=model, text=model,
    #     ))
    # for model in non_china_model_annotations:
    #     if model == 'GPT-3 175B (davinci)':
    #         text = 'GPT-3'
    #     else:
    #         text = model
    #     fig.add_trace(go.Scatter(
    #         x=df_plm[df_plm['Model'] == model]['date'], y=df_plm[df_plm['Model'] == model]['log_flop'],
    #         mode='text', name=model, text=text,
    #     ))

    # Show the export controls date
    # Convert 2022-10-07 to seconds since epoch
    # export_controls_date = pd.Timestamp('2022-10-07').value / 1e6
    # fig.add_vline(x=export_controls_date, line_color='black', line_width=1, line_dash='dot',
    #     annotation_text='October 2022<br>Export controls introduced', annotation_position='bottom right')

    date_grid = pd.date_range(start=df['date'].min(), end=df['date'].max(), freq='D')

    trend_dfs = {}
    ci_dfs = {}
    fit_results = {}
    for category, model_type in model_types.items():
        if category == 'All':
             continue
        # df.category = 'All'

        params = model_params[model_type]
        fit_result = fit_n_phase_exponential(df[df['category'] == category], **params)
        fit_results[category] = fit_result

        # Get the month of the first point, then use that to index the date_grid
        start_month = df[df['category'] == category]['date'].min().month + 1
        start_index = np.where(date_grid.month == start_month)[0][0]
        log_flop = fit_result.predict(pd.Series(date_grid))
        # To plot the bootstrapped mean prediction instead:
        # log_flop = np.mean(regression_data[category]['bootstrap_predictions'][model_type], axis=0)

        trend_dfs[category] = pd.DataFrame({
            'date': date_grid[start_index:],
            'log_flop': log_flop[start_index:],
        })

        # Get the confidence intervals
        ci_data = calculate_confidence_intervals(regression_data[category]['bootstrap_predictions'], percentile=90)
        ci_dfs[category] = pd.DataFrame({
            'date': date_grid[start_index:],
            'lower': ci_data[model_type]['lower'][start_index:],
            'upper': ci_data[model_type]['upper'][start_index:],
        })

        # Plot the best fit line with confidence intervals
        fig.add_trace(go.Scatter(
            x=date_grid[start_index:], y=log_flop[start_index:],
            mode='lines', name=f'{category} best fit line',
            line=dict(color=colors[category], width=1),
            showlegend=False,
        ))
        fig.add_trace(go.Scatter(
            x=date_grid[start_index:],
            y=ci_data[model_type]['lower'][start_index:],
            mode='lines',
            line=dict(color=colors[category], width=0),
            showlegend=False,
        ))
        fig.add_trace(go.Scatter(
            x=date_grid[start_index:],
            y=ci_data[model_type]['upper'][start_index:],
            mode='lines',
            fill='tonexty',
            fillcolor='rgba(255,0,0,0.1)' if category == 'Specialized model' else 'rgba(0,0,255,0.1)',
            line=dict(color=colors[category], width=0),
            name=f'{category} 90% CI',
            showlegend=False,
        ))

    # Add slope labels
    for category in ['Specialized model', 'Protein language model']:
        category_df = df[df['category'] == category]
        points = [category_df['date'].min()] + fit_results[category].break_points_dt + [category_df['date'].max()]
        model_type = model_types[category]
        best_slope = 10**regression_data[category]['models'][model_type].oom_year_slopes[-1]
        slopes = bootstrap_summary_data[category]['bootstrap_slopes'][model_type]
        slope_label = f'{best_slope:.1f}x/year<br>90% CI: {slopes[0]:.1f}-{slopes[1]:.1f}x/year'
        for i in range(len(points) - 2, len(points) - 1):
            mid = points[i] + (points[i+1] - points[i]) / 2
            if category == 'Specialized model':
                mid += pd.Timedelta(days=150)
            y = fit_results[category].predict(pd.Series([mid]))[0]
            fig.add_annotation(
                x=mid, y=y + 1.6 * (1 if category == 'Protein language model' else -1),
                text=slope_label,
                showarrow=False,
                font=dict(size=12, color=colors[category])
            )

    # Update layout
    title = f'Compute trends for the largest protein language models and specialized biological models'
    fig.update_layout(
        template='plotly_white',
        width=800,
        height=400,
        title=title,
        xaxis_title='Publication date',
        yaxis_title='Training compute (FLOP)',
        legend_title='',
        margin=dict(l=10, r=10, t=40, b=10),
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(int(df['log_flop'].min()), int(df['log_flop'].max())+2, 2)),
            ticktext=[f'10<sup>{i}</sup>' for i in range(int(df['log_flop'].min()), int(df['log_flop'].max())+2, 2)]
        )
    )

    if save:
        fname = f'compute_regression_spec={model_types["Specialized model"]}_plm={model_types["Protein language model"]}_top{top_n}_cutoff={cutoff_date}'
        save_plot(fig, results_dir, fname)

        slope_df = pd.DataFrame({
            'Category': ['Protein language model', 'Specialized model'],
            'Best fit slope': [10**regression_data[category]['models'][model_types[category]].oom_year_slopes[-1] for category in ['Protein language model', 'Specialized model']],
            '90% CI lower': [bootstrap_summary_data[category]['bootstrap_slopes'][model_types[category]][0] for category in ['Protein language model', 'Specialized model']],
            '90% CI upper': [bootstrap_summary_data[category]['bootstrap_slopes'][model_types[category]][1] for category in ['Protein language model', 'Specialized model']],
        })
        slope_df.to_csv(results_dir + f'plot_data/recent_slopes_{fname}.csv', index=False)

        df_plm[['Model', 'category', 'date', 'log_flop']].to_csv(results_dir + f'plot_data/plm_scatter_{fname}.csv', index=False)
        df_spec[['Model', 'category', 'date', 'log_flop']].to_csv(results_dir + f'plot_data/spec_scatter_{fname}.csv', index=False)
        trend_dfs['Protein language model'][['date', 'log_flop']].to_csv(results_dir + f'plot_data/plm_best_fit_line_{fname}.csv', index=False)
        trend_dfs['Specialized model'][['date', 'log_flop']].to_csv(results_dir + f'plot_data/spec_best_fit_line_{fname}.csv', index=False)
        ci_dfs['Protein language model'][['date', 'lower', 'upper']].to_csv(results_dir + f'plot_data/plm_ci_{fname}.csv', index=False)
        ci_dfs['Specialized model'][['date', 'lower', 'upper']].to_csv(results_dir + f'plot_data/spec_ci_{fname}.csv', index=False)

    fig.show()

    return fit_results

fit_results = plot_model(df_filtered, model_types, model_params)

In [None]:
df_filtered

Unnamed: 0,Model,flop,date,category,log_flop
28,SPIDER2,1.822e+16,2016-10-28,Specialized model,16.260548
0,DNCON2,9.5e+16,2018-05-01,Specialized model,16.977724
129,SSA,3.2e+18,2019-02-22,Specialized model,18.50515
57,UniRep,2.2e+19,2019-03-26,Protein language model,19.342423
7,TAPE Transformer,3e+19,2019-06-19,Protein language model,19.477121
48,UDSMProt,6.37e+17,2019-09-04,Protein language model,17.804139
23,SeqVec,4.1e+19,2019-12-17,Specialized model,19.612784
54,AlphaFold,1e+20,2020-01-15,Specialized model,20.0
34,ProGen,3.7e+20,2020-03-13,Protein language model,20.568202
24,ProBERTa,9.72e+18,2020-09-01,Specialized model,18.987666
