In [1]:
import sys
sys.path.append("../../../motor-imagery-classification-2024/")
import os
import torch
from torch import nn
import optuna
import torchsummary
from datetime import datetime
import json
import numpy as np
from classification.classifiers import DeepClassifier
from classification.loaders import load_data
from classification.open_bci_loaders import OpenBCIDataset,OpenBCISubject,load_files
from models.unet.eeg_unets import UnetConfig,Unet,BottleNeckClassifier
import lightning as L
from lightning import Fabric
from pytorch_lightning.utilities.model_summary import ModelSummary

torch.set_float32_matmul_precision('medium')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from models.unet import base_eegnet

In [3]:
print(f"path {os.getcwd()}")
files = load_files("../../data/collected_data/")
train_split = 2*[["train"]]
test_split = 2*[["test"]]
save_path = os.path.join("processed","raw")
csp_save_path = os.path.join("processed","data/collected_data/csp")

train_csp_dataset = OpenBCIDataset(
	subject_splits=train_split,
	dataset=files,
	save_paths=[csp_save_path],
	fake_data=None,
	dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512
)

test_csp_dataset = OpenBCIDataset(
	subject_splits=test_split,
	dataset=files,
	save_paths=[csp_save_path],
	fake_data=None,
	dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512
)

path d:\Machine learning\MI SSL\motor-imagery-classification-2024\experiments\supervised_clf
Saving new data
(1984, 2, 512)
(1984,)
final data shape: (1984, 2, 512)
Saving new data
(992, 2, 512)
(992,)
final data shape: (992, 2, 512)


In [4]:
model = base_eegnet.EEGNet(2,224)

In [5]:
ModelSummary(model)

   | Name       | Type        | Params
--------------------------------------------
0  | conv1      | Conv2d      | 2.1 K 
1  | batchnorm1 | BatchNorm2d | 32    
2  | padding1   | ZeroPad2d   | 0     
3  | conv2      | Conv2d      | 260   
4  | batchnorm2 | BatchNorm2d | 8     
5  | pooling2   | MaxPool2d   | 0     
6  | padding2   | ZeroPad2d   | 0     
7  | conv3      | Conv2d      | 516   
8  | batchnorm3 | BatchNorm2d | 8     
9  | pooling3   | MaxPool2d   | 0     
10 | out_proj   | Linear      | 450   
--------------------------------------------
3.3 K     Trainable params
0         Non-trainable params
3.3 K     Total params
0.013     Total estimated model params size (MB)

In [10]:
clf = DeepClassifier(
	model=model,
	save_paths=[csp_save_path],
	train_split=train_split,
	test_split=test_split,
	dataset=None,
	dataset_type=OpenBCIDataset,
	subject_dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512,
	index_cutoff=512
	)

Loading saved data
(1984, 2, 512)
(1984,)
final data shape: (1984, 2, 512)
Loading saved data
(992, 2, 512)
(992,)
final data shape: (992, 2, 512)


In [11]:
clf.sample_batch().shape

torch.Size([32, 2, 512])

In [12]:
lr = 4E-4
weight_decay = 2E-6
FABRIC = Fabric(accelerator="cuda",precision="bf16-mixed")
clf.fit(FABRIC,32,lr,weight_decay)

Using bfloat16 Automatic Mixed Precision (AMP)


checkpointing
Epoch [1/32], Training Loss: 0.747, Training Accuracy: 50.86%, Validation Loss: 0.734, Validation Accuracy: 46.77%
checkpointing
Epoch [2/32], Training Loss: 0.713, Training Accuracy: 53.88%, Validation Loss: 0.707, Validation Accuracy: 50.40%
checkpointing
Epoch [3/32], Training Loss: 0.719, Training Accuracy: 52.17%, Validation Loss: 0.691, Validation Accuracy: 54.64%
Min loss: 0.69091796875 vs 0.694580078125
Epoch [4/32], Training Loss: 0.694, Training Accuracy: 56.80%, Validation Loss: 0.695, Validation Accuracy: 55.44%
checkpointing
Epoch [5/32], Training Loss: 0.670, Training Accuracy: 60.13%, Validation Loss: 0.686, Validation Accuracy: 55.04%
checkpointing
Epoch [6/32], Training Loss: 0.669, Training Accuracy: 60.58%, Validation Loss: 0.683, Validation Accuracy: 57.66%
checkpointing
Epoch [7/32], Training Loss: 0.655, Training Accuracy: 61.90%, Validation Loss: 0.676, Validation Accuracy: 57.06%
Min loss: 0.676025390625 vs 0.677001953125
Epoch [8/32], Training Los

67.54032258064517