In [1]:
import numpy as np
import pandas as pd
import os
import scipy.io
from einops import rearrange
import matplotlib.pyplot as plt
import seaborn as sns
# import pywt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from lightning import Fabric
from pytorch_lightning.utilities.model_summary import ModelSummary
import lightning as L
from einops import repeat



In [2]:
import sys

sys.path.append("../../../motor-imagery-classification-2024/")

from classification.loaders import EEGDataset,load_data
from models.unet.eeg_unets import Unet,UnetConfig, BottleNeckClassifier, Unet1D
from classification.classifiers import DeepClassifier , SimpleCSP
from classification.loaders import subject_dataset
from ntd.networks import SinusoidalPosEmb
from ntd.diffusion_model import Diffusion
from ntd.utils.kernels_and_diffusion_utils import WhiteNoiseProcess

In [3]:
FS = 250
DEVICE = "cuda"
sns.set_style("darkgrid")

## Loading Data

In [4]:
dataset = {}
for i in range(1,10):
    mat_train,mat_test = load_data("../../data/2b_iv",i)
    dataset[f"subject_{i}"] = {"train":mat_train,"test":mat_test}

REAL_DATA = "../../data/2b_iv/raw"

TRAIN_SPLIT = 3*[[]] + 6*[["train","test"]] + 3*[[]]
TEST_SPLIT = 3*[["test","train"]]+ 3*[[]] +3*[["test","train"]]
SAVE_PATH = "../../saved_models/unet"
if not os.path.isdir(SAVE_PATH):
	os.makedirs(SAVE_PATH)

CHANNELS = [0,1,2]

train_dataset = EEGDataset(subject_splits=TRAIN_SPLIT,
                    dataset=None,
                    save_paths=[REAL_DATA],
                    dataset_type=subject_dataset,
                    channels=CHANNELS,
                    sanity_check=False,
                    length=2.05)

test_dataset = EEGDataset(subject_splits=TEST_SPLIT,
                    dataset=None,
                    save_paths=[REAL_DATA],
                    channels=CHANNELS,
                    sanity_check=False,
                    length=2.05)

KeyboardInterrupt: 

## Defining diffusion model

In [None]:
@torch.jit.script
def double_inputs(x:torch.Tensor,
				  t:torch.Tensor,
				  cond:torch.Tensor):
	
	x = torch.cat([x,x],0)
	t = torch.cat([t,t],0)
	cond = torch.cat([cond,0*cond],0)
	return x,t,cond

In [None]:
@torch.jit.script
def dedouble_outputs(x:torch.Tensor,
					 w:float):
	
	conditional = x[0:len(x)//2]
	unconditinoal = x[len(x)//2:]
	return (1+w)*conditional-w*unconditinoal

In [None]:
from lightning import LightningDataModule


from models.unet.eeg_unets import (UnetConfig,
								   Encode,
								   Decode,
								   ConvConfig,
								   EncodeConfig,
								   DecodeConfig,
								   Convdown)


class DiffusionUnet(L.LightningModule):

	"""
	base Unet model with adaptable topology and dimension in nnUnet style.
	
	Attributes:
		config: configuration for Unet
	
	"""

	def __init__(self,
				 config: UnetConfig,
				 classifier: L.LightningDataModule,
				 time_dim=12,
				 ):
		
		super().__init__()

		self.time_dim = 12
		self.time_embbeder = SinusoidalPosEmb(time_dim)
		self.class_dim = 1
		self.signal_length = config.input_shape
		self.input_features = [config.starting_channels]
		self.signal_channel = config.input_channels
		size = torch.tensor(config.input_shape)

		self.class_embed_product = nn.Linear(1,config.input_channels,
									   bias=True)
		self.class_embed_addition = nn.Linear(1,config.input_channels,
									   bias=False)

		"""
		possible input shapes:
		1. N x D x L
			starting with N examples D=2 (2 electrodes) with length L (time)
			going to N x 32 x L after 1st layer
			
		2. N X D x C x L
			starting with N examples D = 1 (one feature) with 2 electrodes with lenght L
			going to N x 32 x 2 x L

		3. N x D x C x F x L
			N examples D =1 feature 2 electrodes x frequency x L

		4. N x D x F x L
			we don't do the distinction between channels and features but add a frequency dimension
		"""

		# can't divice 0-d tensor
		get_min = lambda x: min(x) if len(size.shape)>0 else x

		while (get_min(size) > 8) & (len(self.input_features)<=6):
			if 2*self.input_features[-1] < config.max_channels:
				self.input_features.append(2*self.input_features[-1])
			else:
				self.input_features.append(config.max_channels)
			size = size/config.pool_fact

		self.encoder = nn.ModuleList()
		self.decoder = nn.ModuleList()
		self.out_shape = [get_min(size),self.input_features[-1]]

		# input_channels = [32, 64, 128, 256,256....]

		self.auxiliary_clf = classifier

		self.base_conv_config = ConvConfig(
			input_channels=1,
			output_channels=1,
			conv_op=config.conv_op,
			norm_op=config.norm_op,
			non_lin=config.non_lin,
			groups=config.conv_group,
			padding=config.conv_padding,
			kernel_size=config.conv_kernel,
			pool_fact=config.pool_fact,
			pool_op=config.pool_op,
			residual=config.residual,
			p_drop=config.conv_pdrop
		)

		self.base_decode_config = DecodeConfig(
			x_channels=1,
			g_channels=1,
			output_channels=1,
			up_conv=config.up_op,
			groups=config.deconv_group,
			padding=config.deconv_padding,
			kernel_size=config.deconv_kernel,
			stride=config.deconv_stride,
			conv_config=self.base_conv_config
		)

		input_channels = config.input_channels + self.time_dim

		for idx,i in enumerate(self.input_features[:-1]):
			encode_config = self.base_conv_config.new_shapes(input_channels,i)
			self.encoder.append(Encode(encode_config))
			input_channels = i

		bottleneck_conv_config = self.base_conv_config.new_shapes(
			input_channels=self.input_features[-2],
			ouput_channels=self.input_features[-1]
		)

		self.middle_conv = Convdown(bottleneck_conv_config)

		output_features = self.input_features[::-1]

		for i in range(len(self.input_features)-1):

			decode_config = self.base_decode_config.new_shapes(
				x_channels=output_features[i+1],
				g_channels=output_features[i],
				output_channels=output_features[i+1]
			)

			self.decoder.append(Decode(decode_config))

		self.output_conv = config.conv_op(config.starting_channels,config.input_channels,1)

	def forward(self,
			 x,
			 t,
			 cond):

		"""
		Full U-net forward pass to get the reconstructed datas
		"""
		
		factor = repeat(self.class_embed_product(cond[...,0]),"b d -> b d l",l=x.shape[-1])
		bias = repeat(self.class_embed_addition(cond[...,0]),"b d -> b d l",l=x.shape[-1])
		x = x*factor+bias
		if self.time_embbeder is not None:
			time_emb = self.time_embbeder(t)  # (-1, time_dim)
			time_emb_repeat = repeat(time_emb, "b t -> b t l", l=x.shape[2])
			x = torch.cat([x, time_emb_repeat], dim=1)

		skip_connections = []
		for encode in self.encoder:
			x,skip = encode(x)
			skip_connections.append(skip)

		x = self.middle_conv(x)

		for decode,skip in zip(self.decoder,reversed(skip_connections)):
			x = decode(skip,x)

		x = self.output_conv(x)
		return x
	
	def conditional_forward(self,
			 x,
			 t,
			 cond,
			 w):

		"""
		Full U-net forward pass to get the reconstructed datas
		"""

		n = len(x)
		x,t,cond = double_inputs(x,t,cond)
		x = self.forward(x,t,cond)
		x = dedouble_outputs(x,w)
		return x
	
	def classify(self,x):

		cond = torch.zeros((x.shape[0],1,x.shape[-1]),device=x.device)
		factor = repeat(self.class_embed_product(cond[...,0]),"b d -> b d l",l=x.shape[-1])
		bias = repeat(self.class_embed_addition(cond[...,0]),"b d -> b d l",l=x.shape[-1])
		x = x*factor+bias
		t = torch.zeros((len(x)),device=x.device)
		if self.time_embbeder is not None:
			time_emb = self.time_embbeder(t)
			time_emb_repeat = repeat(time_emb, "b t -> b t l", l=x.shape[2])
			x = torch.cat([x, time_emb_repeat], dim=1)

		skip_connections = []
		for encode in self.encoder:
			x,skip = encode(x)
			skip_connections.append(skip)

		x = self.middle_conv(x)
		y = self.auxiliary_clf(x)

		return y

In [None]:
Unet1D = UnetConfig(
	input_shape=(512),
	input_channels=3,
	conv_op=nn.Conv1d,
	norm_op=nn.InstanceNorm1d,
	non_lin=nn.ReLU,
	pool_op=nn.AvgPool1d,
	up_op=nn.ConvTranspose1d,
	starting_channels=32,
	max_channels=256,
	conv_group=1,
	conv_padding=(1),
	conv_kernel=(3),
	pool_fact=2,
	deconv_group=1,
	deconv_padding=(0),
	deconv_kernel=(2),
	deconv_stride=(2),
	residual=True
)

In [None]:
classifier = BottleNeckClassifier((2048,1024))
unet_1d = DiffusionUnet(Unet1D,classifier)

## Training

In [None]:
lr = 6E-4
num_epochs = 180
time_dim = 12
decay_min = 2
decay_max = 2
activation_type = "leaky_relu"
num_timesteps = 250
schedule = "linear"
batch_size = 64
# If the schedule is not cosine, we need to test the end_beta
start_beta = 0.0001
end_beta = 0.08
		
train_loader = DataLoader(
	train_dataset,
	batch_size,
)

In [None]:
unet_1d.signal_length

512

In [None]:
# Optimizer (also testing learning rate here)
def train(fabric,
		  num_epochs):
		
	noise_sampler = WhiteNoiseProcess(1.0, 512)

	diffusion_model = Diffusion(
		network=unet_1d,
		diffusion_time_steps=num_timesteps,
		noise_sampler=noise_sampler,
		mal_dist_computer=noise_sampler,
		schedule=schedule,
		start_beta=start_beta,
		end_beta=end_beta,
	)
	optimizer = optim.AdamW(
		unet_1d.parameters(),
		lr=lr,
	)

	train_loader = DataLoader(
		train_dataset,
		batch_size,
	)
		
	diffusion_model,optimizer = fabric.setup(diffusion_model,optimizer)
	train_loader = fabric.setup_dataloaders(train_loader)

	loss_per_epoch = []

	stop_counter = 0
	min_delta = 0.05
	tolerance = 30
			
		# Train model
	for i in range(num_epochs):
		
		epoch_loss = []
		for batch in train_loader:
			
			with fabric.autocast():
			# Repeat the cue signal to match the signal length
				# print(batch["signal"].shape)
				signal,cue = batch
				cue = (cue + 1).to(signal.dtype)
				cond = cue.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 512).to(DEVICE)
				
				loss = diffusion_model.train_batch(signal.to(DEVICE), cond=cond,
									   p_uncond=0.15)
			loss = torch.mean(loss)
			
			epoch_loss.append(loss.item())
			
			fabric.backward(loss)
			# loss.backward()
			optimizer.step()
			optimizer.zero_grad()
			
		epoch_loss = np.mean(epoch_loss)
		loss_per_epoch.append(epoch_loss)
		
		print(f"Epoch {i} loss: {epoch_loss}")

		print(f"diff: {epoch_loss - min(loss_per_epoch)}")

		if epoch_loss - min(loss_per_epoch) >= min_delta*min(loss_per_epoch):
			stop_counter += 1
		if stop_counter > tolerance:
			break
	torch.save(diffusion_model.state_dict(),os.path.join(SAVE_PATH,"unet_diff.pt"))
	torch.save(unet_1d.state_dict(),os.path.join(SAVE_PATH,"unet_state_dict.pt"))
	return diffusion_model

In [None]:
FABRIC = Fabric(accelerator="cuda",precision="bf16-mixed")

Using bfloat16 Automatic Mixed Precision (AMP)


## Diffusion

In [None]:
train_fn = lambda x:train(x,150)
diffusion_model = FABRIC.launch(train_fn)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Epoch 0 loss: 1400.421052631579
diff: 0.0
Epoch 1 loss: 1122.5263157894738
diff: 0.0
Epoch 2 loss: 909.6140350877193
diff: 0.0
Epoch 3 loss: 790.0350877192982
diff: 0.0
Epoch 4 loss: 720.421052631579
diff: 0.0
Epoch 5 loss: 704.3508771929825
diff: 0.0
Epoch 6 loss: 670.4561403508771
diff: 0.0
Epoch 7 loss: 650.5263157894736
diff: 0.0
Epoch 8 loss: 631.0175438596491
diff: 0.0
Epoch 9 loss: 628.5964912280701
diff: 0.0
Epoch 10 loss: 609.6491228070175
diff: 0.0
Epoch 11 loss: 619.6491228070175
diff: 10.0
Epoch 12 loss: 608.7368421052631
diff: 0.0
Epoch 13 loss: 596.0701754385965
diff: 0.0
Epoch 14 loss: 594.140350877193
diff: 0.0
Epoch 15 loss: 589.7894736842105
diff: 0.0
Epoch 16 loss: 588.1754385964912
diff: 0.0
Epoch 17 loss: 565.8245614035088
diff: 0.0
Epoch 18 loss: 567.6491228070175
diff: 1.82456140350871
Epoch 19 loss: 566.4561403508771
diff: 0.6315789473683253
Epoch 20 loss: 573.0877192982456
diff: 7.263157894736764
Epoch 21 loss: 565.8947368421053
diff: 0.07017543859649322
Epoch 

In [None]:
def generate_samples(fabric,
                     diffusion_model, 
					 condition, 
                     n_iter=20,
                     w=0):
    # it's a bit hard to predict memory consumption so splitting in mini-batches to be safe
    num_samples = 200
    cond = 0
    if (condition == 0):
        cond = (torch.zeros(num_samples, 1, 512)+1).to(dtype=torch.float16,
                                                       device=DEVICE)
    elif (condition == 1):
        cond = (torch.ones(num_samples, 1, 512)+1).to(dtype=torch.float16,
                                                      device=DEVICE)
    
    diffusion_model.eval()

    print(f"Generating samples: cue {condition + 1}")
    complete_samples = []
    with fabric.autocast():
        with torch.no_grad():
            for i in range(n_iter):
                samples, _ = diffusion_model.sample(num_samples, cond=cond,w=w)
                samples = samples.cpu().numpy()
                print(samples.shape)
                complete_samples.append(samples)
    complete_samples = np.float32(np.concatenate(complete_samples))
    print(complete_samples.shape)
    return complete_samples

In [None]:
# diffusion_model.load_state_dict(torch.load(os.path.join(SAVE_PATH,"unet_diff.pt")))

In [None]:
# unet_1d.load_state_dict(torch.load(os.path.join(SAVE_PATH,"unet_state_dict.pt")))

In [None]:
generated_signals_zero = generate_samples(FABRIC,diffusion_model, condition=0,n_iter=10,w=8)
generated_signals_one = generate_samples(FABRIC,diffusion_model, condition=1,n_iter=10,w=8)
np.save(os.path.join(SAVE_PATH,"generated_zeros.npy"),generated_signals_zero)
np.save(os.path.join(SAVE_PATH,"generated_ones.npy"),generated_signals_one)

Generating samples: cue 1
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(2000, 3, 512)
Generating samples: cue 2
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(2000, 3, 512)


In [None]:
generated_signals_zero = np.load(os.path.join(SAVE_PATH,"generated_zeros.npy"))
generated_signals_one = np.load(os.path.join(SAVE_PATH,"generated_ones.npy"))

In [None]:
def check():

	accuracies = []
		
	test_classifier = SimpleCSP(train_split=TRAIN_SPLIT,
								test_split=TEST_SPLIT,
								dataset=None,
								save_paths=[REAL_DATA],
								channels=CHANNELS,
								length=2.05)

	full_x,full_y = test_classifier.get_train()

	print(f"full x shape: {full_x.shape}")

	acc = test_classifier.fit()

	print(f"reaching an accuracy of {acc} without fake data")

	for real_fake_split in range(15, 46, 15):
		
		# Train new classifier with a mix of generated and real data
		
		# Change real_fake_split percent of the test_classifier data to generated signals
		n = int(len(full_x) * real_fake_split / 100)

		shuffling = np.random.permutation(full_x.shape[0])

		split_x = full_x[shuffling]
		split_y = full_y[shuffling]
		split_x[0:n//2] = generated_signals_one[0:n//2]
		split_y[0:n//2] = 1

		split_x[n//2:n] = generated_signals_zero[0:n//2]
		split_y[n//2:n] = 0

		print(f"split x shape: {split_x.shape}")

		acc = test_classifier.fit((split_x,split_y))

		accuracies.append(acc)
					
		print(f"Reaching an accuracy of {acc} using {real_fake_split}% fake data")

In [None]:
x = torch.rand((2,3,512),device=unet_1d.device)

In [None]:
unet_1d.classify(x).shape

torch.Size([2, 2])

## Classification

In [None]:
ones = os.path.join(SAVE_PATH,"generated_ones.npy")
zeros = os.path.join(SAVE_PATH,"generated_zeros.npy")
fake_paths = (ones,zeros)

In [None]:
deep_clf = DeepClassifier(
	model=unet_1d,
	save_paths=["../../data/2b_iv/raw/"],
	fake_data=fake_paths,
	train_split=TRAIN_SPLIT,
	test_split=TEST_SPLIT,
	dataset=None,
	dataset_type=subject_dataset,
	length=2.05,
	index_cutoff=512
)

(3615, 3, 512)
(3615,)
we have fake data
final data shape: (5421, 3, 512)
(3360, 3, 512)
(3360,)
final data shape: (3360, 3, 512)


- We can reach about 78% accuracy without any kind of pre-training or fake data

In [None]:
ModelSummary(unet_1d)

  | Name                 | Type                 | Params
--------------------------------------------------------------
0 | time_embbeder        | SinusoidalPosEmb     | 0     
1 | class_embed_product  | Linear               | 6     
2 | class_embed_addition | Linear               | 3     
3 | encoder              | ModuleList           | 1.2 M 
4 | decoder              | ModuleList           | 2.4 M 
5 | auxiliary_clf        | BottleNeckClassifier | 2.1 M 
6 | middle_conv          | Convdown             | 393 K 
7 | output_conv          | Conv1d               | 99    
--------------------------------------------------------------
6.1 M     Trainable params
0         Non-trainable params
6.1 M     Total params
24.477    Total estimated model params size (MB)

In [None]:
fine_tune = [unet_1d.encoder,
			unet_1d.decoder,
			unet_1d.middle_conv,
			unet_1d.output_conv,
			unet_1d.class_embed_addition,
			unet_1d.class_embed_product]

to_optimize = [{"params":i.parameters(),
	"lr":1E-3,
	"weight_decay":1E-4} for i in fine_tune]

to_optimize.append({"params":unet_1d.auxiliary_clf.parameters(),
     "lr":1E-3,
 	"weight_decay":1E-4})

optimizer = optim.AdamW(to_optimize)

In [None]:
deep_clf.fit(fabric=FABRIC,
			 num_epochs=150,
			 lr=1E-3,
			 weight_decay=1E-4,
			 verbose=True,
			 optimizer=optimizer)

using specified optimizer


Epoch [1/150], Training Loss: 0.734, Training Accuracy: 51.47%, Validation Loss: 0.698, Validation Accuracy: 50.42%
Epoch [2/150], Training Loss: 0.688, Training Accuracy: 54.18%, Validation Loss: 0.704, Validation Accuracy: 50.68%
Epoch [3/150], Training Loss: 0.645, Training Accuracy: 59.53%, Validation Loss: 0.771, Validation Accuracy: 48.96%
Epoch [4/150], Training Loss: 0.593, Training Accuracy: 63.00%, Validation Loss: 0.747, Validation Accuracy: 50.09%
Epoch [5/150], Training Loss: 0.563, Training Accuracy: 64.80%, Validation Loss: 0.708, Validation Accuracy: 50.60%
Epoch [6/150], Training Loss: 0.540, Training Accuracy: 65.60%, Validation Loss: 0.732, Validation Accuracy: 49.61%
Epoch [7/150], Training Loss: 0.527, Training Accuracy: 64.53%, Validation Loss: 0.694, Validation Accuracy: 50.48%
Epoch [8/150], Training Loss: 0.500, Training Accuracy: 66.19%, Validation Loss: 0.710, Validation Accuracy: 50.86%
Epoch [9/150], Training Loss: 0.488, Training Accuracy: 68.29%, Validati

54.970238095238095

In [None]:
deep_clf.setup_dataloaders(use_fake=False)
deep_clf.fit(fabric=FABRIC,
			 num_epochs=150,
			 lr=1E-3,
			 weight_decay=1E-4,
			 verbose=True,
			 optimizer=optimizer)

(3615, 3, 512)
(3615,)
final data shape: (3615, 3, 512)
(3360, 3, 512)
(3360,)
final data shape: (3360, 3, 512)
using specified optimizer
Epoch [1/150], Training Loss: 0.802, Training Accuracy: 49.85%, Validation Loss: 0.695, Validation Accuracy: 50.15%
Epoch [2/150], Training Loss: 0.694, Training Accuracy: 49.35%, Validation Loss: 0.694, Validation Accuracy: 49.61%
Epoch [3/150], Training Loss: 0.694, Training Accuracy: 50.84%, Validation Loss: 0.693, Validation Accuracy: 50.51%
Epoch [4/150], Training Loss: 0.694, Training Accuracy: 50.35%, Validation Loss: 0.693, Validation Accuracy: 50.27%
Epoch [5/150], Training Loss: 0.694, Training Accuracy: 50.51%, Validation Loss: 0.693, Validation Accuracy: 50.09%
Epoch [6/150], Training Loss: 0.693, Training Accuracy: 50.04%, Validation Loss: 0.693, Validation Accuracy: 50.24%
Epoch [7/150], Training Loss: 0.693, Training Accuracy: 49.68%, Validation Loss: 0.693, Validation Accuracy: 50.42%
Epoch [8/150], Training Loss: 0.693, Training Accu

50.80357142857143

In [None]:
54.97-50.80

4.170000000000002

In [None]:
67.97-66.39

1.5799999999999983

In [None]:
60.59-51.75

8.840000000000003

In [None]:
((54.97-50.80)+(67.97-66.39)+60.59-51.75)/3

4.863333333333334

- 67.97 - 66.39

- 83.73% with fine-tuning

- 68.7