In [1]:
import sys

sys.path.insert(0, '../')

from modules.VT_dataset import *
from modules.train_prep import *
from modules.plot_results import *

In [2]:
class EEGMLP(nn.Module):
    def __init__(self, input_size, num_classes=4):
        super(EEGMLP, self).__init__()
        # Input size should match the total number of data points in each EEG sample
        self.fc1 = nn.Linear(input_size, 512)  # First hidden layer
        self.fc2 = nn.Linear(512, 256)         # Second hidden layer
        self.fc3 = nn.Linear(256, 128)         # Third hidden layer
        self.fc4 = nn.Linear(128, num_classes) # Output layer

    def forward(self, x):
        # Flatten the input tensor
        x = x.view(x.size(0), -1)
        # Forward pass through the network
        x = F.relu(self.fc1(x))  # Activation function for first layer
        x = F.relu(self.fc2(x))  # Activation function for second layer
        x = F.relu(self.fc3(x))  # Activation function for third layer
        x = self.fc4(x)          # No activation function for output layer
        return x

In [3]:
if torch.cuda.is_available():
    torch.cuda.set_device(cuda_device)
elif torch.backends.mps.is_available():
    mps_device = torch.device("mps")

data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
    ]),
}

# Setup datasets and loaders
dsets = {
    'train': RawDataset(train_dir, train_behav_file, data_transforms['train']),
    'val': RawDataset(val_dir, val_behav_file, data_transforms['val'])
}
dset_sizes = {split: len(dsets[split]) for split in ['train', 'val']}

dset_loaders = {}
for split in ['train', 'val']:
    targets = np.array([dsets[split].get_label(i) for i in range(len(dsets[split]))])
    class_counts = dsets[split].get_class_counts()
    class_weights = np.array([1.0 / class_counts[label] if class_counts[label] > 0 else 0 for label in targets])
    sampler = WeightedRandomSampler(class_weights, num_samples=len(class_weights), replacement=True)
    dset_loaders[split] = torch.utils.data.DataLoader(dsets[split], batch_size=b_size, num_workers=0, sampler=sampler)
    print('done making loader:', split)

# Initialize model, criterion, and optimizer
model_ft = EEGMLP(input_size=129*1250, num_classes=4)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.0001, weight_decay=1e-5)

# Device configuration with MPS support
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model_ft.to(device)
criterion.to(device)
print(f"Training on {device}")

# Train model
model_ft, accuracies, losses, preds, labels = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dset_loaders, dset_sizes, num_epochs=n_epochs)

100%|██████████| 192/192 [00:06<00:00, 27.63it/s]


Class counts:  {'HEHF': 2489, 'HELF': 238, 'LEHF': 1371, 'LELF': 878}


100%|██████████| 48/48 [00:01<00:00, 26.10it/s]


Class counts:  {'HEHF': 508, 'HELF': 101, 'LEHF': 535, 'LELF': 100}
done making loader: train
done making loader: val
Training on mps




----------
Epoch 0/4
----------
LR is set to 0.0001

In train phase:


 11%|█         | 270/2488 [00:13<01:28, 24.98it/s]

KeyboardInterrupt: 

 11%|█         | 270/2488 [00:30<01:28, 24.98it/s]

In [None]:
# Output results
for split in ['train', 'val']:
    print(split, 'accuracies by epoch:', accuracies[split])
    print(split, 'losses by epoch:', losses[split])

for key, value in accuracies.items():
    new_value = [v.cpu().numpy() for v in value]
    accuracies[key] = new_value

torch.save(model_ft.state_dict(), '../models/MLP_VT_best_model.pt')

plot_training_history(accuracies, losses, 'MLP_VT')
plot_confusion_matrix(labels, preds, 'MLP_VT')