In [None]:
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import itertools

from ipywidgets import interact, widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy

from . import nonparametric
from . import plotting
from . import semiparametric
from . import smoothing

# Scaling Laws

In [None]:
MIN_BATCH_SIZE = 1024
MIN_ITERS = 1000
MIN_MODEL_SIZE = 4e6
MAX_MODEL_SIZE = 335e6
MAX_ITERS = 128000
BIG_BATCH_SIZE = 2**20
min_compute_budget = (
    6
    * MIN_MODEL_SIZE
    * MIN_ITERS
    * MIN_BATCH_SIZE
    * nonparametric._SEQUENCE_LENGTH
)
max_compute_budget = (
    6
    * MAX_MODEL_SIZE
    * MAX_ITERS
    * BIG_BATCH_SIZE
    * nonparametric._SEQUENCE_LENGTH
)

REASONABLE_COMPUTE_BUDGETS = np.logspace(
    np.log2(min_compute_budget), np.log2(max_compute_budget), base=2, num=100
)

## Analysis 1 - Diminishing Returns
For a fixed User budget + privacy budget, there is an optimal model size even with infinite compute

In [None]:
PRIVACY_BUDGET = 8
USER_BUDGETS = [10**6, 10**7, 10**8, 10**9]

results = pd.concat([
    nonparametric.scaling_law_query(
        user_budget=user_budget,
        privacy_budget=PRIVACY_BUDGET,
        compute_budget=compute_budget,
    ).head(n=1)
    for compute_budget in REASONABLE_COMPUTE_BUDGETS
    for user_budget in USER_BUDGETS
]).reset_index(drop=True)

In [None]:
data = (
    results.set_index(['Compute Budget', 'Data Budget'])['Cross Entropy']
    .unstack()
    .expanding()
    .min()
    .rename(
        columns={
            10**6: '$10^6$',
            10**7: '$10^7$',
            10**8: '$10^8$',
            10**9: '$10^9$',
        }
    )
)
plotting.lineplot(
    data,
    logx=True,
    symbols=[None] * 10,
    linestyles=['-', '--', '-.', ':'][::-1],
)

model_sizes = (
    results.set_index([
        'Compute Budget',
        'Model Size',
        'Data Budget',
    ])[['Cross Entropy']]
    .unstack()
    .droplevel(0)
    .idxmin()
    .values
)
print(model_sizes)
if PRIVACY_BUDGET == 8:
  plt.annotate('5.9M', (10**18.5, 4.3), xytext=None, arrowprops=None)
  plt.annotate('93M', (10**19.7, 2.8), xytext=None, arrowprops=None)

if PRIVACY_BUDGET == 1:
  plt.annotate('5.9M', (10**18.5, 4.2), xytext=None, arrowprops=None)
  plt.annotate('114M', (10**19.7, 2.8), xytext=None, arrowprops=None)
  plt.ylabel('Cross Entropy', fontsize='x-large')

plt.xlabel('Compute Budget (FLOPs)', fontsize='x-large')
plt.ylim(1, 7)

tmp_filename = 'vary_compute_and_user_privacy=%s.pdf' % PRIVACY_BUDGET
with open(tmp_filename, 'wb') as fh:
  plt.savefig(fh, format='pdf', dpi=300, bbox_inches='tight')

## Analysis 2 - Optimal Allocation

As a function of privacy budget and number of users, what is the:
1. saturating model size (the model size one should use even with infinite compute)
2. compute budget of diminishing returns (i.e., compute budget that is within 5% of optimal)

In [None]:
PRIVACY_BUDGETS = [1, 4, 16, 64]
USER_BUDGETS = [10**6, 10**7, 10**8, 10**9]

results = (
    pd.concat([
        nonparametric.scaling_law_query(
            user_budget=user_budget,
            privacy_budget=privacy_budget,
            compute_budget=compute_budget,
        )
        for compute_budget in REASONABLE_COMPUTE_BUDGETS
        for user_budget in USER_BUDGETS
        for privacy_budget in PRIVACY_BUDGETS
    ])
    .reset_index(drop=True)
    .rename(columns={'User Budget': 'Data Budget'})
)

In [None]:
def groupby_argmin(df, groupby_cols: list[str], min_col: str):
  idx = df.groupby(groupby_cols)[min_col].idxmin()
  return df.loc[idx]


def groupby_error_bounds(df, groupby_cols: list[str], min_col: str):
  def foo(sub):
    best = sub[min_col].min()
    top = sub[sub[min_col] <= best * 1.02]
    return (
        top[['Iterations', 'Batch Size', 'Model Size', 'Cross Entropy']]
        .describe()
        .loc[['min', 'max']]
    )

  return df.groupby(groupby_cols).apply(foo).unstack()


data = groupby_argmin(
    results, ['Data Budget', 'Privacy Budget'], 'Cross Entropy'
)
summary = data.set_index(['Data Budget', 'Privacy Budget'])[[
    'Compute Budget',
    'Cross Entropy',
    'Model Size',
    'Iterations',
    'Batch Size',
]]

table = summary['Model Size'].unstack()
table.round()

In [None]:
best_points = data = groupby_argmin(
    results,
    ['Data Budget', 'Compute Budget', 'Privacy Budget'],
    'Cross Entropy',
).set_index(['Data Budget', 'Compute Budget', 'Privacy Budget'])
best_range = groupby_error_bounds(
    results,
    ['Data Budget', 'Compute Budget', 'Privacy Budget'],
    'Cross Entropy',
)

In [None]:
user_budget = 10**9
for ylabel in ['Cross Entropy', 'Model Size', 'Batch Size', 'Iterations']:

  data = best_points[ylabel].unstack().loc[user_budget][[1, 4, 16, 64]]
  if ylabel != 'Cross Entropy':
    std = (
        best_range[ylabel]
        .loc[user_budget]
        .stack()
        .unstack('Privacy Budget')
        .unstack()
    )
    data = std.xs('min', level=1, axis=1)

  else:
    std = None

  logy = ylabel != 'Cross Entropy'
  plotting.lineplot(
      data,
      std=std,
      logx=True,
      logy=logy,
      symbols=[None] * 10,
      fill_between=True,
      linestyles=['-', '--', '-.', ':'][::-1],
  )
  # plt.title('Data Budget: %s' % user_budget)
  plt.xlabel('Compute Budget')
  plt.ylabel(ylabel)
  if ylabel == 'Model Size':
    plt.yticks(list(nonparametric._MODEL_SIZES.values()))

  tmp_filename = 'optimal_%s_%dusers.pdf' % (ylabel, user_budget)
  with open(tmp_filename, 'wb') as fh:
    plt.savefig(fh, format='pdf', dpi=300, bbox_inches='tight')

## Analyis 3 - Saturating Compute Budget Table

In [None]:
PRIVACY_BUDGETS = [1, 4, 16, 64]
USER_BUDGETS = [10**5, 10**6, 10**7, 10**8, 10**9]

results = pd.concat([
    nonparametric.scaling_law_query(
        user_budget=user_budget,
        privacy_budget=privacy_budget,
        compute_budget=compute_budget,
    ).head(n=1)
    for compute_budget in REASONABLE_COMPUTE_BUDGETS
    for user_budget in USER_BUDGETS
    for privacy_budget in PRIVACY_BUDGETS
]).reset_index(drop=True)

In [None]:
# Find the minimum compute budget that gets within 1% of optimal xent


def find_saturating_config(df):
  best = df['Cross Entropy'].min()
  return (
      df[df['Cross Entropy'] <= 1.01 * best]
      .sort_values('Compute Budget')
      .iloc[0]
  )


summary = results.groupby(['Data Budget', 'Privacy Budget']).apply(
    find_saturating_config
)
summary['Tokens'] = (
    summary['Batch Size']
    * summary['Iterations']
    * nonparametric._SEQUENCE_LENGTH
)
summary['Ratio'] = summary['Tokens'] / summary['Model Size']

cols = [
    'Compute Budget',
    'Cross Entropy',
    'Model Size',
    'Iterations',
    'Batch Size',
    'Tokens',
    'Ratio',
]
summary[cols]

# Analysis 4 - What is the optimal Token / Model Ratio

In [None]:
PRIVACY_BUDGETS = [10**k for k in range(7)]
USER_BUDGETS = [10**7]
COARSE_COMPUTE_BUDGETS = np.logspace(16, 23, num=2 * (23 - 16 + 1) - 1)

results = pd.concat([
    nonparametric.scaling_law_query(
        user_budget=user_budget,
        privacy_budget=privacy_budget,
        compute_budget=compute_budget,
    )
    for compute_budget in COARSE_COMPUTE_BUDGETS
    for user_budget in USER_BUDGETS
    for privacy_budget in PRIVACY_BUDGETS
]).reset_index(drop=True)

In [None]:
def groupby_argmin(df, groupby_cols: list[str], min_col: str):
  idx = df.groupby(groupby_cols)[min_col].idxmin()
  return df.loc[idx]


def groupby_error_bounds(df, groupby_cols: list[str], min_col: str):
  def foo(sub):
    best = sub[min_col].min()
    top = sub[sub[min_col] <= best * 1.02]
    return (
        top[['Iterations', 'Batch Size', 'Model Size', 'Token Model Ratio']]
        .describe()
        .loc[['min', 'max']]
    )

  return df.groupby(groupby_cols).apply(foo).unstack()


summary = groupby_argmin(
    results,
    ['Data Budget', 'Privacy Budget', 'Compute Budget'],
    'Cross Entropy',
)
bounds = groupby_error_bounds(
    results,
    ['Data Budget', 'Privacy Budget', 'Compute Budget'],
    'Cross Entropy',
)
# Apply filter when we reach maximum model size we evaluated
bounds = bounds[bounds['Model Size']['max'] < 3.316500e08]

In [None]:
std = (
    bounds['Token Model Ratio']
    .loc[10**7]
    .stack()
    .unstack('Privacy Budget')
    .unstack()
)
data = std.xs('min', level=1, axis=1).rename(
    columns=lambda x: '$10^{%d}$' % np.log10(x).round()
    if isinstance(x, int) and x >= 100
    else x
)
data = (
    summary[summary['Model Size'] < 3e08]
    .set_index(['Privacy Budget', 'Compute Budget'])['Token Model Ratio']
    .unstack('Privacy Budget')
)

In [None]:
data = (
    summary[summary['Model Size'] < 3e08]
    .set_index(['Privacy Budget', 'Compute Budget'])['Token Model Ratio']
    .unstack('Privacy Budget')
)

In [None]:
STD = std.rename(
    columns=lambda x: '$10^{%d}$' % np.log10(x).round()
    if isinstance(x, int) and x >= 100
    else x,
    level=0,
)
STD[('$\infty$ (Chinchilla)', 'min')] = STD[
    ('$\infty$ (Chinchilla)', 'max')
] = 20
DATA = STD.xs('min', level=1, axis=1)
DATA['$\infty$ (Chinchilla)'] = 20

In [None]:
plotting.lineplot(DATA, STD, logx=True, logy=True, ncol=4, fill_between=True)
plt.ylabel('Tokens / Model Size')
plt.xlabel('Compute Budget')
plt.ylim(5, 5e6)

handles, labels = plt.gca().get_legend_handles_labels()
order = [0, 4, 1, 5, 2, 6, 3, 7]
plt.legend(
    [handles[idx] for idx in order],
    [labels[idx] for idx in order],
    title='Privacy Budget',
    frameon=True,
    handlelength=1,
    handletextpad=0.3,
    borderpad=0.5,
    fontsize='large',
    ncol=4,
)

tmp_filename = 'token_model_ratios_1000000users.pdf'
with open(tmp_filename, 'wb') as fh:
  plt.savefig(fh, format='pdf', dpi=300, bbox_inches='tight')

# Interactive Visualization

1. Cross Entropy vs. Iterations (minimizing over Batch Size / Model Size)
2. Cross Entropy vs. Batch Size (minimizing over Iterations / Model Size)
3. Cross Entropy vs. Model Size (minimizing over Iterations / Batch Size)

In [None]:
def update_plot(USER_BUDGET, PRIVACY_BUDGET, COMPUTE_BUDGET):
  """Regenerates the plot with updated budget values."""

  results = nonparametric.scaling_law_query(
      user_budget=USER_BUDGET,
      privacy_budget=PRIVACY_BUDGET,
      compute_budget=COMPUTE_BUDGET,
  )
  base = 2**0.25
  results['Iterations'] = (
      base ** (np.log(results.Iterations) / np.log(base)).round()
  )

  fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
  for xlabel, ax in zip(['Iterations', 'Batch Size', 'Model Size'], axes):
    data = results.groupby(xlabel)[['Cross Entropy']].min()
    plotting.lineplot(data, ax=ax, symbols=[None], logx=True, legend=False)
    if xlabel == 'Iterations':
      ax.set_ylabel('Cross Entropy')
    ax.set_xlabel(xlabel)

  plt.show()


# Create sliders
user_budget_slider = widgets.FloatLogSlider(
    value=2**25, base=2, min=20, max=30, step=1, description='User Budget:'
)
privacy_budget_slider = widgets.FloatLogSlider(
    value=1, base=2, min=-6, max=6, step=0.1, description='Privacy Budget:'
)
compute_budget_slider = widgets.FloatLogSlider(
    value=10**21.1,
    base=10,
    min=16,
    max=23,
    step=0.1,
    description='Compute Budget:',
)

# Use interact to link sliders to the update_plot function
interact(
    update_plot,
    USER_BUDGET=user_budget_slider,
    PRIVACY_BUDGET=privacy_budget_slider,
    COMPUTE_BUDGET=compute_budget_slider,
)

# Analysis 5 - Full 3x3 Grid

In [None]:
user_budget = 1e7
privacy_budget = 4
compute_budget = 1e19


def format_legend(x):
  if x in lines['Privacy Budget']:
    text = str(x)
    return '$\mathbf{%s}$' % text if x == privacy_budget else x
  elif x in lines['Data Budget']:
    text = '$10^{%d}$' % np.log10(x).round()
    return '$\mathbf{%s}$' % text[1:-1] if x == user_budget else text
  elif x in lines['Compute Budget']:
    text = '$10^{%d}$' % np.log10(float(x)).round()
    return '$\mathbf{%s}$' % text[1:-1] if x == compute_budget else text
  else:
    raise ValueError(x)


inputs = {
    'user_budget': user_budget,
    'privacy_budget': privacy_budget,
    'compute_budget': compute_budget,
}

xlabels = ['Iterations', 'Model Size', 'Batch Size']
lines = {
    'Privacy Budget': [1, 4, 16, 64],
    'Data Budget': [10**6, 10**7, 10**8, 10**9],
    'Compute Budget': [10**17, 10**19, 10**21, 10**23],
}

names = {
    'Privacy Budget': 'privacy_budget',
    'Data Budget': 'user_budget',
    'Compute Budget': 'compute_budget',
}

fig, axes = plt.subplots(
    3, 3, figsize=(12 * 1.6, 12), sharey=True
)  # , sharex=True)

for i, xlabel in enumerate(xlabels):
  for j, legend in enumerate(lines):

    results = pd.concat([
        nonparametric.scaling_law_query(**{**inputs, names[legend]: key})
        for key in lines[legend]
    ])

    base = 2**0.25
    results['Iterations'] = (
        base ** (np.log(results.Iterations) / np.log(base)).round()
    )
    data = (
        results.groupby([xlabel, legend])['Cross Entropy']
        .min()
        .unstack()
        .rename(columns=format_legend)
    )

    plotting.lineplot(
        data,
        logx=True,
        symbols=[None] * 10,
        linestyles=['-', '--', '-.', ':'][::-1],
        ax=axes[j][i],
        legend=False,
    )
    if i == 0:
      axes[j][i].set(ylabel='Cross Entropy')
    if True:
      axes[j][i].set(xlabel=xlabel)

    if i == 2:
      axes[j][i].legend(
          title=data.columns.name,
          frameon=True,
          handlelength=1,
          handletextpad=0.5,
          borderpad=0.5,
          fontsize='large',
          loc='upper left',
          ncol=1,
          bbox_to_anchor=(1, 1),
      )

plt.tight_layout()
plt.ylim(2.5, 6.5)

tmp_filename = 'optimal_3x3.pdf'
with open(tmp_filename, 'wb') as fh:
  plt.savefig(fh, format='pdf', dpi=300, bbox_inches='tight')

# Comparison to Baseline

In [None]:
PRIVACY_BUDGETS = [
    1,
    2,
    4,
    8,
    16,
    32,
    64,
    128,
    256,
    512,
    1024,
    2048,
    4096,
    8192,
]


baselines = {
    'BertLarge': [335e6, 7500],
    'BertMedium': [41e6, 5000],
    'BertTiny': [4.6e6, 2500],
}

frames = {}


for model_name, (model_size, iterations) in baselines.items():
  batch_size = (
      1e19 / model_size / iterations / nonparametric._SEQUENCE_LENGTH / 6
  )
  configs = pd.DataFrame()
  configs['Privacy Budget'] = PRIVACY_BUDGETS
  configs['Iterations'] = iterations
  configs['Model Size'] = model_size
  configs['Batch Size'] = batch_size
  configs['Data Budget'] = 10**7
  configs['Compute Budget'] = (
      6 * nonparametric._SEQUENCE_LENGTH * model_size * iterations * batch_size
  )

  frames[model_name] = nonparametric.batched_scaling_law_query(
      compute_configs=configs
  )

frames['$1 \cdot 10^{19}$ FLOPs'] = pd.concat([
    nonparametric.scaling_law_query(
        user_budget=10**7,
        privacy_budget=privacy_budget,
        compute_budget=1e19,
    ).head(n=1)
    for privacy_budget in PRIVACY_BUDGETS
])

frames['$2\\cdot 10^{18}$ FLOPs'] = pd.concat([
    nonparametric.scaling_law_query(
        user_budget=10**7,
        privacy_budget=privacy_budget,
        compute_budget=2e18,
    ).head(n=1)
    for privacy_budget in PRIVACY_BUDGETS
])

In [None]:
from matplotlib.lines import Line2D

data = pd.DataFrame()
for column in frames:
  data['   ' + column] = frames[column].set_index('Privacy Budget')[
      'Cross Entropy'
  ]

plotting.lineplot(
    data,
    logx=True,
)
plt.xlabel('Privacy Budget (Epsilon)')
plt.ylabel('Cross Entropy')

handles, labels = plt.gca().get_legend_handles_labels()
dummy = Line2D([0], [0], color='none', linewidth=0, label='Compute Optimal')
plt.legend(
    handles=handles[:3] + [dummy] + handles[3:],
    loc='upper right',
    title=data.columns.name,
    frameon=True,
    handlelength=0.5,
    handletextpad=0,
    borderpad=0.5,
    fontsize='large',
    ncol=2,
)

tmp_filename = 'baseline_compare.pdf'
with open(tmp_filename, 'wb') as fh:
  plt.savefig(fh, format='pdf', dpi=300, bbox_inches='tight')

## Analysis 6 - Vector Field Visualization

Note: For some reason the plots visualized in this notebook have small arrows.  But the downloaded plots look correct.  When commenting out the download_file colab magic, and using plt.show(), the plots render correctly in colab.

In [None]:
dpsgd_sigmas = nonparametric.load_accounting_data()

data_batch_sizes = pd.DataFrame(
    data=[[2**k, 2**j] for k in range(30) for j in range(k + 1)],
    columns=['Records', 'Batch Size'],
)
data_batch_sizes['Sampling Probability'] = (
    data_batch_sizes['Batch Size'] / data_batch_sizes['Records']
)

accounting = pd.merge(dpsgd_sigmas, data_batch_sizes, on='Sampling Probability')
accounting['Noise Batch Ratio'] = (
    accounting['Noise Multiplier'] / accounting['Batch Size']
)
accounting['Epochs'] = (
    accounting['Iterations'] * accounting['Batch Size'] / accounting['Records']
)

In [None]:
table = (
    accounting[(accounting.Iterations == 16000) & (accounting.Records == 2**24)]
    .set_index(['Batch Size', 'Epsilon'])['Noise Batch Ratio']
    .unstack()
    .iloc[-15:-2, -13:]
)
plotting.plot_vector_field(table)
plt.xticks([1 / 64, 1 / 8, 1, 8, 64], labels=['1/64', '1/8', '1', '8', '64'])
plt.xlabel('Privacy Budget (Epsilon)', fontsize='large')
plt.ylabel('Compute Budget (Batch Size)', fontsize='large')

tmp_filename = 'accounting_privacy_vs_compute.pdf'
with open(tmp_filename, 'wb') as fh:
  plt.savefig(fh, format='pdf', dpi=300, bbox_inches='tight')

In [None]:
table = (
    accounting[(accounting.Iterations == 16000) & (accounting.Epsilon == 4.0)]
    .set_index(['Batch Size', 'Records'])['Noise Batch Ratio']
    .unstack()
    .iloc[9:22, -13:]
)

plotting.plot_vector_field(table)
plt.xticks(
    [10**k for k in range(5, 9, 1)],
    labels=[f'$10^{k}$' for k in range(5, 9, 1)],
)
plt.yticks(
    [10**k for k in range(3, 7, 1)],
    labels=[f'$10^{k}$' for k in range(3, 7, 1)],
)
plt.xlabel('Data Budget (Users)', fontsize='large')
plt.ylabel('Compute Budget (Batch Size)', fontsize='large')

tmp_filename = 'accounting_data_vs_compute.pdf'
with open(tmp_filename, 'wb') as fh:
  plt.savefig(fh, format='pdf', dpi=300, bbox_inches='tight')

In [None]:
table = (
    accounting[
        (accounting.Iterations == 16000) & (accounting['Batch Size'] == 65536)
    ]
    .set_index(['Epsilon', 'Records'])['Noise Batch Ratio']
    .unstack()
    .iloc[-13:, -13:]
)

plotting.plot_vector_field(table)
grid = [10**k for k in range(0, 9, 2)]
labels = [f'$10^{k}$' for k in range(0, 9, 2)]
plt.xticks(
    [10**k for k in range(5, 9, 1)],
    labels=[f'$10^{k}$' for k in range(5, 9, 1)],
)
plt.yticks([1 / 64, 1 / 8, 1, 8, 64], labels=['1/64', '1/8', '1', '8', '64'])
# plt.yticks([1/10, 1, 10, 100], labels=['0.1', '1', '10', '100'])
plt.ylabel('Privacy Budget (Epsilon)', fontsize='large')
plt.xlabel('Data Budget (Users)', fontsize='large')

tmp_filename = 'accounting_data_vs_privacy.pdf'
with open(tmp_filename, 'wb') as fh:
  plt.savefig(fh, format='pdf', dpi=300, bbox_inches='tight')

## Analysis 7 - Learning Rate Ablation

In [None]:
df = pd.read_csv(nonparametric._EXPERIMENT_PATH)
df['Learning Rate'] = 2 ** np.log2(df['Learning Rate']).round()

In [None]:
def groupby_argmin(df, groupby_cols: list[str], min_col: str):
  idx = df.groupby(groupby_cols)[min_col].idxmin()
  return df.loc[idx]

In [None]:
iters = [100 * int(2 ** (k / 2)) for k in range(6, 21)]
model = 'BertLarge'

for nbr in [20, 15, 10]:
  results = (
      df[
          (df.Model == model)
          & (np.log2(df['Noise Batch Ratio']).round() == -nbr)
      ]
      .set_index(['Iterations', 'Learning Rate'])['Cross Entropy']
      .unstack()
      .loc[iters]
      .rename(lambda x: '$0.5^{%d}$' % round(-np.log2(x)), axis=1)
  )

  plotting.lineplot(results, logx=True)
  if nbr == 20:
    plt.ylabel('Cross Entropy')
  plt.xlabel('Iterations')

  tmp_filename = 'learning_rate_%s_%d.pdf' % (model, nbr)
  with open(tmp_filename, 'wb') as fh:
    plt.savefig(fh, format='pdf', dpi=300, bbox_inches='tight')