1. Please do [AoD](https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-fitbit-prod-research-deid-eng-team:r&reason=%22b%2F285178698%22) before running this colab.
2. Use any Borg runtime kernels named after `Fitbit Prod Research`

In [None]:
import tensorflow_datasets as tfds
from google3.pyglib import gfile
import json
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from tqdm.notebook import tqdm
from scipy.signal import spectrogram
from collections import defaultdict
from statsmodels.tsa.stattools import acf
import pandas as pd
import gc

In [None]:
DEFAULT_DATA_ROOT = '/namespace/fitbit-medical-sandboxes/jg/partner/encrypted/chr-ards-fitbit-prod-research/deid/exp/dmcduff/ttl=52w/lsm_v2/datasets/tfds'  # @param {type:"string"}

In [None]:
#@title Check the paths

latest_dataset_paths = gfile.Glob(f'{DEFAULT_DATA_ROOT}/*')
latest_dataset_paths

In [None]:
# @title Utils

# for the V2 dataset, these are the modality labels
labels = [
    'HR',
    'eda_level_real',
    'leads_contact_counts',
    'steps',
    'jerk_auto',
    'step_count',
    'log_energy',
    'covariance',
    'log_energy_ratio',
    'zero_crossing_std',
    'zero_crossing_avg',
    'axis_mean',
    'altim_std',
    'kurtosis',
    'sleep_coefficient',
    'wrist_temperatures',
    'hrv_shannon_entropy_rr',
    'hrv_shannon_entropy_rrd',
    'hrv_percentage_of_nn_30',
    'ceda_magnitude_real_micro_siemens',
    'ceda_slope_real_micro_siemens',
    'rmssd_percentile_0595',
    'sdnn_percentile_0595',
    'msa_probability',
    'hrv_percent_good',
    'hrv_rr_80th_percentile_mean',
    'hrv_rr_20th_percentile_mean',
    'hrv_rr_median',
    'hrv_rr_mean',
    'hr_at_rest_mean',
    'skin_temperature_magnitude',
    'skin_temperature_slope',
]

# These are each of the new missingness groups for the v2 dataset
# arbitrary set one item of the group to be the key
# the value is another dictionary, with name being name of the missingness group
# and members being the name of the modalities inside of group
missgroup_dict = {
    'HR': {'name': 'HR', 'members': ['HR']},
    'steps': {'name': 'steps', 'members': ['steps']},
    'wrist_temperatures': {
        'name': 'wrist_temperatures',
        'members': ['wrist_temperatures'],
    },
    'sleep_coefficient': {
        'name': 'sleep_coefficient',
        'members': ['sleep_coefficient', 'is_on_wrist'],
    },
    'eda_level_real': {
        'name': 'EDA Sensor',
        'members': [
            'eda_level_real',
            'eda_level_imaginary',
            'eda_slope_real',
            'eda_slope_imaginary',
            'leads_contact_counts',
        ],
    },
    'jerk_auto': {
        'name': 'ACC Sensor',
        'members': [
            'jerk_auto',
            'step_count',
            'log_energy',
            'covariance',
            'log_energy_ratio',
            'zero_crossing_std',
            'zero_crossing_avg',
            'axis_mean',
            'altim_std',
            'kurtosis',
        ],
    },
    'hrv_shannon_entropy_rr': {
        'name': 'HRV',
        'members': [
            'hrv_shannon_entropy_rr',
            'hrv_shannon_entropy_rrd',
            'hrv_percentage_of_nn_30',
            'rmssd_percentile_0595',
            'sdnn_percentile_0595',
            'hrv_rr_80th_percentile_mean',
            'hrv_rr_20th_percentile_mean',
            'hrv_rr_median',
            'hrv_rr_mean',
        ],
    },
    'ceda_magnitude_real_micro_siemens': {
        'name': 'CEDA Magnitude',
        'members': [
            'ceda_magnitude_real_micro_siemens',
            'hrv_percent_good',
            'skin_temperature_magnitude',
        ],
    },
    'ceda_slope_real_micro_siemens': {
        'name': 'ceda_slope_real_micro_siemens',
        'members': [
            'msa_probability',
        ],
    },
    'msa_probability': {
        'name': 'msa_probability',
        'members': ['msa_probability'],
    },
    'hr_at_rest_mean': {
        'name': 'hr_at_rest_mean',
        'members': ['hr_at_rest_mean'],
    },
    'skin_temperature_slope': {
        'name': 'skin_temperature_slope',
        'members': ['skin_temperature_slope'],
    },
}


def check_dataset_length(root_data_path: str, dataset_name: str) -> str:
  """Checks and prints the length of a dataset.

  This function takes the root data path and the dataset name as input.
  It then uses gfile.Glob to find all the data samples within the dataset
  and prints the total number of samples found.

  Args:
    root_data_path: The root directory where the dataset is located.
    dataset_name: The name of the dataset to check.

  Returns:
      None. This function prints the dataset name and the number of data
      samples.
  """
  print('Dataset Name:', dataset_name)
  print(
      'Number of Data Sample:',
      len(gfile.Glob(f'{root_data_path}/{dataset_name}/lsm/*/*')),
  )


def inspect_dataset(root_data_path: str, dataset_name: str):
  """Loads, inspects, and visualizes a subset of a TensorFlow dataset.

  This function loads a specified dataset using `tfds.load`, prints its length,
  and then visualizes the first 5 samples. It extracts the 'mask' and
  'input_signal' from each sample, converts them to NumPy arrays, and
  uses the `visualize` function to display them.

  Args:
    root_data_path: The root directory where the dataset is located.
    dataset_name: The name of the dataset to inspect.

  Returns:
    None. This function prints information and displays visualizations.
  """
  print('Dataset Name:', dataset_name)
  data = tfds.load(
      'lsm',
      data_dir=f'{root_data_path}/{dataset_name}',
      split='train',
      shuffle_files=False,
  )
  print('Number of Data Sample:', len(data))
  for sample in data.take(5):
    mask = tf.io.parse_tensor(sample['mask'], out_type=tf.bool).numpy().T
    sample = (
        tf.io.parse_tensor(sample['input_signal'], out_type=tf.double).numpy().T
    )

    visualize(mask, cmap='Greys')
    visualize(sample)

    print('--------------------------')


def visualize(
    sample_signal_input,
    figsize=(20, 5),
    title='',
    cmap='cool',
    dim=None,
    cbar=True,
    disabletext=False,
):
  """Visualizes a sample signal as a heatmap.

  This function creates a heatmap visualization of the input sample signal.
  It uses `seaborn.heatmap` to generate the heatmap and allows for customization
  of figure size, title, colormap, and display options.

  Args:
    sample_signal_input: The input sample signal data as a NumPy array.
    figsize: Tuple specifying the width and height of the figure (default: (20,
      5)).
    title: The title of the plot (default: '').
    cmap: The colormap to use for the heatmap (default: 'cool').
    dim: Optional dimension to select for visualization (default: None, displays
      all dimensions).
    cbar: Whether to display the colorbar (default: True).
    disabletext: Whether to disable the plot title (default: False).

  Returns:
    None. This function displays the heatmap visualization.
  """

  if dim is not None:
    sample_signal_input = sample_signal_input[[dim], :]
    labels_temp = []
  else:
    labels_temp = labels

  plt.figure(figsize=figsize)
  ax1 = plt.subplot2grid((1, 12), (0, 0), colspan=12)
  ax1 = sns.heatmap(
      sample_signal_input,
      cmap=cmap,
      cbar=cbar,
      linewidths=0.0,
      linecolor='black',
      alpha=0.8,
      ax=ax1,
      yticklabels=labels_temp,
  )

  for tick in ax1.get_xticklabels():
    tick.set_fontname('Ubuntu')
  ax1.tick_params(axis='x', labelsize=10.5)

  for tick in ax1.get_yticklabels():
    tick.set_fontname('Ubuntu')
  ax1.tick_params(axis='y', labelsize=10.5)

  # Set x-axis ticks every 4 hours
  tick_interval = 4 * 60  # 4 hours in minutes
  xticks = np.arange(0, sample_signal_input.shape[1], tick_interval)
  xtick_labels = [minutes_to_time(x) for x in xticks]
  ax1.set_xticks(xticks)
  ax1.set_xticklabels(xtick_labels, rotation=45, ha='right')
  # ax1.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)

  plt.tight_layout()

  if not disabletext:
    plt.title(title)

  for i in np.arange(0, sample_signal_input.shape[1], 60):
    if i % (60 * 24) == 0:
      tempwidth, tempalpha = 2, 1
    else:
      tempwidth, tempalpha = 1, 0.4
    ax1.axvline(x=i, color='k', alpha=tempalpha, linewidth=tempwidth)

  for i in np.arange(0, sample_signal_input.shape[0] + 1, 1):
    ax1.axhline(y=i, color='k', alpha=0.4, linewidth=1)

  plt.tight_layout()
  plt.show()


def consecutive_ones_lengths(mask):
  """Calculates the lengths of consecutive sequences of 1s in a binary mask.

  This function identifies and measures the lengths of continuous stretches of
  1s within
  a given binary mask (an array consisting of 0s and 1s). It works by:

  1. Finding the points where the mask transitions from 0 to 1 (start of a
  sequence)
     and from 1 to 0 (end of a sequence) using `np.diff`.
  2. Using `np.where` to get the indices of these transitions.
  3. Calculating the length of each sequence by subtracting the start index from
  the end index.

  It's particularly useful for analyzing patterns or gaps in data represented by
  such masks.

  Args:
    mask: A 1-D NumPy array representing the binary mask.

  Returns:
    A 1-D NumPy array containing the lengths of each consecutive sequence of 1s
    found
    in the input mask.
  """

  # Find where the mask changes value
  diff = np.diff(mask, prepend=0, append=0)

  # Start and end indices of sequences of ones
  starts = np.where(diff == 1)[0]
  ends = np.where(diff == -1)[0]

  # Calculate lengths of each sequence of ones
  lengths = ends - starts
  return lengths


def minutes_to_time(x):
  """Converts minutes to a time string in HH:MM format.

  This function takes an integer representing a duration in minutes and
  converts it into a formatted time string in the format "HH:MM" (hours and
  minutes).

  Args:
    x: An integer representing the duration in minutes.

  Returns:
    A string representing the time in HH:MM format.
  """
  hours = int(x // 60)
  minutes = int(x % 60)
  return f'{hours:02d}:{minutes:02d}'


class StopExecution(Exception):
  """Custom exception used to halt the execution of a cell or process.

  This exception is designed to stop the execution flow without displaying
  a traceback. It is useful for scenarios where you want to terminate
  a process prematurely but avoid cluttering the output with unnecessary
  traceback information.

  Attributes: None

  Methods:
      _render_traceback_: Overrides the default traceback rendering to
          suppress the traceback output.
  """

  def _render_traceback_(self):
    return []

In [None]:
# @title Check the Data Length

check_dataset_length(DEFAULT_DATA_ROOT, 'lsm_v2_pretraining_n_200000_300m')
check_dataset_length(DEFAULT_DATA_ROOT, 'lsm_v2_pretraining_n_200000_1440m')
check_dataset_length(DEFAULT_DATA_ROOT, 'lsm_v2_pretraining_n_200000_10080m')

In [None]:
# @title Inspect the 5-hour Data

inspect_dataset(DEFAULT_DATA_ROOT, 'lsm_v2_pretraining_n_200000_300m')

In [None]:
# @title Inspect the Daily Data

inspect_dataset(DEFAULT_DATA_ROOT, 'lsm_v2_pretraining_n_200000_1440m')

In [None]:
#@title Inspect the Weekly Data

inspect_dataset(DEFAULT_DATA_ROOT, 'lsm_v2_pretraining_n_200000_10080m')

# Collect Aggregate Statistics for Visualization

In [None]:
DATASET_NAME = "lsm_v2_pretraining_n_200000_10080m" #@param
NUM_PPL = 1_000 #@param
TOTAL_TIME_TIMESERIES = 24*60*7


In [None]:
MISSGAP_VIZ = True #@param
SPECTROGRAM_VIZ = True #@param
ACF_VIZ = True #@param
MISSCORR_VIZ = True #@param
FULLDATA_VIZ = True #@param

In [None]:
##### Can be killed early if you want to #####

NUM_CHANNELS = len(labels)
TIME_LEN = TOTAL_TIME_TIMESERIES
DAYS_INWEEK = 7

# dataset creation
print('Dataset Name:', DATASET_NAME)
data = tfds.load(
    'lsm',
    data_dir=f'{DEFAULT_DATA_ROOT}/{DATASET_NAME}',
    split='train',
    shuffle_files=False,
)

missgroupkeys_set = set(missgroup_dict.keys())


# aggregate statistics
missgaps_all = defaultdict(list)
Sxx_all = np.zeros((NUM_CHANNELS, 31, 47))
acfday_all = defaultdict(list)
acfweek_all = defaultdict(list)
misscorr_all = np.zeros((len(labels), len(labels)))
validcorrcounts_all = np.zeros((len(labels), len(labels)))
sum_all = np.zeros((NUM_CHANNELS, TIME_LEN))
notmissamt_all = np.zeros((NUM_CHANNELS, TIME_LEN))


for sample in tqdm(data.take(NUM_PPL)):
  mask = (
      tf.io.parse_tensor(sample['mask'], out_type=tf.bool)
      .numpy()
      .T.reshape(NUM_CHANNELS, DAYS_INWEEK, TIME_LEN // DAYS_INWEEK)
  )
  signal = (
      tf.io.parse_tensor(sample['input_signal'], out_type=tf.double)
      .numpy()
      .T.reshape(NUM_CHANNELS, DAYS_INWEEK, TIME_LEN // DAYS_INWEEK)
  )
  for i in range(NUM_CHANNELS):
    for j in range(7):
      ### Calculate Missingness Gap Lengths for specific missgroup
      if MISSGAP_VIZ:
        if labels[i] in missgroupkeys_set:
          mask_input = mask[i, j, :]
          missgaps_all[labels[i]].append(consecutive_ones_lengths(mask_input))

      signal_channel = signal[i, j, :]

      ### Calculate Spectrogram for specific channel
      if SPECTROGRAM_VIZ:
        f, t, Sxx = spectrogram(signal_channel, nperseg=60, noverlap=30)
        Sxx_all[i, :] += Sxx

      ### Calculate ACF for day for specific channel
      if ACF_VIZ:
        max_lag = signal_channel.shape[-1] // 2
        acf_result = acf(signal_channel, nlags=max_lag, fft=True)
        # Remove NaN values from ACF results
        valid_lags = np.isfinite(acf_result)  # Identify non-NaN values
        acf_result = acf_result[valid_lags]
        lags = np.arange(0, max_lag + 1)[valid_lags]
        acfday_all[labels[i]].extend([lags, acf_result])

    mask_channel = mask[i].flatten()
    signal_channel = signal[i].flatten()

    ### Calculate ACF for week for specific channel
    if ACF_VIZ:
      max_lag = signal_channel.shape[-1] // 2
      acf_result = acf(signal_channel, nlags=max_lag, fft=True)
      # Remove NaN values from ACF results
      valid_lags = np.isfinite(acf_result)  # Identify non-NaN values
      acf_result = acf_result[valid_lags]
      lags = np.arange(0, max_lag + 1)[valid_lags]
      acfweek_all[labels[i]].extend([lags, acf_result])

  ### Calculate Missingness Correlations
  if MISSCORR_VIZ:
    misscorr_temp = (
        pd.DataFrame(mask.reshape(NUM_CHANNELS, -1).T).corr().to_numpy()
    )
    valid_mask = np.isfinite(misscorr_temp)
    misscorr_all[valid_mask] += misscorr_temp[valid_mask]
    # Increment the count of valid updates
    validcorrcounts_all[valid_mask] += 1

  ### Calculate Full Data and Missingness Heatmap
  if FULLDATA_VIZ:
    signal_temp = np.copy(signal)
    signal_temp[(mask == 1)] = 0
    sum_all += signal_temp.reshape(NUM_CHANNELS, -1)
    notmissamt_all += ~mask.reshape(NUM_CHANNELS, -1)

In [None]:
if MISSCORR_VIZ:
  # Calculate Missingness Correlations (Continued)
  # Normalization and avoid division by zero
  with np.errstate(divide='ignore', invalid='ignore'):
    misscorr_all = np.divide(
        misscorr_all, validcorrcounts_all
    )  # Element-wise division
    misscorr_all[~np.isfinite(misscorr_all)] = (
        0  # Replace any resulting NaNs with 0
    )
  # Ensure diagonals are 1 (self-correlation is always 1)
  np.fill_diagonal(misscorr_all, 1)

if FULLDATA_VIZ:
  ### Calculate Full Data and Missingness Heatmap (Continued)
  dataheatmap_all = sum_all / notmissamt_all
  missheatmap_all = (1 - notmissamt_all) / NUM_PPL

# Visualizations

In [None]:
# @title Heatmaps of total missingness and total data (while ignoring missing)

if not FULLDATA_VIZ:
  print("No full data visualization requested")
  raise StopExecution
visualize(missheatmap_all, cmap="Greys")
visualize(dataheatmap_all)

In [None]:
# @title Correlation Visualizations

if not MISSCORR_VIZ:
  print("No missingness correlation visualization requested")
  raise StopExecution

# Correlation visualization
plt.figure(figsize=(10, 8))
sns.heatmap(
    misscorr_all,
    fmt=".2f",
    xticklabels=labels,
    yticklabels=labels,
    cmap="coolwarm",
    cbar=True,
    square=True,
    linewidths=0.5,
)
plt.title("Correlation Matrix Visualization")
plt.xticks(
    rotation=45, ha="right"
)  # Rotate x-axis labels for better readability
plt.yticks(rotation=0)  # Keep y-axis labels horizontal
plt.tight_layout()
plt.show()

# Redo correlation visualization with perfect correlation of 1 to be black
plt.figure(figsize=(10, 8))
ax = sns.heatmap(
    misscorr_all,
    fmt=".2f",
    xticklabels=labels,
    yticklabels=labels,
    cmap="coolwarm",
    cbar=True,
    square=True,
    linewidths=0.5,
)
plt.title("Correlation Matrix Visualization")
plt.xticks(
    rotation=45, ha="right"
)  # Rotate x-axis labels for better readability
plt.yticks(rotation=0)  # Keep y-axis labels horizontal
plt.tight_layout()
# Overlay black color for values of 1
for (i, j), value in np.ndenumerate(misscorr_all):
  if value == 1:
    ax.add_patch(plt.Rectangle((j, i), 1, 1, color="black"))
plt.show()

In [None]:
# @title Autocorrelation Function (ACF) Visualization

if not ACF_VIZ:
  print("No ACF visualization requested")
  raise StopExecution

# ACF heatmap for week data
plt.figure(figsize=(10, 3))
x_vals = np.concatenate(acfweek_all['HR'][::2])
y_vals = np.concatenate(acfweek_all['HR'][1::2])
max_lag = TIME_LEN // 2
plt.hist2d(x_vals, y_vals, bins=[max_lag, 50], cmap='Blues', cmin=1)
# Label every 360 minutes (for readability)
time_labels = [
    minutes_to_time(i) for i in range(0, max_lag + 1, 360)
]  # Label every hour
plt.xticks(np.arange(0, max_lag + 1, 360), time_labels)
# Emphasize gridlines every 24 hours
for i in range(0, max_lag + 1, 24 * 60):
  plt.axvline(x=i, color='gray', linestyle='--', linewidth=2)
plt.xlabel('Time Shifted [HH:MM]')
plt.ylabel('Autocorrelation')
plt.title('ACF Heatmap for Week-long HR')
plt.colorbar(label='Frequency')  # Colorbar to indicate density
plt.grid(axis='x')
plt.tight_layout()
plt.show()


# ACF heatmap for day data
plt.figure(figsize=(10, 3))
x_vals = np.concatenate(acfday_all['HR'][::2])
y_vals = np.concatenate(acfday_all['HR'][1::2])
max_lag = 24 * 60 // 2
plt.hist2d(x_vals, y_vals, bins=[max_lag, 50], cmap='Blues', cmin=1)
# Label every 360 minutes (for readability)
time_labels = [
    minutes_to_time(i) for i in range(0, max_lag + 1, 360)
]  # Label every hour
plt.xticks(np.arange(0, max_lag + 1, 360), time_labels)
# Emphasize gridlines every 24 hours
for i in range(0, max_lag + 1, 24 * 60):
  plt.axvline(x=i, color='gray', linestyle='--', linewidth=2)
plt.xlabel('Time Shifted [HH:MM]')
plt.ylabel('Autocorrelation')
plt.title('ACF Heatmap for Week-long HR')
plt.colorbar(label='Frequency')  # Colorbar to indicate density
plt.grid(axis='x')
plt.tight_layout()
plt.show()

In [None]:
#@title Channel Specific Aggregate Spectrogram

if not SPECTROGRAM_VIZ:
  print("No spectrogram visualization requested")
  raise StopExecution

for i in range(len(labels)):
  plt.figure(figsize=(10,2))

  c = plt.pcolormesh(t, f, Sxx_all[i,:], shading='gouraud', vmax=np.max(Sxx_all))
  plt.colorbar(c)

  # plt.ylabel('Freq [cpm]')
  # plt.xlabel('Time [min]')
  plt.yticks([])
  plt.xticks([])

  # Add secondary x-axis for time of day (minutes from midnight to 11:59 PM)
  timemax = 24*60
  tickfreq = 180
  tickamt = timemax//tickfreq
  # Generate hourly ticks
  hourly_ticks = np.linspace(30, 1408, tickamt+1)[1:][:-1]  # Tick every 60 minutes
  hourly_labels = [minutes_to_time(x) for x in np.arange(tickfreq, timemax+tickfreq, tickfreq)][:-1]
  # Add secondary x-axis with hourly ticks
  ax = plt.gca()
  ax2 = ax.secondary_xaxis('bottom')
  ax2.set_xticks(hourly_ticks)
  ax2.set_xticklabels(hourly_labels)
  ax2.set_xlabel('Time of Day [HH:MM]')

  plt.title(f"{labels[i]} Spectrogram")
  plt.tight_layout()
  plt.show()

In [None]:
#@title Missingness Gaps for Each Missingness Group

if not MISSGAP_VIZ:
  print("No missingness gap length visualization requested")
  raise StopExecution

xticks = np.arange(0, 1441, 180)  # Every 3 hours
xtick_labels = [f"{hour:02d}:00" for hour in range(0,24,3)]  # Hourly labels


for miss_key in missgroup_dict.keys():
  # plt.hist(np.concatenate(missingness_gaps[i]), bins=50);
  # Plot the histogram
  plt.figure(figsize=(8, 3))
  data = np.concatenate(missgaps_all[miss_key])
  counts, bins, patches = plt.hist(data, bins=1440, color='blue')#, edgecolor='black')
  xticks = np.arange(0, bins[-1] + 1, bins[-1] // 8)  # Example: 10 evenly spaced ticks
  xtick_labels = [f"{int(x // 60):02d}:{int(x % 60):02d}" for x in xticks]  # Convert to HH:MM
  plt.xticks(xticks, xtick_labels, rotation=45)

  # Add labels and show the plot
  plt.xlabel("Time [HH:MM]")
  plt.ylabel("Frequency")
  plt.title(f"{missgroup_dict[miss_key]['name']} Missingness Gap Length Histograms")
  plt.tight_layout()
  plt.show()

In [None]:
xticks = np.arange(0, 1441, 180)  # Every hour
xtick_labels = [f"{hour:02d}:00" for hour in range(0,24,3)]  # Hourly labels


for miss_key in missgroup_dict.keys():
  # plt.hist(np.concatenate(missingness_gaps[i]), bins=50);
  # Plot the histogram
  plt.figure(figsize=(8, 3))
  data = np.concatenate(missgaps_all[miss_key])
  counts, bins, patches = plt.hist(data, bins=1440, color='blue')#, edgecolor='black')
  xticks = np.arange(0, bins[-1] + 1, bins[-1] // 96 //  3)# 8)  # Example: 10 evenly spaced ticks
  xtick_labels = [f"{int(x // 60):02d}:{int(x % 60):02d}" for x in xticks]  # Convert to HH:MM
  plt.xticks(xticks, xtick_labels, rotation=45)
  plt.xlim(0, 60)

  # Add labels and show the plot
  plt.xlabel("Time [HH:MM]")
  plt.ylabel("Frequency")
  plt.title(f"{missgroup_dict[miss_key]['name']} Missingness Gap Length Histograms")
  plt.tight_layout()
  plt.show()

# Sanity Checking

In [None]:
# @title CLT for Aggregate Statistics
DATASET_NAME = 'lsm_v2_pretraining_n_200000_10080m'  # @param
NUM_TRIALS = 5
NUMPPL_PERTRIAL = np.array([10, 100, 1_000, 10_000, 100_000])

data = tfds.load(
    'lsm',
    data_dir=f'{DEFAULT_DATA_ROOT}/{DATASET_NAME}',
    split='train',
    shuffle_files=False,
)

# Step 1: Compute and store the summary statistic `mean_missamt` for the dataset
mean_missamt_list = []
for sample in tqdm(
    data.take(np.sum(NUMPPL_PERTRIAL) * NUM_TRIALS), desc='Processing Data'
):
  mask = tf.io.parse_tensor(sample['mask'], out_type=tf.bool).numpy()
  mean_missamt = np.mean(
      mask
  )  # Compute the summary statistic (adjust calculation as needed)
  mean_missamt_list.append(mean_missamt)

# Convert the list to a NumPy array for efficient slicing
mean_missamt_array = np.array(mean_missamt_list)

# Step 2: Perform trials for each `num_ppl` in `num_ppl_list`
sample_stats = []

current_index = 0  # To keep track of the data index
for n in tqdm(NUMPPL_PERTRIAL, desc=f'Processing trials'):
  for _ in range(NUM_TRIALS):
    # Slice the next `n` samples for the current trial
    trial_samples = mean_missamt_array[current_index : current_index + n]
    current_index += n

    # Compute and store the sample mean
    sample_mean = np.mean(trial_samples)
    sample_stats.append(sample_mean)

In [None]:
num_ppl_list_repeat = np.repeat(NUMPPL_PERTRIAL, repeats=NUM_TRIALS)

# Assuming num_ppl_list_repeat and sample_stats are defined
plt.figure(figsize=(8, 6))
plt.scatter(
    num_ppl_list_repeat,
    sample_stats,
    alpha=0.25,
    color="blue",
    label="Sample Statistics",
)

# Set x-axis to log scale
plt.xscale("log")

# Add labels, title, and legend
plt.xlabel("Number of People (log scale)")
plt.ylabel("Mean % of Missingness ")
plt.title("Mean % of Missingness vs. Sample Size (Log Scale)")
plt.legend()
plt.axhline(np.mean(mean_missamt_list), c="red")
plt.grid(axis="y", which="both")
plt.minorticks_on()

print(f"True Mean % of Missingness {np.mean(sample_stats)}")

In [None]:
# @title Z-norm statistics per channel
DATASET_NAME = 'lsm_v2_pretraining_n_200000_10080m'  # @param
NUM_PPL = 100_000
NUM_CHANNELS = 32

data = tfds.load(
    'lsm',
    data_dir=f'{DEFAULT_DATA_ROOT}/{DATASET_NAME}',
    split='train',
    shuffle_files=False,
)

channelmean_all = []
channelstd_all = []

for sample in tqdm(data.take(NUM_PPL)):
  signal = (
      tf.io.parse_tensor(sample['input_signal'], out_type=tf.float).numpy().T
  )  # reshape(NUM_CHANNELS, 7, 10080//7)
  mask = (
      tf.io.parse_tensor(sample['mask'], out_type=tf.bool).numpy().T
  )  # reshape(NUM_CHANNELS, 7, 10080//7)

  # With z-norm happening before imputation
  signal[mask] = np.nan
  channelmean_all.append(np.nanmean(signal, axis=0))
  channelstd_all.append(np.nanstd(signal, axis=0))
  ### do not uncomment below code, this is for if z-norm happened after imputation
  # channelmean_all.append(np.mean(signal, axis=0))
  # channelstd_all.append(np.std(signal, axis=0))

  gc.collect()

In [None]:
for i, name in enumerate(labels):
  print(f"Mean of means {name}: {np.nanmean(channelmean_all[i])}")

for i, name in enumerate(labels):
  plt.figure(figsize=(5,2))
  plt.hist(channelmean_all[i], bins=50)
  plt.title(f"Histogram of Mean of {name}")

In [None]:
for i, name in enumerate(labels):
  print(f"Mean of Stds {name}: {np.nanmean(channelstd_all[i])}")

for i, name in enumerate(labels):
  plt.figure(figsize=(5,2))
  plt.hist(channelstd_all[i], bins=50)
  plt.title(f"Histogram of Std of {name}")