In [None]:
import os
import re
import tempfile
import warnings
import collections
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

from google3.learning.deepmind.xmanager2.client import xmanager_api
from google3.pyglib import gfile
from google3.pyglib.function_utils import memoize
from matplotlib import font_manager

import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.lines import Line2D
from matplotlib.ticker import FixedLocator  # Import for the fix
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import LogLocator
from matplotlib.ticker import FuncFormatter

# Suppress specific warning
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

# Import Google font family
_GOOGLE_SANS_PATH = (
    'google3/third_party/googlefonts/api/googlerestricted/googlesans/'
)

@memoize.Memoize()
def import_google3_fonts(font_path: str) -> None:
  """Import fonts stored in google3 into Matplotlib for use in Colab.

  Args:
    font_path: google3 path to either a directory that contains .ttf fonts or to
      a specific .ttf font file.
  """
  if gfile.IsDirectory(font_path):
    # Create a temp directory as a destination for copied font files.
    tmp_dir = tempfile.mkdtemp()
    # Copy font files from google3 to temp dir.
    gfile.RecursivelyCopyDir(font_path, tmp_dir, overwrite=True)
    # Add font files in directory to matplotlib font_manager.
    font_files = font_manager.findSystemFonts(fontpaths=tmp_dir)
  else:
    # Assume the path points to a file if it's not a directory.
    # Copy ttf file from google3 to temp location.
    tmp_file = tempfile.NamedTemporaryFile(suffix='.ttf')
    tmp_file.close()
    gfile.Copy(font_path, tmp_file.name)
    font_files = [tmp_file.name]

  # Add fonts to default font manager.
  for font_file in font_files:
    font_manager.fontManager.addfont(font_file)


def import_default_google_fonts() -> None:
  """Register a set of default fonts (Roboto, Google Sans) with Matplotlib."""
  # Prepend google_src to google3 paths.
  import_google3_fonts(os.path.join('/google_src/head/depot', _GOOGLE_SANS_PATH))


# Import and register Google fonts with Matplotlib so we can use them.
import_default_google_fonts()


# Set up plot style
xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 12
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
# ['DejaVu Sans', 'Arial', 'Helvetica', 'Times New Roman', 'Verdana', 'Georgia']
mpl.rcParams['font.family'] = 'Google Sans'
plt.rcParams['font.family'] = 'Google Sans'
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 1
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'


plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE-5)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')

In [None]:
def read_xm_metrics(example_xid, metric_name, unit_id, lowest=True):
  experiment = xm_client.get_experiment(example_xid)
  work_unit = experiment.get_work_unit(unit_id)
  all_series = work_unit.list_measurement_series()
  # Read measurement series metadata.
  for series in all_series:
    if series.label == metric_name:
      # Read measurement points data.
      all_measurements = []
      for measurement in series.measurements:
        all_measurements.append(measurement.objective_value)
      if lowest:
        return min(all_measurements)
      else:
        return all_measurements


def add_min_columns(df):
  # Function to calculate the minimum value in each list
  def min_of_list(lst):
    return min(lst)

  # Calculate minimum values and add as new columns
  df['min_valid_mean_absolute_error_all'] = df[
      'valid_mean_absolute_error_all'
  ].apply(min_of_list)
  df['min_valid_mean_absolute_error_masked'] = df[
      'valid_mean_absolute_error_masked'
  ].apply(min_of_list)
  df['min_valid_mean_squared_error_all'] = df[
      'valid_mean_squared_error_all'
  ].apply(min_of_list)
  df['min_valid_mean_squared_error_masked'] = df[
      'valid_mean_squared_error_masked'
  ].apply(min_of_list)

  return df


def process_string_metric(input_string):
  # Define the mapping of long error names to their abbreviations
  error_map = {'mean_absolute_error': 'mae', 'mean_squared_error': 'mse'}

  # Replace the errors in the string using the map
  for long_error, short_error in error_map.items():
    input_string = re.sub(long_error, short_error, input_string)

  # Remove 'valid_' and replace '/' with '_'
  input_string = input_string.replace('valid_', '').replace('/', '_')

  return input_string


def generate_percentiled_numbers(max_value, percentiles):
  """Generate a list of integer numbers based on the given percentiles of the maximum value.

  Parameters:
  max_value (int): The maximum value to base the percentages on.
  percentiles (list of float): A list of percentiles (0-100) to calculate.

  Returns:
  list of int: A list of integers corresponding to the given percentiles.
  """
  return [round(max_value * (p / 100))-1 for p in percentiles]

In [None]:
# @title Data Scaling

# Get unique learning rates


xm_id_dict = {  # Model Size, ParamSize, PatchSize
    124248449: ['Tiny', 2.21, '10x5'],
    124248804: ['ExtraSmall', 7.3, '10x5'],
    # 124142001: ['Small', 24.6, '10x5'],
    124248847: ['Base', 110.74, '10x5'],
}

compute_metrics = [
    'core_hours_TPU v5 lite',
    'train_mean_absolute_error_all',
    'train_mean_absolute_error_masked',
    'train_mean_squared_error_all',
    'train_mean_squared_error_masked',
]


metric_names = [
    'valid_mean_squared_error_masked',
    'forecast_0.1_eval/valid_mean_squared_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_masked',
    'forecast_0.4_eval/valid_mean_squared_error_masked',
    'imputation_0.1_eval/valid_mean_squared_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_masked',
    'imputation_0.4_eval/valid_mean_squared_error_masked',
]

xm_exp_dict = collections.defaultdict(list)
for key, values in xm_id_dict.items():
  xm_id = key
  model_size = values[0]
  param_size = values[1]
  patch_size = values[2]
  experiment = xm_client.get_experiment(xm_id)
  num_of_units = experiment.get_num_work_units()
  for id in range(num_of_units):
    real_id = id + 1
    work_unit = experiment.get_work_unit(real_id)
    key_list = work_unit.parameters.keys()
    xm_exp_dict['unit_id'].append(id)
    xm_exp_dict['xm_id'].append(xm_id)
    xm_exp_dict['Param Size'].append(param_size)
    xm_exp_dict['Model Size'].append(model_size)
    xm_exp_dict['Patch Size'].append(patch_size)
    for param_name in key_list:
      xm_exp_dict[param_name].append(work_unit.parameters[param_name])
    for metric in metric_names + compute_metrics:
      xm_exp_dict[metric].append(
          read_xm_metrics(xm_id, metric, real_id, lowest=False)
      )
df = pd.DataFrame(xm_exp_dict)
# df = add_min_columns(df)
df

In [None]:
# @title Random Imputation (Val Loss) - Style Updated
from scipy.optimize import curve_fit
from scipy.stats import linregress
from scipy.spatial import ConvexHull


def filter_pairs(x, y):
  new_x, new_y = [], []

  for xi, yi in zip(x, y):
    if yi > 0.5:
      continue
    if not (xi < 0.6 and yi < 0.4):
      new_x.append(xi)
      new_y.append(yi)

  return new_x, new_y


# Custom formatter function to display y-ticks as floats
def log_float_formatter(y, pos):
  return f'{y:.2f}'


def log_float_formatter_3(y, pos):
  return f'{y:.3f}'


# Custom scaling function
def scaling_function_full(C, a, b, c, d):
  return a + b * ((C + d) ** c)


# No saturation for lower end; remove d
def scaling_function(C, a, b, c):
  return a + b * (C ** c)


# Function to filter for the lowest y for each unique x
def filter_best_pairs(x, y):
  unique_x = np.unique(x)
  best_y = [
      np.min([y_val for x_val, y_val in zip(x, y) if x_val == ux])
      for ux in unique_x
  ]
  return unique_x, np.array(best_y)


def filter_best_pairs(x, y):
  # Sort the pairs based on x
  sorted_pairs = sorted(zip(x, y), key=lambda pair: pair[0])
  sorted_x, sorted_y = zip(*sorted_pairs)

  # Initialize lists for valid x and y values
  valid_x = []
  valid_y = []

  # Keep track of the previous y value, start with infinity to make sure first pair is always included
  previous_y = float('inf')

  # Iterate through sorted x and y values
  for x_val, y_val in zip(sorted_x, sorted_y):
    # Only keep the pair if y is less than or equal to the previous y
    if y_val <= previous_y:
      valid_x.append(x_val)
      valid_y.append(y_val)
      previous_y = y_val  # Update previous_y to the current y

  return np.array(valid_x), np.array(valid_y)


def filter_best_pairs(x, y):
  # Sort x and y together based on x (to ensure monotonic x values)
  sorted_pairs = sorted(zip(x, y), key=lambda pair: pair[0])
  sorted_x, sorted_y = zip(*sorted_pairs)

  # Initialize lists for valid x and y values
  valid_x = []
  valid_y = []

  # Keep track of the minimum y encountered so far
  min_y = float('inf')

  # Iterate through sorted x and y values
  for x_val, y_val in zip(sorted_x, sorted_y):
    # Only keep the pair if y decreases or stays the same
    if y_val <= min_y:
      valid_x.append(x_val)
      valid_y.append(y_val)
      min_y = y_val  # Update the minimum y

  return np.array(valid_x), np.array(valid_y)


def format_scientific_latex(number):
  """Helper function to format numbers as scientific notation with 10^x."""
  exponent = int(np.floor(np.log10(abs(number)))) if number != 0 else 0
  mantissa = number / 10**exponent
  return r'{:.2f} \times 10^{{{}}}'.format(mantissa, exponent)


# Fit and plot the scaling function
def fit_and_plot_custom_scaling(x, y, ax, color, p0=[0.22, 0.16, -0.79],
                                label=None, xlabel='C'):
  # Use curve_fit to fit the scaling function to the data
  # p0 is the initial guess, which will be adjusted during fitting
  params, _ = curve_fit(
      scaling_function, x, y, p0=p0, maxfev=100000
  )

  # Print the optimized parameters for reference
  print(
      f'Optimized parameters: a = {params[0]}, b = {params[1]}, c = {params[2]}'
  )
  a = params[0]
  b = params[1]
  c = params[2]
  # d = params[3]
  formatted_b = format_scientific_latex(b)
  # formatted_d = format_scientific_latex(d)

  # Generate fitted values
  plot_x = np.linspace(min(x), max(x), num=10000)
  fitted_y = scaling_function(plot_x, *params)
  # equation_text = r'$L = {:.2f} + {} \cdot (x + {})^{{{:.2f}}}$'.format(
  #     a, formatted_b, formatted_d, c
  # )
  equation_text = r'$L = {:.2f} + {:.2f} \cdot {}^{{{:.2f}}}$'.format(
      a, b, xlabel, c
  )
  ax.text(
      0.95,
      0.95,
      equation_text,
      transform=ax.transAxes,
      fontsize=11,
      verticalalignment='top',
      horizontalalignment='right',
      bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'),
  )

  # Plot the fitted line (no log transformation since we're debugging)
  ax.plot(plot_x, fitted_y, color=color, label=label, alpha=0.8, linestyle = "--")


def fit_and_plot_linear_scaling(x, y, ax, color, label=None):
  # Ensure x and y are numpy arrays for element-wise operations
  x = np.array(x)
  y = np.array(y)

  # Perform linear regression
  slope, intercept, r_value, p_value, std_err = linregress(x, y)

  # Print the linear fit parameters (slope and intercept)
  print(f'Linear fit: y = {slope} * x + {intercept}')

  # Generate fitted y values based on the linear regression result
  fitted_y = slope * x + intercept

  # Plot the fitted line
  ax.plot(x, fitted_y, color=color, label=f'Fit: slope={slope:.2f}', alpha=0.8)

  # Create the LaTeX formatted equation text
  equation_text = r'$y = {:.2f}x + {:.2f}$'.format(slope, intercept)

  # Add the equation to the plot as a text annotation
  ax.text(
      0.55,
      0.95,
      equation_text,
      transform=ax.transAxes,
      fontsize=8,
      verticalalignment='top',
      horizontalalignment='left',
      bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'),
  )

  return slope, intercept


use_aug = True
use_last = True
sample_size = 1321235
compute_hours_steps = [1, 5, 10, 20, 40, 80, 100]
# colors = ['#d32f2f', '#388e3c', '#1976d2']  # Red, Green, Blue
# colors = ['#465ece', '#bed2f6', '#f8ab8d']
colors = [plt.cm.tab20c(i) for i in range(3)]

other_metric_names = [
    'valid_mean_squared_error_masked',
]
line_alpha = 1
circle_alpha = 1


# Create a figure with a custom layout
fig, axs = plt.subplots(1, 3, figsize=(12, 4), dpi=100)

# Unpack the axes for easier access
ax1, ax2, ax3 = axs


# Define marker sizes based on model sizes
marker_size_map = {
    'Tiny': 25,
    'ExtraSmall': 80,
    'Base': 150,
    100000: 25,
    750000: 80,
    1321235: 150,
}
color_map = {
    'Tiny': colors[2],
    'ExtraSmall': colors[1],
    'Base': colors[0],
}

data_scaling_list = []
model_scaling_list = []
compute_scaling_list = []
metric_name = 'valid_mean_squared_error_masked'
displayed_metric = process_string_metric(metric_name)

#####################################################################################
# Figure 1: Compute Scaling

# Create lists to store all x and y values across subsets
x_all, y_all, line_idx_all = [], [], []

ax1.set(xscale="log", yscale="log")

# Iterate over data sizes and model sizes
line_idx = 0
for data_size in [100000, 750000, 1321235]:
  for model_size in ['ExtraSmall', 'Base', 'Tiny']:
    subset = df[
        (df['Model Size'] == model_size)
        & (df['config.dataset_configs.train_num_samples'] == data_size)
        & (df['config.use_train_augmentations'] == use_aug)
    ]

    if not subset.empty:
      compute_length = len(subset.iloc[0]['core_hours_TPU v5 lite'])

      # Subsample the trajectory
      idx_range = list(range(10)) + list(range(10, 50, 10)) + list(range(50, 100, 20)) + [compute_length - 1]
      # idx_range = range(compute_length)

      x = [
          subset.iloc[0]['core_hours_TPU v5 lite'][idx]
          for idx in idx_range
      ]
      y = [subset.iloc[0][metric_name][idx] for idx in idx_range]
      x, y = filter_pairs(x, y)
      # Append x and y to overall list
      # Scatter plot and line plot for this subset
      sns.scatterplot(
          x=x,
          y=y,
          color=color_map[model_size],
          ax=ax1,
          s=marker_size_map[data_size],
          alpha=circle_alpha,
          legend=False,
      )
      if (len(x) > 0):
        x_all.extend(x)
        y_all.extend(y)
        line_idx_all.extend([line_idx for _ in range(len(x))])
      line_idx += 1
      sns.lineplot(
          x=x,
          y=y,
          color=color_map[model_size],
          ax=ax1,
          linewidth=1,
          alpha=line_alpha,
      )


df_xy = pd.DataFrame([x_all, y_all, line_idx_all]).T
df_xy.columns = ['col_x', 'col_y', 'col_idx']
hull = ConvexHull(df_xy[['col_x', 'col_y']])
hull_points = df_xy.iloc[hull.vertices]

def get_lower(polygon):
    minx = np.argmin(polygon[:, 0])
    maxx = np.argmax(polygon[:, 0]) + 1
    if minx >= maxx:
        lower_curve = np.concatenate([polygon[minx:], polygon[:maxx]])
    else:
        lower_curve = polygon[minx:maxx]
    return lower_curve
lower_curve = get_lower(np.array(hull_points)[:,:-1])
df_xy_fit = pd.DataFrame(lower_curve[:-1])
df_xy_fit.columns = ['col_x', 'col_y',]

# sns.lineplot(
#     x=df_xy_fit['col_x'].values,
#     y=df_xy_fit['col_y'].values,
#     color='green',
#     ax=ax1,
#     linewidth = 1,
#     alpha=line_alpha,
# )

fit_and_plot_custom_scaling(
    df_xy_fit['col_x'], df_xy_fit['col_y'], ax1, color='k', xlabel='C'
)
# ax1.set_ylim(ax1.get_ylim()[0], 0.41)



#####################################################################################
# Figure 2: Data Scaling

x_all, y_all = [], []

marker_size_map_datascaling = {
    5000: 25, 50000: 50, 500000: 80, 3750000: 120, 6606175: 150
}

ax2.set(xscale="log", yscale="log")

for model_size in ['ExtraSmall', 'Base', 'Tiny']:
  subset = df[
      (df['Model Size'] == model_size)
      # & (df['config.dataset_configs.train_num_samples'] == data_size)
      & (df['config.use_train_augmentations'] == use_aug)
  ]
  if not subset.empty:
    x = []
    y = []
    for _, row in subset.iterrows():
      x.append(round(row['config.dataset_configs.train_num_samples'] * 5, 2))
      y.append(row[metric_name][-1] if use_last else min(row[metric_name]))
    if metric_name == 'valid_mean_squared_error_masked':
      data_scaling_list.append((x, y))

    # here data size should be based on x axis
    scatter = sns.scatterplot(
        x=x,
        y=y,
        s=[marker_size_map_datascaling[i] for i in x],  # Use 's' instead of 'size' to set marker size directly
        color=color_map[model_size],
        ax=ax2,
        alpha=circle_alpha,
        legend=False,
    )

    if (len(x) > 0):
      x_all.extend(x)
      y_all.extend(y)
      sns.lineplot(
          x=x,
          y=y,
          color=color_map[model_size],
          ax=ax2,
          linewidth=1,
          alpha=line_alpha
      )


df_xy = pd.DataFrame([x_all, y_all]).T
df_xy.columns = ['col_x', 'col_y']
hull = ConvexHull(df_xy[['col_x', 'col_y']])
hull_points = df_xy.iloc[hull.vertices]

lower_curve = get_lower(np.array(hull_points)[:])
df_xy_fit = pd.DataFrame(lower_curve[1:])
df_xy_fit.columns = ['col_x', 'col_y']

# sns.lineplot(
#     x=df_xy_fit['col_x'].values,
#     y=df_xy_fit['col_y'].values,
#     color='green',
#     ax=ax2,
#     linewidth = 1,
#     alpha=0.8,
# )

fit_and_plot_custom_scaling(
    df_xy_fit['col_x'], df_xy_fit['col_y'], ax2, color='k', xlabel='D'
)

plt.tight_layout()



#####################################################################################
# Figure 3: Model Scaling
x_all, y_all = [], []

ax3.set(xscale="log", yscale="log")

for data_size in [750000, 1321235]:
  subset = df[
      (df['config.dataset_configs.train_num_samples'] == data_size)
      & (df['config.use_train_augmentations'] == use_aug)
  ]

  if not subset.empty:
    x = [row['Param Size'] for _, row in subset.iterrows()]
    y = [
        row[metric_name][-1] if use_last else min(row[metric_name])
        for _, row in subset.iterrows()
    ]
    sizes = [marker_size_map[row['Model Size']] for _, row in subset.iterrows()]

    if metric_name == 'valid_mean_squared_error_masked':
      model_scaling_list.append((x, y))
    sns.scatterplot(
        x=x,
        y=y,
        s=marker_size_map[data_size],
        color=list(color_map.values()),  # based on x axis
        ax=ax3,
        alpha=circle_alpha,
        legend=False,
    )
    sns.lineplot(x=x, y=y, color=colors[-1], ax=ax3, linewidth=1, alpha=line_alpha)
    x_all.extend(x)
    y_all.extend(y)

df_xy = pd.DataFrame([x_all, y_all]).T
df_xy.columns = ['col_x', 'col_y']
hull = ConvexHull(df_xy[['col_x', 'col_y']])
hull_points = df_xy.iloc[hull.vertices]

lower_curve = get_lower(np.array(hull_points)[:])
df_xy_fit = pd.DataFrame(lower_curve[1:-1])
df_xy_fit.columns = ['col_x', 'col_y']

# sns.lineplot(
#     x=df_xy_fit['col_x'].values,
#     y=df_xy_fit['col_y'].values,
#     color='green',
#     ax=ax3,
#     linewidth = 1,
#     alpha=0.8,
# )

fit_and_plot_custom_scaling(
    df_xy_fit['col_x'], df_xy_fit['col_y'], ax3, color='k', xlabel='N'
)

plt.tight_layout()

#####################################################################################
# Titles and labels
ax1.set_ylabel('Mean Squared Error')
ax1.set_xlabel(r'$\mathbf{Compute}$ [C]')
ax1.text(0.5, -0.24, 'TPU v5e core hours', transform=ax1.transAxes, color='gray', ha='center')
ax2.set_xlabel(r'$\mathbf{Data\ Size}$ [D]')
ax2.text(0.5, -0.24, 'Hours', transform=ax2.transAxes, color='gray', ha='center')
ax3.set_xlabel(r'$\mathbf{Model\ Size}$ [N]')
ax3.text(0.5, -0.24, 'Million of Params', transform=ax3.transAxes, color='gray', ha='center')

ax1.yaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(1, 10), numticks=10))
ax1.yaxis.set_major_formatter(FuncFormatter(log_float_formatter))
ax2.yaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(1, 10), numticks=10))
ax2.yaxis.set_major_formatter(FuncFormatter(log_float_formatter))
ax3.yaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(1, 10), numticks=10))
ax3.yaxis.set_major_formatter(FuncFormatter(log_float_formatter_3))

marker_sizes = [
    marker_size_map['Tiny'],
    marker_size_map['ExtraSmall'],
    marker_size_map['Base'],
]
marker_labels = ['2M', '7M', '110M']

marker_handles = [
    plt.scatter([], [], s=size, color='black') for size in marker_sizes
]

# Combine handles and labels
combined_handles = marker_handles
combined_labels = marker_labels

#####################################################################################
# Legend
marker_size_handle_labels = ['0.005M', '0.05M', '0.5M', '3.8M', '6.6M']
marker_size_handles = [
    mlines.Line2D([], [], color='black', marker='o', linestyle=':',
                  markersize=np.sqrt(marker_size_map_datascaling[data_size]),
                  label=marker_size_handle_labels[i])
    for i, data_size in enumerate([5000, 50000, 500000, 3750000, 6606175])
]
color_handle_labels = ['ViT 2M', 'ViT 7M', 'ViT 110M']
color_handles = [
    mlines.Line2D([], [], color=color_map[model_size], marker='o', linestyle='-',
                  label=color_handle_labels[i])
    for i, model_size in enumerate(['Tiny', 'ExtraSmall', 'Base'])
]

# combined_handles = marker_size_handles + color_handles
# combined_labels = [h.get_label() for h in combined_handles]
# fig.legend(handles=combined_handles, labels=combined_labels, ncol=6, loc='upper center')
legend1 = fig.legend(
    handles=marker_size_handles, labels=[h.get_label() for h in marker_size_handles],
    ncol=5, loc='upper left', bbox_to_anchor=(0.05, 1.1), fontsize=12, frameon=False)
legend2 = fig.legend(
    handles=color_handles, labels=[h.get_label() for h in color_handles],
    ncol=3, loc='upper right', bbox_to_anchor=(0.96, 1.1), fontsize=12, frameon=False)
plt.subplots_adjust(top=0.80)

#####################################################################################
# Final plot
plt.tight_layout()
plt.savefig("/tmp/teaser.pdf", bbox_inches='tight', format="pdf")
%download_file /tmp/teaser.pdf
plt.savefig("/tmp/teaser.svg", bbox_inches='tight', format="svg")
%download_file /tmp/teaser.svg
plt.show()

In [None]:
# @title Random Imputation (Val Loss)
from scipy.optimize import curve_fit
from scipy.stats import linregress

def filter_pairs(x, y):
  new_x, new_y = [], []

  for xi, yi in zip(x, y):
    if not (xi < 0.6 and yi < 0.4):
      new_x.append(xi)
      new_y.append(yi)

  return new_x, new_y


# Define the correct custom scaling function
def scaling_function(C, a, b, c, d):
  return a + b * (C + d) ** c


# Function to filter for the lowest y for each unique x
def filter_best_pairs(x, y):
  unique_x = np.unique(x)
  best_y = [
      np.min([y_val for x_val, y_val in zip(x, y) if x_val == ux])
      for ux in unique_x
  ]
  return unique_x, np.array(best_y)


def filter_best_pairs(x, y):
  # Sort the pairs based on x
  sorted_pairs = sorted(zip(x, y), key=lambda pair: pair[0])
  sorted_x, sorted_y = zip(*sorted_pairs)

  # Initialize lists for valid x and y values
  valid_x = []
  valid_y = []

  # Keep track of the previous y value, start with infinity to make sure first pair is always included
  previous_y = float('inf')

  # Iterate through sorted x and y values
  for x_val, y_val in zip(sorted_x, sorted_y):
    # Only keep the pair if y is less than or equal to the previous y
    if y_val <= previous_y:
      valid_x.append(x_val)
      valid_y.append(y_val)
      previous_y = y_val  # Update previous_y to the current y

  return np.array(valid_x), np.array(valid_y)


def filter_best_pairs(x, y):
  # Sort x and y together based on x (to ensure monotonic x values)
  sorted_pairs = sorted(zip(x, y), key=lambda pair: pair[0])
  sorted_x, sorted_y = zip(*sorted_pairs)

  # Initialize lists for valid x and y values
  valid_x = []
  valid_y = []

  # Keep track of the minimum y encountered so far
  min_y = float('inf')

  # Iterate through sorted x and y values
  for x_val, y_val in zip(sorted_x, sorted_y):
    # Only keep the pair if y decreases or stays the same
    if y_val <= min_y:
      valid_x.append(x_val)
      valid_y.append(y_val)
      min_y = y_val  # Update the minimum y

  return np.array(valid_x), np.array(valid_y)


def format_scientific_latex(number):
  """Helper function to format numbers as scientific notation with 10^x."""
  exponent = int(np.floor(np.log10(abs(number)))) if number != 0 else 0
  mantissa = number / 10**exponent
  return r'{:.2f} \times 10^{{{}}}'.format(mantissa, exponent)


# Fit and plot the scaling function
def fit_and_plot_custom_scaling(x, y, ax, color, label=None):
  # Use curve_fit to fit the scaling function to the data
  # p0 is the initial guess, which will be adjusted during fitting
  params, _ = curve_fit(
      scaling_function, x, y, p0=[0.22, 0.16, -0.79, -0.07], maxfev=10000
  )

  # Print the optimized parameters for reference
  print(
      f'Optimized parameters: a = {params[0]}, b = {params[1]}, c ='
      f' {params[2]}, d = {params[3]}'
  )
  a = params[0]
  b = params[1]
  c = params[2]
  d = params[3]
  formatted_b = format_scientific_latex(b)
  formatted_d = format_scientific_latex(d)

  # Generate fitted values
  fitted_y = scaling_function(x, *params)
  equation_text = r'$L = {:.2f} + {} \cdot (x + {})^{{{:.2f}}}$'.format(
      a, formatted_b, formatted_d, c
  )
  ax.text(
      0.15,
      0.90,
      equation_text,
      transform=ax.transAxes,
      fontsize=8,
      verticalalignment='top',
      horizontalalignment='left',
      bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'),
  )

  # Plot the fitted line (no log transformation since we're debugging)
  ax.plot(np.log10(x), fitted_y, color=color, label=label, alpha=0.8)


def fit_and_plot_linear_scaling(x, y, ax, color, label=None):
  # Ensure x and y are numpy arrays for element-wise operations
  x = np.array(x)
  y = np.array(y)

  # Perform linear regression
  slope, intercept, r_value, p_value, std_err = linregress(x, y)

  # Print the linear fit parameters (slope and intercept)
  print(f'Linear fit: y = {slope} * x + {intercept}')

  # Generate fitted y values based on the linear regression result
  fitted_y = slope * x + intercept

  # Plot the fitted line
  ax.plot(x, fitted_y, color=color, label=f'Fit: slope={slope:.2f}', alpha=0.8)

  # Create the LaTeX formatted equation text
  equation_text = r'$y = {:.2f}x + {:.2f}$'.format(slope, intercept)

  # Add the equation to the plot as a text annotation
  ax.text(
      0.55,
      0.95,
      equation_text,
      transform=ax.transAxes,
      fontsize=8,
      verticalalignment='top',
      horizontalalignment='left',
      bbox=dict(facecolor='white', alpha=0.8, boxstyle='round,pad=0.5'),
  )

  return slope, intercept


use_aug = True
use_last = True
sample_size = 1321235
compute_hours_steps = [1, 5, 10, 20, 40, 80, 100]
# colors = ['#d32f2f', '#388e3c', '#1976d2']  # Red, Green, Blue
colors = ['#465ece', '#bed2f6', '#f8ab8d']  # Red, Green, Blue

# #465ece, #f8ab8d, #bed2f6
other_metric_names = [
    'valid_mean_squared_error_masked',
]
line_alpha = 0.2
circle_alpha = 1


# Create a figure with a custom layout
fig, axs = plt.subplots(1, 3, figsize=(12, 4), dpi=100)

# Unpack the axes for easier access
ax1, ax2, ax3 = axs

# Define marker sizes based on model sizes
marker_size_map = {
    'Deb': 2,
    'Tiny': 2,
    'ExtraSmall': 2,
    'Small': 2,
    'Base': 2,
    'Large': 2,
}

data_scaling_list = []
model_scaling_list = []
compute_scaling_list = []
metric_name = 'valid_mean_squared_error_masked'
displayed_metric = process_string_metric(metric_name)

#####################################################################################
# Figure 1: Compute Scaling

# Create lists to store all x and y values across subsets
x_all, y_all = [], []

# Iterate over data sizes and model sizes
for data_size in [100000, 750000, 1321235]:
  for model_size in ['ExtraSmall', 'Base', 'Tiny']:
    subset = df[
        (df['Model Size'] == model_size)
        & (df['config.dataset_configs.train_num_samples'] == data_size)
        & (df['config.use_train_augmentations'] == use_aug)
    ]

    if not subset.empty:
      compute_length = len(subset.iloc[0]['core_hours_TPU v5 lite'])
      x = [
          subset.iloc[0]['core_hours_TPU v5 lite'][idx]
          for idx in range(compute_length)
      ]
      y = [subset.iloc[0][metric_name][idx] for idx in range(compute_length)]
      x, y = filter_pairs(x, y)
      # Append x and y to overall list
      # Scatter plot and line plot for this subset
      log_x = np.log10(x)
      sns.scatterplot(
          x=log_x,
          y=y,
          color='black',
          ax=ax1,
          s=marker_size_map[model_size],
          alpha=circle_alpha,
          legend=False,
      )
      if model_size == 'Tiny' and data_size == 100000:
        color = 'red'
        x_all.extend(x)
        y_all.extend(y)
      else:
        color = 'black'
      sns.lineplot(
          x=log_x,
          y=y,
          color=color,
          ax=ax1,
          alpha=line_alpha,
      )


fit_and_plot_custom_scaling(
    x_all, y_all, ax1, color='#217BFE', label='Best Fit'
)
sns.scatterplot(
    x=np.log10(x_all),
    y=y_all,
    color='#217BFE',
    ax=ax1,
    s=30,
    alpha=circle_alpha,
    legend=False,
)

#####################################################################################
# Figure 2: Data Scaling
x_all_data, y_all_data = [], []

for model_size in ['Base']:
  subset = df[(df['Model Size'] == model_size)]

  if not subset.empty:
    x = []
    y = []
    for _, row in subset.iterrows():
      x.append(round(row['config.dataset_configs.train_num_samples'] * 5, 2))
      y.append(row[metric_name][-1] if use_last else min(row[metric_name]))
    log_x = np.log10(x)
    if metric_name == 'valid_mean_squared_error_masked':
      data_scaling_list.append((x, y))
    # Set the marker size based on the model size
    marker_size = marker_size_map[model_size]
    scatter = sns.scatterplot(
        x=log_x,
        y=y,
        s=marker_size,  # Use 's' instead of 'size' to set marker size directly
        color='black',
        ax=ax2,
        alpha=circle_alpha,
        legend=False,
    )
    x_all_data.extend(x)
    y_all_data.extend(y)
    sns.lineplot(x=log_x, y=y, color='black', ax=ax2, alpha=line_alpha)


fit_and_plot_custom_scaling(
    x_all_data, y_all_data, ax2, color='#217BFE', label='Best Fit'
)
# fit_and_plot_linear_scaling(
#     np.log10(x_all_data), y_all_data, ax2, color='#217BFE', label='Best Fit'
# )

sns.scatterplot(
    x=np.log10(x_all_data),
    y=y_all_data,
    color='#217BFE',
    ax=ax2,
    s=30,
    alpha=circle_alpha,
    legend=False,
)


#####################################################################################
# Figure 3: Model Scaling
x_all, y_all = [], []
for data_size in [1321235]:
  subset = df[
      (df['config.dataset_configs.train_num_samples'] == data_size)
      & (df['config.use_train_augmentations'] == use_aug)
  ]

  if not subset.empty:
    x = [row['Param Size'] for _, row in subset.iterrows()]
    y = [
        row[metric_name][-1] if use_last else min(row[metric_name])
        for _, row in subset.iterrows()
    ]
    sizes = [marker_size_map[row['Model Size']] for _, row in subset.iterrows()]
    if metric_name == 'valid_mean_squared_error_masked':
      model_scaling_list.append((x, y))
    x_log = np.log10(x)
    sns.scatterplot(
        x=x_log,
        y=y,
        size=sizes,
        sizes=(75, 150),
        color='black',
        ax=ax3,
        alpha=circle_alpha,
        legend=False,
    )
    # ax3.set_ylim(0, 0.65)
    sns.lineplot(x=x_log, y=y, color='black', ax=ax3, alpha=line_alpha)
    x_all.extend(x)
    y_all.extend(y)

fit_and_plot_linear_scaling(
    np.log10(x_all), y_all, ax3, color='#217BFE', label='Best Fit'
)
sns.scatterplot(
    x=np.log10(x_all),
    y=y_all,
    color='#217BFE',
    ax=ax3,
    s=30,
    alpha=circle_alpha,
    legend=False,
)

#####################################################################################
# Titles and labels
ax1.set_xlabel(r'$\mathbf{Compute}$' + '\n TPU v5 VLP core hours')
ax1.set_ylabel('Masked Mean Squared Error')
ax2.set_xlabel(r'$\mathbf{Data\ Size}$' + '\n(Hours)')
ax3.set_xlabel(r'$\mathbf{Model\ Size}$' + '\n(Million of Params)')

# Set the number of ticks and ensure unique tick labels
for ax in [ax1, ax2, ax3]:
  ax.xaxis.set_major_locator(
      MaxNLocator(integer=True, prune='both')
  )  # Adjust to avoid repetitive ticks
  xticks = ax.get_xticks()
  ax.xaxis.set_major_locator(FixedLocator(xticks))
  ax.set_xticklabels([
      f'$10^{int(val)}$' if i == 0 or val != xticks[i - 1] else ''
      for i, val in enumerate(xticks)
  ])

marker_sizes = [
    marker_size_map['Tiny'],
    marker_size_map['ExtraSmall'],
    marker_size_map['Base'],
]
marker_labels = ['2M', '7M', '110M']

marker_handles = [
    plt.scatter([], [], s=size, color='black') for size in marker_sizes
]

# Combine handles and labels
combined_handles = marker_handles
combined_labels = marker_labels
plt.tight_layout()
plt.show()

In [None]:
# @title Forecasting

use_aug = True
use_last = True
sample_size = 1321235
compute_hours_steps = [1, 5, 10, 20, 40, 80, 100]
# colors = ['#d32f2f', '#388e3c', '#1976d2']  # Red, Green, Blue
colors = ['#465ece', '#bed2f6', '#f8ab8d']  # Red, Green, Blue

# #465ece, #f8ab8d, #bed2f6
other_metric_names = [
    'forecast_0.2_eval/valid_mean_squared_error_masked',
]
metric_name = 'forecast_0.2_eval/valid_mean_squared_error_masked'
line_alpha = 0.2
circle_alpha = 1


# Create a figure with a custom layout
fig, axs = plt.subplots(1, 3, figsize=(12, 4), dpi=100)

# Unpack the axes for easier access
ax1, ax2, ax3 = axs

# Define marker sizes based on model sizes
marker_size_map = {
    'Deb': 2,
    'Tiny': 2,
    'ExtraSmall': 2,
    'Small': 2,
    'Base': 2,
    'Large': 2,
}

data_scaling_list = []
model_scaling_list = []
compute_scaling_list = []
displayed_metric = process_string_metric(metric_name)

#####################################################################################
# Figure 1: Compute Scaling

# Create lists to store all x and y values across subsets
x_all, y_all = [], []

# Iterate over data sizes and model sizes
for data_size in [1321235]:
  for model_size in ['Base']:
    subset = df[
        (df['Model Size'] == model_size)
        & (df['config.dataset_configs.train_num_samples'] == data_size)
        & (df['config.use_train_augmentations'] == use_aug)
    ]

    if not subset.empty:
      compute_length = len(subset.iloc[0]['core_hours_TPU v5 lite'])
      metric_length = len(subset.iloc[0][metric_name])
      min_length = min(compute_length, metric_length)
      compute_list = subset.iloc[0]['core_hours_TPU v5 lite'][:min_length]
      metric_list = subset.iloc[0][metric_name][:min_length]
      x = [compute_list[idx] for idx in range(min_length)]
      y = [metric_list[idx] for idx in range(min_length)]
      x, y = filter_pairs(x, y)
      # Append x and y to overall list
      # Scatter plot and line plot for this subset
      log_x = np.log10(x)
      sns.scatterplot(
          x=log_x,
          y=y,
          color='black',
          ax=ax1,
          s=marker_size_map[model_size],
          alpha=circle_alpha,
          legend=False,
      )
      if model_size == 'Tiny' and data_size == 100000:
        color = 'red'
        x_all.extend(x)
        y_all.extend(y)
      else:
        color = 'black'
      sns.lineplot(
          x=log_x,
          y=y,
          color=color,
          ax=ax1,
          alpha=line_alpha,
      )


# fit_and_plot_custom_scaling(
#     x_all, y_all, ax1, color='#217BFE', label='Best Fit'
# )
sns.scatterplot(
    x=np.log10(x_all),
    y=y_all,
    color='#217BFE',
    ax=ax1,
    s=30,
    alpha=circle_alpha,
    legend=False,
)

#####################################################################################
# Figure 2: Data Scaling
x_all_data, y_all_data = [], []

for model_size in ['Base']:
  subset = df[(df['Model Size'] == model_size)]

  if not subset.empty:
    x = []
    y = []
    for _, row in subset.iterrows():
      x.append(round(row['config.dataset_configs.train_num_samples'] * 5, 2))
      y.append(row[metric_name][-1] if use_last else min(row[metric_name]))
    log_x = np.log10(x)
    # Set the marker size based on the model size
    marker_size = marker_size_map[model_size]
    scatter = sns.scatterplot(
        x=log_x,
        y=y,
        s=marker_size,  # Use 's' instead of 'size' to set marker size directly
        color='black',
        ax=ax2,
        alpha=circle_alpha,
        legend=False,
    )
    x_all_data.extend(x)
    y_all_data.extend(y)
    sns.lineplot(x=log_x, y=y, color='black', ax=ax2, alpha=line_alpha)


fit_and_plot_custom_scaling(
    x_all_data, y_all_data, ax2, color='#217BFE', label='Best Fit'
)
# fit_and_plot_linear_scaling(
#     np.log10(x_all_data), y_all_data, ax2, color='#217BFE', label='Best Fit'
# )

sns.scatterplot(
    x=np.log10(x_all_data),
    y=y_all_data,
    color='#217BFE',
    ax=ax2,
    s=30,
    alpha=circle_alpha,
    legend=False,
)


#####################################################################################
# Figure 3: Model Scaling
x_all, y_all = [], []
for data_size in [1321235]:
  subset = df[
      (df['config.dataset_configs.train_num_samples'] == data_size)
      & (df['config.use_train_augmentations'] == use_aug)
  ]

  if not subset.empty:
    x = [row['Param Size'] for _, row in subset.iterrows()]
    y = [
        row[metric_name][-1] if use_last else min(row[metric_name])
        for _, row in subset.iterrows()
    ]
    sizes = [marker_size_map[row['Model Size']] for _, row in subset.iterrows()]
    x_log = np.log10(x)
    sns.scatterplot(
        x=x_log,
        y=y,
        size=sizes,
        sizes=(75, 150),
        color='black',
        ax=ax3,
        alpha=circle_alpha,
        legend=False,
    )
    # ax3.set_ylim(0, 0.65)
    sns.lineplot(x=x_log, y=y, color='black', ax=ax3, alpha=line_alpha)
    x_all.extend(x)
    y_all.extend(y)

fit_and_plot_linear_scaling(
    np.log10(x_all), y_all, ax3, color='#217BFE', label='Best Fit'
)
sns.scatterplot(
    x=np.log10(x_all),
    y=y_all,
    color='#217BFE',
    ax=ax3,
    s=30,
    alpha=circle_alpha,
    legend=False,
)

#####################################################################################
# Titles and labels
ax1.set_xlabel(r'$\mathbf{Compute}$' + '\n TPU v5 VLP core hours')
ax1.set_ylabel('Masked Mean Squared Error')
ax2.set_xlabel(r'$\mathbf{Data\ Size}$' + '\n(Hours)')
ax3.set_xlabel(r'$\mathbf{Model\ Size}$' + '\n(Million of Params)')

# Set the number of ticks and ensure unique tick labels
for ax in [ax1, ax2, ax3]:
  ax.xaxis.set_major_locator(
      MaxNLocator(integer=True, prune='both')
  )  # Adjust to avoid repetitive ticks
  xticks = ax.get_xticks()
  ax.xaxis.set_major_locator(FixedLocator(xticks))
  ax.set_xticklabels([
      f'$10^{int(val)}$' if i == 0 or val != xticks[i - 1] else ''
      for i, val in enumerate(xticks)
  ])

marker_sizes = [
    marker_size_map['Tiny'],
    marker_size_map['ExtraSmall'],
    marker_size_map['Base'],
]
marker_labels = ['2M', '7M', '110M']

marker_handles = [
    plt.scatter([], [], s=size, color='black') for size in marker_sizes
]

# Combine handles and labels
combined_handles = marker_handles
combined_labels = marker_labels
plt.tight_layout()
plt.show()

In [None]:
# @title Imputation

use_aug = True
use_last = True
sample_size = 1321235
compute_hours_steps = [1, 5, 10, 20, 40, 80, 100]
# colors = ['#d32f2f', '#388e3c', '#1976d2']  # Red, Green, Blue
colors = ['#465ece', '#bed2f6', '#f8ab8d']  # Red, Green, Blue

# #465ece, #f8ab8d, #bed2f6
other_metric_names = [
    'imputation_0.2_eval/valid_mean_squared_error_masked',
]
metric_name = 'imputation_0.2_eval/valid_mean_squared_error_masked'
line_alpha = 0.2
circle_alpha = 1


# Create a figure with a custom layout
fig, axs = plt.subplots(1, 3, figsize=(12, 4), dpi=100)

# Unpack the axes for easier access
ax1, ax2, ax3 = axs

# Define marker sizes based on model sizes
marker_size_map = {
    'Deb': 2,
    'Tiny': 2,
    'ExtraSmall': 2,
    'Small': 2,
    'Base': 2,
    'Large': 2,
}

data_scaling_list = []
model_scaling_list = []
compute_scaling_list = []
displayed_metric = process_string_metric(metric_name)

#####################################################################################
# Figure 1: Compute Scaling

# Create lists to store all x and y values across subsets
x_all, y_all = [], []

# Iterate over data sizes and model sizes
for data_size in [1321235]:
  for model_size in ['Base']:
    subset = df[
        (df['Model Size'] == model_size)
        & (df['config.dataset_configs.train_num_samples'] == data_size)
        & (df['config.use_train_augmentations'] == use_aug)
    ]

    if not subset.empty:
      compute_length = len(subset.iloc[0]['core_hours_TPU v5 lite'])
      metric_length = len(subset.iloc[0][metric_name])
      min_length = min(compute_length, metric_length)
      compute_list = subset.iloc[0]['core_hours_TPU v5 lite'][:min_length]
      metric_list = subset.iloc[0][metric_name][:min_length]
      x = [compute_list[idx] for idx in range(min_length)]
      y = [metric_list[idx] for idx in range(min_length)]
      x, y = filter_pairs(x, y)
      # Append x and y to overall list
      # Scatter plot and line plot for this subset
      log_x = np.log10(x)
      sns.scatterplot(
          x=log_x,
          y=y,
          color='black',
          ax=ax1,
          s=marker_size_map[model_size],
          alpha=circle_alpha,
          legend=False,
      )
      if model_size == 'Tiny' and data_size == 100000:
        color = 'red'
        x_all.extend(x)
        y_all.extend(y)
      else:
        color = 'black'
      sns.lineplot(
          x=log_x,
          y=y,
          color=color,
          ax=ax1,
          alpha=line_alpha,
      )


# fit_and_plot_custom_scaling(
#     x_all, y_all, ax1, color='#217BFE', label='Best Fit'
# )
sns.scatterplot(
    x=np.log10(x_all),
    y=y_all,
    color='#217BFE',
    ax=ax1,
    s=30,
    alpha=circle_alpha,
    legend=False,
)

#####################################################################################
# Figure 2: Data Scaling
x_all_data, y_all_data = [], []

for model_size in ['Base']:
  subset = df[(df['Model Size'] == model_size)]

  if not subset.empty:
    x = []
    y = []
    for _, row in subset.iterrows():
      x.append(round(row['config.dataset_configs.train_num_samples'] * 5, 2))
      y.append(row[metric_name][-1] if use_last else min(row[metric_name]))
    log_x = np.log10(x)
    # Set the marker size based on the model size
    marker_size = marker_size_map[model_size]
    scatter = sns.scatterplot(
        x=log_x,
        y=y,
        s=marker_size,  # Use 's' instead of 'size' to set marker size directly
        color='black',
        ax=ax2,
        alpha=circle_alpha,
        legend=False,
    )
    x_all_data.extend(x)
    y_all_data.extend(y)
    sns.lineplot(x=log_x, y=y, color='black', ax=ax2, alpha=line_alpha)


fit_and_plot_custom_scaling(
    x_all_data, y_all_data, ax2, color='#217BFE', label='Best Fit'
)
# fit_and_plot_linear_scaling(
#     np.log10(x_all_data), y_all_data, ax2, color='#217BFE', label='Best Fit'
# )

sns.scatterplot(
    x=np.log10(x_all_data),
    y=y_all_data,
    color='#217BFE',
    ax=ax2,
    s=30,
    alpha=circle_alpha,
    legend=False,
)


#####################################################################################
# Figure 3: Model Scaling
x_all, y_all = [], []
for data_size in [1321235]:
  subset = df[
      (df['config.dataset_configs.train_num_samples'] == data_size)
      & (df['config.use_train_augmentations'] == use_aug)
  ]

  if not subset.empty:
    x = [row['Param Size'] for _, row in subset.iterrows()]
    y = [
        row[metric_name][-1] if use_last else min(row[metric_name])
        for _, row in subset.iterrows()
    ]
    sizes = [marker_size_map[row['Model Size']] for _, row in subset.iterrows()]
    x_log = np.log10(x)
    sns.scatterplot(
        x=x_log,
        y=y,
        size=sizes,
        sizes=(75, 150),
        color='black',
        ax=ax3,
        alpha=circle_alpha,
        legend=False,
    )
    # ax3.set_ylim(0, 0.65)
    sns.lineplot(x=x_log, y=y, color='black', ax=ax3, alpha=line_alpha)
    x_all.extend(x)
    y_all.extend(y)

fit_and_plot_linear_scaling(
    np.log10(x_all), y_all, ax3, color='#217BFE', label='Best Fit'
)
sns.scatterplot(
    x=np.log10(x_all),
    y=y_all,
    color='#217BFE',
    ax=ax3,
    s=30,
    alpha=circle_alpha,
    legend=False,
)

#####################################################################################
# Titles and labels
ax1.set_xlabel(r'$\mathbf{Compute}$' + '\n TPU v5 VLP core hours')
ax1.set_ylabel('Masked Mean Squared Error')
ax2.set_xlabel(r'$\mathbf{Data\ Size}$' + '\n(Hours)')
ax3.set_xlabel(r'$\mathbf{Model\ Size}$' + '\n(Million of Params)')

# Set the number of ticks and ensure unique tick labels
for ax in [ax1, ax2, ax3]:
  ax.xaxis.set_major_locator(
      MaxNLocator(integer=True, prune='both')
  )  # Adjust to avoid repetitive ticks
  xticks = ax.get_xticks()
  ax.xaxis.set_major_locator(FixedLocator(xticks))
  ax.set_xticklabels([
      f'$10^{int(val)}$' if i == 0 or val != xticks[i - 1] else ''
      for i, val in enumerate(xticks)
  ])

marker_sizes = [
    marker_size_map['Tiny'],
    marker_size_map['ExtraSmall'],
    marker_size_map['Base'],
]
marker_labels = ['2M', '7M', '110M']

marker_handles = [
    plt.scatter([], [], s=size, color='black') for size in marker_sizes
]

# Combine handles and labels
combined_handles = marker_handles
combined_labels = marker_labels
plt.tight_layout()
plt.show()