1.  Please do
    [AoD](https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-fitbit-prod-research-deid-eng-team:r&reason=%22b%2F285178698%22)
    before running this colab.
2.  Run `experimental/health_foundation_models/colab/colab_launch_borg_cpu.sh`
3.  Use `heath_foundation_models_cpu` as the colab kernel.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
from google3.medical.waveforms.modelling.lsm.datasets.lsm import sensor_constants
from google3.pyglib import gfile

In [None]:
DEFAULT_DATA_ROOT = '/namespace/fitbit-medical-sandboxes/jg/partner/encrypted/chr-ards-fitbit-prod-research/deid/exp/dmcduff/ttl=52w/lsm_v2/datasets/tfds'  # @param {type:"string"}

In [None]:
# @title Constants
labels = sensor_constants.FEATURES_TO_INCLUDE

In [None]:
# @title Check the paths

latest_dataset_paths = gfile.Glob(f'{DEFAULT_DATA_ROOT}/*')
latest_dataset_paths

In [None]:
# @title Utils


def check_dataset_length(root_data_path: str, dataset_name: str) -> None:
  """Checks and prints the length of a dataset.

  This function takes the root data path and the dataset name as input.
  It then uses gfile.Glob to find all the data samples within the dataset
  and prints the total number of samples found.

  Args:
    root_data_path: The root directory where the dataset is located.
    dataset_name: The name of the dataset to check.

  Returns:
      None. This function prints the dataset name and the number of data
      samples.
  """
  print('Dataset Name:', dataset_name)
  print(
      'Number of Data Sample:',
      len(gfile.Glob(f'{root_data_path}/{dataset_name}/lsm/*/*')),
  )


def inspect_dataset(
    root_data_path: str, dataset_name: str, data_class: str == 'lsm'
):
  """Loads, inspects, and visualizes a subset of a TensorFlow dataset.

  This function loads a specified dataset using `tfds.load`, prints its length,
  and then visualizes the first 5 samples. It extracts the 'mask' and
  'input_signal' from each sample, converts them to NumPy arrays, and
  uses the `visualize` function to display them.

  Args:
    root_data_path: The root directory where the dataset is located.
    dataset_name: The name of the dataset to inspect.

  Returns:
    None. This function prints information and displays visualizations.
  """
  print('Dataset Name:', dataset_name)
  try:
    data = tfds.load(
        data_class,
        data_dir=f'{root_data_path}/{dataset_name}',
        split='train',
        shuffle_files=False,
    )
    print('Number of Train Data Sample:', len(data))
  except:
    pass

  data_valid = tfds.load(
      data_class,
      data_dir=f'{root_data_path}/{dataset_name}',
      split='valid',
      shuffle_files=False,
  )
  print('Number of Valid Data Sample:', len(data_valid))

  for sample in data_valid.take(5):
    mask = tf.io.parse_tensor(sample['mask'], out_type=tf.bool).numpy().T
    sample = (
        tf.io.parse_tensor(sample['input_signal'], out_type=tf.double).numpy().T
    )

    visualize(mask, cmap='Greys')
    visualize(sample)

    print('--------------------------')


def visualize(
    sample_signal_input,
    figsize=(20, 5),
    title='',
    cmap='cool',
    dim=None,
    cbar=True,
    disabletext=False,
):
  """Visualizes a sample signal as a heatmap.

  This function creates a heatmap visualization of the input sample signal.
  It uses `seaborn.heatmap` to generate the heatmap and allows for customization
  of figure size, title, colormap, and display options.

  Args:
    sample_signal_input: The input sample signal data as a NumPy array.
    figsize: Tuple specifying the width and height of the figure (default: (20,
      5)).
    title: The title of the plot (default: '').
    cmap: The colormap to use for the heatmap (default: 'cool').
    dim: Optional dimension to select for visualization (default: None, displays
      all dimensions).
    cbar: Whether to display the colorbar (default: True).
    disabletext: Whether to disable the plot title (default: False).

  Returns:
    None. This function displays the heatmap visualization.
  """

  if dim is not None:
    sample_signal_input = sample_signal_input[[dim], :]
    labels_temp = []
  else:
    labels_temp = labels

  plt.figure(figsize=figsize)
  ax1 = plt.subplot2grid((1, 12), (0, 0), colspan=12)
  ax1 = sns.heatmap(
      sample_signal_input,
      cmap=cmap,
      cbar=cbar,
      linewidths=0.0,
      linecolor='black',
      alpha=0.8,
      ax=ax1,
      yticklabels=labels_temp,
  )

  for tick in ax1.get_xticklabels():
    tick.set_fontname('Ubuntu')
  ax1.tick_params(axis='x', labelsize=10.5)

  for tick in ax1.get_yticklabels():
    tick.set_fontname('Ubuntu')
  ax1.tick_params(axis='y', labelsize=10.5)

  # Set x-axis ticks every 4 hours
  tick_interval = 4 * 60  # 4 hours in minutes
  xticks = np.arange(0, sample_signal_input.shape[1], tick_interval)
  xtick_labels = [minutes_to_time(x) for x in xticks]
  ax1.set_xticks(xticks)
  ax1.set_xticklabels(xtick_labels, rotation=45, ha='right')
  # ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)

  plt.tight_layout()

  if not disabletext:
    plt.title(title)

  for i in np.arange(0, sample_signal_input.shape[1], 60):
    if i % (60 * 24) == 0:
      tempwidth, tempalpha = 2, 1
    else:
      tempwidth, tempalpha = 1, 0.4
    ax1.axvline(x=i, color='k', alpha=tempalpha, linewidth=tempwidth)

  for i in np.arange(0, sample_signal_input.shape[0] + 1, 1):
    ax1.axhline(y=i, color='k', alpha=0.4, linewidth=1)

  plt.tight_layout()
  plt.show()


def consecutive_ones_lengths(mask):
  """Calculates the lengths of consecutive sequences of 1s in a binary mask.

  This function identifies and measures the lengths of continuous stretches of
  1s within
  a given binary mask (an array consisting of 0s and 1s). It works by:

  1. Finding the points where the mask transitions from 0 to 1 (start of a
  sequence)
     and from 1 to 0 (end of a sequence) using `np.diff`.
  2. Using `np.where` to get the indices of these transitions.
  3. Calculating the length of each sequence by subtracting the start index from
  the end index.

  It's particularly useful for analyzing patterns or gaps in data represented by
  such masks.

  Args:
    mask: A 1-D NumPy array representing the binary mask.

  Returns:
    A 1-D NumPy array containing the lengths of each consecutive sequence of 1s
    found
    in the input mask.
  """

  # Find where the mask changes value
  diff = np.diff(mask, prepend=0, append=0)

  # Start and end indices of sequences of ones
  starts = np.where(diff == 1)[0]
  ends = np.where(diff == -1)[0]

  # Calculate lengths of each sequence of ones
  lengths = ends - starts
  return lengths


def minutes_to_time(x):
  """Converts minutes to a time string in HH:MM format.

  This function takes an integer representing a duration in minutes and
  converts it into a formatted time string in the format "HH:MM" (hours and
  minutes).

  Args:
    x: An integer representing the duration in minutes.

  Returns:
    A string representing the time in HH:MM format.
  """
  hours = int(x // 60)
  minutes = int(x % 60)
  return f'{hours:02d}:{minutes:02d}'


class StopExecution(Exception):
  """Custom exception used to halt the execution of a cell or process.

  This exception is designed to stop the execution flow without displaying
  a traceback. It is useful for scenarios where you want to terminate
  a process prematurely but avoid cluttering the output with unnecessary
  traceback information.

  Attributes: None

  Methods:
      _render_traceback_: Overrides the default traceback rendering to
          suppress the traceback output.
  """

  def _render_traceback_(self):
    return []

In [None]:
# @title Check the Data Length
check_dataset_length(
    DEFAULT_DATA_ROOT,
    'lsm_v2_pretrain_sessions_-1_windowsize_1440_sensorfeatures_26_validonly_False_missingratio_0.2_timestamp_202503080218',
)
check_dataset_length(
    DEFAULT_DATA_ROOT,
    'lsm_v2_pretrain_sessions_-1_windowsize_1440_sensorfeatures_26_validonly_False_missingratio_0.5_timestamp_202503090320',
)
check_dataset_length(
    DEFAULT_DATA_ROOT,
    'lsm_v2_pretrain_sessions_-1_windowsize_1440_sensorfeatures_26_validonly_False_missingratio_0.8_timestamp_202503091557',
)

In [None]:
# @title Inspect the 20% Daily Data

inspect_dataset(
    DEFAULT_DATA_ROOT,
    'lsm_v2_pretrain_sessions_-1_windowsize_1440_sensorfeatures_26_validonly_False_missingratio_0.2_timestamp_202503080218',
    'lsm',
)

In [None]:
# @title Inspect 50% Missing Daily Data

inspect_dataset(
    DEFAULT_DATA_ROOT,
    'lsm_v2_pretrain_sessions_-1_windowsize_1440_sensorfeatures_26_validonly_False_missingratio_0.5_timestamp_202503090320',
    'lsm',
)

In [None]:
# @title Inspect the 80% Daily Data

inspect_dataset(
    DEFAULT_DATA_ROOT,
    'lsm_v2_pretrain_sessions_-1_windowsize_1440_sensorfeatures_26_validonly_False_missingratio_0.8_timestamp_202503091557',
    'lsm',
)

In [None]:
# @title Inspect the perfect balanced test dataset

inspect_dataset(
    '/namespace/fitbit-medical-sandboxes/jg/partner/encrypted/chr-ards-fitbit-prod-research/deid/exp/dmcduff/ttl=52w/lsm_v2/datasets/tfds_test',
    'lsm_v2_missing_balanced_20250301_valid_dataset',
    'LsmMissingBalanced',
)