Goal: fit a 2 state heirarchical HMM on the data to detect fatigue vs not fatigue.

Data to feed in: subject_data_for_HMM.json

Data format:
84 subjects

subject_data = {
    '1': {
        'epoch_accuracy': [0.8, 0.7, 0.9, ...],  # 30 values (one value pxser block)
        'post_epoch_post_cue_rest_duration': [2.1, 3.5, 1.8, ...],  # 30 values (one value per block)
        'rest_cue_type': ['switch' or 'stay'] # if the cue before the rest period was switch or stay (basically this is the same as epoch_follows_task_switch but one row up)
        'epoch_follows_task_switch': [1 if switch, 0 if stay], # if this block was a product of a task switch (first row has to be NA because it was neither)
        'block_number': [1,1,1, 2,2,2, 3,3,3, ..., 10,10,10],
        'epoch_within_block': [1,2,3, 1,2,3, 1,2,3, ..., 1,2,3],
        'overall_epoch': [1,2,3,4,5,6,7,8,9, ..., 28,29,30],
        'game_type': ['A','A','A', 'B','B','B', ...],
        'pre_epoch_rest_duration': [previous rest duration for each epoch] # first row has to be NA because there was no previous rest duration
    },
    # ... for all 84 subjects
}

In [5]:
# imports
import json
import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt


In [3]:
# load the data
with open('subject_data_for_HMM.json', 'r') as f:
    subject_data = json.load(f)

In [6]:
# Prepare data: list of arrays, one per subject
# emissions (what it's predicting)
epoch_accuracies = [np.array(subj['epoch_accuracy']) for subj in subject_data.values()]
rest_durations = [np.array(subj['post_epoch_post_cue_rest_duration']) for subj in subject_data.values()]

# covariates (depends on these)
# later model some of these as Predictors of Transition Probabilities
rest_cue_type = [np.array(subj['rest_cue_type']) for subj in subject_data.values()]
epoch_follows_task_switch = [np.array(subj['epoch_follows_task_switch']) for subj in subject_data.values()]
overall_epoch = [np.array(subj['overall_epoch']) for subj in subject_data.values()]
game_type = [np.array(subj['game_type']) for subj in subject_data.values()]

# data about the dataset
n_subjects = len(epoch_accuracies)
lengths = [len(x) for x in epoch_accuracies]

# concatenate all of the relevant vars as observations
all_obs = [np.column_stack((acc, rest)) for acc, rest in zip(epoch_accuracies, rest_durations)]

n_states = 2 # high fatigue, low fatigue