https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-fitbit-prod-research-deid-eng-team:r&reason=%22b%2F285178698%22

Colab Kernel: Fitbit Prod Research Colab and please follow the steps:

-   Use the Fitbit prod kernel;
-   Restart the session;
-   Add import tensorflow_datasets as tfds to the top; -
-   Run ad_hoc import.

In [None]:
# @title Imports and Utils
import datetime
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from google3.fitbit.research.sensing.common.colab import metadata_database_helpers
from google3.fitbit.research.sensing.common.infra.transforms import data_loading
from google3.fitbit.research.sensing.common.infra.utils import data_intermediates
from google3.fitbit.research.sensing.common.proto import data_key_pb2
from google3.fitbit.research.sensing.kereru.utils import data_loader
# from google3.medical.waveforms.modelling.lsm.datasets.lsm import sensor_constants
from google3.pyglib import gfile


NORMALIZATION_PARAMETERS = {
    'HR': [81.77487101546053, 14.071882463657419],
    'eda_level_real': [4.610644695036417, 4.038640434725605],
    'leads_contact_counts': [232.12929124219374, 47.452852397806325],
    'steps': [8.040939885304397, 18.620454054986762],
    'jerk_auto': [202.14300725778924, 35.148804019774296],
    'step_count': [10.972891950490821, 16.380615326908973],
    'log_energy': [57.67806502240965, 43.5115141491905],
    'covariance': [44.4438237732618, 12.932594067855076],
    'log_energy_ratio': [43.869075905081964, 22.12416232697863],
    'zero_crossing_std': [159.64183077769096, 28.980474108337734],
    'zero_crossing_avg': [50.14393512016058, 35.18255260603251],
    'axis_mean': [117.5897743871761, 26.153304870045442],
    'altim_std': [0.003743016795787982, 0.04998379208940402],
    'kurtosis': [105.52495686705635, 61.87037974939132],
    'sleep_coefficient': [8.607589305861987, 3.9187832962818043],
    'wrist_temperatures': [30.80471706762349, 2.899243085851182],
    'hrv_shannon_entropy_rr': [3.2953304951132885, 0.464777365409023],
    'hrv_shannon_entropy_rrd': [2.9810634995051184, 0.48817021363471297],
    'hrv_percentage_of_nn_30': [0.35201734277287905, 0.1902607735053669],
    'ceda_magnitude_real_micro_siemens': [4.743574484899, 12.913499081063],
    'ceda_slope_real_micro_siemens': [3.2444288158784063, 1.821951365148186],
    'rmssd_percentile_0595': [34.1895276671302, 23.783359512525266],
    'sdnn_percentile_0595': [45.335573726241854, 24.38601160405501],
    'msa_probability': [48.1172590194961, 14.292898676874556],
    'hrv_percent_good': [0.2714810920080538, 0.2762414786979745],
    'hrv_rr_80th_percentile_mean': [828.8905850347666, 108.40428688789727],
    'hrv_rr_20th_percentile_mean': [734.5942838543058, 88.41269789220864],
    'hrv_rr_median': [780.925540250376, 94.86837708152842],
    'hrv_rr_mean': [785.7749142736874, 90.44585649648346],
    'hr_at_rest_mean': [82.86923994290905, 10.867752252500274],
    'skin_temperature_magnitude': [31.469650973107296, 1.7002512792231534],
    'skin_temperature_slope': [0.2655571148317653, 17.266512596820043],
}


FEATURES_TO_INCLUDE = [
    'HR',
    'eda_level_real',
    'leads_contact_counts',
    'steps',
    'jerk_auto',
    'log_energy',
    'covariance',
    'log_energy_ratio',
    'zero_crossing_std',
    'zero_crossing_avg',
    'axis_mean',
    'altim_std',
    'kurtosis',
    'sleep_coefficient',
    'wrist_temperatures',
    'hrv_shannon_entropy_rr',
    'hrv_shannon_entropy_rrd',
    'ceda_slope_real_micro_siemens',
    'rmssd_percentile_0595',
    'sdnn_percentile_0595',
    'hrv_percent_good',
    'hrv_rr_80th_percentile_mean',
    'hrv_rr_20th_percentile_mean',
    'hrv_rr_median',
    'hr_at_rest_mean',
    'skin_temperature_slope',
]


DATABASE_PATH = '/span/nonprod/consumer-health-research:fitbit-prod-research'
DATA_KEY_TYPE = 'TIER2_SITE_DATA'
DATA_STORAGE_KEYS_TO_LOAD = [
    'steps_compact',
    'steps',
    'sleep_score',
    'daily_typical_sleep_periods',
    'daily_resting_heart_rate',
    'daily_time_in_sleep_stages',
    'd_user',
    'daily_sleep',
    'f_user_activity_daily',
]


## PLOTTING IMPORTS AND UTILS
def visualize_features(array_feature, title: str):

  fig = plt.figure(figsize=(20, 7))
  ax1 = plt.subplot2grid((1, 12), (0, 0), colspan=12)
  group = array_feature

  ax1 = sns.heatmap(
      group.T,
      cmap='Reds',
      cbar=True,
      linewidths=0.0,
      linecolor='black',
      alpha=0.8,
      ax=ax1,
      yticklabels=True,
  )

  ax1.set_xticks(np.arange(0, group.shape[0], 60))
  ax1.set_xticklabels(np.arange(0, group.shape[0], 60))
  for tick in ax1.get_xticklabels():
    tick.set_fontname('Ubuntu')
    tick.set_style('italic')
  ax1.tick_params(axis='x', labelsize=10.5)
  ax1.set_xlabel('Minutes')

  ax1.set_yticklabels(FEATURES_TO_INCLUDE)
  for tick in ax1.get_yticklabels():
    tick.set_fontname('Ubuntu')
  ax1.tick_params(axis='y', labelsize=10.5)

  plt.xticks(rotation=45)  # Rotate labels for better readability
  plt.yticks(rotation=0)  # Rotate labels for better readability
  plt.title(title, fontname='Ubuntu', fontsize=16)
  plt.tight_layout()

  ax1.set_ylabel('Feature', fontname='Ubuntu', fontsize=14)

  ax1.axhline(y=0, color='k', linewidth=1, alpha=1)
  ax1.axhline(y=group.shape[1], color='k', alpha=1, linewidth=1)
  ax1.axvline(x=0, color='k', linewidth=1, alpha=1)
  ax1.axvline(x=group.shape[0], color='k', alpha=1, linewidth=1)

  for i in np.arange(0, group.shape[0], 60):
    ax1.axvline(x=i, color='k', alpha=0.4, linewidth=1)
  for i in np.arange(0, group.shape[1], 1):
    ax1.axhline(y=i, color='k', alpha=0.4, linewidth=1)
  plt.show()


def load_user_data(
    data_key_type: str,
    user_id: str,
    data_storage_keys_to_load: list[str],
) -> data_intermediates.DataKeyAndKeyValues:
  """Loads Tier-2 user data to a DataKeyAndKeyValuesWithData.

  Args:
    data_key_type: Type of the DataKey to load.
    user_id: User ID to load.
    data_storage_keys_to_load: List of DataStorage keys that should be loaded
      for the user. All available loaded data will be in the returned
      DataKeyAndKeyValuesWithData's data field.

  Returns:
    DataKeyAndKeyValuesWithData with the loaded data.
  """
  # Each user is represented by one DataKey in the database. Each DataKey has
  # Data Storage elements associated with it. These are what will point to the
  # capacitor files with the imported raw data.
  dkkv = metadata_database_helpers.get_database_data_for_data_key(
      DATABASE_PATH,
      data_key_pb2.DataKey(
          type=data_key_type,
          session_id=user_id,
      ),
  )

  dkkvwd = list(
      data_loading.LoadDataDoFn(
          data_storage_keys_to_load, data_loader.get_data_loader()
      ).process(dkkv)
  )[0]
  return dkkvwd


def describe_data_key_key_value(dkkv):
  print(f'DataKey: {dkkv.data_key}')
  print(f'\ttype: {dkkv.data_key.type}')
  print(f'\tsession_id (i.e., user): {dkkv.data_key.session_id}')
  print(f'\tdata_storage_dict: {dkkv.data_storage_dict}')
  print(f'\tprocess_data_dict: {dkkv.process_data_dict}')


def describe_loaded_data(
    dkkvwd: data_intermediates.DataKeyAndKeyValues,
) -> None:
  """Print description of loaded DataKeyAndKeyValuesWithData."""
  print('Loaded data for DataKey:')
  print(
      '\ttype:                   '
      f' {dkkvwd.data_key_and_key_values.data_key.type}'
  )
  print(
      '\tsession_id (i.e., user):'
      f' {dkkvwd.data_key_and_key_values.data_key.session_id}\t'
  )

  print('Loaded data:')
  for k, v in dkkvwd.data.items():
    print(f'\t{k}: {len(v)} elements')

  print('Metadata table data:')
  print(
      '\tPipeline Metadata:'
      f' {len(dkkvwd.data_key_and_key_values.pipeline_metadata_dict)} items'
  )
  print(
      '\tData Storage:     '
      f' {len(dkkvwd.data_key_and_key_values.data_storage_dict)} items'
  )
  print(
      '\tProcess Data:     '
      f' {len(dkkvwd.data_key_and_key_values.process_data_dict)} items'
  )


def convert_days_to_datetime(days, start_date=datetime.datetime(1970, 1, 1)):
  """Converts a number of days to a datetime object, relative to a start date.

  Args:
    days: The number of days to add to the start date.
    start_date: The starting datetime object (defaults to epoch time,
      1970-01-01).

  Returns:
    A datetime object representing the calculated date.
  """
  return start_date + datetime.timedelta(days=days)

In [None]:
# @title Plotting


def calculate_metrics(y_true, y_pred, selected_mask):
  """Calculates MSE, MAE, and MAPE.

  Args:
      y_true: Ground truth values.
      y_pred: Predicted values.
      selected_mask: A boolean mask indicating which values to consider.

  Returns:
      A tuple containing the MSE, MAE, and MAPE.
  """
  y_true = np.array(y_true)
  y_pred = np.array(y_pred)
  y_true = y_true[selected_mask]
  y_pred = y_pred[selected_mask]
  assert y_true.shape == y_pred.shape
  mse = np.mean(np.square(y_true - y_pred))
  mae = np.mean(np.abs(y_true - y_pred))
  number_of_sample = y_true.shape[0]

  # Avoid division by zero and only calculate MAPE where y_true is not zero
  # non_zero_mask = y_true != 0
  non_zero_mask = np.abs(y_true) > 1e-3
  mape = (
      np.mean(
          np.abs(
              (y_true[non_zero_mask] - y_pred[non_zero_mask])
              / y_true[non_zero_mask]
          )
      )
      * 100
  )

  return mse, mae, mape, number_of_sample


def plot_all_features(
    features_to_include,
    reconstruction_merged,
    original,
    normalization_parameters,
    subset_mask,
    figure_title: str,
):
  """Visualizes all features in a single 4x8 grid of subplots.

  Includes MSE, MAE, and MAPE in the title of each subplot, and only shows the
  legend on the first subplot.

  Args:
      features_to_include: List of feature names.
      reconstruction_merged:  Numpy array of reconstructed data.
      original: Numpy array of original data.
      normalization_parameters: Dictionary of normalization parameters (min,
        range) for each feature.
      subset_mask: Numpy array of subset mask where data is NOT missing in
        original data but masked in addtional masking.
      figure_title: Title of the figure.
  """

  fig, axes = plt.subplots(8, 4, figsize=(24, 24))  # Increase figure width
  axes = axes.flatten()

  for i, feature in enumerate(features_to_include):
    ax = axes[i]

    ax.set_facecolor('xkcd:white')
    ax.spines['top'].set_color('black')
    ax.spines['bottom'].set_color('black')
    ax.spines['left'].set_color('black')
    ax.spines['right'].set_color('black')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.xaxis.set_tick_params(which='major', size=10, width=2, direction='in')
    ax.xaxis.set_tick_params(which='minor', size=7, width=2, direction='in')
    ax.yaxis.set_tick_params(which='major', size=10, width=2, direction='in')
    ax.yaxis.set_tick_params(which='minor', size=7, width=2, direction='in')

    feature_index = features_to_include.index(feature)

    reconstructed_data = (
        reconstruction_merged[:, feature_index]
        * normalization_parameters[feature][1]
        + normalization_parameters[feature][0]
    )
    original_data = (
        original[:, feature_index] * normalization_parameters[feature][1]
        + normalization_parameters[feature][0]
    )

    ax.plot(reconstructed_data, label='Reconstruction from LSM', alpha=0.6)
    ax.plot(original_data, label='Original Data (Ground-Truth)', alpha=0.6)

    ax.set_xlabel('Time (minutes from midnight)', labelpad=10)
    ax.set_ylabel(feature, labelpad=10)

    # Calculate and display metrics in the title
    selected_mask = subset_mask[:, feature_index]
    mse, mae, mape, n_sample = calculate_metrics(
        original_data, reconstructed_data, selected_mask
    )
    ax.set_title(
        f'MSE: {mse:.2f}, MAE: {mae:.2f}, MAPE: {mape:.2f}%, N: {n_sample}',
        fontsize=12,
    )
    for j in range(len(selected_mask)):
      if selected_mask[j]:
        ax.axvspan(j - 0.5, j + 0.5, facecolor='red', alpha=0.3)

    # Only show legend on the first subplot
    if i == 0:
      ax.legend()
    elif ax.get_legend() is not None:
      ax.get_legend().remove()  # Remove the legend
  plt.suptitle(
      figure_title, fontsize=14, y=0.98
  )  # Move title closer to subplots
  plt.tight_layout(
      rect=[0, 0, 1, 0.985]
  )  # Minimize gap between title and subplots
  plt.show()


def plot_one_sample(
    file_path: str,
    plotOn: bool,
    model_name: str,
    task_name: str,
    patch_size=(10, 1),
):
  with gfile.Open(file_path, 'rb') as f:
    loaded_data = np.load(f, allow_pickle=True)
  input = np.reshape(
      loaded_data['input_signal'][0, 0, :, :, 0],
      (WINDOW_LENGTH, NUMBER_OF_FEATURE),
  )
  original = np.reshape(
      loaded_data['input_signal'][0, 0, :, :, 0],
      (WINDOW_LENGTH, NUMBER_OF_FEATURE),
  ).copy()
  imputation_mask = np.reshape(
      loaded_data['imputation_mask'][0, 0, :, :, 0],
      (WINDOW_LENGTH, NUMBER_OF_FEATURE),
  )
  input[imputation_mask] = np.nan

  id = loaded_data['user_id'][:, 0][0].decode('utf-8')
  key = loaded_data['key'][:, 0][0].decode('utf-8')
  print('**************************************')
  print('Getting data for: ', id, ' at: ', key)
  print('**************************************')
  dkkvwd = load_user_data(
      data_key_type=DATA_KEY_TYPE,
      user_id=id,
      data_storage_keys_to_load=DATA_STORAGE_KEYS_TO_LOAD,
  )

  step_cnt = -1
  sleep_minutes = -1
  table = 'f_user_activity_daily'
  if table not in dkkvwd.data.keys():
    print('Table not in dkkvwd.data.keys()')
    return
  for i in dkkvwd.data[table]:
    calculated_date = convert_days_to_datetime(i.activity_dt)
    if calculated_date.strftime('%Y-%m-%d') == key[0:10]:
      step_cnt = i.step_cnt
      sleep_minutes = i.sleep_all_asleep_minute_cnt
      break
  if step_cnt == -1:
    print('Step Count is -1')
    return
  additional_mask = np.zeros([WINDOW_LENGTH, NUMBER_OF_FEATURE])
  num_patches = [
      int(WINDOW_LENGTH / patch_size[0]),
      int(NUMBER_OF_FEATURE / patch_size[1]),
  ]
  token_mask = loaded_data['token_mask'][:, 0, :].reshape(
      num_patches[0], num_patches[1]
  )
  for i in range(num_patches[0]):
    for j in range(num_patches[1]):
      if token_mask[i, j] == 1:
        additional_mask[
            i * patch_size[0] : (i + 1) * patch_size[0],
            j * patch_size[1] : (j + 1) * patch_size[1],
        ] = 1
        input[
            i * patch_size[0] : (i + 1) * patch_size[0],
            j * patch_size[1] : (j + 1) * patch_size[1],
        ] = np.nan
  additional_mask = additional_mask.astype(bool)

  tmp = loaded_data['eval_plot_logits'][0, 0, :, :].reshape(
      loaded_data['eval_plot_logits'].shape[2], patch_size[0], patch_size[1]
  )[:, :, :]
  reconstruction = np.zeros([WINDOW_LENGTH, NUMBER_OF_FEATURE])
  for i in range(num_patches[0]):
    for j in range(num_patches[1]):
      reconstruction[
          i * patch_size[0] : (i + 1) * patch_size[0],
          j * patch_size[1] : (j + 1) * patch_size[1],
      ] = tmp[i * num_patches[1] + j, :, :].reshape(
          patch_size[0], patch_size[1]
      )

  reconstruction_merged = input.copy()
  reconstruction_merged[imputation_mask] = reconstruction[imputation_mask]
  reconstruction_merged[additional_mask] = reconstruction[additional_mask]

  if plotOn:
    print('********* USER ID *********')
    print('ID: ', id)
    print('Key: ', key)
    visualize_features(imputation_mask, 'Missing Mask')
    plt.show()
    visualize_features(original, 'Original Signal')
    plt.show()
    visualize_features(additional_mask, 'Addtional Mask')
    plt.show()
    visualize_features(input, 'Input with Addtional Mask')
    plt.show()
    visualize_features(reconstruction, 'RECONSTRUCTION (ALL)')
    plt.show()
    visualize_features(reconstruction_merged, 'RECONSTRUCTION (MERGED)')
    plt.show()
    subset_mask = (~imputation_mask) & additional_mask
    plot_all_features(
        FEATURES_TO_INCLUDE,
        reconstruction_merged,
        original,
        NORMALIZATION_PARAMETERS,
        subset_mask,
        f'Model: {model_name}, Task: {task_name}',
    )

In [None]:
# @title Find the right file


def get_config_file_paths(
    root_path: str, dataset_name: str, config_type: str
) -> List[str]:
  """Returns a list of file paths for the given dataset and config type."""
  selected_path = gfile.Glob(f'{root_path}/{dataset_name}/{config_type}*')
  return selected_path

In [None]:
# @title Experiment ID Dict
EXPERIMENT_ID_DICT = {
    'random_10_by_1': (
        '25_3_20_nai_1rand8_200kstep_lsm_v2_missing_balanced_20250301_valid_dataset_0.3_only_xid_155551157_wid_1_20250325140044'
    ),
    'random_10_by_2': (
        '25_3_20_nai_1rand8_10by2_200kstep_lsm_v2_missing_balanced_20250301_valid_dataset_0.3_only_xid_155551076_wid_1_20250325140733'
    ),
    'random_10_by_1_inherited': (
        '25_3_20_inh_1rand8_200kstep_lsm_v2_missing_balanced_20250301_valid_dataset_0.3_only_xid_155551505_wid_1_20250325163411'
    ),
    'mix_10_by_1': (
        '25_3_20_nai_1rand8_1fbar4_sharedemb384_200kstep_lsm_v2_missing_balanced_20250301_valid_dataset_0.3_only_xid_155553088_wid_1_20250325135212'
    ),
}

In [None]:
# @title Constants Set up
ROOT_FILE_PATH = '/namespace/fitbit-medical-sandboxes/jg/partner/encrypted/chr-ards-fitbit-prod-research/deid/exp/dmcduff/ttl=52w/lsm_v2/exp_dumps'
LIMIT_SAMPLE = 5
WINDOW_LENGTH = 1440
NUMBER_OF_FEATURE = 26
PATCH_SIZE = [10, 1]

In [None]:
SELECTED_CONFIG = 'imputation_0.04167' # @param ['forecast_0.00695', 'forecast_0.02084', 'forecast_0.04167', 'forecast_0.125', 'forecast_0.25', 'forecast_0.5', 'imputation_0.00695', 'imputation_0.02084', 'imputation_0.04167', 'imputation_0.125', 'imputation_0.25', 'imputation_0.5', 'random_imputation_0.2', 'random_imputation_0.5', 'random_imputation_0.8']
MODEL_NAME = 'random_10_by_1_inherited' # @param ['mix_10_by_1', 'random_10_by_1', 'random_10_by_2', 'random_10_by_1_inherited']
selected_paths = get_config_file_paths(
    ROOT_FILE_PATH, EXPERIMENT_ID_DICT[MODEL_NAME], SELECTED_CONFIG
)
for file_path in selected_paths[:LIMIT_SAMPLE]:
  if '10_by_2' in MODEL_NAME:
    plot_one_sample(file_path, True, MODEL_NAME, SELECTED_CONFIG, (10, 2))
  else:
    plot_one_sample(file_path, True, MODEL_NAME, SELECTED_CONFIG, (10, 1))

In [None]:
selected_path = gfile.Glob(
    f'{ROOT_FILE_PATH}/{EXPERIMENT_ID_DICT[MODEL_NAME]}/*'
)
sorted(
    list(set([path.split('/')[-1].split('_eval')[0] for path in selected_path]))
)