In [1]:
import numpy as np
import pandas as pd
import os
import scipy.io
from scipy.signal import butter, filtfilt, iirnotch, cheby2
from einops import rearrange
import matplotlib.pyplot as plt
import seaborn as sns
# import pywt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from lightning import Fabric



In [2]:
import sys

sys.path.append("../../motor-imagery-classification-2024/")

from loaders import EEGDataset,load_data
from models.unet.eeg_unets import Unet,UnetConfig, BottleNeckClassifier, Unet1D
from classifiers import DeepClassifier
from loaders import subject_dataset

In [3]:
FS = 250
sns.set_style("darkgrid")

In [4]:
dataset = {}
for i in range(1,10):
	mat_train,mat_test = load_data("../data/2b_iv",i)
	dataset[f"subject_{i}"] = {"train":mat_train,"test":mat_test}

In [5]:
classifier = BottleNeckClassifier((2048,1024))
unet_1d = Unet(Unet1D,classifier)

In [6]:
train_split = 6*[["train","test"]] + 3*[["train"]]
test_split = 6*[[]] + 3* [["test"]]

In [7]:
ones = "../saved_models/raw_eeg/generated_ones.npy"
zeros = "../saved_models/raw_eeg/generated_zeros.npy"
fake_paths = (ones,zeros)

In [12]:
deep_clf = DeepClassifier(
	model=unet_1d,
	save_paths=["../data/2b_iv/raw/"],
	fake_data=fake_paths,
	train_split=train_split,
	test_split=test_split,
	dataset=None,
	dataset_type=subject_dataset,
	length=2.05
)

(4560, 3, 512)
(4560,)
we have fake data
final data shape: (8560, 3, 512)
(707, 3, 512)
(707,)
final data shape: (707, 3, 512)


In [13]:
fabric = Fabric(accelerator="cuda",precision="bf16-mixed")

Using bfloat16 Automatic Mixed Precision (AMP)


In [14]:
deep_clf.fit(fabric=fabric,
			 num_epochs=150,
			 lr=1E-3,
			 weight_decay=1E-4,
			 verbose=True)

Epoch [1/150], Training Loss: 0.672, Training Accuracy: 59.05%, Validation Loss: 0.640, Validation Accuracy: 68.03%
Epoch [2/150], Training Loss: 0.662, Training Accuracy: 61.03%, Validation Loss: 0.615, Validation Accuracy: 68.46%
Epoch [3/150], Training Loss: 0.655, Training Accuracy: 61.06%, Validation Loss: 0.607, Validation Accuracy: 65.63%
Epoch [4/150], Training Loss: 0.661, Training Accuracy: 61.31%, Validation Loss: 0.588, Validation Accuracy: 69.59%
Epoch [5/150], Training Loss: 0.654, Training Accuracy: 61.95%, Validation Loss: 0.632, Validation Accuracy: 66.34%
Epoch [6/150], Training Loss: 0.655, Training Accuracy: 62.15%, Validation Loss: 0.611, Validation Accuracy: 68.46%
Epoch [7/150], Training Loss: 0.651, Training Accuracy: 62.08%, Validation Loss: 0.599, Validation Accuracy: 69.17%
Epoch [8/150], Training Loss: 0.652, Training Accuracy: 61.34%, Validation Loss: 0.597, Validation Accuracy: 70.16%
Epoch [9/150], Training Loss: 0.648, Training Accuracy: 61.95%, Validati

79.4908062234795

In [15]:
deep_clf.setup_dataloaders(use_fake=False)

(4560, 3, 512)
(4560,)
final data shape: (4560, 3, 512)
(707, 3, 512)
(707,)
final data shape: (707, 3, 512)


In [16]:
deep_clf.fit(fabric=fabric,
			 num_epochs=150,
			 lr=1E-3,
			 weight_decay=1E-4,
			 verbose=True)

Epoch [1/150], Training Loss: 0.657, Training Accuracy: 62.15%, Validation Loss: 0.648, Validation Accuracy: 61.95%
Epoch [2/150], Training Loss: 0.646, Training Accuracy: 63.05%, Validation Loss: 0.614, Validation Accuracy: 68.60%
Epoch [3/150], Training Loss: 0.624, Training Accuracy: 65.75%, Validation Loss: 0.591, Validation Accuracy: 68.18%
Epoch [4/150], Training Loss: 0.622, Training Accuracy: 64.28%, Validation Loss: 0.580, Validation Accuracy: 70.44%
Epoch [5/150], Training Loss: 0.611, Training Accuracy: 66.62%, Validation Loss: 0.652, Validation Accuracy: 67.89%
Epoch [6/150], Training Loss: 0.615, Training Accuracy: 66.12%, Validation Loss: 0.598, Validation Accuracy: 69.45%
Epoch [7/150], Training Loss: 0.606, Training Accuracy: 67.17%, Validation Loss: 0.588, Validation Accuracy: 71.29%
Epoch [8/150], Training Loss: 0.595, Training Accuracy: 67.46%, Validation Loss: 0.562, Validation Accuracy: 72.14%
Epoch [9/150], Training Loss: 0.586, Training Accuracy: 68.97%, Validati

78.64214992927865