# Cross-validated logL from HMM object

In [4]:
import sys
import re
from contextlib import redirect_stderr
from io import StringIO
from tqdm import tqdm
import os
import glob
import h5py
import numpy as np
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from hmmlearn import hmm

In [5]:
# ─── Settings ────────────────────────────────────────────────────────────────────
INPUT_DIR     = "/home/jovyan/narratives-project/shirer_components"
DATASET       = "timeseries"
STATE_COUNTS  = [10, 14]
N_SPLITS      = 5
MAX_ITER      = 500
RANDOM_SEED   = 42
# ─────────────────────────────────────────────────────────────────────────────────

def load_data(input_dir):
    data = []
    for path in sorted(glob.glob(os.path.join(input_dir, "*.h5"))):
        with h5py.File(path, "r") as f:
            ts = f[DATASET][()].T
        ts = StandardScaler().fit_transform(ts)
        data.append(ts)
    return data

def compute_cv_loglik(data, k, n_splits=N_SPLITS):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_SEED)
    scores = []

    for fold, (train_idx, test_idx) in enumerate(kf.split(data), 1):
        train = [data[i] for i in train_idx]
        test  = [data[i] for i in test_idx]

        X_train = np.vstack(train)
        lengths_train = [d.shape[0] for d in train]

        print(f"\n  [Fold {fold}] Fitting HMM (k={k}) with verbose EM output:")
        model = hmm.GaussianHMM(n_components=k, covariance_type='full', n_iter=MAX_ITER, verbose=True)
        model.fit(X_train, lengths_train)

        X_test = np.vstack(test)
        lengths_test = [d.shape[0] for d in test]
        try:
            logL = model.score(X_test, lengths_test)
            scores.append(logL)
            print(f"  [Fold {fold}] logL = {logL:.1f}")
        except Exception as e:
            print(f"  [ERROR] Fold {fold} for k={k}: {e}")

    return np.mean(scores), np.std(scores)

def main():
    data = load_data(INPUT_DIR)
    print(f"✔ Loaded {len(data)} subjects")

    cv_results = []

    for k in STATE_COUNTS:
        print(f"\n=== Running CV for k={k} ===")
        mean_cvll, std_cvll = compute_cv_loglik(data, k)
        print(f"→ k={k}: Mean CVLL = {mean_cvll:.2f}, SD = {std_cvll:.2f}")
        cv_results.append({
            "k": k,
            "mean_cvll": mean_cvll,
            "std_cvll": std_cvll
        })

    # Sort and report best
    cv_results.sort(key=lambda x: x["mean_cvll"], reverse=True)
    print("\n=== Summary ===")
    for res in cv_results:
        print(f"k={res['k']:>2} | CVLL = {res['mean_cvll']:.2f} ± {res['std_cvll']:.2f}")

    best = cv_results[0]
    print(f"\n✅ Best model: k={best['k']} (CVLL = {best['mean_cvll']:.2f})")

if __name__ == "__main__":
    main()

✔ Loaded 75 subjects

=== Running CV for k=10 ===

  [Fold 1] Fitting HMM (k=10) with verbose EM output:


         1 -203903.04903179             +nan
         2 -194940.36320081   +8962.68583098
         3 -190116.54762887   +4823.81557195
         4 -187293.95693624   +2822.59069263
         5 -185731.32527972   +1562.63165652
         6 -184567.39696463   +1163.92831509
         7 -183652.10559544    +915.29136919
         8 -182922.11534528    +729.99025016
         9 -182332.61708124    +589.49826404
        10 -181857.70104805    +474.91603318
        11 -181465.45469748    +392.24635057
        12 -181116.63077082    +348.82392666
        13 -180774.02560360    +342.60516721
        14 -180524.28207249    +249.74353112
        15 -180343.33601562    +180.94605687
        16 -180224.62118144    +118.71483418
        17 -180138.15241336     +86.46876808
        18 -180073.51962185     +64.63279151
        19 -180024.18931378     +49.33030807
        20 -179983.77110268     +40.41821110
        21 -179941.11893609     +42.65216658
        22 -179885.44363742     +55.67529868
        23

  [Fold 1] logL = -48354.1

  [Fold 2] Fitting HMM (k=10) with verbose EM output:


         1 -205404.79379304             +nan
         2 -195049.69507165  +10355.09872139
         3 -190726.07149757   +4323.62357408
         4 -187840.13229471   +2885.93920287
         5 -186056.48308546   +1783.64920925
         6 -184754.04146352   +1302.44162193
         7 -183829.12390028    +924.91756325
         8 -183141.63998486    +687.48391542
         9 -182507.34970553    +634.29027933
        10 -181797.85038373    +709.49932181
        11 -180946.51192467    +851.33845906
        12 -180275.54997050    +670.96195417
        13 -179861.99444603    +413.55552447
        14 -179592.47885875    +269.51558728
        15 -179395.29578606    +197.18307269
        16 -179233.11876891    +162.17701715
        17 -179097.68642585    +135.43234305
        18 -178979.08287009    +118.60355576
        19 -178885.06257243     +94.02029766
        20 -178817.80608180     +67.25649063
        21 -178761.76941288     +56.03666892
        22 -178703.11601679     +58.65339610
        23

  [Fold 2] logL = -46650.9

  [Fold 3] Fitting HMM (k=10) with verbose EM output:


         1 -203218.24204011             +nan
         2 -191275.55045606  +11942.69158405
         3 -186508.21484382   +4767.33561225
         4 -183869.34482633   +2638.87001749
         5 -182294.73896790   +1574.60585843
         6 -181441.36078243    +853.37818547
         7 -180893.49749748    +547.86328495
         8 -180483.66495819    +409.83253929
         9 -180035.00625438    +448.65870381
        10 -179566.75915026    +468.24710412
        11 -179157.42755231    +409.33159795
        12 -178819.67033163    +337.75722068
        13 -178540.88339946    +278.78693217
        14 -178327.79623978    +213.08715968
        15 -178161.69806411    +166.09817567
        16 -178042.64933540    +119.04872871
        17 -177941.48136721    +101.16796819
        18 -177825.12855159    +116.35281562
        19 -177652.92514626    +172.20340533
        20 -177434.60129555    +218.32385071
        21 -177248.83620593    +185.76508962
        22 -177124.95311045    +123.88309548
        23

  [Fold 3] logL = -48225.8

  [Fold 4] Fitting HMM (k=10) with verbose EM output:


         1 -210376.69899800             +nan
         2 -200268.91544993  +10107.78354807
         3 -196964.87596948   +3304.03948045
         4 -194457.01201954   +2507.86394994
         5 -192803.97103545   +1653.04098409
         6 -191319.90470958   +1484.06632587
         7 -189970.14780004   +1349.75690953
         8 -188961.44752566   +1008.70027439
         9 -188181.89915650    +779.54836915
        10 -187611.35315037    +570.54600613
        11 -187167.12345471    +444.22969566
        12 -186785.08213557    +382.04131914
        13 -186412.11117478    +372.97096079
        14 -186048.65904640    +363.45212838
        15 -185729.08622314    +319.57282327
        16 -185482.86182325    +246.22439988
        17 -185296.37078225    +186.49104100
        18 -185171.30890339    +125.06187886
        19 -185089.58148556     +81.72741782
        20 -185026.23393005     +63.34755552
        21 -184956.15367458     +70.08025546
        22 -184879.78411739     +76.36955720
        23

  [Fold 4] logL = -42704.3

  [Fold 5] Fitting HMM (k=10) with verbose EM output:


         1 -207703.60409766             +nan
         2 -199012.92675082   +8690.67734684
         3 -192977.57944380   +6035.34730703
         4 -189322.49475588   +3655.08468791
         5 -187166.57429067   +2155.92046521
         6 -185779.36327843   +1387.21101224
         7 -184771.98811635   +1007.37516208
         8 -183882.48231451    +889.50580184
         9 -183082.31472967    +800.16758484
        10 -182585.41969224    +496.89503743
        11 -182262.81491757    +322.60477467
        12 -182033.82896318    +228.98595439
        13 -181837.13376942    +196.69519376
        14 -181614.86156646    +222.27220297
        15 -181370.41464091    +244.44692554
        16 -181174.30961852    +196.10502240
        17 -181022.24373872    +152.06587980
        18 -180902.05330459    +120.19043414
        19 -180792.41003420    +109.64327039
        20 -180690.95694040    +101.45309380
        21 -180588.09686839    +102.86007201
        22 -180487.41294731    +100.68392108
        23

  [Fold 5] logL = -44792.0
→ k=10: Mean CVLL = -46145.43, SD = 2151.04

=== Running CV for k=14 ===

  [Fold 1] Fitting HMM (k=14) with verbose EM output:


         1 -199828.49440263             +nan
         2 -187841.81296296  +11986.68143967
         3 -184359.21527390   +3482.59768906
         4 -182550.35245395   +1808.86281995
         5 -181104.80506710   +1445.54738685
         6 -179819.88161520   +1284.92345190
         7 -178920.57042009    +899.31119510
         8 -178215.68021112    +704.89020898
         9 -177545.27928086    +670.40093026
        10 -176946.70280767    +598.57647319
        11 -176451.25483358    +495.44797409
        12 -176090.72335070    +360.53148287
        13 -175818.37738668    +272.34596402
        14 -175566.64244895    +251.73493774
        15 -175337.15137654    +229.49107240
        16 -175099.14130554    +238.01007101
        17 -174808.52381868    +290.61748686
        18 -174495.11498852    +313.40883016
        19 -174237.54186725    +257.57312126
        20 -174033.50721976    +204.03464749
        21 -173871.81320964    +161.69401012
        22 -173737.06959357    +134.74361608
        23

  [Fold 1] logL = -48080.1

  [Fold 2] Fitting HMM (k=14) with verbose EM output:


         1 -208361.29419057             +nan
         2 -192641.82790711  +15719.46628345
         3 -187937.18875906   +4704.63914805
         4 -185585.97751282   +2351.21124624
         5 -184071.10531417   +1514.87219866
         6 -183028.03726656   +1043.06804760
         7 -182362.66981122    +665.36745534
         8 -181894.36615398    +468.30365724
         9 -181539.37260689    +354.99354709
        10 -181270.21435458    +269.15825231
        11 -181024.20605348    +246.00830110
        12 -180739.27620276    +284.92985071
        13 -180310.61287495    +428.66332781
        14 -179802.29433100    +508.31854395
        15 -179385.34712222    +416.94720878
        16 -179084.50572100    +300.84140121
        17 -178884.21638080    +200.28934020
        18 -178758.28943478    +125.92694602
        19 -178673.92846461     +84.36097016
        20 -178565.30188791    +108.62657670
        21 -178483.96431998     +81.33756793
        22 -178429.68652633     +54.27779365
        23

  [Fold 2] logL = -46353.4

  [Fold 3] Fitting HMM (k=14) with verbose EM output:


         1 -206365.60382580             +nan
         2 -192596.65312003  +13768.95070577
         3 -186927.23410214   +5669.41901789
         4 -183943.56033046   +2983.67377168
         5 -181984.52992170   +1959.03040876
         6 -180620.35387499   +1364.17604671
         7 -179737.86949904    +882.48437595
         8 -179246.61563506    +491.25386398
         9 -178918.78695619    +327.82867887
        10 -178586.60767268    +332.17928351
        11 -178235.16608335    +351.44158933
        12 -177926.71569806    +308.45038529
        13 -177653.66432184    +273.05137621
        14 -177326.88186512    +326.78245672
        15 -176864.26295001    +462.61891511
        16 -176282.50651656    +581.75643344
        17 -175767.75235327    +514.75416329
        18 -175425.45496160    +342.29739166
        19 -175216.17551245    +209.27944916
        20 -175078.83838286    +137.33712958
        21 -174967.56848743    +111.26989543
        22 -174879.67062830     +87.89785913
        23

  [Fold 3] logL = -47925.4

  [Fold 4] Fitting HMM (k=14) with verbose EM output:


         1 -212275.57676071             +nan
         2 -197777.44202843  +14498.13473228
         3 -192996.14775964   +4781.29426879
         4 -190575.03764117   +2421.11011847
         5 -188985.02024888   +1590.01739228
         6 -187833.32657498   +1151.69367391
         7 -186839.15419612    +994.17237886
         8 -185884.06349210    +955.09070402
         9 -184973.10575981    +910.95773229
        10 -184218.68486610    +754.42089371
        11 -183606.06457612    +612.62028998
        12 -183102.51748830    +503.54708782
        13 -182699.52584378    +402.99164452
        14 -182363.20369394    +336.32214984
        15 -182097.54823760    +265.65545634
        16 -181851.68670829    +245.86152931
        17 -181662.02321539    +189.66349291
        18 -181500.49706312    +161.52615227
        19 -181353.47991774    +147.01714537
        20 -181219.95857815    +133.52133960
        21 -181111.55673551    +108.40184263
        22 -181029.23084080     +82.32589471
        23

  [Fold 4] logL = -41755.9

  [Fold 5] Fitting HMM (k=14) with verbose EM output:


         1 -212340.76372195             +nan
         2 -196609.34342136  +15731.42030059
         3 -193367.52390911   +3241.81951224
         4 -190482.76877348   +2884.75513563
         5 -188227.20496554   +2255.56380794
         6 -186497.08507363   +1730.11989191
         7 -185251.97619537   +1245.10887825
         8 -184215.90604538   +1036.07014999
         9 -183411.67763110    +804.22841429
        10 -182733.28982577    +678.38780533
        11 -182166.98691388    +566.30291189
        12 -181703.49106301    +463.49585088
        13 -181334.97988370    +368.51117930
        14 -180982.34251623    +352.63736747
        15 -180740.09698271    +242.24553353
        16 -180543.71900628    +196.37797643
        17 -180392.71161911    +151.00738717
        18 -180270.96943499    +121.74218412
        19 -180171.76694529     +99.20248970
        20 -180069.53989890    +102.22704638
        21 -179970.74085390     +98.79904500
        22 -179877.63391003     +93.10694387
        23

  [Fold 5] logL = -44901.7
→ k=14: Mean CVLL = -45803.31, SD = 2331.69

=== Summary ===
k=14 | CVLL = -45803.31 ± 2331.69
k=10 | CVLL = -46145.43 ± 2151.04

✅ Best model: k=14 (CVLL = -45803.31)


       309 -176243.59397693      +0.00967549
