In [5]:
# Orthogonalization and Optimal State Number Computation

# Import necessary libraries
import os
import itertools
import numpy as np
import mne
from mne_connectivity import symmetric_orth
from hmmlearn import hmm
from scipy.signal import hilbert, resample, butter, lfilter
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import networkx as nx
import seaborn as sns
from scipy.optimize import fminbound
import time
import cupy as cp

os.environ['OMP_NUM_THREADS'] = '1'

# Define input and output directories
files_in = '../data/in/subjects/'
files_out = '../data/out/subjects/'

def downsample_with_filtering(data, original_fs, target_fs):
    """Downsamples data with an anti-aliasing filter."""
    nyq_rate = original_fs / 2.0
    cutoff_freq = target_fs / 2.0
    normalized_cutoff = cutoff_freq / nyq_rate
    b, a = butter(4, normalized_cutoff, btype='low')
    filtered_data = lfilter(b, a, data, axis=2)
    
    duration = data.shape[2] / original_fs
    new_num_samples = int(duration * target_fs)
    downsampled_data = resample(filtered_data, new_num_samples, axis=2)
    
    return downsampled_data

def apply_orthogonalization(data):
    """Applies orthogonalization to the given data."""
    analytic_signal = hilbert(data, axis=2)
    amplitude_envelope = np.abs(analytic_signal)
    Q, R = np.linalg.qr(amplitude_envelope.reshape(-1, amplitude_envelope.shape[-1]).T)
    rank = np.linalg.matrix_rank(R)
    if rank < amplitude_envelope.shape[-1]:
        print(f"Warning: Signals appear to be collinear.")
    orthogonalized_data = symmetric_orth(amplitude_envelope)
    orthogonalized_data = orthogonalized_data.reshape(amplitude_envelope.shape)
    return orthogonalized_data

def process_participant(subject, mode, dir_in, dir_out):
    label_time_courses_file = os.path.join(dir_out, f"{subject}_label_time_courses.npy")
    
    if os.path.exists(label_time_courses_file):
        try:
            label_time_courses = np.load(label_time_courses_file)
            print(f"Loaded data for {subject} in mode {mode}")
            
            downsampled_label_time_courses = downsample_with_filtering(label_time_courses, 513, 250)
            orthogonalized_data = apply_orthogonalization(downsampled_label_time_courses)
            
            output_file_path = os.path.join(dir_out, "orth.npy")
            np.save(output_file_path, orthogonalized_data)
            print(f"File saved successfully for participant {subject}, mode {mode} at {output_file_path}")
            
            return orthogonalized_data
        except Exception as e:
            print(f"Error processing {subject} in {mode}: {e}")
    else:
        print(f"File not found: {label_time_courses_file}")
    
    return None

def determine_optimal_states(orthogonalized_data, subject, mode, start_state=3):
    feature_variances = np.var(orthogonalized_data, axis=0)
    fraction_of_max_variance = 0.05
    variance_floor = fraction_of_max_variance * np.max(feature_variances)

    features = np.mean(orthogonalized_data, axis=2)
    features = np.ma.masked_invalid(features).filled(0)

    reshaped_data = orthogonalized_data.reshape(-1, 1)

    pca = PCA(n_components=0.99)
    pca_data = pca.fit_transform(reshaped_data)

    scaler = StandardScaler()
    pca_data = scaler.fit_transform(pca_data)

    participant_start_time = time.time()
    
    state_numbers = range(start_state, 17)

    for n_states in state_numbers:
        state_start_time = time.time()
        print(f"Processing state: {n_states} | Subject: {subject} | Mode: {mode}")

        pca_data_gpu = cp.asarray(pca_data)

        model = hmm.GaussianHMM(n_components=n_states, n_iter=50, covariance_type='full', tol=1e-7, verbose=False, params='st', init_params='stmc')
        model.fit(pca_data_gpu.get())

        log_likelihood = model.score(pca_data_gpu.get())
        n_params = n_states * (2 * pca_data_gpu.shape[1] - 1)
        aic = 2 * n_params - 2 * log_likelihood
        bic = np.log(pca_data_gpu.shape[0]) * n_params - 2 * log_likelihood

        state_elapsed_time = (time.time() - state_start_time) / 60
        print(f"Time taken for state {n_states}: {state_elapsed_time:.2f} minutes")

        with open(f"aic_bic_{subject}_{mode}.txt", "a") as f:
            f.write(f"{n_states}\t{aic}\t{bic}\n")

    # Find the optimal number of states
    with open(f"aic_bic_{subject}_{mode}.txt", "r") as f:
        lines = f.readlines()
        state_numbers = []
        aics = []
        bics = []
        for line in lines:
            parts = line.split()
            state_numbers.append(int(parts[0]))
            aics.append(float(parts[1]))
            bics.append(float(parts[2]))

        min_aic_index = np.argmin(aics)
        min_bic_index = np.argmin(bics)
        optimal_state_aic = state_numbers[min_aic_index]
        optimal_state_bic = state_numbers[min_bic_index]
        optimal_states = int((optimal_state_aic + optimal_state_bic) / 2)

        with open(f"aic_bic_{subject}_{mode}.txt", "a") as f:
            f.write(f"\nOptimal state (AIC): {optimal_state_aic}\n")
            f.write(f"Optimal state (BIC): {optimal_state_bic}\n")
            f.write(f"Optimal state (Average): {optimal_states}\n")

    print(f"Optimal number of states based on AIC/BIC: {optimal_states}")
    participant_elapsed_time = (time.time() - participant_start_time) / 60
    print(f"Total time taken for participant {subject} ({mode}): {participant_elapsed_time:.2f} minutes")

    return optimal_states

def main():
    # Get user input for starting point
    start_subject = input("Enter the starting participant number (e.g., 401): ")
    start_mode = input("Enter the starting condition (EC or EO): ")
    start_state = int(input("Enter the starting state number (3 to 16): "))

    # Load subject list
    with open("./names.txt", "r") as names:
        subject_list = names.read().split('\n')

    start_index = subject_list.index(start_subject)

    # Main processing loop
    for i in range(start_index, len(subject_list)):
        subject = subject_list[i]
        
        if subject == start_subject:
            modes = ["EO"] if start_mode == "EO" else ["EC", "EO"]
        else:
            modes = ["EC", "EO"]

        for mode in modes:
            dir_in = os.path.join(files_in, subject, mode)
            dir_out = os.path.join(files_out, subject, mode)

            # Ensure output directory exists
            os.makedirs(dir_out, exist_ok=True)

            # Process participant data
            orthogonalized_data = process_participant(subject, mode, dir_in, dir_out)

            if orthogonalized_data is not None:
                # Determine optimal number of states
                optimal_states = determine_optimal_states(orthogonalized_data, subject, mode, start_state)
                print(f"Optimal number of states for {subject} in {mode}: {optimal_states}")

        # Reset start_state for subsequent subjects
        start_state = 3

if __name__ == "__main__":
    main()