In [2]:
import numpy as np
import pandas as pd
import os
import scipy.io
from einops import rearrange
import matplotlib.pyplot as plt
import seaborn as sns
# import pywt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from lightning import Fabric
from pytorch_lightning.utilities.model_summary import ModelSummary
import lightning as L
from einops import repeat



In [3]:
import sys

sys.path.append("../../../motor-imagery-classification-2024/")

from classification.loaders import EEGDataset,load_data
from models.unet.eeg_unets import Unet,UnetConfig, BottleNeckClassifier, Unet1D
from classification.classifiers import DeepClassifier , SimpleCSP
from classification.loaders import subject_dataset
from ntd.networks import SinusoidalPosEmb
from ntd.diffusion_model import Diffusion
from ntd.utils.kernels_and_diffusion_utils import WhiteNoiseProcess

from u_net_diffusion import DiffusionUnet, DiffusionUnetConfig

torch.set_float32_matmul_precision('medium')

In [3]:
FS = 250
DEVICE = "cuda"
sns.set_style("darkgrid")

## Loading data

In [6]:
dataset = {}
for i in range(1,10):
    mat_train,mat_test = load_data("../../data/2b_iv",i)
    dataset[f"subject_{i}"] = {"train":mat_train,"test":mat_test}

REAL_DATA = "../../data/2b_iv/raw"

TRAIN_SPLIT = 8*[["train","test"]]
TEST_SPLIT = 8*[[]] + [["test"]]
SAVE_PATH = "../../saved_models/unet"
if not os.path.isdir(SAVE_PATH):
	os.makedirs(SAVE_PATH)

CHANNELS = [0,1,2]

train_dataset = EEGDataset(subject_splits=TRAIN_SPLIT,
                    dataset=None,
                    save_paths=[REAL_DATA],
                    dataset_type=subject_dataset,
                    channels=CHANNELS,
                    sanity_check=False,
                    length=2.05)

test_dataset = EEGDataset(subject_splits=TEST_SPLIT,
                    dataset=None,
                    save_paths=[REAL_DATA],
                    channels=CHANNELS,
                    sanity_check=False,
                    length=2.05)

(4705, 3, 512)
(4705,)
final data shape: (4705, 3, 512)
(245, 3, 512)
(245,)
final data shape: (245, 3, 512)


In [5]:
UnetDiff1D = DiffusionUnetConfig(
	time_dim=12,
	class_dim=12,
	num_classes=2,
	input_shape=(512),
	input_channels=3,
	conv_op=nn.Conv1d,
	norm_op=nn.InstanceNorm1d,
	non_lin=nn.ReLU,
	pool_op=nn.AvgPool1d,
	up_op=nn.ConvTranspose1d,
	starting_channels=32,
	max_channels=256,
	conv_group=1,
	conv_padding=(1),
	conv_kernel=(3),
	pool_fact=2,
	deconv_group=1,
	deconv_padding=(0),
	deconv_kernel=(2),
	deconv_stride=(2),
	residual=True
)

In [6]:
classifier = BottleNeckClassifier((2048,1024),)
unet = DiffusionUnet(UnetDiff1D,classifier)
print(ModelSummary(unet))

  | Name          | Type                 | Params
-------------------------------------------------------
0 | time_embed    | SinusoidalPosEmb     | 0     
1 | class_embed   | Embedding            | 36    
2 | encoder       | ModuleList           | 1.2 M 
3 | decoder       | ModuleList           | 2.5 M 
4 | auxiliary_clf | BottleNeckClassifier | 2.1 M 
5 | middle_conv   | EmbedConvdown        | 400 K 
6 | output_conv   | Conv1d               | 99    
-------------------------------------------------------
6.2 M     Trainable params
0         Non-trainable params
6.2 M     Total params
24.706    Total estimated model params size (MB)


## Training

In [7]:
lr = 6E-4
num_epochs = 180
time_dim = 12
decay_min = 2
decay_max = 2
activation_type = "leaky_relu"
num_timesteps = 250
schedule = "linear"
batch_size = 64
# If the schedule is not cosine, we need to test the end_beta
start_beta = 0.0001
end_beta = 0.08
		
train_loader = DataLoader(
	train_dataset,
	batch_size,
)

In [8]:
unet.signal_length

512

In [9]:
def train(fabric,
		  num_epochs):
		
	noise_sampler = WhiteNoiseProcess(1.0, 512)

	diffusion_model = Diffusion(
		network=unet,
		diffusion_time_steps=num_timesteps,
		noise_sampler=noise_sampler,
		mal_dist_computer=noise_sampler,
		schedule=schedule,
		start_beta=start_beta,
		end_beta=end_beta,
	)
	optimizer = optim.AdamW(
		unet.parameters(),
		lr=lr,
	)

	train_loader = DataLoader(
		train_dataset,
		batch_size,
	)
		
	diffusion_model,optimizer = fabric.setup(diffusion_model,optimizer)
	train_loader = fabric.setup_dataloaders(train_loader)

	loss_per_epoch = []

	stop_counter = 0
	min_delta = 0.05
	tolerance = 30
			
		# Train model
	for i in range(num_epochs):
		
		epoch_loss = []
		for batch in train_loader:
			
			with fabric.autocast():
			# Repeat the cue signal to match the signal length
				# print(batch["signal"].shape)
				signal,cue = batch
				cue = (cue + 1).to(signal.dtype)
				cond = cue.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 512).to(DEVICE)
				
				loss = diffusion_model.train_batch(signal.to(DEVICE), cond=cond,
									   p_uncond=0.15)
			loss = torch.mean(loss)
			
			epoch_loss.append(loss.item())
			
			fabric.backward(loss)
			# loss.backward()
			optimizer.step()
			optimizer.zero_grad()
			
		epoch_loss = np.mean(epoch_loss)
		loss_per_epoch.append(epoch_loss)
		
		print(f"Epoch {i} loss: {epoch_loss}")

		print(f"diff: {epoch_loss - min(loss_per_epoch)}")

		if epoch_loss - min(loss_per_epoch) >= min_delta*min(loss_per_epoch):
			stop_counter += 1
		if stop_counter > tolerance:
			break
	torch.save(diffusion_model.state_dict(),os.path.join(SAVE_PATH,"unet_diff.pt"))
	torch.save(unet.state_dict(),os.path.join(SAVE_PATH,"unet_state_dict.pt"))
	return diffusion_model

In [10]:
FABRIC = Fabric(accelerator="cuda",precision="bf16-mixed")

Using bfloat16 Automatic Mixed Precision (AMP)


In [11]:
train_fn = lambda x:train(x,100)
diffusion_model = FABRIC.launch(train_fn)

Epoch 0 loss: 1170.8333333333333
diff: 0.0
Epoch 1 loss: 794.4166666666666
diff: 0.0
Epoch 2 loss: 712.4166666666666
diff: 0.0
Epoch 3 loss: 696.0833333333334
diff: 0.0
Epoch 4 loss: 656.7916666666666
diff: 0.0
Epoch 5 loss: 651.5416666666666
diff: 0.0
Epoch 6 loss: 652.7916666666666
diff: 1.25
Epoch 7 loss: 642.625
diff: 0.0
Epoch 8 loss: 615.7916666666666
diff: 0.0
Epoch 9 loss: 600.3333333333334
diff: 0.0
Epoch 10 loss: 604.75
diff: 4.416666666666629
Epoch 11 loss: 614.9583333333334
diff: 14.625
Epoch 12 loss: 587.6666666666666
diff: 0.0
Epoch 13 loss: 589.1666666666666
diff: 1.5
Epoch 14 loss: 589.75
diff: 2.0833333333333712
Epoch 15 loss: 592.2083333333334
diff: 4.5416666666667425
Epoch 16 loss: 591.2916666666666
diff: 3.625
Epoch 17 loss: 567.0416666666666
diff: 0.0
Epoch 18 loss: 585.625
diff: 18.58333333333337
Epoch 19 loss: 572.0
diff: 4.958333333333371
Epoch 20 loss: 563.9166666666666
diff: 0.0
Epoch 21 loss: 586.7083333333334
diff: 22.791666666666742
Epoch 22 loss: 552.95833

In [12]:
diffusion_model.load_state_dict(torch.load(os.path.join(SAVE_PATH,"unet_diff.pt")))

<All keys matched successfully>

In [13]:
unet.load_state_dict(torch.load(os.path.join(SAVE_PATH,"unet_state_dict.pt")))

<All keys matched successfully>

In [14]:
def generate_samples(fabric,
                     diffusion_model, 
					 condition,
                     batch_size=200,
                     n_iter=20,
                     w=0):
    # it's a bit hard to predict memory consumption so splitting in mini-batches to be safe
    num_samples = batch_size
    cond = 0
    if (condition == 0):
        cond = (torch.zeros(num_samples, 1, 512)+1).to(dtype=torch.float16,
                                                       device=DEVICE)
    elif (condition == 1):
        cond = (torch.ones(num_samples, 1, 512)+1).to(dtype=torch.float16,
                                                      device=DEVICE)
    
    diffusion_model.eval()

    print(f"Generating samples: cue {condition + 1}")
    complete_samples = []
    with fabric.autocast():
        with torch.no_grad():
            for i in range(n_iter):
                samples, _ = diffusion_model.sample(num_samples, cond=cond,w=w)
                samples = samples.cpu().numpy()
                print(samples.shape)
                complete_samples.append(samples)
    complete_samples = np.float32(np.concatenate(complete_samples))
    print(complete_samples.shape)
    return complete_samples

In [15]:
generated_signals_zero = generate_samples(FABRIC,diffusion_model, condition=0,n_iter=10,
										  batch_size=200,w=3)
generated_signals_one = generate_samples(FABRIC,diffusion_model, condition=1,n_iter=10,
										 batch_size=200,w=3)
np.save(os.path.join(SAVE_PATH,"generated_zeros.npy"),generated_signals_zero)
np.save(os.path.join(SAVE_PATH,"generated_ones.npy"),generated_signals_one)

Generating samples: cue 1
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(2000, 3, 512)
Generating samples: cue 2
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(200, 3, 512)
(2000, 3, 512)


In [16]:
def check():

	generated_signals_zero = np.load(os.path.join(SAVE_PATH,"generated_zeros.npy"))
	generated_signals_one = np.load(os.path.join(SAVE_PATH,"generated_ones.npy"))

	accuracies = []
		
	test_classifier = SimpleCSP(train_split=TRAIN_SPLIT,
								test_split=TEST_SPLIT,
								dataset=None,
								save_paths=[REAL_DATA],
								channels=CHANNELS,
								length=2.05)

	full_x,full_y = test_classifier.get_train()

	print(f"full x shape: {full_x.shape}")

	acc = test_classifier.fit()

	print(f"reaching an accuracy of {acc} without fake data")

	for real_fake_split in range(15, 46, 15):
		
		# Train new classifier with a mix of generated and real data
		
		# Change real_fake_split percent of the test_classifier data to generated signals
		n = int(len(full_x) * real_fake_split / 100)

		shuffling = np.random.permutation(full_x.shape[0])

		split_x = full_x[shuffling]
		split_y = full_y[shuffling]
		split_x[0:n//2] = generated_signals_one[0:n//2]
		split_y[0:n//2] = 1

		split_x[n//2:2*(n//2)] = generated_signals_zero[0:n//2]
		split_y[n//2:2*(n//2)] = 0

		print(f"split x shape: {split_x.shape}")

		acc = test_classifier.fit((split_x,split_y))

		accuracies.append(acc)
					
		print(f"Reaching an accuracy of {acc} using {real_fake_split}% fake data")

In [17]:
check()

(3026, 3, 512)
(3026,)
final data shape: (3026, 3, 512)
(2241, 3, 512)
(2241,)
final data shape: (2241, 3, 512)
full x shape: (3026, 3, 512)
input shape: (3026, 3, 512)
Computing rank from data with rank=None
    Using tolerance 3.4 (2.2e-16 eps * 3 dim * 5.2e+15  max singular value)
    Estimated rank (mag): 3
    MAG: rank 3 computed from 3 data channels with 0 projectors
Reducing data rank from 3 -> 3
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 3.4 (2.2e-16 eps * 3 dim * 5.1e+15  max singular value)
    Estimated rank (mag): 3
    MAG: rank 3 computed from 3 data channels with 0 projectors
Reducing data rank from 3 -> 3
Estimating covariance using EMPIRICAL
Done.
reaching an accuracy of 0.5885765283355645 without fake data
split x shape: (3026, 3, 512)
input shape: (3026, 3, 512)
Computing rank from data with rank=None
    Using tolerance 4.2 (2.2e-16 eps * 3 dim * 6.3e+15  max singular value)
    Estimated rank (mag): 3
  

## Classification

In [18]:
x = torch.rand((2,3,512),device=unet.device)

In [20]:
unet.classify(x)

tensor([[-0.1628, -0.0399],
        [ 0.2175,  0.2106]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [21]:
ones = os.path.join(SAVE_PATH,"generated_ones.npy")
zeros = os.path.join(SAVE_PATH,"generated_zeros.npy")
fake_paths = (ones,zeros)

In [22]:
deep_clf = DeepClassifier(
	model=unet,
	save_paths=["../../data/2b_iv/raw/"],
	fake_data=fake_paths,
	train_split=TRAIN_SPLIT,
	test_split=TEST_SPLIT,
	dataset=None,
	dataset_type=subject_dataset,
	length=2.05,
	index_cutoff=512
)

(3026, 3, 512)
(3026,)
we have fake data
final data shape: (4538, 3, 512)
(2241, 3, 512)
(2241,)
final data shape: (2241, 3, 512)


In [23]:
ModelSummary(unet)

  | Name          | Type                 | Params
-------------------------------------------------------
0 | time_embed    | SinusoidalPosEmb     | 0     
1 | class_embed   | Embedding            | 36    
2 | encoder       | ModuleList           | 1.2 M 
3 | decoder       | ModuleList           | 2.5 M 
4 | auxiliary_clf | BottleNeckClassifier | 2.1 M 
5 | middle_conv   | EmbedConvdown        | 400 K 
6 | output_conv   | Conv1d               | 99    
-------------------------------------------------------
6.2 M     Trainable params
0         Non-trainable params
6.2 M     Total params
24.706    Total estimated model params size (MB)

In [26]:
fine_tune = [unet.encoder,
			unet.decoder,
			unet.middle_conv,
			unet.output_conv,]

to_optimize = [{"params":i.parameters(),
	"lr":2E-5,
	"weight_decay":1E-4} for i in fine_tune]

to_optimize.append({"params":unet.auxiliary_clf.parameters(),
     "lr":1E-3,
 	"weight_decay":1E-4})

optimizer = optim.AdamW(to_optimize)

In [27]:
deep_clf.fit(fabric=FABRIC,
			 num_epochs=150,
			 lr=1E-3,
			 weight_decay=1E-4,
			 verbose=True,
			 optimizer=optimizer)

using specified optimizer
Epoch [1/150], Training Loss: 0.803, Training Accuracy: 49.69%, Validation Loss: 0.689, Validation Accuracy: 53.73%
Epoch [2/150], Training Loss: 0.704, Training Accuracy: 48.55%, Validation Loss: 0.697, Validation Accuracy: 50.38%
Epoch [3/150], Training Loss: 0.701, Training Accuracy: 50.64%, Validation Loss: 0.691, Validation Accuracy: 54.71%
Epoch [4/150], Training Loss: 0.702, Training Accuracy: 50.37%, Validation Loss: 0.694, Validation Accuracy: 51.18%
Epoch [5/150], Training Loss: 0.700, Training Accuracy: 50.31%, Validation Loss: 0.714, Validation Accuracy: 51.85%
Epoch [6/150], Training Loss: 0.700, Training Accuracy: 50.20%, Validation Loss: 0.709, Validation Accuracy: 51.49%
Epoch [7/150], Training Loss: 0.698, Training Accuracy: 51.94%, Validation Loss: 0.696, Validation Accuracy: 51.45%
Epoch [8/150], Training Loss: 0.699, Training Accuracy: 50.02%, Validation Loss: 0.696, Validation Accuracy: 50.56%
Epoch [9/150], Training Loss: 0.699, Training 

KeyboardInterrupt: 

In [4]:
from classification.classifiers import k_fold_splits

In [5]:
k_fold_splits(k=9,n_participants=9,leave_out=True)[0]

[[['train', 'test'],
  ['train', 'test'],
  ['train', 'test'],
  ['train', 'test'],
  [],
  ['train', 'test'],
  ['train', 'test'],
  ['train', 'test'],
  ['train', 'test']],
 [[], [], [], [], ['test'], [], [], [], []]]