Goal: fit a 2 state non-heirarchical HMM on the data to detect fatigue vs not fatigue.

Data to feed in: subject_data_for_HMM.json

Data format:
84 subjects

subject_data = {
    '1': {
        'epoch_accuracy': [0.8, 0.7, 0.9, ...],  # 30 values (one value pxser block)
        'post_epoch_post_cue_rest_duration': [2.1, 3.5, 1.8, ...],  # 30 values (one value per block)
        'rest_cue_type': ['switch' or 'stay'] # if the cue before the rest period was switch or stay (basically this is the same as epoch_follows_task_switch but one row up)
        'epoch_follows_task_switch': [1 if switch, 0 if stay], # if this block was a product of a task switch (first row has to be NA because it was neither)
        'block_number': [1,1,1, 2,2,2, 3,3,3, ..., 10,10,10],
        'epoch_within_block': [1,2,3, 1,2,3, 1,2,3, ..., 1,2,3],
        'overall_epoch': [1,2,3,4,5,6,7,8,9, ..., 28,29,30],
        'game_type': ['A','A','A', 'B','B','B', ...],
        'pre_epoch_rest_duration': [previous rest duration for each epoch] # first row has to be NA because there was no previous rest duration
    },
    # ... for all 84 subjects
}

In [1]:
# imports
import json
import numpy as np
import pandas as pd
import pymc as pm;print(pm.__version__)
import pytensor as pt
import pytensor.tensor as pt
from pytensor.scan import scan



5.12.0


In [2]:
# load the data
with open('subject_data_for_HMM.json', 'r') as f:
    subject_data = json.load(f)

In [6]:
# Prepare data: list of arrays, one per subject
# emissions (what it's predicting)
epoch_accuracies = [np.array(subj['epoch_accuracy']) for subj in subject_data.values()]
rest_durations = [np.array(subj['post_epoch_post_cue_rest_duration']) for subj in subject_data.values()]

# covariates (depends on these)
overall_epoch = [np.array(subj['overall_epoch']) for subj in subject_data.values()]
game_type = [np.array(subj['game_type']) for subj in subject_data.values()]
game_type_numeric = [
    np.array([0 if g == 'digit_span' else 1 for g in subj['game_type']])
    for subj in subject_data.values()
]

rest_cue_type = [np.array(subj['rest_cue_type']) for subj in subject_data.values()]# later model as Predictors of Transition Probabilities
epoch_follows_task_switch = [np.array(subj['epoch_follows_task_switch']) for subj in subject_data.values()]# later model as Predictors of Transition Probabilities

# data about the dataset
n_subjects = len(epoch_accuracies)
lengths = [len(x) for x in epoch_accuracies]

# concatenate all of the relevant vars as observations (the emissions)
all_obs = [np.column_stack((acc, rest)) for acc, rest in zip(epoch_accuracies, rest_durations)]

n_states = 2 # high fatigue, low fatigue

In [9]:
for subj_idx in range(n_subjects):
    print(f"working on {subj_idx}")
    
    obs = all_obs[subj_idx]  # shape (n_timepoints, 2) (epoch_accuracies, rest_durations)
    game = game_type_numeric[subj_idx]
    time = overall_epoch[subj_idx]

    with pm.Model() as model:
        # Group-level (subject-level) parameters
        base_mu = pm.Normal('base_mu', mu=0, sigma=1, shape=(n_states, 2))
        beta_game = pm.Normal('beta_game', mu=0, sigma=1, shape=(n_states, 2))
        beta_time = pm.Normal('beta_time', mu=0, sigma=1, shape=(n_states, 2))
        sigma = pm.HalfNormal('sigma', sigma=1, shape=(n_states, 2))
        pi = pm.Dirichlet('pi', a=np.ones(n_states))
        A = pm.Dirichlet('A', a=np.ones((n_states, n_states)), shape=(n_states, n_states))

        # Custom HMM likelihood (for this subject only)
        class HMMCustomDist(pm.CustomDist):
            @staticmethod
            def dist(size, value, game, time, base_mu, beta_game, beta_time, sigma, pi, A):
                return pt.zeros_like(value)
            @staticmethod
            def logp(size,value, game, time, base_mu, beta_game, beta_time, sigma, pi, A):
                mu = (base_mu[None, :, :] +
                      beta_game[None, :, :] * game[:, None, None] +
                      beta_time[None, :, :] * time[:, None, None])
                logp_states = []
                for k in range(n_states):
                    logp_acc = pm.logp(pm.Normal.dist(mu=mu[:, k, 0], sigma=sigma[k, 0]), value[:, 0])
                    logp_rest = pm.logp(pm.Normal.dist(mu=mu[:, k, 1], sigma=sigma[k, 1]), value[:, 1])
                    logp_states.append(logp_acc + logp_rest)
                logp_states = pt.stack(logp_states, axis=1)
                def scan_fn(logp_t, prev_alpha):
                    alpha = pt.logsumexp(prev_alpha + pt.log(A), axis=1) + logp_t
                    return alpha
                alpha_0 = pt.log(pi) + logp_states[0]
                alphas, _ = scan(fn=scan_fn, sequences=logp_states[1:], outputs_info=alpha_0)
                logp = pt.logsumexp(alphas[-1])
                return logp

        obs_i = pm.CustomDist(
            f'obs_{subj_idx}',
            obs,
            game,
            time,
            base_mu,
            beta_game,
            beta_time,
            sigma,
            pi,
            A,
            dist=HMMCustomDist.dist,
            logp=HMMCustomDist.logp,
            ndim_supp=0,
            shape=()
        )

        trace = pm.sample(500, tune=500, target_accept=0.95, progressbar=True)
        # Save or analyze trace here

working on 0


KeyboardInterrupt: 