In [None]:
from google3.learning.deepmind.xmanager2.client import xmanager_api
import matplotlib as mpl
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import collections
import re
from matplotlib.lines import Line2D
import numpy as np
import warnings
from matplotlib.ticker import FixedLocator  # Import for the fix
from matplotlib.ticker import MaxNLocator

# Suppress specific warning
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 12
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
mpl.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 1
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'


plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE-5)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')

In [None]:
def read_xm_metrics(example_xid, metric_name, unit_id, lowest=True):
  experiment = xm_client.get_experiment(example_xid)
  work_unit = experiment.get_work_unit(unit_id)
  all_series = work_unit.list_measurement_series()
  # Read measurement series metadata.
  for series in all_series:
    if series.label == metric_name:
      # Read measurement points data.
      all_measurements = []
      for measurement in series.measurements:
        all_measurements.append(measurement.objective_value)
      if lowest:
        return min(all_measurements)
      else:
        return all_measurements


def add_min_columns(df):
  # Function to calculate the minimum value in each list
  def min_of_list(lst):
    return min(lst)

  # Calculate minimum values and add as new columns
  df['min_valid_mean_absolute_error_all'] = df[
      'valid_mean_absolute_error_all'
  ].apply(min_of_list)
  df['min_valid_mean_absolute_error_masked'] = df[
      'valid_mean_absolute_error_masked'
  ].apply(min_of_list)
  df['min_valid_mean_squared_error_all'] = df[
      'valid_mean_squared_error_all'
  ].apply(min_of_list)
  df['min_valid_mean_squared_error_masked'] = df[
      'valid_mean_squared_error_masked'
  ].apply(min_of_list)

  return df


def process_string_metric(input_string):
  # Define the mapping of long error names to their abbreviations
  error_map = {'mean_absolute_error': 'mae', 'mean_squared_error': 'mse'}

  # Replace the errors in the string using the map
  for long_error, short_error in error_map.items():
    input_string = re.sub(long_error, short_error, input_string)

  # Remove 'valid_' and replace '/' with '_'
  input_string = input_string.replace('valid_', '').replace('/', '_')

  return input_string


def generate_percentiled_numbers(max_value, percentiles):
  """Generate a list of integer numbers based on the given percentiles of the maximum value.

  Parameters:
  max_value (int): The maximum value to base the percentages on.
  percentiles (list of float): A list of percentiles (0-100) to calculate.

  Returns:
  list of int: A list of integers corresponding to the given percentiles.
  """
  return [round(max_value * (p / 100))-1 for p in percentiles]

In [None]:
# @title Data Loading

# Get unique learning rates


xm_id_dict = {  # Model Size, ParamSize, PatchSize
    124248449: ['Tiny', 2.21, '10x5'],
    124248804: ['ExtraSmall', 7.3, '10x5'],
    # 124142001: ['Small', 24.6, '10x5'], #ignore this for now.
    124248847: ['Base', 110.74, '10x5'],
}

compute_metrics = [
    'core_hours_TPU v5 lite',
    'train_mean_absolute_error_all',
    'train_mean_absolute_error_masked',
    'train_mean_squared_error_all',
    'train_mean_squared_error_masked',
]


metric_names = [
    'valid_mean_squared_error_masked',
    'forecast_0.1_eval/valid_mean_squared_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_masked',
    'forecast_0.4_eval/valid_mean_squared_error_masked',
    'imputation_0.1_eval/valid_mean_squared_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_masked',
    'imputation_0.4_eval/valid_mean_squared_error_masked',
]

xm_exp_dict = collections.defaultdict(list)
for key, values in xm_id_dict.items():
  xm_id = key
  model_size = values[0]
  param_size = values[1]
  patch_size = values[2]
  experiment = xm_client.get_experiment(xm_id)
  num_of_units = experiment.get_num_work_units()
  for id in range(num_of_units):
    real_id = id + 1
    work_unit = experiment.get_work_unit(real_id)
    key_list = work_unit.parameters.keys()
    xm_exp_dict['unit_id'].append(id)
    xm_exp_dict['xm_id'].append(xm_id)
    xm_exp_dict['Param Size'].append(param_size)
    xm_exp_dict['Model Size'].append(model_size)
    xm_exp_dict['Patch Size'].append(patch_size)
    for param_name in key_list:
      xm_exp_dict[param_name].append(work_unit.parameters[param_name])
    for metric in metric_names + compute_metrics:
      xm_exp_dict[metric].append(
          read_xm_metrics(xm_id, metric, real_id, lowest=False)
      )
df = pd.DataFrame(xm_exp_dict)
df

In [None]:
# @title Plot Main Scaling

use_aug = True
use_last = True
sample_size = 1321235
compute_hours_steps = [1, 5, 10, 20, 40, 80, 100]
colors = ['#465ece', '#bed2f6', '#f8ab8d']
other_metric_names = [
    'forecast_0.2_eval/valid_mean_squared_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_masked',
    'valid_mean_squared_error_masked',
]
line_alpha = 0.2
circle_alpha = 1
# Create a figure with a custom layout
fig, axs = plt.subplots(1, 3, figsize=(12, 4), dpi=100)

# Unpack the axes for easier access
ax1, ax2, ax3 = axs

# Define marker sizes based on model sizes
marker_size_map = {
    'Deb': 50,
    'Tiny': 75,
    'ExtraSmall': 100,
    'Small': 125,
    'Base': 150,
    'Large': 175,
}

data_scaling_list = []
model_scaling_list = []
compute_scaling_list = []

for idx, metric_name in enumerate(other_metric_names):
  displayed_metric = process_string_metric(metric_name)
  color = colors[idx]

  # Figure 1: Compute Scaling
  if metric_name in other_metric_names:
    # for data_size in df['config.dataset_configs.train_num_samples'].unique():
    #   for model_size in df['Model Size'].unique():
    for data_size in [100000, 750000, 1321235]:
      for model_size in ['ExtraSmall', 'Base']:
        subset = df[
            (df['Model Size'] == model_size)
            & (df['config.dataset_configs.train_num_samples'] == data_size)
            & (df['config.use_train_augmentations'] == use_aug)
        ]

        if not subset.empty:
          x_idx = generate_percentiled_numbers(
              len(subset.iloc[0]['core_hours_TPU v5 lite']), compute_hours_steps
          )
          x = [subset.iloc[0]['core_hours_TPU v5 lite'][idx] for idx in x_idx]
          y_idx = generate_percentiled_numbers(
              len(subset.iloc[0][metric_name]), compute_hours_steps
          )
          y = [subset.iloc[0][metric_name][idx] for idx in y_idx]
          y = y[: len(x)]
          if (
              metric_name == 'valid_mean_squared_error_masked'
              and model_size == 'Base'
              and data_size == 1321235
          ):
            compute_scaling_list.append((x, y))
          x = np.log10(x)
          sns.scatterplot(
              x=x,
              y=y,
              color=color,
              ax=ax1,
              s=marker_size_map[model_size],
              alpha=circle_alpha,
              legend=False,  # Turn off the legend for markers
          )
          # ax1.set_ylim(0, 0.85)
          sns.lineplot(
              x=x,
              y=y[: len(x)],
              color=color,
              ax=ax1,
              alpha=line_alpha,
          )

  # Figure 2: Data Scaling
  for model_size in df['Model Size'].unique():
    subset = df[
        (df['Model Size'] == model_size)
        & (df['config.use_train_augmentations'] == use_aug)
    ]

    if not subset.empty:
      x = []
      y = []
      for _, row in subset.iterrows():
        x.append(round(row['config.dataset_configs.train_num_samples'] * 5, 2))
        y.append(row[metric_name][-1] if use_last else min(row[metric_name]))
      x = np.log10(x)
      if metric_name == 'valid_mean_squared_error_masked':
        data_scaling_list.append((x, y))
      # Set the marker size based on the model size
      marker_size = marker_size_map[model_size]

      scatter = sns.scatterplot(
          x=x,
          y=y,
          s=marker_size,  # Use 's' instead of 'size' to set marker size directly
          color=color,
          ax=ax2,
          alpha=circle_alpha,
          legend=False,
      )
      # ax2.set_ylim(0, 0.85)
      sns.lineplot(x=x, y=y, color=color, ax=ax2, alpha=line_alpha)

  # Figure 3: Model Scaling
  if metric_name in other_metric_names:
    for data_size in [100000, 750000, 1321235]:
      subset = df[
          (df['config.dataset_configs.train_num_samples'] == data_size)
          & (df['config.use_train_augmentations'] == use_aug)
      ]

      if not subset.empty:
        x = [row['Param Size'] for _, row in subset.iterrows()]
        y = [
            row[metric_name][-1] if use_last else min(row[metric_name])
            for _, row in subset.iterrows()
        ]
        sizes = [
            marker_size_map[row['Model Size']] for _, row in subset.iterrows()
        ]
        if metric_name == 'valid_mean_squared_error_masked':
          model_scaling_list.append((x, y))
        x = np.log10(x)
        sns.scatterplot(
            x=x,
            y=y,
            size=sizes,
            sizes=(75, 150),
            color=color,
            ax=ax3,
            alpha=circle_alpha,
            legend=False,
        )
        # ax3.set_ylim(0, 0.65)
        sns.lineplot(x=x, y=y, color=color, ax=ax3, alpha=line_alpha)


# Titles and labels
ax1.set_xlabel(r'$\mathbf{Compute}$' + '\n TPU v5 VLP core hours')
ax1.set_ylabel('Masked Mean Squared Error')
ax2.set_xlabel(r'$\mathbf{Data\ Size}$' + '\n(Hours)')
ax3.set_xlabel(r'$\mathbf{Model\ Size}$' + '\n(Million of Params)')

# Set the number of ticks and ensure unique tick labels
for ax in [ax1, ax2, ax3]:
  ax.xaxis.set_major_locator(
      MaxNLocator(integer=True, prune='both')
  )  # Adjust to avoid repetitive ticks
  xticks = ax.get_xticks()
  ax.xaxis.set_major_locator(FixedLocator(xticks))
  ax.set_xticklabels([
      f'$10^{int(val)}$' if i == 0 or val != xticks[i - 1] else ''
      for i, val in enumerate(xticks)
  ])

marker_sizes = [
    marker_size_map['Tiny'],
    marker_size_map['ExtraSmall'],
    marker_size_map['Base'],
]
marker_labels = ['2M', '7M', '110M']

marker_handles = [
    plt.scatter([], [], s=size, color='black') for size in marker_sizes
]

# Task legend
task_handles = [
    plt.Line2D([0], [0], color=color, linestyle='-', linewidth=2)
    for color in colors
]
task_labels = ['Forecasting', 'Imputation', 'Random Filling']

# Combine handles and labels
combined_handles = marker_handles + task_handles
combined_labels = marker_labels + task_labels

# Add combined legend at the top center
fig.legend(
    combined_handles,
    combined_labels,
    title='',
    loc='upper center',
    bbox_to_anchor=(0.5, 1.06),  # Center the combined legend
    ncol=len(combined_labels),
    frameon=False,
    fontsize=MEDIUM_SIZE,
    handletextpad=0.4,
)
plt.tight_layout()
plt.savefig("/tmp/lsm_main_scaling.png", bbox_inches='tight', format="png")
plt.show()
%download_file /tmp/lsm_main_scaling.png
plt.savefig("/tmp/lsm_main_scaling.pdf", bbox_inches='tight', format="pdf")
%download_file /tmp/lsm_main_scaling.pdf

