In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from hmmlearn.hmm import GaussianHMM
import matplotlib.pyplot as plt

In [None]:
# Load dataset
df = pd.read_csv('your_dataset.csv')  # Replace with your actual file

# Example features (replace with your actual brain metric columns)
features = ['feature1', 'feature2', 'feature3']

# Define age bins (change if needed)
age_bins = [(5, 8), (9, 12), (13, 16)]

# Assign age bin
def assign_bin(age):
    for i, (low, high) in enumerate(age_bins):
        if low <= age <= high:
            return i
    return None

df['age_bin'] = df['age'].apply(assign_bin)
df = df.dropna(subset=['age_bin'])

# Normalize brain features
scaler = StandardScaler()
df[features] = scaler.fit_transform(df[features])

In [None]:
def generate_pseudo_sequences(df, features, age_bins, num_sequences=1000):
    sequences = []
    for _ in range(num_sequences):
        seq = []
        for bin_id in range(len(age_bins)):
            bin_data = df[df['age_bin'] == bin_id]
            if bin_data.empty:
                break
            subject = bin_data.sample(1)
            seq.append(subject[features].values[0])
        if len(seq) == len(age_bins):
            sequences.append(seq)
    return sequences

# Create sequences
sequences = generate_pseudo_sequences(df, features, age_bins, num_sequences=1000)

# Stack data for HMM
X = np.vstack(sequences)
lengths = [len(age_bins)] * len(sequences)

In [None]:
# Train HMM with 3 hidden states (e.g., Mild, Moderate, Severe)
model = GaussianHMM(n_components=3, covariance_type='diag', n_iter=100, random_state=42)
model.fit(X, lengths)

In [None]:
# Decode example sequences to show state progression
for i, seq in enumerate(sequences[:5]):  # Show first 5 sequences
    logprob, states = model.decode(np.array(seq))
    print(f"Sequence {i+1} → Predicted states: {states}")

In [None]:
# Plot average feature values for each hidden state
for i in range(model.n_components):
    plt.plot(model.means_[i], label=f'State {i}')
plt.title("Mean Feature Values by Hidden State")
plt.xlabel("Feature Index")
plt.ylabel("Standardized Value")
plt.legend()
plt.grid(True)
plt.show()