In [15]:
import sys
sys.path.append("../../../motor-imagery-classification-2024/")
import os
import torch
from torch import nn
import optuna
import torchsummary
from datetime import datetime
import json
import numpy as np
from classification.classifiers import DeepClassifier
from classification.loaders import load_data
from classification.open_bci_loaders import OpenBCIDataset,OpenBCISubject,load_files
from models.unet.eeg_unets import UnetConfig,Unet,BottleNeckClassifier
import lightning as L
from lightning import Fabric
from pytorch_lightning.utilities.model_summary import ModelSummary

torch.set_float32_matmul_precision('medium')

In [16]:
from models.unet import base_eegnet

In [17]:
L.seed_everything(42)

Seed set to 42


42

In [18]:
print(f"path {os.getcwd()}")
files = load_files("../../data/collected_data/")
train_split = 2*[["train"]]
test_split = 2*[["test"]]
save_path = os.path.join("processed","raw")
csp_save_path = os.path.join("processed","data/collected_data/csp")

train_csp_dataset = OpenBCIDataset(
	subject_splits=train_split,
	dataset=files,
	save_paths=[csp_save_path],
	fake_data=None,
	dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512
)

test_csp_dataset = OpenBCIDataset(
	subject_splits=test_split,
	datasechannels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512t=files,
	save_paths=[csp_save_path],
	fake_data=None,
	dataset_type=OpenBCISubject,
	
)

path d:\Machine learning\MI SSL\motor-imagery-classification-2024\experiments\supervised_clf
Saving new data
(1984, 2, 512)
(1984,)
final data shape: (1984, 2, 512)
Saving new data
(992, 2, 512)
(992,)
final data shape: (992, 2, 512)


In [19]:
model = base_eegnet.EEGNet(2,224)

In [20]:
ModelSummary(model)

   | Name       | Type        | Params
--------------------------------------------
0  | conv1      | Conv2d      | 2.1 K 
1  | batchnorm1 | BatchNorm2d | 32    
2  | padding1   | ZeroPad2d   | 0     
3  | conv2      | Conv2d      | 260   
4  | batchnorm2 | BatchNorm2d | 8     
5  | pooling2   | MaxPool2d   | 0     
6  | padding2   | ZeroPad2d   | 0     
7  | conv3      | Conv2d      | 516   
8  | batchnorm3 | BatchNorm2d | 8     
9  | pooling3   | MaxPool2d   | 0     
10 | out_proj   | Linear      | 450   
--------------------------------------------
3.3 K     Trainable params
0         Non-trainable params
3.3 K     Total params
0.013     Total estimated model params size (MB)

In [21]:
clf = DeepClassifier(
	model=model,
	save_paths=[csp_save_path],
	train_split=train_split,
	test_split=test_split,
	dataset=None,
	dataset_type=OpenBCIDataset,
	subject_dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512,
	index_cutoff=512
	)

Loading saved data
(1984, 2, 512)
(1984,)
final data shape: (1984, 2, 512)
Loading saved data
(992, 2, 512)
(992,)
final data shape: (992, 2, 512)


In [22]:
clf.sample_batch().shape

torch.Size([32, 2, 512])

In [23]:
lr = 4E-4
weight_decay = 2E-6
FABRIC = Fabric(accelerator="cuda",precision="bf16-mixed")
clf.fit(FABRIC,32,lr,weight_decay)

Using bfloat16 Automatic Mixed Precision (AMP)


checkpointing
Epoch [1/32], Training Loss: 0.737, Training Accuracy: 52.12%, Validation Loss: 0.700, Validation Accuracy: 52.62%
Min loss: 0.700439453125 vs 0.705322265625
Epoch [2/32], Training Loss: 0.707, Training Accuracy: 54.33%, Validation Loss: 0.705, Validation Accuracy: 52.42%
checkpointing
Epoch [3/32], Training Loss: 0.706, Training Accuracy: 55.34%, Validation Loss: 0.697, Validation Accuracy: 56.05%
checkpointing
Epoch [4/32], Training Loss: 0.681, Training Accuracy: 56.30%, Validation Loss: 0.693, Validation Accuracy: 56.25%
Min loss: 0.693115234375 vs 0.705078125
Epoch [5/32], Training Loss: 0.682, Training Accuracy: 57.91%, Validation Loss: 0.705, Validation Accuracy: 50.40%
Min loss: 0.693115234375 vs 0.6953125
Epoch [6/32], Training Loss: 0.682, Training Accuracy: 58.37%, Validation Loss: 0.695, Validation Accuracy: 53.23%
checkpointing
Epoch [7/32], Training Loss: 0.676, Training Accuracy: 57.86%, Validation Loss: 0.692, Validation Accuracy: 52.02%
Min loss: 0.691650

65.3225806451613

In [43]:
config = UnetConfig(
	input_shape=(256),
	input_channels=2,
	conv_op=nn.Conv1d,
	norm_op=nn.InstanceNorm1d,
	non_lin=nn.ReLU,
	pool_op=nn.MaxPool1d,
	up_op=nn.ConvTranspose1d,
	starting_channels=16,
	max_channels=64,
	conv_group=1,
	conv_kernel=(7),
	conv_padding=7//2,
	pool_fact=2,
	deconv_group=1,
	deconv_padding=(0),
	deconv_kernel=(2),
	deconv_stride=(2),
	residual=True,
	conv_pdrop=0.25
)

classifier = BottleNeckClassifier([64],pool="max")

unet = Unet(config,classifier)
unet.to("cuda")


Unet(
  (encoder): ModuleList(
    (0): Encode(
      (convdown): Convdown(
        (c1): Conv1d(2, 16, kernel_size=(7,), stride=(1,), padding=(3,))
        (c2): Conv1d(16, 16, kernel_size=(7,), stride=(1,), padding=(3,))
        (drop): Dropout(p=0.25, inplace=False)
        (instance_norm): InstanceNorm1d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (non_lin): ReLU()
      )
      (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Encode(
      (convdown): Convdown(
        (c1): Conv1d(16, 32, kernel_size=(7,), stride=(1,), padding=(3,))
        (c2): Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(3,))
        (drop): Dropout(p=0.25, inplace=False)
        (instance_norm): InstanceNorm1d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (non_lin): ReLU()
      )
      (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Enc

In [44]:
ModelSummary(unet)

  | Name          | Type                 | Params
-------------------------------------------------------
0 | encoder       | ModuleList           | 170 K 
1 | decoder       | ModuleList           | 315 K 
2 | auxiliary_clf | BottleNeckClassifier | 130   
3 | middle_conv   | Convdown             | 57.5 K
4 | output_conv   | Conv1d               | 34    
-------------------------------------------------------
543 K     Trainable params
0         Non-trainable params
543 K     Total params
2.176     Total estimated model params size (MB)

In [45]:
unet_clf = DeepClassifier(
	model=unet,
	save_paths=[csp_save_path],
	train_split=train_split,
	test_split=test_split,
	dataset=None,
	dataset_type=OpenBCIDataset,
	subject_dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512,
	index_cutoff=512
	)

Loading saved data
(1984, 2, 512)
(1984,)
final data shape: (1984, 2, 512)
Loading saved data
(992, 2, 512)
(992,)
final data shape: (992, 2, 512)


In [46]:
unet_clf.fit(FABRIC,32,lr,weight_decay)

checkpointing
Epoch [1/32], Training Loss: 0.737, Training Accuracy: 50.96%, Validation Loss: 0.728, Validation Accuracy: 50.40%
checkpointing
Epoch [2/32], Training Loss: 0.724, Training Accuracy: 50.30%, Validation Loss: 0.705, Validation Accuracy: 53.23%
Min loss: 0.705322265625 vs 0.731201171875
Epoch [3/32], Training Loss: 0.715, Training Accuracy: 51.66%, Validation Loss: 0.731, Validation Accuracy: 48.99%
Min loss: 0.705322265625 vs 0.729248046875
Epoch [4/32], Training Loss: 0.708, Training Accuracy: 51.26%, Validation Loss: 0.729, Validation Accuracy: 47.78%
Min loss: 0.705322265625 vs 0.728759765625
Epoch [5/32], Training Loss: 0.709, Training Accuracy: 51.21%, Validation Loss: 0.729, Validation Accuracy: 46.77%
Min loss: 0.705322265625 vs 0.708740234375
Epoch [6/32], Training Loss: 0.697, Training Accuracy: 54.03%, Validation Loss: 0.709, Validation Accuracy: 49.80%
Min loss: 0.705322265625 vs 0.718994140625
Epoch [7/32], Training Loss: 0.693, Training Accuracy: 53.73%, Vali

61.08870967741935