# HMM Model Selection 

Now that we have network-level signal, temporally trimmed data for all our subjects we can fit the HMM model! First we're going to try 10 states to see how that looks...
Random note: Before doing this I also confirmed that all the networks were ordered correctly across subjects.

In [2]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0))
assert torch.cuda.is_available(), "No GPU detected!"
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(0)}")

CUDA available: True
Device: Tesla T4


In [2]:
import os
import glob
import h5py
import numpy as np
import pickle
from sklearn.preprocessing import StandardScaler
from hmmlearn import hmm
import traceback
from tqdm import tqdm
from joblib import Parallel, delayed
from scipy.optimize import linear_sum_assignment
from scipy.stats import pearsonr

In [3]:
# === Settings ===
INPUT_DIR       = "/home/jovyan/narratives-project/shirer_components"
DATASET         = "timeseries"
STATE_COUNTS    = [10]
N_REPEATS       = 15
N_PERMUTATIONS  = 100
MAX_ITER        = 500
N_JOBS          = 4    # use all cores
PKL_OUT         = "/home/jovyan/narratives-project/hmm-objects/hmmlearn_consensus_results_k10perm.pkl"
RUN_PERMUTATIONS = True  # set to False to skip permutations

def load_data(input_dir):
    data = []
    for path in sorted(glob.glob(os.path.join(input_dir, "*.h5"))):
        with h5py.File(path, "r") as f:
            ts = f[DATASET][()].T
        ts = StandardScaler().fit_transform(ts)
        data.append(ts)
    return data

def compute_num_parameters(n_states, n_features):
    """
    Cacl
    """
    p_init  = n_states - 1
    p_trans = n_states * (n_states - 1)
    p_means = n_states * n_features
    p_cov   = n_states * (n_features * (n_features + 1) // 2)
    return p_init + p_trans + p_means + p_cov

def fractional_occupancy(state_seqs, k):
    """
    For each state sequence (from each subject), counts how often each state
    was active and averages across all timepoints to get fraction of time each 
    state was active. Used to check for unused states.
    """
    fo = np.zeros(k)
    total_len = 0
    for seq in state_seqs:
        fo += np.bincount(seq, minlength=k)
        total_len += len(seq)
    return fo / total_len

def circular_shift(ts):
    """
    Rolls a time series forward by a random number of TRs. Used 
    to generate null data for the permutation testing.
    """
    shift = np.random.randint(1, ts.shape[0])
    return np.roll(ts, shift, axis=0)

def fit_hmm_prestore(X, lengths, k, max_iter=MAX_ITER):
    model = hmm.GaussianHMM(n_components=k, covariance_type='full', n_iter=max_iter, verbose=True)
    model.fit(X, lengths)
    logL = model.score(X, lengths)
    hidden = model.predict(X, lengths)
    # split hidden back into per-subject paths
    paths, offset = [], 0
    for L in lengths:
        paths.append(hidden[offset:offset + L])
        offset += L
    return model, logL, paths

def fit_best_model(X, lengths, k, n_repeats):
    """
    Run N HMM repetitions, select most representative model by state alignment (Hungarian method).
    """
    all_models = []

    for run in range(1, n_repeats + 1):
        try:
            model, logL, paths = fit_hmm_prestore(X, lengths, k)
            all_models.append({
                "model": model,
                "logL": logL,
                "means": model.means_,
                "paths": paths
            })
            print(f"  [RUN {run}/{n_repeats}] logL = {logL:.1f}")
        except Exception as e:
            print(f"  [ERROR] run {run} k={k}: {e}")
            continue

    if not all_models:
        return None, None, None, None

    # Align all mean matrices and compute similarity
    means_all = [m["means"] for m in all_models]
    scores = align_and_score(means_all)
    best_idx = np.argmax(scores)
    best = all_models[best_idx]

    fo = fractional_occupancy(best["paths"], k)
    print(f"  [BEST] run={best_idx+1}/{n_repeats} k={k} logL={best['logL']:.1f}, FO min={fo.min():.3f}")
    return best["model"], best["logL"], fo, best["paths"]

def run_permutations(X, lengths, data, k, logL_real):
    def one_perm(_):
        shifted = [circular_shift(ts) for ts in data]
        Xs = np.vstack(shifted)
        try:
            _, logL, _ = fit_hmm_prestore(Xs, lengths, k)
            return logL
        except:
            return None

    nulls = Parallel(n_jobs=N_JOBS)(
        delayed(one_perm)(i) for i in range(N_PERMUTATIONS)
    )
    null_logLs = [nl for nl in nulls if nl is not None]
    pval = np.mean([logL_real <= nl for nl in null_logLs])
    return null_logLs, pval

def align_and_score(all_means):
    n_runs = len(all_means)
    scores = np.zeros(n_runs)

    for i in range(n_runs):
        curr = all_means[i]
        sims = []
        for j in range(n_runs):
            if i == j:
                continue
            other = all_means[j]
            # cost matrix = negative correlation
            cost = np.zeros((curr.shape[0], other.shape[0]))
            for s1 in range(curr.shape[0]):
                for s2 in range(other.shape[0]):
                    r, _ = pearsonr(curr[s1], other[s2])
                    cost[s1, s2] = -r  # negative for minimization
            row_ind, col_ind = linear_sum_assignment(cost)
            aligned_corrs = [-cost[r, c] for r, c in zip(row_ind, col_ind)]
            sims.append(np.mean(aligned_corrs))
        scores[i] = np.mean(sims)
    return scores

def main():
    data = load_data(INPUT_DIR)
    # pre-stack once
    X       = np.vstack(data)
    lengths = [d.shape[0] for d in data]

    print(f"[DEBUG] Loaded {len(data)} subjects")
    for i, d in enumerate(data):
        print(f"  Subject {i}: shape={d.shape}, mean={np.mean(d):.4f}, std={np.std(d):.4f}")

    results = []
    for k in STATE_COUNTS:
        print(f"Fitting HMM for k={k}...")
        model, logL, fo, paths = fit_best_model(X, lengths, k, N_REPEATS)
        if model is None:
            print(f"  [WARN] No model converged for k={k}")
            continue

        # model selection metrics
        n_params = compute_num_parameters(k, data[0].shape[1])
        aic = 2 * n_params - 2 * logL
        bic = np.log(sum(lengths)) * n_params - 2 * logL
        print(f"  Done. logL={logL:.1f}, AIC={aic:.1f}, FO min={fo.min():.3f}")

        null_logLs, pval = [], None
        if RUN_PERMUTATIONS:
            print(f"  Running {N_PERMUTATIONS} permutations for k={k}...")
            null_logLs, pval = run_permutations(X, lengths, data, k, logL)
            print(f"  Null model p-value: {pval:.4f}")
        else:
            print(f"  [SKIP] Permutations skipped for k={k}")

        results.append({
            "k": k,
            "model": model,
            "logL": logL,
            "AIC": aic,
            "BIC": bic,
            "FO": fo,
            "subject_paths": paths,
            "null_logLs": null_logLs,
            "pval": pval
        })

    with open(PKL_OUT, "wb") as f:
        pickle.dump(results, f)
    print(f"Saved HMMlearn permutation results to {PKL_OUT}")

if __name__ == "__main__":
    main()


[DEBUG] Loaded 75 subjects
  Subject 0: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 1: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 2: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 3: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 4: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 5: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 6: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 7: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 8: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 9: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 10: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 11: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 12: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 13: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 14: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 15: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 16: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 17: shape=(279, 14), mean=-0.00

         1 -261773.45078789             +nan
         2 -243146.74264435  +18626.70814354
         3 -237303.65747338   +5843.08517097
         4 -234604.49431960   +2699.16315378
         5 -232967.14240323   +1637.35191637
         6 -231913.55922860   +1053.58317463
         7 -231257.16188108    +656.39734751
         8 -230785.06232089    +472.09956020
         9 -230388.50810904    +396.55421184
        10 -230014.01146869    +374.49664036
        11 -229698.91408528    +315.09738341
        12 -229444.37448840    +254.53959687
        13 -229234.42148868    +209.95299972
        14 -229050.44304858    +183.97844010
        15 -228903.42965319    +147.01339538
        16 -228793.74347897    +109.68617422
        17 -228703.66688843     +90.07659054
        18 -228621.55956273     +82.10732570
        19 -228540.56614218     +80.99342056
        20 -228458.54194027     +82.02420191
        21 -228370.00373741     +88.53820285
        22 -228275.38310135     +94.62063606
        23

  [RUN 1/15] logL = -221495.7


         1 -254021.61436221             +nan
         2 -243594.02459014  +10427.58977207
         3 -238911.65136839   +4682.37322175
         4 -235753.54347905   +3158.10788934
         5 -233799.70604685   +1953.83743220
         6 -232593.77978567   +1205.92626118
         7 -231777.74634063    +816.03344504
         8 -231042.44561464    +735.30072599
         9 -230221.94800631    +820.49760833
        10 -229441.71615256    +780.23185375
        11 -228751.00944854    +690.70670401
        12 -228120.80311608    +630.20633247
        13 -227630.15057449    +490.65254159
        14 -227313.19749876    +316.95307573
        15 -227131.63218734    +181.56531142
        16 -227009.28183505    +122.35035229
        17 -226904.24139988    +105.04043517
        18 -226797.42294580    +106.81845409
        19 -226669.68365552    +127.73929028
        20 -226520.70270411    +148.98095141
        21 -226349.13250522    +171.57019888
        22 -226180.49183031    +168.64067492
        23

KeyboardInterrupt: 