In [None]:
from google3.learning.deepmind.xmanager2.client import xmanager_api
import matplotlib as mpl
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import collections
import re

xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 18
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
mpl.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 2
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'


plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE-5)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')

In [None]:
def read_xm_metrics(example_xid, metric_name, unit_id, lowest=True):
  experiment = xm_client.get_experiment(example_xid)
  work_unit = experiment.get_work_unit(unit_id)
  all_series = work_unit.list_measurement_series()
  # Read measurement series metadata.
  for series in all_series:
    if series.label == metric_name:
      # Read measurement points data.
      all_measurements = []
      for measurement in series.measurements:
        all_measurements.append(measurement.objective_value)
      if lowest:
        return min(all_measurements)
      else:
        return all_measurements


def add_min_columns(df):
  # Function to calculate the minimum value in each list
  def min_of_list(lst):
    return min(lst)

  # Calculate minimum values and add as new columns
  df['min_valid_mean_absolute_error_all'] = df[
      'valid_mean_absolute_error_all'
  ].apply(min_of_list)
  df['min_valid_mean_absolute_error_masked'] = df[
      'valid_mean_absolute_error_masked'
  ].apply(min_of_list)
  df['min_valid_mean_squared_error_all'] = df[
      'valid_mean_squared_error_all'
  ].apply(min_of_list)
  df['min_valid_mean_squared_error_masked'] = df[
      'valid_mean_squared_error_masked'
  ].apply(min_of_list)

  return df


def process_string_metric(input_string):
  # Define the mapping of long error names to their abbreviations
  error_map = {'mean_absolute_error': 'mae', 'mean_squared_error': 'mse'}

  # Replace the errors in the string using the map
  for long_error, short_error in error_map.items():
    input_string = re.sub(long_error, short_error, input_string)

  # Remove 'valid_' and replace '/' with '_'
  input_string = input_string.replace('valid_', '').replace('/', '_')

  return input_string

In [None]:
# @title Data Scaling

# Get unique learning rates


xm_id_dict = {  # Model Size, ParamSize, PatchSize
    # 122552298: ['Deb', 0.11, '10x5'],
    122552990: ['Tiny', 2.21, '10x5'],
    122552440: ['ExtraSmall', 7.3, '10x5'],
    122552749: ['Small', 24.6, '10x5'],
    122527956: ['Base', 110.74, '10x5'],
    # 122528949: ['Large', 328.13, '10x5'],
}


metric_names = [
    'valid_mean_absolute_error_all',
    'valid_mean_absolute_error_masked',
    'valid_mean_squared_error_all',
    'valid_mean_squared_error_masked',
    'forecast_0.1_eval/valid_mean_absolute_error_all',
    'forecast_0.1_eval/valid_mean_absolute_error_masked',
    'forecast_0.1_eval/valid_mean_squared_error_all',
    'forecast_0.1_eval/valid_mean_squared_error_masked',
    'forecast_0.2_eval/valid_mean_absolute_error_all',
    'forecast_0.2_eval/valid_mean_absolute_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_all',
    'forecast_0.2_eval/valid_mean_squared_error_masked',
    'forecast_0.4_eval/valid_mean_absolute_error_all',
    'forecast_0.4_eval/valid_mean_absolute_error_masked',
    'forecast_0.4_eval/valid_mean_squared_error_all',
    'forecast_0.4_eval/valid_mean_squared_error_masked',
    'imputation_0.1_eval/valid_mean_absolute_error_all',
    'imputation_0.1_eval/valid_mean_absolute_error_masked',
    'imputation_0.1_eval/valid_mean_squared_error_all',
    'imputation_0.1_eval/valid_mean_squared_error_masked',
    'imputation_0.2_eval/valid_mean_absolute_error_all',
    'imputation_0.2_eval/valid_mean_absolute_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_all',
    'imputation_0.2_eval/valid_mean_squared_error_masked',
    'imputation_0.4_eval/valid_mean_absolute_error_all',
    'imputation_0.4_eval/valid_mean_absolute_error_masked',
    'imputation_0.4_eval/valid_mean_squared_error_all',
    'imputation_0.4_eval/valid_mean_squared_error_masked',
]
xm_exp_dict = collections.defaultdict(list)
for key, values in xm_id_dict.items():
  xm_id = key
  model_size = values[0]
  param_size = values[1]
  patch_size = values[2]
  experiment = xm_client.get_experiment(xm_id)
  num_of_units = experiment.get_num_work_units()
  for id in range(num_of_units):
    real_id = id + 1
    work_unit = experiment.get_work_unit(real_id)
    key_list = work_unit.parameters.keys()
    xm_exp_dict['unit_id'].append(id)
    xm_exp_dict['xm_id'].append(xm_id)
    xm_exp_dict['Param Size'].append(param_size)
    xm_exp_dict['Model Size'].append(model_size)
    xm_exp_dict['Patch Size'].append(patch_size)
    for param_name in key_list:
      xm_exp_dict[param_name].append(work_unit.parameters[param_name])
    for metric in metric_names:
      xm_exp_dict[metric].append(
          read_xm_metrics(xm_id, metric, real_id, lowest=False)
      )
df = pd.DataFrame(xm_exp_dict)
df = add_min_columns(df)
df

##

In [None]:
use_aug = True
use_last = True
sample_size = 1321235
compute_hours = [1000, 5000, 10000, 15000, 25000, 30000, 40000, 50000]
default_marker_color = '#1967d2'
for metric_name in metric_names:
  displayed_metric = process_string_metric(metric_name)
  # Create a figure with three subplots in a row
  fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=600)
  # Figure 1: Compute Scaling
  for data_size in df['config.dataset_configs.train_num_samples'].unique():
    for model_size in df['Model Size'].unique():
      subset = df[
          (df['Model Size'] == model_size)
          & (df['config.dataset_configs.train_num_samples'] == data_size)
          & (df['config.use_train_augmentations'] == use_aug)
      ]
      num_of_logging = len(subset.iloc[0][metric_name])
      logging_in_steps = [
          int((idx - 1) / 100000 * num_of_logging) for idx in compute_hours
      ]
      x = [int(x / 600) for x in compute_hours]
      y = [subset.iloc[0][metric_name][idx] for idx in logging_in_steps]

      sns.scatterplot(
          x=x,
          y=y[: len(x)],
          color=default_marker_color,
          ax=axes[0],
          s=150,
          alpha=0.5,
      )
      sns.lineplot(
          x=x, y=y[: len(x)], color=default_marker_color, ax=axes[0], alpha=0.2
      )
      axes[0].set_title(f'Compute Scaling')
      axes[0].set_xlabel(r'$\mathbf{Compute}$' + '\n TPU v5 VLP core hours')
      axes[0].set_ylabel(displayed_metric)
  # Figure 2: Data Scaling
  for model_size in df['Model Size'].unique():
    subset = df[
        (df['Model Size'] == model_size)
        & (df['config.use_train_augmentations'] == use_aug)
    ]
    x = [
        round(
            row['config.dataset_configs.train_num_samples'] / 1_000_000 * 5, 2
        )
        for _, row in subset.iterrows()
    ]
    if use_last:
      y = [row[metric_name][-1] for _, row in subset.iterrows()]
    else:
      y = [min(row[metric_name]) for _, row in subset.iterrows()]
    sns.scatterplot(
        x=x, y=y, color=default_marker_color, ax=axes[1], s=150, alpha=0.5
    )
    sns.lineplot(x=x, y=y, color=default_marker_color, ax=axes[1], alpha=0.2)
    axes[1].set_title(f'Data Scaling')
    axes[1].set_xlabel(r'$\mathbf{Data\ Size}$' + '\n(Million Hours)')
  # Figure 3: Model Scaling
  for data_size in df['config.dataset_configs.train_num_samples'].unique():
    subset = df[
        (df['config.dataset_configs.train_num_samples'] == data_size)
        & (df['config.use_train_augmentations'] == use_aug)
    ]
    x = [row['Param Size'] for _, row in subset.iterrows()]
    if use_last:
      y = [row[metric_name][-1] for _, row in subset.iterrows()]
    else:
      y = [min(row[metric_name]) for _, row in subset.iterrows()]
    sns.scatterplot(
        x=x, y=y, color=default_marker_color, ax=axes[2], s=150, alpha=0.5
    )
    sns.lineplot(x=x, y=y, color=default_marker_color, ax=axes[2], alpha=0.2)
    axes[2].set_title('Model Size Scaling')
    axes[2].set_xlabel(r'$\mathbf{Model\ Size}$' + '\n(Million of Params)')
  plt.tight_layout()
  plt.show()

In [None]:
use_aug = False
use_last = True
sample_size = 1321235
compute_hours = [1000, 5000, 10000, 15000, 25000, 30000, 40000, 50000]
default_marker_color = '#1967d2'
for metric_name in metric_names:
  displayed_metric = process_string_metric(metric_name)
  # Create a figure with three subplots in a row
  fig, axes = plt.subplots(1, 3, figsize=(18, 6), dpi=600)
  # Figure 1: Compute Scaling
  for data_size in df['config.dataset_configs.train_num_samples'].unique():
    for model_size in df['Model Size'].unique():
      subset = df[
          (df['Model Size'] == model_size)
          & (df['config.dataset_configs.train_num_samples'] == data_size)
          & (df['config.use_train_augmentations'] == use_aug)
      ]
      num_of_logging = len(subset.iloc[0][metric_name])
      logging_in_steps = [
          int((idx - 1) / 100000 * num_of_logging) for idx in compute_hours
      ]
      x = [int(x / 600) for x in compute_hours]
      y = [subset.iloc[0][metric_name][idx] for idx in logging_in_steps]

      sns.scatterplot(
          x=x,
          y=y[: len(x)],
          color=default_marker_color,
          ax=axes[0],
          s=150,
          alpha=0.5,
      )
      sns.lineplot(
          x=x, y=y[: len(x)], color=default_marker_color, ax=axes[0], alpha=0.2
      )
      axes[0].set_title(f'Compute Scaling')
      axes[0].set_xlabel(r'$\mathbf{Compute}$' + '\n TPU v5 VLP core hours')
      axes[0].set_ylabel(displayed_metric)
  # Figure 2: Data Scaling
  for model_size in df['Model Size'].unique():
    subset = df[
        (df['Model Size'] == model_size)
        & (df['config.use_train_augmentations'] == use_aug)
    ]
    x = [
        round(
            row['config.dataset_configs.train_num_samples'] / 1_000_000 * 5, 2
        )
        for _, row in subset.iterrows()
    ]
    if use_last:
      y = [row[metric_name][-1] for _, row in subset.iterrows()]
    else:
      y = [min(row[metric_name]) for _, row in subset.iterrows()]
    sns.scatterplot(
        x=x, y=y, color=default_marker_color, ax=axes[1], s=150, alpha=0.5
    )
    sns.lineplot(x=x, y=y, color=default_marker_color, ax=axes[1], alpha=0.2)
    axes[1].set_title(f'Data Scaling')
    axes[1].set_xlabel(r'$\mathbf{Data\ Size}$' + '\n(Million Hours)')
  # Figure 3: Model Scaling
  for data_size in df['config.dataset_configs.train_num_samples'].unique():
    subset = df[
        (df['config.dataset_configs.train_num_samples'] == data_size)
        & (df['config.use_train_augmentations'] == use_aug)
    ]
    x = [row['Param Size'] for _, row in subset.iterrows()]
    if use_last:
      y = [row[metric_name][-1] for _, row in subset.iterrows()]
    else:
      y = [min(row[metric_name]) for _, row in subset.iterrows()]
    sns.scatterplot(
        x=x, y=y, color=default_marker_color, ax=axes[2], s=150, alpha=0.5
    )
    sns.lineplot(x=x, y=y, color=default_marker_color, ax=axes[2], alpha=0.2)
    axes[2].set_title('Model Size Scaling')
    axes[2].set_xlabel(r'$\mathbf{Model\ Size}$' + '\n(Million of Params)')
  plt.tight_layout()
  plt.show()