In [1]:
import sys

sys.path.insert(0, '../')

from modules.VT_dataset import *
from modules.train_prep import *
from modules.plot_results import *

In [2]:
class EEGMLP(nn.Module):
    """
    A multilayer perceptron (MLP) neural network model for EEG data classification.
    Args:
    - input_size: The number of input features (e.g., number of EEG channels * number of time points per sample).
    - num_classes: The number of output classes for the classification task.
    """
    def __init__(self, input_size, num_classes=4):
        super(EEGMLP, self).__init__()
        # Initialize the base class, nn.Module.
        
        # Define the architecture of the MLP:
        # First fully connected layer from input_size to 512 nodes.
        self.fc1 = nn.Linear(input_size, 512)  # First hidden layer
        
        # Second fully connected layer from 512 to 256 nodes.
        self.fc2 = nn.Linear(512, 256)         # Second hidden layer
        
        # Third fully connected layer from 256 to 128 nodes.
        self.fc3 = nn.Linear(256, 128)         # Third hidden layer
        
        # Output layer: fully connected layer from 128 nodes to num_classes nodes.
        self.fc4 = nn.Linear(128, num_classes) # Output layer

    def forward(self, x):
        """
        Defines the forward pass of the MLP.
        Args:
        - x: Input tensor containing a batch of input data.
        
        Returns:
        - x: The output of the network after passing through the layers.
        """
        # Flatten the input tensor to match the expected input dimensions of the first fully connected layer.
        # This is necessary if the input tensor has more than two dimensions (e.g., batch_size x channels x data_points).
        x = x.view(x.size(0), -1)
        
        # Apply a ReLU activation function to the output of the first fully connected layer.
        x = F.relu(self.fc1(x))  # Activation function for the first layer
        
        # Apply a ReLU activation function to the output of the second fully connected layer.
        x = F.relu(self.fc2(x))  # Activation function for the second layer
        
        # Apply a ReLU activation function to the output of the third fully connected layer.
        x = F.relu(self.fc3(x))  # Activation function for the third layer
        
        # Output layer: No activation function is applied here because it's common to apply softmax
        # or sigmoid separately during the loss computation phase, especially for classification tasks.
        x = self.fc4(x)          # Output of the network is the input to the final layer
        
        return x


In [None]:
# Check if CUDA is available and set the device accordingly
if torch.cuda.is_available():
    torch.cuda.set_device(cuda_device)  # Use a specific CUDA device if available
elif torch.backends.mps.is_available():
    mps_device = torch.device("mps")  # Use Metal Performance Shaders (MPS) if on macOS with available support

# Define transformations for the datasets
data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),  # Convert images/pil_images to tensor, and also normalizes (0, 1)
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
    ]),
}

# Setup datasets and their corresponding loaders
dsets = {
    'train': RawDataset(train_dir, train_behav_file, data_transforms['train']),
    'val': RawDataset(val_dir, val_behav_file, data_transforms['val'])
}
dset_sizes = {split: len(dsets[split]) for split in ['train', 'val']}  # Get the size of each dataset

# Prepare data loaders, applying a weighted sampling strategy to handle class imbalance
dset_loaders = {}
for split in ['train', 'val']:
    targets = np.array([dsets[split].get_label(i) for i in range(len(dsets[split]))])
    class_counts = dsets[split].get_class_counts()
    class_weights = np.array([1.0 / class_counts[label] if class_counts[label] > 0 else 0 for label in targets])
    sampler = WeightedRandomSampler(class_weights, num_samples=len(class_weights), replacement=True)
    dset_loaders[split] = torch.utils.data.DataLoader(dsets[split], batch_size=b_size, num_workers=0, sampler=sampler)
    print('done making loader:', split)

# Initialize the EEGMLP model, loss function, and optimizer
model_ft = EEGMLP(input_size=129*1250, num_classes=4)  # Initialize model with specified input size and number of classes
criterion = nn.CrossEntropyLoss()  # Loss function suitable for classification
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.0001, weight_decay=1e-5)  # Optimizer with learning rate and weight decay

# Set the device for the model and criterion based on availability
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model_ft.to(device)
criterion.to(device)
print(f"Training on {device}")

# Execute the training process
model_ft, accuracies, losses, preds, labels = train_model(
    model_ft, criterion, optimizer_ft, exp_lr_scheduler, dset_loaders, dset_sizes, num_epochs=n_epochs
)

 22%|██▏       | 545/2488 [00:40<01:12, 26.68it/s]

In [None]:
# Loop through each data split (typically 'train' and 'val' for training and validation)
for split in ['train', 'val']:
    # Print accuracy and loss metrics for each epoch within the training and validation phases
    print(split, 'accuracies by epoch:', accuracies[split])
    print(split, 'losses by epoch:', losses[split])

# Convert PyTorch tensors in the accuracies dictionary to NumPy arrays for better interoperability
for key, value in accuracies.items():
    # List comprehension to move each tensor in the list to the CPU and convert it to a NumPy array
    new_value = [v.cpu().numpy() for v in value]
    accuracies[key] = new_value  # Update the dictionary with the new list of NumPy arrays

# Save the state dictionary of the model to a file for later use or deployment
torch.save(model_ft.state_dict(), '../../models/MLP_VT_best_model.pt')

# Assume plot_training_history and plot_confusion_matrix are defined elsewhere in the codebase.
# These functions generate visualizations for the model's training history and performance.

# Plot the training history using the updated accuracies and losses dictionaries.
# This function will generate a plot showing the accuracy and loss over all training epochs.
plot_training_history(accuracies, losses, 'MLP_VT')

# Plot the confusion matrix for the model's predictions compared to the true labels.
# This visualization helps in understanding the model's performance across different classes.
plot_confusion_matrix(labels, preds, 'MLP_VT')
