# Exploring Reproducible Randomnenss
##### Colab Kernel (Electrodes)
##### Dataset (Electrodes)

Grants command for Access on Demand (AoD):

https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-electrodes-deid-colab-jobs&reason=b%2F314799341


# Setup

In [None]:
# @title Imports

import io
import functools
from typing import Any, Callable, Dict, Iterator, Tuple, Optional, Type, Union
import time
from collections import Counter

from absl import logging
from clu import metric_writers
from clu import periodic_actions
from clu import platform

import flax
from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.profiler

import pandas as pd
import ml_collections
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline

from colabtools import adhoc_import
with adhoc_import.Google3():
  from scenic.dataset_lib import dataset_utils
  from scenic.google.xm import xm_utils
  from scenic.model_lib.base_models import base_model
  from scenic.model_lib.base_models import model_utils
  from scenic.model_lib.layers import nn_ops
  from scenic.model_lib.layers import nn_layers
  from scenic.projects.baselines import vit
  from scenic.train_lib import optax as scenic_optax
  from scenic.train_lib import pretrain_utils
  from scenic.train_lib import train_utils

  from scenic.projects.multimask.models import model_utils as mm_model_utils

  from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
  from google3.experimental.largesensormodels.scenic.datasets import lsm_tiny_dataset
  from google3.experimental.largesensormodels.scenic.models import lsm_vit as lsm_vit_mae
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_utils as lsm_model_utils
  from google3.experimental.largesensormodels.scenic.trainers import lsm_mae_trainer

  from google3.pyglib import gfile


Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[
    [jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]],
    Dict[str, Tuple[float, int]],
]
LossFn = Callable[
    [jnp.ndarray, Batch, Optional[jnp.ndarray], jnp.ndarray], float
]
LrFns = Dict[str, Callable[[jnp.ndarray], jnp.ndarray]]
Patch = Union[Tuple[int, int], Tuple[int, int, int]]


In [None]:
# @title Sample Model Config

r"""A config to train a Tiny ViT MAE on LSM 5M dataset.

Forked from google3/third_party/py/scenic/projects/multimask/configs/mae_cifar10_tiny.py

To run on XManager:
gxm third_party/py/scenic/google/xm/launch_xm.py -- \
--binary //experimental/largesensormodels/scenic:main \
--config=experimental/largesensormodels/scenic/configs/mae_lsm_tiny.py \
--platform=vlp_4x4 \
--exp_name=lsm_mae_tier2_TinyShallow_10_5_res \
--workdir=/cns/dz-d/home/xliucs/lsm/xm/\{xid\} \
--xm_resource_alloc=group:mobile-dynamic/h2o-ai-gqm-quota \
--priority=200

To run locally:
./third_party/py/scenic/google/runlocal.sh \
--uptc="" \
--binary=//experimental/largesensormodels/scenic:main \
--config=$(pwd)/experimental/largesensormodels/scenic/configs/mae_lsm_tiny.py:runlocal
"""


# To set constants.
DATASET_NAME = 'lsm_300min_10M_impute'
CACHE_DATASET = False
TRAIN_DATA_SIZE = 1000000  # 100k train samples
BATCH_SIZE = 8
NUMBER_OF_EPOCH = 100
REPEAT_DATA = False

# Model variant / patch H (time steps) / patch W (features)
VARIANT = 'TiShallow/10/5'
LRS = [1e-3]
TOKEN_MASK_PROB = 'constant_0.8'
LOSS_ONLY_MASKED_TOKENS = True
USE_DATETIME_FEATURES = True
USE_TRAIN_AUGMENTATIONS = True
TRAIN_AUGMENTATIONS = ['stretch', 'flip', 'noise']
OHE_LABELS = True

# Derived constants.
TRAIN_DATA_SIZE = min(
    TRAIN_DATA_SIZE,
    dataset_constants.lsm_dataset_constants[DATASET_NAME]['num_train_examples']
)

STEPS_PER_EPOCH = max(1, int(TRAIN_DATA_SIZE / BATCH_SIZE))
NUM_TRAIN_STEPS = int(NUMBER_OF_EPOCH * STEPS_PER_EPOCH)

LOG_EVAL_SUMMARY_STEPS = STEPS_PER_EPOCH
LOG_CHECKPOINT_STEPS = LOG_EVAL_SUMMARY_STEPS * 5
MAX_NUM_CHECKPOINTS = int(NUM_TRAIN_STEPS / LOG_CHECKPOINT_STEPS)


def get_config_common_few_shot(
    batch_size: Optional[int] = None,
    target_resolution: int = 224,
    resize_resolution: int = 256,
) -> ml_collections.ConfigDict:
  """Returns a standard-ish fewshot eval configuration.

  Copied from
  third_party/py/scenic/projects/baselines/configs/google/common/common_fewshot.py

  Args:
    batch_size: The batch size to use for fewshot evaluation.
    target_resolution: The target resolution of the fewshot evaluation.
    resize_resolution: The resize resolution of the fewshot evaluation.

  Returns:
    A ConfigDict with the fewshot evaluation configuration.
  """
  config = ml_collections.ConfigDict()
  config.batch_size = batch_size
  config.representation_layer = 'pre_logits'
  config.log_eval_steps = 25_000
  config.datasets = {
      'birds': ('caltech_birds2011', 'train', 'test'),
      'caltech': ('caltech101', 'train', 'test'),
      'cars': ('cars196:2.1.0', 'train', 'test'),
      'cifar100': ('cifar100', 'train', 'test'),
      'col_hist': ('colorectal_histology', 'train[:2000]', 'train[2000:]'),
      'dtd': ('dtd', 'train', 'test'),
      'imagenet': ('imagenet2012_subset/10pct', 'train', 'validation'),
      'pets': ('oxford_iiit_pet', 'train', 'test'),
      'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
  }
  config.pp_train = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)'
  config.pp_eval = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)'
  config.shots = [1, 5, 10, 25]
  config.l2_regs = [2.0**i for i in range(-10, 20)]
  config.walk_first = ('imagenet', 10)

  return config


def get_config(runlocal=''):
  """Returns the ViT experiment configuration."""

  runlocal = bool(runlocal)

  # Experiment.
  config = ml_collections.ConfigDict()
  config.experiment_name = f'electrodes-mae-{DATASET_NAME}-{TRAIN_DATA_SIZE}'
  config.dataset_name = f'lsm_prod/{DATASET_NAME}'

  config.shuffle_seed = 42

  # Dataset.
  config.data_dtype_str = 'float32'
  config.dataset_configs = ml_collections.ConfigDict()
  config.dataset_configs.dataset = f'lsm_prod/{DATASET_NAME}'
  # config.dataset_configs.num_classes = NUM_CLASSES
  config.dataset_configs.train_split = 'train'  # train data split
  config.dataset_configs.train_num_samples = TRAIN_DATA_SIZE  # train sample
  # eval data split - note: this split is used for validation and test.
  config.dataset_configs.eval_split = 'test[:64]' if runlocal else 'test'
  config.dataset_configs.cache_dataset = CACHE_DATASET
  config.dataset_configs.prefetch_to_device = 2
  # Shuffle_buffer_size is per host, so small-ish is ok.
  config.dataset_configs.shuffle_buffer_size = 250_000
  config.dataset_configs.repeat_data = REPEAT_DATA
  config.dataset_configs.ohe_labels = OHE_LABELS

  # Model.
  if len(VARIANT.split('/')) == 3:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[2]  # patch height
  elif len(VARIANT.split('/')) == 2:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[1]  # patch height
  else:
    raise ValueError(f'Invalid model variant: {VARIANT}')

  version = 'Deb' if runlocal else version
  config.model_name = 'lsm_vit_mae'
  config.model = ml_collections.ConfigDict()
  # encoder
  config.model.hidden_size = model_constants.HIDDEN_SIZES[version]
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = tuple([int(patch_h), int(patch_w)])
  config.model.num_heads = model_constants.NUM_HEADS[version]
  config.model.mlp_dim = model_constants.MLP_DIMS[version]
  config.model.num_layers = model_constants.NUM_LAYERS[version]
  config.model.dropout_rate = 0.0
  config.model.classifier = 'none'  # Has to be "none" for the autoencoder
  config.model.representation_size = None
  config.model.positional_embedding = 'sinusoidal_2d'
  config.model.positional_embedding_decoder = 'sinusoidal_2d'
  # decoder
  config.model.decoder_config = ml_collections.ConfigDict()
  config.model.decoder_config.hidden_size = (
      model_constants.DECODER_HIDDEN_SIZES[version]
  )
  config.model.decoder_config.mlp_dim = model_constants.DECODER_MLP_DIMS[
      version
  ]
  config.model.decoder_config.num_layers = model_constants.DECODER_NUM_LAYERS[
      version
  ]
  config.model.decoder_config.num_heads = model_constants.DECODER_NUM_HEADS[
      version
  ]
  config.model.decoder_config.dropout_rate = 0.0
  config.model.decoder_config.attention_dropout_rate = 0.0

  config.masked_feature_loss = ml_collections.ConfigDict()
  config.masked_feature_loss.targets_type = 'rgb'
  config.masked_feature_loss.token_mask_probability = TOKEN_MASK_PROB
  config.masked_feature_loss.loss_only_masked_tokens = LOSS_ONLY_MASKED_TOKENS
  config.masked_feature_loss.loss_type = 'squared'  # 'squared' or 'absolute'

  # Datetime features.
  config.use_datetime_features = USE_DATETIME_FEATURES

  # Training.
  config.trainer_name = 'lsm_mae_trainer'
  config.batch_size = 8 if runlocal else BATCH_SIZE
  config.num_training_steps = NUM_TRAIN_STEPS
  config.log_eval_steps = LOG_EVAL_SUMMARY_STEPS
  config.log_summary_steps = LOG_EVAL_SUMMARY_STEPS
  config.rng_seed = 42
  config.use_train_augmentations = USE_TRAIN_AUGMENTATIONS
  config.train_augmentations = TRAIN_AUGMENTATIONS
  sched = ml_collections.ConfigDict()
  sched.re = '(.*)'
  sched.lr_configs = ml_collections.ConfigDict()
  sched.lr_configs.learning_rate_schedule = 'compound'
  sched.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
  sched.lr_configs.total_steps = NUM_TRAIN_STEPS
  sched.lr_configs.steps_per_cycle = sched.lr_configs.total_steps
  sched.lr_configs.warmup_steps = STEPS_PER_EPOCH * NUMBER_OF_EPOCH * 0.05
  sched.lr_configs.base_learning_rate = LRS[0]
  config.schedule = ml_collections.ConfigDict({'all': sched})

  # *Single* optimizer.
  optim = ml_collections.ConfigDict()
  optim.optax_name = 'scale_by_adam'
  # optim.optax = dict(mu_dtype='bfloat16')
  optim.optax_configs = ml_collections.ConfigDict({  # Optimizer settings.
      'b1': 0.9,
      'b2': 0.95,
  })
  config.optax = dict(mu_dtype='bfloat16')
  optim.max_grad_norm = 1.0

  optim.weight_decay = 1e-4
  optim.weight_decay_decouple = True
  config.optimizer = optim

  # Fewshot.
  # TODO(girishvn): This needs to be adapted to electrode dataset
  config.fewshot = get_config_common_few_shot(batch_size=config.batch_size)
  config.fewshot.datasets = {}
  config.fewshot.walk_first = ()
  config.fewshot.representation_layer = 'pre_logits'
  config.fewshot.log_eval_steps = LOG_EVAL_SUMMARY_STEPS

  # Logging.
  config.write_summary = True
  config.xprof = True  # Profile using xprof.
  config.checkpoint = True  # Do checkpointing.
  config.checkpoint_steps = LOG_CHECKPOINT_STEPS
  config.debug_train = False  # Debug mode during training.
  config.debug_eval = False  # Debug mode during eval.
  config.max_checkpoints_to_keep = MAX_NUM_CHECKPOINTS
  # BEGIN GOOGLE-INTERNAL
  if runlocal:
    # Current implementation fails with UPTC.
    config.count_flops = False
  # END GOOGLE-INTERNAL

  return config


# BEGIN GOOGLE-INTERNAL
def get_hyper(hyper):
  """Defines the hyper-parameters sweeps for doing grid search."""
  return hyper.product([
      hyper.sweep('config.schedule.all.lr_configs.base_learning_rate', LRS),
  ])


In [None]:
# @title Data Dir

data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/dmcduff/ttl=6w/msa_1_5/lsm_tfds_datasets'

In [None]:
# @title Dataset

"""Electrodes dataset data preprocesser and loader.

Adapted from a combination of the following files:
google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py
google3/third_party/py/scenic/dataset_lib/dataset_utils.py
"""

import functools
from typing import Any, Optional

from absl import logging
import jax.numpy as jnp
import jax.profiler
import ml_collections  # pylint: disable=unused-import
from scenic.dataset_lib import dataset_utils
import tensorflow as tf
import tensorflow_datasets as tfds

from google3.experimental.largesensormodels.scenic.datasets import dataset_constants


def get_height_crop_width_pad(
    feat_shape: tuple[int, int, int], patch_size: tuple[int, int]
):
  """Gets H crop, and W pad values for an image to allow for even patching.

  NOTE: This assumes that the image is of the shape [H, W, C],
  where H is the time axis, and W is the feature axis.

  Args:
    feat_shape: tuple; Shape of the image (H, W, C).
    patch_size: tuple; Size of the patches to extract from the image (H, W).

  Returns:
    crop_h: int; Number of rows to crop from the top of the image.
    pad_w: tuple; Number of columns to pad on the left and right of the image.
    feat_shape_new: tuple; Shape of the new feature image (H, W, C).
  """

  height, width, channels = feat_shape
  patch_h, patch_w = patch_size

  # Crop H (time) for even patches
  num_patches_h = height // patch_h
  crop_h = height - int(num_patches_h * patch_h)

  # Pad W to make even patches
  remainder_w = width % patch_w
  if remainder_w != 0:
    pad_total = patch_w - remainder_w
    pad1 = pad_total // 2
    pad2 = pad_total - pad1
  else:
    pad1 = 0
    pad2 = 0
  pad_w = (pad1, pad2)

  # Calculate new shape
  height_new = height - crop_h
  width_new = pad1 + width + pad2
  feat_shape_new = (height_new, width_new, channels)

  return (crop_h, 0), pad_w, feat_shape_new


def patch_compatible_resize_example(
    example: tf.Tensor,
    patch_size: tuple[int, int],
):
  """Crops and pads features to allow for a integer number of patches.

  NOTE: This assumes that the image is in the shape [H, W, C], where H is the
  Time axis which can be cropped, and W is the feature axis which can be padded.

  NOTE: This should be applied AFTER augmentations as to ensure that noise is
    not applied to zero-padding.

  Args:
    example: A dictionary of inputs containing at least the 'input_signal',
      'labels', and possibly 'datetime_signal' fields.
    patch_size: tuple; Size of the patches to extract from the image (H, W).

  Returns:
    A dictionary of inputs containing at least the 'input_signal', 'labels',
      and possibly 'datetime_signal' fields. Where 'input_signal', and possibly
      'datetime_signal' fields are H cropped and W padded.
  """
  # Parse inputs
  features = example['input_signal']
  time_features = example['datetime_signal']  # datetime features

  # Crop time axis (h) and pad feature axis (w)
  crop_h, pad_w, _ = get_height_crop_width_pad(features.shape, patch_size)
  features = features[crop_h[0] :, :, :]
  features = tf.pad(
      features,
      paddings=[[0, 0], pad_w, [0, 0]],
      mode='CONSTANT',
      constant_values=0,
  )

  # Crop time axis (h) and pad time feature axis (w) of datetime features.
  if time_features is not None:
    time_crop_h, time_pad_w, _ = get_height_crop_width_pad(
        time_features.shape, patch_size
    )
    time_features = time_features[time_crop_h[0] :, :, :]
    time_features = tf.pad(
        time_features,
        [[0, 0], time_pad_w, [0, 0]],
        mode='CONSTANT',
        constant_values=0,
    )

  example['input_signal'] = features
  example['datetime_signal'] = time_features

  return example


def preprocess_example(example, dataset_name, dtype=tf.float32):
  """Preprocesses the given example.

  Adapted from google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py

  Args:
    example: dict; Example that has an 'image' and a 'label'.
    dataset_name: str; Name of the dataset. This is used to extract the
      datetime features.
    dtype: Tensorflow data type; Data type of the image.

  Returns:
    A preprocessed example.

  NOTE: This assumes that the image is in the shape [H, W, C],
    where H is the Time axis, and W is the feature axis.
  """
  dataset_name = dataset_name.split('/')[-1]
  features = tf.cast(example['input_signal'], dtype=dtype)
  time_features = dataset_constants.lsm_dataset_constants[dataset_name].get(
      'datetime_features', None
  )

  if time_features is None:
    raise ValueError(dataset_name)

  # Split input into inputs and time-features
  feature_indices = list(range(features.shape[1]))
  if time_features is not None:

    # Get the inidices of datetime_features,
    # and split them from the indicies of other features.
    time_feature_indices = list(time_features['indices'])
    feature_indices = list(set(feature_indices) - set(time_features['indices']))
    time_feature_indices = tf.convert_to_tensor(time_feature_indices)
    feature_indices = tf.convert_to_tensor(feature_indices)

    # Using the above indices, split the feature tensor.
    time_features = tf.gather(features, time_feature_indices, axis=1)
    features = tf.gather(features, feature_indices, axis=1)
  else:
    time_features = None

  # Stress / Mood / Activity Labels:
  # A) Binary label of stress (0/1).
  stress_label = tf.cast(example['label'], dtype=tf.int32)
  # B) Boolean logs (True/False) of an logged exercise or mood event.
  # (exercise and mood events are mutally exclusive).
  exercise_log = example['metadata']['exercise_log']
  mood_log = example['metadata']['mood_log']
  # C) The log value (int 64 log code) for an excercise or mood event.
  # NOTE: that exercise and mood events DO NOT occur simultaneously
  log_value = tf.cast(example['metadata']['log_value'], tf.int32)

  # Return preprocessed feature and desired labels.
  # A) If activities or mood dataset: the log value is indexed [0, n classes],
  # one-hot encoded, and returned as the label.
  if ('activities' in dataset_name or 'mood' in dataset_name):
    # One hot encode the log value.
    # a) offset value of log_value - an artifact of dataset creation.
    log_value_offset = tf.cast(
        dataset_constants.lsm_dataset_constants[dataset_name][
            'log_value_offset'
        ],
        tf.int32
    )
    # b) list of possible labels (log_values) for a dataset.
    log_value_label_list = tf.convert_to_tensor(
        dataset_constants.lsm_dataset_constants[dataset_name]['log_values']
    )
    # c) offset log_value.
    log_value_label_list = log_value_label_list - log_value_offset
    n_classes = len(log_value_label_list)  # number of classes in label map
    # d) generate label map.
    lookup_initializer = tf.lookup.KeyValueTensorInitializer(
        keys=log_value_label_list, values=tf.range(n_classes)
    )
    label_map = tf.lookup.StaticHashTable(lookup_initializer, default_value=-1)
    # e) get label index from label map.
    label_idx = label_map.lookup(log_value)
    return {
        'input_signal': features,
        'datetime_signal': time_features,
        'label': tf.one_hot(label_idx, n_classes),
        'exercise_log': exercise_log,
        'mood_log': mood_log,
        'log_value': log_value,
    }

  # B) If stress dataset: the stress_label is one-hot encoded,
  # and returned as the label.
  elif 'stress' in dataset_name:
    stress_label = tf.cast(stress_label, tf.int32)
    return {
        'input_signal': features,
        'datetime_signal': time_features,
        'label': tf.one_hot(stress_label, 2)
    }

  # C) This is used for pretraining datasets.
  else:
    return {
        'input_signal': features,
        'datetime_signal': time_features,
        'stress_label': stress_label,
        'exercise_log': exercise_log,
        'mood_log': mood_log,
        'log_value': log_value,
    }


def augment_example(example, augmentations, seed=0):
  """Applies augmentations (stretch, flip, noise) to the features."""

  augmented_feat = example['input_signal']
  height, width, _ = augmented_feat.shape

  # TODO REMOVE THIS
  apply_noise = -1.0
  noise_std = -1.0
  apply_flip = -1.0
  apply_stretch = -1.0
  stretch = -1.0
  # TODO REMOVE THIS



  # Stretch (along time/height axis).
  if 'stretch' in augmentations:
    apply_stretch = tf.random.uniform([], minval=0, maxval=1, seed=seed)
    if apply_stretch >= 0.5:
      stretch = tf.random.uniform([], minval=1.0, maxval=1.5, seed=seed+1)
      stretched_height = int(height * stretch)
      augmented_feat = tf.image.resize(
          augmented_feat, size=[int(stretched_height), int(width)]
      )
      offset_height = stretched_height - height
      augmented_feat = tf.image.crop_to_bounding_box(
          image=augmented_feat,
          target_height=height,
          target_width=width,
          offset_height=offset_height,
          offset_width=0,
      )

      # TODO(girishvn): apply translate?
      augmented_feat = augmented_feat[
          -1 * height :, :, :
      ]  # crop to original size

  # Flip (along time/height axis).
  if 'flip' in augmentations:
    apply_flip = tf.random.uniform([], minval=0, maxval=1, seed=seed+3)
    if apply_flip >= 0.5:
      augmented_feat = tf.image.flip_up_down(augmented_feat)

  # Noise (gaussian).
  if 'noise' in augmentations:
    apply_noise = tf.random.uniform([], minval=0, maxval=1, seed=seed+4)
    if apply_noise >= 0.5:
      noise_std = tf.random.uniform([], minval=0.0, maxval=0.5, seed=seed+5)
      noise = tf.random.normal(
          shape=tf.shape(augmented_feat),
          mean=0.0,
          stddev=noise_std,
          seed=seed+6
      )
      augmented_feat += noise

  #TODO REMOVE
  example['aug_apply_noise'] = apply_noise
  example['aug_noise_std'] = noise_std
  example['aug_apply_flip'] = apply_flip
  example['aug_apply_stretch'] = apply_stretch
  example['aug_stretch'] = stretch
  #TODO REMOVE

  example['input_signal'] = augmented_feat
  return example


def update_metadata(metadata, dataset_name, patch_size):
  """Update metadata to reflect resizing and addition of datetime features."""
  # Setup: Get dataset name, feature shape, and possible datetime features.
  metadata_update = dict()
  dataset_name = dataset_name.split('/')[-1]
  time_features = dataset_constants.lsm_dataset_constants[dataset_name].get(
      'datetime_features', None
  )
  feature_shape = list(metadata['input_shape'][1:])
  feature_indices = list(range(feature_shape[1]))

  # Split features from time series features
  # NOTE: This assumes that the original 'input_signal' field has sensor
  # features contactanated to datetime features along the feature (w) dimension.
  if time_features is not None:
    # Get datetime indicies
    time_feature_indices = list(time_features['indices'])
    # Remove datetime indices from feature indices
    feature_indices = list(set(feature_indices) - set(time_features['indices']))
    # Get updated feature and datetime feature shapes.
    time_feature_shape = feature_shape.copy()  # update time feature shape
    time_feature_shape[1] = len(time_feature_indices)
    feature_shape[1] = len(feature_indices)  # update feature shape
  else:
    time_feature_shape = None

  # Padding: Update shape to reflect padding (for perfect patching).
  # valid_feats arrays denote which features are valid (1) vs padded (0).
  # 1. Update for sensor features
  _, pad_w, feat_shape_new = get_height_crop_width_pad(
      tuple(feature_shape), patch_size
  )
  valid_feat_mask = [0] * pad_w[0] + [1] * feature_shape[1] + [0] * pad_w[1]
  metadata_update['input_shape'] = tuple([-1] + list(feat_shape_new))
  metadata_update['input_valid_feats'] = tuple(valid_feat_mask)

  # 2. Update for datetime features
  if time_features is not None:
    _, time_pad_w, time_feature_shape_new = get_height_crop_width_pad(
        tuple(time_feature_shape), patch_size
    )
    valid_time_feat_mask = (
        [0] * time_pad_w[0] + [1] * time_feature_shape[1] + [0] * time_pad_w[1]
    )
    metadata_update['datetime_input_shape'] = tuple(
        [-1] + list(time_feature_shape_new)
    )
    metadata_update['datime_valid_feats'] = tuple(valid_time_feat_mask)

  else:
    metadata_update['datetime_input_shape'] = None
    metadata_update['datime_valid_feats'] = None

  # Update if dataset it one-hot-encoded or not.
  if 'activities' in dataset_name or 'mood' in dataset_name:
    metadata_update['target_is_onehot'] = True
    metadata_update['num_classes'] = len(
        dataset_constants.lsm_dataset_constants[dataset_name]['log_values']
    )
  elif 'stress' in dataset_name:
    metadata_update['target_is_onehot'] = True
    metadata_update['num_classes'] = 2

  return metadata_update


def get_electrodes_dataset(
    *,
    config,
    num_shards,
    batch_size,
    eval_batch_size=None,
    dtype_str='float32',
    shuffle_seed=0,
    rng=None,
    shuffle_buffer_size=None,
    dataset_service_address: Optional[str] = None,
    dataset_name=None,  # 'lsm_prod/lsm_300min_10M_impute'
    data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/dmcduff/ttl=6w/msa_1_5/lsm_tfds_datasets',
):
  """Gets and formats the Electrodes dataset.

  Adapted from:
  google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py and
  google3/third_party/py/scenic/dataset_lib/dataset_utils.py.

  Args:
    config: ml_collections.ConfigDict; Config for the experiment.
    num_shards: int; Number of shards to split the dataset into.
    batch_size: int; Batch size for training.
    eval_batch_size: int; Batch size for evaluation.
    dtype_str: str; Data type of the image.
    shuffle_seed: int; Seed for shuffling the dataset.
    rng: jax.random.PRNGKey; Random number generator key.
    shuffle_buffer_size: int; Size of the shuffle buffer.
    dataset_service_address: str; Address of the dataset service.
    dataset_name: str; Name of the dataset.
    data_dir: str; Directory of the dataset.

  Returns:
    A dataset_utils.Dataset object.
  """

  # Setup: General
  if rng is None:
    rng = jax.random.PRNGKey(config.rng_seed)

  # 1. Process information.
  p_idx = jax.process_index()  # current process index
  p_cnt = jax.process_count()  # process count (number of processes)

  aug_rngs = jax.random.split(rng, p_cnt)  # per-device augmentation seeds
  aug_rng = aug_rngs[p_idx]  # device augmentation seed
  tf_aug_rng = aug_rng[0]  # jax random seeds are arrays, tf expects an int.
  del rng

  # 2. dataset and data type information.
  dataset_configs = config.dataset_configs  # get dataset configurations.
  dataset_name = dataset_configs.get('dataset', dataset_name)  # get ds name
  dtype = getattr(tf, dtype_str)  # data dtype
  if eval_batch_size is None:  # set eval batch size
    eval_batch_size = batch_size

  # Setup: Mapping functions.
  # 1. Preprocessing, augmentation, and cropping/padding functions.
  preprocess_fn = functools.partial(
      preprocess_example, dataset_name=dataset_name, dtype=dtype
  )
  # 2. Augmentation function.
  augment_fn = functools.partial(
      augment_example,
      augmentations=config.get('train_augmentations', []),
      seed=tf_aug_rng,
  )

  print('BASE SEED IS', tf_aug_rng)

  # 3. Crop and pad features and time features to be patch size compatible.
  crop_and_pad_fn = functools.partial(
      patch_compatible_resize_example, patch_size=config.model.patches.size
  )

  # Setup: Data splits.
  # 1. Train split: Get the entire or a subset of the training set.
  train_split_name = dataset_configs.get('train_split', 'train')
  num_train_samples = dataset_configs.get('train_num_samples', None)
  if num_train_samples:
    train_split = f'{train_split_name}[:{num_train_samples}]'
  else:
    train_split = train_split_name

  # 2. Validation / Test splits: Split the test split into validation and
  # test sets. (50% - 50% split).
  eval_split_name = dataset_configs.get('eval_split', 'test')
  val_split, test_split = tfds.even_splits(split=eval_split_name, n=2)

  # 3. Per-process split: Split splits evenly per worker).
  train_split_range = tfds.even_splits(split=train_split, n=p_cnt)[p_idx]
  val_split_range = tfds.even_splits(split=val_split, n=p_cnt)[p_idx]
  test_split_range = tfds.even_splits(split=test_split, n=p_cnt)[p_idx]

  # 4. Load dataset splits.
  train_ds = tfds.load(
      dataset_name,
      data_dir=data_dir,
      split=train_split_range,
      shuffle_files=False,  # NOTE: train shuffle is done below.
  )
  val_ds = tfds.load(
      dataset_name,
      data_dir=data_dir,
      split=val_split_range,
      shuffle_files=False,
  )
  test_ds = tfds.load(
      dataset_name,
      data_dir=data_dir,
      split=test_split_range,
      shuffle_files=False,
  )
  logging.info(  # pylint:disable=logging-fstring-interpolation
      f'Loaded train, val, and test split {p_idx}/{p_cnt} from {dataset_name}.'
  )

  # Data processing and preperation.
  # 1. Enable multi threaded workers.
  options = tf.data.Options()
  options.threading.private_threadpool_size = 48
  train_ds = train_ds.with_options(options)
  val_ds = val_ds.with_options(options)
  test_ds = test_ds.with_options(options)

  # 2. Preprocessing: Applied before `ds.cache()` to re-use it.
  train_ds = train_ds.map(
      preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  val_ds = val_ds.map(
      preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  test_ds = test_ds.map(
      preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )

  # 3. Cache datasets: This can signficantly speed up training.
  if dataset_configs.cache_dataset:
    train_ds = train_ds.cache()
    val_ds = val_ds.cache()
    test_ds = test_ds.cache()

  # 4. Data preperation (repetition, shuffling, augmentations, batching, etc.).
  repeat_ds = dataset_configs.get('repeat_data', True)

  # 4a. Train: repeat, augment, crop/pad, shuffle, and batch.
  if repeat_ds:
    train_ds = train_ds.repeat()  # repeat
  # NOTE: Train augmentations are done after repeat for true randomness.
  if config.use_train_augmentations:
    train_ds = train_ds.map(  # train data augmentations
        augment_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
  train_ds = train_ds.map(  # crop/pad for perfect patching
      crop_and_pad_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  shuffle_buffer_size = shuffle_buffer_size or (8 * batch_size)
  train_ds = train_ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)  # shuffle
  train_ds = train_ds.batch(batch_size, drop_remainder=True)  # batch
  train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)  # prefetch

  # 4b. Validation: crop/pad, batch, and repeat.
  val_ds = val_ds.map(  # crop/pad for perfect patching
      crop_and_pad_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  val_ds = val_ds.batch(batch_size, drop_remainder=False)  # batch
  if repeat_ds:
    val_ds = val_ds.repeat()  # repeat
  val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)

  # 4c. Test: crop/pad, batch, and repeat.
  test_ds = test_ds.map(  # crop/pad for perfect patching
      crop_and_pad_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  test_ds = test_ds.batch(batch_size, drop_remainder=False)  # batch
  if repeat_ds:
    test_ds = test_ds.repeat()  # repeat
  test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)

  if dataset_service_address:
    if shuffle_seed is not None:
      raise ValueError(
          'Using dataset service with a random seed causes each '
          'worker to produce exactly the same data. Add '
          'config.shuffle_seed = None to your config if you '
          'want to run with dataset service.'
      )
    train_ds = dataset_utils.distribute(train_ds, dataset_service_address)
    logging.info('Using the tf.data service at %s', dataset_service_address)

  # Other mappings
  # 1. Set up batch padding: If batch remainders are NOT dropped batches may be
  # padded to allow for an enough patches to contain all samples.
  maybe_pad_batches_train = functools.partial(
      dataset_utils.maybe_pad_batch,
      train=True,
      batch_size=batch_size,
      inputs_key='input_signal',
  )
  maybe_pad_batches_eval = functools.partial(
      dataset_utils.maybe_pad_batch,
      train=False,
      batch_size=eval_batch_size,
      inputs_key='input_signal',
  )
  maybe_pad_batches_test = functools.partial(
      dataset_utils.maybe_pad_batch,
      train=False,
      batch_size=eval_batch_size,
      inputs_key='input_signal',
  )
  # 2. Set up batch sharding: Shard batches to be processed by multiple devices.
  shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards)

  # 3. Apply other mappings and Iter dataset
  train_iter = iter(train_ds)
  train_iter = map(dataset_utils.tf_to_numpy, train_iter)
  train_iter = map(maybe_pad_batches_train, train_iter)
  train_iter = map(shard_batches, train_iter)

  val_iter = iter(val_ds)
  val_iter = map(dataset_utils.tf_to_numpy, val_iter)
  val_iter = map(maybe_pad_batches_eval, val_iter)
  val_iter = map(shard_batches, val_iter)

  test_iter = iter(test_ds)
  test_iter = map(dataset_utils.tf_to_numpy, test_iter)
  test_iter = map(maybe_pad_batches_test, test_iter)
  test_iter = map(shard_batches, test_iter)

  # Save meta data
  info = tfds.builder(dataset_name, data_dir=data_dir, try_gcs=True).info
  input_shape = tuple([-1] + list(info.features['input_signal'].shape))
  meta_data = {
      'input_shape': input_shape,
      'num_train_examples': dataset_utils.get_num_examples(
          dataset=dataset_name, split=train_split, data_dir=data_dir
      ),
      'num_val_examples': dataset_utils.get_num_examples(
          dataset=dataset_name, split=val_split, data_dir=data_dir
      ),
      'num_test_examples': dataset_utils.get_num_examples(
          dataset=dataset_name, split=test_split, data_dir=data_dir
      ),
      'input_dtype': getattr(jnp, dtype_str),
      # The following two fields are set as defaults and may be updated in the
      # update_metadata function below.
      'target_is_onehot': False,
      'num_classes': None,
  }
  # Update metadata to reflect preprocessing, and paddings
  # (Changes in shape, and features).
  meta_data.update(
      update_metadata(
          meta_data,
          dataset_name=dataset_name,
          patch_size=config.model.patches.size,
      )
  )

  # Return dataset structure.
  return dataset_utils.Dataset(train_iter, val_iter, test_iter, meta_data)


def get_dataset(
    config: Any,
    data_rng: jnp.ndarray,
    *,
    num_local_shards: Optional[int] = None,
    dataset_service_address: Optional[str] = None,
    **kwargs: Any,
) -> dataset_utils.Dataset:
  """Adapted from: google3/third_party/py/scenic/train_lib/train_utils.py."""

  # Get device count
  device_count = jax.device_count()
  logging.info('device_count: %d', device_count)
  logging.info('num_hosts : %d', jax.process_count())
  logging.info('host_id : %d', jax.process_index())

  # Set the dataset builder functions
  dataset_suported_list = [
      x['dataset_name']
      for x in dataset_constants.lsm_dataset_constants.values()
  ]
  dataset_name = config.dataset_configs.dataset
  if dataset_name.split('/')[1] in dataset_suported_list:
    dataset_builder = get_electrodes_dataset
  else:
    raise ValueError(f'Dataset {dataset_name} is not supported.')

  # Get batch size
  batch_size = config.batch_size
  if batch_size % device_count > 0:
    raise ValueError(
        f'Batch size ({batch_size}) must be divisible by the '
        f'number of devices ({device_count})'
    )

  local_batch_size = batch_size // jax.process_count()
  device_batch_size = batch_size // device_count
  logging.info('local_batch_size : %d', local_batch_size)
  logging.info('device_batch_size : %d', device_batch_size)

  # Get shuffle seed - ensure it exists
  shuffle_seed = config.get('shuffle_seed', None)
  if dataset_service_address and shuffle_seed is not None:
    raise ValueError(
        'Using dataset service with a random seed causes each '
        'worker to produce exactly the same data. Add '
        'config.shuffle_seed = None to your config if you want '
        'to run with dataset service.'
    )

  # Get shuffle buffer size.
  shuffle_buffer_size = config.get('shuffle_buffer_size', None)
  # Local shard count.
  num_local_shards = num_local_shards or jax.local_device_count()

  # Build the dataset
  ds = dataset_builder(
      config=config,
      num_shards=num_local_shards,
      batch_size=local_batch_size,
      dtype_str=config.data_dtype_str,
      shuffle_seed=shuffle_seed,
      rng=data_rng,
      shuffle_buffer_size=shuffle_buffer_size,
      dataset_service_address=dataset_service_address,
      **kwargs,
  )

  return ds



# Check Dataset Reproducibility
### Takeaways:
- Training augmentations ARE reproducible if `tf.keras.utils.set_random_seed` is set AND if `tf.config.experimental.enable_op_determinism` enabled.
- Reproducible training also requires a set dataset shuffle seed. This can be done by setting `config.shuffle_seed`.


In [None]:
DATASET_NAME = 'lsm_300min_600_activities_balanced'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 8

# TEST 1
print('Trail 1:')
tf.keras.utils.set_random_seed(1)
tf.config.experimental.enable_op_determinism()

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = get_dataset(config, data_rng)

apply_noise_arr1 = []
noise_std_arr1 = []
apply_flip_arr1 = []
apply_stretch_arr1 = []
stretch_arr1 = []
for i in range(100):
  x = next(dataset.train_iter)
  apply_noise_arr1.append(x['aug_apply_noise'])
  noise_std_arr1.append(x['aug_noise_std'])
  apply_flip_arr1.append(x['aug_apply_flip'])
  apply_stretch_arr1.append(x['aug_apply_stretch'])
  stretch_arr1.append(x['aug_stretch'])
apply_noise_arr1 = tf.concat(apply_noise_arr1, axis=-1)[0]
noise_std_arr1 = tf.concat(noise_std_arr1, axis=-1)[0]
apply_flip_arr1 = tf.concat(apply_flip_arr1, axis=-1)[0]
apply_stretch_arr1 = tf.concat(apply_stretch_arr1, axis=-1)[0]
stretch_arr1 = tf.concat(stretch_arr1, axis=-1)[0]
print(apply_noise_arr1[0:5])


# TEST 2
print('\nTrail 2:')
tf.keras.utils.set_random_seed(1)
tf.config.experimental.enable_op_determinism()

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = get_dataset(config, data_rng)

apply_noise_arr2 = []
noise_std_arr2 = []
apply_flip_arr2 = []
apply_stretch_arr2 = []
stretch_arr2 = []
for i in range(100):
  x = next(dataset.train_iter)
  apply_noise_arr2.append(x['aug_apply_noise'])
  noise_std_arr2.append(x['aug_noise_std'])
  apply_flip_arr2.append(x['aug_apply_flip'])
  apply_stretch_arr2.append(x['aug_apply_stretch'])
  stretch_arr2.append(x['aug_stretch'])
apply_noise_arr2 = tf.concat(apply_noise_arr2, axis=-1)[0]
noise_std_arr2 = tf.concat(noise_std_arr2, axis=-1)[0]
apply_flip_arr2 = tf.concat(apply_flip_arr2, axis=-1)[0]
apply_stretch_arr2 = tf.concat(apply_stretch_arr2, axis=-1)[0]
stretch_arr2 = tf.concat(stretch_arr2, axis=-1)[0]
print(apply_noise_arr2[0:5])

print('\nTrail 1 equals Trail 2?:')
print('apply_noise:', bool(tf.reduce_all(tf.math.equal(apply_noise_arr1, apply_noise_arr2))))
print('noise_std:', bool(tf.reduce_all(tf.math.equal(noise_std_arr1, noise_std_arr2))))
print('apply_flip:', bool(tf.reduce_all(tf.math.equal(apply_flip_arr1, apply_flip_arr2))))
print('apply_stretch:', bool(tf.reduce_all(tf.math.equal(apply_stretch_arr1, apply_stretch_arr2))))
print('stretch:', bool(tf.reduce_all(tf.math.equal(stretch_arr1, stretch_arr2))))


# Check seeded random functions vs unseeded random functions
### Experiment information:
- Assuming the global seed is set, will functions with the same seed produce the same sequence of numbers.
### Takeaways:
- Random functions must be seeded differently if different sequences of random numbers are required.


In [None]:
# @title Baseline: Global Seed Set - Function Seeds Set (All Same)

tf.keras.utils.set_random_seed(1)
tf.config.experimental.enable_op_determinism()
x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=0)
print(x)

x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=0)
print(x)

x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=0)
print(x)

In [None]:
# @title Experiment 1: Global Seed Set / Function Seeds Set (All Same)

# Explanation:
# If function seeds are set, and there are random functions, with different
# seeds are run between calls, will the output sequence be the same?

tf.keras.utils.set_random_seed(1)
tf.config.experimental.enable_op_determinism()
x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=0)
print(x)

x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=1)
x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=0)
print(x)

x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=1)
x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=0)
print(x)

In [None]:
# @title Experiment 2: Global Seed Set - Function Seeds Set (All Different)

# Explanation:
# Showing that random function with different seed will produce a different
# sequence of numbers.

tf.keras.utils.set_random_seed(1)
tf.config.experimental.enable_op_determinism()
x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=0)
print(x)

x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=1)
x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=100)
print(x)

x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=1)
x = tf.random.uniform(shape=(), minval=0, maxval=1, dtype=tf.float32, seed=1000)
print(x)