In [1]:
import numpy as np
import pandas as pd
import os
import scipy.io
from einops import rearrange
import matplotlib.pyplot as plt
import seaborn as sns
# import pywt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from lightning import Fabric
from pytorch_lightning.utilities.model_summary import ModelSummary
import lightning as L
from einops import repeat



In [2]:
import sys

sys.path.append("../../../motor-imagery-classification-2024/")

from classification.loaders import EEGDataset,load_data
from models.unet.eeg_unets import Unet,UnetConfig, BottleNeckClassifier, Unet1D
from classification.classifiers import DeepClassifier , SimpleCSP
from classification.loaders import subject_dataset
from classification.open_bci_loaders import OpenBCIDataset,OpenBCISubject, load_files
from ntd.networks import SinusoidalPosEmb
from ntd.diffusion_model import Diffusion
from ntd.utils.kernels_and_diffusion_utils import WhiteNoiseProcess

from u_net_diffusion import DiffusionUnet, DiffusionUnetConfig
from models.unet import base_eegnet
torch.set_float32_matmul_precision('medium')

In [3]:
FS = 256
DEVICE = "cuda"
sns.set_style("darkgrid")
SAVE_PATH = "../../saved_models/EEGNet"
if not os.path.isdir(SAVE_PATH):
	os.makedirs(SAVE_PATH)


In [4]:
print(f"path {os.getcwd()}")
files = load_files("../../data/collected_data/")
train_split = 2*[["train"]]
test_split = 2*[["test"]]
save_path = os.path.join("processed","raw")
csp_save_path = os.path.join("processed","data/collected_data/csp")

train_csp_dataset = OpenBCIDataset(
	subject_splits=train_split,
	dataset=files,
	save_paths=[csp_save_path],
	fake_data=None,
	dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512
)
test_csp_dataset = OpenBCIDataset(
	subject_splits=test_split,
	dataset=files,
	save_paths=[csp_save_path],
	fake_data=None,
	dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512
)

path d:\Machine learning\MI SSL\motor-imagery-classification-2024\models\diffusion
Saving new data
(1984, 2, 512)
(1984,)
final data shape: (1984, 2, 512)
Saving new data
(992, 2, 512)
(992,)
final data shape: (992, 2, 512)


In [5]:
model = base_eegnet.EEGUNet(8,8,2,512)
print(ModelSummary(model))

   | Name        | Type             | Params
--------------------------------------------------
0  | time_embed  | SinusoidalPosEmb | 0     
1  | class_embed | Embedding        | 24    
2  | conv1       | Conv2d           | 2.1 K 
3  | conv2       | Conv2d           | 400   
4  | pooling2    | MaxPool2d        | 0     
5  | embed2      | Embed            | 288   
6  | embed3      | Embed            | 576   
7  | conv3       | Conv2d           | 724   
8  | bottle_conv | Conv1d           | 1.3 K 
9  | pooling3    | MaxPool2d        | 0     
10 | out_proj    | Linear           | 1.0 K 
11 | decode1     | DecoderBlock     | 13.4 K
12 | decode2     | DecoderBlock     | 52.5 K
13 | out_conv    | Conv1d           | 65    
--------------------------------------------------
72.4 K    Trainable params
0         Non-trainable params
72.4 K    Total params
0.290     Total estimated model params size (MB)


In [6]:
lr = 6E-4
num_epochs = 180
time_dim = 12
decay_min = 2
decay_max = 2
activation_type = "leaky_relu"
num_timesteps = 1000
schedule = "linear"
batch_size = 64
# If the schedule is not cosine, we need to test the end_beta
start_beta = 0.0001
end_beta = 0.08
		
train_loader = DataLoader(
	train_csp_dataset,
	batch_size,
)

In [7]:
def train(fabric,
		  unet,
		  train_dataset,
		  num_epochs):
		
	noise_sampler = WhiteNoiseProcess(1.0, 512)

	diffusion_model = Diffusion(
		network=unet,
		diffusion_time_steps=num_timesteps,
		noise_sampler=noise_sampler,
		mal_dist_computer=noise_sampler,
		schedule=schedule,
		start_beta=start_beta,
		end_beta=end_beta,
	)
	optimizer = optim.AdamW(
		unet.parameters(),
		lr=lr,
	)

	train_loader = DataLoader(
		train_dataset,
		batch_size,
	)
		
	diffusion_model,optimizer = fabric.setup(diffusion_model,optimizer)
	train_loader = fabric.setup_dataloaders(train_loader)

	loss_per_epoch = []

	stop_counter = 0
	min_delta = 0.05
	tolerance = 30
			
		# Train model
	for i in range(num_epochs):
		
		epoch_loss = []
		for batch in train_loader:
			
			with fabric.autocast():
			# Repeat the cue signal to match the signal length
				# print(batch["signal"].shape)
				signal,cue = batch
				cue = (cue + 1).to(signal.dtype)
				cond = cue.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 512).to(DEVICE)
				
				loss = diffusion_model.train_batch(signal.to(DEVICE), cond=cond,
									   p_uncond=0.15)
			loss = torch.mean(loss)
			
			epoch_loss.append(loss.item())
			
			fabric.backward(loss)
			# loss.backward()
			optimizer.step()
			optimizer.zero_grad()
			
		epoch_loss = np.mean(epoch_loss)
		loss_per_epoch.append(epoch_loss)
		
		print(f"Epoch {i} loss: {epoch_loss}")

		print(f"diff: {epoch_loss - min(loss_per_epoch)}")

		if epoch_loss - min(loss_per_epoch) >= min_delta*min(loss_per_epoch):
			stop_counter += 1
		if stop_counter > tolerance:
			break
	torch.save(diffusion_model.state_dict(),os.path.join(SAVE_PATH,"unet_diff.pt"))
	torch.save(unet.state_dict(),os.path.join(SAVE_PATH,"unet_state_dict.pt"))
	return diffusion_model

In [8]:
FABRIC = Fabric(accelerator="cuda",precision="bf16-mixed")

Using bfloat16 Automatic Mixed Precision (AMP)


In [9]:
train_fn = lambda x:train(x,model,train_csp_dataset,100)
diffusion_model = FABRIC.launch(train_fn)



Epoch 0 loss: 1150.7096774193549
diff: 0.0
Epoch 1 loss: 1096.258064516129
diff: 0.0
Epoch 2 loss: 1070.967741935484
diff: 0.0
Epoch 3 loss: 1052.3870967741937
diff: 0.0
Epoch 4 loss: 1042.3225806451612
diff: 0.0
Epoch 5 loss: 1029.2903225806451
diff: 0.0
Epoch 6 loss: 1021.4193548387096
diff: 0.0
Epoch 7 loss: 1019.3548387096774
diff: 0.0
Epoch 8 loss: 1016.0
diff: 0.0
Epoch 9 loss: 1013.9354838709677
diff: 0.0
Epoch 10 loss: 1013.6774193548387
diff: 0.0
Epoch 11 loss: 1011.3548387096774
diff: 0.0
Epoch 12 loss: 1009.0322580645161
diff: 0.0
Epoch 13 loss: 1007.6129032258065
diff: 0.0
Epoch 14 loss: 1011.0967741935484
diff: 3.48387096774195
Epoch 15 loss: 1008.516129032258
diff: 0.9032258064515872
Epoch 16 loss: 1007.0967741935484
diff: 0.0
Epoch 17 loss: 1007.8709677419355
diff: 0.7741935483870748
Epoch 18 loss: 1006.7096774193549
diff: 0.0
Epoch 19 loss: 1006.8387096774194
diff: 0.12903225806451246
Epoch 20 loss: 1007.8709677419355
diff: 1.1612903225806122
Epoch 21 loss: 1005.8064516