# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
import json
from itertools import combinations_with_replacement
import numpy as np
import os
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy.stats import chi2
from sklearn.model_selection import KFold
import statsmodels.api as sm
from tqdm import tqdm

from data import *
from plotting import *
from regression import *
from utils import *

# Parameters

In [3]:
assign_country_method = 'inclusive'  # ['inclusive', 'exclusive']. Default: 'inclusive'.
# 'external': Filter to the top n models overall
# 'internal': Filter to the top n models within 'Non-China' and 'China' categories
# 'disabled': No filtering
frontier_selection = 'external'  # ['disabled', 'internal', 'external']. Default: 'external'.
top_n = 10  # Filter to the top n models by training compute at time of release. Default: 10.
model_selection = 'Language models'  # ['All models', 'Language models', 'Google DeepMind models', 'OpenAI models', 'Meta AI models']. Default: 'Language models'.
filter_alphago_outliers = True  # Whether to filter out AlphaGo Master and AlphaGo Zero. Default: True.
filter_finetuned_models = True  # Whether to filter out separate finetuned models (base + finetuned models are still included if there is no separate base model in our dataset). Default: True.
include_speculative_compute = True  # Whether to include speculative compute estimates that rely on benchmark imputation and rough guesses. Default: True.
cutoff_date = '2018-01-01'  # When to start the regressions from. Default: '2018-01-01'.
top_n_cutoff_date = '1950-01-01'  # When to split the top-n filtering into Non-China and China categories - set to e.g. 2010 to turn off the "kickstarting". Default: '1950-01-01'.
save = True  # Whether to save the plots. Default: True.

In [4]:
# Default: no models excluded
exclude_models = []

# Early China models that are not representative of current trends
# exclude_models = [
#     'genCNN + dyn eval',
#     'R-FCN',
#     'ResNet-200',
#     '2-layer-LSTM+Deep-Gradient-Compression',
# ]

# Key China models around the breakpoint
# exclude_models = [
#     'ERNIE 3.0 Titan',
#     'Yuan 1.0',
# ]

# Largest China model
# exclude_models = [
#     'GLM-4 (0116)',
# ]

# All large BlueLMs - it's not clear they were ever released
# exclude_models = [
#     'BlueLM 70B',
#     'BlueLM 130B',
#     'BlueLM 175B',
# ]

In [5]:
results_dir = 'results/compute/03Dec_update_plot_data/'
os.makedirs(results_dir, exist_ok=True)
os.makedirs(results_dir + 'plot_data', exist_ok=True)

In [6]:
colors = {'Non-China': 'blue', 'China': 'red'}


# Data preparation

In [7]:
# Load data
pcd_df = load_pcd_df()

In [8]:
pcd_df

Unnamed: 0,Model,Domain,Task,Authors,Notability criteria,Notability criteria notes,Model accessibility,Link,Citations,Reference,...,Assumed hardware FLOP/s,Hardware type,Compute estimate method,Training compute estimation method,Biological model safeguards,Hardware utilization (temp),BenchmarkHub-v1,Field 90,Post-training compute (FLOP),Post-training compute notes
0,babbage-002,Language,Language modelling,,,,,,,,...,,,,,,,,,,
1,tts-1,Speech,Text-to-speech,,,,,,,,...,,,,,,,,,,
2,tts-1-hd,Speech,Text-to-speech,,,,,,,,...,,,,,,,,,,
3,LM-Design,Biology,Protein design,"Zaixiang Zheng, Yifan Deng, Dongyu Xue, Yi Zho...",,,,https://proceedings.mlr.press/v202/zheng23a.html,46.0,Structure-informed Language Models Are Protein...,...,,,,,LM-Design,,,,,
4,Genie (bio),Biology,,,,,,https://arxiv.org/abs/2301.12485,,"Generating Novel, Designable, and Diverse Prot...",...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2075,π0 (pi-zero),"Robotics,Vision",Robotic manipulation,"Kevin Black, Noah Brown, Danny Driess, Adnan E...",,,Unreleased,https://www.physicalintelligence.company/downl...,,π0: Our First Generalist Policy,...,,,,,,,,,,
2076,Hunyuan-Large,Language,"Language modelling/generation,Question answeri...","Xingwu Sun, Yanfeng Chen, Yiqing Huang, Ruobin...",,,Open weights (restricted use),https://arxiv.org/abs/2411.02265,,Hunyuan-Large: An Open-Source MoE Model with 5...,...,,,,"Operation counting,Other",,,,,,
2077,Qwen2.5-Coder (32B),Language,"Language modelling/generation,Code generation","Binyuan Hui, Jian Yang, Zeyu Cui, Jiaxi Yang, ...",,,Open weights (unrestricted),https://arxiv.org/abs/2409.12186,,Qwen2.5-Coder Technical Report,...,,,,Operation counting,,,,,,
2078,Tulu 3 (Tülu 3) 70B,Language,"Language modelling/generation,Protein question...","Nathan Lambert, Jacob Morrison, Valentina Pyat...",,,Open weights (restricted use),https://allenai.org/papers/tulu-3-report.pdf,,TÜLU 3: Pushing Frontiers in Open Language Mod...,...,,,,,,,,,,


In [9]:
pcd_df = pcd_df[~pcd_df['Model'].isin(exclude_models)]

In [10]:
print(pcd_df.loc[pcd_df['Model'] == 'Megatron-BERT']['Country (from Organization)'])
print(pcd_df.loc[pcd_df['Model'] == 'Yi-34B']['Country (from Organization)'])


802    United States of America
Name: Country (from Organization), dtype: object
1630    China
Name: Country (from Organization), dtype: object


In [11]:
country_df = pcd_df.dropna(subset=['Publication date', 'Country (from Organization)'])
len(country_df)

1921

In [12]:
country_df['Country (from Organization)'].unique()


array(['United States of America',
       'United States of America,United States of America', 'Italy',
       'New Zealand',
       'United Kingdom of Great Britain and Northern Ireland',
       'Switzerland', 'Japan', 'Multinational', 'Netherlands', 'Finland',
       'Canada', 'Japan,United States of America', 'Spain',
       'Denmark,United Kingdom of Great Britain and Northern Ireland',
       'India', 'Germany', 'France',
       'United Kingdom of Great Britain and Northern Ireland,United States of America',
       'Taiwan',
       'United States of America,United States of America,United States of America',
       'United Kingdom of Great Britain and Northern Ireland,Canada',
       'United States of America,Germany', 'Korea (Republic of)',
       'United States of America,United Kingdom of Great Britain and Northern Ireland',
       'Mexico', 'Switzerland,Germany', 'France,Canada',
       'France,United States of America,France', 'Canada,Singapore',
       'Finland,Multinational

In [13]:
country_df[country_df['Country (from Organization)'].str.contains('China')][['Model', 'Country (from Organization)']]

Unnamed: 0,Model,Country (from Organization)
397,AdaRNN,China
401,SPPNet,"United States of America,China,China"
419,Cascaded LNet-ANet,"Hong Kong,China"
430,CRF-RNN,United Kingdom of Great Britain and Northern I...
435,genCNN + dyn eval,"China,China,Ireland"
...,...,...
2067,Janus 1.3B,"China,Hong Kong,China,China"
2068,Yi-Lightning,China
2074,Pro-PRIME,"China,China,China"
2076,Hunyuan-Large,"Multinational,China"


In [14]:
country_df[~country_df['Country (from Organization)'].str.contains('China')]

Unnamed: 0,Model,Domain,Task,Authors,Notability criteria,Notability criteria notes,Model accessibility,Link,Citations,Reference,...,Assumed hardware FLOP/s,Hardware type,Compute estimate method,Training compute estimation method,Biological model safeguards,Hardware utilization (temp),BenchmarkHub-v1,Field 90,Post-training compute (FLOP),Post-training compute notes
94,Theseus,Robotics,Maze solving,Claude Shannon,Historical significance,,,https://www.technologyreview.com/2018/12/19/13...,0.0,Mighty Mouse,...,,,,,,,,,,
95,SNARC,Robotics,Maze solving,Marvin Minsky,Historical significance,,,https://en.wikipedia.org/wiki/Stochastic_neura...,33.0,A Neural-Analogue Calculator Based upon a Prob...,...,,,,,,,,,,
96,Genetic algorithm,Mathematics,Numerical simulation,NA Barricelli,Historical significance,Possibly first computer simulation of a geneti...,,https://link.springer.com/article/10.1007/BF01...,266.0,Numerical testing of evolution theories,...,,,,,,,,,,
97,Sequence-based pattern recognition,Vision,Character recognition,O. G. Selfridge,Historical significance,,,https://dl.acm.org/doi/10.1145/1455292.1455310,290.0,Pattern recognition and modern computers,...,,,,,,,,,,
98,Self Organizing System,Other,Pattern recognition,W. A. Clark and B. G. Farley,Historical significance,,,https://dl.acm.org/doi/10.1145/1455292.1455309,93.0,Generalization of pattern recognition in a sel...,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2072,Aya Expanse 32B,Language,"Language modelling/generation,Translation",,,,Open weights (restricted use),https://cohere.com/blog/aya-expanse-connecting...,,"Cohere For AI launches Aya Expanse, a state-of...",...,,,,Operation counting,,,,,,
2073,Aya Expanse 8B,Language,"Language modelling/generation,Translation",,,,Open weights (restricted use),https://cohere.com/blog/aya-expanse-connecting...,,"Cohere For AI launches Aya Expanse, a state-of...",...,,,,Operation counting,,,,,,
2075,π0 (pi-zero),"Robotics,Vision",Robotic manipulation,"Kevin Black, Noah Brown, Danny Driess, Adnan E...",,,Unreleased,https://www.physicalintelligence.company/downl...,,π0: Our First Generalist Policy,...,,,,,,,,,,
2078,Tulu 3 (Tülu 3) 70B,Language,"Language modelling/generation,Protein question...","Nathan Lambert, Jacob Morrison, Valentina Pyat...",,,Open weights (restricted use),https://allenai.org/papers/tulu-3-report.pdf,,TÜLU 3: Pushing Frontiers in Open Language Mod...,...,,,,,,,,,,


Check if the country is listed.

TODO: try other methods of reducing multiple countries to one country.
- Use the first country listed
- Mutually exclusive (e.g. China but NOT Non-China)

In [15]:
# Including China
china_countries = ['China', 'Hong Kong', 'Taiwan']

def assign_country_inclusively(row):
    if any([country in china_countries for country in row['Country (from Organization)'].split(',')]):
        return 'China'
    else:
        return 'Non-China'

# Exclusively China
def assign_country_exclusively(row):
    countries = row['Country (from Organization)'].split(',')
    # Allow multinational if no non-China country is also listed
    # This applies to Tencent, for example
    if all([(country in china_countries) or (country == 'Multinational') for country in countries]):
        return 'China'
    elif any([country in china_countries for country in countries]):
        # Not exclusively China - discard because it's an ambiguous case
        return np.nan
    else:
        # Exclusively Non-China
        return 'Non-China'

# First country listed
assign_country = assign_country_inclusively if assign_country_method == 'inclusive' else assign_country_exclusively
country_df.loc[:, 'Country'] = country_df.apply(assign_country, axis=1)

display(country_df[country_df['Country'] == 'China'][['Model', 'Country']])
display(country_df[country_df['Country'] == 'Non-China'][['Model', 'Country']])
display(country_df[country_df['Country'].isna()][['Model', 'Country']])


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  country_df.loc[:, 'Country'] = country_df.apply(assign_country, axis=1)


Unnamed: 0,Model,Country
195,RNN for speech,China
397,AdaRNN,China
401,SPPNet,China
419,Cascaded LNet-ANet,China
430,CRF-RNN,China
...,...,...
2067,Janus 1.3B,China
2068,Yi-Lightning,China
2074,Pro-PRIME,China
2076,Hunyuan-Large,China


Unnamed: 0,Model,Country
94,Theseus,Non-China
95,SNARC,Non-China
96,Genetic algorithm,Non-China
97,Sequence-based pattern recognition,Non-China
98,Self Organizing System,Non-China
...,...,...
2072,Aya Expanse 32B,Non-China
2073,Aya Expanse 8B,Non-China
2075,π0 (pi-zero),Non-China
2078,Tulu 3 (Tülu 3) 70B,Non-China


Unnamed: 0,Model,Country


In [16]:
for cat in country_df['Country'].unique():
    if pd.isna(cat):
        print(cat, len(country_df.loc[country_df['Country'].isna()]))
    else:
        print(cat, len(country_df.loc[country_df['Country'] == cat]))

Non-China 1559
China 362


In [17]:
df = country_df

In [18]:
def find_top_models_up_to_release(df, top_n):
    """Find the models which were in the top n by compute when they were released."""
    # This set will keep track of models that were ever in the top 10 at their release
    ever_in_top_n = set()

    # Iterate over each date in the DataFrame
    for current_date in df['date'].unique():
        # Get all entries up to the current date
        historical_data = df[df['date'] <= current_date]
        # Find top 10 models by flop count in this subset
        top_n_models = historical_data.nlargest(top_n, 'flop')['Model']
        # Update the set of models that were ever in top n
        ever_in_top_n.update(top_n_models)

    # Return DataFrame filtered to only include models that were ever in the top 10
    return df[df['Model'].isin(ever_in_top_n)]


def filter_top_models_within_category(df, top_n, cutoff_date, category):
    """Find the models which were in the top-n by compute when they were released,
    among models in the specified category. The top-n models in the specified category
    are seeded with the overall top-n models before the cutoff date.
    """
    # Filter top-n models within the category, but seeded with overall top-n models
    top_models_df = find_top_models_up_to_release(df, top_n)
    top_n_models_at_cutoff_date_df = top_models_df[top_models_df['date'] <= cutoff_date].nlargest(top_n, 'flop')
    category_df = df[df['category'] == category]

    # This set will keep track of models that were ever in the top 10 at their release
    ever_in_top_n = set()

    # Iterate over each date in the DataFrame
    for current_date in category_df['date'].unique():
        # Get all entries up to the current date
        category_since_cutoff = category_df[(category_df['date'] <= current_date) & (category_df['date'] > cutoff_date)]
        historical_data = pd.concat([category_since_cutoff, top_n_models_at_cutoff_date_df])
        # Find top 10 models by flop count in this subset
        top_n_models_df = historical_data.nlargest(top_n, 'flop')
        # Update the set of models that were ever in top n
        # Filter out the models that aren't in the category
        ever_in_top_n.update(top_n_models_df[top_n_models_df['category'].str.contains(category)]['Model'])

    # Return DataFrame filtered to only include models that were ever in the top 10
    new_df = df[df['Model'].isin(ever_in_top_n)]
    # Assign the category to the new DataFrame (overwrites cases with both US and China)
    # E.g. if a "USA,China" model is top-10 among models affiliated with China, then it's just "China"
    new_df['category'] = category
    
    return new_df


def filter_top_models_in_both_categories(df, top_n, cutoff_date):
    # Get top models for Open and Closed categories
    top_us_models = filter_top_models_within_category(df, top_n, cutoff_date, category='Non-China')
    top_china_models = filter_top_models_within_category(df, top_n, cutoff_date, category='China')
    # Combine the results
    df_filtered = pd.concat([top_us_models, top_china_models])
    # Sort the combined DataFrame by date
    df_filtered = df_filtered.sort_values('date')
    return df_filtered

In [19]:
df_filtered = (df[['Model', 'Training compute (FLOP)', 'Publication date', 'Organization', 'Notability criteria', 'Domain', 'Base model', 'Country', 'Training hardware']]
    .rename(columns={'Training compute (FLOP)': 'flop', 'Publication date': 'date', 'Country': 'category'})
    .assign(date=lambda x: pd.to_datetime(x['date']), log_flop=lambda x: np.log10(x['flop']))
    .sort_values('date'))

In [20]:
list(df_filtered[df_filtered['Base model'].notna()]['Model'])

['BatchNorm',
 'Layer Normalization: Handwriting sequence generation',
 'Layer Normalization: Draw',
 'Order embeddings with layer norm',
 'Layer Normalization: The Attentive Reader',
 'Layer Normalization: Skip Thoughts',
 'ULM-FiT',
 'ADP-FAIRSEQ + NGRAMRES',
 'Fine-tuned-AWD-LSTM-DOC (fin)',
 'Cross-lingual alignment',
 'Theseus 6/768',
 'UnifiedQA',
 'LUKE',
 'GPT-Neo-2.7B (finetuned on PTB)',
 'GPT-Neo-2.7B (finetuned)',
 'Unicorn',
 'Multitask Unified Model (MUM)',
 '$\\infty$-former (SM)',
 'FLAN 137B',
 'AlphaFold-Multimer',
 'T0-XXL',
 'GPT-2 (AMPS)',
 'Masked Autoencoders ViT-H',
 'ViT-G/14 (LiT)',
 'Engine-XL(NE)',
 'HSO',
 'Contriever',
 'Vespa',
 'OntoProtein',
 'InstructGPT',
 'BERT-RBP',
 'Flamingo',
 'Jurassic-X',
 'DeBERTaV3large + KEAR',
 'SimCSE',
 'CogVideo',
 'Minerva (540B)',
 'Delphi',
 'Transformer-XL + RMT',
 'GPT-NeoX-Japanese',
 'BlenderBot 3',
 'PaLM-SayCan',
 'Sparrow',
 'NMST+GPT-2',
 'Decaying Fast Weights Transformer (WT-103)',
 "Instruct-GPT + Mind's Ey

In [21]:
# Add speculative compute estimates based on benchmark imputation and rough guesses
if include_speculative_compute:
    speculative_compute_estimates = {
        "Claude 3.5 Sonnet": 4.72e25,
        "Claude 3 Opus": 1.59e25,
        "Claude 3 Sonnet": 5.51e24,
        "GPT-4o": 3.98e25,
        "Gemini 1.0 Pro": 1.85e24,
        "Gemini 1.5 Pro": 1.60e25,
        "Mistral Large 2": 2.01e25,
        "GPT-4 Turbo": 2.1e25,  # rough guess matching GPT-4
        "GPT-4V": 2.1e25,  # rough guess matching GPT-4
        "Claude 2": 4.33e24,
        "Claude 2.1": 4.33e24,  # rough guess matching Claude 2
    }
    for model, compute in speculative_compute_estimates.items():
        df_filtered.loc[df_filtered['Model'] == model, "flop"] = compute
        df_filtered.loc[df_filtered['Model'] == model, "log_flop"] = np.log10(compute)

df_filtered.dropna(subset=['flop'], inplace=True)

# Drop Alpha Go Master / Zero
if filter_alphago_outliers:
    mask = (df_filtered['Model'] == 'AlphaGo Master') | (df_filtered['Model'] == 'AlphaGo Zero')
    df_filtered = df_filtered[~mask]

# Drop finetuned models
if filter_finetuned_models:
    mask = df_filtered['Base model'].isna()
    df_filtered = df_filtered[mask]

top_models_df = find_top_models_up_to_release(df_filtered, top_n)  # For reference

if frontier_selection == 'external':
    # Filter top models before other filters
    df_filtered = filter_top_models_in_both_categories(df_filtered, top_n, top_n_cutoff_date)

if model_selection == 'Language models':
    re = 'Language|Multimodal'
    mask = df_filtered['Domain'].str.contains(re, na=False)
    df_filtered = df_filtered[mask]

if frontier_selection == 'internal':
    # Filter top models after other filters
    df_filtered = filter_top_models_in_both_categories(df_filtered, top_n, top_n_cutoff_date)

# Filter for models after the cutoff date
df_filtered = df_filtered[df_filtered['date'] > cutoff_date]

print(f"{len(df_filtered)}{' top' if frontier_selection != 'disabled' else ''} {top_n} {model_selection} models found")
print(f"They span {df_filtered['date'].min().strftime('%B %Y')} to {df_filtered['date'].max().strftime('%B %Y')}")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  new_df['category'] = category


114 top 10 Language models models found
They span August 2018 to November 2024


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  new_df['category'] = category


In [22]:
if top_n == 1:
    # Remove BIDAF outlier
    df_filtered = df_filtered[df_filtered['Model'] != 'BIDAF']

In [23]:
non_china_df = df_filtered[df_filtered['category'] == 'Non-China']
china_df = df_filtered[df_filtered['category'] == 'China']
recent_top_models_df = top_models_df[top_models_df['date'] > pd.to_datetime('2010-01-01')]

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=non_china_df['date'],
    y=non_china_df['log_flop'],
    mode='markers',
    marker=dict(color=colors['Non-China'], opacity=0.5),
    text=non_china_df['Model'],
    hoverinfo='text',
    name=f'Top-{top_n} Non-China'
))

fig.add_trace(go.Scatter(
    x=china_df['date'],
    y=china_df['log_flop'],
    mode='markers',
    marker=dict(color=colors['China'], opacity=0.5),
    text=china_df['Model'],
    hoverinfo='text',
    name=f'Top-{top_n} China'
))

fig.add_trace(go.Scatter(
    x=recent_top_models_df['date'],
    y=recent_top_models_df['log_flop'],
    mode='markers',
    marker=dict(color='grey', opacity=0.5),
    text=recent_top_models_df['Model'],
    hoverinfo='text',
    name=f'Top-{top_n} Overall'
))

fig.update_layout(
    width=800,
    height=400,
    xaxis_title='Date',
    yaxis_title='Log FLOP',
    title=f'Top-{top_n} models',
    margin=dict(t=50, l=60, r=60, b=50),
)

save_plot(fig, results_dir, f'top_{top_n}_models_without_kickstarting')

fig.show()

In [24]:
top_models_since_cutoff = top_models_df[top_models_df['date'] >= pd.to_datetime(cutoff_date)]
top_models_set = set(top_models_since_cutoff['Model'])
non_china_top_models_set = set(non_china_df['Model'])
china_top_models_set = set(china_df['Model'])

frac_non_china_top_models = len(non_china_top_models_set.intersection(top_models_set)) / len(top_models_set)
frac_china_top_models = len(china_top_models_set.intersection(top_models_set)) / len(top_models_set)
print(f"Fraction of overall top-{top_n} models that are Non-China: {frac_non_china_top_models*100:.1f}%")
print(f"Fraction of overall top-{top_n} models that are China: {frac_china_top_models*100:.1f}%")


Fraction of overall top-10 models that are Non-China: 71.1%
Fraction of overall top-10 models that are China: 7.9%


# Training hardware analysis

In [25]:
# Models with training hardware
china_num_hardware = china_df[china_df['Training hardware'].notna()].shape[0]
print(f"Number of top-{top_n} {model_selection} China models with training hardware: {china_num_hardware} of {china_df.shape[0]}")
# Print models grouped by training hardware
china_df.groupby('Training hardware')['Model'].apply(lambda x: ', '.join(x))

Number of top-10 Language models China models with training hardware: 17 of 56


Training hardware
Huawei Ascend 910                                                     PanGu-α, CodeGeeX, PanGu-Σ
Huawei Ascend 910,NVIDIA Tesla V100 DGXS 32 GB                                   ERNIE 3.0 Titan
NVIDIA A100                                       PLUG, MegaScale (530B), MegaScale (Production)
NVIDIA A100 SXM4 40 GB                                                                  GLM-130B
NVIDIA A100 SXM4 80 GB                                                       JIANG, CodeFuse-13B
NVIDIA A800                                                                FLM-101B, Skywork-13B
NVIDIA H100 SXM5 80GB                                                               Yi-Lightning
NVIDIA H800                                                                          DeepSeek-V2
NVIDIA Tesla V100 DGXS 32 GB                                                                M6-T
NVIDIA V100                                                                 CPM-Large, ERNIE 3.0
Name: Model,

In [26]:
export_controls_cutoff_date = pd.to_datetime('2022-10-07')  # actual date of controls is 2022-10-07
pre_export_controls = china_df[china_df['date'] < export_controls_cutoff_date]
post_export_controls = china_df[china_df['date'] >= export_controls_cutoff_date]
pre_export_controls.groupby('Training hardware')['Model'].apply(lambda x: ', '.join(x))

Training hardware
Huawei Ascend 910                                    PanGu-α, CodeGeeX
Huawei Ascend 910,NVIDIA Tesla V100 DGXS 32 GB         ERNIE 3.0 Titan
NVIDIA A100                                                       PLUG
NVIDIA A100 SXM4 40 GB                                        GLM-130B
NVIDIA Tesla V100 DGXS 32 GB                                      M6-T
NVIDIA V100                                       CPM-Large, ERNIE 3.0
Name: Model, dtype: object

In [27]:
post_export_controls.groupby('Training hardware')['Model'].apply(lambda x: ', '.join(x))


Training hardware
Huawei Ascend 910                                          PanGu-Σ
NVIDIA A100               MegaScale (530B), MegaScale (Production)
NVIDIA A100 SXM4 80 GB                         JIANG, CodeFuse-13B
NVIDIA A800                                  FLM-101B, Skywork-13B
NVIDIA H100 SXM5 80GB                                 Yi-Lightning
NVIDIA H800                                            DeepSeek-V2
Name: Model, dtype: object

In [28]:
non_china_num_hardware = non_china_df[non_china_df['Training hardware'].notna()].shape[0]
print(f"Number of top-{top_n} {model_selection} Non-China models with training hardware: {non_china_num_hardware} of {non_china_df.shape[0]}")
# Print models grouped by training hardware
non_china_df.groupby('Training hardware')['Model'].apply(lambda x: ', '.join(x))


Number of top-10 Language models Non-China models with training hardware: 39 of 58


Training hardware
Google TPU v3                        XLNet, T5-11B, Meena, GShard (dense), Switch, ...
Google TPU v4                        GLaM, AlphaCode, PaLM (540B), PaLM 2, Gemini 1...
Google TPU v4,Google TPU v3                                                 Chinchilla
NVIDIA A100                                    HyperCLOVA 82B, LLaMA-65B, Amazon Titan
NVIDIA A100 SXM4 40 GB               GPT-NeoX-20B, GPT-3.5 (text-davinci-003), GPT-...
NVIDIA A100 SXM4 80 GB               Megatron-Turing NLG 530B, OPT-175B, BLOOM-176B...
NVIDIA A100,NVIDIA H100 SXM5 80GB                                            Reka Core
NVIDIA H100 SXM5 80GB                Inflection-1, Inflection-2, Mistral Large, Inf...
NVIDIA Tesla V100 DGXS 16 GB                      Big Transformer for Back-Translation
NVIDIA Tesla V100 DGXS 32 GB         RoBERTa Large, Megatron-LM (8.3B), XLM-RoBERTa...
NVIDIA Tesla V100S PCIe 32 GB                                            Megatron-BERT
Name: Model, dtype: objec

# Regression analysis

In [29]:
dep_var = 'log_flop'

In [30]:
#@markdown Analysis of best fit to the data

@dataclass
class FitResult:
    df: pd.DataFrame
    p: int = None
    bic: float = None
    rss: float = None
    mse: float = None
    predict: Callable = None

@dataclass
class HyperbolicFitResult(FitResult):
    params: tuple[float] = None

@dataclass
class KinkedFitResult(FitResult):
    break_points: tuple[float] = None
    break_points_dt: float = None
    oom_year_slopes: tuple[float] = None

    # Model properties for each breakpoint combination
    # (for debugging)
    bics: tuple[float] = None
    rsss: tuple[float] = None
    mses: tuple[float] = None
    break_points_list: tuple[tuple[float]] = None
    break_points_dt_list: tuple[tuple[float]] = None

def fit_hyperbolic(df):
    def hyperbolic_model(t, A, B, k):
        return A / (1 + B * np.exp(-k * t))

    # Prepare data for curve fitting
    timestamp = pd.to_datetime(df['date']).apply(lambda date: date.toordinal()).values

    # Initial guess for the parameters
    # initial_guess = [0, 0, 0]
    initial_guess = [1.72373207e-02, -9.45447534e-01, -7.50101861e-08]  # Updated initial guess

    # Fit the model to the data
    try:
      params, covariance = curve_fit(hyperbolic_model, timestamp, df[dep_var], p0=initial_guess, maxfev=100000, ftol=1e-10)
    except RuntimeError as e:
      print("FATAL ERROR WHEN FITTING HYPERBOLIC")
      return None

    # Extracting parameters
    A, B, k = params

    # Compute predictions to calculate residuals
    predicted_log_y = hyperbolic_model(timestamp, *params)

    # Compute the Residual Sum of Squares (RSS)
    rss = np.sum((df[dep_var] - predicted_log_y) ** 2)

    # Number of observations (n)
    n = len(df[dep_var])

    # Number of parameters (p)
    p = len(params) + 1

    # Calculate log-likelihood under the assumption of normally distributed errors
    # log_likelihood = -0.5 * rss
    log_likelihood = -0.5 * n * (np.log(2 * np.pi * rss/n) + 1)

    # Compute bic_hyperbolic using the provided formula
    bic = p * np.log(n) - 2 * log_likelihood

    # Compute MSE
    mse = rss / n

    fit_result = HyperbolicFitResult(
        df=df,
        p=p,
        bic=bic,
        rss=rss,
        mse=mse,
        params=params,
        predict=lambda date: hyperbolic_model(date.apply(lambda d: d.toordinal()), *params)
    )

    return fit_result

def fit_n_phase_exponential(df, kink_count=0, allow_discontinuities=False, min_n_segment=10):
    # Generate monthly breakpoints between 2010 and 2024
    one_month = pd.DateOffset(months=1)
    break_point_grid = pd.date_range(start=df['date'].min() - one_month, end=df['date'].max() - 4*one_month, freq='MS')
    break_point_grid = [x.toordinal() for x in break_point_grid]

    x = pd.to_datetime(df['date']).apply(lambda date: date.toordinal()).values
    y = df[dep_var].values

    break_points_list = []
    bics = []
    rsss = []
    mses = []
    models = []

    for break_points in combinations_with_replacement(break_point_grid, kink_count):
        # Model predictors

        intercept_change_points = (0,)
        if allow_discontinuities:
            intercept_change_points += break_points
        slope_change_points = (0,) + break_points

        predictors = np.zeros((len(x), len(intercept_change_points) + len(slope_change_points)))

        for i, intercept_point in enumerate(intercept_change_points):
            predictors[:, i] = (x >= intercept_point).astype(int)

        for i, break_point in enumerate(slope_change_points):
            predictors[:, len(intercept_change_points) + i] = np.maximum(x - break_point, 0)

        # Fit the model
        model = sm.OLS(y, predictors).fit()

        # Calculate BIC manually based on log-likelihood
        n = len(x) # Number of observations
        p = len(model.params) + 2*kink_count + 1 # Number of parameters

        # Calculate log-likelihood under the assumption of normally distributed errors
        # We have to iterate over all points to get their individual log-likelihoods
        log_likelihood = 0
        rss = 0
        invalid_model = False # Discard models with segments with less than 2 points
        for i, break_point in enumerate(slope_change_points):
            left_x = break_point
            right_x = slope_change_points[i + 1] if i + 1 < len(slope_change_points) else np.inf

            segment_predictors = predictors[(left_x <= x) & (x < right_x), :]
            segment_y = y[(left_x <= x) & (x < right_x)]
            segment_n = len(segment_y)

            assert min_n_segment > 2

            if segment_n < min_n_segment:
                invalid_model = True
                break

            y_pred = model.predict(segment_predictors)

            segment_rss = np.sum((y_pred - segment_y)**2)
            if segment_rss == 0:
                print(f"segment_rss={segment_rss}")
                print(f"y_pred={y_pred}")
                print(f"segment_y={segment_y}")
                invalid_model = True
                break
            segment_mse = segment_rss / segment_n

            segment_log_likelihood = -segment_n/2 * (np.log(2*np.pi) + np.log(segment_rss/segment_n) + 1)
            log_likelihood += segment_log_likelihood
            rss += segment_rss

        if invalid_model:
            continue

        # Compute BIC using the manual method based on the log-likelihood
        bic = p * np.log(n) - 2 * log_likelihood
        # bic = n*np.log(rss/n) + p*np.log(n)

        bics.append(bic)
        rsss.append(rss)
        mses.append(rss/len(df))
        models.append(model)
        break_points_list.append(break_points)

    # Prepare the result object
    best_bic = min(bics)
    best_idx = bics.index(best_bic)
    best_rss = rsss[best_idx]
    best_mse = mses[best_idx]
    best_model = models[best_idx]
    best_break_points = break_points_list[best_idx]

    p = len(best_model.params) + 2*kink_count + 1 # Number of parameters

    intercept_change_points = (0,)
    if allow_discontinuities:
        intercept_change_points += best_break_points
    slope_change_points = (0,) + best_break_points

    intercepts = best_model.params[:len(intercept_change_points)]
    oom_year_slopes = 365 * np.cumsum(best_model.params[len(intercepts):])

    def predict(date):
        if not isinstance(date, pd.Series):
            date = pd.Series(date)
        x = pd.to_datetime(date).apply(lambda date: date.toordinal()).values

        predictors = np.zeros((len(x), len(intercept_change_points) + len(slope_change_points)))

        for i, intercept_point in enumerate(intercept_change_points):
            predictors[:, i] = (x >= intercept_point).astype(int)

        for i, break_point in enumerate(slope_change_points):
            predictors[:, len(intercept_change_points) + i] = np.maximum(x - break_point, 0)

        return best_model.predict(predictors)

    fit_result = KinkedFitResult(
        df=df,
        p=p,
        bic=best_bic,
        rss=best_rss,
        mse=best_mse,
        break_points=best_break_points,
        predict=predict,
        break_points_dt=[pd.Timestamp.fromordinal(bp) for bp in best_break_points],
        bics=bics,
        rsss=rsss,
        mses=mses,
        oom_year_slopes=oom_year_slopes,
        break_points_list=break_points_list,
        break_points_dt_list=[[pd.Timestamp.fromordinal(bp) for bp in break_points] for break_points in break_points_list],
    )

    return fit_result

def calculate_lag(df, fit_results, date=None):
    if date is None:
        date = df['date'].max()

    # Get the predictions for the two categories
    y_non_china = fit_results['Non-China'].predict(pd.Series([date]))[0]
    y_china = fit_results['China'].predict(pd.Series([date]))[0]
    
    # Get the final slope for the 'China' category
    slope_non_china = fit_results['Non-China'].oom_year_slopes[-1]
    
    # Calculate lag
    lag = (y_non_china - y_china) / slope_non_china
    
    return lag


## Model selection

In [31]:
fit_em_all = lambda df_fit : {
    "Simple" : fit_n_phase_exponential(df_fit, kink_count=0),
    "One kink" : fit_n_phase_exponential(df_fit, kink_count=1),
    "Discontinuity" : fit_n_phase_exponential(df_fit, kink_count=1, allow_discontinuities=True),
    # "Hyperbolic": fit_hyperbolic(df_fit),
}

# Best model fits
print(f"Fitting China and Non-China models")
regression_data = {
    'China': {},
    'Non-China': {},
}
regression_data['China']['models'] = fit_em_all(df_filtered[df_filtered['category'] == 'China'])
regression_data['Non-China']['models'] = fit_em_all(df_filtered[df_filtered['category'] == 'Non-China'])


Fitting China and Non-China models


In [32]:
# K-Fold Cross Validation
def perform_cross_validation(df, k=10, random_state=42):
    kf = KFold(n_splits=k, shuffle=True, random_state=random_state)
    folds_mses = defaultdict(lambda : [])
    for train_index, test_index in kf.split(df):
        train_df, test_df = df.iloc[train_index], df.iloc[test_index]

        # Fit the models on the training set
        fold_models = fit_em_all(train_df)

        # Predict on the test set
        for name,model in fold_models.items():
            try:
                predicted_log_y = model.predict(test_df["date"])
            except AttributeError:
                continue
            test_rss = np.sum((predicted_log_y - test_df[dep_var])**2)
            test_mse = test_rss / len(test_df)
            folds_mses[name].append(test_mse)

    # Compute mean MSE
    folds_mses = {name: np.mean(folds_mses[name]) for name in folds_mses}

    return folds_mses

regression_data['China']['folds_mses'] = perform_cross_validation(df_filtered[df_filtered['category'] == 'China'])
regression_data['Non-China']['folds_mses'] = perform_cross_validation(df_filtered[df_filtered['category'] == 'Non-China'])

In [33]:
# Bootstrap
bootstrap_sample_size = 1000

pred_start_date = df_filtered['date'].min()
pred_end_date = df_filtered['date'].max()

regression_data['China']['bootstrap_predictions'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_predictions'] = defaultdict(lambda : [])

regression_data['China']['bootstrap_bics'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_bics'] = defaultdict(lambda : [])

regression_data['China']['bootstrap_mses'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_mses'] = defaultdict(lambda : [])

regression_data['China']['bootstrap_bic_score_diff'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_bic_score_diff'] = defaultdict(lambda : [])

regression_data['China']['bootstrap_slopes'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_slopes'] = defaultdict(lambda : [])

regression_data['China']['bootstrap_breaks'] = defaultdict(lambda : [])
regression_data['Non-China']['bootstrap_breaks'] = defaultdict(lambda : [])

for category in ['China', 'Non-China']:
    print(f"Bootstrapping {category} data")
    for bootstrap_index in tqdm(range(bootstrap_sample_size)):
        if bootstrap_index == 0:
            # Use the original data as the first bootstrap sample
            sample = df_filtered.copy()
        else:
            sample = df_filtered.sample(len(df_filtered), replace=True, random_state=DEFAULT_RNG)
        sample = sample[sample['category'] == category]
        sample = sample.sort_values('date')

        # Compute BICs
        boot_models = fit_em_all(sample)

        # Compute K fold validation
        boot_folds_mses = perform_cross_validation(sample)

        # Store results
        for name, model in boot_models.items():
            # It might be None if the hyperbolic fails to fit
            if model is None: continue

            regression_data[category]['bootstrap_bics'][name].append(model.bic)
            regression_data[category]['bootstrap_mses'][name].append(boot_folds_mses[name])
            regression_data[category]['bootstrap_bic_score_diff'][name].append(model.bic - boot_models["Simple"].bic)

            if isinstance(model, KinkedFitResult):
                if (len(model.oom_year_slopes) > 0): regression_data[category]['bootstrap_slopes'][name].append(10**model.oom_year_slopes[-1])
                if (len(model.break_points_dt) > 0): regression_data[category]['bootstrap_breaks'][name].append(model.break_points_dt[-1])

        # Store predictions for confidence intervals
        predictions = {}
        for name, model in boot_models.items():
            try:
                date_grid = pd.date_range(start=pred_start_date, end=pred_end_date, freq='MS')
                pred = model.predict(pd.Series(date_grid))
                regression_data[category]['bootstrap_predictions'][name].append(pred)
            except AttributeError:
                continue

Bootstrapping China data


100%|██████████| 1000/1000 [02:10<00:00,  7.64it/s]


Bootstrapping Non-China data


100%|██████████| 1000/1000 [02:21<00:00,  7.07it/s]


In [34]:
ci_width = 0.90
qs = [(1 - ci_width)/2, (1 + ci_width)/2]
bootstrap_preferred_percent = {}
bootstrap_summary_data = {
    'China': defaultdict(lambda: {}),
    'Non-China': defaultdict(lambda: {}),
}
for category in ['China', 'Non-China']:
    for name in regression_data[category]['models']:
        bootstrap_summary_data[category]['bootstrap_preferred_percent'][name] = np.mean(np.array(regression_data[category]['bootstrap_bic_score_diff'][name])<0)
        bootstrap_summary_data[category]['bootstrap_bics'][name] = np.quantile(np.array(regression_data[category]['bootstrap_bics'][name]), qs)
        bootstrap_summary_data[category]['bootstrap_mses'][name] = np.quantile(np.array(regression_data[category]['bootstrap_mses'][name]), qs)
        bootstrap_summary_data[category]['bootstrap_bic_score_diff'][name] = np.quantile(np.array(regression_data[category]['bootstrap_bic_score_diff'][name]), qs)
        try:
            bootstrap_summary_data[category]['bootstrap_slopes'][name] = np.quantile(np.array(regression_data[category]['bootstrap_slopes'][name]), qs)
            bootstrap_summary_data[category]['bootstrap_breaks'][name] = np.quantile(np.array(regression_data[category]['bootstrap_breaks'][name]), qs)
        except IndexError:
            pass

# Models with lower BIC score / MSE are preferred.

results = {
    'China': [],
    'Non-China': [],
}

for category in ['China', 'Non-China']:
    for name, model in regression_data[category]['models'].items():
        param_count = model.p
        log_likelihood = (np.log(len(df_filtered))*param_count - model.bic)/2

        param_count_simple = regression_data[category]['models']['Simple'].p
        log_likelihood_simple = (np.log(len(df_filtered))*param_count_simple - regression_data[category]['models']['Simple'].bic)/2

        c2 = chi2.sf(2*(log_likelihood - log_likelihood_simple), df=(param_count - param_count_simple))

        result = {
            "Model": name,
            "BIC" : np.round(model.bic, 2),
            "BIC 90% CI" : np.round(bootstrap_summary_data[category]['bootstrap_bics'][name], 2),
            #"Parameter count": param_count,
            #"Log likelihood": np.round((np.log(len(df_filtered))*param_count - model.bic)/2),
            # "MSE" : model.mse,
            "BIC score diff": np.round(model.bic - regression_data[category]['models']['Simple'].bic, 2),
            "BIC score diff 90% CI": np.round(bootstrap_summary_data[category]['bootstrap_bic_score_diff'][name], 2),
            "Xi²": c2,
            "% times preferred over simple": f"{bootstrap_summary_data[category]['bootstrap_preferred_percent'][name]:.0%}",
            # "bayes factor over simple" : np.exp(-0.5 * (model.bic - models["simple"].bic)),
            "K-fold mean MSE" : np.round(regression_data[category]['folds_mses'][name], 2),
            "K-fold mean MSE 90% CI" : np.round(bootstrap_summary_data[category]['bootstrap_mses'][name], 2),
        }

        try:
            result["Recent slope (Nx/year)"] = np.round(10**model.oom_year_slopes[-1], 2)
            result["Recent slope 90% CI"] = np.round(bootstrap_summary_data[category]['bootstrap_slopes'][name], 2)
            result["Break point"] = model.break_points_dt[-1].strftime('%Y-%m')
            result["Break point 90% CI"] = [date.strftime('%Y-%m') for date in bootstrap_summary_data[category]['bootstrap_breaks'][name]]
        except (AttributeError, IndexError):
            pass
        results[category].append(result)

results = {category: pd.DataFrame(results[category]) for category in ['China', 'Non-China']}

lag_results = []
for name, model in regression_data['Non-China']['models'].items():
    for name_china, model_china in regression_data['China']['models'].items():
        lags = []
        for i in range(bootstrap_sample_size):
            last_pred_china = regression_data['China']['bootstrap_predictions'][name_china][i][-1]
            last_pred_non_china = regression_data['Non-China']['bootstrap_predictions'][name][i][-1]
            lag = (last_pred_non_china - last_pred_china) / model.oom_year_slopes[-1]
            lags.append(lag)
        lag_result = {
            "Non-China model": name,
            "China model": name_china,
            "Point estimate": np.round(lags[0], 2),  # First bootstrap sample is the original data
            "Mean": np.round(np.mean(lags), 2),
            "90% CI": np.round(np.quantile(lags, qs), 2),
        }
        lag_results.append(lag_result)
lag_results = pd.DataFrame(lag_results)
# bayes_factor = np.exp(-0.5 * (kinked_fit.bic - simple_fit.bic))

print("Bootstrapped regression results")
for category in ['China', 'Non-China']:
    print(category)
    display(results[category])
display(lag_results)

Bootstrapped regression results
China


Unnamed: 0,Model,BIC,BIC 90% CI,BIC score diff,BIC score diff 90% CI,Xi²,% times preferred over simple,K-fold mean MSE,K-fold mean MSE 90% CI,Recent slope (Nx/year),Recent slope 90% CI,Break point,Break point 90% CI
0,Simple,179.66,"[142.74, 208.87]",0.0,"[0.0, 0.0]",,0%,1.33,"[0.81, 1.64]",17.45,"[11.77, 24.48]",,
1,One kink,133.17,"[93.65, 154.3]",-46.5,"[-78.7, -26.06]",4.157691e-13,100%,0.69,"[0.33, 0.93]",2.38,"[1.76, 5.34]",2021-10,"[2021-05, 2022-01]"
2,Discontinuity,130.29,"[89.72, 149.02]",-49.37,"[-84.77, -30.76]",5.141604e-14,100%,0.77,"[0.31, 0.92]",5.35,"[2.71, 15.69]",2022-06,"[2021-02, 2023-01]"


Non-China


Unnamed: 0,Model,BIC,BIC 90% CI,BIC score diff,BIC score diff 90% CI,Xi²,% times preferred over simple,K-fold mean MSE,K-fold mean MSE 90% CI,Recent slope (Nx/year),Recent slope 90% CI,Break point,Break point 90% CI
0,Simple,72.46,"[49.48, 89.48]",0.0,"[0.0, 0.0]",,0%,0.18,"[0.12, 0.22]",4.75,"[4.19, 5.28]",,
1,One kink,76.89,"[40.5, 87.22]",4.43,"[-24.38, 7.39]",0.020549,56%,0.19,"[0.12, 0.22]",4.2,"[3.61, 30.8]",2020-02,"[2019-10, 2024-02]"
2,Discontinuity,68.28,"[32.11, 83.29]",-4.18,"[-34.37, 6.0]",0.00012,74%,0.18,"[0.11, 0.22]",4.65,"[2.6, 24.66]",2020-02,"[2019-11, 2024-02]"


Unnamed: 0,Non-China model,China model,Point estimate,Mean,90% CI
0,Simple,Simple,0.64,0.63,"[0.16, 1.04]"
1,Simple,One kink,1.94,1.8,"[1.39, 2.24]"
2,Simple,Discontinuity,1.54,1.58,"[0.97, 2.02]"
3,One kink,Simple,0.56,0.82,"[0.14, 1.58]"
4,One kink,One kink,1.97,2.09,"[1.46, 2.86]"
5,One kink,Discontinuity,1.53,1.85,"[1.04, 2.69]"
6,Discontinuity,Simple,0.61,0.72,"[0.15, 1.34]"
7,Discontinuity,One kink,1.93,1.91,"[1.37, 2.56]"
8,Discontinuity,Discontinuity,1.53,1.68,"[0.97, 2.37]"


In [35]:
# Save results_df
for category in ['China', 'Non-China']:
    regression_fname = f'compute_regression_analysis_{category}_{model_selection}_frontier={frontier_selection}_top_n={top_n}_cutoff={cutoff_date}.csv'
    results[category].to_csv(os.path.join(results_dir, regression_fname), index=False)
lag_results.to_csv(os.path.join(results_dir, f'lags_{model_selection}_frontier={frontier_selection}_top_n={top_n}_cutoff={cutoff_date}.csv'), index=False)


# Plot predictions with bootstrapped CIs

In [36]:
regression_data['China']['bootstrap_predictions']

defaultdict(<function __main__.<lambda>()>,
            {'Simple': [array([17.44316028, 17.54522786, 17.6506977 , 17.75276529, 17.85823512,
                     17.96370496, 18.05896804, 18.16443788, 18.26650546, 18.3719753 ,
                     18.47404288, 18.57951272, 18.68498256, 18.78705014, 18.89251998,
                     18.99458757, 19.1000574 , 19.20552724, 19.30419257, 19.40966241,
                     19.51172999, 19.61719983, 19.71926742, 19.82473725, 19.93020709,
                     20.03227468, 20.13774451, 20.2398121 , 20.34528194, 20.45075177,
                     20.54601485, 20.65148469, 20.75355228, 20.85902211, 20.9610897 ,
                     21.06655954, 21.17202937, 21.27409696, 21.3795668 , 21.48163438,
                     21.58710422, 21.69257406, 21.78783713, 21.89330697, 21.99537456,
                     22.10084439, 22.20291198, 22.30838182, 22.41385165, 22.51591924,
                     22.62138908, 22.72345666, 22.8289265 , 22.93439634, 23.02965941,


In [37]:
def calculate_confidence_intervals(bootstrap_preds, percentile=90):
    lower_percentile = (100 - percentile) / 2
    upper_percentile = 100 - lower_percentile
    ci = {}
    for model, preds in bootstrap_preds.items():
        preds_array = np.array(preds)  # Shape: (bootstrap_samples, n_dates)
        lower = np.percentile(preds_array, lower_percentile, axis=0)
        upper = np.percentile(preds_array, upper_percentile, axis=0)
        ci[model] = {'lower': lower, 'upper': upper}
    return ci

In [38]:
# Calculate 90% Confidence Intervals
confidence_intervals = calculate_confidence_intervals(regression_data['China']['bootstrap_predictions'], percentile=90)
confidence_intervals

{'Simple': {'lower': array([16.72781994, 16.84125344, 16.95803438, 17.06811269, 17.18186028,
         17.298764  , 17.40645664, 17.52568777, 17.64107273, 17.75673939,
         17.86868894, 17.98396673, 18.10175475, 18.21391049, 18.33107373,
         18.44336478, 18.55939887, 18.67548788, 18.78409326, 18.90031484,
         19.01260656, 19.12864134, 19.24108081, 19.35717644, 19.47540207,
         19.59160634, 19.71214968, 19.82649177, 19.94682429, 20.05999771,
         20.16194414, 20.27495342, 20.38399685, 20.49676903, 20.61389627,
         20.73515328, 20.85775288, 20.97177462, 21.08484158, 21.19297022,
         21.30501055, 21.41726924, 21.51952359, 21.63714073, 21.74848331,
         21.86218082, 21.97149192, 22.08828958, 22.20103653, 22.30886388,
         22.41801138, 22.52344688, 22.63466706, 22.74921949, 22.85369934,
         22.96333767, 23.07303366, 23.18468489, 23.28741338, 23.39087907,
         23.49657006, 23.60196771, 23.70459849, 23.8021621 , 23.90553924,
         24.0024957

In [48]:
# Graph of the different model fits using plotly
model_types = {
    'Non-China': 'Simple',
    'China': 'One kink',
}
# Parameters for each model selection
model_params = {
    'Simple': {
        'kink_count': 0,
        'allow_discontinuities': False,
    },
    'One kink': {
        'kink_count': 1,
        'allow_discontinuities': False,
    },
    'Discontinuity': {
        'kink_count': 1,
        'allow_discontinuities': True,
    }
}

def plot_model(df, model_types, model_params):
    fig = go.Figure()

    # Plot the original data points
    df_non_china = df[df['category'] == 'Non-China']
    df_china = df[df['category'] == 'China']

    fig.add_trace(go.Scatter(
        x=df_non_china['date'], y=df_non_china['log_flop'],
        mode='markers', name='Not developed in China', text=df_non_china['Model'],
        marker=dict(color=colors['Non-China'], opacity=0.1, size=10)
    ))
    fig.add_trace(go.Scatter(
        x=df_china['date'], y=df_china['log_flop'],
        mode='markers', name='Developed in China', text=df_china['Model'],
        marker=dict(color=colors['China'], opacity=0.1, size=10)
    ))

    # Show the export controls date
    # Convert 2022-10-07 to seconds since epoch
    export_controls_date = pd.Timestamp('2022-10-07').value / 1e6
    fig.add_vline(x=export_controls_date, line_color='black', line_width=1, line_dash='dot', 
        annotation_text='October 2022<br>Export controls introduced', annotation_position='bottom right')

    # non_china_date_grid = pd.date_range(start=non_china_df['date'].min(), end=non_china_df['date'].max(), freq='ME')
    # china_date_grid = pd.date_range(start=china_df['date'].min(), end=china_df['date'].max(), freq='ME')
    date_grid = pd.date_range(start=df['date'].min(), end=df['date'].max(), freq='ME')

    trend_dfs = {}
    ci_dfs = {}
    fit_results = {}
    for category, model_type in model_types.items():
        params = model_params[model_type]
        fit_result = fit_n_phase_exponential(df[df['category'] == category], **params)
        fit_results[category] = fit_result

        # Get the month of the first point, then use that to index the date_grid
        start_month = df[df['category'] == category]['date'].min().month
        start_index = np.where(date_grid.month == start_month)[0][0]
        log_flop = fit_result.predict(pd.Series(date_grid))
        # To plot the bootstrapped mean prediction instead:
        # log_flop = np.mean(regression_data[category]['bootstrap_predictions'][model_type], axis=0)

        trend_dfs[category] = pd.DataFrame({
            'date': date_grid[start_index:],
            'log_flop': log_flop[start_index:],
        })
 
        # Get the confidence intervals
        ci_data = calculate_confidence_intervals(regression_data[category]['bootstrap_predictions'], percentile=90)
        ci_dfs[category] = pd.DataFrame({
            'date': date_grid[start_index:],
            'lower': ci_data[model_type]['lower'][start_index:],
            'upper': ci_data[model_type]['upper'][start_index:],
        })

        # Plot the best fit line with confidence intervals
        fig.add_trace(go.Scatter(
            x=date_grid[start_index:], y=log_flop[start_index:],
            mode='lines', name=f'{category} best fit line',
            line=dict(color=colors[category], width=1),
            showlegend=False,
        ))
        fig.add_trace(go.Scatter(
            x=date_grid[start_index:],
            y=ci_data[model_type]['lower'][start_index:],
            mode='lines',
            line=dict(color=colors[category], width=0),
            showlegend=False,
        ))
        fig.add_trace(go.Scatter(
            x=date_grid[start_index:],
            y=ci_data[model_type]['upper'][start_index:],
            mode='lines',
            fill='tonexty',
            fillcolor='rgba(0,0,255,0.1)' if category == 'Non-China' else 'rgba(255,0,0,0.1)',
            line=dict(color=colors[category], width=0),
            name=f'{category} 90% CI',
            showlegend=False,
        ))

    # Add slope labels
    for category in ['Non-China', 'China']:
        category_df = df[df['category'] == category]
        points = [category_df['date'].min()] + fit_results[category].break_points_dt + [category_df['date'].max()]
        model_type = model_types[category]
        best_slope = 10**regression_data[category]['models'][model_type].oom_year_slopes[-1]
        slopes = bootstrap_summary_data[category]['bootstrap_slopes'][model_type]
        slope_label = f'{best_slope:.1f}x/year<br>90% CI: {slopes[0]:.1f}-{slopes[1]:.1f}x/year'
        for i in range(len(points) - 2, len(points) - 1):
            mid = points[i] + (points[i+1] - points[i]) / 2
            if category == 'China':
                mid += pd.Timedelta(days=150)
            y = fit_results[category].predict(pd.Series([mid]))[0]
            fig.add_annotation(
                x=mid, y=y + 1.8 * (1 if category == 'Non-China' else -1),
                text=slope_label,
                showarrow=False,
                font=dict(size=12, color=colors[category])
            )

    # Update layout
    title = f'Frontiers of training compute for LLMs inside and outside China'
    fig.update_layout(
        template='plotly_white',
        width=800,
        height=400,
        title=title,
        xaxis_title='Publication date',
        yaxis_title='Training compute (FLOP)',
        legend_title='',
        margin=dict(l=10, r=10, t=40, b=10),
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(int(df['log_flop'].min()), int(df['log_flop'].max())+2, 2)),
            ticktext=[f'10<sup>{i}</sup>' for i in range(int(df['log_flop'].min()), int(df['log_flop'].max())+2, 2)]
        )
    )

    if save:
        fname = f'compute_regression_{model_selection}_frontier={frontier_selection}_top{top_n}_cutoff={cutoff_date}'
        save_plot(fig, results_dir, fname)
        
        slope_df = pd.DataFrame({
            'Category': ['Non-China', 'China'],
            'Best fit slope': [10**regression_data[category]['models'][model_types[category]].oom_year_slopes[-1] for category in ['Non-China', 'China']],
            '90% CI lower': [bootstrap_summary_data[category]['bootstrap_slopes'][model_types[category]][0] for category in ['Non-China', 'China']],
            '90% CI upper': [bootstrap_summary_data[category]['bootstrap_slopes'][model_types[category]][1] for category in ['Non-China', 'China']],
        })
        slope_df.to_csv(results_dir + f'plot_data/recent_slopes_{fname}.csv', index=False)

        df_non_china[['Model', 'date', 'log_flop']].to_csv(results_dir + f'plot_data/non_china_scatter_{fname}.csv', index=False)
        df_china[['Model', 'date', 'log_flop']].to_csv(results_dir + f'plot_data/china_scatter_{fname}.csv', index=False)
        trend_dfs['Non-China'][['date', 'log_flop']].to_csv(results_dir + f'plot_data/non_china_best_fit_line_{fname}.csv', index=False)
        trend_dfs['China'][['date', 'log_flop']].to_csv(results_dir + f'plot_data/china_best_fit_line_{fname}.csv', index=False)
        ci_dfs['Non-China'][['date', 'lower', 'upper']].to_csv(results_dir + f'plot_data/non_china_ci_{fname}.csv', index=False)
        ci_dfs['China'][['date', 'lower', 'upper']].to_csv(results_dir + f'plot_data/china_ci_{fname}.csv', index=False)


    for category in ['Non-China', 'China']:
        print(category)
        fit_result = fit_results[category]
        simple_fit = fit_n_phase_exponential(df[df['category'] == category], 0)
        bayes_factor = np.exp(-0.5 * (fit_result.bic - simple_fit.bic))
        unadjusted_bayes_factor = np.exp(-0.5 * (fit_result.bic - (simple_fit.bic + 2*np.log(len(df[df['category'] == category])))))

        print(f"BIC score: {fit_result.bic}")
        bic_score_difference = fit_result.bic - simple_fit.bic
        if bic_score_difference > 0:
            print(f"The simple exponential is preferred over this fit by a BIC score difference of {fit_result.bic - simple_fit.bic}")
        if bic_score_difference < 0:
            print(f"This fit is preferred over a simple exponential by a BIC score difference of {-(fit_result.bic - simple_fit.bic)}")

    fig.show()

    return fit_result

fit_result = plot_model(df_filtered, model_types, model_params)

Non-China
BIC score: 72.46027132121029
China
BIC score: 133.16957453021683
This fit is preferred over a simple exponential by a BIC score difference of 46.495291021118874
