## LSM v2 Naive Baselines
##### Colab Kernel (Brainframe CPU)
##### Dataset (Anything)

Grants command for Access on Demand (AoD):

https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-electrodes-deid-colab-jobs&reason=b%2F314799341

### About This Notebook:
This notebook implements and evaluates naive baselines to compare against the LSM ViT MAE method. These baselines are evaluated on the validation set of the electrodes dataset. These baselines include:
1. Mean fill
2. Linear  interpolation
3. Nearest neighbor
4. MICE (as described here: TODO)

To run and visualize examples of these baselines run all setup cells, and then run the `Plot Naive Baseline Examples` cell.

To run naive baseline eval, across the `validation` set, set the `TO SET` values in the `Run Eval` and then run the cell. This takes ~1.5hrs to iterate over the ~650K examples.


# Setup

In [None]:
# @title Imports

import functools
from typing import Any, Callable, Dict, Iterator, Tuple, Optional, Type, Union
import time

import collections
from collections import Counter

from absl import logging
from clu import metric_writers
from clu import periodic_actions
from clu import platform

import flax
from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.profiler

import copy
import pandas as pd
import pickle
import ml_collections
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import tqdm

from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.ensemble import RandomForestRegressor
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingRegressor

from colabtools import adhoc_import
with adhoc_import.Google3():
  from scenic.dataset_lib import dataset_utils
  from scenic.google.xm import xm_utils
  from scenic.model_lib.base_models import base_model
  from scenic.model_lib.base_models import model_utils
  from scenic.model_lib.layers import nn_ops
  from scenic.train_lib import optax as scenic_optax
  from scenic.train_lib import pretrain_utils
  from scenic.train_lib import train_utils

  from scenic.projects.multimask.models import model_utils as mm_model_utils

  from google3.experimental.largesensormodels.scenic.datasets import get_dataset
  from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
  from google3.experimental.largesensormodels.scenic.models import lsm_vit as lsm_vit_mae
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_utils as lsm_model_utils
  from google3.experimental.largesensormodels.scenic.trainers import lsm_mae_trainer
  from google3.experimental.largesensormodels.scenic.trainers import lsm_mae_utils

  from google3.learning.deepmind.xmanager2.client import xmanager_api
  from google3.pyglib import gfile

  import ml_collections
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils.patcher_config import Patcher_Config
  from google3.experimental.largesensormodels.scenic.trainers.masking.masker_config import MaskStrategy_Config, Masker_Config
  from google3.experimental.largesensormodels.scenic.utils import config_constants
  from google3.experimental.largesensormodels.scenic.utils import predefined_configs


Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[
    [jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]],
    Dict[str, Tuple[float, int]],
]
LossFn = Callable[
    [jnp.ndarray, Batch, Optional[jnp.ndarray], jnp.ndarray], float
]
LrFns = Dict[str, Callable[[jnp.ndarray], jnp.ndarray]]
Patch = Union[Tuple[int, int], Tuple[int, int, int]]

In [None]:
# @title Re-import from max's CL

# NOTE: This is currently (14 May 2025) needed to run sensor imputation baselines

import functools
from typing import Any, Callable, Dict, Iterator, Tuple, Optional, Type, Union
import time

import collections
from collections import Counter

from absl import logging
from clu import metric_writers
from clu import periodic_actions
from clu import platform

import flax
from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.profiler

import copy
import pandas as pd
import pickle
import ml_collections
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import tqdm

from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.ensemble import RandomForestRegressor
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingRegressor

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.data import Dataset

x = Dataset


from colabtools import adhoc_import
with adhoc_import.Google3():
  from scenic.dataset_lib import dataset_utils
  from scenic.google.xm import xm_utils
  from scenic.model_lib.base_models import base_model
  from scenic.model_lib.base_models import model_utils
  from scenic.model_lib.layers import nn_ops
  from scenic.train_lib import optax as scenic_optax
  from scenic.train_lib import pretrain_utils
  from scenic.train_lib import train_utils
  from scenic.projects.multimask.models import model_utils as mm_model_utils


with adhoc_import.Google3CitcClient(
    'lsm_mixwmod_25_4_30', username='xumax', behavior='preferred'
):

  from google3.experimental.largesensormodels.scenic.datasets import get_dataset
  from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
  from google3.experimental.largesensormodels.scenic.models import lsm_vit as lsm_vit_mae
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_utils as lsm_model_utils
  from google3.experimental.largesensormodels.scenic.trainers import lsm_mae_trainer
  from google3.experimental.largesensormodels.scenic.trainers import lsm_mae_utils

  # from google3.learning.deepmind.xmanager2.client import xmanager_api
  from google3.pyglib import gfile

  import ml_collections
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils.patcher_config import Patcher_Config
  from google3.experimental.largesensormodels.scenic.trainers.masking.masker_config import MaskStrategy_Config, Masker_Config
  # from google3.experimental.largesensormodels.scenic.utils import config_constants
  from google3.experimental.largesensormodels.scenic.utils import predefined_configs


Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[
    [jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]],
    Dict[str, Tuple[float, int]],
]
LossFn = Callable[
    [jnp.ndarray, Batch, Optional[jnp.ndarray], jnp.ndarray], float
]
LrFns = Dict[str, Callable[[jnp.ndarray], jnp.ndarray]]
Patch = Union[Tuple[int, int], Tuple[int, int, int]]





In [None]:
# @title Patching Helper Functions

def patch_img(img, ph, pw):
  b, img_h, img_w, img_c = img.shape  # get the image shape
  nh, nw = (img_h // ph, img_w // pw)  # number of patches

  p1 = jnp.reshape(img, shape=(b, nh, ph, nw, pw, img_c))
  p2 = jnp.transpose(p1, (0, 1, 3, 2, 4, 5))  # [b, nh, nw, ph, pw, c]
  # p3 = jnp.reshape(p2, shape=(b, nh*nw, ph, pw, img_c))
  patches = p2

  return patches


def get_pixel_mask(token_mask, ph, pw, img_shape):
  b, h, w, c = img_shape
  nh, nw = (h // ph, w // pw)  # number of patches
  p1 = jnp.ones((b, nh, ph, nw, pw, c))  # [b, nh, ph, nw, pw, c]
  p2 = jnp.transpose(p1, (0, 1, 3, 2, 4, 5))  # [b, nh, nw, ph, pw, c]
  p3 = jnp.reshape(p2, shape=(b, nh*nw, ph*pw, c))  # [b, nh*nw, ph*pw, c]
  # [b, n patches, patch size]
  patched_ones = jnp.reshape(p3, shape=(b, nh*nw, ph*pw*c))

  # Apply mask
  weights_broadcast = jax.lax.broadcast_in_dim(
    token_mask,
    shape=patched_ones.shape,
    broadcast_dimensions=tuple(range(token_mask.ndim)),
  )
  img_mask = model_utils.apply_weights(patched_ones, weights_broadcast)
  img_mask = jnp.reshape(img_mask, shape=(b, nh, nw, ph, pw, c))
  img_mask = jnp.transpose(img_mask, (0, 1, 3, 2, 4, 5))
  img_mask = jnp.reshape(img_mask, shape=(b, nh*ph, nw*pw, c))

  return img_mask


def patch_and_mask_img(x, dropout_rng, config):
  n_batch, _, _, _ = x.shape
  ph, pw = config.model.patches.size

  # Patch image
  patches = patch_img(x, ph, pw)
  height = patches.shape[1]
  width = patches.shape[2]

  # Generate mask indices.
  n_tokens = height * width
  token_mask_probability = config.masked_feature_loss.token_mask_probability
  masking_configs = token_mask_probability.split('_')
  mask_probability = float(masking_configs[1])

  # Get masking strategy [random, forecast, imputation].
  if len(masking_configs) >= 3:
    masking_strategy = masking_configs[2]
  else:
    masking_strategy = 'random'

  # Get the mask dim (imputation and forecast).
  if len(masking_configs) >= 4:
    mask_dim = masking_configs[3]
    if mask_dim in ['h', 'time']:
      mask_dim = 'h'
      mask_dim_len = height
      mask_offdim_len = width
    elif mask_dim in ['w', 'feature', 'sensor']:
      mask_dim = 'w'
      mask_dim_len = width
      mask_offdim_len = height
    else:
      raise ValueError(f'Unsupported mask_dim: {mask_dim}')
  else:
    mask_dim = 'h'
    mask_dim_len = height
    mask_offdim_len = width

  if masking_strategy == 'random':  # Random Mask
    n_masked = int(mask_probability * n_tokens)
    mask_indices, unmasked_indices, token_mask = (
        mm_model_utils.get_mask_indices(
            n_batch, n_tokens, n_masked, dropout_rng  # TODO switch the rng each iteration TODO TODO TODO TODO(girishvn)
        )
    )
  elif masking_strategy == 'forecast':  # Forecast
    n_dim_masked = int(mask_probability * mask_dim_len)
    mask_indices, unmasked_indices, token_mask = (
        lsm_model_utils.get_forecast_mask_indices(
            n_batch=n_batch, n_h=height, n_w=width,
            n_dim_masked=n_dim_masked, mask_dim=mask_dim
        )
    )
  elif masking_strategy == 'imputation':  # Imputation
    n_dim_masked = int(mask_probability * mask_dim_len)
    mask_indices, unmasked_indices, token_mask = (
        lsm_model_utils.get_imputation_mask_indices(
            n_batch=n_batch, n_h=height, n_w=width,
            n_dim_masked=n_dim_masked, mask_dim=mask_dim,
            rng=dropout_rng
        )
    )
  elif masking_strategy == 'partialbar':  # Structured Bar
    mask_dim_prob = float(masking_configs[4])
    mask_offdim_prob = float(masking_configs[5])
    n_dim_masked = int(mask_dim_prob * mask_dim_len)
    n_offdim_masked = int(mask_offdim_prob * mask_offdim_len)
    mask_indices, unmasked_indices, token_mask = (
        lsm_model_utils.get_random_partial_bar_mask_indices(
            n_batch=n_batch, n_h=height, n_w=width,
            n_dim_masked=n_dim_masked, n_offdim_masked=n_offdim_masked,
            mask_dim=mask_dim,
            rng=dropout_rng
        )
    )
  else:
    raise ValueError(f'Unsupported masking strategy: {masking_strategy}')

  # Convert Generate Pixel-Level Mask (From Patch-Level Mask)
  pixel_mask = get_pixel_mask(token_mask, ph, pw, x.shape)

  mask_info = {
      'mask_indices': mask_indices,
      'unmasked_indices': unmasked_indices,
      'token_mask': token_mask,
      'pixel_mask': pixel_mask
  }

  return mask_info



In [None]:
# @title Naive Baseline Functions

def fit_linear_interp(x, mask_info, config):

  mask_indices = mask_info['mask_indices']
  unmasked_indices = mask_info['unmasked_indices']
  token_mask = mask_info['token_mask']
  pixel_mask = mask_info['pixel_mask']

  # Calculate prediction
  x_pred = jnp.array(x.copy())
  b, h, w, _ = x_pred.shape
  for i in range(b):  # iterate through batch
    for j in range(w):  # iterate through features
      if not dataset.meta_data['input_valid_feats'][j]:
        continue

      feat_vals = x[i, :, j, :]
      masked_feat_idx = jnp.where(pixel_mask[i, :, j, :] == 1)[0]
      unmasked_feat_idx = jnp.where(pixel_mask[i, :, j, :] == 0)[0]
      unmasked_feat_vals = jnp.ravel(feat_vals[unmasked_feat_idx])

      # All features are masked - impossible to interpolate
      if unmasked_feat_idx.size == 0:
        x_pred = x_pred.at[i, jnp.arange(h), j, 0].set(0)
      # No features are masked - no interpolation needed
      elif masked_feat_idx.size == 0:
        pass
      else:
        # Linear Interpolate
        masked_feat_vals_interp = jnp.interp(
            x=masked_feat_idx, xp=unmasked_feat_idx, fp=unmasked_feat_vals,
        )
        x_pred = x_pred.at[
            i, masked_feat_idx, j, 0
        ].set(masked_feat_vals_interp)

  # repatch x_pred
  ph, pw = config.model.patches.size
  x_pred = patch_img(x_pred, ph, pw)
  b, nh, nw, ph, pw, c = x_pred.shape
  x_pred = jnp.reshape(x_pred, shape=(b, nh*nw, ph*pw*c))

  return x_pred, {'token_mask': token_mask}


def fit_mean_fill(x, mask_info, config):

  mask_indices = mask_info['mask_indices']
  unmasked_indices = mask_info['unmasked_indices']
  token_mask = mask_info['token_mask']
  pixel_mask = mask_info['pixel_mask']

  # Calculate prediction
  x_pred = jnp.array(x.copy())
  b, h, w, _ = x_pred.shape
  for i in range(b):  # iterate through batch
    for j in range(w):  # iterate through features
      if not dataset.meta_data['input_valid_feats'][j]:
        continue

      feat_vals = x[i, :, j, :]
      masked_feat_idx = jnp.where(pixel_mask[i, :, j, :] == 1)[0]
      unmasked_feat_idx = jnp.where(pixel_mask[i, :, j, :] == 0)[0]
      unmasked_feat_vals = jnp.ravel(feat_vals[unmasked_feat_idx])

      # All features are masked - impossible to interpolate
      if unmasked_feat_idx.size == 0:
        x_pred = x_pred.at[i, jnp.arange(h), j, 0].set(0)
      # No features are masked - no interpolation needed
      elif masked_feat_idx.size == 0:
        pass
      else:
        mean_val = jnp.mean(x_pred[i, masked_feat_idx, j, 0])
        x_pred = x_pred.at[
            i, masked_feat_idx, j, 0
        ].set(mean_val)

  # repatch x_pred
  ph, pw = config.model.patches.size
  x_pred = patch_img(x_pred, ph, pw)
  b, nh, nw, ph, pw, c = x_pred.shape
  x_pred = jnp.reshape(x_pred, shape=(b, nh*nw, ph*pw*c))

  return x_pred, {'token_mask': token_mask}


def pandas_naive_baselines(x, mask_info, config):

  token_mask = mask_info['token_mask']
  pixel_mask = mask_info['pixel_mask']

  # Get inputs.
  x = jnp.asarray(x)
  b, h, w, c = x.shape

  # Convert mask to a mask of Nans.
  x_nan_masked = x.at[jnp.where(pixel_mask == 1)].set(jnp.nan)
  # b, h, w, c -> h, b, w, c
  x_batch_masked = jnp.transpose(x_nan_masked, (1, 0, 2, 3))
  # h, b, w, c -> h, b*w*c
  x_batch_masked = jnp.reshape(x_batch_masked, (h, b*w))

  # Convert to dataframe - enables single call interpolation on the whole batch
  x_df = pd.DataFrame(x_batch_masked)

  # Linear Interpolate
  linear_interp_df = x_df.interpolate(method='linear', limit_direction='both', axis=0)
  linear_interp_df = linear_interp_df.bfill().ffill()
  linear_interp_df = linear_interp_df.fillna(0)

  # Nearest Neighbor Interpolate
  nn_interp_df = x_df.interpolate(method='nearest', limit_direction='both', axis=0)
  nn_interp_df = nn_interp_df.bfill().ffill()
  nn_interp_df = nn_interp_df.fillna(0)

  # Mean Fill Interpolate
  mean_fill_interp_df = x_df.fillna(x_df.mean())
  mean_fill_interp_df = mean_fill_interp_df.fillna(0)

  # Spline
  # spline_interp_df = x_df.interpolate(method='spline', order=3, axis=0)
  # spline_interp_df = spline_interp_df.bfill().ffill()
  # spline_interp_df = spline_interp_df.fillna(0)

  # spline_interp_df = x_df.interpolate(method='cubic', axis=0)
  # spline_interp_df = spline_interp_df.bfill().ffill()
  # spline_interp_df = spline_interp_df.fillna(0)

  # Convert from df to jnp array.
  linear_interp = linear_interp_df.to_numpy()
  nn_interp = nn_interp_df.to_numpy()
  mean_fill_interp = mean_fill_interp_df.to_numpy()
  # spline_interp = spline_interp_df.to_numpy()

  linear_interp = jnp.asarray(linear_interp)
  nn_interp = jnp.asarray(nn_interp)
  mean_fill_interp = jnp.asarray(mean_fill_interp)
  # spline_interp = jnp.asarray(spline_interp)

  # Reshape to recover batch dim, feature dim, and channel dim
  linear_interp = jnp.reshape(linear_interp, (h, b, w, c))
  nn_interp = jnp.reshape(nn_interp, (h, b, w, c))
  mean_fill_interp = jnp.reshape(mean_fill_interp, (h, b, w, c))
  # spline_interp = jnp.reshape(spline_interp, (h, b, w, c))

  # Transpose to original shape
  linear_interp = jnp.transpose(linear_interp, (1, 0, 2, 3))
  nn_interp = jnp.transpose(nn_interp, (1, 0, 2, 3))
  mean_fill_interp = jnp.transpose(mean_fill_interp, (1, 0, 2, 3))
  # spline_interp = jnp.transpose(spline_interp, (1, 0, 2, 3))

  # Repatch prediction
  ph, pw = config.model.patches.size
  linear_interp = patch_img(linear_interp, ph, pw)
  nn_interp = patch_img(nn_interp, ph, pw)
  mean_fill_interp = patch_img(mean_fill_interp, ph, pw)
  # spline_interp = patch_img(spline_interp, ph, pw)

  b, nh, nw, ph, pw, c = linear_interp.shape
  linear_interp = jnp.reshape(linear_interp, shape=(b, nh*nw, ph*pw*c))
  nn_interp = jnp.reshape(nn_interp, shape=(b, nh*nw, ph*pw*c))
  mean_fill_interp = jnp.reshape(mean_fill_interp, shape=(b, nh*nw, ph*pw*c))
  # spline_interp = jnp.reshape(spline_interp, shape=(b, nh*nw, ph*pw*c))

  # Return baseline logits, and mask
  baseline_dict = {
      'linear': linear_interp,
      'nn': nn_interp,
      'mean_fill': mean_fill_interp,
      # 'spline': spline_interp,
  }
  return baseline_dict, {'token_mask': token_mask}


def fit_MICE_baselines(
    model, x, mask_info, config, estimator_name=None, max_iter=10
):

  token_mask = mask_info['token_mask']
  pixel_mask = mask_info['pixel_mask']

  # Get inputs.
  x = jnp.asarray(x)
  b, h, w, c = x.shape

  if b != 1:
    raise ValueError('Batch size must be 1 for MICE')

  # Convert mask to a mask of Nans.
  x_nan_masked = x.at[jnp.where(pixel_mask == 1)].set(jnp.nan)
  x_nan_masked = jnp.reshape(x_nan_masked, (h, w))

  # Convert to dataframe - enables single call interpolation on the whole batch
  x_df = pd.DataFrame(x_nan_masked)

  # MICE Interpolation
  # set estimator
  if estimator_name == 'LinearRegression':
    estimator = LinearRegression()
  elif estimator_name == 'RandomForestRegressor':
    estimator = RandomForestRegressor(
        n_estimators=100,
        max_depth=10,
        random_state=42,
    )
  else:
    estimator = None

  imp = IterativeImputer(
      estimator=estimator,
      max_iter=max_iter,
      random_state=42,
  )

  interp_df = imp.fit_transform(x_df)
  interp_df = pd.DataFrame(interp_df)

  # Convert from df to jnp array.
  interp_out = interp_df.to_numpy()
  interp_out = jnp.asarray(interp_out)


  # Reshape to recover batch dim, feature dim, and channel dim
  interp_out = jnp.reshape(interp_out, (b, h, w, c))

  # Repatch prediction
  ph, pw = config.model.patches.size
  interp_out = patch_img(interp_out, ph, pw)

  b, nh, nw, ph, pw, c = interp_out.shape
  interp_out = jnp.reshape(interp_out, shape=(b, nh*nw, ph*pw*c))

  # Return baseline logits, and mask
  baseline_dict = {
      'MICE': interp_out,
  }
  return baseline_dict, {'token_mask': token_mask}


In [None]:
# @title Naive Evaluation Pipeline

def naive_eval_step(
    batch: Batch,
    *,
    metrics_fn: MetricFn,
    config: ml_collections.ConfigDict,
    debug: Optional[bool] = False,
    rng: Optional[jax.random.PRNGKey] = None,
) -> Tuple[Dict[str, Tuple[float, int]], jnp.ndarray, Dict[str, Any]]:

  # Flatten out process dimension
  batch['input_signal'] = batch['input_signal'][0]
  batch['batch_mask'] = batch['batch_mask'][0]
  batch['token_mask'] = batch['token_mask'][0]
  batch['imputation_mask'] = batch['imputation_mask'][0]
  # batch['patched_imputationmask'] = batch['patched_imputationmask'][0]

  # Add prediction targets
  batch['targets'] = lsm_mae_utils.get_targets(batch, config)
  batch['patched_imputationmask'] = lsm_mae_utils.patchify_imputationmask(batch, config)

  # Patch and mask img.
  pixel_mask = get_pixel_mask(batch['token_mask'], PH, PW, batch['input_signal'].shape)
  mask_info = {
      'token_mask': batch['token_mask'],
      'pixel_mask': pixel_mask
  }

  # Calculate Baselines
  logits_dict, aux = pandas_naive_baselines(batch['input_signal'], mask_info, config)

  metrics_dict = dict()
  metrics_dict['linear'] = metrics_fn(
      logits_dict['linear'], aux['token_mask'], batch
  )
  metrics_dict['nn'] = metrics_fn(
      logits_dict['nn'], aux['token_mask'], batch
  )
  metrics_dict['mean_fill'] = metrics_fn(
      logits_dict['mean_fill'], aux['token_mask'], batch
  )
  # metrics_dict['spline'] = metrics_fn(
  #     logits_dict['spline'], aux['token_mask'], batch
  # )

  return metrics_dict, logits_dict, aux


def MICE_eval_step(
    model,
    batch: Batch,
    *,
    metrics_fn: MetricFn,
    config: ml_collections.ConfigDict,
    debug: Optional[bool] = False,
    rng: Optional[jax.random.PRNGKey] = None,
) -> Tuple[Dict[str, Tuple[float, int]], jnp.ndarray, Dict[str, Any]]:

  raise ValueError('NEEDS TO BE UPDATED TO LSMv2')

  # Flatten out process dimension
  batch['input_signal'] = batch['input_signal'][0]
  batch['batch_mask'] = batch['batch_mask'][0]

  # Add prediction targets
  batch['targets'] = lsm_mae_utils.get_targets(batch, config)

  if rng is None:
    # Always use the same seed, so that eval is as consistent as possible
    rng = jax.random.PRNGKey(config.rng_seed)

  # Patch and mask img.
  mask_info = patch_and_mask_img(batch['input_signal'], rng, config)

  # Calculate Baselines
  logits_dict, aux = fit_MICE_baselines(
      model, batch['input_signal'], mask_info, config
  )

  metrics_dict = dict()
  metrics_dict['MICE'] = metrics_fn(
      logits_dict['MICE'], aux['token_mask'], batch
  )

  return metrics_dict, logits_dict, aux

# OLD
# def regression_metrics_function(
#     predictions: jnp.ndarray,
#     prediction_masks: jnp.ndarray,
#     batch: base_model.Batch,
#     metrics: base_model.MetricNormalizerFnDict,
#     axis_name: Union[str, Tuple[str, ...]] = 'batch',
# ) -> Dict[str, Tuple[float, int]]:
#   """Calculate metrics for the regression task.

#   Currently we assume each metric_fn has the API:
#     ```metric_fn(predictions, targets, weights)```
#   and returns an array of shape [batch,]. We also assume that to compute
#   the aggregate metric, one should sum across all batches, then divide by the
#   total samples seen. In this way we currently only support metrics of the 1/N
#   sum f(inputs, targets). Note, the caller is responsible for dividing by
#   the normalizer when computing the mean of each metric.

#   Args:
#    predictions: Output of model in shape [batch, length, channels].
#    prediction_masks: Masks used for masked modeling, shape [batch, length]
#    batch: Batch (dict) with keys 'targets' and optionally 'batch_mask'.
#    metrics: The regression metrics to evaluate. The key is the name of the
#      metric, and the value is the metrics function, normalizer, and a bool
#      indicating whether to apply prediction_masks.
#    axis_name: List of axes on which we run the pmsum.

#   Returns:
#     A dict of metrics, in which keys are metrics name and values are tuples of
#     (metric, normalizer).
#   """
#   targets = batch['targets']
#   batch_weights = batch.get('batch_mask')
#   weights = jnp.expand_dims(batch_weights, axis=-1) * prediction_masks
#   evaluated_metrics = {}
#   for key, val in metrics.items():
#     curr_weights = weights if val[2] else batch_weights

#     val0 = val[0](
#         targets,
#         predictions,  # pytype: disable=wrong-arg-types  # jax-ndarray
#         curr_weights,
#     )
#     val1 = val[1](
#         targets,
#         predictions,  # pytype: disable=wrong-arg-types  # jax-ndarray
#         batch_weights,
#     )
#     evaluated_metrics[key] = (jnp.sum(val0), jnp.sum(val1))

#   return evaluated_metrics  # pytype: disable=bad-return-type  # jax-ndarray


# NEW
def regression_metrics_function(
    predictions: jnp.ndarray,
    prediction_masks: jnp.ndarray,
    batch: base_model.Batch,
    metrics: base_model.MetricNormalizerFnDict,
    axis_name: Union[str, Tuple[str, ...]] = 'batch',
) -> Dict[str, Tuple[float, int]]:
  """Calculate metrics for the regression task.

  Currently we assume each metric_fn has the API:
    ```metric_fn(predictions, targets, weights)```
  and returns an array of shape [batch_size,]. We also assume that to compute
  the aggregate metric, one should sum across all batches, then divide by the
  total samples seen. In this way we currently only support metrics of the 1/N
  sum f(inputs, targets). Note, the caller is responsible for dividing by
  the normalizer when computing the mean of each metric.

  Args:
   predictions: Output of model in shape [batch_size, num_patches, patch_size].
     specifically, shape [batch_size, gh * gw, ph * pw]. see
     patchify_imputationmask func in trainers/lsm_mae_utils for more info ph and
     pw are the time and modalities of the patches (i.e. patch size) gh and gw
     are the total number of number of patches (i.e. num patches size)
   prediction_masks: Masks used for masked modeling, shape [batch_size,
     num_patches]
   batch: Batch (dict) with keys 'targets' and optionally 'batch_mask'.
   metrics: The regression metrics to evaluate. The key is the name of the
     metric. The value is the metrics function, normalizer, a bool indicating
     whether to apply prediction_masks, and a bool indicating whether to apply
     patched_imputationmask
   axis_name: List of axes on which we run the pmsum.

  Returns:
    A dict of metrics, in which keys are metrics name and values are tuples of
    (metric, normalizer).
  """
  targets = batch['targets']
  batch_weights = batch.get('batch_mask')
  # create a mask with all data points, then chip at it based on input masks

  evaluated_metrics = {}
  for key, val in metrics.items():
    curr_weights = jnp.ones(targets.shape)

    # LOSS_ONLY_MASKED_TOKENS
    if val[2]:
      curr_weights = jnp.expand_dims(prediction_masks, axis=-1) * curr_weights
    # LOSS_IGNORE_IMPUTATION
    if val[3]:
      # see loss_function for ViTMAESingleChannelModel for a similar computation
      curr_weights = (
          jnp.logical_not(batch['patched_imputationmask']) * curr_weights
      )

    val0 = val[0](
        targets,
        predictions,  # pytype: disable=wrong-arg-types  # jax-ndarray
        curr_weights,
    )
    val1 = val[1](
        targets,
        predictions,  # pytype: disable=wrong-arg-types  # jax-ndarray
        batch_weights,
    )
    evaluated_metrics[key] = (jnp.sum(val0), jnp.sum(val1))

  return evaluated_metrics  # pytype: disable=bad-return-type  # jax-ndarray


def naive_evaluate(
    *,
    rng: jnp.ndarray,
    config: ml_collections.ConfigDict,
    dataset: dataset_utils.Dataset,
) -> Tuple[Any, Optional[Dict[str, float]], Optional[Dict[str, Any]]]:

  # Initialize model.
  metrics_fn = functools.partial(
        regression_metrics_function,
        metrics=lsm_vit_mae._REGRESSION_METRICS
  )

  valid_iter = dataset.valid_iter
  num_valid_ex = dataset.meta_data['num_val_examples']
  if not isinstance(valid_iter, dict):  # Only on validation set.
    valid_iter, num_valid_ex = {'valid': valid_iter}, {'valid': num_valid_ex}

  for val_name, val_iter in valid_iter.items():
    num_ex = num_valid_ex[val_name]
    # Ceil rounding such that we include the last incomplete batch.
    eval_batch_size = config.get('eval_batch_size', config.batch_size)
    total_eval_steps = int(np.ceil(num_ex / eval_batch_size))
    steps_per_eval = config.get('steps_per_eval') or total_eval_steps
    eval_metrics = []
    for idx in tqdm.tqdm(range(steps_per_eval)):
      rng, mask_rng = jax.random.split(rng)  # pylint: disable=unused-variable
      eval_batch = next(val_iter)

      # Naive Baselines
      e_metrics, _, _ = naive_eval_step(  # pylint: disable=unused-variable
          eval_batch,
          metrics_fn=metrics_fn,
          config=config,
          debug=config.debug_eval,
          rng=mask_rng,
      )

      eval_metrics.append(e_metrics)

  return eval_metrics


def MICE_evaluate(
    *,
    rng: jnp.ndarray,
    config: ml_collections.ConfigDict,
    dataset: dataset_utils.Dataset,
) -> Tuple[Any, Optional[Dict[str, float]], Optional[Dict[str, Any]]]:

  # Initialize model.
  metrics_fn = functools.partial(
        regression_metrics_function,
        metrics=lsm_vit_mae._REGRESSION_METRICS
  )

  valid_iter = dataset.valid_iter
  num_valid_ex = dataset.meta_data['num_val_examples']
  if not isinstance(valid_iter, dict):  # Only on validation set.
    valid_iter, num_valid_ex = {'valid': valid_iter}, {'valid': num_valid_ex}

  for val_name, val_iter in valid_iter.items():
    num_ex = num_valid_ex[val_name]
    # Ceil rounding such that we include the last incomplete batch.
    eval_batch_size = config.get('eval_batch_size', config.batch_size)
    total_eval_steps = int(np.ceil(num_ex / eval_batch_size))
    steps_per_eval = config.get('steps_per_eval') or total_eval_steps
    eval_metrics = []
    for idx in tqdm.tqdm(range(steps_per_eval)):
      rng, mask_rng = jax.random.split(rng)  # pylint: disable=unused-variable
      eval_batch = next(val_iter)

      # MICE Baseline
      e_metrics, _, _ = MICE_eval_step(  # pylint: disable=unused-variable
          eval_batch,
          metrics_fn=metrics_fn,
          config=config,
          debug=config.debug_eval,
          rng=mask_rng,
      )

      eval_metrics.append(e_metrics)

  return eval_metrics



In [None]:
# @title Calculate Metrics Function

def calc_metrics(
    eval_metrics, method_list=('linear', 'nn', 'mean_fill')
):

  calculated_metrics = dict()
  for k in method_list:
    mae_all = 0
    mae_masked = 0
    mse_all = 0
    mse_masked = 0
    ex_count = 0
    for v in eval_metrics:
      mae_sample = v[k]['mean_absolute_error_masked_ignoreimp_mean'][0]
      mse_sample = v[k]['mean_squared_error_masked_ignoreimp_mean'][0]

      # TODO: this is hacky.
      # added to remove huge outliers in mice...
      # if mse_sample > mae_sample + 100:
      #   continue

      mae_masked += mae_sample
      mse_masked += mse_sample
      ex_count += v[k]['mean_absolute_error_masked_ignoreimp_mean'][1]

    print(f'{k}')
    print('MAE Masked:', mae_masked / ex_count)
    print('MSE Masked:', mse_masked / ex_count)
    print('Ex Count:', ex_count)
    print()

    calculated_metrics[k] = (mae_masked / ex_count, mse_masked / ex_count, ex_count)

  return calculated_metrics


def dict_to_df(data, mask_strat):
  rows = []
  for method, metrics in data.items():
    row = {}
    row['mask_strategy'] = mask_strat
    row['method'] = method
    row['mean_absolute_error_masked_ignoreimp_mean'] = metrics[0]
    row['mean_squared_error_masked_ignoreimp_mean'] = metrics[1]
    rows.append(row)
  return pd.DataFrame(rows)


# Run Examples and Naive Baseline Eval

In [None]:
# @title Sample Config

# To set constants.
# 1) Dataset variables.
TRAIN_DATASET_NAME = [
    # 'lsm_v2_pretraining_sessions_-1_windowsize_1440_sensorfeatures_26_validonly_False_missingratio_0.2_timestamp_202504110407_doublethreshold_False',
    'lsm_v2_pretraining_sessions_-1_windowsize_1440_sensorfeatures_26_validonly_False_missingratio_0.5_timestamp_202504110538_doublethreshold_False'
    # 'lsm_v2_pretraining_sessions_-1_windowsize_1440_sensorfeatures_26_validonly_False_missingratio_0.8_timestamp_202504110551_doublethreshold_False',
]
VALID_DATASET_NAME = 'lsm_v2_missing_balanced_20250301_valid_dataset'
LSM_PREDEFINED_CONFIGS = predefined_configs.LSM_PREDEFINED_CONFIGS
LOSS_IGNORE_IMPUTATION = [True]
CACHE_DATASET = True
TRAIN_DATA_SIZES = [1_000, 10_000, 100_000, 1_000_000, 1_601_088]
USE_DATETIME_FEATURES = False
USE_TRAIN_AUGMENTATIONS = [False]
TRAIN_AUGMENTATIONS = ['stretch', 'flip', 'noise']
SHUFFLE_SEED = 42
SHUFFLE_BUFFER_SIZE = 250_000

# 2) Training / eval variables.
NUM_TRAIN_STEPS = 100_000
LRS = [5e-3]
WEIGHT_DECAYS = [1e-4]

# 3) Logging variables.
LOG_EVAL_SUMMARY_STEPS = NUM_TRAIN_STEPS / 10  # STEPS_PER_EPOCH
LOG_CHECKPOINT_STEPS = NUM_TRAIN_STEPS / 10  # LOG_EVAL_SUMMARY_STEPS * 5
LOG_TRAIN_SUMMARY_STEPS = NUM_TRAIN_STEPS / 100
MAX_NUM_CHECKPOINTS = int(NUM_TRAIN_STEPS / LOG_CHECKPOINT_STEPS)
ENABLE_DUMP_MODE = False

# Model variant
VARIANT = 'S'

LOSS_ONLY_MASKED_TOKENS = True

# Downstream Tasks.

# Linear probe eval.
LINEAR_PROBE_USE_TRAIN_AUGMENTATIONS = False
LINEAR_PROBE_TRAIN_AUGMENTATIONS = ['noise']

# THINGS THAT CHANGE
PATCHER_CONFIG = Patcher_Config(
    hidden_size=384,
    kernel_size=(10, 1),
    groups=1,
    mode='2d',
)

MASKER_CONFIG = Masker_Config(
    maskstrategy_list=[
        MaskStrategy_Config(
            strategy='random',
            mask_probability=0.8,
            weight=1,
            mask_dim='time',
            inherited_depend=True,
        ),
    ],
    on_cpu=True,
    inherited=True,
)


def get_config(runlocal=''):
  """Returns the ViT experiment configuration."""

  runlocal = bool(runlocal)

  # Experiment.
  config = ml_collections.ConfigDict()
  if runlocal:
    config.runlocal = True
  else:
    config.runlocal = False

  config.experiment_name = f'LSM V2-{TRAIN_DATASET_NAME[0]}'
  config.shuffle_seed = SHUFFLE_SEED
  config.loss_ignore_imputation = LOSS_IGNORE_IMPUTATION[0]

  # Dataset.
  config.data_dtype_str = 'float32'
  config.dataset_configs = ml_collections.ConfigDict()
  config.dataset_configs.dataset = TRAIN_DATASET_NAME[0]
  config.dataset_configs.valid_dataset = VALID_DATASET_NAME
  config.dataset_configs.num_classes = None
  config.dataset_configs.train_split = 'train'  # train data split
  config.dataset_configs.train_num_samples = TRAIN_DATA_SIZES[0]  # train sample
  config.dataset_configs.eval_split = 'valid'
  config.dataset_configs.eval_num_samples = 64 if runlocal else None
  config.dataset_configs.cache_dataset = CACHE_DATASET
  config.dataset_configs.prefetch_to_device = 2
  # Shuffle_buffer_size is per host, so small-ish is ok.
  config.dataset_configs.shuffle_buffer_size = (
      256 if runlocal else SHUFFLE_BUFFER_SIZE
  )
  config.enable_dump_mode = ENABLE_DUMP_MODE
  # Model.
  version = VARIANT

  version = 'Deb' if runlocal else version
######################## paste this for gen_eval !!!!! ########################
  config.model_name = 'lsm_vit_mae'
  config.model = ml_collections.ConfigDict()
  config.model.patcher_config = PATCHER_CONFIG
  config.model.num_heads = model_constants.NUM_HEADS[version]
  config.model.mlp_dim = model_constants.MLP_DIMS[version]
  config.model.num_layers = model_constants.NUM_LAYERS[version]
  config.model.dropout_rate = 0.0
  config.model.classifier = 'none'  # Has to be "none" for the autoencoder
  config.model.representation_size = None
  config.model.positional_embedding = 'sinusoidal_2d'
  config.model.positional_embedding_decoder = 'sinusoidal_2d'

  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = (PH, PW)

  # decoder
  config.model.decoder_config = ml_collections.ConfigDict()
  config.model.decoder_config.hidden_size = (
      model_constants.DECODER_HIDDEN_SIZES[version]
  )
  config.model.decoder_config.mlp_dim = model_constants.DECODER_MLP_DIMS[
      version
  ]
  config.model.decoder_config.num_layers = model_constants.DECODER_NUM_LAYERS[
      version
  ]
  config.model.decoder_config.num_heads = model_constants.DECODER_NUM_HEADS[
      version
  ]
  config.model.decoder_config.dropout_rate = 0.0
  config.model.decoder_config.attention_dropout_rate = 0.0

  config.masked_feature_loss = ml_collections.ConfigDict()
  config.masked_feature_loss.targets_type = 'rgb'
  config.masked_feature_loss.loss_only_masked_tokens = LOSS_ONLY_MASKED_TOKENS
  config.masked_feature_loss.loss_type = 'squared'  # 'squared' or 'absolute'

  config.masker_config = MASKER_CONFIG
  # Datetime features.
  config.use_datetime_features = USE_DATETIME_FEATURES
######################## paste this for gen_eval !!!!! ########################

  # Training.
  config.trainer_name = 'lsm_mae_trainer'
  config.batch_size = 8 if runlocal else BATCH_SIZE
  config.num_training_steps = 100 if runlocal else NUM_TRAIN_STEPS
  config.log_eval_steps = LOG_EVAL_SUMMARY_STEPS
  config.log_summary_steps = LOG_TRAIN_SUMMARY_STEPS
  config.rng_seed = 42
  config.use_train_augmentations = USE_TRAIN_AUGMENTATIONS[0]
  config.train_augmentations = TRAIN_AUGMENTATIONS
  sched = ml_collections.ConfigDict()
  sched.re = '(.*)'
  sched.lr_configs = ml_collections.ConfigDict()
  sched.lr_configs.learning_rate_schedule = 'compound'
  sched.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
  sched.lr_configs.total_steps = NUM_TRAIN_STEPS
  sched.lr_configs.steps_per_cycle = sched.lr_configs.total_steps
  sched.lr_configs.warmup_steps = int(NUM_TRAIN_STEPS * 0.05)
  sched.lr_configs.base_learning_rate = LRS[0]
  config.schedule = ml_collections.ConfigDict({'all': sched})

  # *Single* optimizer.
  optim = ml_collections.ConfigDict()
  optim.optax_name = 'scale_by_adam'
  # optim.optax = dict(mu_dtype='bfloat16')
  optim.optax_configs = ml_collections.ConfigDict({  # Optimizer settings.
      'b1': 0.9,
      'b2': 0.95,
  })
  config.optax = dict(mu_dtype='bfloat16')
  optim.max_grad_norm = 1.0

  optim.weight_decay = WEIGHT_DECAYS[0]
  optim.weight_decay_decouple = True
  config.optimizer = optim

  # Downstream Tasks.
  # TODO(girishvn): These (0, 1) need to be adapted to LSM datasets
  # 0) Linear Probing.
  # 1) Fewshot.

  # 2) Reconstruction Eval Tasks (Forecast and Imputation).
  config.forecast = LSM_PREDEFINED_CONFIGS['eval_fore_1day']
  config.imputation = LSM_PREDEFINED_CONFIGS['eval_imp_1day']
  config.random_imputation = LSM_PREDEFINED_CONFIGS['eval_randimp_1day']

  # Logging.
  config.write_summary = True
  config.xprof = True  # Profile using xprof.
  config.checkpoint = True  # Do checkpointing.
  config.checkpoint_steps = LOG_CHECKPOINT_STEPS
  config.debug_train = False  # Debug mode during training.
  config.debug_eval = False  # Debug mode during eval.
  config.max_checkpoints_to_keep = MAX_NUM_CHECKPOINTS
  # BEGIN GOOGLE-INTERNAL
  if runlocal:
    # Current implementation fails with UPTC.
    config.count_flops = False
  # END GOOGLE-INTERNAL

  return config



In [None]:
# @title Plot Naive Baseline Examples


MASKER_CONFIG = Masker_Config(
    maskstrategy_list=[
        MaskStrategy_Config(
            strategy='bar',
            mask_probability=2/26,
            weight=1,
            mask_dim='modality',
            inherited_depend=False,
        ),
    ],
    on_cpu=True,
    inherited=True,
    strictmaskperc=0.0,
)


# Things to set
BATCH_SIZE = 8
PH = 10
PW = 1

config = get_config(runlocal=False)  # get configs
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = get_dataset.get_dataset(config, data_rng)

batch_x = next(dataset.valid_iter)
x = jnp.asarray(batch_x['input_signal'][0])
b, h, w, c = x.shape

dropout_rng, rng = jax.random.split(rng)
# mask_info = patch_and_mask_img(x, dropout_rng, config)
# mask_indices = mask_info['mask_indices']
# unmasked_indices = mask_info['unmasked_indices']

token_mask = batch_x['token_mask'][0]
pixel_mask = get_pixel_mask(token_mask, PH, PW, x.shape)
mask_info = {
    'token_mask': token_mask,
    'pixel_mask': pixel_mask
}

logits_dict, aux = pandas_naive_baselines(
    x, mask_info, config
)

# NOTE: Simillar code is implemented in the function `pandas_naive_baselines`
# Get input array.
x_nan_masked = x.at[jnp.where(pixel_mask == 1)].set(jnp.nan)

# set pixl mask
x_batch_masked = jnp.transpose(x_nan_masked, (1, 0, 2, 3))  # b, h, w, c -> h, b, w, c
x_batch_masked = jnp.reshape(x_batch_masked, (h, b*w))  # b, w, h, c -> b*w, h, c
x_df = pd.DataFrame(x_batch_masked)

# Linear Interpolate
linear_interp_df = x_df.interpolate(method='linear', limit_direction='both', axis=0)
linear_interp_df = linear_interp_df.fillna(0)

# Nearest Neighbor Interpolate
nn_interp_df = x_df.interpolate(method='nearest', limit_direction='both', axis=0)
nn_interp_df = nn_interp_df.bfill().ffill()
nn_interp_df = nn_interp_df.fillna(0)

# Mean Fill Interpolate
mean_fill_interp_df = x_df.fillna(x_df.mean())
mean_fill_interp_df = mean_fill_interp_df.fillna(0)

# Convert from df to jnp array.
linear_interp = linear_interp_df.to_numpy()
nn_interp = nn_interp_df.to_numpy()
mean_fill_interp = mean_fill_interp_df.to_numpy()
linear_interp = jnp.asarray(linear_interp)
nn_interp = jnp.asarray(nn_interp)
mean_fill_interp = jnp.asarray(mean_fill_interp)

# Reshape to recover batch dim, feature dim, and channel dim
linear_interp = jnp.reshape(linear_interp, (h, b, w, c))
nn_interp = jnp.reshape(nn_interp, (h, b, w, c))
mean_fill_interp = jnp.reshape(mean_fill_interp, (h, b, w, c))

# Transpose to original shape
linear_interp = jnp.transpose(linear_interp, (1, 0, 2, 3))
nn_interp = jnp.transpose(nn_interp, (1, 0, 2, 3))
mean_fill_interp = jnp.transpose(mean_fill_interp, (1, 0, 2, 3))


# Plot example
for idx in range(b):
  vmin = jnp.min(x[idx])
  vmax = jnp.max(x[idx])
  plt.figure(figsize=(20, 20))
  plt.imshow(jnp.transpose(x[idx], (1, 0, 2)), vmin=vmin, vmax=vmax)
  plt.title('Input')

  plt.figure(figsize=(20, 20))
  plt.imshow(jnp.transpose(pixel_mask[idx], (1, 0, 2)))
  plt.title('Mask')

  plt.figure(figsize=(20, 20))
  plt.imshow(jnp.transpose(x_nan_masked[idx], (1, 0, 2)), vmin=vmin, vmax=vmax)
  plt.title('Masked Input')

  plt.figure(figsize=(20, 20))
  plt.imshow(jnp.transpose(linear_interp[idx], (1, 0, 2)), vmin=vmin, vmax=vmax)
  plt.title('Linear Interpolation')

  plt.figure(figsize=(20, 20))
  plt.imshow(jnp.transpose(nn_interp[idx], (1, 0, 2)), vmin=vmin, vmax=vmax)
  plt.title('Nearest Interpolation')

  plt.figure(figsize=(20, 20))
  plt.imshow(jnp.transpose(mean_fill_interp[idx], (1, 0, 2)), vmin=vmin, vmax=vmax)
  plt.title('Mean Fill Interpolation')

  plt.show()
  print('\n\n\n')


print('BATCH KEYS\n', batch_x.keys())


### Run Single Eval

In [None]:
# TO SET:
BATCH_SIZE = 8
PH = 10
PW = 1

MASKER_CONFIG = Masker_Config(
    maskstrategy_list=[
        MaskStrategy_Config(
            strategy='random',
            mask_probability=0.8,
            inherited_depend=True,
        ),
    ],
    on_cpu=True,
    inherited=True,
)

# Start pipeline.
config = get_config(runlocal=False)  # get configs
rng = jax.random.PRNGKey(config.rng_seed)  # set seeds
data_rng, rng = jax.random.split(rng)
dataset = get_dataset.get_dataset(config, data_rng)  # get dataset

# Run eval.
eval_metrics = naive_evaluate(
  rng=rng,
  config=config,
  dataset=dataset,
)

# Calculate and print metrics.
calc_metrics(eval_metrics)

### Run Eval Sweep

In [None]:
# Constants
BATCH_SIZE = 8
PH = 10
PW = 1

# --------------------------------------
# OPTION 2: Sweep across time imputation / forecast / 0.8 random impute tasks
# Constants: patch size 10x5
# --------------------------------------

# RANDOM_SWEEP = ['random_0.3', 'random_0.5', 'random_0.8']
# IMPUTATION_SWEEP = ['imp_0.00695', 'imp_0.02084', 'imp_0.04167', 'imp_0.125']
# FORECAST_SWEEP = ['forecast_0.00695', 'forecast_0.02084', 'forecast_0.04167', 'forecast_0.125']
# MASK_SWEEP = RANDOM_SWEEP + IMPUTATION_SWEEP + FORECAST_SWEEP

MASK_SWEEP = ['modality_2', 'modality_6', 'modality_12']

TOTAL_MODALITIES = 26

# --------------------------------------
# RUN SWEEP PIPELINE.
# NOTE. This may take hours / days depending on the sweep size.
# Make sure to re-up AoD grant every 20 hours.
# --------------------------------------
count = 0
metrics_list = []
df_list = []
print('Running naive baseline sweep...')
for t in MASK_SWEEP:

  prob = float(t.split('_')[-1])

  if t == 'random_0.3':
    VALID_DATASET_NAME = 'lsm_v2_missing_balanced_20250502_valid_dataset_bounded_30p'
  elif t == 'random_0.5':
    VALID_DATASET_NAME = 'lsm_v2_missing_balanced_20250301_valid_dataset_bounded_50p'
  else:
    VALID_DATASET_NAME = 'lsm_v2_missing_balanced_20250301_valid_dataset'

  if 'random' in t:
    MASKER_CONFIG = Masker_Config(
      maskstrategy_list=[
          MaskStrategy_Config(
              strategy='random',
              mask_probability=prob,
              weight=1,
              mask_dim='time',
              inherited_depend=True,
          ),
      ],
      on_cpu=True,
      inherited=True,
    )

  elif 'imp' in t:
    MASKER_CONFIG = Masker_Config(
      maskstrategy_list=[
          MaskStrategy_Config(
              strategy='bar',
              mask_probability=prob,
              mask_dim='time',
              mask_dim_contiguous=True,
              inherited_depend=False,
          )
      ],
      on_cpu=True,
      inherited=True,
      strictmaskperc=0.0
    )


  elif 'forecast' in t:
    MASKER_CONFIG = Masker_Config(
      maskstrategy_list=[
          MaskStrategy_Config(
              strategy='bar',
              mask_probability=prob,
              mask_dim='time',
              mask_dim_contiguous=True,
              mask_dim_forecasting=True,
              inherited_depend=False,
          )
      ],
      on_cpu=True,
      inherited=True,
      strictmaskperc=0.0
    )

  elif 'modality' in t:

    MASKER_CONFIG = Masker_Config(
      maskstrategy_list=[
          MaskStrategy_Config(
              strategy='bar',
              mask_probability=prob/TOTAL_MODALITIES,
              weight=1,
              mask_dim='modality',
              inherited_depend=False,
          ),
      ],
      on_cpu=True,
      inherited=True,
      strictmaskperc=0.0,
    )

  else:
    raise ValueError(f'Unknown masking strategy: {t}')

  # Start pipeline.
  config = get_config(runlocal=False)  # get configs
  rng = jax.random.PRNGKey(config.rng_seed)  # set seeds
  data_rng, rng = jax.random.split(rng)
  dataset = get_dataset.get_dataset(config, data_rng)  # get dataset

  count += 1
  print('\nIteration:', count)
  print('Mask Strat:', t)

  # Run eval.
  eval_metrics = naive_evaluate(
    rng=rng,
    config=config,
    dataset=dataset,
  )

  # Calculate and print metrics.
  metrics = calc_metrics(eval_metrics)

  # Pandas DF
  df_list.append(dict_to_df(metrics, t))

  # Add to list
  metrics['masking_strategy'] = t
  metrics_list.append(metrics)

  print('Metrics:\n', metrics)
  print('\n')

print('\nDone.')

In [None]:
naive_baselines = ['mean_fill', 'nn', 'linear']
cols = list(metrics_list[0].keys())
for nb in naive_baselines:
  cols.remove(nb)

df = pd.DataFrame(columns=cols)
for m in metrics_list:

  # Iterate through baselines
  for nb in naive_baselines:
    row = dict()
    row['naive_baseline'] = nb
    row['min_valid_mean_absolute_error_masked'] = float(m[nb][0])
    row['min_valid_mean_squared_error_masked'] = float(m[nb][1])
    row['example_count'] = int(m[nb][2])
    for col in cols:
      row[col] = m[col]

    new_row = pd.DataFrame([row])
    df = pd.concat([df, new_row], ignore_index=True)

pd.set_option('display.max_rows', None)
df