# LSM Down-Stream-Task Dataset Explorer
##### Colab Kernel (Electrodes)
##### Dataset (Electrodes)

Grants command for Access on Demand (AoD):

https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-electrodes-deid-colab-jobs&reason=b%2F314799341

### About This Notebook:
This notebook explores down-stream-task specific datasets for the LSM project. It loads and prints dataset meta data, plots sample data, and counts label occurences for the following datasets:

**Actively Used:**
1. Mood dataset: `lsm_300min_2000_mood_balanced`
2. Stress dataset: `lsm_300min_2000_stress_balanced`
3. Activity dataset: `lsm_300min_600_activities_balanced_v4`
4. Exercise dataset: `lsm_300min_2000_mood_balanced` and `lsm_300min_600_activities_balanced`
5. Biological sex dataset: derived from `lsm_300min_600_activities_balanced_v4`
6. Age dataset: derived from `lsm_300min_600_activities_balanced_v4`
7. Subject dependent mood dataset: derived from `lsm_300min_2000_mood_balanced`

**Deprecated Datasets:**
1. Exercise dataset 1: `lsm_300min_600_activities_balanced`
2. Exercise dataset 2: `lsm_300min_300_activities_balanced`
3. Excercise dataset 3 (subset of the above dataset): `lsm_300min_600_activities_9class_subset`

The results from this notebook are (manually) updated in the following google sheets file:

https://docs.google.com/spreadsheets/d/1-crJrg0XhedN5ayRpNd-7AUGr7PRxoQKxL__9gLlB5M/edit?usp=sharing&resourcekey=0-fv0vnqQuas9QsAJVcUOoGA

# Setup

In [None]:
# @title Imports

import io
import functools
from typing import Any, Callable, Dict, Iterator, Tuple, Optional, Type, Union
import time
from collections import Counter
import time
import random

from absl import logging
from clu import metric_writers
from clu import periodic_actions
from clu import platform

import flax
from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.profiler

import pandas as pd
import ml_collections
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline

from colabtools import adhoc_import
with adhoc_import.Google3():
  from scenic.dataset_lib import dataset_utils
  from scenic.google.xm import xm_utils
  from scenic.model_lib.base_models import base_model
  from scenic.model_lib.base_models import model_utils
  from scenic.model_lib.layers import nn_ops
  from scenic.model_lib.layers import nn_layers
  from scenic.projects.baselines import vit
  from scenic.train_lib import optax as scenic_optax
  from scenic.train_lib import pretrain_utils
  from scenic.train_lib import train_utils

  from scenic.projects.multimask.models import model_utils as mm_model_utils

  # LSM Dataset Imports
  from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
  from google3.experimental.largesensormodels.scenic.datasets import lsm_activity_subset_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_mood_vs_activity_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_tiny_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_combined_pretrain_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_fewshot_mood_vs_activity_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_fewshot_remapped_activity_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_mood_subj_dependent_preprocessed_40sps_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_biological_sex_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_binned_age_dataset

  from google3.experimental.largesensormodels.scenic.models import lsm_vit as lsm_vit_mae
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_utils as lsm_model_utils
  from google3.experimental.largesensormodels.scenic.trainers import lsm_mae_trainer

  from google3.pyglib import gfile


Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[
    [jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]],
    Dict[str, Tuple[float, int]],
]
LossFn = Callable[
    [jnp.ndarray, Batch, Optional[jnp.ndarray], jnp.ndarray], float
]
LrFns = Dict[str, Callable[[jnp.ndarray], jnp.ndarray]]
Patch = Union[Tuple[int, int], Tuple[int, int, int]]


In [None]:
# @title Helper Functions


def explore_sample_data(dataset, ohe_labels):
  p_idx = 0
  batch_idx = 0
  x = next(dataset.valid_iter)

  #Pparse data example
  input = x['input_signal']  # Sensor signals
  plt_input = jnp.transpose(input, (0, 1, 3, 2, 4))

  datetime = x['datetime_signal']  # Datetime signals
  plt_datetime = jnp.transpose(datetime, (0, 1, 3, 2, 4))

  print(f'Example keys {list(x.keys())}\n')
  print('Exercise Log', x['exercise_log'])  # Exercise log
  print('Log Value', x['log_value'] + 65536)  # Log value
  print('\nMood Log', x['mood_log'])  # Mood log

  if ohe_labels:
    print(f"\nLabel {x['label']}\n")  # Stress label

  # Plot input signal
  plt.figure(figsize= (15, 4))
  plt.imshow(plt_input[p_idx][batch_idx])
  plt.title('Sensor Inputs');

  plt.figure(figsize= (15, 4))
  plt.imshow(plt_datetime[p_idx][batch_idx])
  plt.title('Datetime Inputs');


def dataset_split_information(ds, event, offset=65536):
  """Given a dataset split retuns a dictionary of stats/metadata."""
  # event can be ['exercise', 'mood', 'stress']

  if event == 'exercise':
    log_key = 'exercise_log'
  elif event == 'mood' or event == 'stress':
    log_key = 'mood_log'
  elif event is None:
    log_key = None
  else:
    raise ValueError(f'event must be exercise, mood, or stress')

  subj_ids = []
  ex_count = 0
  log_values = []
  for d in ds:

    # Subject ID
    # ids = d['labels']['ID']
    # ids = [i.decode() for i in ids.flatten() if type(i) is bytes]
    # subj_ids += ids

    # Total example count.
    ex_count += jnp.sum(d['batch_mask'])
    if ex_count % 100000 == 0:
      print(f'running example count: {ex_count}')

    # Get label
    if event == 'exercise' or event == 'mood':
      valid_log = jnp.where(d[log_key])  # where exercise/mood log is True
      valid_log_value = d['log_value'][valid_log]  # get valid log values
      # add offset value
      valid_log_value = [v + offset for v in valid_log_value.flatten() if v != 0]

    elif event == 'stress':
      log_value = jnp.argmax(d['label'], axis=-1)
      valid_log = jnp.where(d['batch_mask'])  # where exercise/mood log is True
      valid_log_value = log_value[valid_log]  # get valid log values
      valid_log_value = valid_log_value.tolist()

    elif event is None:
      log_value = jnp.argmax(d['label'], axis=-1)
      valid_log = jnp.where(d['batch_mask'])  # where exercise/mood log is True
      valid_log_value = log_value[valid_log]  # get valid log values
      valid_log_value = valid_log_value.tolist()

    # append list to log values list
    log_values += valid_log_value

  info = {'num_examples': ex_count,
          'example_subj_ids': subj_ids,
          'log_value': log_values}

  return info



In [None]:
# @title Exercise Log Value Map

log_value_map = pd.read_csv(io.StringIO('''
adidas Train,59001
Zumba,56001
"Yoga, Vinyasa",52005
"Yoga, Hatha",52003
Yoga,52001
Yoga,52000
Wrestling,15730
Workout,3000
Weights,2131
Weights,2050
Weightlifting,91043
"Weight lifting (free, nautilus or universal-type), light or moderate effort, light workout, general",2130
Weeding,11360
Water volleyball,18365
Water skiing,18150
Water polo,18360
Water jogging,18366
Water aerobics,2120
Water aerobics,18355
Warm It Up,3102
Walking for pleasure,17160
Walk,90013
Volleyball,15711
Ultimate frisbee,15250
"Treadmill, 15% Incline",90022
Treadmill,20049
Treadmill,90019
Trampoline,15700
Tennis,15675
Tai chi,15670
Table tennis,15660
Tabata Workout,20055
TRX,20056
"Swimming, sidestroke, general",18320
"Swimming, leisurely, not lap swimming, general",18310
"Swimming, lake, ocean, river",18300
"Swimming laps, freestyle, fast, vigorous effort",18230
Swim,90024
"Surfing, body or board",18220
Surfing,91056
Stretching,2100
Stressed,13
Strength training,91042
"Standing; moderate/heavy (lifting more than 50 lbs, masonry, painting, paper hanging)",11630
Stairclimber,2065
Squash,15650
Sport,15000
Spinning,55001
Spinning,90002
Softball,15640
Soccer,15605
Snowshoeing,19190
Snowboarding,91051
Snorkeling,18210
Skiing,19160
Skating,91052
Skateboarding,15580
Shooting,4130
Scuba diving,18200
"Scrubbing floors, on hands and knees, scrubbing bathroom, bathtub",5130
Sailing,18120
Run,90009
Rugby,15560
Rowing machine,91041
Rowing Machine,90003
Rowing,90014
Rollerblading,91054
Roller skating,15590
Roller blading,15591
Rock climbing,17120
Rock climbing,15535
Racquetball,15530
Race walking,17110
Powerlifting,91044
Polo,15510
Pilates,53001
Pilates,53000
Paddleboarding,91057
Outdoor Workout,1072
Outdoor Bike,1071
Orienteering,15480
"Mowing lawn, walk, power mower",8120
"Mowing lawn, walk, hand mower",8110
"Mowing lawn, riding mower",8100
Mowing lawn,8095
Mountain Bike,20048
Motorcycle,16030
Meditating,7075
Martial Arts,15430
Lying quietly and watching television,7010
"Laundry, fold or hang clothes, put clothes in washer or dryer, packing suitcase",5090
Lacrosse,15460
Kickboxing,55002
Kettlebell,20053
Kayaking,18100
Jumping rope,15551
"Jogging, general",12020
Jog/walk combination (jogging component of less than 10 minutes),12010
Ironing,5070
Interval Workout,20057
Indoor climbing,91055
"Implied walking/standing -picking up yard, light, picking flowers or vegetables",8250
Ice skating,19030
Hunting,4100
Household Chores,90006
"Horseback riding, saddling horse, grooming horse",15380
Horseback riding,15370
Hockey,15360
"Hiking, cross country",17080
Hike,90012
"Health club exercise, general",2060
Handball,15320
HIIT,91040
Gymnastics,15300
"Golf, walking and pulling clubs",15285
Golf,15255
"General cleaning, moderate effort",11125
Gardening,8245
Frustrated,10
Football,15210
Fitstar: Personal Trainer,3104
Fitstar: Personal Trainer,3103
Fishing,4001
Field Hockey,15350
Fencing,15200
Excited,11
"Elliptical, Low Resistance",90016
Elliptical,20047
Elliptical,90017
"Driving heavy truck, tractor, bus",16050
Driving,16010
Diving,18090
"Digging, spading, filling garden, composting",8050
Dancing,3031
Curling,15170
CrossFit,20050
CrossFit,91045
Cross-country ski,91053
Cross Country Skiing,90015
Cricket,15150
Core training,91046
Cooking or food preparation,5052
Content,7
"Cleaning, house or cabin, general",5030
"Cleaning sink and toilet, light effort",11122
Cleaning,5020
Circuit Training,2040
Carpentry,11040
Cardio Sculpt,20051
Cardio Kickboxing,20054
Canoeing,18080
Canoeing,91059
Calm,12
"Calisthenics, home exercise, light or moderate effort, general (example: back exercises), going up and down from floor",2030
Calisthenics,2020
Boxing,15100
Bootcamp,55003
Bike,90001
Basketball,15040
Baseball,15620
Barre Class,20052
Ballet,3010
Badminton,15020
Archery,15010
Aerobics,90005
Aerobics,91047
"Aerobic, general",3015
Aerobic Workout,3001
7 Minute Workout,3100
10 Minute Abs,3101
'''), header=None)

log_value_map.columns = ['name', 'id']

# Generate dictionary map
log_value_map_dict = {
    id: name for id, name in zip(log_value_map['id'], log_value_map['name'])
}

# Show Map
log_value_map

In [None]:
# @title Sample Model Config

r"""A config to train a Tiny ViT MAE on LSM 5M dataset.

Forked from google3/third_party/py/scenic/projects/multimask/configs/mae_cifar10_tiny.py

To run on XManager:
gxm third_party/py/scenic/google/xm/launch_xm.py -- \
--binary //experimental/largesensormodels/scenic:main \
--config=experimental/largesensormodels/scenic/configs/mae_lsm_tiny.py \
--platform=vlp_4x4 \
--exp_name=lsm_mae_tier2_TinyShallow_10_5_res \
--workdir=/cns/dz-d/home/xliucs/lsm/xm/\{xid\} \
--xm_resource_alloc=group:mobile-dynamic/h2o-ai-gqm-quota \
--priority=200

To run locally:
./third_party/py/scenic/google/runlocal.sh \
--uptc="" \
--binary=//experimental/largesensormodels/scenic:main \
--config=$(pwd)/experimental/largesensormodels/scenic/configs/mae_lsm_tiny.py:runlocal
"""


# To set constants.
DATASET_NAME = 'lsm_300min_pretraining_165K_n10'
CACHE_DATASET = False
TRAIN_DATA_SIZE = 1000000  # 100k train samples
BATCH_SIZE = 8
NUMBER_OF_EPOCH = 100
REPEAT_DATA = False

# Model variant / patch H (time steps) / patch W (features)
VARIANT = 'TiShallow/10/5'
LRS = [1e-3]
TOKEN_MASK_PROB = 'constant_0.8'
LOSS_ONLY_MASKED_TOKENS = True
USE_DATETIME_FEATURES = True
USE_TRAIN_AUGMENTATIONS = True
TRAIN_AUGMENTATIONS = ['stretch', 'flip', 'noise']
OHE_LABELS = True

# Derived constants.
TRAIN_DATA_SIZE = min(
    TRAIN_DATA_SIZE,
    dataset_constants.lsm_dataset_constants[DATASET_NAME]['num_train_examples']
)

STEPS_PER_EPOCH = max(1, int(TRAIN_DATA_SIZE / BATCH_SIZE))
NUM_TRAIN_STEPS = int(NUMBER_OF_EPOCH * STEPS_PER_EPOCH)

LOG_EVAL_SUMMARY_STEPS = STEPS_PER_EPOCH
LOG_CHECKPOINT_STEPS = LOG_EVAL_SUMMARY_STEPS * 5
MAX_NUM_CHECKPOINTS = int(NUM_TRAIN_STEPS / LOG_CHECKPOINT_STEPS)


def get_config_common_few_shot(
    batch_size: Optional[int] = None,
    target_resolution: int = 224,
    resize_resolution: int = 256,
) -> ml_collections.ConfigDict:
  """Returns a standard-ish fewshot eval configuration.

  Copied from
  third_party/py/scenic/projects/baselines/configs/google/common/common_fewshot.py

  Args:
    batch_size: The batch size to use for fewshot evaluation.
    target_resolution: The target resolution of the fewshot evaluation.
    resize_resolution: The resize resolution of the fewshot evaluation.

  Returns:
    A ConfigDict with the fewshot evaluation configuration.
  """
  config = ml_collections.ConfigDict()
  config.batch_size = batch_size
  config.representation_layer = 'pre_logits'
  config.log_eval_steps = 25_000
  config.datasets = {
      'birds': ('caltech_birds2011', 'train', 'test'),
      'caltech': ('caltech101', 'train', 'test'),
      'cars': ('cars196:2.1.0', 'train', 'test'),
      'cifar100': ('cifar100', 'train', 'test'),
      'col_hist': ('colorectal_histology', 'train[:2000]', 'train[2000:]'),
      'dtd': ('dtd', 'train', 'test'),
      'imagenet': ('imagenet2012_subset/10pct', 'train', 'validation'),
      'pets': ('oxford_iiit_pet', 'train', 'test'),
      'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
  }
  config.pp_train = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)'
  config.pp_eval = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)'
  config.shots = [1, 5, 10, 25]
  config.l2_regs = [2.0**i for i in range(-10, 20)]
  config.walk_first = ('imagenet', 10)

  return config


def get_config(runlocal=''):
  """Returns the ViT experiment configuration."""

  runlocal = bool(runlocal)

  # Experiment.
  config = ml_collections.ConfigDict()
  config.experiment_name = f'electrodes-mae-{DATASET_NAME}-{TRAIN_DATA_SIZE}'
  config.dataset_name = f'lsm_prod/{DATASET_NAME}'

  # Dataset.
  config.data_dtype_str = 'float32'
  config.dataset_configs = ml_collections.ConfigDict()
  config.dataset_configs.dataset = f'lsm_prod/{DATASET_NAME}'
  # config.dataset_configs.num_classes = NUM_CLASSES
  config.dataset_configs.train_split = 'train'  # train data split
  config.dataset_configs.train_num_samples = TRAIN_DATA_SIZE  # train sample
  # eval data split - note: this split is used for validation and test.
  config.dataset_configs.eval_split = 'test[:64]' if runlocal else 'test'
  config.dataset_configs.cache_dataset = CACHE_DATASET
  config.dataset_configs.prefetch_to_device = 2
  # Shuffle_buffer_size is per host, so small-ish is ok.
  config.dataset_configs.shuffle_buffer_size = 250_000
  config.dataset_configs.repeat_data = REPEAT_DATA
  config.dataset_configs.ohe_labels = OHE_LABELS

  # Model.
  if len(VARIANT.split('/')) == 3:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[2]  # patch height
  elif len(VARIANT.split('/')) == 2:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[1]  # patch height
  else:
    raise ValueError(f'Invalid model variant: {VARIANT}')

  version = 'Deb' if runlocal else version
  config.model_name = 'lsm_vit_mae'
  config.model = ml_collections.ConfigDict()
  # encoder
  config.model.hidden_size = model_constants.HIDDEN_SIZES[version]
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = tuple([int(patch_h), int(patch_w)])
  config.model.num_heads = model_constants.NUM_HEADS[version]
  config.model.mlp_dim = model_constants.MLP_DIMS[version]
  config.model.num_layers = model_constants.NUM_LAYERS[version]
  config.model.dropout_rate = 0.0
  config.model.classifier = 'none'  # Has to be "none" for the autoencoder
  config.model.representation_size = None
  config.model.positional_embedding = 'sinusoidal_2d'
  config.model.positional_embedding_decoder = 'sinusoidal_2d'
  # decoder
  config.model.decoder_config = ml_collections.ConfigDict()
  config.model.decoder_config.hidden_size = (
      model_constants.DECODER_HIDDEN_SIZES[version]
  )
  config.model.decoder_config.mlp_dim = model_constants.DECODER_MLP_DIMS[
      version
  ]
  config.model.decoder_config.num_layers = model_constants.DECODER_NUM_LAYERS[
      version
  ]
  config.model.decoder_config.num_heads = model_constants.DECODER_NUM_HEADS[
      version
  ]
  config.model.decoder_config.dropout_rate = 0.0
  config.model.decoder_config.attention_dropout_rate = 0.0

  config.masked_feature_loss = ml_collections.ConfigDict()
  config.masked_feature_loss.targets_type = 'rgb'
  config.masked_feature_loss.token_mask_probability = TOKEN_MASK_PROB
  config.masked_feature_loss.loss_only_masked_tokens = LOSS_ONLY_MASKED_TOKENS
  config.masked_feature_loss.loss_type = 'squared'  # 'squared' or 'absolute'

  # Datetime features.
  config.use_datetime_features = USE_DATETIME_FEATURES

  # Training.
  config.trainer_name = 'lsm_mae_trainer'
  config.batch_size = 8 if runlocal else BATCH_SIZE
  config.num_training_steps = NUM_TRAIN_STEPS
  config.log_eval_steps = LOG_EVAL_SUMMARY_STEPS
  config.log_summary_steps = LOG_EVAL_SUMMARY_STEPS
  config.rng_seed = 42
  config.use_train_augmentations = USE_TRAIN_AUGMENTATIONS
  config.train_augmentations = TRAIN_AUGMENTATIONS
  sched = ml_collections.ConfigDict()
  sched.re = '(.*)'
  sched.lr_configs = ml_collections.ConfigDict()
  sched.lr_configs.learning_rate_schedule = 'compound'
  sched.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
  sched.lr_configs.total_steps = NUM_TRAIN_STEPS
  sched.lr_configs.steps_per_cycle = sched.lr_configs.total_steps
  sched.lr_configs.warmup_steps = STEPS_PER_EPOCH * NUMBER_OF_EPOCH * 0.05
  sched.lr_configs.base_learning_rate = LRS[0]
  config.schedule = ml_collections.ConfigDict({'all': sched})

  # *Single* optimizer.
  optim = ml_collections.ConfigDict()
  optim.optax_name = 'scale_by_adam'
  # optim.optax = dict(mu_dtype='bfloat16')
  optim.optax_configs = ml_collections.ConfigDict({  # Optimizer settings.
      'b1': 0.9,
      'b2': 0.95,
  })
  config.optax = dict(mu_dtype='bfloat16')
  optim.max_grad_norm = 1.0

  optim.weight_decay = 1e-4
  optim.weight_decay_decouple = True
  config.optimizer = optim

  # Fewshot.
  # TODO(girishvn): This needs to be adapted to electrode dataset
  config.fewshot = get_config_common_few_shot(batch_size=config.batch_size)
  config.fewshot.datasets = {}
  config.fewshot.walk_first = ()
  config.fewshot.representation_layer = 'pre_logits'
  config.fewshot.log_eval_steps = LOG_EVAL_SUMMARY_STEPS

  # Logging.
  config.write_summary = True
  config.xprof = True  # Profile using xprof.
  config.checkpoint = True  # Do checkpointing.
  config.checkpoint_steps = LOG_CHECKPOINT_STEPS
  config.debug_train = False  # Debug mode during training.
  config.debug_eval = False  # Debug mode during eval.
  config.max_checkpoints_to_keep = MAX_NUM_CHECKPOINTS
  # BEGIN GOOGLE-INTERNAL
  if runlocal:
    # Current implementation fails with UPTC.
    config.count_flops = False
  # END GOOGLE-INTERNAL

  return config


# BEGIN GOOGLE-INTERNAL
def get_hyper(hyper):
  """Defines the hyper-parameters sweeps for doing grid search."""
  return hyper.product([
      hyper.sweep('config.schedule.all.lr_configs.base_learning_rate', LRS),
  ])


In [None]:
# @title Data Dir

data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/dmcduff/ttl=6w/msa_1_5/lsm_tfds_datasets'

# Mood / Stress Datasets

## 2000 Mood Dataset Explorer

In [None]:
# @title Load Dataset

DATASET_NAME = 'lsm_300min_2000_mood_balanced'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_tiny_dataset.get_dataset(config, data_rng)

print('Processed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])


In [None]:
# @title Dataset Information

dataset_name = 'lsm_prod/' + DATASET_NAME
ds, info = tfds.load(
    dataset_name,
    data_dir=data_dir,
    shuffle_files=False,  # NOTE: train shuffle is done below.
    with_info=True
)
print('Dataset Information:\n')
print(info)

# Sample Data
print('\n\nExample Sample Data:\n')
explore_sample_data(dataset, OHE_LABELS)

In [None]:
# @title Mood Label Breakdown

dataset = lsm_tiny_dataset.get_dataset(config, data_rng)
train_info = dataset_split_information(dataset.train_iter, event='mood', offset=65536)
valid_info = dataset_split_information(dataset.valid_iter, event='mood', offset=65536)

train_counter = Counter(train_info['log_value'])
valid_counter = Counter(valid_info['log_value'])
test_counter =  valid_counter

print('Num Examples:')
print('Train Samples', train_info['num_examples'])
print('Valid Samples', valid_info['num_examples'])

print('\nDataset Breakdown:')
train_df = pd.DataFrame.from_dict({'log_value': train_counter.keys(), 'train_count': train_counter.values()})
test_df = pd.DataFrame.from_dict({'log_value': test_counter.keys(), 'test_count': test_counter.values()})
mood_data2000_df = pd.merge(train_df, test_df, on='log_value', how='outer')
mood_data2000_df.sort_values(by=['log_value'], ascending=True)


## 2000 Stress Dataset Explorer

In [None]:
# @title Load Dataset

DATASET_NAME = 'lsm_300min_2000_stress_balanced'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_tiny_dataset.get_dataset(config, data_rng)

print('Processed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

In [None]:
# @title Stress Label Breakdown

dataset = lsm_tiny_dataset.get_dataset(config, data_rng)
train_info = dataset_split_information(dataset.train_iter, event='stress', offset=0)
valid_info = dataset_split_information(dataset.valid_iter, event='stress', offset=0)

train_counter = Counter(train_info['log_value'])
valid_counter = Counter(valid_info['log_value'])
test_counter =  valid_counter

print('Num Examples:')
print('Train Samples', train_info['num_examples'])
print('Valid Samples', valid_info['num_examples'])

print('\nDataset Breakdown:')
train_df = pd.DataFrame.from_dict({'log_value': train_counter.keys(), 'train_count': train_counter.values()})
test_df = pd.DataFrame.from_dict({'log_value': test_counter.keys(), 'test_count': test_counter.values()})
stress_data2000_df = pd.merge(train_df, test_df, on='log_value', how='outer')
stress_data2000_df


# Activity Datasets

In [None]:
# @title Load Dataset

DATASET_NAME = 'lsm_300min_600_activities_balanced_v4'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_tiny_dataset.get_dataset(config, data_rng)

print('Processed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

In [None]:
# @title Dataset Information
OHE_LABELS = True
dataset_name = 'lsm_prod/' + DATASET_NAME
ds, info = tfds.load(
    dataset_name,
    data_dir=data_dir,
    shuffle_files=False,  # NOTE: train shuffle is done below.
    with_info=True
)
print('Dataset Information:\n')
print(info)

# Sample Data
print('\n\nExample Sample Data:\n')
explore_sample_data(dataset, OHE_LABELS)

In [None]:
# @title Exercise Label Breakdown

dataset = lsm_tiny_dataset.get_dataset(config, data_rng)
train_info = dataset_split_information(dataset.train_iter, event='exercise', offset=65536)
valid_info = dataset_split_information(dataset.valid_iter, event='exercise', offset=65536)

train_counter = Counter(train_info['log_value'])
valid_counter = Counter(valid_info['log_value'])

print('Num Examples:')
print('Train Samples', train_info['num_examples'])
print('Valid Samples', valid_info['num_examples'])

print('\nDataset Breakdown:')
train_df = pd.DataFrame.from_dict({'log_value': train_counter.keys(), 'train_count': train_counter.values()})
test_df = pd.DataFrame.from_dict({'log_value': valid_counter.keys(), 'test_count': valid_counter.values()})
data_activity_large_df = pd.merge(train_df, test_df, on='log_value', how='outer')
data_activity_large_df


#  Fewshot Exercise Detection (ACTIVITY VS MOOD)

In [None]:
DATASET_NAME = 'fewshot_lsm_300min_mood_vs_activity'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset
config.update({'fewshot_samples_per_class': 2})
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_fewshot_mood_vs_activity_dataset.get_dataset(config, data_rng)

print('Processed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])


In [None]:
config.update({'fewshot_samples_per_class': 6200})

dataset = get_dataset(config, data_rng)
train_info = dataset_split_information(dataset.train_iter, event=None, offset=65536)
valid_info = dataset_split_information(dataset.valid_iter, event=None, offset=65536)

train_counter = Counter(train_info['log_value'])
valid_counter = Counter(valid_info['log_value'])
test_counter =  valid_counter

print('Num Examples:')
print('Train Samples', train_info['num_examples'])
print('Valid Samples', valid_info['num_examples'])

print('\nDataset Breakdown:')
train_df = pd.DataFrame.from_dict({'log_value': train_counter.keys(), 'train_count': train_counter.values()})
test_df = pd.DataFrame.from_dict({'log_value': test_counter.keys(), 'test_count': test_counter.values()})
mood_data2000_df = pd.merge(train_df, test_df, on='log_value', how='outer')
mood_data2000_df.sort_values(by=['log_value'], ascending=True)

# Biological Sex Dataset (Derived from Activity Dataset)

In [None]:
# @title Get Dataset

DATASET_NAME = 'lsm_300min_600_biological_sex'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset

rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)

start_t = time.time()
dataset = lsm_biological_sex_dataset.get_dataset(config, data_rng)
end_t = time.time()

print('Dataset Time', end_t - start_t)

print('\nProcessed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

In [None]:
# @title Train Sample  Breakdown

state_t = time.time()

label_list = []
batch_count = 0
for d in dataset.train_iter:
  if batch_count % 1000 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)

  # OHE label
  label = jnp.argmax(d['label'], axis=-1)
  label = label[valid]
  label = label.tolist()
  label_list += label

end_t = time.time()

print('Time', end_t - state_t)
print('\nTrain Data Splits:')
mood_counter = Counter()
for l in label_list:
  mood_counter[l] += 1

for k in mood_counter.keys():
  print(k, mood_counter[k])

In [None]:
# @title Valid Sample Breakdown

state_t = time.time()

label_list = []
batch_count = 0
for d in dataset.valid_iter:
  if batch_count % 1000 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)

  # OHE label
  label = jnp.argmax(d['label'], axis=-1)
  label = label[valid]
  label = label.tolist()
  label_list += label

end_t = time.time()

print('Time', end_t - state_t)
print('\nValid Data Splits:')
mood_counter = Counter()
for l in label_list:
  mood_counter[l] += 1

for k in mood_counter.keys():
  print(k, mood_counter[k])

# Age Dataset (Derived from Activity Dataset)

In [None]:
# @title Get Dataset

DATASET_NAME = 'lsm_300min_600_binnned_age'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset

rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)

start_t = time.time()
dataset = lsm_binned_age_dataset.get_dataset(config, data_rng)
end_t = time.time()

print('Dataset Time', end_t - start_t)

print('\nProcessed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

In [None]:
# @title Train Sample  Breakdown

state_t = time.time()

label_list = []
batch_count = 0
for d in dataset.train_iter:
  if batch_count % 1000 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)

  # OHE label
  label = jnp.argmax(d['label'], axis=-1)
  label = label[valid]
  label = label.tolist()
  label_list += label

end_t = time.time()

print('Time', end_t - state_t)
print('\nTrain Data Splits:')
mood_counter = Counter()
for l in label_list:
  mood_counter[l] += 1

for k in mood_counter.keys():
  print(k, mood_counter[k])

In [None]:
# @title Test Sample  Breakdown

state_t = time.time()

label_list = []
batch_count = 0
for d in dataset.valid_iter:
  if batch_count % 1000 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)

  # OHE label
  label = jnp.argmax(d['label'], axis=-1)
  label = label[valid]
  label = label.tolist()
  label_list += label

end_t = time.time()

print('Time', end_t - state_t)
print('\nTrain Data Splits:')
mood_counter = Counter()
for l in label_list:
  mood_counter[l] += 1

for k in mood_counter.keys():
  print(k, mood_counter[k])

# Subject Dependent Mood Dataset

## BUILD DATASET

In [None]:
# @title HELPERS

def _bytestring_feature(list_of_bytestrings) -> tf.train.Feature:
  return tf.train.Feature(
      bytes_list=tf.train.BytesList(value=list_of_bytestrings))


def _int_feature(list_of_ints) -> tf.train.Feature:
  return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))


def _float_feature(list_of_floats) -> tf.train.Feature:
  return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))


def get_useful_fields(example):

  return {
      'input_signal': example['input_signal'],
      'label': example['label'],
      'exercise_log': example['metadata']['exercise_log'],
      'mood_log': example['metadata']['mood_log'],
      'log_value': example['metadata']['log_value'],
  }


def serialize_example(input_signal, label, exercise_log, mood_log, log_value):

    # Convert to a TFRecord-compatible format
    examples_features = {
        'input_signal': _float_feature(input_signal.numpy().flatten().tolist()),
        'label': _int_feature([label]),
        'exercise_log': _int_feature([exercise_log]),
        'mood_log': _int_feature([mood_log]),
        'log_value': _int_feature([log_value]),
    }
    # Create an Example protobuf
    example_proto = tf.train.Example(
        features=tf.train.Features(feature=examples_features)
    )
    return example_proto.SerializeToString()


# Wrap the function for TensorFlow compatibility
def tf_serialize_example(example):
  input_signal = example['input_signal']
  label = example['label']
  exercise_log = example['exercise_log']
  mood_log = example['mood_log']
  log_value = example['log_value']

  # Call tf.py_function with individual arguments
  tf_string = tf.py_function(
      serialize_example,
      [input_signal, label, exercise_log, mood_log, log_value],
      tf.string
  )
  return tf_string


def add_key(example):
  dt = example['metadata']['DT']  # this is the number
  id = example['metadata']['ID']  # this is the string

  number_str = tf.strings.as_string(dt)  # Convert the int to string
  combined_key = tf.strings.join([id, number_str], separator="_")  # Combine with separator
  example["key"] = combined_key
  return example


def filter_by_datetime_key(example, allowed_keys):
  key = example['key']
  keep_example = tf.reduce_any(tf.math.equal(key, allowed_keys))
  return keep_example


In [None]:
# @title LOAD DATASET

dataset_name = 'lsm_prod/lsm_300min_2000_mood_balanced'
data_dir=(
    '/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/'
    'deid/exp/dmcduff/ttl=6w/msa_1_5/lsm_tfds_datasets'
)

train_ds, info = tfds.load(
    dataset_name,
    data_dir=data_dir,
    split='train',
    shuffle_files=False,
    with_info=True,
)

test_ds, _ = tfds.load(
    dataset_name,
    data_dir=data_dir,
    split='test',
    shuffle_files=False,
    with_info=True,
)

ds = train_ds.concatenate(test_ds)


# DECODE SUBJECT ID

def subj_decode_example(example):
  subj_id = example['metadata']['ID']
  dt = example['metadata']['DT']

  def decode_id(id_tensor):
    return id_tensor.numpy().decode('utf-8')

  def decode_dt(dt_tensor):
    return dt_tensor.numpy()

  example['metadata']['ID'] = tf.py_function(decode_id, [subj_id], tf.string)
  example['metadata']['DT'] = tf.py_function(decode_dt, [dt], tf.int64)
  return example

ds = ds.map(subj_decode_example)



In [None]:
# @title GET DATASET STATS

start_t = time.time()
t_wind = 30 * 60  # 30 minutes
event_dict = {}

for d in ds:
  md = d['metadata']
  dt = md['DT']
  id = md['ID']
  id = id.numpy().decode('utf-8')
  dt = dt.numpy()

  # Check if subject ID in event_dict
  if id in event_dict.keys():
    event_dict[id]['subj_count'] += 1

    # Iterate over existing the event time stamps
    dt_overlap = False
    for dt_k in event_dict[id].keys():
      if dt_k == 'subj_count':
        continue

      # If stamp stamps are similar
      # append to list
      if dt <= dt_k + t_wind and dt >= dt_k - t_wind:
        dt_overlap = True
        event_dict[id][dt_k].append(dt)
        break

    # If no similar time stamp exists
    if not dt_overlap:
      event_dict[id][dt] = [dt]

  # If not add it to the dictionary
  else:
    event_dict[id] = {}
    event_dict[id]['subj_count'] = 1  # init subject count to 1
    event_dict[id][dt] = [dt]  # init a list of date time events in this range

end_t = time.time()
print('Time', end_t - start_t)


# Stats in the form of:
# SUBJ ID 1
#     subj_count: X
#     DT1: []
#     DT2: []
#     ...
#     DTY: []
#
# ...
#
# SUBJ ID N
#


In [None]:
# @title SPLIT INTO TRAIN AND TEST SAMPLES

random.seed(42)
min_subj_sample_count = 30
min_subj_event_count = 5

idx = 0
total_samples = 0

p_train = 0.8
p_test = 0.2

train_dt_list = []
test_dt_list = []

# Iterate through subject IDs
for id in event_dict.keys():
  sample_count = event_dict[id]['subj_count']  # get number of samples per subject
  event_count = len(event_dict[id].keys()) - 1  # get number of events per subject

  if (
      sample_count >= min_subj_sample_count and
      event_count >= min_subj_event_count
  ):

    # Get list of the events
    event_info = event_dict[id]

    # Pop Subject Sample Count Off the List
    events_dt_list = list(event_info.keys())
    events_dt_list.remove('subj_count')

    # Take DT Event List
    num_events = len(events_dt_list)
    num_train = int(p_train*num_events)
    num_test = num_events - num_train

    # Shuffle events and split
    random.shuffle(events_dt_list)
    train_keys = events_dt_list[:num_train]
    test_kets = events_dt_list[num_train:]

    # Assign each datetime per event to either train or test
    for k in train_keys:
      subj_dt_k = [id + '_' + str(dt) for dt in event_info[k]]
      train_dt_list += subj_dt_k
      # train_dt_list += event_info[k]
    for k in test_kets:
      subj_dt_k = [id + '_' + str(dt) for dt in event_info[k]]
      test_dt_list += subj_dt_k
      # test_dt_list += event_info[k]

    # Logging
    print(f'{idx}: {id}, {sample_count}, {event_count}')
    total_samples += sample_count
    idx += 1

print('\nTotal Samples', total_samples)
print('Num Train Sample', len(train_dt_list))
print('Num Test Sample', len(test_dt_list))

count = 0
for t in train_dt_list:
  if t in test_dt_list:
    print(t)
    count += 1
print('Train Test Overlap Samples', count)

In [None]:
# @title Filter and Split into Train Test
ds = ds.map(add_key)

# Splitting Train and Test
train_filter_fn = functools.partial(
    filter_by_datetime_key, allowed_keys=train_dt_list
)
test_filter_fn = functools.partial(
    filter_by_datetime_key, allowed_keys=test_dt_list
)
train_ds = ds.filter(train_filter_fn)
test_ds = ds.filter(test_filter_fn)

# Get useful fields
train_ds = train_ds.map(get_useful_fields)
test_ds = test_ds.map(get_useful_fields)

# Count Samples Per Split
train_count = 0
test_count = 0
for d in train_ds:
  train_count += 1
for d in test_ds:
  test_count += 1
print('Train Count', train_count)
print('Test Count', test_count)

In [None]:
# @title Save Dataset Out In Example Form

serialized_train_ds = train_ds.map(tf_serialize_example)
serialized_test_ds = test_ds.map(tf_serialize_example)

In [None]:
# @title Save Examples to TF Records

data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/girishvn/ttl=6w/lsm_processed_datasets/'
train_fname = 'processed_subj_dependent_mood_train.tfrecord'
test_fname = 'processed_subj_dependent_mood_test.tfrecord'

train_fpath = os.path.join(data_dir, train_fname)
val_fpath = os.path.join(data_dir, test_fname)

train_count = 0
val_count = 0

start_t = time.time()
# Write out train data
with tf.io.TFRecordWriter(train_fpath) as writer:
  for serialized_example in serialized_train_ds:
    writer.write(serialized_example.numpy())

    train_count += 1
    if train_count % 100 == 0:
      print(f'Processed {train_count} examples in {time.time() - start_t} s.')

end_t = time.time()
print('Train Time', end_t - start_t)

start_t = time.time()
# Write out test data
with tf.io.TFRecordWriter(val_fpath) as writer:
  for serialized_example in serialized_test_ds:
    writer.write(serialized_example.numpy())

    val_count += 1
    if val_count % 100 == 0:
      print(f'Processed {val_count} examples in {time.time() - start_t} s.')

end_t = time.time()
print('Test Time', end_t - start_t)
print('Train Example Count', train_count)
print('Test Example Count', val_count)


In [None]:
print('Test Time', end_t - start_t)
print('Train Example Count', train_count)
print('Test Example Count', val_count)

## LOAD PREPROCESSED DATA

In [None]:
DATASET_NAME = 'lsm_300min_2000_mood_subject_dependent_preprocessed_40sps'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset
# config.dataset_configs.update({'samples_per_subject': 40, 'repeat': False})

rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)

start_t = time.time()
dataset = lsm_mood_subj_dependent_preprocessed_40sps_dataset.get_preprocessed_dataset(config, data_rng)
end_t = time.time()

print('Dataset Time', end_t - start_t)

print('\nProcessed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])


In [None]:
# @title Train Sample  Breakdown

state_t = time.time()

label_list = []
batch_count = 0
for d in dataset.train_iter:
  if batch_count % 1000 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)

  # OHE label
  label = jnp.argmax(d['label'], axis=-1)
  label = label[valid]
  label = label.tolist()
  label_list += label

end_t = time.time()

print('Time', end_t - state_t)
print('\nTrain Data Splits:')
mood_counter = Counter()
for l in label_list:
  mood_counter[l] += 1

for k in mood_counter.keys():
  print(k, mood_counter[k])

In [None]:
# @title Test Sample  Breakdown

state_t = time.time()

label_list = []
batch_count = 0
for d in dataset.valid_iter:
  if batch_count % 1000 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)

  # OHE label
  label = jnp.argmax(d['label'], axis=-1)
  label = label[valid]
  label = label.tolist()
  label_list += label

end_t = time.time()

print('Time', end_t - state_t)
print('\nTrain Data Splits:')
mood_counter = Counter()
for l in label_list:
  mood_counter[l] += 1

for k in mood_counter.keys():
  print(k, mood_counter[k])

# DEPRECATED DATASETS

## DEPRECATED Exercise Datasets

### 600 Sample / Activity Dataset

In [None]:
# @title Load Dataset

DATASET_NAME = 'lsm_300min_600_activities_balanced'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_tiny_dataset.get_dataset(config, data_rng)

print('Processed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

In [None]:
# @title Dataset Information

dataset_name = 'lsm_prod/' + DATASET_NAME
ds, info = tfds.load(
    dataset_name,
    data_dir=data_dir,
    shuffle_files=False,  # NOTE: train shuffle is done below.
    with_info=True
)
print('Dataset Information:\n')
print(info)

# Sample Data
print('\n\nExample Sample Data:\n')
explore_sample_data(dataset, OHE_LABELS)

In [None]:
# @title Exercise Label Breakdown

dataset = lsm_tiny_dataset.get_dataset(config, data_rng)
train_info = dataset_split_information(dataset.train_iter, event='exercise', offset=65536)
valid_info = dataset_split_information(dataset.valid_iter, event='exercise', offset=65536)
test_info = dataset_split_information(dataset.test_iter, event='exercise', offset=65536)

train_counter = Counter(train_info['log_value'])
valid_counter = Counter(valid_info['log_value'])
test_counter = Counter(test_info['log_value'])
test_counter = test_counter + valid_counter

print('Num Examples:')
print('Train Samples', train_info['num_examples'])
print('Valid Samples', valid_info['num_examples'] + test_info['num_examples'])

print('\nDataset Breakdown:')
train_df = pd.DataFrame.from_dict({'log_value': train_counter.keys(), 'train_count': train_counter.values()})
test_df = pd.DataFrame.from_dict({'log_value': test_counter.keys(), 'test_count': test_counter.values()})
data600_df = pd.merge(train_df, test_df, on='log_value', how='outer')
data600_df


### 300 Sample / Activity Dataset

In [None]:
# @title Load Dataset

DATASET_NAME = 'lsm_300min_300_activities_balanced'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_tiny_dataset.get_dataset(config, data_rng)

print('Processed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

In [None]:
# @title Dataset Information

dataset_name = 'lsm_prod/' + DATASET_NAME
ds, info = tfds.load(
    dataset_name,
    data_dir=data_dir,
    shuffle_files=False,  # NOTE: train shuffle is done below.
    with_info=True
)
print('Dataset Information:\n')
print(info)

# Sample Data
print('\n\nExample Sample Data:\n')
explore_sample_data(dataset, OHE_LABELS)

In [None]:
# @title Exercise Label Breakdown

dataset = lsm_tiny_dataset.get_dataset(config, data_rng)
train_info = dataset_split_information(dataset.train_iter, event='exercise', offset=65536)
valid_info = dataset_split_information(dataset.valid_iter, event='exercise', offset=65536)
test_info = dataset_split_information(dataset.test_iter, event='exercise', offset=65536)

train_counter = Counter(train_info['log_value'])
valid_counter = Counter(valid_info['log_value'])
test_counter = Counter(test_info['log_value'])
test_counter = test_counter + valid_counter

print('Num Examples:')
print('Train Samples', train_info['num_examples'])
print('Valid Samples', valid_info['num_examples'] + test_info['num_examples'])

print('\nDataset Breakdown:')
train_df = pd.DataFrame.from_dict({'log_value': train_counter.keys(), 'train_count': train_counter.values()})
test_df = pd.DataFrame.from_dict({'log_value': test_counter.keys(), 'test_count': test_counter.values()})
data300_df = pd.merge(train_df, test_df, on='log_value', how='outer')
data300_df


### 600 / 300 Dataset Breakdowns

In [None]:
data_df = pd.merge(data600_df, data300_df, on='log_value', how='outer')
name_col = [log_value_map_dict[id] for id in data_df['log_value']]
data_df['activity_name'] = name_col
data_df.columns = ['log_value', 'train600', 'test600', 'train300', 'test300', 'activity_name']
data_df = data_df.sort_values(by=['log_value'], ascending=True)

data_df[['activity_name', 'log_value', 'train600', 'test600', 'train300', 'test300']]

In [None]:
label_subset = [90013, 56001, 90014, 90019, 90017, 52000, 90005, 90024, 90009]

data_9class_df = data_df[data_df['log_value'].isin(label_subset)]

train600_sum = data_9class_df['train600'].sum()
test600_sum = data_9class_df['test600'].sum()

print('Train 600: ', train600_sum)
print('Test 600: ', test600_sum)

data_9class_df


### 600 Sample / Activity Data Subset / 9 Classes

In [None]:
# @title Load Dataset and Explore Sample Data

DATASET_NAME = 'lsm_300min_600_activities_9class_subset'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1
REPEAT_DATA = False

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_activity_subset_dataset.get_dataset(config, data_rng)

print('Processed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

# Sample Data
print('\n\nExample Sample Data:\n')
explore_sample_data(dataset, OHE_LABELS)

In [None]:
# @title Exercise Label Breakdown

dataset = lsm_activity_subset_dataset.get_dataset(config, data_rng)
train_info = dataset_split_information(dataset.train_iter, event='exercise', offset=65536)
valid_info = dataset_split_information(dataset.valid_iter, event='exercise', offset=65536)

train_counter = Counter(train_info['log_value'])
valid_counter = Counter(valid_info['log_value'])
test_counter = valid_counter

print('Num Examples:')
print('Train Samples', train_info['num_examples'])
print('Valid Samples', valid_info['num_examples'])

print('\nDataset Breakdown:')
train_df = pd.DataFrame.from_dict({'log_value': train_counter.keys(), 'train_count': train_counter.values()})
test_df = pd.DataFrame.from_dict({'log_value': test_counter.keys(), 'test_count': test_counter.values()})
data600_9class_df = pd.merge(train_df, test_df, on='log_value', how='outer')
data600_9class_df

## DEPRECATED Subject Dependent Mood Dataset

### Count Mood Samples By User

In [None]:
dataset_name = 'lsm_prod/lsm_300min_2000_mood_balanced'
data_dir = '/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/dmcduff/ttl=6w/msa_1_5/lsm_tfds_datasets'
batch_size = 16

train_ds, info = tfds.load(
    dataset_name,
    data_dir=data_dir,
    split='train',
    shuffle_files=False,  # NOTE: train shuffle is done below.
    with_info=True,
)

val_ds, info = tfds.load(
    dataset_name,
    data_dir=data_dir,
    split='test',
    shuffle_files=True,
    with_info=True,
)

ds = train_ds.concatenate(val_ds)

# preprocesses dataset
options = tf.data.Options()
options.threading.private_threadpool_size = 48
ds = ds.with_options(options)

ds = ds.batch(batch_size, drop_remainder=False)  # batch
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)  # prefetch
maybe_pad_batches_train = functools.partial(
    dataset_utils.maybe_pad_batch,
    train=False,
    batch_size=batch_size,
    inputs_key='input_signal',
)
ds = iter(ds)
ds = map(dataset_utils.tf_to_numpy, ds)
ds = map(maybe_pad_batches_train, ds)



In [None]:
state_t = time.time()

subj_list = []
label_list = []

batch_count = 0
for d in ds:
  if batch_count % 100 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)
  subjs = d['metadata']['ID'][valid]
  log_vals = d['metadata']['log_value'][valid]
  subjs = subjs.tolist()
  log_vals = log_vals.tolist()

  subj_list += subjs
  label_list += log_vals

end_t = time.time()

print('\nTotal Items:', len(subj_list))
print('Traversal Time', end_t - state_t)


In [None]:
df = pd.DataFrame({'Subject': subj_list, 'Class': label_list})
count_table = df.groupby(['Subject', 'Class']).size().unstack(fill_value=0)
count_table['total_labels'] = count_table.sum(axis=1)
count_table['all_classes'] = (count_table > 0).all(axis=1)
count_table = count_table.sort_values(by=['all_classes', 'total_labels'], ascending=[False, False])

pd.set_option('display.max_rows', 500)
count_table

### Subject Dependent Mood Dataset

In [None]:
# @title Dataset

"""Electrodes dataset data preprocesser and loader.

Adapted from a combination of the following files:
google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py
google3/third_party/py/scenic/dataset_lib/dataset_utils.py
"""

import collections
import functools
from typing import Any, Optional

from absl import logging
import jax.numpy as jnp
import jax.profiler
import ml_collections  # pylint: disable=unused-import
from scenic.dataset_lib import dataset_utils
import tensorflow as tf
import tensorflow_datasets as tfds

from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
from google3.experimental.largesensormodels.scenic.datasets import lsm_tiny_dataset


def filter_allowed_subjects(example, allowed_subjects):
  """Filter out examples where the label is not in allowed_labels."""
  subj = example['metadata']['ID']
  keep_example = tf.reduce_any(tf.math.equal(subj, allowed_subjects))
  return keep_example


def update_metadata(
    metadata, dataset_name, patch_size, dataset_configs
):
  """Update metadata to reflect resizing and addition of datetime features."""
  # Setup: Get dataset name, feature shape, and possible datetime features.
  metadata_update = dict()
  dataset_name = dataset_name.split('/')[-1]
  time_features = dataset_constants.lsm_dataset_constants[dataset_name].get(
      'datetime_features', None
  )
  feature_shape = list(metadata['input_shape'][1:])
  feature_indices = list(range(feature_shape[1]))

  # Split features from time series features
  # NOTE: This assumes that the original 'input_signal' field has sensor
  # features contactanated to datetime features along the feature (w) dimension.
  if time_features is not None:
    # Get datetime indicies
    time_feature_indices = list(time_features['indices'])
    # Remove datetime indices from feature indices
    feature_indices = list(set(feature_indices) - set(time_features['indices']))
    # Get updated feature and datetime feature shapes.
    time_feature_shape = feature_shape.copy()  # update time feature shape
    time_feature_shape[1] = len(time_feature_indices)
    feature_shape[1] = len(feature_indices)  # update feature shape
  else:
    time_feature_shape = None

  # Padding: Update shape to reflect padding (for perfect patching).
  # valid_feats arrays denote which features are valid (1) vs padded (0).
  # 1. Update for sensor features
  _, pad_w, feat_shape_new = lsm_tiny_dataset.get_height_crop_width_pad(
      tuple(feature_shape), patch_size
  )
  valid_feat_mask = [0] * pad_w[0] + [1] * feature_shape[1] + [0] * pad_w[1]
  metadata_update['input_shape'] = tuple([-1] + list(feat_shape_new))
  metadata_update['input_valid_feats'] = tuple(valid_feat_mask)

  # 2. Update for datetime features
  if time_features is not None:
    _, time_pad_w, time_feature_shape_new = (
        lsm_tiny_dataset.get_height_crop_width_pad(
            tuple(time_feature_shape), patch_size
        )
    )
    valid_time_feat_mask = (
        [0] * time_pad_w[0] + [1] * time_feature_shape[1] + [0] * time_pad_w[1]
    )
    metadata_update['datetime_input_shape'] = tuple(
        [-1] + list(time_feature_shape_new)
    )
    metadata_update['datime_valid_feats'] = tuple(valid_time_feat_mask)

  else:
    metadata_update['datetime_input_shape'] = None
    metadata_update['datime_valid_feats'] = None

  # Update if dataset it one-hot-encoded or not.
  if 'activities' in dataset_name or 'mood' in dataset_name:
    metadata_update['target_is_onehot'] = True
    metadata_update['num_classes'] = len(
        dataset_constants.lsm_dataset_constants[dataset_name]['log_values']
    )
  elif 'stress' in dataset_name:
    metadata_update['target_is_onehot'] = True
    metadata_update['num_classes'] = 2

  # 4. Add dataset log values and log value names and number of classes.
  log_values = dataset_constants.lsm_dataset_constants[dataset_name].get(
      'log_values', None
  )
  log_value_names = dataset_constants.lsm_dataset_constants[dataset_name].get(
      'log_value_names', None
  )
  metadata_update['log_values'] = log_values
  metadata_update['log_value_names'] = log_value_names

  # 7. Update time cropping:
  start, end = dataset_configs.get('relative_time_window', (None, None))
  if end is None:
    end = 1
  if start is None:
    start = 0

  # Time Crop image based on horizon.
  # Get number of patches along time axis (h).
  p_h = patch_size[0]
  h = feat_shape_new[0]
  n_h = h // p_h
  start_idx = int(start * n_h) * p_h
  end_idx = int(end * n_h) * p_h
  metadata_update['input_shape'] = tuple(
      [-1] + [end_idx - start_idx] + list(feat_shape_new)[1:]
  )

  return metadata_update


def get_subject_dependent_mood_dataset(
    *,
    config,
    num_shards,
    batch_size,
    eval_batch_size=None,
    dtype_str='float32',
    shuffle_seed=0,
    rng=None,
    shuffle_buffer_size=None,
    dataset_service_address: Optional[str] = None,
    data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/dmcduff/ttl=6w/msa_1_5/lsm_tfds_datasets',
):
  """Gets and formats the Subject Dependent Mood Dataset.

  Adapted from:
  google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py and
  google3/third_party/py/scenic/dataset_lib/dataset_utils.py.

  Args:
    config: ml_collections.ConfigDict; Config for the experiment.
    num_shards: int; Number of shards to split the dataset into.
    batch_size: int; Batch size for training.
    eval_batch_size: int; Batch size for evaluation.
    dtype_str: str; Data type of the image.
    shuffle_seed: int; Seed for shuffling the dataset.
    rng: jax.random.PRNGKey; Random number generator key.
    shuffle_buffer_size: int; Size of the shuffle buffer.
    dataset_service_address: str; Address of the dataset service.
    data_dir: str; Directory of the dataset.

  Returns:
    A dataset_utils.Dataset object.
  """

  # Setup: General
  if rng is None:
    rng = jax.random.PRNGKey(config.rng_seed)

  # 1. Process information.
  p_idx = jax.process_index()  # current process index
  p_cnt = jax.process_count()  # process count (number of processes)

  aug_rngs = jax.random.split(rng, p_cnt)  # per-device augmentation seeds
  aug_rng = aug_rngs[p_idx]  # device augmentation seed
  tf_aug_rng = aug_rng[0]  # jax random seeds are arrays, tf expects an int.
  del rng

  # 2. dataset and data type information.
  dataset_configs = config.dataset_configs  # get dataset configurations.
  dtype = getattr(tf, dtype_str)  # data dtype
  if eval_batch_size is None:  # set eval batch size
    eval_batch_size = batch_size

  # 3. Used dataset name.
  used_dataset_name = 'lsm_prod/lsm_300min_2000_mood_balanced'

  # 4. Repeat dataset.
  repeat_ds = dataset_configs.get('repeat_data', True)

  # Setup: Mapping functions.
  # 2. Preprocessing, augmentation, and cropping/padding functions.
  preprocess_fn = functools.partial(
      lsm_tiny_dataset.preprocess_example,
      dataset_name=used_dataset_name,
      dtype=dtype
  )
  # 3. Augmentation function.
  augment_fn = functools.partial(
      lsm_tiny_dataset.augment_example,
      augmentations=config.get('train_augmentations', []),
      seed=tf_aug_rng,
  )
  # 4. Crop and pad features and time features to be patch size compatible.
  crop_and_pad_fn = functools.partial(
      lsm_tiny_dataset.patch_compatible_resize_example,
      patch_size=config.model.patches.size
  )

  # 5. Time crop data input
  start, end = dataset_configs.get('relative_time_window', (None, None))
  if (start is not None) or (end is not None):
    time_crop_examples = True
  else:
    time_crop_examples = False
  time_crop_fn = functools.partial(
      lsm_tiny_dataset.time_crop_example,
      patch_size=config.model.patches.size,
      start=start,
      end=end
  )

  # Setup: Data splits.
  # Load dataset splits.
  train_ds = tfds.load(
      used_dataset_name,
      data_dir=data_dir,
      split='train',
      shuffle_files=False,  # NOTE: train shuffle is done below.
  )
  val_ds = tfds.load(
      used_dataset_name,
      data_dir=data_dir,
      split='test',
      shuffle_files=False,
  )
  logging.info(  # pylint:disable=logging-fstring-interpolation
      'Loaded combined train + val split '
      f'{p_idx}/{p_cnt} from {used_dataset_name}.'
  )
  ds = train_ds.concatenate(val_ds)

  # Data processing and preperation.
  # 0. Enable multi threaded workers.
  options = tf.data.Options()
  options.threading.private_threadpool_size = 48
  ds = ds.with_options(options)

  # 1. Per-process split: Split splits evenly per worker).
  # Count samples per subject.
  subj_label_counts = collections.Counter()
  for d in ds:
    subj = d['metadata']['ID']
    subj_label_counts[subj.numpy().decode('utf-8')] += 1

  # Filter down to subjects with at least N samples.
  allowed_subjs = [
      subj for subj, count in subj_label_counts.items()
      if count >= dataset_configs.samples_per_subject
  ]
  filter_fn = functools.partial(
      filter_allowed_subjects, allowed_subjects=allowed_subjs
  )
  ds = ds.filter(filter_fn)

  # Split the data into train and val splits.
  # Splits each participants data evenly between train and val.
  def filter_by_subj(subj):
    return lambda x: tf.equal(
        x['metadata']['ID'], subj
    )

  train_subj_splits, valid_subj_splits = [], []
  num_train_samples, num_val_samples = 0, 0
  for subj in allowed_subjs:
    subj_ds = ds.filter(filter_by_subj(subj))
    size_subj_ds = sum(1 for _ in subj_ds)

    train_size = int(0.8 * size_subj_ds)
    num_train_samples += train_size
    num_val_samples += size_subj_ds - train_size

    subj_train_split = subj_ds.take(train_size)
    subj_val_split = subj_ds.skip(train_size)

    train_subj_splits.append(subj_train_split)
    valid_subj_splits.append(subj_val_split)

  # Concat class datasets.
  train_ds = train_subj_splits[0]
  val_ds = valid_subj_splits[0]
  for i in range(1, len(allowed_subjs)):
    train_ds = train_ds.concatenate(train_subj_splits[i])
    val_ds = val_ds.concatenate(valid_subj_splits[i])

  # Get samples per class
  spc = collections.Counter()
  for d in train_ds:
    log_val = int(d['metadata']['log_value'])
    spc[log_val] += 1

  spc_labels = tf.convert_to_tensor(list(spc.keys()))
  spc_label_counts = tf.convert_to_tensor(list(spc.values()))

  # Get mood log values
  dataset_key = used_dataset_name.split('/')[-1]
  offset = tf.cast(
      dataset_constants.lsm_dataset_constants[dataset_key]['log_value_offset'],
      tf.int32
  )
  log_val_list = tf.convert_to_tensor(
      dataset_constants.lsm_dataset_constants[dataset_key]['log_values']
  )
  log_val_list = log_val_list - offset  # offset value

  sorted_indices_tensor2 = tf.argsort(spc_labels)
  sorted_tensor2 = tf.gather(spc_labels, sorted_indices_tensor2)
  matching_indices = tf.argsort(tf.argsort(log_val_list))
  mapping = tf.gather(sorted_indices_tensor2, matching_indices)
  label_counts = tf.gather(spc_label_counts, mapping)

  # Split dataset over host devices.
  train_ds = train_ds.shard(p_cnt, p_idx)
  val_ds = val_ds.shard(p_cnt, p_idx)

  # 2. Preprocessing: Applied before `ds.cache()` to re-use it.
  train_ds = train_ds.map(
      preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  val_ds = val_ds.map(
      preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )

  # 3. Cache datasets: This can signficantly speed up training.
  if dataset_configs.cache_dataset:
    train_ds = train_ds.cache()
    val_ds = val_ds.cache()

  # 4 Train repeats and augmentations.
  if repeat_ds:
    train_ds = train_ds.repeat()  # repeat
  # NOTE: Train augmentations are done after repeat for true randomness.
  if config.use_train_augmentations:
    train_ds = train_ds.map(  # train data augmentations
        augment_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )

  # 5. Crop and pad for perfect patching.
  train_ds = train_ds.map(  # crop/pad for perfect patching
      crop_and_pad_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  val_ds = val_ds.map(  # crop/pad for perfect patching
      crop_and_pad_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )

  # 6. Time crop input data.
  if time_crop_examples:
    train_ds = train_ds.map(
        time_crop_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    val_ds = val_ds.map(
        time_crop_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )

  # 7. Data preperation (shuffling, augmentations, batching, eval repeat, etc.).
  # 7a. Train: Shuffle, batch, prefetch
  shuffle_buffer_size = shuffle_buffer_size or (8 * batch_size)
  train_ds = train_ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)  # shuffle
  train_ds = train_ds.batch(batch_size, drop_remainder=True)  # batch
  train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)  # prefetch

  # 7b. Validation: Batch, Repeat, Prefetch
  val_ds = val_ds.batch(batch_size, drop_remainder=False)  # batch
  if repeat_ds:
    val_ds = val_ds.repeat()  # repeat
  val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)  # prefetch

  # Ensure that no seed is set if dataset_service_address is defined.
  if dataset_service_address:
    if shuffle_seed is not None:
      raise ValueError(
          'Using dataset service with a random seed causes each '
          'worker to produce exactly the same data. Add '
          'config.shuffle_seed = None to your config if you '
          'want to run with dataset service.'
      )
    train_ds = dataset_utils.distribute(train_ds, dataset_service_address)
    logging.info('Using the tf.data service at %s', dataset_service_address)

  # Other mappings
  # 1. Set up batch padding: If batch remainders are NOT dropped batches may be
  # padded to allow for an enough patches to contain all samples.
  maybe_pad_batches_train = functools.partial(
      dataset_utils.maybe_pad_batch,
      train=True,
      batch_size=batch_size,
      inputs_key='input_signal',
  )
  maybe_pad_batches_eval = functools.partial(
      dataset_utils.maybe_pad_batch,
      train=False,
      batch_size=eval_batch_size,
      inputs_key='input_signal',
  )

  # 2. Set up batch sharding: Shard batches to be processed by multiple devices.
  shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards)

  # 3. Apply other mappings and Iter dataset
  train_iter = iter(train_ds)
  train_iter = map(dataset_utils.tf_to_numpy, train_iter)
  train_iter = map(maybe_pad_batches_train, train_iter)
  train_iter = map(shard_batches, train_iter)

  val_iter = iter(val_ds)
  val_iter = map(dataset_utils.tf_to_numpy, val_iter)
  val_iter = map(maybe_pad_batches_eval, val_iter)
  val_iter = map(shard_batches, val_iter)

  # Save meta data
  info = tfds.builder(used_dataset_name, data_dir=data_dir, try_gcs=True).info
  input_shape = tuple([-1] + list(info.features['input_signal'].shape))
  meta_data = {
      'input_shape': input_shape,
      'num_train_examples': num_train_samples,
      'num_val_examples': num_val_samples,
      'num_test_examples': 0,
      'input_dtype': getattr(jnp, dtype_str),
      'label_counts': label_counts,
      # The following two fields are set as defaults and may be updated in the
      # update_metadata function below.
      'target_is_onehot': False,
      'num_classes': None,
  }

  # Update metadata to reflect preprocessing, and paddings
  # (Changes in shape, and features).
  meta_data.update(
      update_metadata(
          meta_data,
          dataset_name=used_dataset_name,
          patch_size=config.model.patches.size,
          dataset_configs=dataset_configs,
      )
  )

  # Return dataset structure.
  return dataset_utils.Dataset(train_iter, val_iter, None, meta_data)


def get_dataset(
    config: Any,
    data_rng: jnp.ndarray,
    *,
    num_local_shards: Optional[int] = None,
    dataset_service_address: Optional[str] = None,
    **kwargs: Any,
) -> dataset_utils.Dataset:
  """Adapted from: google3/third_party/py/scenic/train_lib/train_utils.py."""

  # Get device count
  device_count = jax.device_count()
  logging.info('device_count: %d', device_count)
  logging.info('num_hosts : %d', jax.process_count())
  logging.info('host_id : %d', jax.process_index())

  # Set the dataset builder functions
  # Get list of supported, non-deprecated datasets.
  dataset_name = config.dataset_configs.dataset
  dataset_suported_list = ['lsm_300min_2000_mood_subject_dependant']
  if dataset_name.split('/')[1] in dataset_suported_list:
    dataset_builder = get_subject_dependent_mood_dataset
  else:
    raise ValueError(f'Dataset {dataset_name} is not supported.')

  # Get batch size
  batch_size = config.batch_size
  if batch_size % device_count > 0:
    raise ValueError(
        f'Batch size ({batch_size}) must be divisible by the '
        f'number of devices ({device_count})'
    )

  local_batch_size = batch_size // jax.process_count()
  device_batch_size = batch_size // device_count
  logging.info('local_batch_size : %d', local_batch_size)
  logging.info('device_batch_size : %d', device_batch_size)

  # Get shuffle seed - ensure it exists
  shuffle_seed = config.get('shuffle_seed', None)
  if dataset_service_address and shuffle_seed is not None:
    raise ValueError(
        'Using dataset service with a random seed causes each '
        'worker to produce exactly the same data. Add '
        'config.shuffle_seed = None to your config if you want '
        'to run with dataset service.'
    )

  # Get shuffle buffer size.
  shuffle_buffer_size = config.get('shuffle_buffer_size', None)
  # Local shard count.
  num_local_shards = num_local_shards or jax.local_device_count()

  # Build the dataset
  ds = dataset_builder(
      config=config,
      num_shards=num_local_shards,
      batch_size=local_batch_size,
      dtype_str=config.data_dtype_str,
      shuffle_seed=shuffle_seed,
      rng=data_rng,
      shuffle_buffer_size=shuffle_buffer_size,
      dataset_service_address=dataset_service_address,
      **kwargs,
  )

  return ds


In [None]:
DATASET_NAME = 'lsm_300min_2000_mood_subject_dependant'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset
config.dataset_configs.update({'samples_per_subject': 40, 'repeat': False})

rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)

start_t = time.time()
dataset = get_dataset(config, data_rng)
end_t = time.time()

print('Dataset Time', end_t - start_t)

print('\nProcessed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

In [None]:
# @title Train Sample  Breakdown

state_t = time.time()

label_list = []
batch_count = 0
for d in dataset.train_iter:
  if batch_count % 1000 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)
  log_vals = d['log_value'][valid]
  log_vals = log_vals.tolist()
  label_list += log_vals

end_t = time.time()

print('Time', end_t - state_t)
print('\nTrain Data Splits:')
mood_counter = Counter()
for l in label_list:
  mood_counter[l] += 1

for k in mood_counter.keys():
  print(k, mood_counter[k])

In [None]:
# @title Valid Sample  Breakdown

state_t = time.time()

label_list = []
batch_count = 0
for d in dataset.valid_iter:
  if batch_count % 1000 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)
  log_vals = d['log_value'][valid]
  log_vals = log_vals.tolist()
  label_list += log_vals

end_t = time.time()

print('Time', end_t - state_t)
print('\nValid Data Splits:')
mood_counter = Counter()
for l in label_list:
  mood_counter[l] += 1

for k in mood_counter.keys():
  print(k, mood_counter[k])

### Save and Load Processed Mood Dataset
This is required because running traversals of the dataset on XM is EXTREMELY slow and causing jobs to be killed.

### Preprocess and Save Dataset

In [None]:
# @title Format To-Save Processed Mood Dataset

"""Electrodes dataset data preprocesser and loader.

Adapted from a combination of the following files:
google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py
google3/third_party/py/scenic/dataset_lib/dataset_utils.py
"""

import collections
import functools
from typing import Any, Optional

from absl import logging
import jax.numpy as jnp
import jax.profiler
import ml_collections  # pylint: disable=unused-import
from scenic.dataset_lib import dataset_utils
import tensorflow as tf
import tensorflow_datasets as tfds

from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
from google3.experimental.largesensormodels.scenic.datasets import lsm_tiny_dataset


def _bytestring_feature(list_of_bytestrings) -> tf.train.Feature:
  return tf.train.Feature(
      bytes_list=tf.train.BytesList(value=list_of_bytestrings))


def _int_feature(list_of_ints) -> tf.train.Feature:
  return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))


def _float_feature(list_of_floats) -> tf.train.Feature:
  return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))


def serialize_example(input_signal, label, exercise_log, mood_log, log_value):

    # Convert to a TFRecord-compatible format
    examples_features = {
        'input_signal': _float_feature(input_signal.numpy().flatten().tolist()),
        'label': _int_feature([label]),
        'exercise_log': _int_feature([exercise_log]),
        'mood_log': _int_feature([mood_log]),
        'log_value': _int_feature([log_value]),
    }
    # Create an Example protobuf
    example_proto = tf.train.Example(
        features=tf.train.Features(feature=examples_features)
    )
    return example_proto.SerializeToString()

# Wrap the function for TensorFlow compatibility
def tf_serialize_example(example):
  input_signal = example['input_signal']
  label = example['label']
  exercise_log = example['exercise_log']
  mood_log = example['mood_log']
  log_value = example['log_value']

  # Call tf.py_function with individual arguments
  tf_string = tf.py_function(
      serialize_example,
      [input_signal, label, exercise_log, mood_log, log_value],
      tf.string
  )
  return tf_string


def filter_allowed_subjects(example, allowed_subjects):
  """Filter out examples where the label is not in allowed_labels."""
  subj = example['metadata']['ID']
  keep_example = tf.reduce_any(tf.math.equal(subj, allowed_subjects))
  return keep_example


def get_useful_fields(example):

  return {
      'input_signal': example['input_signal'],
      'label': example['label'],
      'exercise_log': example['metadata']['exercise_log'],
      'mood_log': example['metadata']['mood_log'],
      'log_value': example['metadata']['log_value'],
  }


def get_subject_dependent_mood_dataset(
    *,
    config,
    num_shards,
    batch_size,
    eval_batch_size=None,
    dtype_str='float32',
    shuffle_seed=0,
    rng=None,
    shuffle_buffer_size=None,
    dataset_service_address: Optional[str] = None,
    data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/dmcduff/ttl=6w/msa_1_5/lsm_tfds_datasets',
):
  """Gets and formats the Subject Dependent Mood Dataset.

  Adapted from:
  google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py and
  google3/third_party/py/scenic/dataset_lib/dataset_utils.py.

  Args:
    config: ml_collections.ConfigDict; Config for the experiment.
    num_shards: int; Number of shards to split the dataset into.
    batch_size: int; Batch size for training.
    eval_batch_size: int; Batch size for evaluation.
    dtype_str: str; Data type of the image.
    shuffle_seed: int; Seed for shuffling the dataset.
    rng: jax.random.PRNGKey; Random number generator key.
    shuffle_buffer_size: int; Size of the shuffle buffer.
    dataset_service_address: str; Address of the dataset service.
    data_dir: str; Directory of the dataset.

  Returns:
    A dataset_utils.Dataset object.
  """

  # Setup: General
  if rng is None:
    rng = jax.random.PRNGKey(config.rng_seed)

  # 1. Process information.
  p_idx = jax.process_index()  # current process index
  p_cnt = jax.process_count()  # process count (number of processes)

  aug_rngs = jax.random.split(rng, p_cnt)  # per-device augmentation seeds
  aug_rng = aug_rngs[p_idx]  # device augmentation seed
  tf_aug_rng = aug_rng[0]  # jax random seeds are arrays, tf expects an int.
  del rng

  # 2. dataset and data type information.
  dataset_configs = config.dataset_configs  # get dataset configurations.
  dtype = getattr(tf, dtype_str)  # data dtype
  if eval_batch_size is None:  # set eval batch size
    eval_batch_size = batch_size

  # 3. Used dataset name.
  used_dataset_name = 'lsm_prod/lsm_300min_2000_mood_balanced'

  # Setup: Data splits.
  # Load dataset splits.
  train_ds = tfds.load(
      used_dataset_name,
      data_dir=data_dir,
      split='train',
      shuffle_files=False,  # NOTE: train shuffle is done below.
  )
  val_ds = tfds.load(
      used_dataset_name,
      data_dir=data_dir,
      split='test',
      shuffle_files=False,
  )
  logging.info(  # pylint:disable=logging-fstring-interpolation
      'Loaded combined train + val split '
      f'{p_idx}/{p_cnt} from {used_dataset_name}.'
  )
  ds = train_ds.concatenate(val_ds)

  # Data processing and preperation.
  # 0. Enable multi threaded workers.
  options = tf.data.Options()
  options.threading.private_threadpool_size = 48
  ds = ds.with_options(options)

  # 1. Per-process split: Split splits evenly per worker).
  # Count samples per subject.
  subj_label_counts = collections.Counter()
  for d in ds:
    subj = d['metadata']['ID']
    subj_label_counts[subj.numpy().decode('utf-8')] += 1

  # Filter down to subjects with at least N samples.
  allowed_subjs = [
      subj for subj, count in subj_label_counts.items()
      if count >= dataset_configs.samples_per_subject
  ]
  filter_fn = functools.partial(
      filter_allowed_subjects, allowed_subjects=allowed_subjs
  )
  ds = ds.filter(filter_fn)

  # Split the data into train and val splits.
  # Splits each participants data evenly between train and val.
  def filter_by_subj(subj):
    return lambda x: tf.equal(
        x['metadata']['ID'], subj
    )

  train_subj_splits, valid_subj_splits = [], []
  num_train_samples, num_val_samples = 0, 0
  for subj in allowed_subjs:
    subj_ds = ds.filter(filter_by_subj(subj))
    size_subj_ds = sum(1 for _ in subj_ds)

    train_size = int(0.8 * size_subj_ds)
    num_train_samples += train_size
    num_val_samples += size_subj_ds - train_size

    subj_train_split = subj_ds.take(train_size)
    subj_val_split = subj_ds.skip(train_size)

    train_subj_splits.append(subj_train_split)
    valid_subj_splits.append(subj_val_split)

  # Concat class datasets.
  train_ds = train_subj_splits[0]
  val_ds = valid_subj_splits[0]
  for i in range(1, len(allowed_subjs)):
    train_ds = train_ds.concatenate(train_subj_splits[i])
    val_ds = val_ds.concatenate(valid_subj_splits[i])

  # Get samples per class
  spc = collections.Counter()
  for d in train_ds:
    log_val = int(d['metadata']['log_value'])
    spc[log_val] += 1

  spc_labels = tf.convert_to_tensor(list(spc.keys()))
  spc_label_counts = tf.convert_to_tensor(list(spc.values()))

  # Get mood log values
  dataset_key = used_dataset_name.split('/')[-1]
  offset = tf.cast(
      dataset_constants.lsm_dataset_constants[dataset_key]['log_value_offset'],
      tf.int32
  )
  log_val_list = tf.convert_to_tensor(
      dataset_constants.lsm_dataset_constants[dataset_key]['log_values']
  )
  log_val_list = log_val_list - offset  # offset value

  sorted_indices_tensor2 = tf.argsort(spc_labels)
  sorted_tensor2 = tf.gather(spc_labels, sorted_indices_tensor2)
  matching_indices = tf.argsort(tf.argsort(log_val_list))
  mapping = tf.gather(sorted_indices_tensor2, matching_indices)
  label_counts = tf.gather(spc_label_counts, mapping)

  # Only get useful data fields
  train_ds = train_ds.map(get_useful_fields)
  val_ds = val_ds.map(get_useful_fields)

  return train_ds, val_ds


def get_dataset2(
    config: Any,
    data_rng: jnp.ndarray,
    *,
    num_local_shards: Optional[int] = None,
    dataset_service_address: Optional[str] = None,
    **kwargs: Any,
) -> dataset_utils.Dataset:
  """Adapted from: google3/third_party/py/scenic/train_lib/train_utils.py."""

  # Get device count
  device_count = jax.device_count()
  logging.info('device_count: %d', device_count)
  logging.info('num_hosts : %d', jax.process_count())
  logging.info('host_id : %d', jax.process_index())

  # Set the dataset builder functions
  # Get list of supported, non-deprecated datasets.
  dataset_name = config.dataset_configs.dataset
  dataset_suported_list = ['lsm_300min_2000_mood_subject_dependant']
  if dataset_name.split('/')[1] in dataset_suported_list:
    dataset_builder = get_subject_dependent_mood_dataset
  else:
    raise ValueError(f'Dataset {dataset_name} is not supported.')

  # Get batch size
  batch_size = config.batch_size
  if batch_size % device_count > 0:
    raise ValueError(
        f'Batch size ({batch_size}) must be divisible by the '
        f'number of devices ({device_count})'
    )

  local_batch_size = batch_size // jax.process_count()
  device_batch_size = batch_size // device_count
  logging.info('local_batch_size : %d', local_batch_size)
  logging.info('device_batch_size : %d', device_batch_size)

  # Get shuffle seed - ensure it exists
  shuffle_seed = config.get('shuffle_seed', None)
  if dataset_service_address and shuffle_seed is not None:
    raise ValueError(
        'Using dataset service with a random seed causes each '
        'worker to produce exactly the same data. Add '
        'config.shuffle_seed = None to your config if you want '
        'to run with dataset service.'
    )

  # Get shuffle buffer size.
  shuffle_buffer_size = config.get('shuffle_buffer_size', None)
  # Local shard count.
  num_local_shards = num_local_shards or jax.local_device_count()

  # Build the dataset
  train_ds, val_ds = dataset_builder(
      config=config,
      num_shards=num_local_shards,
      batch_size=local_batch_size,
      dtype_str=config.data_dtype_str,
      shuffle_seed=shuffle_seed,
      rng=data_rng,
      shuffle_buffer_size=shuffle_buffer_size,
      dataset_service_address=dataset_service_address,
      **kwargs,
  )

  return train_ds, val_ds


In [None]:
DATASET_NAME = 'lsm_300min_2000_mood_subject_dependant'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset
config.dataset_configs.update({'samples_per_subject': 40, 'repeat': False})

rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)

start_t = time.time()
dataset = get_dataset2(config, data_rng)
train_ds, val_ds = dataset
end_t = time.time()

print('Dataset Time', end_t - start_t)

# Apply serialization function to each example in the dataset
serialized_train_ds = train_ds.map(tf_serialize_example)
serialized_val_ds = val_ds.map(tf_serialize_example)

In [None]:
# @title Save Examples to TF Records

data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/girishvn/ttl=6w/lsm_processed_datasets/'
train_fname = 'processed_mood_2000_train.tfrecord'
test_fname = 'processed_mood_2000_test.tfrecord'

train_fpath = os.path.join(data_dir, train_fname)
val_fpath = os.path.join(data_dir, test_fname)

train_count = 0
val_count = 0

start_t = time.time()
# Write out train data
with tf.io.TFRecordWriter(train_fpath) as writer:
  for serialized_example in serialized_train_ds:
    writer.write(serialized_example.numpy())

    train_count += 1
    if train_count % 100 == 0:
      print(f'Processed {train_count} examples in {time.time() - start_t} s.')

end_t = time.time()
print('Train Time', end_t - start_t)

start_t = time.time()
# Write out test data
with tf.io.TFRecordWriter(val_fpath) as writer:
  for serialized_example in serialized_val_ds:
    writer.write(serialized_example.numpy())

    val_count += 1
    if val_count % 100 == 0:
      print(f'Processed {val_count} examples in {time.time() - start_t} s.')

end_t = time.time()
print('Test Time', end_t - start_t)


print('DONE DONE DONE')


In [None]:
# @title Example Counts
print('Train Count', train_count)
print('Test Count', val_count)

### Load Preprocessed and Saved Data

In [None]:
# @title Data Loader to Load this Saved Data

"""Electrodes dataset data preprocesser and loader.

Adapted from a combination of the following files:
google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py
google3/third_party/py/scenic/dataset_lib/dataset_utils.py

NOTE: This dataset is a HACKY implementation of the
lsm_mood_subj_dependent_dataset which specifically loads a preprpocessed dataset
where subject dependent splits are already created, and where only subjects with
40+ samples are included.

This was created as part of the ICLR '25 Rebuttal for the LSM paper.
This hacky implementation is necessary as lsm_mood_subj_dependent_dataset.py is
made extremely slow on XM (causing idle failures) due to the need to traverse
the dataset to create the subject dependent splits.

If you are interested in using this dataset, please consider using the
lsm_mood_subj_dependent_dataset.py instead, and / or re-implemting this.
"""


import functools
import os
from typing import Any, Optional

from absl import logging
import jax.numpy as jnp
import jax.profiler
import ml_collections  # pylint: disable=unused-import
from scenic.dataset_lib import dataset_utils
import tensorflow as tf
import tensorflow_datasets as tfds

from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
from google3.experimental.largesensormodels.scenic.datasets import lsm_tiny_dataset


def preprocess_example(example, dataset_name, dtype=tf.float32):
  """Preprocesses the given example.

  Adapted from /largesensormodels/scenic/datasets/lsm_tiny_dataset.py.
  This function is modified to work with the pre-processed dataset, which has
  slight differences from the original dataset (e.g. not have 'metadata' field).

  Args:
    example: dict; Example that has an 'image' and a 'label'.
    dataset_name: str; Name of the dataset. This is used to extract the
      datetime features.
    dtype: Tensorflow data type; Data type of the image.

  Returns:
    A preprocessed example.

  NOTE: This assumes that the image is in the shape [H, W, C],
    where H is the Time axis, and W is the feature axis.
  """
  dataset_name = dataset_name.split('/')[-1]
  features = tf.cast(example['input_signal'], dtype=dtype)
  time_features = dataset_constants.lsm_dataset_constants[dataset_name].get(
      'datetime_features', None
  )

  if time_features is None:
    raise ValueError(dataset_name)

  # Split input into inputs and time-features
  feature_indices = list(range(features.shape[1]))
  if time_features is not None:

    # Get the inidices of datetime_features,
    # and split them from the indicies of other features.
    time_feature_indices = list(time_features['indices'])
    feature_indices = list(set(feature_indices) - set(time_features['indices']))
    time_feature_indices = tf.convert_to_tensor(time_feature_indices)
    feature_indices = tf.convert_to_tensor(feature_indices)

    # Using the above indices, split the feature tensor.
    time_features = tf.gather(features, time_feature_indices, axis=1)
    features = tf.gather(features, feature_indices, axis=1)
  else:
    time_features = None

  # Stress / Mood / Activity Labels:
  # A) Binary label of stress (0/1).
  stress_label = tf.cast(example['label'][0], dtype=tf.int32)  # pylint: disable=unused-variable
  # B) Boolean logs (True/False) of an logged exercise or mood event.
  # (exercise and mood events are mutally exclusive).
  exercise_log = example['exercise_log'][0]
  mood_log = example['mood_log'][0]
  # C) The log value (int 64 log code) for an excercise or mood event.
  # NOTE: that exercise and mood events DO NOT occur simultaneously
  log_value = tf.cast(example['log_value'][0], tf.int32)

  # Return preprocessed feature and desired labels.
  # A) If activities or mood dataset: the log value is indexed [0, n classes],
  # one-hot encoded, and returned as the label.
  # a) offset value of log_value - an artifact of dataset creation.
  log_value_offset = tf.cast(
      dataset_constants.lsm_dataset_constants[dataset_name][
          'log_value_offset'
      ],
      tf.int32
  )
  # b) list of possible labels (log_values) for a dataset.
  log_value_label_list = tf.convert_to_tensor(
      dataset_constants.lsm_dataset_constants[dataset_name]['log_values']
  )
  # c) offset log_value.
  log_value_label_list = log_value_label_list - log_value_offset
  n_classes = len(log_value_label_list)  # number of classes in label map
  # d) generate label map.
  lookup_initializer = tf.lookup.KeyValueTensorInitializer(
      keys=log_value_label_list, values=tf.range(n_classes)
  )
  label_map = tf.lookup.StaticHashTable(lookup_initializer, default_value=-1)
  # e) get label index from label map.
  label_idx = label_map.lookup(log_value)
  return {
      'input_signal': features,
      'datetime_signal': time_features,
      'label': tf.one_hot(label_idx, n_classes),
      'exercise_log': exercise_log,
      'mood_log': mood_log,
      'log_value': log_value,
  }


def parse_tfexample_fn(example):
  """Parses features from serialized tf example."""
  # The dataset has more labels than we use.
  feature_spec = {
      'input_signal': tf.io.FixedLenFeature(
          shape=[300, 30, 1], dtype=tf.float32
      ),
      'label': tf.io.FixedLenFeature(
          shape=1, dtype=tf.int64
      ),
      'exercise_log': tf.io.FixedLenFeature(
          shape=1, dtype=tf.int64
      ),
      'mood_log': tf.io.FixedLenFeature(
          shape=1, dtype=tf.int64
      ),
      'log_value': tf.io.FixedLenFeature(
          shape=1, dtype=tf.int64
      ),
  }
  parsed_example = tf.io.parse_single_example(example, feature_spec)
  parsed_example['exercise_log'] = tf.cast(
      parsed_example['exercise_log'], tf.bool
  )
  parsed_example['mood_log'] = tf.cast(parsed_example['mood_log'], tf.bool)
  return parsed_example


def update_metadata(
    metadata, dataset_name, patch_size, dataset_configs
):
  """Update metadata to reflect resizing and addition of datetime features."""
  # Setup: Get dataset name, feature shape, and possible datetime features.
  metadata_update = dict()
  dataset_name = dataset_name.split('/')[-1]
  time_features = dataset_constants.lsm_dataset_constants[dataset_name].get(
      'datetime_features', None
  )
  feature_shape = list(metadata['input_shape'][1:])
  feature_indices = list(range(feature_shape[1]))

  # Split features from time series features
  # NOTE: This assumes that the original 'input_signal' field has sensor
  # features contactanated to datetime features along the feature (w) dimension.
  if time_features is not None:
    # Get datetime indicies
    time_feature_indices = list(time_features['indices'])
    # Remove datetime indices from feature indices
    feature_indices = list(set(feature_indices) - set(time_features['indices']))
    # Get updated feature and datetime feature shapes.
    time_feature_shape = feature_shape.copy()  # update time feature shape
    time_feature_shape[1] = len(time_feature_indices)
    feature_shape[1] = len(feature_indices)  # update feature shape
  else:
    time_feature_shape = None

  # Padding: Update shape to reflect padding (for perfect patching).
  # valid_feats arrays denote which features are valid (1) vs padded (0).
  # 1. Update for sensor features
  _, pad_w, feat_shape_new = lsm_tiny_dataset.get_height_crop_width_pad(
      tuple(feature_shape), patch_size
  )
  valid_feat_mask = [0] * pad_w[0] + [1] * feature_shape[1] + [0] * pad_w[1]
  metadata_update['input_shape'] = tuple([-1] + list(feat_shape_new))
  metadata_update['input_valid_feats'] = tuple(valid_feat_mask)

  # 2. Update for datetime features
  if time_features is not None:
    _, time_pad_w, time_feature_shape_new = (
        lsm_tiny_dataset.get_height_crop_width_pad(
            tuple(time_feature_shape), patch_size
        )
    )
    valid_time_feat_mask = (
        [0] * time_pad_w[0] + [1] * time_feature_shape[1] + [0] * time_pad_w[1]
    )
    metadata_update['datetime_input_shape'] = tuple(
        [-1] + list(time_feature_shape_new)
    )
    metadata_update['datime_valid_feats'] = tuple(valid_time_feat_mask)

  else:
    metadata_update['datetime_input_shape'] = None
    metadata_update['datime_valid_feats'] = None

  # Update if dataset it one-hot-encoded or not.
  metadata_update['target_is_onehot'] = True
  metadata_update['num_classes'] = len(
      dataset_constants.lsm_dataset_constants[dataset_name]['log_values']
  )

  # 4. Add dataset log values and log value names and number of classes.
  log_values = dataset_constants.lsm_dataset_constants[dataset_name].get(
      'log_values', None
  )
  log_value_names = dataset_constants.lsm_dataset_constants[dataset_name].get(
      'log_value_names', None
  )
  metadata_update['log_values'] = log_values
  metadata_update['log_value_names'] = log_value_names

  # 7. Update time cropping:
  start, end = dataset_configs.get('relative_time_window', (None, None))
  if end is None:
    end = 1
  if start is None:
    start = 0

  # Time Crop image based on horizon.
  # Get number of patches along time axis (h).
  p_h = patch_size[0]
  h = feat_shape_new[0]
  n_h = h // p_h
  start_idx = int(start * n_h) * p_h
  end_idx = int(end * n_h) * p_h
  metadata_update['input_shape'] = tuple(
      [-1] + [end_idx - start_idx] + list(feat_shape_new)[1:]
  )

  return metadata_update


def get_subject_dependent_mood_dataset(
    *,
    config,
    num_shards,
    batch_size,
    eval_batch_size=None,
    dtype_str='float32',
    shuffle_seed=0,
    rng=None,
    shuffle_buffer_size=None,
    dataset_service_address: Optional[str] = None,
    data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/dmcduff/ttl=6w/msa_1_5/lsm_tfds_datasets',
):
  """Gets and formats the Subject Dependent Mood Dataset.

  Adapted from:
  google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py and
  google3/third_party/py/scenic/dataset_lib/dataset_utils.py.

  Args:
    config: ml_collections.ConfigDict; Config for the experiment.
    num_shards: int; Number of shards to split the dataset into.
    batch_size: int; Batch size for training.
    eval_batch_size: int; Batch size for evaluation.
    dtype_str: str; Data type of the image.
    shuffle_seed: int; Seed for shuffling the dataset.
    rng: jax.random.PRNGKey; Random number generator key.
    shuffle_buffer_size: int; Size of the shuffle buffer.
    dataset_service_address: str; Address of the dataset service.
    data_dir: str; Directory of the dataset.

  Returns:
    A dataset_utils.Dataset object.
  """

  # START HARDCODED SECTION
  # As explained in the file header: this is hardcoded to read from a
  # pre-processed version of the subject dependent mood dataset.

  processed_data_dir = (
      '/namespace/fitbit-medical-sandboxes/partner/encrypted/'
      'chr-ards-electrodes/deid/exp/girishvn/ttl=6w/lsm_processed_datasets/'
  )
  train_fname = 'processed_mood_2000_train.tfrecord'
  test_fname = 'processed_mood_2000_test.tfrecord'

  # Reference dataset name, used to query for dataset_constants.
  used_dataset_name = 'lsm_prod/lsm_300min_2000_mood_balanced'

  # Pre-computed label counts per class.
  label_counts = [739, 472, 434, 787, 1138]

  # Pre-computed train and test sample counts.
  num_train_samples = 3570
  num_val_samples = 912
  # END HARDCODED SECTION

  train_fpath = os.path.join(processed_data_dir, train_fname)
  val_fpath = os.path.join(processed_data_dir, test_fname)

  # Setup: General
  if rng is None:
    rng = jax.random.PRNGKey(config.rng_seed)

  # 1. Process information.
  p_idx = jax.process_index()  # current process index
  p_cnt = jax.process_count()  # process count (number of processes)

  aug_rngs = jax.random.split(rng, p_cnt)  # per-device augmentation seeds
  aug_rng = aug_rngs[p_idx]  # device augmentation seed
  tf_aug_rng = aug_rng[0]  # jax random seeds are arrays, tf expects an int.
  del rng

  # 2. dataset and data type information.
  dataset_configs = config.dataset_configs  # get dataset configurations.
  dtype = getattr(tf, dtype_str)  # data dtype
  if eval_batch_size is None:  # set eval batch size
    eval_batch_size = batch_size

  # 4. Repeat dataset.
  repeat_ds = dataset_configs.get('repeat_data', True)

  # Setup: Mapping functions.
  # 2. Preprocessing, augmentation, and cropping/padding functions.
  preprocess_fn = functools.partial(
      preprocess_example,
      dataset_name=used_dataset_name,
      dtype=dtype
  )
  # 3. Augmentation function.
  augment_fn = functools.partial(
      lsm_tiny_dataset.augment_example,
      augmentations=config.get('train_augmentations', []),
      seed=tf_aug_rng,
  )
  # 4. Crop and pad features and time features to be patch size compatible.
  crop_and_pad_fn = functools.partial(
      lsm_tiny_dataset.patch_compatible_resize_example,
      patch_size=config.model.patches.size
  )

  # 5. Time crop data input
  start, end = dataset_configs.get('relative_time_window', (None, None))
  if (start is not None) or (end is not None):
    time_crop_examples = True
  else:
    time_crop_examples = False
  time_crop_fn = functools.partial(
      lsm_tiny_dataset.time_crop_example,
      patch_size=config.model.patches.size,
      start=start,
      end=end
  )

  # Setup: Data splits.
  # Load dataset splits.
  train_ds = tf.data.TFRecordDataset(train_fpath)
  val_ds = tf.data.TFRecordDataset(val_fpath)
  train_ds = train_ds.map(parse_tfexample_fn)
  val_ds = val_ds.map(parse_tfexample_fn)

  # Data processing and preperation.
  # 0. Enable multi threaded workers.
  options = tf.data.Options()
  options.threading.private_threadpool_size = 48
  train_ds = train_ds.with_options(options)
  val_ds = val_ds.with_options(options)

  # Split dataset over host devices.
  train_ds = train_ds.shard(p_cnt, p_idx)
  val_ds = val_ds.shard(p_cnt, p_idx)

  # 2. Preprocessing: Applied before `ds.cache()` to re-use it.
  train_ds = train_ds.map(
      preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  val_ds = val_ds.map(
      preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )

  # 3. Cache datasets: This can signficantly speed up training.
  if dataset_configs.cache_dataset:
    train_ds = train_ds.cache()
    val_ds = val_ds.cache()

  # 4 Train repeats and augmentations.
  if repeat_ds:
    train_ds = train_ds.repeat()  # repeat
  # NOTE: Train augmentations are done after repeat for true randomness.
  if config.use_train_augmentations:
    train_ds = train_ds.map(  # train data augmentations
        augment_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )

  # 5. Crop and pad for perfect patching.
  train_ds = train_ds.map(  # crop/pad for perfect patching
      crop_and_pad_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  val_ds = val_ds.map(  # crop/pad for perfect patching
      crop_and_pad_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )

  # 6. Time crop input data.
  if time_crop_examples:
    train_ds = train_ds.map(
        time_crop_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    val_ds = val_ds.map(
        time_crop_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )

  # 7. Data preperation (shuffling, augmentations, batching, eval repeat, etc.).
  # 7a. Train: Shuffle, batch, prefetch
  shuffle_buffer_size = shuffle_buffer_size or (8 * batch_size)
  train_ds = train_ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)  # shuffle
  train_ds = train_ds.batch(batch_size, drop_remainder=True)  # batch
  train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)  # prefetch

  # 7b. Validation: Batch, Repeat, Prefetch
  val_ds = val_ds.batch(batch_size, drop_remainder=False)  # batch
  if repeat_ds:
    val_ds = val_ds.repeat()  # repeat
  val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)  # prefetch

  # Ensure that no seed is set if dataset_service_address is defined.
  if dataset_service_address:
    if shuffle_seed is not None:
      raise ValueError(
          'Using dataset service with a random seed causes each '
          'worker to produce exactly the same data. Add '
          'config.shuffle_seed = None to your config if you '
          'want to run with dataset service.'
      )
    train_ds = dataset_utils.distribute(train_ds, dataset_service_address)
    logging.info('Using the tf.data service at %s', dataset_service_address)

  # Other mappings
  # 1. Set up batch padding: If batch remainders are NOT dropped batches may be
  # padded to allow for an enough patches to contain all samples.
  maybe_pad_batches_train = functools.partial(
      dataset_utils.maybe_pad_batch,
      train=True,
      batch_size=batch_size,
      inputs_key='input_signal',
  )
  maybe_pad_batches_eval = functools.partial(
      dataset_utils.maybe_pad_batch,
      train=False,
      batch_size=eval_batch_size,
      inputs_key='input_signal',
  )

  # 2. Set up batch sharding: Shard batches to be processed by multiple devices.
  shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards)

  # 3. Apply other mappings and Iter dataset
  train_iter = iter(train_ds)
  train_iter = map(dataset_utils.tf_to_numpy, train_iter)
  train_iter = map(maybe_pad_batches_train, train_iter)
  train_iter = map(shard_batches, train_iter)

  val_iter = iter(val_ds)
  val_iter = map(dataset_utils.tf_to_numpy, val_iter)
  val_iter = map(maybe_pad_batches_eval, val_iter)
  val_iter = map(shard_batches, val_iter)

  # Save meta data
  info = tfds.builder(used_dataset_name, data_dir=data_dir, try_gcs=True).info
  input_shape = tuple([-1] + list(info.features['input_signal'].shape))
  meta_data = {
      'input_shape': input_shape,
      'num_train_examples': num_train_samples,
      'num_val_examples': num_val_samples,
      'num_test_examples': 0,
      'input_dtype': getattr(jnp, dtype_str),
      'label_counts': label_counts,
      # The following two fields are set as defaults and may be updated in the
      # update_metadata function below.
      'target_is_onehot': False,
      'num_classes': None,
  }

  # Update metadata to reflect preprocessing, and paddings
  # (Changes in shape, and features).
  meta_data.update(
      update_metadata(
          meta_data,
          dataset_name=used_dataset_name,
          patch_size=config.model.patches.size,
          dataset_configs=dataset_configs,
      )
  )

  # Return dataset structure.
  return dataset_utils.Dataset(train_iter, val_iter, None, meta_data)


def get_preprocessed_dataset(
    config: Any,
    data_rng: jnp.ndarray,
    *,
    num_local_shards: Optional[int] = None,
    dataset_service_address: Optional[str] = None,
    **kwargs: Any,
) -> dataset_utils.Dataset:
  """Adapted from: google3/third_party/py/scenic/train_lib/train_utils.py."""

  # Get device count
  device_count = jax.device_count()
  logging.info('device_count: %d', device_count)
  logging.info('num_hosts : %d', jax.process_count())
  logging.info('host_id : %d', jax.process_index())

  # Set the dataset builder functions
  # Get list of supported, non-deprecated datasets.
  dataset_name = config.dataset_configs.dataset
  dataset_suported_list = [
      'lsm_300min_2000_mood_subject_dependent_preprocessed_40spc'
  ]
  if dataset_name.split('/')[1] in dataset_suported_list:
    dataset_builder = get_subject_dependent_mood_dataset
  else:
    raise ValueError(f'Dataset {dataset_name} is not supported.')

  # Get batch size
  batch_size = config.batch_size
  if batch_size % device_count > 0:
    raise ValueError(
        f'Batch size ({batch_size}) must be divisible by the '
        f'number of devices ({device_count})'
    )

  local_batch_size = batch_size // jax.process_count()
  device_batch_size = batch_size // device_count
  logging.info('local_batch_size : %d', local_batch_size)
  logging.info('device_batch_size : %d', device_batch_size)

  # Get shuffle seed - ensure it exists
  shuffle_seed = config.get('shuffle_seed', None)
  if dataset_service_address and shuffle_seed is not None:
    raise ValueError(
        'Using dataset service with a random seed causes each '
        'worker to produce exactly the same data. Add '
        'config.shuffle_seed = None to your config if you want '
        'to run with dataset service.'
    )

  # Get shuffle buffer size.
  shuffle_buffer_size = config.get('shuffle_buffer_size', None)
  # Local shard count.
  num_local_shards = num_local_shards or jax.local_device_count()

  # Build the dataset
  ds = dataset_builder(
      config=config,
      num_shards=num_local_shards,
      batch_size=local_batch_size,
      dtype_str=config.data_dtype_str,
      shuffle_seed=shuffle_seed,
      rng=data_rng,
      shuffle_buffer_size=shuffle_buffer_size,
      dataset_service_address=dataset_service_address,
      **kwargs,
  )

  return ds



In [None]:
DATASET_NAME = 'lsm_300min_2000_mood_subject_dependent_preprocessed_40spc'
TRAIN_DATA_SIZE = None
BATCH_SIZE = 1

config = get_config(runlocal=False)  # must be false to get full dataset
config.dataset_configs.update({'samples_per_subject': 40, 'repeat': False})

rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)

start_t = time.time()
dataset = get_preprocessed_dataset(config, data_rng)
end_t = time.time()

print('Dataset Time', end_t - start_t)

print('\nProcessed Dataset Meta Data:\n')
for k in dataset.meta_data.keys():
  print(k, dataset.meta_data[k])

In [None]:
# @title Train Sample  Breakdown

state_t = time.time()

label_list = []
batch_count = 0
for d in dataset.train_iter:
  if batch_count % 1000 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)
  log_vals = d['log_value'][valid]
  log_vals = log_vals.tolist()
  label_list += log_vals

end_t = time.time()

print('Time', end_t - state_t)
print('\nTrain Data Splits:')
mood_counter = Counter()
for l in label_list:
  mood_counter[l] += 1

for k in mood_counter.keys():
  print(k, mood_counter[k])

In [None]:
# @title Valid Sample  Breakdown

state_t = time.time()

label_list = []
batch_count = 0
for d in dataset.valid_iter:
  if batch_count % 1000 == 0:
    print(batch_count, time.time())
  batch_count += 1

  bmask = d['batch_mask']
  valid = np.where(bmask == 1)
  log_vals = d['log_value'][valid]
  log_vals = log_vals.tolist()
  label_list += log_vals

end_t = time.time()

print('Time', end_t - state_t)
print('\nValid Data Splits:')
mood_counter = Counter()
for l in label_list:
  mood_counter[l] += 1

for k in mood_counter.keys():
  print(k, mood_counter[k])