In [None]:
import os
import re
import tempfile
import warnings
import collections
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

from google3.learning.deepmind.xmanager2.client import xmanager_api
from google3.pyglib import gfile
from google3.pyglib.function_utils import memoize
from matplotlib import font_manager

import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.lines import Line2D
from matplotlib.ticker import FixedLocator  # Import for the fix
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import LogLocator
from matplotlib.ticker import FuncFormatter

In [None]:
#@title Google Sans Import

# Import Google font family
_GOOGLE_SANS_PATH = (
    'google3/third_party/googlefonts/api/googlerestricted/googlesans/'
)

@memoize.Memoize()
def import_google3_fonts(font_path: str) -> None:
  """Import fonts stored in google3 into Matplotlib for use in Colab.

  Args:
    font_path: google3 path to either a directory that contains .ttf fonts or to
      a specific .ttf font file.
  """
  if gfile.IsDirectory(font_path):
    # Create a temp directory as a destination for copied font files.
    tmp_dir = tempfile.mkdtemp()
    # Copy font files from google3 to temp dir.
    gfile.RecursivelyCopyDir(font_path, tmp_dir, overwrite=True)
    # Add font files in directory to matplotlib font_manager.
    font_files = font_manager.findSystemFonts(fontpaths=tmp_dir)
  else:
    # Assume the path points to a file if it's not a directory.
    # Copy ttf file from google3 to temp location.
    tmp_file = tempfile.NamedTemporaryFile(suffix='.ttf')
    tmp_file.close()
    gfile.Copy(font_path, tmp_file.name)
    font_files = [tmp_file.name]

  # Add fonts to default font manager.
  for font_file in font_files:
    font_manager.fontManager.addfont(font_file)


def import_default_google_fonts() -> None:
  """Register a set of default fonts (Roboto, Google Sans) with Matplotlib."""
  # Prepend google_src to google3 paths.
  import_google3_fonts(os.path.join('/google_src/head/depot', _GOOGLE_SANS_PATH))


# Import and register Google fonts with Matplotlib so we can use them.
import_default_google_fonts()

In [None]:
#@title Set up Plot Settings

pd.set_option('display.max_rows', None)  # Show all rows
pd.set_option('display.max_columns', None)  # Show all columns
# Suppress specific warning
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 12
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 1
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'
plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE-5)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')
mpl.rcParams['font.family'] = 'Google Sans'

def log_float_formatter(y, pos):
    return f'{y:.2f}'


In [None]:
def read_xm_metrics(example_xid, metric_name, unit_id, lowest=True):
  experiment = xm_client.get_experiment(example_xid)
  work_unit = experiment.get_work_unit(unit_id)
  all_series = work_unit.list_measurement_series()
  # Read measurement series metadata.
  for series in all_series:
    if series.label == metric_name:
      # Read measurement points data.
      all_measurements = []
      for measurement in series.measurements:
        all_measurements.append(measurement.objective_value)
      if lowest:
        return min(all_measurements)
      else:
        return all_measurements


def add_min_columns(df):
  # Function to calculate the minimum value in each list
  def min_of_list(lst):
    return min(lst)

  # Calculate minimum values and add as new columns
  df['min_valid_mean_absolute_error_all'] = df[
      'valid_mean_absolute_error_all'
  ].apply(min_of_list)
  df['min_valid_mean_absolute_error_masked'] = df[
      'valid_mean_absolute_error_masked'
  ].apply(min_of_list)
  df['min_valid_mean_squared_error_all'] = df[
      'valid_mean_squared_error_all'
  ].apply(min_of_list)
  df['min_valid_mean_squared_error_masked'] = df[
      'valid_mean_squared_error_masked'
  ].apply(min_of_list)

  return df


def process_string_metric(input_string):
  # Define the mapping of long error names to their abbreviations
  error_map = {'mean_absolute_error': 'mae', 'mean_squared_error': 'mse'}

  # Replace the errors in the string using the map
  for long_error, short_error in error_map.items():
    input_string = re.sub(long_error, short_error, input_string)

  # Remove 'valid_' and replace '/' with '_'
  input_string = input_string.replace('valid_', '').replace('/', '_')

  return input_string


def generate_percentiled_numbers(max_value, percentiles):
  """Generate a list of integer numbers based on the given percentiles of the maximum value.

  Parameters:
  max_value (int): The maximum value to base the percentages on.
  percentiles (list of float): A list of percentiles (0-100) to calculate.

  Returns:
  list of int: A list of integers corresponding to the given percentiles.
  """
  return [round(max_value * (p / 100))-1 for p in percentiles]

# Custom formatter function to display y-ticks as floats
def log_float_formatter(y, pos):
    return f'{y:.2f}'


In [None]:
# @title Data Scaling

# Get unique learning rates


xm_id_dict = {  # Model Size, ParamSize, PatchSize
    124248449: ['Tiny', 2.21, '10x5'],
    124248804: ['ExtraSmall', 7.3, '10x5'],
    # 124142001: ['Small', 24.6, '10x5'],
    124248847: ['Base', 110.74, '10x5'],
    # 125769633: ['Large', 328.1, '10x5'],
}

compute_metrics = [
    'core_hours_TPU v5 lite',
    'examples_seen'
]


metric_names = [
    # 'valid_mean_absolute_error_masked',
    'valid_mean_squared_error_masked',
    # 'forecast_0.2_eval/valid_mean_absolute_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_masked',
    # 'imputation_0.2_eval/valid_mean_absolute_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_masked',

]

xm_exp_dict = collections.defaultdict(list)
for key, values in xm_id_dict.items():
  xm_id = key
  model_size = values[0]
  param_size = values[1]
  patch_size = values[2]
  experiment = xm_client.get_experiment(xm_id)
  num_of_units = experiment.get_num_work_units()
  for id in range(num_of_units):
    real_id = id + 1
    work_unit = experiment.get_work_unit(real_id)
    key_list = work_unit.parameters.keys()
    xm_exp_dict['unit_id'].append(id)
    xm_exp_dict['xm_id'].append(xm_id)
    xm_exp_dict['Param Size'].append(param_size)
    xm_exp_dict['Model Size'].append(model_size)
    xm_exp_dict['Patch Size'].append(patch_size)
    for param_name in key_list:
      xm_exp_dict[param_name].append(work_unit.parameters[param_name])
    for metric in metric_names + compute_metrics:
      xm_exp_dict[metric].append(
          read_xm_metrics(xm_id, metric, real_id, lowest=False)
      )
default_df = pd.DataFrame(xm_exp_dict)

In [None]:
df = default_df[default_df['config.dataset_configs.train_num_samples'] == 1321235]

In [None]:
df

In [None]:
custom_model_params_map = {
    'Base': 110000000,
    'ExtraSmall': 7000000,
    'Tiny': 2000000,
    'Large': 328000000,
    'Small': 2400000,
}
df['model_params'] = df['Model Size'].map(custom_model_params_map)
df

In [None]:
example_seen_list = df['examples_seen'].values
valid_mse_list = df['valid_mean_squared_error_masked'].values
model_params = df['model_params'].values
model_names = {
  2000000: 'ViT-2M',
  7000000: 'ViT-7M',
  110000000: 'ViT-110M'
}
# 3182BD, 6BAED6, 9ECAE1
colors = {
  110000000: '#3182BD',
  7000000: '#6BAED6',
  2000000: '#9ECAE1'
}



fig, axes = plt.subplots(1, 1, figsize=(3.3, 3), sharex=True, dpi=100)
axes.set(xscale="log", yscale="log")
# Define four distinct colors for the lines
# colors = plt.cm.tab20c()

# Loop through each row in the DataFrame and plot its corresponding line
for i in range(len(df)):
    samples_have_seen = np.array(example_seen_list[i], dtype=float)
    valid_mse = np.array(valid_mse_list[i], dtype=float)
    # Use modulo to cycle through the colors if there are more than 4 lines
    color = plt.cm.tab20c(i)
    plt.plot(
        samples_have_seen[:],
        valid_mse[1:],  # Skip the first point if needed
        # label=f'{model_params[i]}',
        label=model_names[model_params[i]],
        color=colors[model_params[i]],
    )

# Set labels and title
axes.yaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(1, 10), numticks=10))
axes.yaxis.set_major_formatter(FuncFormatter(log_float_formatter))
axes.xaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(1, 10), numticks=10))
plt.xlabel('Samples Have Seen')
plt.ylabel('Test Loss')
plt.legend(frameon=False, fontsize=MEDIUM_SIZE-2)
# Add the legend
plt.tight_layout()
plt.savefig("/tmp/model_efficiency_lsm.pdf", bbox_inches='tight', format="pdf")
plt.show()
%download_file /tmp/model_efficiency_lsm.pdf

In [None]:
example_seen_list