In [None]:
from google3.learning.deepmind.xmanager2.client import xmanager_api
import matplotlib as mpl
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import collections

xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 14
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
mpl.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = 20
plt.rcParams['axes.linewidth'] = 2
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'

SMALL_SIZE = 16
MEDIUM_SIZE = 18
BIGGER_SIZE = 20

plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')

In [None]:
def read_xm_metrics(example_xid, metric_name, unit_id, lowest=True):
  experiment = xm_client.get_experiment(example_xid)
  work_unit = experiment.get_work_unit(unit_id)
  all_series = work_unit.list_measurement_series()
  # Read measurement series metadata.
  for series in all_series:
    if series.label == metric_name:
      # Read measurement points data.
      all_measurements = []
      for measurement in series.measurements:
        all_measurements.append(measurement.objective_value)
      if lowest:
        return min(all_measurements)
      else:
        return all_measurements



In [None]:
# @title Data Scaling

# Get unique learning rates

xm_id = 117136753
xm_id_dict = {
    'Large': [117136753, 328.13],
    'Base': [117137176, 110.74],
    'Small': [117137094, 24.59],
    'Tiny': [117137514, 2.22],
    }

metric_names = [
    'valid_mean_absolute_error_all',
    'valid_mean_absolute_error_masked',
    'valid_mean_squared_error_all',
    'valid_mean_squared_error_masked',
]
xm_exp_dict = collections.defaultdict(list)
for key, values in xm_id_dict.items():
  xm_id = values[0]
  param_size = values[1]
  experiment = xm_client.get_experiment(xm_id)
  num_of_units = experiment.get_num_work_units()
  for id in range(num_of_units):
    real_id = id + 1
    work_unit = experiment.get_work_unit(real_id)
    key_list = work_unit.parameters.keys()
    xm_exp_dict['unit_id'].append(id)
    xm_exp_dict['xm_id'].append(xm_id)
    xm_exp_dict['Param Size'].append(param_size)
    xm_exp_dict['Model Size'].append(key)
    for param_name in key_list:
      xm_exp_dict[param_name].append(work_unit.parameters[param_name])
    for metric in metric_names:
      xm_exp_dict[metric].append(
          read_xm_metrics(xm_id, metric, real_id, lowest=False)
      )
df = pd.DataFrame(xm_exp_dict)
df

In [None]:
metric_name = 'valid_mean_absolute_error_all'
metric_name_short = 'MAE All Patches'
sample_size = 5_000_000
compute_hours = [0, 7000, 12500, 18750, 25000, 50000, 75000, 100000]

# Create a figure with three subplots in a row
fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=600)

# Figure 1: Compute Scaling
subset = df[
    (df['config.dataset_configs.train_num_samples'] == sample_size)
    & (df['Model Size'] == 'Large')
]

num_of_logging = len(subset.iloc[0][metric_name])
logging_in_steps = [
    int((idx - 1) / 100000 * num_of_logging) for idx in compute_hours
]
x = [int(x / 250) for x in compute_hours]
y = [subset.iloc[0][metric_name][idx] for idx in logging_in_steps]

sns.barplot(x=x, y=y[: len(x)], palette='Blues', ax=axes[0])
axes[0].set_title('Compute Scaling')
axes[0].set_xlabel(r'$\mathbf{Compute}$' + '\n TPU v5 VLP core hours')
axes[0].set_ylabel(metric_name_short)

# Figure 2: Data Scaling
subset = df[(df['Model Size'] == 'Large')]
x = [
    row['config.dataset_configs.train_num_samples'] / 1_000_000 * 5
    for _, row in subset.iterrows()
]
y = [min(row[metric_name]) for _, row in subset.iterrows()]

sns.barplot(x=x, y=y, palette='Blues', ax=axes[1])
axes[1].set_title('Data Scaling')
axes[1].set_xlabel(r'$\mathbf{Data\ Size}$' + '\n(Million Hours)')

# Figure 3: Model Scaling
subset = df[(df['config.dataset_configs.train_num_samples'] == sample_size)]
x = [row['Param Size'] for _, row in subset.iterrows()]
y = [min(row[metric_name]) for _, row in subset.iterrows()]

sns.barplot(x=x, y=y, palette='Blues', ax=axes[2])
axes[2].set_title('Model Size Scaling')
axes[2].set_xlabel(r'$\mathbf{Model\ Size}$' + '\n(Million of Params)')

plt.tight_layout()
plt.show()


In [None]:
metric_name = 'valid_mean_absolute_error_masked'
metric_name_short = 'MAE Masked Patches'
sample_size = 5_000_000
compute_hours = [0, 7000, 12500, 18750, 25000, 50000, 75000, 100000]

# Create a figure with three subplots in a row
fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=600)

# Figure 1: Compute Scaling
subset = df[
    (df['config.dataset_configs.train_num_samples'] == sample_size)
    & (df['Model Size'] == 'Large')
]

num_of_logging = len(subset.iloc[0][metric_name])
logging_in_steps = [
    int((idx - 1) / 100000 * num_of_logging) for idx in compute_hours
]
x = [int(x / 250) for x in compute_hours]
y = [subset.iloc[0][metric_name][idx] for idx in logging_in_steps]

sns.barplot(x=x, y=y[: len(x)], palette='Blues', ax=axes[0])
axes[0].set_title('Compute Scaling')
axes[0].set_xlabel(r'$\mathbf{Compute}$' + '\n TPU v5 VLP core hours')
axes[0].set_ylabel(metric_name_short)

# Figure 2: Data Scaling
subset = df[(df['Model Size'] == 'Large')]
x = [
    row['config.dataset_configs.train_num_samples'] / 1_000_000 * 5
    for _, row in subset.iterrows()
]
y = [min(row[metric_name]) for _, row in subset.iterrows()]

sns.barplot(x=x, y=y, palette='Blues', ax=axes[1])
axes[1].set_title('Data Scaling')
axes[1].set_xlabel(r'$\mathbf{Data\ Size}$' + '\n(Million Hours)')

# Figure 3: Model Scaling
subset = df[(df['config.dataset_configs.train_num_samples'] == sample_size)]
x = [row['Param Size'] for _, row in subset.iterrows()]
y = [min(row[metric_name]) for _, row in subset.iterrows()]

sns.barplot(x=x, y=y, palette='Blues', ax=axes[2])
axes[2].set_title('Model Size Scaling')
axes[2].set_xlabel(r'$\mathbf{Model\ Size}$' + '\n(Million of Params)')

plt.tight_layout()
plt.show()


In [None]:
#@title Compute Scaling across data sizes and model sizes

# Sample data and parameters
metric_name = 'valid_mean_absolute_error_all'
metric_name_short = 'MAE All Patches'
sample_size = 5_000_000
compute_hours = [0, 7000, 12500, 18750, 25000, 50000, 75000, 100000]

# Assuming df is your DataFrame
unique_num_samples = df['config.dataset_configs.train_num_samples'].unique()
unique_model_sizes = df['Model Size'].unique()

# Create a figure with NxM subplots
fig, axes = plt.subplots(
    len(unique_num_samples),
    len(unique_model_sizes),
    figsize=(24, 6 * len(unique_num_samples)),
    dpi=600,
)
axes = (
    axes.flatten()
    if len(unique_num_samples) > 1 and len(unique_model_sizes) > 1
    else [axes]
)

for i, num_samples in enumerate(unique_num_samples):
  for j, model_size in enumerate(unique_model_sizes):
    ax = axes[i * len(unique_model_sizes) + j]

    subset = df[
        (df['config.dataset_configs.train_num_samples'] == num_samples)
        & (df['Model Size'] == model_size)
    ]
    num_of_logging = len(subset.iloc[0][metric_name])
    logging_in_steps = [
        int((idx - 1) / 100000 * num_of_logging) for idx in compute_hours
    ]
    x = [int(x / 250) for x in compute_hours]
    y = [subset.iloc[0][metric_name][idx] for idx in logging_in_steps]

    sns.barplot(x=x, y=y[: len(x)], palette='Blues', ax=ax)
    ax.set_title(f'Data: {num_samples}\nModel Size: {model_size}')

    if j == 0:
      ax.set_ylabel(metric_name_short)
    else:
      ax.set_ylabel('')

    if i == len(unique_num_samples) - 1:
      ax.set_xlabel(r'$\mathbf{Compute}$' + '\n TPU v5 VLP core hours')
    else:
      ax.set_xlabel('')

plt.tight_layout()
plt.show()

In [None]:
#@title Data Scaling across model sizes

# Sample data and parameters
metric_name = 'valid_mean_absolute_error_all'
metric_name_short = 'MAE All Patches'

# Assuming df is your DataFrame
unique_model_sizes = df['Model Size'].unique()

# Create a figure with subplots for each model size
fig, axes = plt.subplots(
    1,
    len(unique_model_sizes),
    figsize=(24, 6),
    dpi=600
)

for j, model_size in enumerate(unique_model_sizes):
    ax = axes[j]

    subset = df[df['Model Size'] == model_size]
    x = [
        row['config.dataset_configs.train_num_samples'] / 1_000_000 * 5
        for _, row in subset.iterrows()
    ]
    y = [min(row[metric_name]) for _, row in subset.iterrows()]

    sns.barplot(x=x, y=y, palette='Blues', ax=ax)
    ax.set_title(f'Data Scaling\nModel Size: {model_size}')

    if j == 0:
        ax.set_ylabel(metric_name_short)
    else:
        ax.set_ylabel('')

    ax.set_xlabel(r'$\mathbf{Data\ Size}$' + '\n(Million Samples)')

plt.tight_layout()
plt.show()


In [None]:
#@title Model Scaling All across data sizes

# Sample data and parameters
metric_name = 'valid_mean_absolute_error_all'
metric_name_short = 'MAE All Patches'
sample_size = 5_000_000

# Assuming df is your DataFrame
unique_num_samples = df['config.dataset_configs.train_num_samples'].unique()

# Create a figure with subplots for each data size
fig, axes = plt.subplots(
    1,
    len(unique_num_samples),
    figsize=(28, 6),
    dpi=600
)

for j, num_samples in enumerate(unique_num_samples):
    ax = axes[j]

    subset = df[df['config.dataset_configs.train_num_samples'] == num_samples]
    x = [row['Param Size'] for _, row in subset.iterrows()]
    y = [min(row[metric_name]) for _, row in subset.iterrows()]

    sns.barplot(x=x, y=y, palette='Blues', ax=ax)
    ax.set_title(f'Model Size Scaling\nSamples: {num_samples}')

    if j == 0:
        ax.set_ylabel(metric_name_short)
    else:
        ax.set_ylabel('')

    ax.set_xlabel(r'$\mathbf{Model\ Size}$' + '\n(Million Params)')

plt.tight_layout()
plt.show()
