# HMM Model Selection 

Now that we have network-level signal, temporally trimmed data for all our subjects we can fit the HMM model! First we're going to try 10 states to see how that looks...

In [3]:
import os
import glob
import h5py
import numpy as np
import pickle
from sklearn.preprocessing import StandardScaler
from hmmlearn import hmm
import traceback
from tqdm import tqdm
from joblib import Parallel, delayed
from scipy.optimize import linear_sum_assignment
from scipy.stats import pearsonr

In [None]:
# === Settings ===
INPUT_DIR       = "/home/jovyan/narratives-project/shirer_components"
DATASET         = "timeseries"
STATE_COUNTS    = [10]
N_REPEATS       = 15
N_PERMUTATIONS  = 100
MAX_ITER        = 500
N_JOBS          = -1    # use all cores
PKL_OUT         = "/home/jovyan/narratives-project/hmm-objects/hmmlearn_consensus_results_k10.pkl"
RUN_PERMUTATIONS = False  # set to False to skip permutations

def load_data(input_dir):
    data = []
    for path in sorted(glob.glob(os.path.join(input_dir, "*.h5"))):
        with h5py.File(path, "r") as f:
            ts = f[DATASET][()].T
        ts = StandardScaler().fit_transform(ts)
        data.append(ts)
    return data

def compute_num_parameters(n_states, n_features):
    """
    Cacl
    """
    p_init  = n_states - 1
    p_trans = n_states * (n_states - 1)
    p_means = n_states * n_features
    p_cov   = n_states * (n_features * (n_features + 1) // 2)
    return p_init + p_trans + p_means + p_cov

def fractional_occupancy(state_seqs, k):
    """
    For each state sequence (from each subject), counts how often each state
    was active and averages across all timepoints to get fraction of time each 
    state was active. Used to check for unused states.
    """
    fo = np.zeros(k)
    total_len = 0
    for seq in state_seqs:
        fo += np.bincount(seq, minlength=k)
        total_len += len(seq)
    return fo / total_len

def circular_shift(ts):
    """
    Rolls a time series forward by a random number of TRs. Used 
    to generate null data for the permutation testing.
    """
    shift = np.random.randint(1, ts.shape[0])
    return np.roll(ts, shift, axis=0)

def fit_hmm_prestore(X, lengths, k, max_iter=MAX_ITER):
    model = hmm.GaussianHMM(n_components=k, covariance_type='full', n_iter=max_iter, verbose=True)
    model.fit(X, lengths)
    logL = model.score(X, lengths)
    hidden = model.predict(X, lengths)
    # split hidden back into per-subject paths
    paths, offset = [], 0
    for L in lengths:
        paths.append(hidden[offset:offset + L])
        offset += L
    return model, logL, paths

def fit_best_model(X, lengths, k, n_repeats):
    """
    Run N HMM repetitions, select most representative model by state alignment (Hungarian method).
    """
    all_models = []

    for run in range(1, n_repeats + 1):
        try:
            model, logL, paths = fit_hmm_prestore(X, lengths, k)
            all_models.append({
                "model": model,
                "logL": logL,
                "means": model.means_,
                "paths": paths
            })
            print(f"  [RUN {run}/{n_repeats}] logL = {logL:.1f}")
        except Exception as e:
            print(f"  [ERROR] run {run} k={k}: {e}")
            continue

    if not all_models:
        return None, None, None, None

    # Align all mean matrices and compute similarity
    means_all = [m["means"] for m in all_models]
    scores = align_and_score(means_all)
    best_idx = np.argmax(scores)
    best = all_models[best_idx]

    fo = fractional_occupancy(best["paths"], k)
    print(f"  [BEST] run={best_idx+1}/{n_repeats} k={k} logL={best['logL']:.1f}, FO min={fo.min():.3f}")
    return best["model"], best["logL"], fo, best["paths"]

def run_permutations(X, lengths, data, k, logL_real):
    def one_perm(_):
        shifted = [circular_shift(ts) for ts in data]
        Xs = np.vstack(shifted)
        try:
            _, logL, _ = fit_hmm_prestore(Xs, lengths, k)
            return logL
        except:
            return None

    nulls = Parallel(n_jobs=N_JOBS)(
        delayed(one_perm)(i) for i in range(N_PERMUTATIONS)
    )
    null_logLs = [nl for nl in nulls if nl is not None]
    pval = np.mean([logL_real <= nl for nl in null_logLs])
    return null_logLs, pval

def align_and_score(all_means):
    n_runs = len(all_means)
    scores = np.zeros(n_runs)

    for i in range(n_runs):
        curr = all_means[i]
        sims = []
        for j in range(n_runs):
            if i == j:
                continue
            other = all_means[j]
            # cost matrix = negative correlation
            cost = np.zeros((curr.shape[0], other.shape[0]))
            for s1 in range(curr.shape[0]):
                for s2 in range(other.shape[0]):
                    r, _ = pearsonr(curr[s1], other[s2])
                    cost[s1, s2] = -r  # negative for minimization
            row_ind, col_ind = linear_sum_assignment(cost)
            aligned_corrs = [-cost[r, c] for r, c in zip(row_ind, col_ind)]
            sims.append(np.mean(aligned_corrs))
        scores[i] = np.mean(sims)
    return scores

def main():
    data = load_data(INPUT_DIR)
    # pre-stack once
    X       = np.vstack(data)
    lengths = [d.shape[0] for d in data]

    print(f"[DEBUG] Loaded {len(data)} subjects")
    for i, d in enumerate(data):
        print(f"  Subject {i}: shape={d.shape}, mean={np.mean(d):.4f}, std={np.std(d):.4f}")

    results = []
    for k in STATE_COUNTS:
        print(f"Fitting HMM for k={k}...")
        model, logL, fo, paths = fit_best_model(X, lengths, k, N_REPEATS)
        if model is None:
            print(f"  [WARN] No model converged for k={k}")
            continue

        # model selection metrics
        n_params = compute_num_parameters(k, data[0].shape[1])
        aic = 2 * n_params - 2 * logL
        bic = np.log(sum(lengths)) * n_params - 2 * logL
        print(f"  Done. logL={logL:.1f}, AIC={aic:.1f}, FO min={fo.min():.3f}")

        null_logLs, pval = [], None
        if RUN_PERMUTATIONS:
            print(f"  Running {N_PERMUTATIONS} permutations for k={k}...")
            null_logLs, pval = run_permutations(X, lengths, data, k, logL)
            print(f"  Null model p-value: {pval:.4f}")
        else:
            print(f"  [SKIP] Permutations skipped for k={k}")

        results.append({
            "k": k,
            "model": model,
            "logL": logL,
            "AIC": aic,
            "BIC": bic,
            "FO": fo,
            "subject_paths": paths,
            "null_logLs": null_logLs,
            "pval": pval
        })

    with open(PKL_OUT, "wb") as f:
        pickle.dump(results, f)
    print(f"Saved HMMlearn permutation results to {PKL_OUT}")

if __name__ == "__main__":
    main()


[DEBUG] Loaded 75 subjects
  Subject 0: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 1: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 2: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 3: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 4: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 5: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 6: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 7: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 8: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 9: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 10: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 11: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 12: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 13: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 14: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 15: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 16: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 17: shape=(279, 14), mean=-0.00

         1 -257173.50734666             +nan
         2 -244615.90706294  +12557.60028371
         3 -240057.36896258   +4558.53810037
         4 -236185.55705496   +3871.81190762
         5 -232950.56146077   +3234.99559418
         6 -230726.78234883   +2223.77911194
         7 -229238.84528581   +1487.93706302
         8 -228124.37213293   +1114.47315288
         9 -227268.71751282    +855.65462011
        10 -226624.09050739    +644.62700544
        11 -226132.56186824    +491.52863914
        12 -225700.79734208    +431.76452616
        13 -225298.50104528    +402.29629680
        14 -224962.20271557    +336.29832971
        15 -224654.33642904    +307.86628653
        16 -224369.02615863    +285.31027041
        17 -224079.31811468    +289.70804395
        18 -223809.90094921    +269.41716547
        19 -223576.20223399    +233.69871522
        20 -223385.65698306    +190.54525093
        21 -223224.39210088    +161.26488218
        22 -223089.80559221    +134.58650867
        23

  [RUN 1/15] logL = -221275.7


         1 -256919.72811587             +nan
         2 -245216.32912288  +11703.39899299
         3 -239817.49398277   +5398.83514011
         4 -236234.85883399   +3582.63514877
         5 -233816.64805046   +2418.21078353
         6 -232226.91633503   +1589.73171543
         7 -231206.56124919   +1020.35508583
         8 -230395.14305987    +811.41818933
         9 -229666.95878217    +728.18427770
        10 -229162.58380125    +504.37498092
        11 -228805.53705253    +357.04674872
        12 -228494.37274448    +311.16430805
        13 -228188.41693218    +305.95581230
        14 -227883.01637217    +305.40056001
        15 -227599.71569346    +283.30067871
        16 -227365.49724895    +234.21844451
        17 -227180.26954365    +185.22770530
        18 -227033.49401491    +146.77552874
        19 -226918.15128652    +115.34272839
        20 -226815.65315289    +102.49813363
        21 -226699.63331180    +116.01984110
        22 -226527.66353104    +171.96978076
        23

  [RUN 2/15] logL = -224147.8


         1 -262654.54621498             +nan
         2 -247039.89149570  +15614.65471928
         3 -241943.99695815   +5095.89453755
         4 -238698.28962627   +3245.70733188
         5 -235933.43276605   +2764.85686022
         6 -233652.89606885   +2280.53669720
         7 -232165.30852279   +1487.58754606
         8 -231054.58844267   +1110.72008012
         9 -230172.60888742    +881.97955525
        10 -229491.64455608    +680.96433134
        11 -228937.10580016    +554.53875592
        12 -228472.21986008    +464.88594008
        13 -228156.39188035    +315.82797973
        14 -227945.28686682    +211.10501353
        15 -227789.96499359    +155.32187323
        16 -227660.16607470    +129.79891889
        17 -227549.55435310    +110.61172160
        18 -227449.99954669     +99.55480641
        19 -227351.16593773     +98.83360896
        20 -227236.51890993    +114.64702780
        21 -227089.65997361    +146.85893631
        22 -226906.12477361    +183.53520000
        23

  [RUN 3/15] logL = -224247.6


         1 -264376.79705727             +nan
         2 -246935.82184212  +17440.97521515
         3 -240174.60331475   +6761.21852737
         4 -235909.95206331   +4264.65125143
         5 -233081.71231068   +2828.23975264
         6 -231359.03375458   +1722.67855609
         7 -230493.32524329    +865.70851129
         8 -230073.22797236    +420.09727093
         9 -229802.33746739    +270.89050497
        10 -229599.08262476    +203.25484263
        11 -229403.16476832    +195.91785645
        12 -229205.03672007    +198.12804824
        13 -228946.70466238    +258.33205769
        14 -228600.01445205    +346.69021033
        15 -228179.69779017    +420.31666188
        16 -227728.76560246    +450.93218771
        17 -227314.05496517    +414.71063729
        18 -226991.39976844    +322.65519673
        19 -226757.91602775    +233.48374069
        20 -226581.22752991    +176.68849784
        21 -226443.87880028    +137.34872962
        22 -226328.63938390    +115.23941639
        23