# HMM Model Selection 

Now that we have network-level signal, temporally trimmed data for all our subjects we can fit the HMM model! First we're going to try 10 states to see how that looks...
Random note: Before doing this I also confirmed that all the networks were ordered correctly across subjects.

In [2]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0))
assert torch.cuda.is_available(), "No GPU detected!"
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(0)}")

CUDA available: True
Device: Tesla T4


In [2]:
import os
import glob
import h5py
import numpy as np
import pickle
from sklearn.preprocessing import StandardScaler
from hmmlearn import hmm
import traceback
from tqdm import tqdm
from joblib import Parallel, delayed
from scipy.optimize import linear_sum_assignment
from scipy.stats import pearsonr

In [None]:
# === Settings ===
INPUT_DIR       = "/home/jovyan/narratives-project/shirer_components"
DATASET         = "timeseries"
STATE_COUNTS    = [14]
N_REPEATS       = 15
N_PERMUTATIONS  = 100
MAX_ITER        = 500
N_JOBS          = 4    # use all cores
PKL_OUT         = "/home/jovyan/narratives-project/hmm-objects/hmmlearn_consensus_results_k14.pkl"
RUN_PERMUTATIONS = False  # set to False to skip permutations

def load_data(input_dir):
    data = []
    for path in sorted(glob.glob(os.path.join(input_dir, "*.h5"))):
        with h5py.File(path, "r") as f:
            ts = f[DATASET][()].T
        ts = StandardScaler().fit_transform(ts)
        data.append(ts)
    return data

def compute_num_parameters(n_states, n_features):
    """
    Cacl
    """
    p_init  = n_states - 1
    p_trans = n_states * (n_states - 1)
    p_means = n_states * n_features
    p_cov   = n_states * (n_features * (n_features + 1) // 2)
    return p_init + p_trans + p_means + p_cov

def fractional_occupancy(state_seqs, k):
    """
    For each state sequence (from each subject), counts how often each state
    was active and averages across all timepoints to get fraction of time each 
    state was active. Used to check for unused states.
    """
    fo = np.zeros(k)
    total_len = 0
    for seq in state_seqs:
        fo += np.bincount(seq, minlength=k)
        total_len += len(seq)
    return fo / total_len

def circular_shift(ts):
    """
    Rolls a time series forward by a random number of TRs. Used 
    to generate null data for the permutation testing.
    """
    shift = np.random.randint(1, ts.shape[0])
    return np.roll(ts, shift, axis=0)

def fit_hmm_prestore(X, lengths, k, max_iter=MAX_ITER):
    model = hmm.GaussianHMM(n_components=k, covariance_type='full', n_iter=max_iter, verbose=True)
    model.fit(X, lengths)
    logL = model.score(X, lengths)
    hidden = model.predict(X, lengths)
    # split hidden back into per-subject paths
    paths, offset = [], 0
    for L in lengths:
        paths.append(hidden[offset:offset + L])
        offset += L
    return model, logL, paths

def fit_best_model(X, lengths, k, n_repeats):
    """
    Run N HMM repetitions, select most representative model by state alignment (Hungarian method).
    """
    all_models = []

    for run in range(1, n_repeats + 1):
        try:
            model, logL, paths = fit_hmm_prestore(X, lengths, k)
            all_models.append({
                "model": model,
                "logL": logL,
                "means": model.means_,
                "paths": paths
            })
            print(f"  [RUN {run}/{n_repeats}] logL = {logL:.1f}")
        except Exception as e:
            print(f"  [ERROR] run {run} k={k}: {e}")
            continue

    if not all_models:
        return None, None, None, None

    # Align all mean matrices and compute similarity
    means_all = [m["means"] for m in all_models]
    scores = align_and_score(means_all)
    best_idx = np.argmax(scores)
    best = all_models[best_idx]

    fo = fractional_occupancy(best["paths"], k)
    print(f"  [BEST] run={best_idx+1}/{n_repeats} k={k} logL={best['logL']:.1f}, FO min={fo.min():.3f}")
    return best["model"], best["logL"], fo, best["paths"]

def run_permutations(X, lengths, data, k, logL_real):
    def one_perm(_):
        shifted = [circular_shift(ts) for ts in data]
        Xs = np.vstack(shifted)
        try:
            _, logL, _ = fit_hmm_prestore(Xs, lengths, k)
            return logL
        except:
            return None

    nulls = Parallel(n_jobs=N_JOBS)(
        delayed(one_perm)(i) for i in range(N_PERMUTATIONS)
    )
    null_logLs = [nl for nl in nulls if nl is not None]
    pval = np.mean([logL_real <= nl for nl in null_logLs])
    return null_logLs, pval

def align_and_score(all_means):
    n_runs = len(all_means)
    scores = np.zeros(n_runs)

    for i in range(n_runs):
        curr = all_means[i]
        sims = []
        for j in range(n_runs):
            if i == j:
                continue
            other = all_means[j]
            # cost matrix = negative correlation
            cost = np.zeros((curr.shape[0], other.shape[0]))
            for s1 in range(curr.shape[0]):
                for s2 in range(other.shape[0]):
                    r, _ = pearsonr(curr[s1], other[s2])
                    cost[s1, s2] = -r  # negative for minimization
            row_ind, col_ind = linear_sum_assignment(cost)
            aligned_corrs = [-cost[r, c] for r, c in zip(row_ind, col_ind)]
            sims.append(np.mean(aligned_corrs))
        scores[i] = np.mean(sims)
    return scores

def main():
    data = load_data(INPUT_DIR)
    # pre-stack once
    X       = np.vstack(data)
    lengths = [d.shape[0] for d in data]

    print(f"[DEBUG] Loaded {len(data)} subjects")
    for i, d in enumerate(data):
        print(f"  Subject {i}: shape={d.shape}, mean={np.mean(d):.4f}, std={np.std(d):.4f}")

    results = []
    for k in STATE_COUNTS:
        print(f"Fitting HMM for k={k}...")
        model, logL, fo, paths = fit_best_model(X, lengths, k, N_REPEATS)
        if model is None:
            print(f"  [WARN] No model converged for k={k}")
            continue

        # model selection metrics
        n_params = compute_num_parameters(k, data[0].shape[1])
        aic = 2 * n_params - 2 * logL
        bic = np.log(sum(lengths)) * n_params - 2 * logL
        print(f"  Done. logL={logL:.1f}, AIC={aic:.1f}, FO min={fo.min():.3f}")

        null_logLs, pval = [], None
        if RUN_PERMUTATIONS:
            print(f"  Running {N_PERMUTATIONS} permutations for k={k}...")
            null_logLs, pval = run_permutations(X, lengths, data, k, logL)
            print(f"  Null model p-value: {pval:.4f}")
        else:
            print(f"  [SKIP] Permutations skipped for k={k}")

        results.append({
            "k": k,
            "model": model,
            "logL": logL,
            "AIC": aic,
            "BIC": bic,
            "FO": fo,
            "subject_paths": paths,
            "null_logLs": null_logLs,
            "pval": pval
        })

    with open(PKL_OUT, "wb") as f:
        pickle.dump(results, f)
    print(f"Saved HMMlearn permutation results to {PKL_OUT}")

if __name__ == "__main__":
    main()


[DEBUG] Loaded 75 subjects
  Subject 0: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 1: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 2: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 3: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 4: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 5: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 6: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 7: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 8: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 9: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 10: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 11: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 12: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 13: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 14: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 15: shape=(279, 14), mean=0.0000, std=1.0000
  Subject 16: shape=(279, 14), mean=-0.0000, std=1.0000
  Subject 17: shape=(279, 14), mean=-0.00

         1 -258424.31463245             +nan
         2 -241722.64672110  +16701.66791136
         3 -236481.72258943   +5240.92413167
         4 -233337.51570176   +3144.20688768
         5 -231243.84153685   +2093.67416491
         6 -229660.60118937   +1583.24034748
         7 -228508.98463939   +1151.61654998
         8 -227783.03000312    +725.95463627
         9 -227275.41915151    +507.61085161
        10 -226899.59770916    +375.82144235
        11 -226576.62235497    +322.97535419
        12 -226274.01977699    +302.60257798
        13 -226023.59954190    +250.42023508
        14 -225788.35113547    +235.24840643
        15 -225521.22650837    +267.12462710
        16 -225205.77352828    +315.45298009
        17 -224855.93167235    +349.84185593
        18 -224519.32391545    +336.60775690
        19 -224241.77140358    +277.55251186
        20 -224009.96351335    +231.80789023
        21 -223820.56107387    +189.40243948
        22 -223673.10934801    +147.45172586
        23

  [RUN 1/15] logL = -216805.2


         1 -258180.57421610             +nan
         2 -242477.54841158  +15703.02580452
         3 -235912.55712108   +6564.99129050
         4 -232188.26818374   +3724.28893735
         5 -229713.32666250   +2474.94152123
         6 -228161.44570390   +1551.88095860
         7 -227224.65775936    +936.78794454
         8 -226608.92369754    +615.73406182
         9 -226159.99384094    +448.92985661
        10 -225811.65649595    +348.33734499
        11 -225498.11430669    +313.54218926
        12 -225155.70437468    +342.40993201
        13 -224793.25404653    +362.45032816
        14 -224437.17720326    +356.07684327
        15 -224165.29313617    +271.88406709
        16 -223906.12227327    +259.17086290
        17 -223593.98961384    +312.13265943
        18 -223254.29428309    +339.69533075
        19 -222958.62216419    +295.67211890
        20 -222718.34592403    +240.27624016
        21 -222521.08557029    +197.26035374
        22 -222364.03472181    +157.05084847
        23

  [RUN 2/15] logL = -218694.0


         1 -262708.90671612             +nan
         2 -246382.73540267  +16326.17131346
         3 -240287.30340465   +6095.43199801
         4 -237150.31828491   +3136.98511975
         5 -235134.91425724   +2015.40402767
         6 -233722.28312298   +1412.63113426
         7 -232616.13789421   +1106.14522877
         8 -231459.32604609   +1156.81184812
         9 -230463.99459075    +995.33145534
        10 -229732.71999595    +731.27459481
        11 -229056.96159352    +675.75840243
        12 -228474.09460817    +582.86698536
        13 -227944.32905291    +529.76555525
        14 -227502.42183658    +441.90721633
        15 -227163.72853141    +338.69330517
        16 -226869.14643216    +294.58209925
        17 -226506.17079366    +362.97563850
        18 -226035.76198530    +470.40880837
        19 -225567.67850453    +468.08348077
        20 -225231.61026087    +336.06824365
        21 -224977.54816117    +254.06209970
        22 -224758.86186795    +218.68629322
        23