# LSM Pretrain Dataset Explorer
##### Colab Kernel (Electrodes)
##### Dataset (Electrodes)

Grants command for Access on Demand (AoD):

https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-electrodes-deid-colab-jobs&reason=b%2F314799341

### About This Notebook:
This notebook explores pretrain (unlabeled) datasets for the LSM project. It loads and prints dataset meta data, and plots sample data for the following datasets:

**Actively Used:**
1. Balanced Pretrain Dataset (10 Samples Per 165k Subjects): `lsm_300min_pretraining_165K_n10`
2. Combined 8M Pretrain Dataset: `lsm_300min_pretraining_165K_n10` + `lsm_300min_10M_impute`

# Setup

In [None]:
# @title Imports

import io
import functools
from typing import Any, Callable, Dict, Iterator, Tuple, Optional, Type, Union
import time
from collections import Counter

from absl import logging
from clu import metric_writers
from clu import periodic_actions
from clu import platform

import flax
from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.profiler

import pandas as pd
import ml_collections
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline

from colabtools import adhoc_import
with adhoc_import.Google3():
  from scenic.dataset_lib import dataset_utils
  from scenic.google.xm import xm_utils
  from scenic.model_lib.base_models import base_model
  from scenic.model_lib.base_models import model_utils
  from scenic.model_lib.layers import nn_ops
  from scenic.model_lib.layers import nn_layers
  from scenic.projects.baselines import vit
  from scenic.train_lib import optax as scenic_optax
  from scenic.train_lib import pretrain_utils
  from scenic.train_lib import train_utils

  from scenic.projects.multimask.models import model_utils as mm_model_utils

  from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
  from google3.experimental.largesensormodels.scenic.datasets import lsm_activity_subset_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_mood_vs_activity_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_tiny_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_combined_pretrain_dataset

  from google3.experimental.largesensormodels.scenic.models import lsm_vit as lsm_vit_mae
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_utils as lsm_model_utils
  from google3.experimental.largesensormodels.scenic.trainers import lsm_mae_trainer

  from google3.pyglib import gfile


Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[
    [jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]],
    Dict[str, Tuple[float, int]],
]
LossFn = Callable[
    [jnp.ndarray, Batch, Optional[jnp.ndarray], jnp.ndarray], float
]
LrFns = Dict[str, Callable[[jnp.ndarray], jnp.ndarray]]
Patch = Union[Tuple[int, int], Tuple[int, int, int]]


In [None]:
# @title Helper Functions


def explore_sample_data(dataset, ohe_labels):
  p_idx = 0
  batch_idx = 0
  x = next(dataset.valid_iter)

  #Pparse data example
  input = x['input_signal']  # Sensor signals
  plt_input = jnp.transpose(input, (0, 1, 3, 2, 4))

  datetime = x['datetime_signal']  # Datetime signals
  plt_datetime = jnp.transpose(datetime, (0, 1, 3, 2, 4))

  print(f'Example keys {list(x.keys())}\n')
  print('Exercise Log', x['exercise_log'])  # Exercise log
  print('Log Value', x['log_value'] + 65536)  # Log value
  print('\nMood Log', x['mood_log'])  # Mood log

  if ohe_labels:
    print(f"\nLabel {x['label']}\n")  # Stress label

  # Plot input signal
  plt.figure(figsize= (15, 4))
  plt.imshow(plt_input[p_idx][batch_idx])
  plt.title('Sensor Inputs');

  plt.figure(figsize= (15, 4))
  plt.imshow(plt_datetime[p_idx][batch_idx])
  plt.title('Datetime Inputs');


def dataset_split_information(ds, event, offset=65536):
  """Given a dataset split retuns a dictionary of stats/metadata."""
  # event can be ['exercise', 'mood', 'stress']

  if event == 'exercise':
    log_key = 'exercise_log'
  elif event == 'mood' or event == 'stress':
    log_key = 'mood_log'
  else:
    raise ValueError(f'event must be exercise, mood, or stress')

  subj_ids = []
  ex_count = 0
  log_values = []
  for d in ds:

    # Subject ID
    # ids = d['labels']['ID']
    # ids = [i.decode() for i in ids.flatten() if type(i) is bytes]
    # subj_ids += ids

    # Total example count.
    ex_count += jnp.sum(d['batch_mask'])
    if ex_count % 100000 == 0:
      print(f'running example count: {ex_count}')

    # Get label
    if event == 'exercise' or event == 'mood':
      valid_log = jnp.where(d[log_key])  # where exercise/mood log is True
      valid_log_value = d['log_value'][valid_log]  # get valid log values
      # add offset value
      valid_log_value = [v + offset for v in valid_log_value.flatten() if v != 0]

    elif event == 'stress':
      log_value = jnp.argmax(d['label'], axis=-1)
      valid_log = jnp.where(d['batch_mask'])  # where exercise/mood log is True
      valid_log_value = log_value[valid_log]  # get valid log values
      valid_log_value = valid_log_value.tolist()

    # append list to log values list
    log_values += valid_log_value

  info = {'num_examples': ex_count,
          'example_subj_ids': subj_ids,
          'log_value': log_values}

  return info



In [None]:
# @title Sample Model Config

r"""A config to train a Tiny ViT MAE on LSM 5M dataset.

Forked from google3/third_party/py/scenic/projects/multimask/configs/mae_cifar10_tiny.py

To run on XManager:
gxm third_party/py/scenic/google/xm/launch_xm.py -- \
--binary //experimental/largesensormodels/scenic:main \
--config=experimental/largesensormodels/scenic/configs/mae_lsm_tiny.py \
--platform=vlp_4x4 \
--exp_name=lsm_mae_tier2_TinyShallow_10_5_res \
--workdir=/cns/dz-d/home/xliucs/lsm/xm/\{xid\} \
--xm_resource_alloc=group:mobile-dynamic/h2o-ai-gqm-quota \
--priority=200

To run locally:
./third_party/py/scenic/google/runlocal.sh \
--uptc="" \
--binary=//experimental/largesensormodels/scenic:main \
--config=$(pwd)/experimental/largesensormodels/scenic/configs/mae_lsm_tiny.py:runlocal
"""


# To set constants.
DATASET_NAME = 'lsm_300min_10M_impute'
CACHE_DATASET = False
TRAIN_DATA_SIZE = 1000000  # 100k train samples
BATCH_SIZE = 8
NUMBER_OF_EPOCH = 100
REPEAT_DATA = False

# Model variant / patch H (time steps) / patch W (features)
VARIANT = 'TiShallow/1/1'
LRS = [1e-3]
TOKEN_MASK_PROB = 'constant_0.8'
LOSS_ONLY_MASKED_TOKENS = True
USE_DATETIME_FEATURES = True
USE_TRAIN_AUGMENTATIONS = True
TRAIN_AUGMENTATIONS = ['stretch', 'flip', 'noise']
OHE_LABELS = True

# Derived constants.
TRAIN_DATA_SIZE = min(
    TRAIN_DATA_SIZE,
    dataset_constants.lsm_dataset_constants[DATASET_NAME]['num_train_examples']
)

STEPS_PER_EPOCH = max(1, int(TRAIN_DATA_SIZE / BATCH_SIZE))
NUM_TRAIN_STEPS = int(NUMBER_OF_EPOCH * STEPS_PER_EPOCH)

LOG_EVAL_SUMMARY_STEPS = STEPS_PER_EPOCH
LOG_CHECKPOINT_STEPS = LOG_EVAL_SUMMARY_STEPS * 5
MAX_NUM_CHECKPOINTS = int(NUM_TRAIN_STEPS / LOG_CHECKPOINT_STEPS)


def get_config_common_few_shot(
    batch_size: Optional[int] = None,
    target_resolution: int = 224,
    resize_resolution: int = 256,
) -> ml_collections.ConfigDict:
  """Returns a standard-ish fewshot eval configuration.

  Copied from
  third_party/py/scenic/projects/baselines/configs/google/common/common_fewshot.py

  Args:
    batch_size: The batch size to use for fewshot evaluation.
    target_resolution: The target resolution of the fewshot evaluation.
    resize_resolution: The resize resolution of the fewshot evaluation.

  Returns:
    A ConfigDict with the fewshot evaluation configuration.
  """
  config = ml_collections.ConfigDict()
  config.batch_size = batch_size
  config.representation_layer = 'pre_logits'
  config.log_eval_steps = 25_000
  config.datasets = {
      'birds': ('caltech_birds2011', 'train', 'test'),
      'caltech': ('caltech101', 'train', 'test'),
      'cars': ('cars196:2.1.0', 'train', 'test'),
      'cifar100': ('cifar100', 'train', 'test'),
      'col_hist': ('colorectal_histology', 'train[:2000]', 'train[2000:]'),
      'dtd': ('dtd', 'train', 'test'),
      'imagenet': ('imagenet2012_subset/10pct', 'train', 'validation'),
      'pets': ('oxford_iiit_pet', 'train', 'test'),
      'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
  }
  config.pp_train = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)'
  config.pp_eval = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)'
  config.shots = [1, 5, 10, 25]
  config.l2_regs = [2.0**i for i in range(-10, 20)]
  config.walk_first = ('imagenet', 10)

  return config


def get_config(runlocal=''):
  """Returns the ViT experiment configuration."""

  runlocal = bool(runlocal)

  # Experiment.
  config = ml_collections.ConfigDict()
  config.experiment_name = f'electrodes-mae-{DATASET_NAME}-{TRAIN_DATA_SIZE}'
  config.dataset_name = f'lsm_prod/{DATASET_NAME}'

  # Dataset.
  config.data_dtype_str = 'float32'
  config.dataset_configs = ml_collections.ConfigDict()
  config.dataset_configs.dataset = f'lsm_prod/{DATASET_NAME}'
  # config.dataset_configs.num_classes = NUM_CLASSES
  config.dataset_configs.train_split = 'train'  # train data split
  config.dataset_configs.train_num_samples = TRAIN_DATA_SIZE  # train sample
  # eval data split - note: this split is used for validation and test.
  config.dataset_configs.eval_split = 'test[:64]' if runlocal else 'test'
  config.dataset_configs.cache_dataset = CACHE_DATASET
  config.dataset_configs.prefetch_to_device = 2
  # Shuffle_buffer_size is per host, so small-ish is ok.
  config.dataset_configs.shuffle_buffer_size = 250_000
  config.dataset_configs.repeat_data = REPEAT_DATA
  config.dataset_configs.ohe_labels = OHE_LABELS

  # Model.
  if len(VARIANT.split('/')) == 3:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[2]  # patch height
  elif len(VARIANT.split('/')) == 2:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[1]  # patch height
  else:
    raise ValueError(f'Invalid model variant: {VARIANT}')

  version = 'Deb' if runlocal else version
  config.model_name = 'lsm_vit_mae'
  config.model = ml_collections.ConfigDict()
  # encoder
  config.model.hidden_size = model_constants.HIDDEN_SIZES[version]
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = tuple([int(patch_h), int(patch_w)])
  config.model.num_heads = model_constants.NUM_HEADS[version]
  config.model.mlp_dim = model_constants.MLP_DIMS[version]
  config.model.num_layers = model_constants.NUM_LAYERS[version]
  config.model.dropout_rate = 0.0
  config.model.classifier = 'none'  # Has to be "none" for the autoencoder
  config.model.representation_size = None
  config.model.positional_embedding = 'sinusoidal_2d'
  config.model.positional_embedding_decoder = 'sinusoidal_2d'
  # decoder
  config.model.decoder_config = ml_collections.ConfigDict()
  config.model.decoder_config.hidden_size = (
      model_constants.DECODER_HIDDEN_SIZES[version]
  )
  config.model.decoder_config.mlp_dim = model_constants.DECODER_MLP_DIMS[
      version
  ]
  config.model.decoder_config.num_layers = model_constants.DECODER_NUM_LAYERS[
      version
  ]
  config.model.decoder_config.num_heads = model_constants.DECODER_NUM_HEADS[
      version
  ]
  config.model.decoder_config.dropout_rate = 0.0
  config.model.decoder_config.attention_dropout_rate = 0.0

  config.masked_feature_loss = ml_collections.ConfigDict()
  config.masked_feature_loss.targets_type = 'rgb'
  config.masked_feature_loss.token_mask_probability = TOKEN_MASK_PROB
  config.masked_feature_loss.loss_only_masked_tokens = LOSS_ONLY_MASKED_TOKENS
  config.masked_feature_loss.loss_type = 'squared'  # 'squared' or 'absolute'

  # Datetime features.
  config.use_datetime_features = USE_DATETIME_FEATURES

  # Training.
  config.trainer_name = 'lsm_mae_trainer'
  config.batch_size = 8 if runlocal else BATCH_SIZE
  config.num_training_steps = NUM_TRAIN_STEPS
  config.log_eval_steps = LOG_EVAL_SUMMARY_STEPS
  config.log_summary_steps = LOG_EVAL_SUMMARY_STEPS
  config.rng_seed = 42
  config.use_train_augmentations = USE_TRAIN_AUGMENTATIONS
  config.train_augmentations = TRAIN_AUGMENTATIONS
  sched = ml_collections.ConfigDict()
  sched.re = '(.*)'
  sched.lr_configs = ml_collections.ConfigDict()
  sched.lr_configs.learning_rate_schedule = 'compound'
  sched.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
  sched.lr_configs.total_steps = NUM_TRAIN_STEPS
  sched.lr_configs.steps_per_cycle = sched.lr_configs.total_steps
  sched.lr_configs.warmup_steps = STEPS_PER_EPOCH * NUMBER_OF_EPOCH * 0.05
  sched.lr_configs.base_learning_rate = LRS[0]
  config.schedule = ml_collections.ConfigDict({'all': sched})

  # *Single* optimizer.
  optim = ml_collections.ConfigDict()
  optim.optax_name = 'scale_by_adam'
  # optim.optax = dict(mu_dtype='bfloat16')
  optim.optax_configs = ml_collections.ConfigDict({  # Optimizer settings.
      'b1': 0.9,
      'b2': 0.95,
  })
  config.optax = dict(mu_dtype='bfloat16')
  optim.max_grad_norm = 1.0

  optim.weight_decay = 1e-4
  optim.weight_decay_decouple = True
  config.optimizer = optim

  # Fewshot.
  # TODO(girishvn): This needs to be adapted to electrode dataset
  config.fewshot = get_config_common_few_shot(batch_size=config.batch_size)
  config.fewshot.datasets = {}
  config.fewshot.walk_first = ()
  config.fewshot.representation_layer = 'pre_logits'
  config.fewshot.log_eval_steps = LOG_EVAL_SUMMARY_STEPS

  # Logging.
  config.write_summary = True
  config.xprof = True  # Profile using xprof.
  config.checkpoint = True  # Do checkpointing.
  config.checkpoint_steps = LOG_CHECKPOINT_STEPS
  config.debug_train = False  # Debug mode during training.
  config.debug_eval = False  # Debug mode during eval.
  config.max_checkpoints_to_keep = MAX_NUM_CHECKPOINTS
  # BEGIN GOOGLE-INTERNAL
  if runlocal:
    # Current implementation fails with UPTC.
    config.count_flops = False
  # END GOOGLE-INTERNAL

  return config


# BEGIN GOOGLE-INTERNAL
def get_hyper(hyper):
  """Defines the hyper-parameters sweeps for doing grid search."""
  return hyper.product([
      hyper.sweep('config.schedule.all.lr_configs.base_learning_rate', LRS),
  ])


In [None]:
# @title Data Dir

data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/dmcduff/ttl=6w/msa_1_5/lsm_tfds_datasets'
print(data_dir)

In [None]:
# @title Dataset Feature Columns

with gfile.Open('/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/raw/datasets/msa_1_5/lsm_tfds_datasets/lsm_prod/lsm_300min_100K_unimpute/1.0.0/Dataset_FeatureNames.csv', 'r') as f:
  df = pd.read_csv(f)

features = df.columns
print(list(features))

# Balanced Pre-train Dataset
### Dataset Name: lsm_300min_pretraining_165K_n10


In [None]:
# @title Load Dataset and Visualize Sample

DATASET_NAME = 'lsm_300min_pretraining_165K_n10'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 16

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_tiny_dataset.get_dataset(config, data_rng)

print('Processed Dataset Meta Data:')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

print('\nData Sample Information:')
explore_sample_data(dataset, False)

In [None]:
# @title Metadata From Data Constants

print('Data Constants Meta Data:')
dataset_name = config.dataset_name.split('/')[-1]
dataset_constants.lsm_dataset_constants[dataset_name]

# Combined 8M Pre-train Dataset
### Dataset name: lsm_300min_pretraining_8M_combined

In [None]:
# @title Load Dataset and Visualize Sample

DATASET_NAME = 'lsm_300min_pretraining_8M_combined'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 16

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_combined_pretrain_dataset.get_dataset(config, data_rng)

print('Processed Dataset Meta Data:')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

print('\nData Sample Information:')
explore_sample_data(dataset, False)

In [None]:
# @title Metadata From Data Constants

print('Data Constants Meta Data:')
dataset_name = config.dataset_name.split('/')[-1]
dataset_constants.lsm_dataset_constants[dataset_name]