In [None]:
import numpy as np
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
from google3.pyglib import gfile
from matplotlib import pyplot as plt

In [None]:
SAMPLES_PER_DATASET = 2000
DEFAULT_DATA_ROOT = '/namespace/fitbit-medical-sandboxes/jg/partner/encrypted/chr-ards-fitbit-prod-research/deid/exp/dmcduff/ttl=52w/lsm_v2/datasets/tfds_test'
DATASET_NAME = 'lsm_v2_missing_balanced_20250301_valid_dataset_bounded_50p'
DATA_CLASS = 'LsmMissingBalanced'

In [None]:
# @title Creating concatenated_dataset

def load_dataset(dataset_path: str, use_shuffle: bool, sample_size: int):
  """Loads, inspects, and visualizes a subset of a TensorFlow dataset.

  Args:
    dataset_path: The dataset path of the specific dataset.
    use_shuffle: Whether to shuffle the dataset.
    sample_size: The number of samples to visualize.

  Returns:
    valid_data: The validation data.
  """
  data_valid = tfds.load(
      'lsm',
      data_dir=dataset_path,
      split='valid',
      shuffle_files=False,
  )
  if use_shuffle:
    data_valid = data_valid.shuffle(100000)
  data_valid = data_valid.take(sample_size)
  return data_valid


all_paths = gfile.Glob(
    f'{DEFAULT_DATA_ROOT}/lsm_v2_test_sessions_-1_windowsize_1440_sensorfeatures_26_validonly_True_missingratio*'
)[:3]
all_tf_ds = []
for path in all_paths:
  print(path)
  all_tf_ds.append(load_dataset(path, False, SAMPLES_PER_DATASET))

concatenated_dataset = all_tf_ds[0]
for ds in all_tf_ds[1:]:
  concatenated_dataset = concatenated_dataset.concatenate(ds)
print('Concated Length: ', len(concatenated_dataset))

In [None]:
# @title Save TFDS by inheriting dataset from prior datasets

# Define the output directory for the new dataset
output_dir = f'{DEFAULT_DATA_ROOT}/{DATASET_NAME}'


class LsmMissingBalanced(tfds.core.GeneratorBasedBuilder):
  VERSION = tfds.core.Version('2.0.0')

  def __init__(self, *args, **kwargs):
    # Initialize with the concatenated dataset passed as argument
    super().__init__(*args, **kwargs)

  def _info(self):

    original_builder = tfds.builder('lsm', data_dir=all_paths[-1])
    original_info = original_builder.info

    updated_split_info = tfds.core.SplitInfo(
        name='valid',
        shard_lengths=[len(concatenated_dataset)],  # Update num_examples
        num_bytes=original_info.splits[
            'valid'
        ].num_bytes,  # Preserve other fields
    )

    # Update the splits in the DatasetInfo
    splits_dict = {}
    splits_dict['valid'] = updated_split_info
    # Build and return a new DatasetInfo instance with the updated splits
    return tfds.core.DatasetInfo(
        builder=self,
        description=original_info.description,
        features=original_info.features,
        homepage=original_info.homepage,
        split_dict=splits_dict,  # Use the SplitDict here
        supervised_keys=original_info.supervised_keys,
        # version=self.VERSION,
        citation=original_info.citation,
    )

  def _split_generators(self, dl_manager):
    # Return a split generator for the validation set (or whatever split you need)
    return [
        tfds.core.SplitGenerator(
            name='valid',
            gen_kwargs={
                'dataset': concatenated_dataset
            },  # Use the concatenated dataset here
        ),
    ]

  def _generate_examples(self, dataset):
    # Ensure that tensors are converted to np.array compatible formats
    for idx, example in enumerate(dataset):
      # Convert tensors to np.array or native Python types
      example_dict = {
          key: self._convert_tensor_to_native(val)
          for key, val in example.items()
      }
      yield idx, example_dict

  def _convert_tensor_to_native(self, tensor):
    if isinstance(tensor, tf.Tensor):
      return tensor.numpy()  # Convert TensorFlow tensor to a NumPy array
    elif isinstance(tensor, np.ndarray):
      return tensor  # Already in compatible format
    else:
      return tensor  # If it's a native Python type (e.g., float, int), return it as is


# Function to save the concatenated dataset using the custom builder
def save_tfds_dataset(output_dir):
  builder = LsmMissingBalanced(
      data_dir=output_dir
  )  # Pass dataset to the builder
  builder.download_and_prepare()
  print(f'Dataset saved at: {output_dir}')


# Save the concatenated dataset

save_tfds_dataset(output_dir)

In [None]:
labels = [
    'HR',
    'eda_level_real',
    'leads_contact_counts',
    'steps',
    'jerk_auto',
    'log_energy',
    'covariance',
    'log_energy_ratio',
    'zero_crossing_std',
    'zero_crossing_avg',
    'axis_mean',
    'altim_std',
    'kurtosis',
    'sleep_coefficient',
    'wrist_temperatures',
    'hrv_shannon_entropy_rr',
    'hrv_shannon_entropy_rrd',
    'ceda_slope_real_micro_siemens',
    'rmssd_percentile_0595',
    'sdnn_percentile_0595',
    'hrv_percent_good',
    'hrv_rr_80th_percentile_mean',
    'hrv_rr_20th_percentile_mean',
    'hrv_rr_median',
    'hr_at_rest_mean',
    'skin_temperature_slope',
]

def minutes_to_time(x):
  """Converts minutes to a time string in HH:MM format.

  This function takes an integer representing a duration in minutes and
  converts it into a formatted time string in the format "HH:MM" (hours and
  minutes).

  Args:
    x: An integer representing the duration in minutes.

  Returns:
    A string representing the time in HH:MM format.
  """
  hours = int(x // 60)
  minutes = int(x % 60)
  return f'{hours:02d}:{minutes:02d}'


def inspect_dataset(root_data_path: str, dataset_name: str, data_class: str):
  """Loads, inspects, and visualizes a subset of a TensorFlow dataset.

  This function loads a specified dataset using `tfds.load`, prints its length,
  and then visualizes the first 5 samples. It extracts the 'mask' and
  'input_signal' from each sample, converts them to NumPy arrays, and
  uses the `visualize` function to display them.

  Args:
    root_data_path: The root directory where the dataset is located.
    dataset_name: The name of the dataset to inspect.

  Returns:
    None. This function prints information and displays visualizations.
  """
  print('Dataset Name:', dataset_name)
  try:
    data = tfds.load(
        data_class,
        data_dir=f'{root_data_path}/{dataset_name}',
        split='train',
        shuffle_files=False,
    )
    print('Number of Train Data Sample:', len(data))
  except:
    pass

  data_valid = tfds.load(
      data_class,
      data_dir=f'{root_data_path}/{dataset_name}',
      split='valid',
      shuffle_files=False,
  )
  print('Number of Valid Data Sample:', len(data_valid))

  for sample in data_valid.take(5):
    mask = tf.io.parse_tensor(sample['mask'], out_type=tf.bool).numpy().T
    sample = (
        tf.io.parse_tensor(sample['input_signal'], out_type=tf.double).numpy().T
    )

    visualize(mask, cmap='Greys')
    visualize(sample)

    print('--------------------------')


def visualize(
    sample_signal_input,
    figsize=(20, 5),
    title='',
    cmap='cool',
    dim=None,
    cbar=True,
    disabletext=False,
):
  """Visualizes a sample signal as a heatmap.

  This function creates a heatmap visualization of the input sample signal.
  It uses `seaborn.heatmap` to generate the heatmap and allows for customization
  of figure size, title, colormap, and display options.

  Args:
    sample_signal_input: The input sample signal data as a NumPy array.
    figsize: Tuple specifying the width and height of the figure (default: (20,
      5)).
    title: The title of the plot (default: '').
    cmap: The colormap to use for the heatmap (default: 'cool').
    dim: Optional dimension to select for visualization (default: None, displays
      all dimensions).
    cbar: Whether to display the colorbar (default: True).
    disabletext: Whether to disable the plot title (default: False).

  Returns:
    None. This function displays the heatmap visualization.
  """

  if dim is not None:
    sample_signal_input = sample_signal_input[[dim], :]
    labels_temp = []
  else:
    labels_temp = labels

  plt.figure(figsize=figsize)
  ax1 = plt.subplot2grid((1, 12), (0, 0), colspan=12)
  ax1 = sns.heatmap(
      sample_signal_input,
      cmap=cmap,
      cbar=cbar,
      linewidths=0.0,
      linecolor='black',
      alpha=0.8,
      ax=ax1,
      yticklabels=labels_temp,
  )

  for tick in ax1.get_xticklabels():
    tick.set_fontname('Ubuntu')
  ax1.tick_params(axis='x', labelsize=10.5)

  for tick in ax1.get_yticklabels():
    tick.set_fontname('Ubuntu')
  ax1.tick_params(axis='y', labelsize=10.5)

  # Set x-axis ticks every 4 hours
  tick_interval = 4 * 60  # 4 hours in minutes
  xticks = np.arange(0, sample_signal_input.shape[1], tick_interval)
  xtick_labels = [minutes_to_time(x) for x in xticks]
  ax1.set_xticks(xticks)
  ax1.set_xticklabels(xtick_labels, rotation=45, ha='right')
  # ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)

  plt.tight_layout()

  if not disabletext:
    plt.title(title)

  for i in np.arange(0, sample_signal_input.shape[1], 60):
    if i % (60 * 24) == 0:
      tempwidth, tempalpha = 2, 1
    else:
      tempwidth, tempalpha = 1, 0.4
    ax1.axvline(x=i, color='k', alpha=tempalpha, linewidth=tempwidth)

  for i in np.arange(0, sample_signal_input.shape[0] + 1, 1):
    ax1.axhline(y=i, color='k', alpha=0.4, linewidth=1)

  plt.tight_layout()
  plt.show()


inspect_dataset(DEFAULT_DATA_ROOT, DATASET_NAME, DATA_CLASS)

In [None]:
# @title Check the distribution of missing ratio

data_valid = tfds.load(
    DATA_CLASS,
    data_dir=f'{DEFAULT_DATA_ROOT}/{DATASET_NAME}',
    split='valid',
    shuffle_files=False,
)
print('Number of Valid Data Sample:', len(data_valid))

missing_ratios_all = []
for sample in data_valid.take(len(data_valid)):
  missing_ratio = sample['missingness_ratio'].numpy()
  missing_ratios_all.append(missing_ratio)

bins = np.arange(0.2, 0.8, 0.1)
sns.histplot(missing_ratios_all, bins=bins, kde=True)