# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
import json
from itertools import combinations_with_replacement
import numpy as np
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy.stats import chi2
from sklearn.model_selection import KFold
import statsmodels.api as sm
from tqdm import tqdm

from data import *
from plotting import *
from regression import *
from utils import *

# Parameters

In [3]:
assign_country_method = 'inclusive'  # ['inclusive', 'exclusive']. Default: 'inclusive'.
# 'external': Filter to the top n models overall
# 'internal': Filter to the top n models within 'Non-China' and 'China' categories
# 'disabled': No filtering
frontier_selection = 'external'  # ['disabled', 'internal', 'external']. Default: 'external'.
top_n = 10  # Filter to the top n models by training compute at time of release. Default: 10.
model_selection = 'Language models'  # ['All models', 'Language models', 'Google DeepMind models', 'OpenAI models', 'Meta AI models']. Default: 'Language models'.
filter_alphago_outliers = True  # Whether to filter out AlphaGo Master and AlphaGo Zero. Default: True.
filter_finetuned_models = True  # Whether to filter out separate finetuned models (base + finetuned models are still included if there is no separate base model in our dataset). Default: True.
include_speculative_compute = True  # Whether to include speculative compute estimates that rely on benchmark imputation and rough guesses. Default: True.
cutoff_date = '2018-01-01'  # When to start the regressions from. Default: '2018-01-01'.
top_n_cutoff_date = '1950-01-01'  # When to split the top-n filtering into Non-China and China categories - set to e.g. 2010 to turn off the "kickstarting". Default: '1950-01-01'.
save = True  # Whether to save the plots. Default: True.

In [4]:
# Default: no models excluded
exclude_models = []

# Early China models that are not representative of current trends
# exclude_models = [
#     'genCNN + dyn eval',
#     'R-FCN',
#     'ResNet-200',
#     '2-layer-LSTM+Deep-Gradient-Compression',
# ]

# Key China models around the breakpoint
# exclude_models = [
#     'ERNIE 3.0 Titan',
#     'Yuan 1.0',
# ]

# Largest China model
# exclude_models = [
#     'GLM-4 (0116)',
# ]

# All large BlueLMs - it's not clear they were ever released
# exclude_models = [
#     'BlueLM 70B',
#     'BlueLM 130B',
#     'BlueLM 175B',
# ]

In [5]:
results_dir = 'results/compute/20250107_safe_bootstrap/'
os.makedirs(results_dir, exist_ok=True)
os.makedirs(results_dir + 'plot_data', exist_ok=True)

In [6]:
colors = {'Non-China': 'blue', 'China': 'red'}


# Data preparation

In [7]:
# Load data
pcd_df = load_pcd_df()

In [8]:
pcd_df

Unnamed: 0,Model,Domain,Task,Authors,Notability criteria,Notability criteria notes,Model accessibility,Link,Citations,Reference,...,Hardware type,Training compute estimation method,Biological model safeguards,Hardware utilization (temp),BenchmarkHub-v1,Hugging Face developer id,Post-training compute (FLOP),Post-training compute notes,Hardware maker,benchmarks/models
0,babbage-002,Language,Language modelling,,,,,,,,...,,,,,,,,,,
1,tts-1,Speech,Text-to-speech,,,,,,,,...,,,,,,,,,,
2,tts-1-hd,Speech,Text-to-speech,,,,,,,,...,,,,,,,,,,
3,LM-Design,Biology,Protein design,"Zaixiang Zheng, Yifan Deng, Dongyu Xue, Yi Zho...",,,,https://proceedings.mlr.press/v202/zheng23a.html,46.0,Structure-informed Language Models Are Protein...,...,,,LM-Design,,,,,,,
4,Deep-LDA,,,,,,,,,Optimization and redevelopment of single-cell ...,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2326,F5-TTS,Speech,"Speech synthesis,Translation","Yushen Chen, Zhikang Niu, Ziyang Ma, Keqi Deng...",,,Open weights (non-commercial),https://arxiv.org/abs/2410.06885,,F5-TTS: A Fairytaler that Fakes Fluent and Fai...,...,,Hardware,,,,SWivid,,,NVIDIA,
2327,Veo 2,"Video,Vision",Video generation,,SOTA improvement,"""Veo has achieved state of the art results in ...",API access,https://deepmind.google/technologies/veo/veo-2/,,Our state-of-the-art video generation model,...,,,,,,,,,,
2328,DeepSeek-V3,Language,"Language modelling/generation,Code generation,...",,,,Open weights (restricted use),https://github.com/deepseek-ai/DeepSeek-V3/blo...,,DeepSeek-V3 Technical Report,...,,Operation counting,,,,deepseek-ai,,,NVIDIA,
2329,OLMo 2 Furious 7B,Language,"Language modelling/generation,Question answering","Team OLMo, Pete Walsh, Luca Soldaini, Dirk Gro...",,,Open weights (unrestricted),https://arxiv.org/abs/2501.00656,,2 OLMo 2 Furious,...,,"Reported,Operation counting",,,,allenai,,,NVIDIA,


In [9]:
pcd_df = pcd_df[~pcd_df['Model'].isin(exclude_models)]

In [10]:
print(pcd_df.loc[pcd_df['Model'] == 'Megatron-BERT']['Country (from Organization)'])
print(pcd_df.loc[pcd_df['Model'] == 'Yi-34B']['Country (from Organization)'])


753    United States of America
Name: Country (from Organization), dtype: object
1641    China
Name: Country (from Organization), dtype: object


In [11]:
country_df = pcd_df.dropna(subset=['Publication date', 'Country (from Organization)'])
len(country_df)

2202

In [12]:
country_df['Country (from Organization)'].unique()


array(['United States of America',
       'United States of America,United States of America', 'Italy',
       'New Zealand',
       'United Kingdom of Great Britain and Northern Ireland',
       'Switzerland', 'Japan', 'Multinational', 'Netherlands', 'Finland',
       'Canada', 'Japan,United States of America', 'Spain',
       'Denmark,United Kingdom of Great Britain and Northern Ireland',
       'India', 'Germany', 'France',
       'United Kingdom of Great Britain and Northern Ireland,United States of America',
       'Taiwan',
       'United States of America,United States of America,United States of America',
       'United Kingdom of Great Britain and Northern Ireland,Canada',
       'United States of America,Germany', 'Korea (Republic of)',
       'United States of America,United Kingdom of Great Britain and Northern Ireland',
       'Mexico', 'Switzerland,Germany',
       'United States of America,United Kingdom of Great Britain and Northern Ireland,United States of America',
  

In [13]:
country_df[country_df['Country (from Organization)'].str.contains('China')][['Model', 'Country (from Organization)']]

Unnamed: 0,Model,Country (from Organization)
324,AdaRNN,China
328,SPPNet,"United States of America,China,China"
334,ACF-WIDER,China
348,Cascaded LNet-ANet,"Hong Kong,China"
360,CRF-RNN,United Kingdom of Great Britain and Northern I...
...,...,...
2311,QwQ,China
2318,Hunyuan Video,"Multinational,China"
2320,Hailuo I2V-01-Live,"China,Singapore"
2326,F5-TTS,"China,United Kingdom of Great Britain and Nort..."


In [14]:
country_df[~country_df['Country (from Organization)'].str.contains('China')]

Unnamed: 0,Model,Domain,Task,Authors,Notability criteria,Notability criteria notes,Model accessibility,Link,Citations,Reference,...,Hardware type,Training compute estimation method,Biological model safeguards,Hardware utilization (temp),BenchmarkHub-v1,Hugging Face developer id,Post-training compute (FLOP),Post-training compute notes,Hardware maker,benchmarks/models
21,Theseus,Robotics,Maze solving,Claude Shannon,Historical significance,,,https://www.technologyreview.com/2018/12/19/13...,0.0,Mighty Mouse,...,,,,,,,,,,
22,SNARC,Robotics,Maze solving,Marvin Minsky,Historical significance,,,https://en.wikipedia.org/wiki/Stochastic_neura...,33.0,A Neural-Analogue Calculator Based upon a Prob...,...,,,,,,,,,,
23,Genetic algorithm,Mathematics,Numerical simulation,NA Barricelli,Historical significance,Possibly first computer simulation of a geneti...,,https://link.springer.com/article/10.1007/BF01...,266.0,Numerical testing of evolution theories,...,,,,,,,,,,
24,Sequence-based pattern recognition,Vision,Character recognition,O. G. Selfridge,Historical significance,,,https://dl.acm.org/doi/10.1145/1455292.1455310,290.0,Pattern recognition and modern computers,...,,,,,,,,,,
25,Self Organizing System,Other,Pattern recognition,W. A. Clark and B. G. Farley,Historical significance,,,https://dl.acm.org/doi/10.1145/1455292.1455309,93.0,Generalization of pattern recognition in a sel...,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2324,Gemini 2.0 Flash,"Language,Vision,Audio,Speech,Video,Multimodal","Language modelling/generation,Question answeri...",,,,API access,https://blog.google/technology/google-deepmind...,,Introducing Gemini 2.0: our new AI model for t...,...,,,,,,,,,Google,
2325,Phi-4,Language,"Language modelling/generation,Question answeri...","Marah Abdin, Jyoti Aneja, Harkirat Behl, Sébas...",,,Open weights (non-commercial),https://arxiv.org/abs/2412.08905,,Phi-4 Technical Report,...,,Operation counting,,,,,,,,
2327,Veo 2,"Video,Vision",Video generation,,SOTA improvement,"""Veo has achieved state of the art results in ...",API access,https://deepmind.google/technologies/veo/veo-2/,,Our state-of-the-art video generation model,...,,,,,,,,,,
2329,OLMo 2 Furious 7B,Language,"Language modelling/generation,Question answering","Team OLMo, Pete Walsh, Luca Soldaini, Dirk Gro...",,,Open weights (unrestricted),https://arxiv.org/abs/2501.00656,,2 OLMo 2 Furious,...,,"Reported,Operation counting",,,,allenai,,,NVIDIA,


Check if the country is listed.

TODO: try other methods of reducing multiple countries to one country.
- Use the first country listed
- Mutually exclusive (e.g. China but NOT Non-China)

In [15]:
# Excluding Taiwan, since it has a different status w.r.t. the US and export controls
china_countries = ['China', 'Hong Kong']

# Including China
def assign_country_inclusively(row):
    if any([country in china_countries for country in row['Country (from Organization)'].split(',')]):
        return 'China'
    else:
        return 'Non-China'

# Exclusively China
def assign_country_exclusively(row):
    countries = row['Country (from Organization)'].split(',')
    # Allow multinational if no non-China country is also listed
    # This applies to Tencent, for example
    if all([(country in china_countries) or (country == 'Multinational') for country in countries]):
        return 'China'
    elif any([country in china_countries for country in countries]):
        # Not exclusively China - discard because it's an ambiguous case
        return np.nan
    else:
        # Exclusively Non-China
        return 'Non-China'

# First country listed
assign_country = assign_country_inclusively if assign_country_method == 'inclusive' else assign_country_exclusively
country_df.loc[:, 'Country'] = country_df.apply(assign_country, axis=1)

display(country_df[country_df['Country'] == 'China'][['Model', 'Country']])
display(country_df[country_df['Country'] == 'Non-China'][['Model', 'Country']])
display(country_df[country_df['Country'].isna()][['Model', 'Country']])


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  country_df.loc[:, 'Country'] = country_df.apply(assign_country, axis=1)


Unnamed: 0,Model,Country
324,AdaRNN,China
328,SPPNet,China
334,ACF-WIDER,China
348,Cascaded LNet-ANet,China
360,CRF-RNN,China
...,...,...
2311,QwQ,China
2318,Hunyuan Video,China
2320,Hailuo I2V-01-Live,China
2326,F5-TTS,China


Unnamed: 0,Model,Country
21,Theseus,Non-China
22,SNARC,Non-China
23,Genetic algorithm,Non-China
24,Sequence-based pattern recognition,Non-China
25,Self Organizing System,Non-China
...,...,...
2324,Gemini 2.0 Flash,Non-China
2325,Phi-4,Non-China
2327,Veo 2,Non-China
2329,OLMo 2 Furious 7B,Non-China


Unnamed: 0,Model,Country


In [16]:
for cat in country_df['Country'].unique():
    if pd.isna(cat):
        print(cat, len(country_df.loc[country_df['Country'].isna()]))
    else:
        print(cat, len(country_df.loc[country_df['Country'] == cat]))

Non-China 1732
China 470


In [17]:
df = country_df

In [18]:
def find_top_models_up_to_release(df, top_n):
    """Find the models which were in the top n by compute when they were released."""
    # This set will keep track of models that were ever in the top 10 at their release
    ever_in_top_n = set()

    # Iterate over each date in the DataFrame
    for current_date in df['date'].unique():
        # Get all entries up to the current date
        historical_data = df[df['date'] <= current_date]
        # Find top 10 models by flop count in this subset
        top_n_models = historical_data.nlargest(top_n, 'flop')['Model']
        # Update the set of models that were ever in top n
        ever_in_top_n.update(top_n_models)

    # Return DataFrame filtered to only include models that were ever in the top 10
    return df[df['Model'].isin(ever_in_top_n)]


def filter_top_models_within_category(df, top_n, cutoff_date, category):
    """Find the models which were in the top-n by compute when they were released,
    among models in the specified category. The top-n models in the specified category
    are seeded with the overall top-n models before the cutoff date.
    """
    # Filter top-n models within the category, but seeded with overall top-n models
    top_models_df = find_top_models_up_to_release(df, top_n)
    top_n_models_at_cutoff_date_df = top_models_df[top_models_df['date'] <= cutoff_date].nlargest(top_n, 'flop')
    category_df = df[df['category'] == category]

    # This set will keep track of models that were ever in the top 10 at their release
    ever_in_top_n = set()

    # Iterate over each date in the DataFrame
    for current_date in category_df['date'].unique():
        # Get all entries up to the current date
        category_since_cutoff = category_df[(category_df['date'] <= current_date) & (category_df['date'] > cutoff_date)]
        historical_data = pd.concat([category_since_cutoff, top_n_models_at_cutoff_date_df])
        # Find top 10 models by flop count in this subset
        top_n_models_df = historical_data.nlargest(top_n, 'flop')
        # Update the set of models that were ever in top n
        # Filter out the models that aren't in the category
        ever_in_top_n.update(top_n_models_df[top_n_models_df['category'].str.contains(category)]['Model'])

    # Return DataFrame filtered to only include models that were ever in the top 10
    new_df = df[df['Model'].isin(ever_in_top_n)]
    # Assign the category to the new DataFrame (overwrites cases with both US and China)
    # E.g. if a "USA,China" model is top-10 among models affiliated with China, then it's just "China"
    new_df['category'] = category
    
    return new_df


def filter_top_models_in_both_categories(df, top_n, cutoff_date):
    # Get top models for Open and Closed categories
    top_us_models = filter_top_models_within_category(df, top_n, cutoff_date, category='Non-China')
    top_china_models = filter_top_models_within_category(df, top_n, cutoff_date, category='China')
    # Combine the results
    df_filtered = pd.concat([top_us_models, top_china_models])
    # Sort the combined DataFrame by date
    df_filtered = df_filtered.sort_values('date')
    return df_filtered

In [19]:
df_filtered = (df[['Model', 'Training compute (FLOP)', 'Publication date', 'Organization', 'Notability criteria', 'Domain', 'Base model', 'Country', 'Training hardware']]
    .rename(columns={'Training compute (FLOP)': 'flop', 'Publication date': 'date', 'Country': 'category'})
    .assign(date=lambda x: pd.to_datetime(x['date']), log_flop=lambda x: np.log10(x['flop']))
    .sort_values('date'))

In [20]:
list(df_filtered[df_filtered['Base model'].notna()]['Model'])

['BatchNorm',
 'CompACT-Deep',
 'Template Adaptation\n',
 'LRR-4X',
 'CMS-RCNN',
 'Order embeddings with layer norm',
 'Layer Normalization: The Attentive Reader',
 'Layer Normalization: Skip Thoughts',
 'Layer Normalization: Draw',
 'Layer Normalization: Handwriting sequence generation',
 'DLDL',
 'HR-ResNet101',
 'ULM-FiT',
 'ADP-FAIRSEQ + NGRAMRES',
 'Fine-tuned-AWD-LSTM-DOC (fin)',
 'Cross-lingual alignment',
 'Theseus 6/768',
 'UnifiedQA',
 'LUKE',
 'GPT-Neo-2.7B (finetuned on PTB)',
 'GPT-Neo-2.7B (finetuned)',
 'Unicorn',
 'Multitask Unified Model (MUM)',
 'Codex',
 '$\\infty$-former (SM)',
 'FLAN 137B',
 'AlphaFold-Multimer',
 'T0-XXL',
 'GPT-2 (AMPS)',
 'Masked Autoencoders ViT-H',
 'ViT-G/14 (LiT)',
 'Engine-XL(NE)',
 'HSO',
 'Contriever',
 'Vespa',
 'OntoProtein',
 'InstructGPT 175B',
 'InstructGPT 6B',
 'InstructGPT 1.3B',
 'InstructGPT 350M',
 'BERT-RBP',
 'Flamingo',
 'Jurassic-X',
 'DeBERTaV3large + KEAR',
 'SimCSE',
 'CogVideo',
 'Minerva (540B)',
 'Delphi',
 'Transform

In [21]:
# Add speculative compute estimates based on benchmark imputation and rough guesses
if include_speculative_compute:
    speculative_compute_estimates = {
        "Claude 3.5 Sonnet": 4.72e25,
        "Claude 3 Opus": 1.59e25,
        "Claude 3 Sonnet": 5.51e24,
        "GPT-4o": 3.98e25,
        "Gemini 1.0 Pro": 1.85e24,
        "Gemini 1.5 Pro": 1.60e25,
        "Mistral Large 2": 2.01e25,
        "GPT-4 Turbo": 2.1e25,  # rough guess matching GPT-4
        "GPT-4V": 2.1e25,  # rough guess matching GPT-4
        "Claude 2": 4.33e24,
        "Claude 2.1": 4.33e24,  # rough guess matching Claude 2
    }
    for model, compute in speculative_compute_estimates.items():
        df_filtered.loc[df_filtered['Model'] == model, "flop"] = compute
        df_filtered.loc[df_filtered['Model'] == model, "log_flop"] = np.log10(compute)

df_filtered.dropna(subset=['flop'], inplace=True)

# Drop Alpha Go Master / Zero
if filter_alphago_outliers:
    mask = (df_filtered['Model'] == 'AlphaGo Master') | (df_filtered['Model'] == 'AlphaGo Zero')
    df_filtered = df_filtered[~mask]

# Drop finetuned models
if filter_finetuned_models:
    mask = df_filtered['Base model'].isna()
    df_filtered = df_filtered[mask]

top_models_df = find_top_models_up_to_release(df_filtered, top_n)  # For reference

if frontier_selection == 'external':
    # Filter top models before other filters
    df_filtered = filter_top_models_in_both_categories(df_filtered, top_n, top_n_cutoff_date)

if model_selection == 'Language models':
    re = 'Language|Multimodal'
    mask = df_filtered['Domain'].str.contains(re, na=False)
    df_filtered = df_filtered[mask]

if frontier_selection == 'internal':
    # Filter top models after other filters
    df_filtered = filter_top_models_in_both_categories(df_filtered, top_n, top_n_cutoff_date)

# Filter for models after the cutoff date
df_filtered = df_filtered[df_filtered['date'] > cutoff_date]

print(f"{len(df_filtered)}{' top' if frontier_selection != 'disabled' else ''} {top_n} {model_selection} models found")
print(f"They span {df_filtered['date'].min().strftime('%B %Y')} to {df_filtered['date'].max().strftime('%B %Y')}")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  new_df['category'] = category


116 top 10 Language models models found
They span August 2018 to December 2024


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  new_df['category'] = category


In [22]:
if top_n == 1:
    # Remove BIDAF outlier
    df_filtered = df_filtered[df_filtered['Model'] != 'BIDAF']

In [23]:
non_china_df = df_filtered[df_filtered['category'] == 'Non-China']
china_df = df_filtered[df_filtered['category'] == 'China']
recent_top_models_df = top_models_df[top_models_df['date'] > pd.to_datetime('2010-01-01')]

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=non_china_df['date'],
    y=non_china_df['log_flop'],
    mode='markers',
    marker=dict(color=colors['Non-China'], opacity=0.5),
    text=non_china_df['Model'],
    hoverinfo='text',
    name=f'Top-{top_n} Non-China'
))

fig.add_trace(go.Scatter(
    x=china_df['date'],
    y=china_df['log_flop'],
    mode='markers',
    marker=dict(color=colors['China'], opacity=0.5),
    text=china_df['Model'],
    hoverinfo='text',
    name=f'Top-{top_n} China'
))

fig.add_trace(go.Scatter(
    x=recent_top_models_df['date'],
    y=recent_top_models_df['log_flop'],
    mode='markers',
    marker=dict(color='grey', opacity=0.5),
    text=recent_top_models_df['Model'],
    hoverinfo='text',
    name=f'Top-{top_n} Overall'
))

fig.update_layout(
    width=800,
    height=400,
    xaxis_title='Date',
    yaxis_title='Log FLOP',
    title=f'Top-{top_n} models',
    margin=dict(t=50, l=60, r=60, b=50),
)

save_plot(fig, results_dir, f'top_{top_n}_models_without_kickstarting')

fig.show()

In [24]:
top_models_since_cutoff = top_models_df[top_models_df['date'] >= pd.to_datetime(cutoff_date)]
top_models_set = set(top_models_since_cutoff['Model'])
non_china_top_models_set = set(non_china_df['Model'])
china_top_models_set = set(china_df['Model'])

frac_non_china_top_models = len(non_china_top_models_set.intersection(top_models_set)) / len(top_models_set)
frac_china_top_models = len(china_top_models_set.intersection(top_models_set)) / len(top_models_set)
print(f"Fraction of overall top-{top_n} models that are Non-China: {frac_non_china_top_models*100:.1f}%")
print(f"Fraction of overall top-{top_n} models that are China: {frac_china_top_models*100:.1f}%")


Fraction of overall top-10 models that are Non-China: 70.0%
Fraction of overall top-10 models that are China: 10.0%


# Training hardware analysis

In [25]:
# Models with training hardware
china_num_hardware = china_df[china_df['Training hardware'].notna()].shape[0]
print(f"Number of top-{top_n} {model_selection} China models with training hardware: {china_num_hardware} of {china_df.shape[0]}")
# Print models grouped by training hardware
china_df.groupby('Training hardware')['Model'].apply(lambda x: ', '.join(x))

Number of top-10 Language models China models with training hardware: 18 of 56


Training hardware
Huawei Ascend 910                                                     PanGu-α, CodeGeeX, PanGu-Σ
Huawei Ascend 910,NVIDIA Tesla V100 DGXS 32 GB                                   ERNIE 3.0 Titan
NVIDIA A100                                       PLUG, MegaScale (530B), MegaScale (Production)
NVIDIA A100 SXM4 40 GB                                                                  GLM-130B
NVIDIA A100 SXM4 80 GB                                                       JIANG, CodeFuse-13B
NVIDIA A800                                                                FLM-101B, Skywork-13B
NVIDIA H100 SXM5 80GB                                                               Yi-Lightning
NVIDIA H800 SXM5                                             DeepSeek-V2 (MoE-236B), DeepSeek-V3
NVIDIA Tesla V100 DGXS 32 GB                                                                M6-T
NVIDIA V100                                                                 CPM-Large, ERNIE 3.0
Name: Model,

In [26]:
export_controls_cutoff_date = pd.to_datetime('2022-10-07')  # actual date of controls is 2022-10-07
pre_export_controls = china_df[china_df['date'] < export_controls_cutoff_date]
post_export_controls = china_df[china_df['date'] >= export_controls_cutoff_date]
pre_export_controls.groupby('Training hardware')['Model'].apply(lambda x: ', '.join(x))

Training hardware
Huawei Ascend 910                                    PanGu-α, CodeGeeX
Huawei Ascend 910,NVIDIA Tesla V100 DGXS 32 GB         ERNIE 3.0 Titan
NVIDIA A100                                                       PLUG
NVIDIA A100 SXM4 40 GB                                        GLM-130B
NVIDIA Tesla V100 DGXS 32 GB                                      M6-T
NVIDIA V100                                       CPM-Large, ERNIE 3.0
Name: Model, dtype: object

In [27]:
post_export_controls.groupby('Training hardware')['Model'].apply(lambda x: ', '.join(x))


Training hardware
Huawei Ascend 910                                          PanGu-Σ
NVIDIA A100               MegaScale (530B), MegaScale (Production)
NVIDIA A100 SXM4 80 GB                         JIANG, CodeFuse-13B
NVIDIA A800                                  FLM-101B, Skywork-13B
NVIDIA H100 SXM5 80GB                                 Yi-Lightning
NVIDIA H800 SXM5               DeepSeek-V2 (MoE-236B), DeepSeek-V3
Name: Model, dtype: object

In [28]:
non_china_num_hardware = non_china_df[non_china_df['Training hardware'].notna()].shape[0]
print(f"Number of top-{top_n} {model_selection} Non-China models with training hardware: {non_china_num_hardware} of {non_china_df.shape[0]}")
# Print models grouped by training hardware
non_china_df.groupby('Training hardware')['Model'].apply(lambda x: ', '.join(x))


Number of top-10 Language models Non-China models with training hardware: 41 of 60


Training hardware
Google TPU v3                        XLNet, T5-11B, T5-3B, Meena, GShard (dense), S...
Google TPU v4                        GLaM, AlphaCode, PaLM (540B), PaLM 2, Gemini 1...
Google TPU v4,Google TPU v3                                                 Chinchilla
NVIDIA A100                                    HyperCLOVA 82B, LLaMA-65B, Amazon Titan
NVIDIA A100 SXM4 40 GB                       GPT-NeoX-20B, GPT-3.5, GPT-4, Falcon-180B
NVIDIA A100 SXM4 80 GB               Megatron-Turing NLG 530B, OPT-175B, BLOOM-176B...
NVIDIA A100,NVIDIA H100 SXM5 80GB                                            Reka Core
NVIDIA H100 SXM5 80GB                Inflection-1, Inflection-2, Mistral Large, Inf...
NVIDIA Tesla V100 DGXS 16 GB                      Big Transformer for Back-Translation
NVIDIA Tesla V100 DGXS 32 GB         RoBERTa Large, Megatron-LM (8.3B), XLM-RoBERTa...
NVIDIA Tesla V100S PCIe 32 GB                                            Megatron-BERT
Name: Model, dtype: objec

# Regression analysis

In [29]:
dep_var = 'log_flop'

In [30]:
#@markdown Analysis of best fit to the data

@dataclass
class FitResult:
    df: pd.DataFrame
    p: int = None
    bic: float = None
    rss: float = None
    mse: float = None
    predict: Callable = None

@dataclass
class HyperbolicFitResult(FitResult):
    params: tuple[float] = None

@dataclass
class KinkedFitResult(FitResult):
    break_points: tuple[float] = None
    break_points_dt: float = None
    oom_year_slopes: tuple[float] = None

    # Model properties for each breakpoint combination
    # (for debugging)
    bics: tuple[float] = None
    rsss: tuple[float] = None
    mses: tuple[float] = None
    break_points_list: tuple[tuple[float]] = None
    break_points_dt_list: tuple[tuple[float]] = None

def fit_hyperbolic(df):
    def hyperbolic_model(t, A, B, k):
        return A / (1 + B * np.exp(-k * t))

    # Prepare data for curve fitting
    timestamp = pd.to_datetime(df['date']).apply(lambda date: date.toordinal()).values

    # Initial guess for the parameters
    # initial_guess = [0, 0, 0]
    initial_guess = [1.72373207e-02, -9.45447534e-01, -7.50101861e-08]  # Updated initial guess

    # Fit the model to the data
    try:
      params, covariance = curve_fit(hyperbolic_model, timestamp, df[dep_var], p0=initial_guess, maxfev=100000, ftol=1e-10)
    except RuntimeError as e:
      print("FATAL ERROR WHEN FITTING HYPERBOLIC")
      return None

    # Extracting parameters
    A, B, k = params

    # Compute predictions to calculate residuals
    predicted_log_y = hyperbolic_model(timestamp, *params)

    # Compute the Residual Sum of Squares (RSS)
    rss = np.sum((df[dep_var] - predicted_log_y) ** 2)

    # Number of observations (n)
    n = len(df[dep_var])

    # Number of parameters (p)
    p = len(params) + 1

    # Calculate log-likelihood under the assumption of normally distributed errors
    # log_likelihood = -0.5 * rss
    log_likelihood = -0.5 * n * (np.log(2 * np.pi * rss/n) + 1)

    # Compute bic_hyperbolic using the provided formula
    bic = p * np.log(n) - 2 * log_likelihood

    # Compute MSE
    mse = rss / n

    fit_result = HyperbolicFitResult(
        df=df,
        p=p,
        bic=bic,
        rss=rss,
        mse=mse,
        params=params,
        predict=lambda date: hyperbolic_model(date.apply(lambda d: d.toordinal()), *params)
    )

    return fit_result

def fit_n_phase_exponential(df, kink_count=0, allow_discontinuities=False, min_n_segment=10):
    # Generate monthly breakpoints between 2010 and 2024
    one_month = pd.DateOffset(months=1)
    break_point_grid = pd.date_range(start=df['date'].min() - one_month, end=df['date'].max() - 4*one_month, freq='MS')
    break_point_grid = [x.toordinal() for x in break_point_grid]

    x = pd.to_datetime(df['date']).apply(lambda date: date.toordinal()).values
    y = df[dep_var].values

    break_points_list = []
    bics = []
    rsss = []
    mses = []
    models = []

    for break_points in combinations_with_replacement(break_point_grid, kink_count):
        # Model predictors

        intercept_change_points = (0,)
        if allow_discontinuities:
            intercept_change_points += break_points
        slope_change_points = (0,) + break_points

        predictors = np.zeros((len(x), len(intercept_change_points) + len(slope_change_points)))

        for i, intercept_point in enumerate(intercept_change_points):
            predictors[:, i] = (x >= intercept_point).astype(int)

        for i, break_point in enumerate(slope_change_points):
            predictors[:, len(intercept_change_points) + i] = np.maximum(x - break_point, 0)

        # Fit the model
        model = sm.OLS(y, predictors).fit()

        # Check for negative discontinuities if discontinuities are allowed
        invalid_discontinuity = False
        if allow_discontinuities and break_points:
            # For each breakpoint, compare the predicted value just before and after
            for break_point in break_points:
                # Create predictor matrices for points just before and after the breakpoint
                before_predictors = np.zeros((1, len(intercept_change_points) + len(slope_change_points)))
                after_predictors = np.zeros((1, len(intercept_change_points) + len(slope_change_points)))
                
                # Fill in the predictor matrices
                for i, intercept_point in enumerate(intercept_change_points):
                    before_predictors[0, i] = (break_point - 1 >= intercept_point)
                    after_predictors[0, i] = (break_point >= intercept_point)
                
                for i, slope_point in enumerate(slope_change_points):
                    before_predictors[0, len(intercept_change_points) + i] = max(0, break_point - 1 - slope_point)
                    after_predictors[0, len(intercept_change_points) + i] = max(0, break_point - slope_point)
                
                # Get predictions
                before_value = model.predict(before_predictors)[0]
                after_value = model.predict(after_predictors)[0]
                
                # Check if there's a negative discontinuity
                if after_value < before_value:
                    invalid_discontinuity = True
                    break

        if invalid_discontinuity:
            continue

        # Calculate BIC manually based on log-likelihood
        n = len(x) # Number of observations
        p = len(model.params) + 2*kink_count + 1 # Number of parameters

        # Calculate log-likelihood under the assumption of normally distributed errors
        # We have to iterate over all points to get their individual log-likelihoods
        log_likelihood = 0
        rss = 0
        invalid_model = False # Discard models with segments with less than 2 points
        for i, break_point in enumerate(slope_change_points):
            left_x = break_point
            right_x = slope_change_points[i + 1] if i + 1 < len(slope_change_points) else np.inf

            segment_predictors = predictors[(left_x <= x) & (x < right_x), :]
            segment_y = y[(left_x <= x) & (x < right_x)]
            segment_n = len(segment_y)

            assert min_n_segment > 2

            if segment_n < min_n_segment:
                invalid_model = True
                break

            y_pred = model.predict(segment_predictors)

            segment_rss = np.sum((y_pred - segment_y)**2)
            if segment_rss == 0:
                print(f"segment_rss={segment_rss}")
                print(f"y_pred={y_pred}")
                print(f"segment_y={segment_y}")
                invalid_model = True
                break
            segment_mse = segment_rss / segment_n

            segment_log_likelihood = -segment_n/2 * (np.log(2*np.pi) + np.log(segment_rss/segment_n) + 1)
            log_likelihood += segment_log_likelihood
            rss += segment_rss

        if invalid_model:
            continue

        # Compute BIC using the manual method based on the log-likelihood
        bic = p * np.log(n) - 2 * log_likelihood
        # bic = n*np.log(rss/n) + p*np.log(n)

        bics.append(bic)
        rsss.append(rss)
        mses.append(rss/len(df))
        models.append(model)
        break_points_list.append(break_points)

    if len(bics) == 0:
        return None

    # Prepare the result object
    best_bic = min(bics)
    best_idx = bics.index(best_bic)
    best_rss = rsss[best_idx]
    best_mse = mses[best_idx]
    best_model = models[best_idx]
    best_break_points = break_points_list[best_idx]

    p = len(best_model.params) + 2*kink_count + 1 # Number of parameters

    intercept_change_points = (0,)
    if allow_discontinuities:
        intercept_change_points += best_break_points
    slope_change_points = (0,) + best_break_points

    intercepts = best_model.params[:len(intercept_change_points)]
    oom_year_slopes = 365 * np.cumsum(best_model.params[len(intercepts):])

    def predict(date):
        if not isinstance(date, pd.Series):
            date = pd.Series(date)
        x = pd.to_datetime(date).apply(lambda date: date.toordinal()).values

        predictors = np.zeros((len(x), len(intercept_change_points) + len(slope_change_points)))

        for i, intercept_point in enumerate(intercept_change_points):
            predictors[:, i] = (x >= intercept_point).astype(int)

        for i, break_point in enumerate(slope_change_points):
            predictors[:, len(intercept_change_points) + i] = np.maximum(x - break_point, 0)

        return best_model.predict(predictors)

    fit_result = KinkedFitResult(
        df=df,
        p=p,
        bic=best_bic,
        rss=best_rss,
        mse=best_mse,
        break_points=best_break_points,
        predict=predict,
        break_points_dt=[pd.Timestamp.fromordinal(bp) for bp in best_break_points],
        bics=bics,
        rsss=rsss,
        mses=mses,
        oom_year_slopes=oom_year_slopes,
        break_points_list=break_points_list,
        break_points_dt_list=[[pd.Timestamp.fromordinal(bp) for bp in break_points] for break_points in break_points_list],
    )

    return fit_result

def calculate_lag(df, fit_results, date=None):
    if date is None:
        date = df['date'].max()

    # Get the predictions for the two categories
    y_non_china = fit_results['Non-China'].predict(pd.Series([date]))[0]
    y_china = fit_results['China'].predict(pd.Series([date]))[0]
    
    # Get the final slope for the 'China' category
    slope_non_china = fit_results['Non-China'].oom_year_slopes[-1]
    
    # Calculate lag
    lag = (y_non_china - y_china) / slope_non_china
    
    return lag


## Model selection

In [31]:
fit_em_all = lambda df_fit : {
    "Simple" : fit_n_phase_exponential(df_fit, kink_count=0),
    "One kink" : fit_n_phase_exponential(df_fit, kink_count=1),
    "Discontinuity" : fit_n_phase_exponential(df_fit, kink_count=1, allow_discontinuities=True),
    # "Hyperbolic": fit_hyperbolic(df_fit),
}

# Best model fits
print(f"Fitting China and Non-China models")
regression_data = {
    'China': {},
    'Non-China': {},
}
regression_data['China']['models'] = fit_em_all(df_filtered[df_filtered['category'] == 'China'])
regression_data['Non-China']['models'] = fit_em_all(df_filtered[df_filtered['category'] == 'Non-China'])


Fitting China and Non-China models


In [32]:
# K-Fold Cross Validation
def perform_cross_validation(df, k=10, random_state=42):
    kf = KFold(n_splits=k, shuffle=True, random_state=random_state)
    folds_mses = defaultdict(lambda : [])
    for train_index, test_index in kf.split(df):
        train_df, test_df = df.iloc[train_index], df.iloc[test_index]

        # Fit the models on the training set
        fold_models = fit_em_all(train_df)

        # Predict on the test set
        for name,model in fold_models.items():
            try:
                predicted_log_y = model.predict(test_df["date"])
            except AttributeError:
                continue
            test_rss = np.sum((predicted_log_y - test_df[dep_var])**2)
            test_mse = test_rss / len(test_df)
            folds_mses[name].append(test_mse)

    # Compute mean MSE
    folds_mses = {name: np.mean(folds_mses[name]) for name in folds_mses}

    return folds_mses

regression_data['China']['folds_mses'] = perform_cross_validation(df_filtered[df_filtered['category'] == 'China'])
regression_data['Non-China']['folds_mses'] = perform_cross_validation(df_filtered[df_filtered['category'] == 'Non-China'])

In [33]:
state = None
for i in range(1000):
    try:
        # Save the state of the original RNG's bit generator
        state = DEFAULT_RNG.bit_generator.state
        sample = df_filtered.sample(len(df_filtered), replace=True, random_state=DEFAULT_RNG)
        fit_em_all(sample[sample['category'] == 'Non-China'])
    except Exception as e:
        print(i)
        print(DEFAULT_RNG)
        print(e)
        break

In [34]:
error_state = state.copy()
# Initialize a new RNG
new_rng = np.random.default_rng()
# Set the new RNG's bit generator state to the saved state
new_rng.bit_generator.state = state
# Reproduce the error
sample = df_filtered.sample(len(df_filtered), replace=True, random_state=new_rng)
fit_em_all(sample[sample['category'] == 'Non-China'])

{'Simple': KinkedFitResult(df=                 Model          flop       date              Organization  \
 704        Grover-Mega  5.700000e+21 2019-05-29  University of Washington   
 1700  Gemini 1.0 Ultra  5.000000e+25 2023-12-06           Google DeepMind   
 803              Meena  1.120000e+23 2020-01-28              Google Brain   
 1657       GPT-4 Turbo  2.100000e+25 2023-11-06                    OpenAI   
 927             Switch  8.220000e+22 2021-01-11                    Google   
 ...                ...           ...        ...                       ...   
 1021  Jurassic-1-Jumbo  3.700000e+23 2021-08-11                 AI21 Labs   
 1124             LaMDA  3.550000e+23 2022-02-10                    Google   
 667       GPT-2 (1.5B)  1.920000e+21 2019-02-14                    OpenAI   
 927             Switch  8.220000e+22 2021-01-11                    Google   
 1179        BIG-G 137B  5.600000e+23 2022-06-09                    Google   
 
                  Notability crit

In [35]:
# Bootstrap
bootstrap_sample_size = 1000

pred_start_date = df_filtered['date'].min()
pred_end_date = df_filtered['date'].max()

regression_data['China']['bootstrap_predictions'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_predictions'] = defaultdict(lambda : [])

regression_data['China']['bootstrap_bics'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_bics'] = defaultdict(lambda : [])

regression_data['China']['bootstrap_mses'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_mses'] = defaultdict(lambda : [])

regression_data['China']['bootstrap_bic_score_diff'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_bic_score_diff'] = defaultdict(lambda : [])

regression_data['China']['bootstrap_slopes'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_slopes'] = defaultdict(lambda : [])

regression_data['China']['bootstrap_breaks'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_breaks'] = defaultdict(lambda : [])

rng = np.random.default_rng(20250103)

from joblib import Parallel, delayed
from tqdm.notebook import tqdm

def bootstrap_iteration(bootstrap_index, category, df_filtered, pred_start_date, pred_end_date, rng_seed):
    if bootstrap_index == 0:
        # Use the original data as the first bootstrap sample
        sample = df_filtered.copy()
    else:
        sample = df_filtered.sample(len(df_filtered), replace=True, random_state=rng_seed)
    sample = sample[sample['category'] == category]
    sample = sample.sort_values('date')

    # Compute BICs
    boot_models = fit_em_all(sample)
    if any(model is None for model in boot_models.values()):
        return None

    # Compute K fold validation
    boot_folds_mses = perform_cross_validation(sample)

    # Initialize local storage
    local_bics = {}
    local_mses = {}
    local_bic_diff = {}
    local_slopes = {}
    local_breaks = {}
    local_predictions = {}

    # Store results
    for name, model in boot_models.items():
        local_bics[name] = model.bic
        local_mses[name] = boot_folds_mses.get(name, np.nan)
        local_bic_diff[name] = model.bic - boot_models.get("Simple", model.bic).bic

        if isinstance(model, KinkedFitResult):
            if len(model.oom_year_slopes) > 0:
                local_slopes[name] = 10**model.oom_year_slopes[-1]
            if len(model.break_points_dt) > 0:
                local_breaks[name] = model.break_points_dt[-1]

    # Store predictions for confidence intervals
    for name, model in boot_models.items():
        try:
            date_grid = pd.date_range(start=pred_start_date, end=pred_end_date, freq='MS')
            pred = model.predict(pd.Series(date_grid))
            local_predictions[name] = pred
        except AttributeError:
            continue

    return (local_bics, local_mses, local_bic_diff, local_slopes, local_breaks, local_predictions)


def bootstrap_with_retry(bootstrap_index, category, df_filtered, pred_start_date, pred_end_date, max_retries=10):
    rng = np.random.default_rng(bootstrap_index)  # Deterministic seed per worker
    
    for retry in range(max_retries):
        try:
            result = bootstrap_iteration(
                bootstrap_index,
                category,
                df_filtered,
                pred_start_date,
                pred_end_date,
                rng.integers(0, 1e9)
            )
            if result is not None:
                return {
                    'success': True,
                    'result': result,
                    'retries': retry
                }
        except Exception as e:
            if retry == max_retries - 1:
                return {
                    'success': False,
                    'error': str(e),
                    'retries': retry + 1
                }
            continue
    
    return {
        'success': False,
        'error': 'Max retries exceeded',
        'retries': max_retries
    }


for category in ['China', 'Non-China']:
    print(f"Bootstrapping {category} data")
    
    # Run parallel bootstrap with retries
    bootstrap_results = Parallel(n_jobs=-1)(
        delayed(bootstrap_with_retry)(
            i,
            category,
            df_filtered,
            pred_start_date,
            pred_end_date
        )
        for i in range(bootstrap_sample_size)
    )
    
    # Analyze results and retry statistics
    successful_results = [r['result'] for r in bootstrap_results if r['success']]
    total_retries = sum(r['retries'] for r in bootstrap_results)
    failed_bootstraps = sum(1 for r in bootstrap_results if not r['success'])
    
    print(f"Bootstrap statistics for {category}:")
    print(f"- Success rate: {(len(successful_results)/bootstrap_sample_size):.1%}")
    print(f"- Average retries: {total_retries/bootstrap_sample_size}")
    print(f"- Failed bootstraps: {failed_bootstraps}")
    
    # Process successful results
    for res in successful_results:
        local_bics, local_mses, local_bic_diff, local_slopes, local_breaks, local_predictions = res
        
        # Update storage as before
        for name, bic in local_bics.items():
            regression_data[category]['bootstrap_bics'][name].append(bic)
        # Update MSEs
        for name, mse in local_mses.items():
            regression_data[category]['bootstrap_mses'][name].append(mse)

        # Update BIC score differences
        for name, diff in local_bic_diff.items():
            regression_data[category]['bootstrap_bic_score_diff'][name].append(diff)

        # Update slopes
        for name, slope in local_slopes.items():
            regression_data[category]['bootstrap_slopes'][name].append(slope)

        # Update break points
        for name, break_pt in local_breaks.items():
            regression_data[category]['bootstrap_breaks'][name].append(break_pt)

        # Update predictions
        for name, pred in local_predictions.items():
            regression_data[category]['bootstrap_predictions'][name].append(pred)

Bootstrapping China data
Bootstrap statistics for China:
- Success rate: 100.0%
- Average retries: 0.0
- Failed bootstraps: 0
Bootstrapping Non-China data
Bootstrap statistics for Non-China:
- Success rate: 100.0%
- Average retries: 0.007
- Failed bootstraps: 0


In [44]:
ci_width = 0.90
qs = [(1 - ci_width)/2, (1 + ci_width)/2]
bootstrap_preferred_percent = {}
bootstrap_summary_data = {
    'China': defaultdict(lambda: {}),
    'Non-China': defaultdict(lambda: {}),
}
for category in ['China', 'Non-China']:
    for name in regression_data[category]['models']:
        bootstrap_summary_data[category]['bootstrap_preferred_percent'][name] = np.mean(np.array(regression_data[category]['bootstrap_bic_score_diff'][name])<0)
        bootstrap_summary_data[category]['bootstrap_bics'][name] = np.quantile(np.array(regression_data[category]['bootstrap_bics'][name]), qs)
        bootstrap_summary_data[category]['bootstrap_mses'][name] = np.quantile(np.array(regression_data[category]['bootstrap_mses'][name]), qs)
        bootstrap_summary_data[category]['bootstrap_bic_score_diff'][name] = np.quantile(np.array(regression_data[category]['bootstrap_bic_score_diff'][name]), qs)
        try:
            bootstrap_summary_data[category]['bootstrap_slopes'][name] = np.quantile(np.array(regression_data[category]['bootstrap_slopes'][name]), qs)
            bootstrap_summary_data[category]['bootstrap_breaks'][name] = np.quantile(np.array(regression_data[category]['bootstrap_breaks'][name]), qs)
        except IndexError:
            pass

# Models with lower BIC score / MSE are preferred.

results = {
    'China': [],
    'Non-China': [],
}

for category in ['China', 'Non-China']:
    for name, model in regression_data[category]['models'].items():
        param_count = model.p
        log_likelihood = (np.log(len(df_filtered))*param_count - model.bic)/2

        param_count_simple = regression_data[category]['models']['Simple'].p
        log_likelihood_simple = (np.log(len(df_filtered))*param_count_simple - regression_data[category]['models']['Simple'].bic)/2

        c2 = chi2.sf(2*(log_likelihood - log_likelihood_simple), df=(param_count - param_count_simple))

        result = {
            "Model": name,
            "BIC" : np.round(model.bic, 2),
            "BIC 90% CI" : np.round(bootstrap_summary_data[category]['bootstrap_bics'][name], 2),
            #"Parameter count": param_count,
            #"Log likelihood": np.round((np.log(len(df_filtered))*param_count - model.bic)/2),
            # "MSE" : model.mse,
            "BIC score diff": np.round(model.bic - regression_data[category]['models']['Simple'].bic, 2),
            "BIC score diff 90% CI": np.round(bootstrap_summary_data[category]['bootstrap_bic_score_diff'][name], 2),
            "Xi²": c2,
            "% times preferred over simple": f"{bootstrap_summary_data[category]['bootstrap_preferred_percent'][name]:.0%}",
            # "bayes factor over simple" : np.exp(-0.5 * (model.bic - models["simple"].bic)),
            "K-fold mean MSE" : np.round(regression_data[category]['folds_mses'][name], 2),
            "K-fold mean MSE 90% CI" : np.round(bootstrap_summary_data[category]['bootstrap_mses'][name], 2),
        }

        try:
            result["Recent slope (Nx/year)"] = np.round(10**model.oom_year_slopes[-1], 2)
            result["Recent slope 90% CI"] = np.round(bootstrap_summary_data[category]['bootstrap_slopes'][name], 2)
            result["Break point"] = model.break_points_dt[-1].strftime('%Y-%m')
            result["Break point 90% CI"] = [date.strftime('%Y-%m') for date in bootstrap_summary_data[category]['bootstrap_breaks'][name]]
        except (AttributeError, IndexError):
            pass
        results[category].append(result)

results = {category: pd.DataFrame(results[category]) for category in ['China', 'Non-China']}

lag_results = []
for name, model in regression_data['Non-China']['models'].items():
    for name_china, model_china in regression_data['China']['models'].items():
        lags = []
        for i in range(bootstrap_sample_size):
            last_pred_china = regression_data['China']['bootstrap_predictions'][name_china][i][-1]
            last_pred_non_china = regression_data['Non-China']['bootstrap_predictions'][name][i][-1]
            lag = (last_pred_non_china - last_pred_china) / model.oom_year_slopes[-1]
            lags.append(lag)
        lag_result = {
            "Non-China model": name,
            "China model": name_china,
            "Point estimate": np.round(lags[0], 2),  # First bootstrap sample is the original data
            "Mean": np.round(np.mean(lags), 2),
            "90% CI": np.round(np.quantile(lags, qs), 2),
        }
        lag_results.append(lag_result)
lag_results = pd.DataFrame(lag_results)
# bayes_factor = np.exp(-0.5 * (kinked_fit.bic - simple_fit.bic))

print("Bootstrapped regression results")
for category in ['China', 'Non-China']:
    print(category)
    display(results[category])
print("Lag results (lag values in years)")
display(lag_results)

Bootstrapped regression results
China


Unnamed: 0,Model,BIC,BIC 90% CI,BIC score diff,BIC score diff 90% CI,Xi²,% times preferred over simple,K-fold mean MSE,K-fold mean MSE 90% CI,Recent slope (Nx/year),Recent slope 90% CI,Break point,Break point 90% CI
0,Simple,175.1,"[136.36, 202.34]",0.0,"[0.0, 0.0]",,0%,1.29,"[0.74, 1.57]",16.24,"[10.6, 22.71]",,
1,One kink,135.21,"[97.7, 152.65]",-39.89,"[-70.87, -19.36]",1.041998e-11,100%,0.69,"[0.34, 0.83]",3.03,"[2.02, 5.88]",2021-09,"[2021-04, 2022-01]"
2,Discontinuity,133.06,"[97.24, 151.22]",-42.04,"[-74.31, -20.32]",1.741736e-12,100%,0.73,"[0.33, 0.78]",3.31,"[2.4, 4.9]",2021-04,"[2021-02, 2021-06]"


Non-China


Unnamed: 0,Model,BIC,BIC 90% CI,BIC score diff,BIC score diff 90% CI,Xi²,% times preferred over simple,K-fold mean MSE,K-fold mean MSE 90% CI,Recent slope (Nx/year),Recent slope 90% CI,Break point,Break point 90% CI
0,Simple,70.93,"[46.83, 87.57]",0.0,"[0.0, 0.0]",,0%,0.17,"[0.11, 0.21]",4.78,"[4.29, 5.27]",,
1,One kink,69.47,"[38.17, 81.7]",-1.46,"[-23.61, 4.35]",0.001292,75%,0.17,"[0.11, 0.21]",4.32,"[3.83, 39.82]",2019-12,"[2019-10, 2024-03]"
2,Discontinuity,82.05,"[44.32, 90.08]",11.12,"[-15.61, 12.53]",0.095571,34%,0.18,"[0.11, 0.22]",5.27,"[0.97, 29.16]",2023-11,"[2019-10, 2024-02]"


Lag results (lag values in years)


Unnamed: 0,Non-China model,China model,Point estimate,Mean,90% CI
0,Simple,Simple,0.59,0.6,"[0.13, 1.02]"
1,Simple,One kink,1.71,1.66,"[1.2, 2.07]"
2,Simple,Discontinuity,1.67,1.66,"[1.25, 2.01]"
3,One kink,Simple,0.52,0.9,"[0.15, 1.8]"
4,One kink,One kink,1.72,2.03,"[1.26, 2.93]"
5,One kink,Discontinuity,1.67,2.03,"[1.32, 2.86]"
6,Discontinuity,Simple,0.73,0.65,"[-0.0, 1.43]"
7,Discontinuity,One kink,1.79,1.64,"[0.98, 2.46]"
8,Discontinuity,Discontinuity,1.74,1.64,"[1.02, 2.4]"


In [45]:
# Find the best model for each category
simplicity_order = ['Simple', 'One kink', 'Discontinuity']
selected_model = {}
for category in ['China', 'Non-China']:
    df = results[category]
    argmin_bic = df['BIC'].argmin()
    min_bic = df['BIC'].iloc[argmin_bic]
    min_bic_model = df['Model'].iloc[argmin_bic]
    min_bic_mse = df['K-fold mean MSE'].iloc[argmin_bic]
    # Iterate over the models in simplicity order
    # If a simpler model is not preferred, the selected model is the one with the lowest BIC
    for model in simplicity_order:
        # Check if the BICs are close
        if df[df['Model'] == model]['BIC'].iloc[0] - min_bic < 2:
            # Check if the MSEs are close
            if df[df['Model'] == model]['K-fold mean MSE'].iloc[0] - min_bic_mse < 0.01:
                selected_model[category] = model
                break
    print(f"Best model for {category}: {selected_model[category]}")


Best model for China: Discontinuity
Best model for Non-China: Simple


In [46]:
# Save results_df
for category in ['China', 'Non-China']:
    regression_fname = f'compute_regression_analysis_{category}_{model_selection}_frontier={frontier_selection}_top_n={top_n}_cutoff={cutoff_date}.csv'
    results[category].to_csv(os.path.join(results_dir, regression_fname), index=False)
lag_results.to_csv(os.path.join(results_dir, f'lags_{model_selection}_frontier={frontier_selection}_top_n={top_n}_cutoff={cutoff_date}.csv'), index=False)


# Plot predictions with bootstrapped CIs

In [47]:
regression_data['China']['bootstrap_predictions']

defaultdict(<function __main__.<lambda>()>,
            {'Simple': [array([17.62591311, 17.72540789, 17.82821917, 17.92771395, 18.03052523,
                     18.13333651, 18.2261983 , 18.32900958, 18.42850436, 18.53131564,
                     18.63081042, 18.7336217 , 18.83643298, 18.93592776, 19.03873903,
                     19.13823382, 19.24104509, 19.34385637, 19.44003466, 19.54284594,
                     19.64234072, 19.745152  , 19.84464678, 19.94745806, 20.05026933,
                     20.14976412, 20.25257539, 20.35207018, 20.45488145, 20.55769273,
                     20.65055453, 20.7533658 , 20.85286059, 20.95567186, 21.05516665,
                     21.15797792, 21.2607892 , 21.36028398, 21.46309526, 21.56259004,
                     21.66540132, 21.7682126 , 21.86107439, 21.96388567, 22.06338045,
                     22.16619173, 22.26568651, 22.36849779, 22.47130907, 22.57080385,
                     22.67361513, 22.77310991, 22.87592119, 22.97873246, 23.07159426,


In [48]:
def calculate_confidence_intervals(bootstrap_preds, percentile=90):
    lower_percentile = (100 - percentile) / 2
    upper_percentile = 100 - lower_percentile
    ci = {}
    for model, preds in bootstrap_preds.items():
        preds_array = np.array(preds)  # Shape: (bootstrap_samples, n_dates)
        lower = np.percentile(preds_array, lower_percentile, axis=0)
        upper = np.percentile(preds_array, upper_percentile, axis=0)
        ci[model] = {'lower': lower, 'upper': upper}
    return ci

In [49]:
# Calculate 90% Confidence Intervals
confidence_intervals = calculate_confidence_intervals(regression_data['China']['bootstrap_predictions'], percentile=90)
confidence_intervals

{'Simple': {'lower': array([16.90518134, 17.01567433, 17.12985041, 17.24034339, 17.35451947,
         17.46869555, 17.57182234, 17.68599842, 17.79488018, 17.90955742,
         18.0191574 , 18.13241476, 18.24567212, 18.35589622, 18.4729625 ,
         18.58467604, 18.69886222, 18.81304839, 18.91986772, 19.03311638,
         19.14378086, 19.25865164, 19.36720196, 19.47972201, 19.59512287,
         19.7062548 , 19.82060974, 19.92982785, 20.04331324, 20.15763141,
         20.26193519, 20.37470188, 20.4855036 , 20.59970864, 20.70917228,
         20.82000491, 20.93177517, 21.04055331, 21.15383218, 21.25942205,
         21.37239536, 21.48507535, 21.58451403, 21.69510825, 21.80112611,
         21.9124499 , 22.02243812, 22.13484547, 22.24436458, 22.35160125,
         22.46285489, 22.5697687 , 22.68023443, 22.78713226, 22.88527947,
         22.99610069, 23.10186848, 23.21192889, 23.3169383 , 23.4176634 ,
         23.51999177, 23.61782137, 23.72042913, 23.81531564, 23.91634945,
         24.0124107

In [50]:
# Graph of the different model fits using plotly
model_types = selected_model
# Custom:
# model_types = {
#     'Non-China': 'Simple',
#     'China': 'Discontinuity',
# }
# Parameters for each model selection
model_params = {
    'Simple': {
        'kink_count': 0,
        'allow_discontinuities': False,
    },
    'One kink': {
        'kink_count': 1,
        'allow_discontinuities': False,
    },
    'Discontinuity': {
        'kink_count': 1,
        'allow_discontinuities': True,
    }
}

def plot_model(df, model_types, model_params):
    fig = go.Figure()

    # Plot the original data points
    df_non_china = df[df['category'] == 'Non-China']
    df_china = df[df['category'] == 'China']

    fig.add_trace(go.Scatter(
        x=df_non_china['date'], y=df_non_china['log_flop'],
        mode='markers', name='Not developed in China', text=df_non_china['Model'],
        marker=dict(color=colors['Non-China'], opacity=0.1, size=10)
    ))
    fig.add_trace(go.Scatter(
        x=df_china['date'], y=df_china['log_flop'],
        mode='markers', name='Developed in China', text=df_china['Model'],
        marker=dict(color=colors['China'], opacity=0.1, size=10)
    ))

    # Show the export controls date
    # Convert 2022-10-07 to seconds since epoch
    export_controls_date = pd.Timestamp('2022-10-07').value / 1e6
    fig.add_vline(x=export_controls_date, line_color='black', line_width=1, line_dash='dot', 
        annotation_text='October 2022<br>Export controls introduced', annotation_position='bottom right')

    # non_china_date_grid = pd.date_range(start=non_china_df['date'].min(), end=non_china_df['date'].max(), freq='MS')
    # china_date_grid = pd.date_range(start=china_df['date'].min(), end=china_df['date'].max(), freq='MS')
    date_grid = pd.date_range(start=df['date'].min(), end=df['date'].max(), freq='MS')

    trend_dfs = {}
    ci_dfs = {}
    fit_results = {}
    for category, model_type in model_types.items():
        params = model_params[model_type]
        fit_result = fit_n_phase_exponential(df[df['category'] == category], **params)
        fit_results[category] = fit_result

        # Get the month of the first point, then use that to index the date_grid
        start_month = df[df['category'] == category]['date'].min().month + 1
        start_index = np.where(date_grid.month == start_month)[0][0]
        log_flop = fit_result.predict(pd.Series(date_grid))
        # To plot the bootstrapped mean prediction instead:
        # log_flop = np.mean(regression_data[category]['bootstrap_predictions'][model_type], axis=0)

        trend_dfs[category] = pd.DataFrame({
            'date': date_grid[start_index:],
            'log_flop': log_flop[start_index:],
        })

        # Get the confidence intervals
        ci_data = calculate_confidence_intervals(regression_data[category]['bootstrap_predictions'], percentile=90)
        ci_dfs[category] = pd.DataFrame({
            'date': date_grid[start_index:],
            'lower': ci_data[model_type]['lower'][start_index:],
            'upper': ci_data[model_type]['upper'][start_index:],
        })

        # Plot the best fit line with confidence intervals
        fig.add_trace(go.Scatter(
            x=date_grid[start_index:], y=log_flop[start_index:],
            mode='lines', name=f'{category} best fit line',
            line=dict(color=colors[category], width=1),
            showlegend=False,
        ))
        fig.add_trace(go.Scatter(
            x=date_grid[start_index:],
            y=ci_data[model_type]['lower'][start_index:],
            mode='lines',
            line=dict(color=colors[category], width=0),
            showlegend=False,
        ))
        fig.add_trace(go.Scatter(
            x=date_grid[start_index:],
            y=ci_data[model_type]['upper'][start_index:],
            mode='lines',
            fill='tonexty',
            fillcolor='rgba(0,0,255,0.1)' if category == 'Non-China' else 'rgba(255,0,0,0.1)',
            line=dict(color=colors[category], width=0),
            name=f'{category} 90% CI',
            showlegend=False,
        ))

    # Add slope labels
    for category in ['Non-China', 'China']:
        category_df = df[df['category'] == category]
        points = [category_df['date'].min()] + fit_results[category].break_points_dt + [category_df['date'].max()]
        model_type = model_types[category]
        best_slope = 10**regression_data[category]['models'][model_type].oom_year_slopes[-1]
        slopes = bootstrap_summary_data[category]['bootstrap_slopes'][model_type]
        slope_label = f'{best_slope:.1f}x/year<br>90% CI: {slopes[0]:.1f}-{slopes[1]:.1f}x/year'
        for i in range(len(points) - 2, len(points) - 1):
            mid = points[i] + (points[i+1] - points[i]) / 2
            if category == 'China':
                mid += pd.Timedelta(days=150)
            y = fit_results[category].predict(pd.Series([mid]))[0]
            fig.add_annotation(
                x=mid, y=y + 1.8 * (1 if category == 'Non-China' else -1),
                text=slope_label,
                showarrow=False,
                font=dict(size=12, color=colors[category])
            )

    # Update layout
    title = f'Frontiers of training compute for LLMs inside and outside China'
    fig.update_layout(
        template='plotly_white',
        width=800,
        height=400,
        title=title,
        xaxis_title='Publication date',
        yaxis_title='Training compute (FLOP)',
        legend_title='',
        margin=dict(l=10, r=10, t=40, b=10),
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(int(df['log_flop'].min()), int(df['log_flop'].max())+2, 2)),
            ticktext=[f'10<sup>{i}</sup>' for i in range(int(df['log_flop'].min()), int(df['log_flop'].max())+2, 2)]
        )
    )

    if save:
        fname = f'compute_regression_china_model={model_types["China"]}_non_china_model={model_types["Non-China"]}_model_selection={model_selection}_frontier={frontier_selection}_top{top_n}_cutoff={cutoff_date}'
        save_plot(fig, results_dir, fname)
        
        slope_df = pd.DataFrame({
            'Category': ['Non-China', 'China'],
            'Best fit slope': [10**regression_data[category]['models'][model_types[category]].oom_year_slopes[-1] for category in ['Non-China', 'China']],
            '90% CI lower': [bootstrap_summary_data[category]['bootstrap_slopes'][model_types[category]][0] for category in ['Non-China', 'China']],
            '90% CI upper': [bootstrap_summary_data[category]['bootstrap_slopes'][model_types[category]][1] for category in ['Non-China', 'China']],
        })
        slope_df.to_csv(results_dir + f'plot_data/recent_slopes_{fname}.csv', index=False)

        df_non_china[['Model', 'Organization', 'date', 'log_flop']].to_csv(results_dir + f'plot_data/non_china_scatter_{fname}.csv', index=False)
        df_china[['Model', 'Organization', 'date', 'log_flop']].to_csv(results_dir + f'plot_data/china_scatter_{fname}.csv', index=False)
        trend_dfs['Non-China'][['date', 'log_flop']].to_csv(results_dir + f'plot_data/non_china_best_fit_line_{fname}.csv', index=False)
        trend_dfs['China'][['date', 'log_flop']].to_csv(results_dir + f'plot_data/china_best_fit_line_{fname}.csv', index=False)
        ci_dfs['Non-China'][['date', 'lower', 'upper']].to_csv(results_dir + f'plot_data/non_china_ci_{fname}.csv', index=False)
        ci_dfs['China'][['date', 'lower', 'upper']].to_csv(results_dir + f'plot_data/china_ci_{fname}.csv', index=False)


    for category in ['Non-China', 'China']:
        print(category)
        fit_result = fit_results[category]
        simple_fit = fit_n_phase_exponential(df[df['category'] == category], 0)
        bayes_factor = np.exp(-0.5 * (fit_result.bic - simple_fit.bic))
        unadjusted_bayes_factor = np.exp(-0.5 * (fit_result.bic - (simple_fit.bic + 2*np.log(len(df[df['category'] == category])))))

        print(f"BIC score: {fit_result.bic}")
        bic_score_difference = fit_result.bic - simple_fit.bic
        if bic_score_difference > 0:
            print(f"The simple exponential is preferred over this fit by a BIC score difference of {fit_result.bic - simple_fit.bic}")
        if bic_score_difference < 0:
            print(f"This fit is preferred over a simple exponential by a BIC score difference of {-(fit_result.bic - simple_fit.bic)}")

    fig.show()

    return fit_result

fit_result = plot_model(df_filtered, model_types, model_params)

Non-China
BIC score: 70.92834709845775
China
BIC score: 133.06409979988965
This fit is preferred over a simple exponential by a BIC score difference of 42.039603834285714


# Trends in the cumulative maximum

In [51]:
def get_frontier(df, x_col, y_col):
    # Sort the dataframe by x_col ascending and y_col descending
    df_sorted = df.sort_values([x_col, y_col], ascending=[True, False])
    frontier = []
    max_y = 0.3  # Excludes models below random chance level
    for _, row in df_sorted.iterrows():
        if row[y_col] > max_y:
            frontier.append(row)
            max_y = row[y_col]
    return pd.DataFrame(frontier)

fig = go.Figure()
for category in ['Non-China', 'China']:
    frontier = get_frontier(df_filtered[df_filtered['category'] == category], 'date', 'log_flop')
    fig.add_trace(go.Scatter(
        x=frontier['date'],
        y=frontier['log_flop'],
        mode='lines', 
        text=frontier['Model'],
        name=f'{category} frontier',
        line=dict(color=colors[category], width=1),
    ))
fig.show()