In [1]:
import numpy as np
import pandas as pd
import os
import scipy.io
from scipy.signal import butter, filtfilt, iirnotch, cheby2
from einops import rearrange
import matplotlib.pyplot as plt
import seaborn as sns
# import pywt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from lightning import Fabric
import yaml
import lightning
from pytorch_lightning.utilities.model_summary import ModelSummary
from einops import repeat
from typing import Tuple



In [2]:
import sys

sys.path.append("../../")

from classification.loaders import EEGDataset,load_data
from models.unet.eeg_unets import Unet,UnetConfig, BottleNeckClassifier, Unet1D
from classification.classifiers import DeepClassifier
from classification.loaders import subject_dataset, CSP_subject_dataset
from ntd.networks import LongConv
from ntd.diffusion_model import Diffusion
from ntd.utils.kernels_and_diffusion_utils import WhiteNoiseProcess
import json

In [3]:
FS = 250
sns.set_style("darkgrid")
DATA_PATH = "../../data/2b_iv"
SAVE_PATH = "../../saved_models/raw_eeg"
if not os.path.isdir(SAVE_PATH):
	os.makedirs(SAVE_PATH)
CONF_PATH = "../diffusion/conf"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


In [4]:
with open(os.path.join(CONF_PATH, "train.yaml"), "r") as f:
    train_yaml = yaml.safe_load(f)
    
with open(os.path.join(CONF_PATH, "classifier.yaml"), "r") as f:
    classifier_yaml = yaml.safe_load(f)
    
with open(os.path.join(CONF_PATH, "network.yaml"), "r") as f:
    network_yaml = yaml.safe_load(f)
    
with open(os.path.join(CONF_PATH, "diffusion.yaml"), "r") as f:
    diffusion_yaml = yaml.safe_load(f)


In [5]:
with open(r"params_2024_05_06_03_41.json","r") as f:
	best_params = json.load(f)

In [65]:
dataset = {}
for i in range(1,10):
    mat_train,mat_test = load_data("../../data/2b_iv",i)
    dataset[f"subject_{i}"] = {"train":mat_train,"test":mat_test}

REAL_DATA = "../../data/2b_iv/raw"

TRAIN_SPLIT = 6*[["train","test"]] + 3*[["train"]]
TEST_SPLIT = 6*[[]] + 3* [["test"]]

CHANNELS = [0,1,2]

train_dataset = EEGDataset(subject_splits=TRAIN_SPLIT,
                    dataset=None,
                    save_paths=[REAL_DATA],
                    dataset_type=subject_dataset,
                    channels=CHANNELS,
                    sanity_check=False,
                    length=2.05)

test_dataset = EEGDataset(subject_splits=TEST_SPLIT,
                    dataset=None,
                    save_paths=[REAL_DATA],
                    channels=CHANNELS,
                    sanity_check=False,
                    length=2.05)

print(train_dataset.data[0].shape)
network_yaml["signal_length"] = train_dataset.data[0].shape[-1]
network_yaml["signal_channel"] = train_dataset.data[0].shape[1]
print(network_yaml["signal_length"])
with open(r"params_2024_05_06_03_41.json","r") as f:
	best_params = json.load(f)

(4560, 3, 512)
(4560,)
final data shape: (4560, 3, 512)
(707, 3, 512)
(707,)
final data shape: (707, 3, 512)
(4560, 3, 512)
512


In [92]:
lr = 6E-4
num_epochs = 180
time_dim = 12
hidden_channel = best_params["hidden_channel"]
kernel_size = best_params["kernel_size"]
num_scales = best_params["num_scales"]
decay_min = 2
decay_max = 2
activation_type = "leaky_relu"
use_fft_conv = kernel_size * (2 ** (num_scales - 1)) >= 100
num_timesteps = 250
schedule = "linear"
# If the schedule is not cosine, we need to test the end_beta
start_beta = 0.0001
end_beta = 0.08

In [93]:
network = LongConv(
			signal_length=network_yaml["signal_length"],
			signal_channel=network_yaml["signal_channel"], # The CSP classifier components
			time_dim=time_dim,
			cond_channel=network_yaml["cond_channel"], # The cond channel will contain the cue (0 or 1)
			hidden_channel=hidden_channel,
			in_kernel_size=kernel_size,
			out_kernel_size=kernel_size,
			slconv_kernel_size=kernel_size,
			num_scales=num_scales,
			decay_min=decay_min,
			decay_max=decay_max,
			heads=network_yaml["heads"],
			activation_type=activation_type,
			use_fft_conv=use_fft_conv,
		)

noise_sampler = WhiteNoiseProcess(1.0, network_yaml["signal_length"])

diffusion_model = Diffusion(
	network=network,
	diffusion_time_steps=num_timesteps,
	noise_sampler=noise_sampler,
	mal_dist_computer=noise_sampler,
	schedule=schedule,
	start_beta=start_beta,
	end_beta=end_beta,
)

In [94]:
diffusion_model.load_state_dict(torch.load("../../saved_models/raw_eeg/best_model.pt"))

<All keys matched successfully>

In [95]:
class ClassificationHead(lightning.LightningModule):

	def __init__(self,
			  channels: Tuple[int],
			  pool=None) -> None:
		super().__init__()

		self.mlp = nn.ModuleList()
		for i in range(len(channels)-1):
			self.mlp.append(nn.Linear(channels[i],channels[i+1]))
			self.mlp.append(nn.ReLU())
		self.mlp.append(nn.Linear(channels[-1],2))
		self.pool = pool

	def forward(self,x):
		x = x[...,-1]
		for i in self.mlp:
			x = i(x)
		return x

In [96]:
class DiffusionClf(lightning.LightningModule):

	def __init__(self,
			  model,
			  clf,
			  freeze=True):

		super().__init__()
		self.model = model.network
		if freeze:
			for param in self.model.parameters():
				param.requires_grad = False
		self.clf = clf

	def forward(self,
			 x):

		cond = torch.ones((x.shape[0],1,x.shape[-1]),device=self.device)
		x = torch.cat([x,cond],1)
		t = torch.zeros(len(x),device=self.device)
		time_embed = self.model.time_embbeder(t)
		time_embed = repeat(time_embed,"b t -> b t l",l=x.shape[-1])
		x = torch.cat([x,time_embed],1)
		x = self.model.conv_pool[0:-1](x)
		return x
	
	def classify(self,x):

		x = self.forward(x)
		x = self.clf(x)
		return x

In [100]:
head = ClassificationHead((96,2048,1024))

In [101]:
clf = DiffusionClf(diffusion_model,head,freeze=False)

In [102]:
ModelSummary(clf)

  | Name  | Type               | Params
---------------------------------------------
0 | model | LongConv           | 429 K 
1 | clf   | ClassificationHead | 2.3 M 
---------------------------------------------
2.7 M     Trainable params
0         Non-trainable params
2.7 M     Total params
10.914    Total estimated model params size (MB)

In [103]:
x = torch.rand((16,3,512),device=clf.device)

In [104]:
clf.classify(x).shape


torch.Size([16, 2])

In [105]:
clf.classify(x)

tensor([[0.0821, 0.0669],
        [0.0717, 0.0726],
        [0.0799, 0.0768],
        [0.0814, 0.0790],
        [0.0716, 0.0674],
        [0.0899, 0.0723],
        [0.0714, 0.0618],
        [0.0791, 0.0774],
        [0.0791, 0.0589],
        [0.0927, 0.0806],
        [0.0801, 0.0686],
        [0.0771, 0.0742],
        [0.0952, 0.0835],
        [0.0860, 0.0671],
        [0.0871, 0.0705],
        [0.0830, 0.0720]], grad_fn=<AddmmBackward0>)

In [106]:
ones = "../saved_models/raw_eeg/generated_ones.npy"
zeros = "../saved_models/raw_eeg/generated_zeros.npy"
fake_paths = (ones,zeros)

In [107]:
slc_clf = DeepClassifier(
	model=clf.to(DEVICE),
	save_paths=["../../data/2b_iv/raw/"],
	fake_data=None,
	train_split=TRAIN_SPLIT,
	test_split=TEST_SPLIT,
	dataset=None,
	dataset_type=subject_dataset
)

(4560, 3, 500)
(4560,)
final data shape: (4560, 3, 500)
(707, 3, 500)
(707,)
final data shape: (707, 3, 500)


In [108]:
fabric = Fabric(accelerator="cuda",precision="bf16-mixed")

Using bfloat16 Automatic Mixed Precision (AMP)


In [109]:
optimizer = optim.AdamW([
	{"params":slc_clf.model.model.parameters(),"lr":2E-5,"weight_decay":1E-4},
	{"params":slc_clf.model.clf.parameters(),"lr":1E-4,"weight_decay":1E-4}
])

In [110]:
slc_clf.fit(fabric=fabric,
			 num_epochs=150,
			 lr=1E-3,
			 weight_decay=1E-4,
			 verbose=True,
			 optimizer=optimizer)

using specified optimizer
Epoch [1/150], Training Loss: 0.697, Training Accuracy: 50.20%, Validation Loss: 0.690, Validation Accuracy: 50.78%
Epoch [2/150], Training Loss: 0.691, Training Accuracy: 52.74%, Validation Loss: 0.691, Validation Accuracy: 53.47%
Epoch [3/150], Training Loss: 0.688, Training Accuracy: 54.19%, Validation Loss: 0.684, Validation Accuracy: 52.05%
Epoch [4/150], Training Loss: 0.687, Training Accuracy: 55.46%, Validation Loss: 0.683, Validation Accuracy: 53.32%
Epoch [5/150], Training Loss: 0.682, Training Accuracy: 56.01%, Validation Loss: 0.702, Validation Accuracy: 51.06%
Epoch [6/150], Training Loss: 0.681, Training Accuracy: 56.86%, Validation Loss: 0.683, Validation Accuracy: 55.16%
Epoch [7/150], Training Loss: 0.675, Training Accuracy: 57.81%, Validation Loss: 0.679, Validation Accuracy: 55.45%
Epoch [8/150], Training Loss: 0.674, Training Accuracy: 57.68%, Validation Loss: 0.692, Validation Accuracy: 52.90%
Epoch [9/150], Training Loss: 0.668, Training 

KeyboardInterrupt: 