In [None]:
#this is going to be our main.py file that imports all the good preprocessing stuff and data input stuff and the machine learning stuff
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import logging
from torch.utils.tensorboard import SummaryWriter
import os
from torch.utils.data import random_split
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
from torchmetrics.classification import MulticlassCalibrationError  # For calibration metrics
import torch.nn.functional as F
import mne

class Args:
    seed_list = [0, 1004, 911, 2021, 119]
    seed = 10
    project_name = "test_project"
    checkpoint = False
    epochs = 10
    batch_size = 32
    optim = 'adam'
    lr_scheduler = "Single"
    lr_init = 1e-3
    lr_max = 4e-3
    t_0 = 5
    t_mult = 2
    t_up = 1
    gamma = 0.5
    momentum = 0.9
    weight_decay = 1e-6
    task_type = 'binary'
    log_iter = 10
    best = True
    last = False
    test_type = "test"
    device = 0  # GPU device number to use

    binary_target_groups = 2
    output_dim = 4

    # Manually set paths here instead of using a YAML config file
    data_path = '/Volumes/SDCARD/v2.0.3/edf' # Set the data path directly
    dir_root = os.getcwd()  # Set the root directory as the current working directory
    dir_result = os.path.join(dir_root, 'results')  # Set the result directory directly

    # Check if the results directory exists, and create it if it doesn't
    if not os.path.exists(dir_result):
        os.makedirs(dir_result)
    reset = False  # Set reset flag to False to avoid overwriting existing results

    num_layers = 2
    hidden_dim = 512  # Number of features in the hidden state of the LSTM
    dropout = 0.1  # Dropout rate for regularization
    num_channel = 20  # Number of data channels (e.g., EEG channels)
    sincnet_bandnum = 20  # SincNet configuration
    enc_model = 'raw'  # Encoder model for feature extraction

    window_shift_label = 1
    window_size_label = 4
    requirement_target = None
    sincnet_kernel_size = 81


class BasicBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()

        # First convolutional layer
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)

        # Second convolutional layer
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        # Shortcut for downsampling
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        # Forward pass
        out = F.relu(self.bn1(self.conv1(x)))  # First conv layer + batch norm + ReLU
        out = self.bn2(self.conv2(out))        # Second conv layer + batch norm
        out += self.shortcut(x)               # Add residual connection
        out = F.relu(out)                     # ReLU activation
        return out

class CNN2D_LSTM_V8_4(nn.Module):
    def __init__(self, args, device):
        super(CNN2D_LSTM_V8_4, self).__init__()
        self.args = args

        # Model parameters
        self.num_layers = args.num_layers  # Number of LSTM layers
        self.hidden_dim = 256  # Features in the LSTM hidden state
        self.dropout = args.dropout  # Dropout rate
        self.num_data_channel = args.num_channel  # Data channels (e.g., EEG channels)
        self.sincnet_bandnum = args.sincnet_bandnum  # SincNet configuration
        self.feature_extractor = args.enc_model  # Feature extraction method
        self.in_planes = 1  # Input planes for ResNet

        # Activation functions
        activation = 'relu'  # Default activation function
        self.activations = nn.ModuleDict({
            'lrelu': nn.LeakyReLU(),
            'prelu': nn.PReLU(),
            'relu': nn.ReLU(inplace=True),
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid(),
            'leaky_relu': nn.LeakyReLU(0.2),
            'elu': nn.ELU()
        })

        # Initialize hidden state for LSTM
        self.hidden = (
            torch.zeros(self.num_layers, args.batch_size, self.hidden_dim).to(device),
            torch.zeros(self.num_layers, args.batch_size, self.hidden_dim).to(device)
        )

        # Define ResNet layers using the BasicBlock
        self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)

        # Adaptive average pooling
        self.agvpool = nn.AdaptiveAvgPool2d((1, 1))

        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=256,
            hidden_size=self.hidden_dim,
            num_layers=self.num_layers,
            batch_first=True,
            dropout=self.dropout
        )

        # Fully connected classifier
        self.classifier = nn.Sequential(
            nn.Linear(in_features=self.hidden_dim, out_features=64, bias=True),
            nn.BatchNorm1d(64),
            self.activations[activation],
            nn.Linear(in_features=64, out_features=args.output_dim, bias=True)
        )

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride1 in strides:
            layers.append(block(self.in_planes, planes, stride1))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        batch_size = x.size(0)

        # Permute input
        x = x.permute(0, 2, 1).unsqueeze(1)

        # Pass through ResNet layers
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        # Apply adaptive pooling
        x = self.agvpool(x)
        x = x.view(x.size(0), -1)

        # Prepare for LSTM
        x = x.unsqueeze(1)

        # Initialize LSTM hidden state
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(x.device)
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(x.device)
        hidden = (h0, c0)

        # LSTM forward pass
        output, hidden = self.lstm(x, hidden)
        output = output[:, -1, :]

        # Classification
        output = self.classifier(output)
        return output

    def init_state(self, device):
        self.hidden = (
            torch.zeros(self.num_layers, self.args.batch_size, self.hidden_dim).to(device),
            torch.zeros(self.num_layers, self.args.batch_size, self.hidden_dim).to(device)
        )

# Function to calculate metrics
def calculate_metrics(labels, predictions, probabilities, num_classes):
    # Calculate accuracy
    correct = (predictions == labels).sum().item()
    total = labels.size(0)
    accuracy = 100 * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(labels.cpu().numpy(), predictions.cpu().numpy(), labels=list(range(num_classes)))

    # Calculate precision, recall, FPR, FNR
    FP = cm.sum(axis=0) - np.diag(cm)
    FN = cm.sum(axis=1) - np.diag(cm)
    TP = np.diag(cm)
    TN = cm.sum() - (FP + FN + TP)

    # Precision, Recall, FPR, FNR per class
    precision = TP / (TP + FP + 1e-8)  # Added epsilon to avoid division by zero
    recall = TP / (TP + FN + 1e-8)
    FPR = FP / (FP + TN + 1e-8)
    FNR = FN / (FN + TP + 1e-8)

    # Average metrics
    avg_precision = np.mean(precision)
    avg_recall = np.mean(recall)
    avg_FPR = np.mean(FPR)
    avg_FNR = np.mean(FNR)

    # Compute calibration error
    calibration_error = ece_criterion(probabilities, labels)

    metrics = {
        'accuracy': accuracy,
        'avg_precision': avg_precision,
        'avg_recall': avg_recall,
        'avg_FPR': avg_FPR,
        'avg_FNR': avg_FNR,
        'ECE': calibration_error.item()
    }

    return metrics

# Function to log metrics
def log_metrics(metrics, epoch, phase):
    logging.info(f'{phase} Metrics after Epoch {epoch}:')
    logging.info(f"Accuracy: {metrics['accuracy']:.2f}%")
    logging.info(f"Precision: {metrics['avg_precision']:.4f}")
    logging.info(f"Recall: {metrics['avg_recall']:.4f}")
    logging.info(f"False Positive Rate: {metrics['avg_FPR']:.4f}")
    logging.info(f"False Negative Rate: {metrics['avg_FNR']:.4f}")
    logging.info(f"Expected Calibration Error: {metrics['ECE']:.4f}")

    # Write to TensorBoard
    writer.add_scalar(f'{phase} accuracy', metrics['accuracy'], epoch)
    writer.add_scalar(f'{phase} precision', metrics['avg_precision'], epoch)
    writer.add_scalar(f'{phase} recall', metrics['avg_recall'], epoch)
    writer.add_scalar(f'{phase} FPR', metrics['avg_FPR'], epoch)
    writer.add_scalar(f'{phase} FNR', metrics['avg_FNR'], epoch)
    writer.add_scalar(f'{phase} ECE', metrics['ECE'], epoch)

# Function to evaluate the model
def evaluate_model(model, data_loader, device, temperature, num_classes):
    model.eval()
    with torch.no_grad():
        all_labels = []
        all_predictions = []
        all_probs = []
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            # Apply temperature scaling
            scaled_outputs = outputs / temperature

            # Apply softmax to get probabilities
            probabilities = torch.nn.functional.softmax(scaled_outputs, dim=1)

            # Get predicted labels
            _, predicted = torch.max(probabilities.data, 1)

            all_labels.append(labels)
            all_predictions.append(predicted)
            all_probs.append(probabilities)

        # Concatenate all tensors
        all_labels = torch.cat(all_labels)
        all_predictions = torch.cat(all_predictions)
        all_probs = torch.cat(all_probs)

        # Calculate metrics
        metrics = calculate_metrics(all_labels, all_predictions, all_probs, num_classes)

    return metrics

# Function to test on a single data point
def test_single_data_point(model, data_point, label, device, temperature, class_names=None):
    model.eval()
    target_channels = 40
    target_points = 5000
    segment_length = 500
    max_time_points = 100000
    
    with torch.no_grad():
        # single_file_input = data_point.to(device).unsqueeze(0)  # Add batch dimension
        data = mne.io.read_raw_edf(data_point, preload=True).get_data()

        # Limit the number of time points if they exceed max_time_points
        if data.shape[1] > max_time_points:
            data = data[:, :max_time_points]

        # Pad or trim channels to match target_channels
        if data.shape[0] < target_channels:
            padding = np.zeros((target_channels - data.shape[0], data.shape[1]))
            data = np.vstack((data, padding))
        elif data.shape[0] > target_channels:
            data = data[:target_channels, :]

        # Randomly select a segment
        if data.shape[1] > segment_length:
            start = np.random.randint(0, max(1, data.shape[1] - segment_length))
            end = start + segment_length
            segment = data[:, start:end]
        else:
            segment = data

        # Interpolate or compress to match target_points
        if segment.shape[1] != target_points:
            segment = np.array([np.interp(np.linspace(0, 1, target_points),
                                        np.linspace(0, 1, segment.shape[1]), channel)
                                for channel in segment])
                                
        # Return tensor
        data = torch.tensor(segment, dtype=torch.float32).unsqueeze(0).to(next(model.parameters()).dtype)
        
        output = model(data)
        # Apply temperature scaling
        scaled_output = output / temperature

        # Apply softmax to get probabilities
        probabilities = torch.nn.functional.softmax(scaled_output, dim=1)
        probabilities = probabilities.cpu().numpy()[0]  # Convert to numpy array

        # Get predicted label
        predicted_label = np.argmax(probabilities)
        actual_label = label

        # Calculate accuracy (1 if correct, 0 if incorrect)
        is_correct = int(predicted_label == actual_label)
        accuracy = 100 * is_correct

        logging.info('Single Data Point Test:')
        if class_names:
            logging.info(f'Actual Label: {class_names[actual_label]}')
            logging.info(f'Predicted Label: {class_names[predicted_label]}')
        else:
            logging.info(f'Actual Label: {actual_label}')
            logging.info(f'Predicted Label: {predicted_label}')
        logging.info(f'Accuracy: {accuracy}%')
        logging.info('Class Probabilities:')
        for i, prob in enumerate(probabilities):
            class_label = class_names[i] if class_names else i
            logging.info(f'Class {class_label}: {prob*100:.2f}%')
            classProbs[i] = prob

       
        # Plot the probabilities as a bar graph
        fig, ax = plt.subplots()
        classes = class_names if class_names else list(range(len(probabilities)))
        ax.bar(classes, probabilities)
        ax.set_xlabel('Class')
        ax.set_ylabel('Probability')
        ax.set_title('Class Probabilities')
        plt.xticks(rotation=45)
        plt.tight_layout()

        # Add the figure to TensorBoard
        writer.add_figure('Class Probabilities', fig)

### CHANGE THIS FOR CUSTOM FILES AND MODEL PATHS ###
model_path = r"C:\Users\dalto\Box Sync\full_model.pth"
single_file = r"C:\Users\dalto\Box Sync\aaaaaaaa_s001_t000.edf"

# List to store the probabilities of each class
classProbs = [0,0,0,0]
classNames = ['Healthy', 'Epilepsy', 'Stroke', 'Concussion']

args = Args()

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')

# Device configuration
if torch.backends.mps.is_available():
    device = torch.device('mps')
    logging.info("Using MPS device (Apple Silicon GPU).")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    logging.info("Using CUDA device.")
else:
    device = torch.device('cpu')
    logging.info("Using CPU.")

# Hyperparameters
num_epochs = 5
batch_size = 100
learning_rate = 0.001
temperature = 2.0  # Temperature parameter for temperature scaling
# checkpoint_dir = './checkpoints'
# os.makedirs(checkpoint_dir, exist_ok=True)

# model = CNN2D_LSTM_V8_4(args,device).to(device)  # Initialize the model
model = torch.load(model_path, map_location=device)  # Load the state dictionary
model = model.float()  # Convert model to float
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Calibration metric
ece_criterion = MulticlassCalibrationError(num_classes=10, n_bins=15).to(device)

# TensorBoard writer
writer = SummaryWriter('runs/simple_cnn')

# Test on a single data point
# Let's take the first image from the test set
single_label = 0
test_single_data_point(model, single_file, single_label, device, temperature)

logging.info("Training and evaluation completed.")

# Close the TensorBoard writer
writer.close()

2024-12-05 11:22:44,240 Using CPU.


Extracting EDF parameters from C:\Users\dalto\Box Sync\aaaaaaaa_s001_t000.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 323839  =      0.000 ...  1264.996 secs...


2024-12-05 11:22:47,559 Single Data Point Test:
2024-12-05 11:22:47,560 Actual Label: 0
2024-12-05 11:22:47,561 Predicted Label: 2
2024-12-05 11:22:47,562 Accuracy: 0%
2024-12-05 11:22:47,564 Class Probabilities:
2024-12-05 11:22:47,566 Class 0: 20.23%
2024-12-05 11:22:47,568 Class 1: 26.36%
2024-12-05 11:22:47,570 Class 2: 28.94%
2024-12-05 11:22:47,571 Class 3: 24.46%
2024-12-05 11:22:47,839 Training and evaluation completed.


In [2]:
%pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.6.0-py3-none-any.whl (926 kB)
                                              0.0/926.4 kB ? eta -:--:--
     ----------                             256.0/926.4 kB 7.9 MB/s eta 0:00:01
     ------------------------------------  921.6/926.4 kB 14.5 MB/s eta 0:00:01
     -------------------------------------- 926.4/926.4 kB 9.7 MB/s eta 0:00:00
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl (28 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.9 torchmetrics-1.6.0
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.1.2 -> 24.3.1
[notice] To update, run: C:\Users\dalto\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip
