In [None]:
import os
import re
import tempfile
import warnings
import collections
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

from google3.learning.deepmind.xmanager2.client import xmanager_api
from google3.pyglib import gfile
from google3.pyglib.function_utils import memoize
from matplotlib import font_manager

import matplotlib.patches as mpatches
import matplotlib.lines as mlines
from matplotlib.lines import Line2D
from matplotlib.ticker import FixedLocator  # Import for the fix
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import LogLocator
from matplotlib.ticker import FuncFormatter

In [None]:
#@title Google Sans Import

# Import Google font family
_GOOGLE_SANS_PATH = (
    'google3/third_party/googlefonts/api/googlerestricted/googlesans/'
)

@memoize.Memoize()
def import_google3_fonts(font_path: str) -> None:
  """Import fonts stored in google3 into Matplotlib for use in Colab.

  Args:
    font_path: google3 path to either a directory that contains .ttf fonts or to
      a specific .ttf font file.
  """
  if gfile.IsDirectory(font_path):
    # Create a temp directory as a destination for copied font files.
    tmp_dir = tempfile.mkdtemp()
    # Copy font files from google3 to temp dir.
    gfile.RecursivelyCopyDir(font_path, tmp_dir, overwrite=True)
    # Add font files in directory to matplotlib font_manager.
    font_files = font_manager.findSystemFonts(fontpaths=tmp_dir)
  else:
    # Assume the path points to a file if it's not a directory.
    # Copy ttf file from google3 to temp location.
    tmp_file = tempfile.NamedTemporaryFile(suffix='.ttf')
    tmp_file.close()
    gfile.Copy(font_path, tmp_file.name)
    font_files = [tmp_file.name]

  # Add fonts to default font manager.
  for font_file in font_files:
    font_manager.fontManager.addfont(font_file)


def import_default_google_fonts() -> None:
  """Register a set of default fonts (Roboto, Google Sans) with Matplotlib."""
  # Prepend google_src to google3 paths.
  import_google3_fonts(os.path.join('/google_src/head/depot', _GOOGLE_SANS_PATH))


# Import and register Google fonts with Matplotlib so we can use them.
import_default_google_fonts()

In [None]:
#@title Set up Plot Settings

pd.set_option('display.max_rows', None)  # Show all rows
pd.set_option('display.max_columns', None)  # Show all columns
# Suppress specific warning
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 12
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 1
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'
plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE-5)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')
mpl.rcParams['font.family'] = 'Google Sans'

# Function for log formatting
def log_float_formatter(x, pos):
    return '{:.1f}'.format(x)


In [None]:
# Create the figure and axes for the subplots
results = {
    'shuffled_subjects': [1000, 10000, 100000],
    'unshuffled_subjects': [100, 1000, 10000],
    'shuffled_validation_loss': [0.66, 0.38, 0.22], # xm/124248847
    'unshuffled_validation_loss': [0.62, 0.37, 0.22], #xm/126315392
}
df = pd.DataFrame(results)


def log_float_formatter(y, pos):
    return f'{y:.2f}'

print('Gradient of Shuffled Loss: ', np.gradient(df['shuffled_validation_loss'].to_numpy()))
print('Gradient of Unshuffled Loss: ', np.gradient(df['unshuffled_validation_loss'].to_numpy()))

fig, axes = plt.subplots(1, 1, figsize=(3, 3), sharex=True, dpi=100)
axes.set(xscale="log", yscale="log")

# Custom markers for each data point
shuffled_markers = ['s', 'o', 'P']
unshuffled_markers = ['s', 'o', 'P']

axes.plot(results['shuffled_subjects'], results['shuffled_validation_loss'], color='#3182BD', linestyle='--', marker='o', markersize=0)
axes.plot(results['unshuffled_subjects'], results['unshuffled_validation_loss'], color='#9ECAE1', linestyle='--', marker='o', markersize=0)

for i, (x, y) in enumerate(zip(results['shuffled_subjects'], results['shuffled_validation_loss'])):
    axes.plot(x, y, marker=shuffled_markers[i], color='#3182BD', markersize=7)

for i, (x, y) in enumerate(zip(results['unshuffled_subjects'], results['unshuffled_validation_loss'])):
    axes.plot(x, y, marker=unshuffled_markers[i], color='#9ECAE1', markersize=7)

axes.yaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(1, 10), numticks=10))
axes.yaxis.set_major_formatter(FuncFormatter(log_float_formatter))
axes.xaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(1, 10), numticks=10))
axes.set_ylabel('Test Loss')
axes.set_xlabel('Number of Subjects')

legend_elements = [
    plt.Line2D([0], [0], color='#3182BD', marker='', markersize=7, label='5 H / subj.'),
    plt.Line2D([0], [0], color='#9ECAE1', marker='', markersize=7, label='50 H / subj.')
]
axes.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.02, 1), frameon=False, fontsize=10)

plt.savefig("/tmp/scaling_subjects_hours.pdf", bbox_inches='tight', format="pdf")
plt.show()
%download_file /tmp/scaling_subjects_hours.pdf

In [None]:
df['shuffled_validation_loss'].to_numpy()

In [None]:
#@title Data vs. Subject Scaling

results = {
    'shuffled_subjects': [1000, 10000, 100000],
    'unshuffled_subjects': [100, 1000, 10000],
    'shuffled_validation_loss': [0.63, 0.38, 0.22], # xm/124248847
    'unshuffled_validation_loss': [0.62, 0.37, 0.22], #xm/126315392
}
df = pd.DataFrame(results)

# Create the figure and axes for the subplots
fig, axes = plt.subplots(1, 1, figsize=(4.5, 4), sharex=True, dpi=100)
axes.set(xscale="log", yscale="log")

# Define markers for the different data points
markers = ['o', 's', 'D', 'P']  # A list of marker shapes

# Plot the first line (shuffled)
for i in range(len(results['shuffled_subjects'])):
    sns.lineplot(
        x=[results['shuffled_subjects'][i]],
        y=[results['shuffled_validation_loss'][i]],
        color='#ea4335',
        marker=markers[i % len(markers)],
        markersize=7,
        linestyle='--',
        label='_nolegend_',  # Exclude from legend
        ax=axes,
    )

# Plot the second line (unshuffled)
for i in range(len(results['unshuffled_subjects'])):
    sns.lineplot(
        x=[results['unshuffled_subjects'][i]],
        y=[results['unshuffled_validation_loss'][i]],
        color='#4285f4',
        marker=markers[i % len(markers)],
        markersize=7,
        linestyle='--',
        label='_nolegend_',  # Exclude from legend
        ax=axes,
    )

# Add the final lineplot to get lines connecting the points
sns.lineplot(
    data=results,
    x='shuffled_subjects',
    y='shuffled_validation_loss',
    color='#ea4335',
    marker=None,  # No marker for this line
    linestyle='--',
    ax=axes,
    label='5+ Hours Per Subject',  # Custom legend label
)

sns.lineplot(
    data=results,
    x='unshuffled_subjects',
    y='unshuffled_validation_loss',
    color='#4285f4',
    marker=None,  # No marker for this line
    linestyle='--',
    ax=axes,
    label='50 Hours Per Subject',  # Custom legend label
)

# Set the axes labels and log scale
axes.yaxis.set_major_locator(LogLocator(base=10.0, subs=np.arange(1, 10), numticks=10))
axes.yaxis.set_major_formatter(FuncFormatter(log_float_formatter))
axes.set_ylabel('Validation Loss')
axes.set_xlabel('Number of Subjects')

# Create custom legend handles with the four markers in black
legend_handles_shapes = [
    Line2D([0], [0], marker='o', color='black', label='0.005M', markersize=7, linestyle='None'),
    Line2D([0], [0], marker='s', color='black', label='0.05M', markersize=7, linestyle='None'),
    Line2D([0], [0], marker='D', color='black', label='0.5M', markersize=7, linestyle='None'),
    Line2D([0], [0], marker='P', color='black', label='3.8M', markersize=7, linestyle='None')
]

# Create a custom legend for the lines (dotted red and blue lines)
legend_handles_lines = [
    Line2D([0], [0], color='#ea4335', linestyle='--', label='5+ Hours Per Subject'),
    Line2D([0], [0], color='#4285f4', linestyle='--', label='50 Hours Per Subject')
]

# Add the custom marker legend below the plot
legend_shapes = axes.legend(
    handles=legend_handles_shapes,
    loc='lower center',
    bbox_to_anchor=(0.5, -0.3),
    ncol=4,
    fontsize=10,
    frameon=False
)

# Add the line legend at the top right of the figure
legend_lines = fig.legend(
    handles=legend_handles_lines,
    loc='upper right',
    bbox_to_anchor=(0.9, 0.95),  # Position it slightly inside the plot
    fontsize=MEDIUM_SIZE-4,
    frameon=False
)
fig.text(0.5, 0.03, "Number of Hours Million", ha='center', fontsize=MEDIUM_SIZE-2)

# Adjust the layout so the legends are properly displayed
plt.tight_layout()

# Save or show the plot with the legends included
plt.savefig("plot_with_legends.png", bbox_inches='tight')  # Ensure legends are fully displayed in saved output
plt.savefig("/tmp/scaling_subjects_hours.pdf", bbox_inches='tight', format="pdf")
plt.show()
%download_file /tmp/scaling_subjects_hours.pdf
