# HMM Model Selection 

Now that we have network-level signal, temporally trimmed data for all our subjects we can fit the HMM model! First we're going to try 10 states to see how that looks...
Random note: Before doing this I also confirmed that all the networks were ordered correctly across subjects.

In [2]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0))
assert torch.cuda.is_available(), "No GPU detected!"
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(0)}")

CUDA available: True
Device: Tesla T4


In [2]:
import os
import glob
import h5py
import numpy as np
import pickle
from sklearn.preprocessing import StandardScaler
from hmmlearn import hmm
import traceback
from tqdm import tqdm
from joblib import Parallel, delayed
from scipy.optimize import linear_sum_assignment
from scipy.stats import pearsonr

In [None]:
# === Settings ===
INPUT_DIR       = "/home/jovyan/narratives-project/shirer_components"
DATASET         = "timeseries"
STATE_COUNTS    = [16]
N_REPEATS       = 15
N_PERMUTATIONS  = 100
MAX_ITER        = 500
N_JOBS          = 4    # use all cores
PKL_OUT         = "/home/jovyan/narratives-project/hmm-objects/hmmlearn_consensus_results_k16.pkl"
RUN_PERMUTATIONS = False  # set to False to skip permutations

def load_data(input_dir):
    data = []
    for path in sorted(glob.glob(os.path.join(input_dir, "*.h5"))):
        with h5py.File(path, "r") as f:
            ts = f[DATASET][()].T
        ts = StandardScaler().fit_transform(ts)
        data.append(ts)
    return data

def compute_num_parameters(n_states, n_features):
    """
    Cacl
    """
    p_init  = n_states - 1
    p_trans = n_states * (n_states - 1)
    p_means = n_states * n_features
    p_cov   = n_states * (n_features * (n_features + 1) // 2)
    return p_init + p_trans + p_means + p_cov

def fractional_occupancy(state_seqs, k):
    """
    For each state sequence (from each subject), counts how often each state
    was active and averages across all timepoints to get fraction of time each 
    state was active. Used to check for unused states.
    """
    fo = np.zeros(k)
    total_len = 0
    for seq in state_seqs:
        fo += np.bincount(seq, minlength=k)
        total_len += len(seq)
    return fo / total_len

def circular_shift(ts):
    """
    Rolls a time series forward by a random number of TRs. Used 
    to generate null data for the permutation testing.
    """
    shift = np.random.randint(1, ts.shape[0])
    return np.roll(ts, shift, axis=0)

def fit_hmm_prestore(X, lengths, k, max_iter=MAX_ITER):
    model = hmm.GaussianHMM(n_components=k, covariance_type='full', n_iter=max_iter, verbose=True)
    model.fit(X, lengths)
    logL = model.score(X, lengths)
    hidden = model.predict(X, lengths)
    # split hidden back into per-subject paths
    paths, offset = [], 0
    for L in lengths:
        paths.append(hidden[offset:offset + L])
        offset += L
    return model, logL, paths

def fit_best_model(X, lengths, k, n_repeats):
    """
    Run N HMM repetitions, select most representative model by state alignment (Hungarian method).
    """
    all_models = []

    for run in range(1, n_repeats + 1):
        try:
            model, logL, paths = fit_hmm_prestore(X, lengths, k)
            all_models.append({
                "model": model,
                "logL": logL,
                "means": model.means_,
                "paths": paths
            })
            print(f"  [RUN {run}/{n_repeats}] logL = {logL:.1f}")
        except Exception as e:
            print(f"  [ERROR] run {run} k={k}: {e}")
            continue

    if not all_models:
        return None, None, None, None

    # Align all mean matrices and compute similarity
    means_all = [m["means"] for m in all_models]
    scores = align_and_score(means_all)
    best_idx = np.argmax(scores)
    best = all_models[best_idx]

    fo = fractional_occupancy(best["paths"], k)
    print(f"  [BEST] run={best_idx+1}/{n_repeats} k={k} logL={best['logL']:.1f}, FO min={fo.min():.3f}")
    return best["model"], best["logL"], fo, best["paths"]

def run_permutations(X, lengths, data, k, logL_real):
    def one_perm(_):
        shifted = [circular_shift(ts) for ts in data]
        Xs = np.vstack(shifted)
        try:
            _, logL, _ = fit_hmm_prestore(Xs, lengths, k)
            return logL
        except:
            return None

    nulls = Parallel(n_jobs=N_JOBS)(
        delayed(one_perm)(i) for i in range(N_PERMUTATIONS)
    )
    null_logLs = [nl for nl in nulls if nl is not None]
    pval = np.mean([logL_real <= nl for nl in null_logLs])
    return null_logLs, pval

def align_and_score(all_means):
    n_runs = len(all_means)
    scores = np.zeros(n_runs)

    for i in range(n_runs):
        curr = all_means[i]
        sims = []
        for j in range(n_runs):
            if i == j:
                continue
            other = all_means[j]
            # cost matrix = negative correlation
            cost = np.zeros((curr.shape[0], other.shape[0]))
            for s1 in range(curr.shape[0]):
                for s2 in range(other.shape[0]):
                    r, _ = pearsonr(curr[s1], other[s2])
                    cost[s1, s2] = -r  # negative for minimization
            row_ind, col_ind = linear_sum_assignment(cost)
            aligned_corrs = [-cost[r, c] for r, c in zip(row_ind, col_ind)]
            sims.append(np.mean(aligned_corrs))
        scores[i] = np.mean(sims)
    return scores

def main():
    data = load_data(INPUT_DIR)
    # pre-stack once
    X       = np.vstack(data)
    lengths = [d.shape[0] for d in data]

    print(f"[DEBUG] Loaded {len(data)} subjects")
    for i, d in enumerate(data):
        print(f"  Subject {i}: shape={d.shape}, mean={np.mean(d):.4f}, std={np.std(d):.4f}")

    results = []
    for k in STATE_COUNTS:
        print(f"Fitting HMM for k={k}...")
        model, logL, fo, paths = fit_best_model(X, lengths, k, N_REPEATS)
        if model is None:
            print(f"  [WARN] No model converged for k={k}")
            continue

        # model selection metrics
        n_params = compute_num_parameters(k, data[0].shape[1])
        aic = 2 * n_params - 2 * logL
        bic = np.log(sum(lengths)) * n_params - 2 * logL
        print(f"  Done. logL={logL:.1f}, AIC={aic:.1f}, FO min={fo.min():.3f}")

        null_logLs, pval = [], None
        if RUN_PERMUTATIONS:
            print(f"  Running {N_PERMUTATIONS} permutations for k={k}...")
            null_logLs, pval = run_permutations(X, lengths, data, k, logL)
            print(f"  Null model p-value: {pval:.4f}")
        else:
            print(f"  [SKIP] Permutations skipped for k={k}")

        results.append({
            "k": k,
            "model": model,
            "logL": logL,
            "AIC": aic,
            "BIC": bic,
            "FO": fo,
            "subject_paths": paths,
            "null_logLs": null_logLs,
            "pval": pval
        })

    with open(PKL_OUT, "wb") as f:
        pickle.dump(results, f)
    print(f"Saved HMMlearn permutation results to {PKL_OUT}")

if __name__ == "__main__":
    main()


[DEBUG] Loaded 75 subjects
  Subject 0: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 1: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 2: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 3: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 4: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 5: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 6: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 7: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 8: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 9: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 10: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 11: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 12: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 13: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 14: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 15: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 16: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 17: shape=(279, 14), mean=-0.00

         1 -258999.37136508             +nan
         2 -242449.50509618  +16549.86626890
         3 -236818.83650203   +5630.66859415
         4 -232713.31614581   +4105.52035621
         5 -230150.16473463   +2563.15141118
         6 -228452.90406133   +1697.26067330
         7 -227045.83908906   +1407.06497227
         8 -225899.47566315   +1146.36342591
         9 -225049.78035267    +849.69531048
        10 -224481.02396943    +568.75638324
        11 -224109.86676200    +371.15720743
        12 -223822.76232547    +287.10443653
        13 -223572.18031853    +250.58200694
        14 -223346.31883946    +225.86147908
        15 -223134.19801486    +212.12082460
        16 -222927.39879949    +206.79921537
        17 -222731.60244504    +195.79635445
        18 -222541.12375505    +190.47868999
        19 -222336.11756887    +205.00618619
        20 -222146.30746059    +189.81010827
        21 -221977.06741026    +169.24005033
        22 -221816.54849577    +160.51891449
        23

  [RUN 1/15] logL = -215128.9


         1 -259571.64615571             +nan
         2 -239659.56593164  +19912.08022407
         3 -233903.47688818   +5756.08904346
         4 -231256.84205339   +2646.63483479
         5 -229825.12832303   +1431.71373036
         6 -228805.68814810   +1019.44017493
         7 -228039.08329296    +766.60485514
         8 -227410.18075540    +628.90253756
         9 -226858.13475292    +552.04600248
        10 -226350.01585175    +508.11890117
        11 -225789.42427104    +560.59158071
        12 -225134.91013053    +654.51414051
        13 -224432.82364353    +702.08648700
        14 -223808.53580599    +624.28783754
        15 -223321.70002832    +486.83577767
        16 -222904.79738545    +416.90264286
        17 -222593.74278644    +311.05459901
        18 -222360.82782377    +232.91496267
        19 -222150.40926551    +210.41855825
        20 -221930.31316388    +220.09610163
        21 -221709.44157430    +220.87158958
        22 -221503.48136932    +205.96020499
        23

  [RUN 2/15] logL = -214607.7


         1 -265194.15927196             +nan
         2 -243198.13216783  +21996.02710413
         3 -235603.00642357   +7595.12574426
         4 -231622.34290033   +3980.66352324
         5 -229307.72716571   +2314.61573462
         6 -227784.71046420   +1523.01670151
         7 -226716.66029719   +1068.05016701
         8 -225873.25926041    +843.40103678
         9 -225229.13399244    +644.12526797
        10 -224665.51923323    +563.61475921
        11 -224158.67761260    +506.84162063
        12 -223730.77441943    +427.90319317
        13 -223349.46301713    +381.31140230
        14 -222998.10351340    +351.35950373
        15 -222668.55126373    +329.55224967
        16 -222373.03462188    +295.51664185
        17 -222105.49804304    +267.53657885
        18 -221857.97629388    +247.52174915
        19 -221600.42763444    +257.54865945
        20 -221317.54561963    +282.88201481
        21 -220995.66222580    +321.88339382
        22 -220628.36200092    +367.30022489
        23

  [RUN 3/15] logL = -216057.9


         1 -258056.10611525             +nan
         2 -242345.26633920  +15710.83977605
         3 -237804.15551727   +4541.11082193
         4 -234430.02439604   +3374.13112123
         5 -231538.58881892   +2891.43557713
         6 -229606.53136885   +1932.05745007
         7 -228365.04621185   +1241.48515700
         8 -227467.74639679    +897.29981506
         9 -226837.92863159    +629.81776519
        10 -226379.82762947    +458.10100212
        11 -226025.27584851    +354.55178096
        12 -225700.23919169    +325.03665682
        13 -225344.15639577    +356.08279592
        14 -224937.48796978    +406.66842598
        15 -224491.98271802    +445.50525177
        16 -224091.17238289    +400.81033513
        17 -223788.53567393    +302.63670896
        18 -223548.62295178    +239.91272215
        19 -223359.62217395    +189.00077783
        20 -223206.95843136    +152.66374260
        21 -223073.55461542    +133.40381593
        22 -222955.93682088    +117.61779454
        23

  [RUN 4/15] logL = -216783.4


         1 -257021.60665074             +nan
         2 -240166.87383939  +16854.73281136
         3 -234674.24321949   +5492.63061990
         4 -231332.19488426   +3342.04833523
         5 -229150.56755208   +2181.62733218
         6 -227653.13051740   +1497.43703468
         7 -226571.77786961   +1081.35264778
         8 -225694.64112172    +877.13674789
         9 -224842.09204539    +852.54907633
        10 -224121.78808331    +720.30396208
        11 -223543.02774424    +578.76033907
        12 -223024.56833769    +518.45940655
        13 -222568.15382411    +456.41451358
        14 -222174.35725577    +393.79656833
        15 -221813.44669508    +360.91056069
        16 -221490.17388813    +323.27280695
        17 -221209.90027056    +280.27361757
        18 -220963.85441059    +246.04585997
        19 -220759.12251219    +204.73189841
        20 -220583.44450465    +175.67800754
        21 -220392.30879930    +191.13570535
        22 -220194.80718743    +197.50161187
        23

  [RUN 5/15] logL = -216699.3


         1 -262592.76381075             +nan
         2 -242994.29443891  +19598.46937183
         3 -236265.63519045   +6728.65924847
         4 -231819.68720346   +4445.94798698
         5 -229342.37558371   +2477.31161975
         6 -227940.06442226   +1402.31116145
         7 -226964.03045998    +976.03396228
         8 -226231.90451751    +732.12594247
         9 -225645.41413511    +586.49038240
        10 -225123.37287507    +522.04126003
        11 -224690.81935577    +432.55351931
        12 -224332.85343410    +357.96592167
        13 -223981.22336204    +351.63007206
        14 -223603.47228980    +377.75107223
        15 -223208.72207015    +394.75021965
        16 -222831.12156630    +377.60050386
        17 -222540.00957985    +291.11198644
        18 -222290.35928520    +249.65029465
        19 -222080.94297223    +209.41631297
        20 -221927.74797978    +153.19499246
        21 -221814.01029719    +113.73768258
        22 -221713.49459690    +100.51570029
        23

  [RUN 6/15] logL = -217081.2


         1 -257954.26526660             +nan
         2 -240317.21536455  +17637.04990205
         3 -235393.98613500   +4923.22922955
         4 -232540.23556175   +2853.75057324
         5 -230598.51579798   +1941.71976377
         6 -229067.94508336   +1530.57071463
         7 -227699.67130212   +1368.27378124
