In [None]:
from google3.learning.deepmind.xmanager2.client import xmanager_api
import matplotlib as mpl
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import collections

xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')
MEDIUM_SIZE = 14
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
mpl.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = 20
plt.rcParams['axes.linewidth'] = 2
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'

SMALL_SIZE = 16
MEDIUM_SIZE = 15
BIGGER_SIZE = 20

plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')

In [None]:
def read_xm_metrics(example_xid, metric_name, unit_id, lowest=True):
  experiment = xm_client.get_experiment(example_xid)
  work_unit = experiment.get_work_unit(unit_id)
  all_series = work_unit.list_measurement_series()
  # Read measurement series metadata.
  for series in all_series:
    if series.label == metric_name:
      # Read measurement points data.
      all_measurements = []
      for measurement in series.measurements:
        all_measurements.append(measurement.objective_value)
      if lowest:
        return min(all_measurements)
      else:
        return all_measurements

In [None]:
metric_names = [
    'train_mean_absolute_error_all',
    'valid_mean_absolute_error_all',
    'train_mean_absolute_error_masked',
    'valid_mean_absolute_error_masked',
    'learning_rate',
    'l2_grads',
    'examples_seen'
]

xm_id = 117045959
experiment = xm_client.get_experiment(xm_id)
num_of_units = experiment.get_num_work_units()
xm_exp_dict = collections.defaultdict(list)

for id in range(num_of_units):
  real_id = id + 1
  work_unit = experiment.get_work_unit(real_id)
  key_list = work_unit.parameters.keys()
  xm_exp_dict['unit_id'].append(id)
  xm_exp_dict['xm_id'].append(xm_id)
  for param_name in key_list:
    xm_exp_dict[param_name].append(work_unit.parameters[param_name])
  for metric in metric_names:
    xm_exp_dict[metric].append(read_xm_metrics(xm_id, metric, real_id, lowest=False))

df = pd.DataFrame(xm_exp_dict)
df.info(verbose=True)

In [None]:
# Get unique learning rates
unique_lr = df['config.schedule.all.lr_configs.base_learning_rate'].unique()

# Set an elegant color palette
palette = sns.color_palette("Set2")
color_mapping = dict(zip(df['config.optimizer.weight_decay'].unique(), palette))

# Create subplots
# Metrics to plot
metrics = [
    'train_mean_absolute_error_all',
    'valid_mean_absolute_error_all',
    'train_mean_absolute_error_masked',
    'valid_mean_absolute_error_masked'
]

# Create subplots
plt.figure(figsize=(6 * len(unique_lr), 6 * len(metrics)), dpi=600)

for row_idx, metric_name in enumerate(metrics):
    for i, lr in enumerate(unique_lr):
        plt.subplot(len(metrics), len(unique_lr), row_idx * len(unique_lr) + i + 1)
        subset = df[df['config.schedule.all.lr_configs.base_learning_rate'] == lr]
        for _, row in subset.iterrows():
            x = list(range(len(row[metric_name])))
            y = row[metric_name]
            weight_decay = row['config.optimizer.weight_decay']
            plt.plot(x, y, label=f"{weight_decay}", color=color_mapping[weight_decay])
        plt.title(f'LR: {lr}')
        if row_idx == len(metrics) - 1:
            plt.xlabel('Steps (k)')
        else:
            plt.xlabel('')
        if i == 0:
            plt.ylabel(metric_name)
        else:
            plt.ylabel('')
        # if row_idx == 0 and i == len(unique_lr) - 1:
        plt.legend(title='Weight Decay', frameon=False)

plt.tight_layout()
plt.suptitle('Ablation Study of Learning Rate and Weight Decay Across Metrics + 1M Dataset + Large ViT', y=1.02)
plt.show()


In [None]:
unique_wd = df['config.optimizer.weight_decay'].unique()
palette = sns.color_palette("Set2")
color_mapping = dict(zip(df['config.schedule.all.lr_configs.base_learning_rate'].unique(), palette))

# Metrics to plot
metrics = [
    'train_mean_absolute_error_all',
    'valid_mean_absolute_error_all',
    'train_mean_absolute_error_masked',
    'valid_mean_absolute_error_masked'
]

# Create subplots
plt.figure(figsize=(6 * len(unique_wd), 6 * len(metrics)), dpi=600)

for row_idx, metric_name in enumerate(metrics):
    for i, wd in enumerate(unique_wd):
        plt.subplot(len(metrics), len(unique_wd), row_idx * len(unique_wd) + i + 1)
        subset = df[df['config.optimizer.weight_decay'] == wd]
        for _, row in subset.iterrows():
            x = list(range(len(row[metric_name])))
            y = row[metric_name]
            learning_rate = row['config.schedule.all.lr_configs.base_learning_rate']
            plt.plot(x, y, label=f"{learning_rate}", color=color_mapping[learning_rate])
        plt.title(f'WD: {wd}')
        if row_idx == len(metrics) - 1:
            plt.xlabel('Steps (k)')
        else:
            plt.xlabel('')
        if i == 0:
            plt.ylabel(metric_name)
        else:
            plt.ylabel('')
        plt.legend(title='Learning Rate', frameon=False)

plt.tight_layout()
plt.suptitle('Ablation Study of Learning Rate and Weight Decay Across Metrics + 1M Dataset + Large ViT', y=1.02)
plt.show()


## Numerical Analysis

In [None]:
metric_names = [
    'valid_mean_absolute_error_all',
    'valid_mean_absolute_error_masked',
    'valid_mean_squared_error_all',
    'valid_mean_squared_error_masked',
]


xm_id = 117045959
experiment = xm_client.get_experiment(xm_id)
num_of_units = experiment.get_num_work_units()
xm_exp_dict = collections.defaultdict(list)

for id in range(num_of_units):
  real_id = id + 1
  work_unit = experiment.get_work_unit(real_id)
  key_list = work_unit.parameters.keys()
  xm_exp_dict['unit_id'].append(id)
  xm_exp_dict['xm_id'].append(xm_id)
  for param_name in key_list:
    xm_exp_dict[param_name].append(work_unit.parameters[param_name])
  for metric in metric_names:
    xm_exp_dict[metric].append(read_xm_metrics(xm_id, metric, real_id, lowest=True))

df = pd.DataFrame(xm_exp_dict)
df.info(verbose=True)

In [None]:
#@title Overall Ranking

# Rank the configurations for each validation loss metric
df['rank_valid_mean_absolute_error_all'] = df['valid_mean_absolute_error_all'].rank(method='min')
df['rank_valid_mean_absolute_error_masked'] = df['valid_mean_absolute_error_masked'].rank(method='min')
df['rank_valid_mean_squared_error_all'] = df['valid_mean_squared_error_all'].rank(method='min')
df['rank_valid_mean_squared_error_masked'] = df['valid_mean_squared_error_masked'].rank(method='min')

# Create a DataFrame to store the rankings
rankings_df = df[['config.optimizer.weight_decay', 'config.schedule.all.lr_configs.base_learning_rate',
                  'rank_valid_mean_absolute_error_all', 'rank_valid_mean_absolute_error_masked',
                  'rank_valid_mean_squared_error_all', 'rank_valid_mean_squared_error_masked']]

# Sort the rankings DataFrame by each ranking metric for better visualization
rankings_df = rankings_df.sort_values(by=['rank_valid_mean_absolute_error_all',
                                          'rank_valid_mean_absolute_error_masked',
                                          'rank_valid_mean_squared_error_all',
                                          'rank_valid_mean_squared_error_masked'])
rankings_df

In [None]:
fixed_weight_decay = 1e-4

# Filter the DataFrame for the fixed weight decay
filtered_df = df[df['config.optimizer.weight_decay'] == fixed_weight_decay]
# Sort the filtered_df by the validation loss
sorted_df = filtered_df.sort_values(by='valid_mean_absolute_error_all')
# Count which learning rate values appear most frequently in the top 3 entries
most_common_learning_rate = (
    sorted_df['config.schedule.all.lr_configs.base_learning_rate']
    .value_counts()
    .idxmax()
)

most_common_learning_rate_count = (
    sorted_df['config.schedule.all.lr_configs.base_learning_rate']
    .value_counts()
    .max()
)

# Display the most common learning rate and its count
print(
    'The most common learning rate in the top 3 entries is'
    f' {most_common_learning_rate} and it appears'
    f' {most_common_learning_rate_count} times.'
)
sorted_df

In [None]:
# Fixed learning rate value, for example 0.001
fixed_learning_rate = 5e-3

# Filter the DataFrame for the fixed learning rate
filtered_df = df[
    df['config.schedule.all.lr_configs.base_learning_rate']
    == fixed_learning_rate
]

# Sort the filtered DataFrame by the validation loss
sorted_df = filtered_df.sort_values(by='valid_mean_absolute_error_all')

# Count which weight decay values appear most frequently in the top 3 entries
most_common_weight_decay = (
    sorted_df['config.optimizer.weight_decay'].value_counts().idxmax()
)

most_common_weight_decay_count = (
    sorted_df['config.optimizer.weight_decay'].value_counts().max()
)

# Display the most common weight decay and its count
print(
    'The most common weight decay in the top 3 entries is'
    f' {most_common_weight_decay} and it appears'
    f' {most_common_weight_decay_count} times.'
)
sorted_df