In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import matplotlib.pyplot as plt
from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import (
    exponential_moving_standardize, preprocess, Preprocessor
)
from scipy.signal import welch
import jax.numpy as jnp
from jax import random

from src.idbn import ImprovedDynamicBayesianNetwork

In [2]:
# Load and preprocess the dataset
subject_id = 1
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])

preprocessors = [
    Preprocessor('pick_types', eeg=True, meg=False, stim=False),
    Preprocessor(lambda data: np.multiply(data, 1e6)),
    Preprocessor('filter', l_freq=4, h_freq=38),
    Preprocessor('resample', sfreq=50)  # Reduced from 100 to 50 Hz
]
preprocess(dataset, preprocessors)

# Extract data and labels
X, y = [], []
for session in dataset.datasets:
    X.append(session.raw.get_data())
    y.extend(session.raw.annotations.description)

X = np.concatenate(X, axis=1)
y = np.array(y)

# Function to extract features (spectral power in different frequency bands)
def extract_features(eeg_data, sfreq, window_size=2):
    n_channels, n_samples = eeg_data.shape
    n_windows = n_samples // (window_size * sfreq)
    features = np.zeros((n_windows, n_channels * 3))

    for i in range(n_windows):
        start = i * window_size * sfreq
        end = (i + 1) * window_size * sfreq
        window = eeg_data[:, start:end]
        
        for ch in range(n_channels):
            f, psd = welch(window[ch], fs=sfreq, nperseg=window_size*sfreq)
            theta = np.mean(psd[(f >= 4) & (f < 8)])
            alpha = np.mean(psd[(f >= 8) & (f < 13)])
            beta = np.mean(psd[(f >= 13) & (f < 30)])
            features[i, ch*3:(ch+1)*3] = [theta, alpha, beta]

    return features

# Extract features
features = extract_features(X, sfreq=50, window_size=2)

# Use only a subset of features and data points
features = features[:500, :6]  # Use 500 time points and 2 channels (2 * 3 features)

num_states = 2
max_parents = 1
dbn = ImprovedDynamicBayesianNetwork(num_features=features.shape[1], num_states=num_states, max_parents=max_parents)

# Convert numpy array to jax array
features_jax = jnp.array(features)

# Fit the model
rng_key = random.PRNGKey(0)
samples = dbn.fit(features_jax)

# Predict cognitive states
predicted_states = dbn.predict(features_jax)

# Evaluate model performance
def evaluate_model(true_labels, predicted_states):
    accuracy = np.mean(true_labels == predicted_states)
    return accuracy

# Analyze state transitions
def analyze_transitions(states):
    transitions = np.zeros((num_states, num_states))
    for i in range(len(states) - 1):
        transitions[states[i], states[i+1]] += 1
    transitions /= transitions.sum(axis=1, keepdims=True)
    return transitions

# Analyze dwell times
def analyze_dwell_times(states):
    dwell_times = [[] for _ in range(num_states)]
    current_state = states[0]
    current_dwell = 1
    for state in states[1:]:
        if state == current_state:
            current_dwell += 1
        else:
            dwell_times[current_state].append(current_dwell)
            current_state = state
            current_dwell = 1
    dwell_times[current_state].append(current_dwell)
    return dwell_times


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.3s


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.2s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.4s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)

NOTE: pick_types() is a legacy function. New code should use inst.pick(...).


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 38 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 38.00 Hz
- Upper transition bandwidth: 9.50 Hz (-6 dB cutoff frequency: 42.75 Hz)
- Filter length: 413 samples (1.652 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


In [None]:

# Assuming y contains the true labels for each window
y_windows = np.array([y[i * 100] for i in range(len(predicted_states))])
accuracy = evaluate_model(y_windows, predicted_states)
print(f"Model Accuracy: {accuracy:.2f}")

# Plot results
plt.figure(figsize=(15, 10))
plt.subplot(2, 1, 1)
plt.plot(predicted_states)
plt.title("Predicted Cognitive States")
plt.xlabel("Time")
plt.ylabel("State")

In [None]:

plt.subplot(2, 1, 2)
state_probabilities = dbn.get_state_probabilities(features_jax)
plt.imshow(state_probabilities.T, aspect='auto', cmap='hot')
plt.title("State Probabilities Over Time")
plt.xlabel("Time")
plt.ylabel("State")
plt.colorbar()

plt.tight_layout()
plt.show()

In [None]:
transition_matrix = analyze_transitions(predicted_states)
plt.figure(figsize=(8, 6))
plt.imshow(transition_matrix, cmap='coolwarm')
plt.title("State Transition Probabilities")
plt.xlabel("To State")
plt.ylabel("From State")
plt.colorbar()
plt.show()

In [None]:

dwell_times = analyze_dwell_times(predicted_states)
plt.figure(figsize=(10, 6))
plt.boxplot(dwell_times)
plt.title("Dwell Times for Each State")
plt.xlabel("State")
plt.ylabel("Dwell Time (windows)")
plt.show()

# Print summary statistics
print("Summary Statistics:")
for i, dwells in enumerate(dwell_times):
    print(f"State {i}:")
    print(f"  Mean dwell time: {np.mean(dwells):.2f}")
    print(f"  Median dwell time: {np.median(dwells):.2f}")
    print(f"  Max dwell time: {np.max(dwells)}")
    print(f"  Min dwell time: {np.min(dwells)}")
    print()

# Get edge probabilities
edge_probabilities = dbn.get_edge_probabilities()

# Print the top 10 strongest edges
sorted_edges = sorted(edge_probabilities.items(), key=lambda x: x[1], reverse=True)[:10]
print("Top 10 strongest edges:")
for edge, prob in sorted_edges:
    print(f"{edge}: {prob:.4f}")

# Visualize the transition probabilities
transition_matrix = np.zeros((num_states, num_states))
for edge, prob in edge_probabilities.items():
    if edge.startswith('State_') and '->' in edge:
        i, j = map(int, edge.replace('State_', '').split('->'))
        transition_matrix[i, j] = prob

plt.figure(figsize=(8, 6))
plt.imshow(transition_matrix, cmap='coolwarm')
plt.title("State Transition Probabilities")
plt.xlabel("To State")
plt.ylabel("From State")
plt.colorbar()
plt.show()