In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import defaultdict
from datetime import datetime
import kaleido  # needed for saving plots
import numpy as np
import os
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.model_selection import KFold
from tqdm import tqdm

from regression import *
from plotting import *

# Parameters

In [3]:
benchmarks_to_analyze = ['MMLU', 'GPQA', 'GSM1k', 'BBH']
bench_is_accuracy = {'MMLU': True, 'BBH': True, 'GSM1k': True, 'GPQA': True, 'LMSys Elo': False, 'SEAL Coding': False, 'SEAL Math': False}
non_suspects_only = True  # Whether to only include not-suspicious benchmark scores in the analysis
trusted_only = False  # Whether to only include actively trusted benchmark scores in the analysis (more strict)
save = True  # Whether to save plots and results to disk

In [4]:
results_dir = 'results/initial/'
os.makedirs(results_dir, exist_ok=True)

In [5]:
rng = np.random.default_rng(seed=42)

# Prepare data

In [6]:
# data_path = "https://docs.google.com/spreadsheets/d/1etu9rXcME0uUA-S2ANA8bsfQbIZgNu-8NxqFGQdDIzQ/export?format=csv&gid=1305280917#gid=1305280917"
data_path = "data/benchmarks_with_model_accessibility.csv"
df = pd.read_csv(data_path)

pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

df.head(30)

Unnamed: 0,System,Model size (parameters),Active Parameters,Dataset size,Date,Open/Closed,Training compute (FLOP),Training compute notes,BBH,GPQA,MMLU,HELM MMLU,SEAL Coding,SEAL Instruction Following,SEAL Math,LMSys Elo,LMSys Elo Notes,LMSys Elo 95% CI,BBH Notes,GPQA Notes,MMLU Notes,HELM MMLU Notes,Trust in benchmark results,Trust notes
0,BLOOM-176B,176000000000.0,176000000000.0,390000000000.0,2022-11-09,Open,4.12e+23,,0.4491,,0.3913,,,,,,,,,,,,0,
1,BloombergGPT,50000000000.0,50000000000.0,708000000000.0,2023-03-30,Closed,2.12e+23,,0.4197,,0.3918,,,,,,,,,,,,0,
2,Camelidae-8x34B,,,,2024-01-05,Open,,,,,0.756,,,,,,,,,,,,0,
3,ChatGLM-6B,6000000000.0,6000000000.0,,2023-03-01,Open,,,0.1873,,,,,,,880.0,,,,,,,0,
4,ChatGLM2-12B-base,12000000000.0,12000000000.0,,2023-06-25,Open,,,0.3602,,,,,,,,,,,,,,0,
5,ChatGLM2-6B-base,6000000000.0,6000000000.0,,2023-06-25,Open,,,0.3368,,,,,,,924.0,,,,,,,0,
6,ChatGLM3-6B,6000000000.0,6000000000.0,,2023-10-27,Open,5.04e+22,,0.661,,,,,,,955.0,,,,,,,0,
7,Chinchilla 70B,,70000000000.0,,2022-03-29,Closed,5.76e+23,,,,0.675,,,,,,,,,,,,0,
8,Claude 2,,,,2023-07-11,Closed,,,,0.353,0.785,,,,,1132.0,,,,Epoch evaluation,"Actually CoT, so probably an overestimate. HEL...",,0,
9,Claude 2.1,,,,2023-11-21,Closed,,,,0.361,,,,,,,,,,Epoch evaluation,,,0,Doesn't perform worse on GSM1k relative to GSM8k


In [7]:
# gsm1k_data_path = "https://docs.google.com/spreadsheets/d/1KYp4h3urj-698IE9bR7n1ctuH1iyCAQ5pTZIqQ_qs9g/export?format=csv"
gsm1k_data_path = "data/gsm1k_with_model_accessibility.csv"
gsm1k_df = pd.read_csv(gsm1k_data_path)
gsm1k_df

Unnamed: 0,System,Date,GSM8k,GSM1k,Training compute (FLOP),Speculative Compute,Active Parameters,Open/Closed
0,claude-2.1,2023-07-11,0.887,0.894,3.800000e+24,,,Closed
1,claude-3-haiku-20240307,2024-03-04,0.785,0.785,,,,Closed
2,claude-3-opus-20240229,2024-03-04,0.802,0.825,,4.000000e+25,,Closed
3,claude-3-sonnet-20240229,2024-03-04,0.719,0.744,,,,Closed
4,codegemma-7b,2024-04-09,0.479,0.416,3.330000e+23,,7000000000,Open
...,...,...,...,...,...,...,...,...
66,vicuna-33b-v1.3,2023-06-22,0.379,0.341,,,33000000000,Open
67,Xwin-Math-13B-V1.0,2024-03-07,0.631,0.529,,,13000000000,Open
68,Xwin-Math-7B-V1.0,2024-03-07,0.529,0.428,,,7000000000,Open
69,Yi-34B-Chat,2023-11-02,0.641,0.569,6.100000e+23,,34000000000,Open


In [8]:
# Concatenate dfs
df = pd.concat([gsm1k_df, df], axis=0, join='outer', ignore_index=True)
df

Unnamed: 0,System,Date,GSM8k,GSM1k,Training compute (FLOP),Speculative Compute,Active Parameters,Open/Closed,Model size (parameters),Dataset size,Training compute notes,BBH,GPQA,MMLU,HELM MMLU,SEAL Coding,SEAL Instruction Following,SEAL Math,LMSys Elo,LMSys Elo Notes,LMSys Elo 95% CI,BBH Notes,GPQA Notes,MMLU Notes,HELM MMLU Notes,Trust in benchmark results,Trust notes
0,claude-2.1,2023-07-11,0.887,0.894,3.800000e+24,,,Closed,,,,,,,,,,,,,,,,,,,
1,claude-3-haiku-20240307,2024-03-04,0.785,0.785,,,,Closed,,,,,,,,,,,,,,,,,,,
2,claude-3-opus-20240229,2024-03-04,0.802,0.825,,4.000000e+25,,Closed,,,,,,,,,,,,,,,,,,,
3,claude-3-sonnet-20240229,2024-03-04,0.719,0.744,,,,Closed,,,,,,,,,,,,,,,,,,,
4,codegemma-7b,2024-04-09,0.479,0.416,3.330000e+23,,7000000000,Open,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
198,XVerse-7B,2023-09-26,,,,,7000000000,Open,,,,,,,,,,,,,,,,,,0.0,
199,Yi-1.5-34B,2024-05-10,,,,,,Open,,,,,0.060,,,,,,,,,,Epoch evaluation,,,0.0,
200,Yi-34B,2023-11-02,,,6.120000e+23,,34000000000,Open,3.400000e+10,3.000000e+12,,0.543,0.165,0.7635,,,,,1111.0,chat,,,Epoch evaluation,,,-1.0,MMLU-GPQA performance difference is relatively...
201,Yi-6B,2023-11-02,,,1.080000e+23,,6000000000,Open,6.000000e+09,3.000000e+12,,0.428,,0.6385,,,,,,,,,,,,0.0,


In [9]:
df.loc[df['System'] == 'Random chance', 'GSM1k'] = 0.0

In [10]:
df['Date'] = pd.to_datetime(df['Date'], format='mixed')
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
df

Unnamed: 0,System,Date,GSM8k,GSM1k,Training compute (FLOP),Speculative Compute,Active Parameters,Open/Closed,Model size (parameters),Dataset size,Training compute notes,BBH,GPQA,MMLU,HELM MMLU,SEAL Coding,SEAL Instruction Following,SEAL Math,LMSys Elo,LMSys Elo Notes,LMSys Elo 95% CI,BBH Notes,GPQA Notes,MMLU Notes,HELM MMLU Notes,Trust in benchmark results,Trust notes
0,claude-2.1,2023-07-11,0.887,0.894,3.800000e+24,,,Closed,,,,,,,,,,,,,,,,,,,
1,claude-3-haiku-20240307,2024-03-04,0.785,0.785,,,,Closed,,,,,,,,,,,,,,,,,,,
2,claude-3-opus-20240229,2024-03-04,0.802,0.825,,4.000000e+25,,Closed,,,,,,,,,,,,,,,,,,,
3,claude-3-sonnet-20240229,2024-03-04,0.719,0.744,,,,Closed,,,,,,,,,,,,,,,,,,,
4,codegemma-7b,2024-04-09,0.479,0.416,3.330000e+23,,7000000000,Open,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
198,XVerse-7B,2023-09-26,,,,,7000000000,Open,,,,,,,,,,,,,,,,,,0.0,
199,Yi-1.5-34B,2024-05-10,,,,,,Open,,,,,0.060,,,,,,,,,,Epoch evaluation,,,0.0,
200,Yi-34B,2023-11-02,,,6.120000e+23,,34000000000,Open,3.400000e+10,3.000000e+12,,0.543,0.165,0.7635,,,,,1111.0,chat,,,Epoch evaluation,,,-1.0,MMLU-GPQA performance difference is relatively...
201,Yi-6B,2023-11-02,,,1.080000e+23,,6000000000,Open,6.000000e+09,3.000000e+12,,0.428,,0.6385,,,,,,,,,,,,0.0,


In [11]:
# Filter out finetuned systems

finetuned_systems = [
 'Layer Normalization: Handwriting sequence generation',
 'ULM-FiT',
 'ADP-FAIRSEQ + NGRAMRES',
 'Cross-lingual alignment',
 'UnifiedQA',
 '$\\infty$-former (SM)',
 'FLAN 137B',
 'AlphaFold-Multimer',
 'Masked Autoencoders',
 'Contriever',
 'BERT-RBP',
 'Minerva',
 'BlenderBot 3',
 'PaLM-SayCan',
 'NMST+GPT-2',
 'Decaying Fast Weights Transformer (WT-103)',
 'GPT-2 + Progressive LRD',
 'U-PaLM',
 'Flan-T5 11B',
 'Flan-PaLM 540B',
 'Taiyi-Stable Diffusion',
 'OPT-IML (175B)',
 'SparseOPT-175B',
 'DiT-XL/2',
 'VideoMAE V2',
 'Segment Anything Model',
 'gLM',
 'MOSS-Moon-003',
 'WizardLM-7B',
 'InstructBLIP',
 'Guanaco-65B',
 'WizardCoder-15.5B',
 'Code Llama-34B',
 'Code Llama-7B',
 'TigerBot-70B',
 'MiniGPT4 (Vicuna finetune)',
 'LLaMA-7B (protein-oriented instructions finetuned)',
 'FinGPT-13B',
 'LLaVA 1.5',
 'CogVLM',
 'Volcano 13B',
 'SPHINX (Llama 2 13B)',
 'Orca 2-13B',
 'Llama Guard',
 'FunSearch',
 'Elyza',
 'Code Llama-70B',
 'Swallow'
]

df = df[~df['System'].isin(finetuned_systems)]
df = df[~df['System'].str.contains('Flan', case=False)]

## Merge SEAL Math with GSM1k

In [12]:
for i, row in df.iterrows():
  if pd.notna(row['SEAL Math']):
    df.at[i, 'GSM1k'] = row['SEAL Math']
df

Unnamed: 0,System,Date,GSM8k,GSM1k,Training compute (FLOP),Speculative Compute,Active Parameters,Open/Closed,Model size (parameters),Dataset size,Training compute notes,BBH,GPQA,MMLU,HELM MMLU,SEAL Coding,SEAL Instruction Following,SEAL Math,LMSys Elo,LMSys Elo Notes,LMSys Elo 95% CI,BBH Notes,GPQA Notes,MMLU Notes,HELM MMLU Notes,Trust in benchmark results,Trust notes
0,claude-2.1,2023-07-11,0.887,0.894,3.800000e+24,,,Closed,,,,,,,,,,,,,,,,,,,
1,claude-3-haiku-20240307,2024-03-04,0.785,0.785,,,,Closed,,,,,,,,,,,,,,,,,,,
2,claude-3-opus-20240229,2024-03-04,0.802,0.825,,4.000000e+25,,Closed,,,,,,,,,,,,,,,,,,,
3,claude-3-sonnet-20240229,2024-03-04,0.719,0.744,,,,Closed,,,,,,,,,,,,,,,,,,,
4,codegemma-7b,2024-04-09,0.479,0.416,3.330000e+23,,7000000000,Open,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
198,XVerse-7B,2023-09-26,,,,,7000000000,Open,,,,,,,,,,,,,,,,,,0.0,
199,Yi-1.5-34B,2024-05-10,,,,,,Open,,,,,0.060,,,,,,,,,,Epoch evaluation,,,0.0,
200,Yi-34B,2023-11-02,,,6.120000e+23,,34000000000,Open,3.400000e+10,3.000000e+12,,0.543,0.165,0.7635,,,,,1111.0,chat,,,Epoch evaluation,,,-1.0,MMLU-GPQA performance difference is relatively...
201,Yi-6B,2023-11-02,,,1.080000e+23,,6000000000,Open,6.000000e+09,3.000000e+12,,0.428,,0.6385,,,,,,,,,,,,0.0,


## Compute

In [13]:
fig = make_subplots(rows=2, cols=2, subplot_titles=benchmarks_to_analyze, vertical_spacing=0.15)

# Define x limits for each subplot
x_limits = {
    'MMLU': [1e20, 1e26],
    'GSM1k': [1e20, 1e26],
    'GPQA': [1e23, 1e26],
    'BBH': [1e20, 1e26],
    'SEAL Math': [1e23, 1e26],
    'SEAL Coding': [1e23, 1e26],
    'LMSys Elo': [1e22, 1e26],
}

for i, bench in enumerate(benchmarks_to_analyze):
    plot_df = df[~(df['System'] == 'Random chance')]
    if non_suspects_only:
        if bench == 'GPQA':
            # GPQA was released November 20, 2023
            old_df = plot_df[plot_df['Date'] < pd.to_datetime('2023-11-20')]
            new_df = plot_df[plot_df['Date'] >= pd.to_datetime('2023-11-20')]
            new_df = new_df[new_df['Trust in benchmark results'] >= 0]
            plot_df = pd.concat([old_df, new_df])
        elif bench == 'MMLU':
            plot_df = plot_df[plot_df['Trust in benchmark results'] >= 0]
    elif trusted_only:
        plot_df = plot_df[plot_df['Trust in benchmark results'] > 0]

    fig.append_trace(
        go.Scatter(
            x=plot_df['Training compute (FLOP)'],
            y=100 * plot_df[bench],
            mode='markers',
            text=plot_df['System'],
            name=bench,
            showlegend=True if i == 0 else False
        ),
        row=i//2 + 1, col=i%2 + 1
    )

    # Update x and y axes for this subplot
    fig.update_xaxes(
        title_text="Training compute (FLOP)" if i//2 + 1 == 2 else None,
        type='log',
        range=[np.log10(x_limits[bench][0]), np.log10(x_limits[bench][1])],  # Set x limits
        tickmode='linear',
        dtick=2,  # This sets ticks at every two powers of 10
        row=i//2 + 1,
        col=i%2 + 1
    )

    if i%2 + 1 == 1:
        fig.update_yaxes(title_text="Accuracy (%)", row=i//2 + 1, col=i%2 + 1)

# Improve the layout
fig.update_layout(
    template='plotly_white',
    width=600,
    height=400,
    # legend_title="Model accessibility",
    font=dict(size=12),
    hovermode="closest",
)

# Margins
fig.update_layout(
    margin=dict(l=0, r=0, t=20, b=0)
)

# Save the plot
if save:
    save_plot(fig, results_dir, 'benchmark_compute')

# Show the plot
fig.show()

In [14]:
bench = 'MMLU'

In [15]:
reg_df = df[~(df['System'] == 'Random chance')]
if non_suspects_only:
    if bench == 'GPQA':
        # GPQA was released November 20, 2023
        old_df = reg_df[reg_df['Date'] < pd.to_datetime('2023-11-20')]
        new_df = reg_df[reg_df['Date'] >= pd.to_datetime('2023-11-20')]
        new_df = new_df[new_df['Trust in benchmark results'] >= 0]
        reg_df = pd.concat([old_df, new_df])
    elif bench == 'MMLU':
        reg_df = reg_df[reg_df['Trust in benchmark results'] >= 0]
elif trusted_only:
    reg_df = reg_df[reg_df['Trust in benchmark results'] > 0]

In [16]:
random_chance_level = df.loc[df["System"] == "Random chance", bench].values[0]
# Filter out models that are not far above random chance level
# This is a heuristic to find the changepoint
filtered_reg_df = reg_df.loc[reg_df[bench] > random_chance_level + 0.05].copy()
filtered_reg_df['log_compute'] = np.log10(filtered_reg_df['Training compute (FLOP)'])
filtered_reg_df[bench + '_log_error'] = -np.log(1 - filtered_reg_df[bench])
filtered_reg_df.dropna(subset=['log_compute', bench + '_log_error'], inplace=True)

In [17]:
filtered_reg_df.loc[:, 'float_date'] = datetime_to_float_year(filtered_reg_df['Date'])

In [18]:
# Single fit for all data
performance_model = fit_ols_regression(filtered_reg_df, ['float_date', 'log_compute'], bench + '_log_error')
performance_model.summary()

0,1,2,3
Dep. Variable:,MMLU_log_error,R-squared:,0.891
Model:,OLS,Adj. R-squared:,0.886
Method:,Least Squares,F-statistic:,187.3
Date:,"Wed, 23 Oct 2024",Prob (F-statistic):,7.87e-23
Time:,18:40:37,Log-Likelihood:,20.446
No. Observations:,49,AIC:,-34.89
Df Residuals:,46,BIC:,-29.22
Df Model:,2,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,-337.8448,58.754,-5.750,0.000,-456.111,-219.579
float_date,0.1623,0.029,5.566,0.000,0.104,0.221
log_compute,0.4413,0.029,15.368,0.000,0.384,0.499

0,1,2,3
Omnibus:,1.17,Durbin-Watson:,1.401
Prob(Omnibus):,0.557,Jarque-Bera (JB):,0.502
Skew:,0.185,Prob(JB):,0.778
Kurtosis:,3.33,Cond. No.,5060000.0


In [19]:
performance_model.params.index

Index(['Intercept', 'float_date', 'log_compute'], dtype='object')

In [20]:
get_predictions(performance_model, filtered_reg_df, ['float_date', 'log_compute'])

array([0.86809836, 0.80417934, 0.83301391, 1.10795706, 0.7344945 ,
       1.34588002, 1.28059877, 0.89900253, 1.96231806, 1.14080943,
       0.82425726, 0.80058925, 1.31763294, 1.6783322 , 0.36626722,
       1.40823777, 1.64743564, 1.69486758, 1.62472506, 2.01194255,
       1.66751951, 0.59635087, 0.83924538, 0.96930727, 0.47770478,
       0.79415586, 0.97842154, 1.11682709, 0.67550977, 1.68978575,
       1.90143795, 0.81918711, 0.50983512, 1.85120689, 1.85120689,
       0.44764354, 1.11906349, 0.70587887, 1.50168741, 0.74232561,
       0.67790121, 0.62812   , 0.41936146, 1.10795706, 1.22814821,
       0.7560468 , 0.93330941, 0.77066472, 1.39592097])

In [21]:
# Define the range for log_compute and float_date
float_date_min, float_date_max = filtered_reg_df['float_date'].min(), filtered_reg_df['float_date'].max()
log_compute_min, log_compute_max = filtered_reg_df['log_compute'].min(), filtered_reg_df['log_compute'].max()

# Create a grid of values
log_compute_vals = np.linspace(log_compute_min - 1, log_compute_max + 1, 100)
float_date_vals = np.linspace(float_date_min - 1, float_date_max + 1, 100)
X_grid, Y_grid = np.meshgrid(float_date_vals, log_compute_vals)

# Prepare the grid for prediction
X_pred = pd.DataFrame({
    'float_date': X_grid.ravel(),
    'log_compute': Y_grid.ravel()
})

# Generate predictions
Z_pred = performance_model.predict(X_pred)
Z_pred = Z_pred.values.reshape(X_grid.shape)

min_performance = Z_pred.min()
max_performance = Z_pred.max()

# Plot the contour using Plotly
fig = go.Figure(data=go.Contour(
    x=float_date_vals,
    y=10**log_compute_vals,
    z=Z_pred,
    colorscale='Viridis',
    colorbar=dict(title='-ln(1 - accuracy)'),
    contours=dict(
        coloring='heatmap',
        showlabels=True,
        labelfont=dict(size=12, color='white')
    )
))

# Add the actual data points with Viridiscolorscale
fig.add_trace(go.Scatter(
    x=filtered_reg_df['float_date'],
    y=10**filtered_reg_df['log_compute'],
    mode='markers',
    marker=dict(
        color=filtered_reg_df[bench + '_log_error'],
        colorscale='Viridis',
        cmin=min_performance,
        cmax=max_performance,
        line=dict(width=1, color='black')
    ),
    text=filtered_reg_df['System'],
    name='Data'
))

fig.update_yaxes(type='log')

fig.update_layout(
    title=f'Regression Model Predictions for {bench}',
    xaxis_title='Year',
    yaxis_title='Training compute (FLOP)',
    coloraxis_colorbar=dict(title=f'Predicted {bench}_log_error'),
    width=800,
    height=400
)

if save:
    save_plot(fig, results_dir, f'{bench}_predictions_isoperformance_contour')

fig.show()

In [22]:
performance_model.params

Intercept     -337.844845
float_date       0.162291
log_compute      0.441317
dtype: float64

In [23]:
isoperformance_slope = -performance_model.params['float_date'] / performance_model.params['log_compute']
print(f'It costs {1/(10**isoperformance_slope):.2f}x less training compute each year to keep {bench} performance fixed.')

It costs 2.33x less training compute each year to keep MMLU performance fixed.


In [24]:
time_model = fit_ols_regression(filtered_reg_df, [bench + '_log_error', 'log_compute'], 'float_date')
time_model.summary()

0,1,2,3
Dep. Variable:,float_date,R-squared:,0.477
Model:,OLS,Adj. R-squared:,0.455
Method:,Least Squares,F-statistic:,21.0
Date:,"Wed, 23 Oct 2024",Prob (F-statistic):,3.32e-07
Time:,18:40:38,Log-Likelihood:,-46.356
No. Observations:,49,AIC:,98.71
Df Residuals:,46,BIC:,104.4
Df Model:,2,,
Covariance Type:,nonrobust,,

0,1,2,3,4,5,6
,coef,std err,t,P>|t|,[0.025,0.975]
Intercept,2041.9042,5.434,375.748,0.000,2030.966,2052.843
MMLU_log_error,2.4800,0.446,5.566,0.000,1.583,3.377
log_compute,-0.8863,0.245,-3.611,0.001,-1.380,-0.392

0,1,2,3
Omnibus:,10.447,Durbin-Watson:,1.477
Prob(Omnibus):,0.005,Jarque-Bera (JB):,10.055
Skew:,-1.041,Prob(JB):,0.00655
Kurtosis:,3.77,Cond. No.,1420.0


In [25]:
# Define the range for performance and log_compute
performance_min, performance_max = filtered_reg_df[bench + '_log_error'].min(), filtered_reg_df[bench + '_log_error'].max()
log_compute_min, log_compute_max = filtered_reg_df['log_compute'].min(), filtered_reg_df['log_compute'].max()

# Create a grid of values
performance_vals = np.linspace(performance_min - 0.5, performance_max + 0.5, 100)
log_compute_vals = np.linspace(log_compute_min - 1, log_compute_max + 1, 100)
X_grid, Y_grid = np.meshgrid(performance_vals, log_compute_vals)

# Prepare the grid for prediction
X_pred = pd.DataFrame({
    bench + '_log_error': X_grid.ravel(),
    'log_compute': Y_grid.ravel()
})

# Generate predictions
Z_pred = time_model.predict(X_pred)
Z_pred = Z_pred.values.reshape(X_grid.shape)

min_year = Z_pred.min()
max_year = Z_pred.max()

# Plot the contour using Plotly
fig = go.Figure(data=go.Contour(
    x=performance_vals,
    y=10**log_compute_vals,
    z=Z_pred,
    colorscale='Viridis',
    colorbar=dict(title='Year'),
    contours=dict(
        coloring='heatmap',
        showlabels=True,
        labelfont=dict(size=12, color='white')
    )
))

# Add the actual data points with Viridiscolorscale
fig.add_trace(go.Scatter(
    x=filtered_reg_df[bench + '_log_error'],
    y=10**filtered_reg_df['log_compute'],
    mode='markers',
    marker=dict(
        color=filtered_reg_df['float_date'],
        colorscale='Viridis',
        cmin=min_year,
        cmax=max_year,
        line=dict(width=1, color='black')
    ),
    text=filtered_reg_df['System'],
    name='Data'
))

fig.update_yaxes(type='log')

fig.update_layout(
    title=f'Regression Model Predictions for {bench}',
    xaxis_title='Performance (negative log of error rate)',
    yaxis_title='Training compute (FLOP)',
    coloraxis_colorbar=dict(title=f'Predicted Year'),
    width=800,
    height=400
)

if save:
    save_plot(fig, results_dir, f'{bench}_predictions_isotime_contour')

fig.show()

In [26]:
def no_split(df, filter_threshold=None):
  return {'All': df}


def open_closed_split(df, filter_threshold=None):
  open_df = df[df['Open/Closed'] == 'Open']
  closed_df = df[df['Open/Closed'] == 'Closed']
  return {'Open': open_df, 'Closed': closed_df}


def new_old_split(df, date):
  new_df = df[df['Date'] >= date]
  old_df = df[df['Date'] < date]
  return {'Before': old_df, 'After': new_df}


def combined_rsquared(xs, ys, models):
    y_true = np.concatenate(ys)
    y_mean = np.mean(y_true)
    y_pred = []
    for i, model in enumerate(models):
        X = sm.add_constant(xs[i])
        y_pred.append(model.predict(X))
    y_pred = np.concatenate(y_pred)
    
    sst = np.sum((y_true - y_mean)**2)
    ssr = np.sum((y_true - y_pred)**2)
    r_squared = 1 - (ssr / sst)
    return r_squared


def combined_bic(xs, ys, models):
  y = np.concatenate(ys)
  y_pred = []
  total_params = 0
  for i, model in enumerate(models):
    X = sm.add_constant(xs[i])
    y_pred.append(model.predict(X))
    total_params += len(model.params)
  y_pred = np.concatenate(y_pred)
  n = len(y)
  rss = np.sum((y - y_pred)**2)
  ll = -n/2 * (1 + np.log(2*np.pi) + np.log(rss/n))
  bic = -2 * ll + total_params * np.log(n)
  return bic


# K-Fold Cross Validation
def perform_cross_validation(df, filter_fn, features, bench, k=10, random_state=42, filter_threshold=None):
  kf = KFold(n_splits=k, shuffle=True, random_state=random_state)
  folds_mses = []
  for train_index, test_index in kf.split(df):
    train_df, test_df = df.iloc[train_index], df.iloc[test_index]
    train_dfs = filter_fn(train_df, filter_threshold)
    test_dfs = filter_fn(test_df, filter_threshold)

    # Fit the models on the training set
    submodels = {}
    for category, train_df in train_dfs.items():
      model = fit_ols_regression(train_df, features, bench + '_log_error')
      submodels[category] = model

    # Predict on the test set
    residuals = []
    for i, (category, test_df) in enumerate(test_dfs.items()):
      predicted = get_predictions(submodels[category], test_df, features)
      residuals.append(predicted - test_df[bench + '_log_error'])
    residuals = np.concatenate(residuals)
    mse = np.mean(residuals**2)
    folds_mses.append(mse)

  return np.array(folds_mses)


def regression_with_results(df, filter_fn, features, bench, filter_threshold=None):
  dfs = filter_fn(df, filter_threshold)
  submodels = {category: fit_ols_regression(df, features, bench + '_log_error') for category, df in dfs.items()}
  mses = perform_cross_validation(df, filter_fn, features, bench, filter_threshold=filter_threshold)
  bic = combined_bic(
      [dfs[category][features] for category in dfs],
      [dfs[category][bench + '_log_error'] for category in dfs],
      [submodels[category] for category in dfs],
  )
  rsquared = combined_rsquared(
      [dfs[category][features] for category in dfs],
      [dfs[category][bench + '_log_error'] for category in dfs],
      [submodels[category] for category in dfs],
  )
  return {'mses': mses, 'bic': bic, 'rsquared': rsquared, 'submodels': submodels}


def boostrapped_regression_with_results(df, filter_fn, features, bench, filter_threshold=None):
  bootstrap_results = []
  for i in tqdm(range(1000)):
    resampled_df = df.sample(frac=1, replace=True, random_state=rng)
    bootstrap_results.append(regression_with_results(resampled_df, filter_fn, features, bench, filter_threshold))
  return bootstrap_results
