In [5]:
from modules.VT_dataset import *
from modules.train_prep import *

In [6]:
class EEGMLP(nn.Module):
    def __init__(self, input_size, num_classes=4):
        super(EEGMLP, self).__init__()
        # Input size should match the total number of data points in each EEG sample
        self.fc1 = nn.Linear(input_size, 512)  # First hidden layer
        self.fc2 = nn.Linear(512, 256)         # Second hidden layer
        self.fc3 = nn.Linear(256, 128)         # Third hidden layer
        self.fc4 = nn.Linear(128, num_classes) # Output layer

    def forward(self, x):
        # Flatten the input tensor
        x = x.view(x.size(0), -1)
        # Forward pass through the network
        x = F.relu(self.fc1(x))  # Activation function for first layer
        x = F.relu(self.fc2(x))  # Activation function for second layer
        x = F.relu(self.fc3(x))  # Activation function for third layer
        x = self.fc4(x)          # No activation function for output layer
        return x

In [7]:
if torch.cuda.is_available():
    torch.cuda.set_device(cuda_device)
elif torch.backends.mps.is_available():
    mps_device = torch.device("mps")

data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
    ]),
}

# Setup datasets and loaders
dsets = {
    'train': RawDataset(train_dir, train_behav_file, data_transforms['train']),
    'val': RawDataset(val_dir, val_behav_file, data_transforms['val'])
}
dset_sizes = {split: len(dsets[split]) for split in ['train', 'val']}

dset_loaders = {}
for split in ['train', 'val']:
    targets = np.array([dsets[split].get_label(i) for i in range(len(dsets[split]))])
    class_counts = dsets[split].get_class_counts()
    class_weights = np.array([1.0 / class_counts[label] if class_counts[label] > 0 else 0 for label in targets])
    sampler = WeightedRandomSampler(class_weights, num_samples=len(class_weights), replacement=True)
    dset_loaders[split] = torch.utils.data.DataLoader(dsets[split], batch_size=b_size, num_workers=0, sampler=sampler)
    print('done making loader:', split)

# Initialize model, criterion, and optimizer
model_ft = EEGMLP(input_size=129*1250, num_classes=4)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.01, weight_decay=1e-5)

# Device configuration with MPS support
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model_ft.to(device)
criterion.to(device)
print(f"Training on {device}")

# Train model
model_ft, accuracies, losses = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, dset_loaders, dset_sizes, num_epochs=n_epochs)

# Output results
for split in ['train', 'val']:
    print(split, 'accuracies by epoch:', accuracies[split])
    print(split, 'losses by epoch:', losses[split])


100%|██████████| 192/192 [00:09<00:00, 20.51it/s]


Class counts:  {'HEHF': 2489, 'HELF': 238, 'LEHF': 1371, 'LELF': 878}


100%|██████████| 48/48 [00:02<00:00, 21.45it/s]


Class counts:  {'HEHF': 508, 'HELF': 101, 'LEHF': 535, 'LELF': 100}
done making loader: train
done making loader: val
Training on mps
----------
Epoch 0/9
----------
LR is set to 0.0001
Reached batch iteration 0
we are on batch 1
we are on batch 2
we are on batch 3
we are on batch 4
we are on batch 5
we are on batch 6
we are on batch 7
we are on batch 8
we are on batch 9
we are on batch 10
we are on batch 11
we are on batch 12
we are on batch 13
we are on batch 14
we are on batch 15
we are on batch 16
we are on batch 17
we are on batch 18
we are on batch 19
we are on batch 20
we are on batch 21
we are on batch 22
we are on batch 23
we are on batch 24
we are on batch 25
we are on batch 26
we are on batch 27
we are on batch 28
we are on batch 29
we are on batch 30
we are on batch 31
we are on batch 32
we are on batch 33
we are on batch 34
we are on batch 35
we are on batch 36
we are on batch 37
we are on batch 38
we are on batch 39
we are on batch 40
we are on batch 41
we are on batch 42

KeyboardInterrupt: 