In [None]:
import os
import re
import tempfile
import warnings
import collections
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

from google3.learning.deepmind.xmanager2.client import xmanager_api
from google3.pyglib import gfile
from google3.pyglib.function_utils import memoize
from matplotlib import font_manager

import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.lines import Line2D
from matplotlib.ticker import FixedLocator  # Import for the fix
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import LogLocator
from matplotlib.ticker import FuncFormatter

In [None]:
#@title Google Sans Import

# Import Google font family
_GOOGLE_SANS_PATH = (
    'google3/third_party/googlefonts/api/googlerestricted/googlesans/'
)

@memoize.Memoize()
def import_google3_fonts(font_path: str) -> None:
  """Import fonts stored in google3 into Matplotlib for use in Colab.

  Args:
    font_path: google3 path to either a directory that contains .ttf fonts or to
      a specific .ttf font file.
  """
  if gfile.IsDirectory(font_path):
    # Create a temp directory as a destination for copied font files.
    tmp_dir = tempfile.mkdtemp()
    # Copy font files from google3 to temp dir.
    gfile.RecursivelyCopyDir(font_path, tmp_dir, overwrite=True)
    # Add font files in directory to matplotlib font_manager.
    font_files = font_manager.findSystemFonts(fontpaths=tmp_dir)
  else:
    # Assume the path points to a file if it's not a directory.
    # Copy ttf file from google3 to temp location.
    tmp_file = tempfile.NamedTemporaryFile(suffix='.ttf')
    tmp_file.close()
    gfile.Copy(font_path, tmp_file.name)
    font_files = [tmp_file.name]

  # Add fonts to default font manager.
  for font_file in font_files:
    font_manager.fontManager.addfont(font_file)


def import_default_google_fonts() -> None:
  """Register a set of default fonts (Roboto, Google Sans) with Matplotlib."""
  # Prepend google_src to google3 paths.
  import_google3_fonts(os.path.join('/google_src/head/depot', _GOOGLE_SANS_PATH))


# Import and register Google fonts with Matplotlib so we can use them.
import_default_google_fonts()

In [None]:
#@title Set up Plot Settings

pd.set_option('display.max_rows', None)  # Show all rows
pd.set_option('display.max_columns', None)  # Show all columns
# Suppress specific warning
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 12
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 1
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'
plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE-5)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')
mpl.rcParams['font.family'] = 'Google Sans'

In [None]:
def read_xm_metrics(example_xid, metric_name, unit_id, lowest=True):
  experiment = xm_client.get_experiment(example_xid)
  work_unit = experiment.get_work_unit(unit_id)
  all_series = work_unit.list_measurement_series()
  # Read measurement series metadata.
  for series in all_series:
    if series.label == metric_name:
      # Read measurement points data.
      all_measurements = []
      for measurement in series.measurements:
        all_measurements.append(measurement.objective_value)
      if lowest:
        return min(all_measurements)
      else:
        return all_measurements


def add_min_columns(df):
  # Function to calculate the minimum value in each list
  def min_of_list(lst):
    return min(lst)

  # Calculate minimum values and add as new columns
  df['min_valid_mean_absolute_error_all'] = df[
      'valid_mean_absolute_error_all'
  ].apply(min_of_list)
  df['min_valid_mean_absolute_error_masked'] = df[
      'valid_mean_absolute_error_masked'
  ].apply(min_of_list)
  df['min_valid_mean_squared_error_all'] = df[
      'valid_mean_squared_error_all'
  ].apply(min_of_list)
  df['min_valid_mean_squared_error_masked'] = df[
      'valid_mean_squared_error_masked'
  ].apply(min_of_list)

  return df


def process_string_metric(input_string):
  # Define the mapping of long error names to their abbreviations
  error_map = {'mean_absolute_error': 'mae', 'mean_squared_error': 'mse'}

  # Replace the errors in the string using the map
  for long_error, short_error in error_map.items():
    input_string = re.sub(long_error, short_error, input_string)

  # Remove 'valid_' and replace '/' with '_'
  input_string = input_string.replace('valid_', '').replace('/', '_')

  return input_string


def generate_percentiled_numbers(max_value, percentiles):
  """Generate a list of integer numbers based on the given percentiles of the maximum value.

  Parameters:
  max_value (int): The maximum value to base the percentages on.
  percentiles (list of float): A list of percentiles (0-100) to calculate.

  Returns:
  list of int: A list of integers corresponding to the given percentiles.
  """
  return [round(max_value * (p / 100))-1 for p in percentiles]

# Custom formatter function to display y-ticks as floats
def log_float_formatter(y, pos):
    return f'{y:.2f}'


In [None]:
# @title Data Scaling

# Get unique learning rates


xm_id_dict = {  # Model Size, ParamSize, PatchSize
    124248449: ['Tiny', 2.21, '10x5'],
    124248804: ['ExtraSmall', 7.3, '10x5'],
    # 124142001: ['Small', 24.6, '10x5'],
    124248847: ['Base', 110.74, '10x5'],
}

compute_metrics = [
    'core_hours_TPU v5 lite',
]


metric_names = [
    # 'valid_mean_absolute_error_masked',
    'valid_mean_squared_error_masked',
    # 'forecast_0.2_eval/valid_mean_absolute_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_masked',
    # 'imputation_0.2_eval/valid_mean_absolute_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_masked',

]

xm_exp_dict = collections.defaultdict(list)
for key, values in xm_id_dict.items():
  xm_id = key
  model_size = values[0]
  param_size = values[1]
  patch_size = values[2]
  experiment = xm_client.get_experiment(xm_id)
  num_of_units = experiment.get_num_work_units()
  for id in range(num_of_units):
    real_id = id + 1
    work_unit = experiment.get_work_unit(real_id)
    key_list = work_unit.parameters.keys()
    xm_exp_dict['unit_id'].append(id)
    xm_exp_dict['xm_id'].append(xm_id)
    xm_exp_dict['Param Size'].append(param_size)
    xm_exp_dict['Model Size'].append(model_size)
    xm_exp_dict['Patch Size'].append(patch_size)
    for param_name in key_list:
      xm_exp_dict[param_name].append(work_unit.parameters[param_name])
    for metric in metric_names + compute_metrics:
      xm_exp_dict[metric].append(
          read_xm_metrics(xm_id, metric, real_id, lowest=False)
      )
df = pd.DataFrame(xm_exp_dict)

In [None]:
#@title Prepare the dataset

def extract_min_last(df):
  def process_column(col):
    # Check if the column contains lists of floats
    if col.apply(
        lambda x: isinstance(x, list)
        and all(isinstance(i, (float, int)) for i in x)
    ).all():
      # Extract the minimum and last values from the list
      min_values = col.apply(lambda x: min(x) if x else None)
      last_values = col.apply(lambda x: x[-1] if x else None)
      return min_values, last_values
    else:
      return col, col

  # Create new DataFrame with the extracted values
  new_df = pd.DataFrame()

  for column in df.columns:
    min_col, last_col = process_column(df[column])
    new_df[f'{column}_min'] = min_col
    new_df[f'{column}_last'] = last_col

  return new_df

new_df = extract_min_last(df)



In [None]:
df

In [None]:
#@title Group by Model Size and Data Size
new_metrics_names = []
for metric_name in metric_names:
  last_metric_name = f'{metric_name}_last'
  new_metrics_names.append(last_metric_name)
grouped_df = new_df.groupby(['config.dataset_configs.train_num_samples_min', 'Model Size_min'])[new_metrics_names]
metrics_summary = grouped_df.mean()
metrics_summary

In [None]:
# @title Data Scaling across Generative Tasks

# Define the task names for subtitles
task_names = [
    'Random Imputation 80% \n (Test Loss)',
    'Temporal Extrapolation 60 Minutes \n (Zero Shot)',
    'Temporal Interpolation 60 Minutes \n (Zero Shot)',
    # '8-Class Activity Recognition',
]  # Modify as needed

# Reset the index if needed
metrics_summary_reset = metrics_summary.reset_index()

# Number of subplots (one per metric)
n_metrics = 3

# Custom colors for different model sizes
model_colors = {
    'Base': '#3182BD',
    'ExtraSmall': '#6BAED6',
    'Tiny': '#9ECAE1',
}

# Custom markers (shapes) for different model sizes
model_shapes = {
    'Base': 'o',  # Circle
    'ExtraSmall': 'o',  # Square
    'Tiny': 'o',  # Triangle
}

# Custom legend labels for different model sizes
custom_legend_labels = {
    'Base': 'ViT - 110M',
    'ExtraSmall': 'ViT - 7M',
    'Tiny': 'ViT - 2M',
}

data_size_marker_map = {
    1000*5: 25,
    10000*5: 50,
    100000*5: 80,
    750000*5: 120,
    1321235*5: 150
}
# Create the figure and axes for the subplots
fig, axes = plt.subplots(1, n_metrics, figsize=(10, 3), sharex=True, dpi=100)

# Loop through each metric to plot the data
for i, ax in enumerate(axes):
  ax.set(xscale="log", yscale="log")
  metric_last = new_metrics_names[i]  # Get the corresponding '_last' metric name

  # Multiply x values by 5 and convert to log10 scale
  # metrics_summary_reset['log_scaled_data_size'] = np.log10(
  #     metrics_summary_reset['config.dataset_configs.train_num_samples_min'] * 5
  # )
  metrics_summary_reset['data_size_hours'] = metrics_summary_reset['config.dataset_configs.train_num_samples_min'] * 5

  for model_size in model_colors.keys():
      # Filter data for the current model size
      subset = metrics_summary_reset[metrics_summary_reset['Model Size_min'] == model_size]

      # Plot the line without markers first
      sns.lineplot(
          data=subset,
          x='data_size_hours',
          y=metric_last,
          ax=ax,
          color=model_colors[model_size],
          linestyle='--',  # Dotted line
          label=custom_legend_labels[model_size],  # Custom legend label
          linewidth=1.0
      )

      # Add scatter plot for the markers with different sizes
      sns.scatterplot(
          data=subset,
          x='data_size_hours',
          y=metric_last,
          ax=ax,
          color=model_colors[model_size],
          marker=model_shapes[model_size],
          s=subset['data_size_hours'].map(data_size_marker_map),  # Map data size to marker size
          legend=False  # Avoid adding duplicate legend entries
      )

  # Set the subtitle (task name) with light bold font
  ax.set_title(task_names[i])  # Set the title to the task name
  ax.set_xlabel('Data Size (Hours)')  # Update x-axis label to reflect log scaling
  if i == 0:
      ax.set_ylabel('Mean Squared Error')  # Set the y-axis label
  else:
      ax.set_ylabel('')
  ax.yaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(1, 10), numticks=10))
  ax.yaxis.set_major_formatter(FuncFormatter(log_float_formatter))
  ax.get_legend().remove()

# Create a list of handles for the custom legend (same shape but different colors)
legend_handles = [
    Line2D([0], [0], marker='o', color=model_colors[model_size], label=custom_legend_labels[model_size],
           markerfacecolor=model_colors[model_size], markersize=10, markeredgecolor=None)
    for model_size in model_colors.keys()
]

# Move the legend to the bottom of the figure and only show the custom colored circles
fig.legend(handles=legend_handles[::-1], loc='lower center', bbox_to_anchor=(0.5, -0.2), ncol=3, fontsize=MEDIUM_SIZE-2, frameon=False)

# Adjust layout for better display
plt.tight_layout()
plt.subplots_adjust(bottom=0.1)  # Make space for the legend at the bottom
plt.savefig("/tmp/lsm_gen_task_data_scaling.pdf", bbox_inches='tight', format="pdf")
plt.show()
%download_file /tmp/lsm_gen_task_data_scaling.pdf


In [None]:
# @title Model Scaling across Generative Tasks

# Define the task names for subtitles
task_names = [
    'Random Imputation 80% \n (Test Loss)',
    'Temporal Extrapolation 60 Minutes \n (Zero Shot)',
    'Temporal Interpolation 60 Minutes \n (Zero Shot)',
    # '8-Class Activity Recognition',
]  # Modify as needed

# Reset the index if needed
metrics_summary_reset = metrics_summary.reset_index()

# Number of subplots (one per metric)
n_metrics = 3


model_colors = {
    'Base': '#3182BD',
    'ExtraSmall': '#6BAED6',
    'Tiny': '#9ECAE1',
    'Deb': 'purple',
    'Small': 'orange',
    'Large': 'cyan',
}

# Custom colors for different model sizes
data_colors = {
    1000: '#3182BD',
    10000: '#3182BD',
    100000: '#3182BD',
    750000: '#3182BD',
    1321235: '#3182BD',
}

# Custom markers (shapes) for different model sizes
data_shapes = {
    1000: 'o',  # Circle
    10000: 'o',  # Square
    100000: 'o',  # Triangle
    750000: 'o',  # Plus
    1321235: 'o',  # Diamond
}

# Custom legend labels for different model sizes
custom_legend_labels = {
    1000: '0.005 M',  # Circle
    10000: '0.05 M',  # Square
    100000: '0.5 M',  # Triangle
    750000: '3.8 M',  # Plus
    1321235: '6.6 M',  # Diamond
}

data_size_marker_map = {
    1000: 25,
    10000: 50,
    100000: 80,
    750000: 120,
    1321235: 150
}


custom_model_params_map = {
    'Base': 110000000,
    'ExtraSmall': 7000000,
    'Tiny': 2000000,
}

metrics_summary_reset['Model Params'] = metrics_summary_reset[
    'Model Size_min'
].map(custom_model_params_map)

# Create the figure and axes for the subplots
fig, axes = plt.subplots(1, n_metrics, figsize=(10, 3), sharex=True, dpi=100)

# Loop through each metric to plot the data
for i, ax in enumerate(axes):
  ax.set(xscale="log", yscale="log")
  metric_last = new_metrics_names[
      i
  ]  # Get the corresponding '_last' metric name
  # Loop through model sizes to plot each one with its own color and shape
  for data_size in data_colors.keys():
    # Filter data for the current model size
    subset = metrics_summary_reset[
        metrics_summary_reset['config.dataset_configs.train_num_samples_min']
        == data_size
    ]

    # Plot the data with custom color and shape
    sns.lineplot(
        data=subset,
        x='Model Params',
        y=metric_last,
        ax=ax,
        color=data_colors[data_size],
        label=custom_legend_labels[data_size],  # Custom legend label
        linestyle='--',  # Dotted line
        linewidth=1.0
    )

    sns.scatterplot(
          data=subset,
          x='Model Params',
          y=metric_last,
          ax=ax,
          color=subset['Model Size_min'].map(model_colors),
          marker=data_shapes[data_size],
          s=subset['config.dataset_configs.train_num_samples_min'].map(data_size_marker_map),  # Map data size to marker size
          legend=False  # Avoid adding duplicate legend entries
    )

    # Set the subtitle (task name) with light bold font
    ax.set_title(task_names[i], fontweight='medium')  # Set the title to the task name
    ax.set_xlabel('Model Parameters')  # Update x-axis label to reflect log scaling
    if i == 0:
        ax.set_ylabel('Mean Squared Error')  # Set the y-axis label
    else:
        ax.set_ylabel('')

  ax.yaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(1, 10), numticks=10))
  ax.yaxis.set_major_formatter(FuncFormatter(log_float_formatter))
  ax.get_legend().remove()

# Create a list of handles for the custom legend (same color but different sizes)
legend_handles = [
    Line2D([0], [0], marker='o', color='black', label=custom_legend_labels[data_size],
           markerfacecolor='black', markersize=np.sqrt(data_size_marker_map[data_size]*0.7),
           markeredgecolor=None) for data_size in data_colors.keys()
]

# Move the legend to the bottom of the figure and only show the custom circle sizes legend
fig.legend(handles=legend_handles, loc='lower center', bbox_to_anchor=(0.5, -0.2), ncol=5, fontsize=MEDIUM_SIZE-2, frameon=False)

# Adjust layout for better display
plt.tight_layout()
plt.subplots_adjust(bottom=0.1)  # Make space for the legend at the bottom
plt.savefig("/tmp/lsm_gen_task_model_scaling.pdf", bbox_inches='tight', format="pdf")
plt.show()
%download_file /tmp/lsm_gen_task_model_scaling.pdf
