In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal, stats
import torch
from omegaconf import OmegaConf
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import joblib

# Path to dir on Linux Computer
os.chdir('/home/vineetreddy/Dropbox/CZW_MIT/BrainBERT')  # Change to the BrainBERT directory
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)  # Add the parent directory to the system path

from demo_brainbert_annotated import *  # importing custom functions
from create_labels import create_labels #import custom function to create labels for seizure data
import models  # Import custom models (user-defined)

# Function to load pre-trained model weights and configuration. 
# Note that ckpt_path is the path to the pretrained weights of BrainBERT.
def load_brainbert_model(ckpt_path):
    """
    Loads the BrainBERT model with pre-trained weights.

    Args:
        ckpt_path (str): Path to the checkpoint file.

    Returns:
        torch.nn.Module: BrainBERT model with loaded weights.
    """
    cfg = OmegaConf.create({"upstream_ckpt": ckpt_path})
    brainbert_model = build_model(cfg)  # Build the model with the given configuration
    brainbert_model.to('cuda')  # Move the model to GPU
    init_state = torch.load(ckpt_path)  # Load the initial state of the model
    load_model_weights(brainbert_model, init_state['model'], False)  # Load the model weights
    return brainbert_model

# Function to generate BrainBERT embeddings from example waveforms
def generate_brainbert_embeddings(model, example_wavs):
    """
    Generates BrainBERT embeddings for each example.

    Args:
        brainbert_model (torch.nn.Module): BrainBERT model.
        example_wavs (np.array): Array of example waveforms.

    Returns:
        np.array: Array of BrainBERT embeddings.
    """
    brainbert_outs = []
    for example_wav in example_wavs:
        # Get the Short-Time Fourier Transform (STFT) of the signal
        f, t, linear = get_stft(example_wav, 2048, clip_fs=25, nperseg=400, noverlap=350, normalizing="zscore", return_onesided=True)  # TODO hardcode sampling rate
        inputs = torch.FloatTensor(linear).unsqueeze(0).transpose(1, 2).to('cuda')  # Prepare inputs for the model
        mask = torch.zeros((inputs.shape[:2])).bool().to('cuda')  # Create a mask for the inputs
        with torch.no_grad():
            out = model.forward(inputs, mask, intermediate_rep=True)  # Get the model output
        brainbert_outs.append(out.cpu().numpy())  # Append the output to the list

    # Average over the time dimension to get a single vector per example (example = 5 second window of time series data)
    brainbert_outs_arr = np.concatenate(brainbert_outs, axis=0)
    brainbert_outs_arr = brainbert_outs_arr.mean(axis=1)
    
    return brainbert_outs_arr

# Load Pre-Trained Model Weights and Configuration
ckpt_path = "/home/vineetreddy/Dropbox/CZW_MIT/stft_large_pretrained_256hz.pth"  # path to pre-trained weights for model
brainbert_model = load_brainbert_model(ckpt_path)  # Load the model

# Paths to save/load embeddings and labels
embeddings_save_path = "/home/vineetreddy/brainbert_embeddings.npy"
labels_save_path = "/home/vineetreddy/brainbert_labels.npy"

# Check if embeddings and labels already exist
if os.path.exists(embeddings_save_path) and os.path.exists(labels_save_path):
    all_brainbert_outs = np.load(embeddings_save_path)
    all_labels = np.load(labels_save_path)
    print("Loaded embeddings and labels from saved files.")
else:
    # Process each file in the directory
    directory = '/home/vineetreddy/edf numpy out/'
    events_dir = '/home/vineetreddy/edf events'  # Directory containing the events .tsv files
    all_brainbert_outs = []
    all_labels = []

    for filename in os.listdir(directory):
        if filename.endswith(".npy"):
            file_path = os.path.join(directory, filename)
            
            # Load in channel array. Each channel array is organized such that each row is a 5-second window 
            # and the columns are the time series data
            example_wavs = np.load(file_path)

            # Generate labels
            labels = create_labels(file_path, events_dir)
            if labels.size == 0:
                continue

            # Generate BrainBERT embeddings for each example
            brainbert_outs_arr = generate_brainbert_embeddings(brainbert_model, example_wavs)
            all_brainbert_outs.append(brainbert_outs_arr)
            all_labels.append(labels)

    # Combine all the data
    all_brainbert_outs = np.concatenate(all_brainbert_outs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Save the embeddings and labels
    np.save(embeddings_save_path, all_brainbert_outs)
    np.save(labels_save_path, all_labels)
    print("Saved embeddings and labels to files.")

# Logistic Regression
model_save_path = "/home/vineetreddy/save model/logistic_model.joblib"

# Initialize variables
brainbert_outs_arr_train = brainbert_outs_arr_test = labels_train = labels_test = None

try:
    logistic_model = joblib.load(model_save_path)
    print(f"Loaded existing model from {model_save_path}")
    
    # If model is loaded, we need to split the data for evaluation
    brainbert_outs_arr_train, brainbert_outs_arr_test, labels_train, labels_test = train_test_split(
        all_brainbert_outs, all_labels, test_size=0.2, random_state=42
    )
except FileNotFoundError:
    print(f"Model not found at {model_save_path}, training a new model.")
    
    # Split the data using train_test_split
    brainbert_outs_arr_train, brainbert_outs_arr_test, labels_train, labels_test = train_test_split(
        all_brainbert_outs, all_labels, test_size=0.2, random_state=42
    )

    logistic_model = LogisticRegression()
    logistic_model.fit(brainbert_outs_arr_train, labels_train)  # Fit the logistic regression model

    # Save the logistic regression model
    joblib.dump(logistic_model, model_save_path)
    print(f"Model saved to {model_save_path}")

    # Predict the labels for the training set
    predictions_train = logistic_model.predict(brainbert_outs_arr_train)
    acc_train = np.mean(predictions_train == labels_train)
    print(f'Training Accuracy: {acc_train}')

# Predict the labels for the test set
predictions = logistic_model.predict(brainbert_outs_arr_test)  # Predict the labels
acc = np.mean(predictions == labels_test)
print(f"Test Accuracy: {acc}")

# Save predictions and true labels for ROC/AUC plot
np.save("/home/vineetreddy/roc_logreg/predictions.npy", predictions)
np.save("/home/vineetreddy/roc_logreg/labels_test.npy", labels_test)

In [None]:
# Import necessary libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import joblib

# Assuming all required functions and model loading are defined in the notebook

# Load embeddings and labels
embeddings_save_path = "/home/vineetreddy/brainbert_embeddings.npy"
labels_save_path = "/home/vineetreddy/brainbert_labels.npy"

all_brainbert_outs = np.load(embeddings_save_path)
all_labels = np.load(labels_save_path)

# PCA of BrainBERT Embeddings
pca = PCA(n_components=2)
reduced_embeddings = pca.fit_transform(all_brainbert_outs)

# Print the variance explained by the first two components
variance_explained = pca.explained_variance_ratio_
print(f"Variance explained by the first component: {variance_explained[0]:.2f}")
print(f"Variance explained by the second component: {variance_explained[1]:.2f}")

plt.figure(figsize=(10, 6))
plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], c=all_labels, cmap='viridis', s=5)
plt.title('PCA of BrainBERT Embeddings')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.colorbar(label='Label')
plt.show()

# Logistic Regression Model Loading and Evaluation
model_save_path = "/home/vineetreddy/save model/logistic_model.joblib"
logistic_model = joblib.load(model_save_path)

# Split the data for evaluation
brainbert_outs_arr_train, brainbert_outs_arr_test, labels_train, labels_test = train_test_split(
    all_brainbert_outs, all_labels, test_size=0.2, random_state=42
)

# Training and Test Accuracy
train_accuracy = np.mean(logistic_model.predict(brainbert_outs_arr_train) == labels_train)
test_accuracy = np.mean(logistic_model.predict(brainbert_outs_arr_test) == labels_test)

plt.figure(figsize=(6, 4))
bars = plt.bar(['Training Accuracy', 'Test Accuracy'], [train_accuracy, test_accuracy], color=['blue', 'green'])
plt.ylim(0, 1)
plt.ylabel('Accuracy')
plt.title('Model Performance')

# Annotate bars with accuracy values
for bar, accuracy in zip(bars, [train_accuracy, test_accuracy]):
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width() / 2, yval + 0.01, f'{accuracy:.2f}', ha='center', va='bottom')

plt.show()

# Confusion Matrix
predictions = logistic_model.predict(brainbert_outs_arr_test)
conf_matrix = confusion_matrix(labels_test, predictions)

# Specify the labels explicitly
disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=['Non-Seizure', 'Seizure'])
disp.plot(cmap='Blues')
plt.title('Confusion Matrix')
plt.show()

# ROC Curve
fpr, tpr, _ = roc_curve(labels_test, logistic_model.predict_proba(brainbert_outs_arr_test)[:, 1])
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc='lower right')
plt.show()
