In [None]:
# Feature extraction and model evaluation for BIS prediction using DE_SQ (all bands), spike detection (all bands), and PE in delta band with bandpass filtering (60s ahead)

import numpy as np
import pandas as pd
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, t, entropy
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from scipy.signal import butter, filtfilt

# ----- Band Definitions -----
bands = {
    'delta': (0.5, 4),
    'theta': (4, 8),
    'alpha': (8, 13),
    'beta': (13, 30),
    'gamma': (30, 45)
}


# ----- Bandpass Filter Function -----
def bandpass_filter(signal, fs, lowcut, highcut, order=4):
    nyq = 0.5 * fs
    b, a = butter(order, [lowcut / nyq, highcut / nyq], btype='band')
    return filtfilt(b, a, signal)


# ----- Permutation Entropy Function -----
def permutation_entropy(time_series, order=3, delay=1):
    n = len(time_series)
    if n < order * delay:
        return 0
    permutations = np.array([time_series[i:i + order * delay:delay] for i in range(n - (order - 1) * delay)])
    sorted_idx = np.argsort(permutations, axis=1)
    patterns, counts = np.unique(sorted_idx, axis=0, return_counts=True)
    probs = counts / counts.sum()
    return -np.sum(probs * np.log2(probs))


# ----- Spike Detection Function -----
def spike_detection(signal, threshold=3):
    spikes = np.where(np.abs(signal - np.mean(signal)) > threshold * np.std(signal))[0]
    return len(spikes)


# ----- Feature Functions -----
def band_features(signal, fs):
    de_sq_list = []
    spike_list = []
    for band_name in bands:
        lowcut, highcut = bands[band_name]
        filtered = bandpass_filter(signal, fs, lowcut, highcut)
        variance = np.var(filtered)
        if variance <= 0:
            variance = 1e-10
        de = 0.5 * np.log(2 * np.pi * np.e * variance)
        de_sq_list.append(de ** 2)
        spike_list.append(spike_detection(filtered))

    # Only PE in delta band
    delta_filtered = bandpass_filter(signal, fs, bands['delta'][0], bands['delta'][1])
    pe_delta = permutation_entropy(delta_filtered)

    return de_sq_list + spike_list + [pe_delta]


# ----- Feature Extraction -----
def extract_features(eeg_values, fs, window_size_seconds, step_size_seconds):
    window_size_samples = int(window_size_seconds * fs)
    step_size_samples = int(step_size_seconds * fs)

    def process_window(start):
        window = eeg_values[start:start + window_size_samples]
        if len(window) < window_size_samples:
            return None
        if np.all(window == 0):
            return None
        return band_features(window, fs)

    features = Parallel(n_jobs=-1)(delayed(process_window)(start)
                                   for start in range(0, len(eeg_values) - window_size_samples + 1, step_size_samples))
    features = [f for f in features if f is not None and not np.any(np.isnan(f)) and not np.any(np.isinf(f))]
    return np.array(features)


# ----- Remove leading/trailing zero BIS values and align EEG -----
def trim_zero_ends(eeg, bis, fs_eeg=128, fs_bis=1):
    bis = np.array(bis)
    eeg = np.array(eeg)

    bis_start_idx = next((i for i, val in enumerate(bis) if val != 0), None)
    bis_end_idx = next((i for i, val in enumerate(bis[::-1]) if val != 0), None)

    if bis_start_idx is None or bis_end_idx is None:
        return np.array([]), np.array([])

    bis_end_idx = len(bis) - bis_end_idx
    start_time = bis_start_idx / fs_bis
    end_time = bis_end_idx / fs_bis

    start_eeg_idx = int(start_time * fs_eeg)
    end_eeg_idx = int(end_time * fs_eeg)

    trimmed_eeg = eeg[start_eeg_idx:end_eeg_idx]
    trimmed_bis = bis[bis_start_idx:bis_end_idx]

    return trimmed_eeg, trimmed_bis


# ----- Training and Testing Code -----
fs = 128
window_size_seconds = 56
step_size_seconds = 1
advance_seconds = 60
train_patients = [9, 411, 760, 770, 2576, 3204, 3324, 3365, 3387, 3413, 3617, 4183, 4276, 4450, 4827]
test_patients = [106, 4968, 5604, 798, 804, 909]
linkpath = r'C:\Users\user\OneDrive - UniSQ\UK-Aus\Aus\Research Project - Startup\MSC6001_2023_2_ Literature review assignment submission Link\PhD\Gold patient csv'

features_list = []
bis_values_list = []

for patient in train_patients:
    eeg = pd.read_csv(f'{linkpath}/EEG{patient}.csv')['EEG'].interpolate('linear')
    bis = pd.read_csv(f'{linkpath}/BIS{patient}.csv')['BIS'].interpolate('linear')
    eeg, bis = trim_zero_ends(eeg.values, bis.values, fs_eeg=fs, fs_bis=1)
    if len(eeg) < fs * window_size_seconds:
        continue
    feats = extract_features(eeg, fs, window_size_seconds, step_size_seconds)
    if feats.size == 0:
        continue
    advance_steps = int(advance_seconds / step_size_seconds)
    bis_shifted = bis[advance_steps:]
    if len(bis_shifted) < len(feats):
        feats = feats[:len(bis_shifted)]
    else:
        bis_shifted = bis_shifted[:len(feats)]
    if len(feats) == 0 or len(bis_shifted) == 0:
        continue
    features_list.append(feats)
    bis_values_list.append(bis_shifted)

X_train = np.vstack(features_list)
y_train = np.concatenate(bis_values_list)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)

model = SVR(kernel='rbf')
model.fit(X_train_scaled, y_train)

# ----- Testing and Combined Evaluation -----
all_preds = []
all_truth = []

for pid in test_patients:
    eeg = pd.read_csv(f'{linkpath}/EEG{pid}.csv')['EEG'].interpolate('linear')
    bis = pd.read_csv(f'{linkpath}/BIS{pid}.csv')['BIS'].interpolate('linear')
    eeg, bis = trim_zero_ends(eeg.values, bis.values, fs_eeg=fs, fs_bis=1)
    if len(eeg) < fs * window_size_seconds:
        continue
    feats = extract_features(eeg, fs, window_size_seconds, step_size_seconds)
    if feats.size == 0:
        continue
    advance_steps = int(advance_seconds / step_size_seconds)
    bis_shifted = bis[advance_steps:]
    if len(bis_shifted) < len(feats):
        feats = feats[:len(bis_shifted)]
    else:
        bis_shifted = bis_shifted[:len(feats)]
    if len(feats) == 0 or len(bis_shifted) == 0:
        continue
    X_test = scaler.transform(feats)
    y_test = bis_shifted
    y_pred = model.predict(X_test)
    all_preds.extend(y_pred)
    all_truth.extend(y_test)

    # Plot time series for this patient with arrows for 60s prediction
    time_axis = np.arange(len(y_test)) * step_size_seconds
    plt.figure(figsize=(10, 4))
    plt.plot(time_axis, y_test, label='Actual BIS')
    plt.plot(time_axis, y_pred, label='Predicted BIS')

    for i in range(0, len(time_axis), 200):
        if i + advance_seconds < len(time_axis):
            plt.annotate('', xy=(time_axis[i + advance_seconds], y_test[i + advance_seconds]),
                         xytext=(time_axis[i], y_pred[i]),
                         arrowprops=dict(arrowstyle='->', color='red', linestyle='--'))

    plt.title(f'Patient {pid} - BIS Prediction 60s Ahead')
    plt.xlabel('Time (s)')
    plt.ylabel('BIS')
    plt.legend()
    plt.tight_layout()
    plt.show()

# Plot overall Confidence Interval with upper/lower parallel lines
coeffs = np.polyfit(all_preds, all_truth, 1)
poly = np.poly1d(coeffs)
x_fit = np.linspace(min(all_preds), max(all_preds), 100)
y_fit = poly(x_fit)
residuals = np.array(all_truth) - poly(np.array(all_preds))
n = len(all_preds)
std_err = np.std(residuals)
t_val = t.ppf(0.975, df=n - 2)
ci = t_val * std_err * np.sqrt(
    1 / n + (x_fit - np.mean(all_preds)) ** 2 / np.sum((np.array(all_preds) - np.mean(all_preds)) ** 2))

# Overall Evaluation
mse = mean_squared_error(all_truth, all_preds)
rmse = np.sqrt(mse)
r2 = r2_score(all_truth, all_preds)
r, _ = pearsonr(all_truth, all_preds)
ci = 1.96 * np.std(np.array(all_preds) - np.array(all_truth)) / np.sqrt(len(all_truth))
print(f"\nOverall: MSE={mse:.2f}, RMSE={rmse:.2f}, RÂ²={r2:.2f}, r={r:.2f}, 95% CI={ci:.2f}")

plt.figure(figsize=(10, 6))
plt.scatter(all_preds, all_truth, label='Data', alpha=0.6)
plt.plot(x_fit, y_fit, label='Linear Fit', color='blue')
plt.plot(x_fit, y_fit + ci * 100, color='blue', linestyle='--', linewidth=1)
plt.plot(x_fit, y_fit - ci * 100, color='blue', linestyle='--', linewidth=1)
plt.xlabel('Predicted BIS (60s ahead)')
plt.ylabel('Actual BIS')
plt.title('All Patients: 95% Confidence Interval for 60s Future BIS Prediction')
plt.legend()
plt.tight_layout()
plt.show()