# LSM XM Metrics Reader

This notebook gives a generic interfce to grab an XM job and its associated metrics, and convert it to a tablular form.

In [None]:
# @title Imports

from google3.learning.deepmind.xmanager2.client import xmanager_api
import matplotlib as mpl
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import collections
import numpy as np


In [None]:
# @title Plot Formatting

MEDIUM_SIZE = 18
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
mpl.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 2
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'

plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')

In [None]:
# @title Set Up

# Setup XM Client
xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')

In [None]:
# @title Metrics

# @title Metrics and Field Names

# Get metric names.
metric_names = [
    # Generic Metrics
    'train_loss',
    'valid_loss',

    # Classification Metrics
    'train_accuracy',
    'train_balanced_accuracy',
    'train_f1_score',
    'train_mAP',
    'train_AP',

    'valid_accuracy',
    'valid_balanced_accuracy',
    'valid_f1_score',
    'valid_mAP',
    'valid_AP',

    # Generative Task Metrics
    # TODO(girishvn): Add generative metrics here (e.g. random impute / impute / forecast / etc. MAE / MSE / etc.)
]

meta_data_name = [
    'num_trainable_params',
    'core_hours',
    'examples_seen',
    'gflops',
]

data_field_names = meta_data_name + metric_names

In [None]:
# @title Helpers

def read_xm_series_metrics(
    example_xid, metric_name, unit_id, lowest=False
):
  experiment = xm_client.get_experiment(example_xid)
  work_unit = experiment.get_work_unit(unit_id)
  all_series = work_unit.list_measurement_series()

  # Read measurement series metadata.
  for series in all_series:
    if metric_name == series.label:
      # Read measurement points data.
      all_measurements = []
      steps = []
      for measurement in series.measurements:
        all_measurements.append(measurement.objective_value)
        steps.append(measurement.step)

      # If return the lowest value.
      if lowest:
        min_arg = np.argmin(all_measurements)
        min_val = all_measurements[min_arg]
        min_step = steps[min_arg]
        return min_val, min_step, series.label

      # Else return the entire series
      else:
        return all_measurements, steps, series.label

  # If not found
  return None, None, metric_name


def add_min_columns(df):
  # Function to calculate the minimum value in each list
  def min_of_list(lst):
    return min(lst)

  def min_idx_of_list(lst):
    min_idx = np.argmin(lst)
    return min_idx

  def last_of_list(lst):
    return lst[-1]

  # Calculate minimum values and add as new columns
  for col in df.columns:
    if 'error' in col:
      new_col_name = 'min_' + col
      new_col_name_idx = 'min_idx_' + col
      min_val = df[col].apply(min_of_list)
      min_val_idx = df[col].apply(min_idx_of_list)
      df[new_col_name] = min_val
      df[new_col_name_idx] = min_val_idx

      new_col_name = 'final_' + col
      df[new_col_name] = df[col].apply(last_of_list)

  return df


def add_better_col_names(df):

  def patch_col_name(patch_size):
    return f'{patch_size[0]}x{patch_size[1]}'

  for col in df.columns:
    if col == 'config.model.patches.size':
      df['patch_size'] = df[col].apply(patch_col_name)

  return df


def get_metrics_df(xm_dict, data_field_names):

  # Get all metrics.
  xm_exp_dict = collections.defaultdict(list)
  # Iterate through XIDs
  for xid, xid_added_constants in xm_dict.items():

    # Setup.
    experiment = xm_client.get_experiment(xid)
    num_of_units = experiment.get_num_work_units()

    # Iterate through job WIDs
    for wid in range(1, num_of_units + 1):

      # Add constant values
      xm_exp_dict['xid'].append(xid)  # add xid
      xm_exp_dict['wid'].append(wid)  # add wid
      xm_exp_dict.update(xid_added_constants)  # add hardcoded xid constants

      # Get info from XM API
      work_unit = experiment.get_work_unit(wid)  # work info
      key_list = work_unit.parameters.keys()  # work unit parameters

      # Get params (often defined as hyperparameters)
      for param_name in key_list:
        xm_exp_dict[param_name].append(work_unit.parameters[param_name])

      # Get XM metrics
      steps_super_set = []
      for metric in data_field_names:
        metric_val, metric_steps, metric_name = read_xm_series_metrics(
            xid, metric, wid, lowest=False
        )
        if metric_val is not None:
          xm_exp_dict[metric_name].append(metric_val)
          xm_exp_dict[f'steps_{metric_name}'].append(metric_steps)
          steps_super_set += metric_steps

      steps_super_set = sorted(list(set(steps_super_set)))
      xm_exp_dict['steps'].append(steps_super_set)


  # Generate dataframe
  df = pd.DataFrame(xm_exp_dict)
  df = add_min_columns(df)
  df = add_better_col_names(df)

  # If a column list is of length 1, simply use the scalar value.
  df = df.map(lambda x: x[0] if isinstance(x, list) and len(x) == 1 else x)

  return df


def get_step_df(df, step, data_field_names):

  # Setup
  step_df = df.copy()
  df_cols = list(step_df.columns)
  metric_cols = list(set(data_field_names) & set(df_cols))
  metric_cols.append('steps')

  # Get step to use
  steps = step_df['steps'][0]

  # Edge cases
  if step == -1:
    step = steps[-1]
  elif step == 0:
    step = steps[0]

  # Ensure step exist
  if step not in steps:
    raise ValueError(f'Step {step} not in {steps}')

  # Update columns
  for col in metric_cols:
      if 'steps' in col:
        continue

      step_col_name = f'steps_{col}'
      col_steps = step_df[step_col_name][0]

      if not isinstance(col_steps, list):
        continue

      try:
        step_idx = np.argwhere(np.array(col_steps) == step)[0][0]
      except:
        step_idx = None

      step_df[col] = step_df[col].apply(
          lambda x: x[step_idx]
          if (isinstance(x, list) and (step_idx is not None))
          else None
      )

  for col in df_cols:
    if 'steps' in col:
      # Drop step column
      step_df = step_df.drop(columns=col)

  return step_df


# Metric Retrieval / Formatting

In [None]:
# @title Get All Metrics


XID_WID_DICT = {
    157571037 : {
        'method': 'LSM_v2'
    }
}

metrics_df = get_metrics_df(XID_WID_DICT, data_field_names)
metrics_df.head()

In [None]:
# @title Get Single Step Metrics

# Set step = 0 for first step
# Set step = -1 for the last step

step = 26

df_step = get_step_df(metrics_df, step, data_field_names)
df_step