## LSM Patch Scaling and Feature Order Analysis
##### Colab Kernel (Brainframe GPU)
##### Dataset (Electrodes)

Grants command for Access on Demand (AoD):

https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-electrodes-deid-colab-jobs&reason=b%2F314799341

### About This Notebook:
**Patch Scaling:** This notebook analyzes patch scaling experiments, where we test how changing the patch size along the time-axis and the feature-axis, independently, affect the overall validation loss, and the the loss on down-stream reconstruction tasks (eg. Imputation and Forecasting).

**Masking Ratio Sweep.** This notebook analyzes the masking ratio and its affect on downstream tasks (eg. imputation and forecasting). We iterate over random masking (30% - 90%). And also structured masking (bar masking along the feature axis, and along the time axis) (30% - 90%)

**Randomized Feature Order:** Further, this explore how ordered features (default), compare against randomized features, for the same above evaluation tasks and for the patch scaling. To specifiy: randomized features are a shuffling of feature rows. This new shuffled order is used for ALL experiments. This feature order is:

`[10, 14, 22, 21, 8, 11, 5, 24, 20, 4, 17, 0, 15, 9, 7, 3, 18, 23, 2, 13, 25, 1, 6, 16, 19, 12]`

and equivalently:

`['hrvRRMean', 'hrvRMSSD', 'zeroCrossingStd', 'logEnergyRatio', 'hrvRR20thPercentile', 'hrvShannonEntropyRR', 'skinTempValue', 'axisMean', 'grok_covariance', 'sclSlope', 'jerkAuto', 'hrvPercentGood', 'hrvSDNN', 'hrvRRMedian', 'hrvRR80thPercentile', 'sclValue', 'stepCount', 'zeroCrossingAvg', 'altimStdNorm', 'hrvPNN30', 'grok_kurtosis', 'onWrist', 'hr', 'sleepCoefficient', 'logEnergy', 'hrvShannonEntropyRRDiffs']`

This notebook draw from results from the following xmanager experiments:
1. [Patch scaling (ordered features)](https://xmanager.corp.google.com/experiments/120664779)
2. [Patch scaling (randomized features)](https://xmanager.corp.google.com/experiments/120685218)

In [None]:
# @title Imports

from google3.learning.deepmind.xmanager2.client import xmanager_api
import matplotlib as mpl
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import collections
import numpy as np


In [None]:
# @title Plot Formatting

MEDIUM_SIZE = 18
mpl.rcParams.update({
    'font.size': MEDIUM_SIZE,
    'axes.labelsize': MEDIUM_SIZE,
    'axes.titlesize': MEDIUM_SIZE,
})
mpl.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['font.size'] = MEDIUM_SIZE
plt.rcParams['axes.linewidth'] = 2
plt.rcParams['axes.edgecolor'] = '#777777'
plt.rcParams['axes.facecolor'] = '#FFFFFF'

plt.rc('font', size=MEDIUM_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)  # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE)  # fontsize of the figure title

elegant_palette = sns.color_palette('muted')

In [None]:
# @title Helpers

def read_xm_metrics(example_xid, metric_name, unit_id, lowest=True):
  experiment = xm_client.get_experiment(example_xid)
  work_unit = experiment.get_work_unit(unit_id)
  all_series = work_unit.list_measurement_series()
  # Read measurement series metadata.
  for series in all_series:
    if series.label == metric_name:
      # Read measurement points data.
      all_measurements = []
      for measurement in series.measurements:
        all_measurements.append(measurement.objective_value)
      if lowest:
        return min(all_measurements)
      else:
        return all_measurements


def add_min_columns(df):
  # Function to calculate the minimum value in each list
  def min_of_list(lst):
    return min(lst)

  def min_idx_of_list(lst):
    min_idx = np.argmin(lst)
    return min_idx

  def last_of_list(lst):
    return lst[-1]

  # Calculate minimum values and add as new columns
  for col in df.columns:
    if 'error' in col:
      new_col_name = 'min_' + col
      new_col_name_idx = 'min_idx_' + col
      min_val = df[col].apply(min_of_list)
      min_val_idx = df[col].apply(min_idx_of_list)
      df[new_col_name] = min_val
      df[new_col_name_idx] = min_val_idx

      new_col_name = 'final_' + col
      df[new_col_name] = df[col].apply(last_of_list)

  return df


def add_better_col_names(df):

  def patch_col_name(patch_size):
    return f'{patch_size[0]}x{patch_size[1]}'

  for col in df.columns:
    if col == 'config.model.patches.size':
      df['patch_size'] = df[col].apply(patch_col_name)

  return df


def get_metrics_df(xm_dict):

  # Get all metrics.
  xm_exp_dict = collections.defaultdict(list)
  for xid, values in xm_dict.items():
    model_size = values['model_size']
    feat_order = values['feature_order']

    experiment = xm_client.get_experiment(xid)
    num_of_units = experiment.get_num_work_units()

    for wid in range(1, num_of_units + 1):
      work_unit = experiment.get_work_unit(wid)
      key_list = work_unit.parameters.keys()
      xm_exp_dict['wid'].append(wid)
      xm_exp_dict['xid'].append(xid)

      xm_exp_dict['Model Size'].append(model_size)
      xm_exp_dict['Feature Order'].append(feat_order)

      for param_name in key_list:
        xm_exp_dict[param_name].append(work_unit.parameters[param_name])
      for metric in data_field_names:
        xm_exp_dict[metric].append(
            read_xm_metrics(xid, metric, wid, lowest=False)
        )
  df = pd.DataFrame(xm_exp_dict)
  df = add_min_columns(df)
  df = add_better_col_names(df)

  return df


In [None]:
feats = ['hrvPercentGood', 'onWrist', 'altimStdNorm', 'sclValue', 'sclSlope', 'skinTempValue', 'hr', 'hrvRR80thPercentile', 'hrvRR20thPercentile', 'hrvRRMedian', 'hrvRRMean', 'hrvShannonEntropyRR', 'hrvShannonEntropyRRDiffs', 'hrvPNN30', 'hrvRMSSD', 'hrvSDNN', 'sleepCoefficient', 'jerkAuto', 'stepCount', 'logEnergy', 'grok_covariance', 'logEnergyRatio', 'zeroCrossingStd', 'zeroCrossingAvg', 'axisMean', 'grok_kurtosis']

reorder_idx = [10, 14, 22, 21, 8, 11, 5, 24, 20, 4, 17, 0, 15, 9, 7, 3, 18, 23, 2, 13, 25, 1, 6, 16, 19, 12]

reorder_feats = [feats[i] for i in reorder_idx]

print(reorder_feats)


In [None]:
# @title Metrics and Field Names

# Get metric names.
metric_names = [
    'valid_mean_absolute_error_all',
    'valid_mean_absolute_error_masked',
    'valid_mean_squared_error_all',
    'valid_mean_squared_error_masked',

    'imputation_0.1_eval/valid_mean_absolute_error_all',
    'imputation_0.1_eval/valid_mean_absolute_error_masked',
    'imputation_0.1_eval/valid_mean_squared_error_all',
    'imputation_0.1_eval/valid_mean_squared_error_masked',

    'imputation_0.2_eval/valid_mean_absolute_error_all',
    'imputation_0.2_eval/valid_mean_absolute_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_all',
    'imputation_0.2_eval/valid_mean_squared_error_masked',

    'imputation_0.4_eval/valid_mean_absolute_error_all',
    'imputation_0.4_eval/valid_mean_absolute_error_masked',
    'imputation_0.4_eval/valid_mean_squared_error_all',
    'imputation_0.4_eval/valid_mean_squared_error_masked',

    'forecast_0.1_eval/valid_mean_absolute_error_all',
    'forecast_0.1_eval/valid_mean_absolute_error_masked',
    'forecast_0.1_eval/valid_mean_squared_error_all',
    'forecast_0.1_eval/valid_mean_squared_error_masked',

    'forecast_0.2_eval/valid_mean_absolute_error_all',
    'forecast_0.2_eval/valid_mean_absolute_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_all',
    'forecast_0.2_eval/valid_mean_squared_error_masked',

    'forecast_0.4_eval/valid_mean_absolute_error_all',
    'forecast_0.4_eval/valid_mean_absolute_error_masked',
    'forecast_0.4_eval/valid_mean_squared_error_all',
    'forecast_0.4_eval/valid_mean_squared_error_masked',
]

meta_data_name = [
    'num_trainable_params',
    'core_hours',
    'examples_seen',
    'gflops',
]

data_field_names = meta_data_name + metric_names

In [None]:
# Setup XM Client
xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')

In [None]:
# @title Getting Base Model Default

default_baase_model_xm_id_dict = {
    124248847: {
        'model_size': 'Base',
        'feature_order': 'Ordered',
        'loss_only_masked_patches': True,
        'meta_data': 'Default base model',
    },
}

default_df = get_metrics_df(default_baase_model_xm_id_dict)
# Filter default df for models that train on the FULL dataset (1321235)
default_df = default_df[
    default_df['config.dataset_configs.train_num_samples'] == 1321235
]
default_df['config.model.patches.size'] = [[10, 5]]  # Set the default patch size
default_df['config.masked_feature_loss.token_mask_probability'] = 'constant_0.8'
default_df

In [None]:
# @title Plotting setup (metrics, colors, legend labels)

plot_metric_names = [
    'imputation_0.1_eval/valid_mean_squared_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_masked',
    'imputation_0.4_eval/valid_mean_squared_error_masked',

    'forecast_0.1_eval/valid_mean_squared_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_masked',
    'forecast_0.4_eval/valid_mean_squared_error_masked',

    'valid_mean_squared_error_masked',
]

prefix = 'final_'
imputation_palette = ['lightsteelblue', 'cornflowerblue', 'steelblue']
forecast_palette = ['lightcoral', 'indianred', 'firebrick']
random_fill_pallete = ['lightgray']
color_wheel = imputation_palette + forecast_palette + random_fill_pallete
swept_time_patch_sizes = [[5, 5], [10, 5], [20, 5], [30, 5]]
swept_feature_patch_sizes = [[10, 2], [10, 5], [10, 10], [10, 26]]

legend_list = [
    'imputation 0.1', 'imputation 0.2', 'imputation 0.4',
    'forecast 0.1', 'forecast 0.2', 'forecast 0.4',
    'random fill 0.8'
]

## Patch Size Sweep

In [None]:
# @title Get Patch Sweep XM Metrics

# XM ID Dict
patch_size_sweep_xm_id_dict = {
    125205101: {
        'model_size': 'Base',
        'feature_order': 'Ordered',
        'loss_only_masked_patches': True,
        'meta_data': 'Extra Small patch size sweep',
    },

    125002746: {
        'model_size': 'Base',
        'feature_order': 'Ordered',
        'loss_only_masked_patches': True,
        'meta_data': 'Small patch size sweep',
    },

    125002485: {
        'model_size': 'Base',
        'feature_order': 'Ordered',
        'loss_only_masked_patches': True,
        'meta_data': 'Medium patch size sweep',
    },

    125001367: {
        'model_size': 'Base',
        'feature_order': 'Ordered',
        'loss_only_masked_patches': True,
        'meta_data': 'Large patch size sweep',
    },
}

patch_sweep_df = get_metrics_df(patch_size_sweep_xm_id_dict)
patch_sweep_df

In [None]:
sweep_df = pd.concat([patch_sweep_df, default_df], ignore_index=True, sort=False)
sweep_df['patch_area'] = sweep_df.apply(lambda row: row['config.model.patches.size'][0] * row['config.model.patches.size'][1], axis=1)
sweep_df

In [None]:
sub_df = sweep_df[[
    'config.model.patches.size', 'gflops',
    'final_imputation_0.1_eval/valid_mean_absolute_error_masked', 'final_imputation_0.2_eval/valid_mean_absolute_error_masked', 'final_imputation_0.4_eval/valid_mean_absolute_error_masked',
    'final_imputation_0.1_eval/valid_mean_squared_error_masked', 'final_imputation_0.2_eval/valid_mean_squared_error_masked', 'final_imputation_0.4_eval/valid_mean_squared_error_masked',

    'final_forecast_0.1_eval/valid_mean_absolute_error_masked', 'final_forecast_0.2_eval/valid_mean_absolute_error_masked', 'final_forecast_0.4_eval/valid_mean_absolute_error_masked',
    'final_forecast_0.1_eval/valid_mean_squared_error_masked', 'final_forecast_0.2_eval/valid_mean_squared_error_masked', 'final_forecast_0.4_eval/valid_mean_squared_error_masked',

    'final_valid_mean_squared_error_masked',
    'final_valid_mean_absolute_error_masked',

]]
sub_df['gflops'] = sweep_df.apply(lambda row: row['gflops'][0], axis=1)
sub_df

In [None]:
# @title Temporal Size Sweep

temporal_sweep_df = sweep_df[sweep_df['config.model.patches.size'].isin(swept_time_patch_sizes)]
xticks = [v[0] for v in swept_time_patch_sizes]

plt.figure(figsize=(10, 5))
for i, m in enumerate(plot_metric_names):

  m = prefix + m

  # patch_size =  np.array(temporal_sweep_df['patch_area'])
  patch_size =  np.array([v[0] for v in temporal_sweep_df['config.model.patches.size']])
  compute = np.array([v[0] for v in temporal_sweep_df['gflops']])
  metric_vals = np.array(temporal_sweep_df[m])

  sorted_idx = np.argsort(patch_size)
  patch_size = patch_size[sorted_idx]
  compute = compute[sorted_idx]
  metric_vals = metric_vals[sorted_idx]

  # Plot
  plt.scatter(patch_size, metric_vals, s=compute*10, alpha=0.5, color=color_wheel[i], label=legend_list[i])
  plt.plot(patch_size, metric_vals, linestyle='-', color=color_wheel[i], linewidth=2, label='')

plt.ylabel('Mean Squared Error')
plt.xlabel('Time Steps / Patch')
plt.xticks(xticks)
plt.xlim(left=0)
plt.ylim(bottom=0)
plt.legend(loc='best', bbox_to_anchor=(1, 1.04));


In [None]:
# @title Feature Size Sweep

feat_sweep_df = sweep_df[sweep_df['config.model.patches.size'].isin(swept_feature_patch_sizes)]
xticks = [v[1] for v in swept_feature_patch_sizes]

plt.figure(figsize=(10, 5))
for i, m in enumerate(plot_metric_names):

  m = prefix + m

  # patch_size =  np.array(feat_sweep_df['patch_area'])
  patch_size =  np.array([v[1] for v in feat_sweep_df['config.model.patches.size']])
  compute = np.array([v[0] for v in feat_sweep_df['gflops']])
  metric_vals = np.array(feat_sweep_df[m])

  sorted_idx = np.argsort(patch_size)
  patch_size = patch_size[sorted_idx]
  compute = compute[sorted_idx]
  metric_vals = metric_vals[sorted_idx]

  # Plot
  plt.scatter(patch_size, metric_vals, s=compute*10, alpha=0.5, color=color_wheel[i])
  plt.plot(patch_size, metric_vals, linestyle='-', color=color_wheel[i], linewidth=2)

plt.ylabel('Mean Squared Error')
plt.xlabel('Features / Patch')
plt.xlim(left=0)
plt.ylim(bottom=0)
plt.xticks(xticks);


## Masking Ratio Sweep

In [None]:
# @title Random Sweep Masking Ratio

mask_ratio_sweep_xm_id_dict = {
    125004844: {
        'model_size': 'Base',
        'feature_order': 'Ordered',
        'loss_only_masked_patches': True,
        'meta_data': 'mask ratio sweep',
    },

}

masking_sweep_df = get_metrics_df(mask_ratio_sweep_xm_id_dict)
masking_sweep_df

In [None]:
# @title Default Base Model XM

sweep_df = pd.concat([masking_sweep_df, default_df], ignore_index=True, sort=False)
sweep_df['mask_ratio'] = sweep_df.apply(lambda row: float(row['config.masked_feature_loss.token_mask_probability'][-3:]), axis=1)
sweep_df

In [None]:
sub_df = sweep_df[[
    'mask_ratio',
    'final_imputation_0.1_eval/valid_mean_absolute_error_masked', 'final_imputation_0.2_eval/valid_mean_absolute_error_masked', 'final_imputation_0.4_eval/valid_mean_absolute_error_masked',
    'final_imputation_0.1_eval/valid_mean_squared_error_masked', 'final_imputation_0.2_eval/valid_mean_squared_error_masked', 'final_imputation_0.4_eval/valid_mean_squared_error_masked',

    'final_forecast_0.1_eval/valid_mean_absolute_error_masked', 'final_forecast_0.2_eval/valid_mean_absolute_error_masked', 'final_forecast_0.4_eval/valid_mean_absolute_error_masked',
    'final_forecast_0.1_eval/valid_mean_squared_error_masked', 'final_forecast_0.2_eval/valid_mean_squared_error_masked', 'final_forecast_0.4_eval/valid_mean_squared_error_masked',

    'final_valid_mean_squared_error_masked',
    'final_valid_mean_absolute_error_masked',

]]
sub_df

In [None]:
# @title Bar Mask Sweep

bar_mask_ratio_sweep_xm_id_dict = {
    125114553: {
        'model_size': 'Base',
        'feature_order': 'Ordered',
        'loss_only_masked_patches': True,
        'meta_data': 'bar mask ratio sweep',
    },

}

bar_masking_sweep_df = get_metrics_df(bar_mask_ratio_sweep_xm_id_dict)
bar_masking_sweep_df

In [None]:
time_bar_masks = ['constant_0.3_bar', 'constant_0.4_bar', 'constant_0.5_bar', 'constant_0.6_bar', 'constant_0.7_bar', 'constant_0.8_bar', 'constant_0.9_bar']
feat_bar_masks = ['constant_0.3_bar_w', 'constant_0.4_bar_w', 'constant_0.5_bar_w', 'constant_0.7_bar_w', 'constant_0.9_bar_w']

time_bar_mask_df = bar_masking_sweep_df[bar_masking_sweep_df['config.masked_feature_loss.token_mask_probability'].isin(time_bar_masks)]
feat_bar_mask_df = bar_masking_sweep_df[bar_masking_sweep_df['config.masked_feature_loss.token_mask_probability'].isin(feat_bar_masks)]

time_bar_mask_df['mask_ratio'] = time_bar_mask_df.apply(lambda row: float(row['config.masked_feature_loss.token_mask_probability'][9:12]), axis=1)
feat_bar_mask_df['mask_ratio'] = feat_bar_mask_df.apply(lambda row: float(row['config.masked_feature_loss.token_mask_probability'][9:12]), axis=1)


In [None]:
# @title Mask Ratio Sweep

plot_metric_names = [
    'imputation_0.2_eval/valid_mean_squared_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_masked',
    # 'valid_mean_squared_error_masked',
]

plot_metric_names = [
    'imputation_0.1_eval/valid_mean_squared_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_masked',
    'imputation_0.4_eval/valid_mean_squared_error_masked',

    'forecast_0.1_eval/valid_mean_squared_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_masked',
    'forecast_0.4_eval/valid_mean_squared_error_masked',

    'valid_mean_squared_error_masked',
]

skip_list = [
    'imputation_0.1_eval/valid_mean_squared_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_masked',
    # 'imputation_0.4_eval/valid_mean_squared_error_masked',

    'forecast_0.1_eval/valid_mean_squared_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_masked',
    # 'forecast_0.4_eval/valid_mean_squared_error_masked',

    'valid_mean_squared_error_masked',
]


plt.figure(figsize=(10, 5))
for i, m in enumerate(plot_metric_names):
  if m in skip_list:
    continue
  m = prefix + m

  mask_ratio = np.array(sweep_df['mask_ratio'])
  metric_vals = np.array(sweep_df[m])

  sorted_idx = np.argsort(mask_ratio)
  mask_ratio = mask_ratio[sorted_idx]
  metric_vals = metric_vals[sorted_idx]

  print(f'{m}, {metric_vals[5]}')

  # Plot
  plt.scatter(mask_ratio, metric_vals, s=100, alpha=0.5, color=color_wheel[i], label='')
  plt.plot(mask_ratio, metric_vals, linestyle='-', color=color_wheel[i], linewidth=2, label='rand_' + legend_list[i])

print()
for i, m in enumerate(plot_metric_names):
  if m in skip_list:
    continue
  m = prefix + m

  mask_ratio = np.array(time_bar_mask_df['mask_ratio'])
  metric_vals = np.array(time_bar_mask_df[m])

  sorted_idx = np.argsort(mask_ratio)
  mask_ratio = mask_ratio[sorted_idx]
  metric_vals = metric_vals[sorted_idx]

  print(f'{m}, {metric_vals[5]}')


  # Plot
  plt.scatter(mask_ratio, metric_vals, s=100, alpha=0.5, color=color_wheel[i], label='')
  plt.plot(mask_ratio, metric_vals, linestyle='--', color=color_wheel[i], linewidth=2, label='bar_' + legend_list[i])

plt.ylabel('Mean Squared Error')
plt.xlabel('Masking Ratio')
plt.ylim(bottom=0)
plt.xticks([0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
plt.legend(loc='best', bbox_to_anchor=(1, 1.04));


## Feature Order Ablation

In [None]:
# @title Random Sweep Masking Ratio
feats_order_xm_id_dict = {
    125296318: {
        'model_size': 'Base',
        'feature_order': 'Random',
        'loss_only_masked_patches': True,
        'meta_data': 'random feature order sweep',
    },

    125246122: {
        'model_size': 'Base',
        'feature_order': 'Max Entropy',
        'loss_only_masked_patches': True,
        'meta_data': 'max entropy feature order',
    },

}

feats_order_df = get_metrics_df(feats_order_xm_id_dict)
feats_order_df

In [None]:
# @title Default Base Model XM

sweep_df = pd.concat([feats_order_df, default_df], ignore_index=True, sort=False)
sweep_df = sweep_df[[
    'Feature Order',
    'final_imputation_0.1_eval/valid_mean_absolute_error_masked', 'final_imputation_0.2_eval/valid_mean_absolute_error_masked', 'final_imputation_0.4_eval/valid_mean_absolute_error_masked',
    'final_imputation_0.1_eval/valid_mean_squared_error_masked', 'final_imputation_0.2_eval/valid_mean_squared_error_masked', 'final_imputation_0.4_eval/valid_mean_squared_error_masked',

    'final_forecast_0.1_eval/valid_mean_absolute_error_masked', 'final_forecast_0.2_eval/valid_mean_absolute_error_masked', 'final_forecast_0.4_eval/valid_mean_absolute_error_masked',
    'final_forecast_0.1_eval/valid_mean_squared_error_masked', 'final_forecast_0.2_eval/valid_mean_squared_error_masked', 'final_forecast_0.4_eval/valid_mean_squared_error_masked',

    'final_valid_mean_absolute_error_masked',
    'final_valid_mean_squared_error_masked',
]]

rand_order_df = sweep_df[sweep_df['Feature Order'] == 'Random']

sweep_df

In [None]:
rand_sub_df = rand_order_df.drop('Feature Order', axis='columns', inplace=False)
rand_sub_df.mean()

rand_sub_df[0:4].mean()

## Get XM Metrics

In [None]:
# @title Get XM Metrics

# Setup XM client
xm_client = xmanager_api.XManagerApi(xm_deployment_env='alphabet')

# XM ID Dict
xm_id_dict = {
    120664779: {
        'model_size': 'Tiny',
        'feature_order': 'Ordered',
        'loss_only_masked_patches': True
    },
    120685218: {
        'model_size': 'Tiny',
        'feature_order': 'Random',
        'loss_only_masked_patches': True
    },
}

# Get metric names.
metric_names = [
    'valid_mean_absolute_error_all',
    'valid_mean_absolute_error_masked',
    'valid_mean_squared_error_all',
    'valid_mean_squared_error_masked',

    'imputation_0.1_eval/valid_mean_absolute_error_all',
    'imputation_0.1_eval/valid_mean_absolute_error_masked',
    'imputation_0.1_eval/valid_mean_squared_error_all',
    'imputation_0.1_eval/valid_mean_squared_error_masked',

    'imputation_0.2_eval/valid_mean_absolute_error_all',
    'imputation_0.2_eval/valid_mean_absolute_error_masked',
    'imputation_0.2_eval/valid_mean_squared_error_all',
    'imputation_0.2_eval/valid_mean_squared_error_masked',

    'imputation_0.4_eval/valid_mean_absolute_error_all',
    'imputation_0.4_eval/valid_mean_absolute_error_masked',
    'imputation_0.4_eval/valid_mean_squared_error_all',
    'imputation_0.4_eval/valid_mean_squared_error_masked',

    'forecast_0.1_eval/valid_mean_absolute_error_all',
    'forecast_0.1_eval/valid_mean_absolute_error_masked',
    'forecast_0.1_eval/valid_mean_squared_error_all',
    'forecast_0.1_eval/valid_mean_squared_error_masked',

    'forecast_0.2_eval/valid_mean_absolute_error_all',
    'forecast_0.2_eval/valid_mean_absolute_error_masked',
    'forecast_0.2_eval/valid_mean_squared_error_all',
    'forecast_0.2_eval/valid_mean_squared_error_masked',

    'forecast_0.4_eval/valid_mean_absolute_error_all',
    'forecast_0.4_eval/valid_mean_absolute_error_masked',
    'forecast_0.4_eval/valid_mean_squared_error_all',
    'forecast_0.4_eval/valid_mean_squared_error_masked',
]

meta_data_name = [
    'num_trainable_params',
    'core_hours',
    'examples_seen',
    'gflops',
]

data_field_names = meta_data_name + metric_names

# Get all metrics.
xm_exp_dict = collections.defaultdict(list)
for xid, values in xm_id_dict.items():
  model_size = values['model_size']
  feat_order = values['feature_order']

  experiment = xm_client.get_experiment(xid)
  num_of_units = experiment.get_num_work_units()

  for wid in range(1, num_of_units + 1):
    work_unit = experiment.get_work_unit(wid)
    key_list = work_unit.parameters.keys()
    xm_exp_dict['wid'].append(wid)
    xm_exp_dict['xid'].append(xid)

    xm_exp_dict['Model Size'].append(model_size)
    xm_exp_dict['Feature Order'].append(feat_order)

    for param_name in key_list:
      xm_exp_dict[param_name].append(work_unit.parameters[param_name])
    for metric in data_field_names:
      xm_exp_dict[metric].append(
          read_xm_metrics(xid, metric, wid, lowest=False)
      )
df = pd.DataFrame(xm_exp_dict)
df = add_min_columns(df)
df = add_better_col_names(df)
df

## Sandbox

In [None]:
# @title Ablation on Patch Sizes (Ordered Features)

feat_order = 'Ordered'
metric_name = 'valid_mean_absolute_error_masked'
metric_name_short = 'MAE Masked Patches'

subset = df[
    (df['Feature Order'] == feat_order)
]

time_patch_sizes = [[5, 5], [10, 5], [20, 5]]
feat_patch_sizes = [[10, 1], [10, 2], [10, 5], [10, 10]]
time_idx = 0
feat_idx = 1

time_subset = subset[
    (subset['config.model.patches.size'].isin(time_patch_sizes))
]
feat_subset = subset[
    (subset['config.model.patches.size'].isin(feat_patch_sizes))
]

# Create a figure with three subplots in a row
fig, axes = plt.subplots(1, 1, figsize=(6, 6), dpi=600)
x = [row['patch_size'] for _, row in time_subset.iterrows()]
y = [min(row[metric_name]) for _, row in time_subset.iterrows()]
idx = np.argsort([int(i.split('x')[time_idx]) for i in x])
x = [x[i] for i in idx]
y = [y[i] for i in idx]

sns.barplot(x=x, y=y, palette='Blues')
axes.set_title('Path Size Scaling')
axes.set_xlabel('Patch Size')
axes.set_ylabel(metric_name_short)
plt.tight_layout()
plt.show()

# Create a figure with three subplots in a row
fig, axes = plt.subplots(1, 1, figsize=(6, 6), dpi=600)
x = [row['patch_size'] for _, row in feat_subset.iterrows()]
y = [min(row[metric_name]) for _, row in feat_subset.iterrows()]
idx = np.argsort([int(i.split('x')[feat_idx]) for i in x])
x = [x[i] for i in idx]
y = [y[i] for i in idx]

sns.barplot(x=x, y=y, palette='Blues')
axes.set_title('Path Size Scaling')
axes.set_xlabel('Patch Size')
axes.set_ylabel(metric_name_short)
plt.tight_layout()
plt.show()


In [None]:
# @title Ablation on Patch Sizes (Random Features)

feat_order = 'Random'
metric_name = 'valid_mean_absolute_error_masked'
metric_name_short = 'MAE Masked Patches'

subset = df[
    (df['Feature Order'] == feat_order)
]

time_patch_sizes = [[5, 5], [10, 5], [20, 5]]
feat_patch_sizes = [[10, 1], [10, 2], [10, 5], [10, 10]]
time_idx = 0
feat_idx = 1

time_subset = subset[
    (subset['config.model.patches.size'].isin(time_patch_sizes))
]
feat_subset = subset[
    (subset['config.model.patches.size'].isin(feat_patch_sizes))
]

metric_name = 'valid_mean_absolute_error_masked'
metric_name_short = 'MAE Masked Patches'
# Create a figure with three subplots in a row
fig, axes = plt.subplots(1, 1, figsize=(6, 6), dpi=600)
x = [row['patch_size'] for _, row in time_subset.iterrows()]
y = [min(row[metric_name]) for _, row in time_subset.iterrows()]
idx = np.argsort([int(i.split('x')[time_idx]) for i in x])
x = [x[i] for i in idx]
y = [y[i] for i in idx]

sns.barplot(x=x, y=y, palette='Blues')
axes.set_title('Path Size Scaling')
axes.set_xlabel('Patch Size')
axes.set_ylabel(metric_name_short)
plt.tight_layout()
plt.show()

metric_name = 'valid_mean_absolute_error_masked'
metric_name_short = 'MAE Masked Patches'
# Create a figure with three subplots in a row
fig, axes = plt.subplots(1, 1, figsize=(6, 6), dpi=600)
x = [row['patch_size'] for _, row in feat_subset.iterrows()]
y = [min(row[metric_name]) for _, row in feat_subset.iterrows()]
idx = np.argsort([int(i.split('x')[feat_idx]) for i in x])
x = [x[i] for i in idx]
y = [y[i] for i in idx]

sns.barplot(x=x, y=y, palette='Blues')
axes.set_title('Path Size Scaling')
axes.set_xlabel('Patch Size')
axes.set_ylabel(metric_name_short)
plt.tight_layout()
plt.show()


## Analysis

In [None]:
# @title Ablation on Patch Sizes (Ordered vs Random)

metric_name = 'min_valid_mean_absolute_error_masked'
metric_name_short = 'MAE Masked Patches (Min)'

time_patch_sizes = [[5, 5], [10, 5], [20, 5]]
feat_patch_sizes = [[10, 1], [10, 2], [10, 5], [10, 10]]
time_idx = 0
feat_idx = 1

time_subset = df[
    (df['config.model.patches.size'].isin(time_patch_sizes))
]
feat_subset = df[
    (df['config.model.patches.size'].isin(feat_patch_sizes))
]

plt.figure(figsize=(10, 6), dpi=500)
barplot = sns.barplot(
    data=time_subset,
    x='patch_size',
    y=metric_name,
    hue='Feature Order',
    palette='Blues',
    order=[f'{i[0]}x{i[1]}' for i in time_patch_sizes[::-1]],
    errorbar=('ci', 95),
    width=0.7,
)
plt.ylabel(metric_name_short)
plt.title('Temporal Patch Size Scaling')
plt.legend(frameon=False, ncol=1, loc='upper right', bbox_to_anchor=(1.3, 1))
plt.show()


plt.figure(figsize=(10, 6), dpi=500)
barplot = sns.barplot(
    data=feat_subset,
    x='patch_size',
    y=metric_name,
    hue='Feature Order',
    palette='Blues',
    order=[f'{i[0]}x{i[1]}' for i in feat_patch_sizes[::-1]],
    errorbar=('ci', 95),
    width=0.7,
)
plt.ylabel(metric_name_short)
plt.title('Feature Patch Size Scaling')
plt.legend(frameon=False, ncol=1, loc='upper right', bbox_to_anchor=(1.3, 1))
plt.show()

In [None]:
# @title Ablation on Patch Sizes (Imputation Features)

metric_name = 'valid_mean_absolute_error_masked'
metric_name_short = 'MAE Masked Patches (Min)'
metric_type = 'final'  # min or final
task_horizons = [0.1, 0.2, 0.4]
feat_order = 'Ordered'
task = 'imputation'

time_patch_sizes = [[5, 5], [10, 5], [20, 5]]
feat_patch_sizes = [[10, 1], [10, 2], [10, 5], [10, 10]]

subset = pd.DataFrame()
for i, row in df.iterrows():
  for hrzn in task_horizons:

    imputation_metric_name = f'{metric_type}_imputation_{hrzn}_eval/{metric_name}'
    forecast_metric_name = f'{metric_type}_forecast_{hrzn}_eval/{metric_name}'

    imp_row = pd.DataFrame(
        {
            'task': ['imputation'],
            'horizon': [hrzn],
            'error': [row[imputation_metric_name]],
            'Model Size': [row['Model Size']],
            'patch_size': [row['patch_size']],
            'config.model.patches.size': [row['config.model.patches.size']],
            'Feature Order': row['Feature Order'],
        }
    )
    for_row = pd.DataFrame(
        {
            'task': ['forecast'],
            'horizon': [hrzn],
            'error': [row[forecast_metric_name]],
            'Model Size': [row['Model Size']],
            'patch_size': [row['patch_size']],
            'config.model.patches.size': [row['config.model.patches.size']],
            'Feature Order': row['Feature Order'],
        }
    )

    subset = pd.concat([subset, imp_row], ignore_index=True)
    subset = pd.concat([subset, for_row], ignore_index=True)

# Get sweep subsets
subset = subset[
    (subset['task'] == task)
]
subset = subset[
    (subset['Feature Order'] == feat_order)
]
time_subset = subset[
    (subset['config.model.patches.size'].isin(time_patch_sizes))
]
feat_subset = subset[
    (subset['config.model.patches.size'].isin(feat_patch_sizes))
]

plt.figure(figsize=(10, 6), dpi=500)
barplot = sns.barplot(
    data=time_subset,
    x='patch_size',
    y='error',
    hue='horizon',
    palette='Blues',
    order=[f'{i[0]}x{i[1]}' for i in time_patch_sizes[::-1]],
    errorbar=('ci', 1),
    width=0.7,
)
plt.ylabel(metric_name_short)
plt.xlabel('Patch Size')
plt.title('Imputation: Temporal Patch Size Scaling')
plt.legend(frameon=False, ncol=1, loc='upper right', bbox_to_anchor=(1.2, 1))
plt.show()


plt.figure(figsize=(10, 6), dpi=500)
barplot = sns.barplot(
    data=feat_subset,
    x='patch_size',
    y='error',
    hue='horizon',
    palette='Blues',
    order=[f'{i[0]}x{i[1]}' for i in feat_patch_sizes[::-1]],
    errorbar=('ci', 1),
    width=0.7,
)
plt.ylabel(metric_name_short)
plt.xlabel('Patch Size')
plt.title('Imputation: Feature Patch Size Scaling')
plt.legend(frameon=False, ncol=1, loc='upper right', bbox_to_anchor=(1.2, 1))
plt.show()

In [None]:
# @title Ablation on Patch Sizes (Forecast)

metric_name = 'valid_mean_absolute_error_masked'
metric_name_short = 'MAE Masked Patches (Min)'
metric_type = 'final'  # min or final
task_horizons = [0.1, 0.2, 0.4]
feat_order = 'Ordered'
task = 'forecast'

time_patch_sizes = [[5, 5], [10, 5], [20, 5]]
feat_patch_sizes = [[10, 1], [10, 2], [10, 5], [10, 10]]

subset = pd.DataFrame()
for i, row in df.iterrows():
  for hrzn in task_horizons:

    imputation_metric_name = f'{metric_type}_imputation_{hrzn}_eval/{metric_name}'
    forecast_metric_name = f'{metric_type}_forecast_{hrzn}_eval/{metric_name}'

    imp_row = pd.DataFrame(
        {
            'task': ['imputation'],
            'horizon': [hrzn],
            'error': [row[imputation_metric_name]],
            'Model Size': [row['Model Size']],
            'patch_size': [row['patch_size']],
            'config.model.patches.size': [row['config.model.patches.size']],
            'Feature Order': row['Feature Order'],
        }
    )
    for_row = pd.DataFrame(
        {
            'task': ['forecast'],
            'horizon': [hrzn],
            'error': [row[forecast_metric_name]],
            'Model Size': [row['Model Size']],
            'patch_size': [row['patch_size']],
            'config.model.patches.size': [row['config.model.patches.size']],
            'Feature Order': row['Feature Order'],
        }
    )

    subset = pd.concat([subset, imp_row], ignore_index=True)
    subset = pd.concat([subset, for_row], ignore_index=True)

# Get sweep subsets
subset = subset[
    (subset['task'] == task)
]
subset = subset[
    (subset['Feature Order'] == feat_order)
]
time_subset = subset[
    (subset['config.model.patches.size'].isin(time_patch_sizes))
]
feat_subset = subset[
    (subset['config.model.patches.size'].isin(feat_patch_sizes))
]

plt.figure(figsize=(10, 6), dpi=500)
barplot = sns.barplot(
    data=time_subset,
    x='patch_size',
    y='error',
    hue='horizon',
    palette='Blues',
    order=[f'{i[0]}x{i[1]}' for i in time_patch_sizes[::-1]],
    errorbar=('ci', 1),
    width=0.7,
)
plt.ylabel(metric_name_short)
plt.xlabel('Patch Size')
plt.title('Forecast: Temporal Patch Size Scaling')
plt.legend(frameon=False, ncol=1, loc='upper right', bbox_to_anchor=(1.2, 1))
plt.show()


plt.figure(figsize=(10, 6), dpi=500)
barplot = sns.barplot(
    data=feat_subset,
    x='patch_size',
    y='error',
    hue='horizon',
    palette='Blues',
    order=[f'{i[0]}x{i[1]}' for i in feat_patch_sizes[::-1]],
    errorbar=('ci', 1),
    width=0.7,
)
plt.ylabel(metric_name_short)
plt.xlabel('Patch Size')
plt.title('Forecast: Feature Patch Size Scaling')
plt.legend(frameon=False, ncol=1, loc='upper right', bbox_to_anchor=(1.2, 1))
plt.show()