In [1]:
import numpy as np
import pandas as pd
import os
import scipy.io
from scipy.signal import butter, filtfilt, iirnotch, cheby2
from einops import rearrange
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from lightning import Fabric
import yaml
import lightning
from pytorch_lightning.utilities.model_summary import ModelSummary
from einops import repeat
from typing import Tuple
from torch.nn import functional as F

from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from mne.decoding import CSP
from sklearn.metrics import accuracy_score

import sys

sys.path.append("../../")

from classification.loaders import EEGDataset,load_data
from models.unet.eeg_unets import Unet,UnetConfig, BottleNeckClassifier, Unet1D
from classification.classifiers import DeepClassifier
from classification.loaders import subject_dataset, CSP_subject_dataset
from classification.open_bci_loaders import OpenBCIDataset,OpenBCISubject,load_files,CSPOpenBCISubject
from models.unet import base_eegnet
from ntd.networks import LongConv
from ntd.diffusion_model import Diffusion
from ntd.utils.kernels_and_diffusion_utils import WhiteNoiseProcess
from classification_heads import ClassificationHead, EEGNetHead, DiffusionClf
import json



In [2]:
torch.set_float32_matmul_precision("medium")
FABRIC = Fabric(accelerator="cuda",precision="bf16-mixed")
lightning.seed_everything(0)

Using bfloat16 Automatic Mixed Precision (AMP)
Seed set to 0


0

## Datasets

### Constants

In [3]:
c = np.split(np.arange(3*9),3)
c = np.concatenate([c[0],c[2]])


In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SIGNAL_LENGTH = 512
N_CSP_CHANNELS = 18
N_EEG_CHANNELS = 2

CONF_PATH = "../diffusion/conf"

FS_BCI_COMP = 250
DATA_PATH_BCI_COMP = "../../data/2b_iv"
SAVE_PATH_BCI_COMP = "../../saved_models/raw_eeg"

REAL_DATA_BCI_COMP = "../../data/2b_iv/raw"
CSP_DATA_BCI_COMP = "../../data/2b_iv/csp"

TRAIN_SPLIT_BCI_COMP = 9*[["train"]]
TEST_SPLIT_BCI_COMP = 9*[["test"]]

CHANNELS_BCI_COMP = [0,2]

c = np.split(np.arange(3*9),3)
c = np.concatenate([c[0],c[2]])

CSP_CHANNELS_BCI_COMP = c


train_split = 2*[["train"]]
test_split = 2*[["test"]]
save_path = os.path.join("processed","raw")
small_stride = os.path.join("processed","data/collected_data/small_stride")


In [5]:
TRAIN_SPLIT_OPENBCI = 2*[["train"]]
TEST_SPLIT_OPENBCI = 2*[["test"]]

SAVE_PATH_OPENBCI = os.path.join("processed","raw")
CSP_SAVE_PATH_OPENBCI = os.path.join("processed","data/collected_data/csp")

CSP_CHANNELS_OPENBCI = np.arange(2*9)
CHANNELS_OPENBCI = np.arange(0,2)

In [6]:
EEGNET_LR = 1E-3
EEGNET_WEIGHT_DECAY = 1E-4

In [7]:
DIFFUSION_LR = 6E-4
DIFFUSION_N_EPOCHS = 250
DIFFUSION_BATCH_SIZE = 64

TIME_DIM = 12
HIDDEN_CHANNEL = 64
KERNEL_SIZE = 65
NUM_SCALES = 2
DECAY_MIN = 2
DECAY_MAX = 2
ACTIVATION_TYPE = "leaky_relu"

USE_FFT_CONV = KERNEL_SIZE * (2**(NUM_SCALES - 1)) >= 100
NUM_TIMESTEPS = 1024
SCHEDULE = "linear"
START_BETA = 1E-4
END_BETA = 8E-2

CLF_EPOCHS = 50

DEBUG = False

In [8]:
with open(os.path.join(CONF_PATH, "train.yaml"), "r") as f:
    train_yaml = yaml.safe_load(f)
    
with open(os.path.join(CONF_PATH, "classifier.yaml"), "r") as f:
    classifier_yaml = yaml.safe_load(f)
    
with open(os.path.join(CONF_PATH, "network.yaml"), "r") as f:
    network_yaml = yaml.safe_load(f)
    
with open(os.path.join(CONF_PATH, "diffusion.yaml"), "r") as f:
    diffusion_yaml = yaml.safe_load(f)

### BCI Competition

#### Normal

In [9]:
dataset = {}
for i in range(1,10):
    mat_train,mat_test = load_data("../../data/2b_iv",i)
    dataset[f"subject_{i}"] = {"train":mat_train,"test":mat_test}

train_dataset_bci_comp = EEGDataset(subject_splits=TRAIN_SPLIT_BCI_COMP,
                    dataset=None,
                    save_paths=[REAL_DATA_BCI_COMP],
                    subject_dataset_type=subject_dataset,
                    channels=CHANNELS_BCI_COMP,
                    sanity_check=False,
                    length=2.05)

test_dataset_bci_comp = EEGDataset(subject_splits=TEST_SPLIT_BCI_COMP,
                    dataset=None,
                    save_paths=[REAL_DATA_BCI_COMP],
                    channels=CHANNELS_BCI_COMP,
                    sanity_check=False,
                    length=2.05)

(3026, 2, 512)
(3026,)
final data shape: (3026, 2, 512)
(2241, 2, 512)
(2241,)
final data shape: (2241, 2, 512)


#### CSP

In [10]:
train_csp_dataset_bci_comp = EEGDataset(subject_splits=TRAIN_SPLIT_BCI_COMP,
                    dataset=None,
                    save_paths=[CSP_DATA_BCI_COMP],
                    subject_dataset_type=CSP_subject_dataset,
                    channels=CSP_CHANNELS_BCI_COMP,
                    sanity_check=False,
                    length=2.05)

test_csp_dataset_bci_comp = EEGDataset(subject_splits=TEST_SPLIT_BCI_COMP,
                    dataset=None,
                    save_paths=[CSP_DATA_BCI_COMP],
                    channels=CSP_CHANNELS_BCI_COMP,
                    sanity_check=False,
                    length=2.05)

(3026, 18, 512)
(3026,)
final data shape: (3026, 18, 512)
(2241, 18, 512)
(2241,)
final data shape: (2241, 18, 512)


### OpenBCI

#### Normal

In [11]:
print(f"path {os.getcwd()}")
files = load_files("../../data/collected_data/")

train_dataset_openbci = OpenBCIDataset(
	subject_splits=TRAIN_SPLIT_OPENBCI,
	dataset=files,
	save_paths=[SAVE_PATH_OPENBCI],
	fake_data=None,
	dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=128,
	epoch_length=512
)

test_dataset_openbci = OpenBCIDataset(
	subject_splits=TEST_SPLIT_OPENBCI,
	save_paths=[SAVE_PATH_OPENBCI],
	fake_data=None,
	dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=128,
	epoch_length=512
)

path d:\Machine learning\MI SSL\motor-imagery-classification-2024\models\diffusion
Saving new data
(416, 2, 512)
(416,)
final data shape: (416, 2, 512)
Loading saved data
(208, 2, 512)
(208,)
final data shape: (208, 2, 512)


In [12]:
print(f"path {os.getcwd()}")
files = load_files("../../data/collected_data/")

train_s25_dataset_openbci = OpenBCIDataset(
	subject_splits=TRAIN_SPLIT_OPENBCI,
	dataset=files,
	save_paths=[small_stride],
	fake_data=None,
	dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512
)

test_s25_dataset_openbci = OpenBCIDataset(
	subject_splits=TEST_SPLIT_OPENBCI,
	save_paths=[small_stride],
	fake_data=None,
	dataset_type=OpenBCISubject,
	channels=np.arange(0,2),
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512
)

path d:\Machine learning\MI SSL\motor-imagery-classification-2024\models\diffusion
Saving new data
(1984, 2, 512)
(1984,)
final data shape: (1984, 2, 512)
Loading saved data
(992, 2, 512)
(992,)
final data shape: (992, 2, 512)


#### CSP

In [13]:

train_csp_dataset_openbci = OpenBCIDataset(
	subject_splits=TRAIN_SPLIT_OPENBCI,
	dataset=files,
	save_paths=[CSP_SAVE_PATH_OPENBCI],
	fake_data=None,
	dataset_type=CSPOpenBCISubject,
	channels=CSP_CHANNELS_OPENBCI,
	subject_channels=["ch2","ch5"],
	stride=128,
	epoch_length=512
)

test_csp_dataset_openbci = OpenBCIDataset(
	subject_splits=TEST_SPLIT_OPENBCI,
	save_paths=[CSP_SAVE_PATH_OPENBCI],
	fake_data=None,
	dataset_type=CSPOpenBCISubject,
	channels=CSP_CHANNELS_OPENBCI,
	subject_channels=["ch2","ch5"],
	stride=128,
	epoch_length=512
)

Saving new data
(416, 18, 512)
(416,)
final data shape: (416, 18, 512)
Loading saved data
(208, 18, 512)
(208,)
final data shape: (208, 18, 512)


---
## CSP

In [14]:
def csp_train_test(train_dset,test_dset):

	x_train,y_train = train_dset.data
	x_test,y_test = test_dset.data

	x_train,y_train = np.float64(x_train),np.float64(y_train)
	x_test,y_test = np.float64(x_test),np.float64(y_test)

	csp = CSP(n_components=x_train.shape[1],reg=None,log=True,norm_trace=False)
	svm = SVC(C=1)

	clf = Pipeline(steps=[("csp",csp),
						("classification",svm)])

	clf.fit(x_train,y_train)

	y_train_pred = clf.predict(x_train)
	y_test_pred = clf.predict(x_test)
	train_acc = accuracy_score(y_train,y_train_pred)
	test_acc = accuracy_score(y_test,y_test_pred)

	print(f"train accuracy: {train_acc}")
	print(f"test accuracy: {test_acc}")
	return test_acc


### BCI Competition

In [15]:
csp_bci_comp_acc = csp_train_test(train_csp_dataset_bci_comp,test_csp_dataset_bci_comp)

Computing rank from data with rank=None


    Using tolerance 9.6 (2.2e-16 eps * 18 dim * 2.4e+15  max singular value)
    Estimated rank (mag): 18
    MAG: rank 18 computed from 18 data channels with 0 projectors
Reducing data rank from 18 -> 18
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 9.6 (2.2e-16 eps * 18 dim * 2.4e+15  max singular value)
    Estimated rank (mag): 18
    MAG: rank 18 computed from 18 data channels with 0 projectors
Reducing data rank from 18 -> 18
Estimating covariance using EMPIRICAL
Done.
train accuracy: 0.7699933906146729
test accuracy: 0.7175368139223561


### OpenBCI

In [16]:
csp_openbci_acc = csp_train_test(train_csp_dataset_openbci,test_csp_dataset_openbci)

Computing rank from data with rank=None
    Using tolerance 2.7 (2.2e-16 eps * 18 dim * 6.6e+14  max singular value)
    Estimated rank (mag): 18
    MAG: rank 18 computed from 18 data channels with 0 projectors
Reducing data rank from 18 -> 18
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank=None


    Using tolerance 2.6 (2.2e-16 eps * 18 dim * 6.5e+14  max singular value)
    Estimated rank (mag): 18
    MAG: rank 18 computed from 18 data channels with 0 projectors
Reducing data rank from 18 -> 18
Estimating covariance using EMPIRICAL
Done.
train accuracy: 0.7932692307692307
test accuracy: 0.6105769230769231


---
## EEGNet

#### Setup

In [17]:
EEGNET_LR = 1E-3
EEGNET_WEIGHT_DECAY = 1E-4

In [18]:
def eeg_net_loop(
		real_path,
		train_split,
		test_split,
		subject_dataset_type,
		channels,
		fake_paths=None,
		**dset_kwargs
):

	model = base_eegnet.EEGNet(2,224)

	clf = DeepClassifier(
		model=model,
		save_paths=[real_path],
		train_split=train_split,
		test_split=test_split,
		dataset=None,
		subject_dataset_type=subject_dataset_type,
		channels=channels,
		fake_data=fake_paths,
		length=2.05,
		index_cutoff=512,
		**dset_kwargs,
	)

	return clf.fit(FABRIC,50,EEGNET_LR,EEGNET_WEIGHT_DECAY)

### BCI Competition

In [19]:
eegnet_bci_comp_acc = eeg_net_loop(
	real_path=REAL_DATA_BCI_COMP,
	train_split=TRAIN_SPLIT_BCI_COMP,
	test_split=TEST_SPLIT_BCI_COMP,
	subject_dataset_type=subject_dataset,
	channels=CHANNELS_BCI_COMP,
)

(3026, 2, 512)
(3026,)
final data shape: (3026, 2, 512)
(2241, 2, 512)
(2241,)
final data shape: (2241, 2, 512)


checkpointing
Epoch [1/50], Training Loss: 0.742, Training Accuracy: 53.44%, Validation Loss: 0.656, Validation Accuracy: 58.93%
checkpointing
Epoch [2/50], Training Loss: 0.650, Training Accuracy: 62.99%, Validation Loss: 0.539, Validation Accuracy: 72.14%
checkpointing
Epoch [3/50], Training Loss: 0.612, Training Accuracy: 65.20%, Validation Loss: 0.507, Validation Accuracy: 74.73%
Min loss: 0.507421875 vs 0.564453125
Epoch [4/50], Training Loss: 0.587, Training Accuracy: 67.58%, Validation Loss: 0.564, Validation Accuracy: 68.93%
Min loss: 0.507421875 vs 0.53046875
Epoch [5/50], Training Loss: 0.577, Training Accuracy: 68.47%, Validation Loss: 0.530, Validation Accuracy: 69.82%
Min loss: 0.507421875 vs 0.5166852678571429
Epoch [6/50], Training Loss: 0.572, Training Accuracy: 69.96%, Validation Loss: 0.517, Validation Accuracy: 71.07%
Min loss: 0.507421875 vs 0.5082589285714286
Epoch [7/50], Training Loss: 0.565, Training Accuracy: 68.94%, Validation Loss: 0.508, Validation Accuracy:

### OpenBCI

In [20]:
eeg_net_openbci_acc = eeg_net_loop(
	real_path=small_stride,
	train_split=TRAIN_SPLIT_OPENBCI,
	test_split=TEST_SPLIT_OPENBCI,
	subject_dataset_type=OpenBCISubject,
	channels=CHANNELS_OPENBCI,
	dataset_type=OpenBCIDataset,
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512,
)

Loading saved data
(1984, 2, 512)
(1984,)
final data shape: (1984, 2, 512)
Loading saved data
(992, 2, 512)
(992,)
final data shape: (992, 2, 512)
checkpointing
Epoch [1/50], Training Loss: 0.760, Training Accuracy: 50.15%, Validation Loss: 0.696, Validation Accuracy: 53.63%
checkpointing
Epoch [2/50], Training Loss: 0.711, Training Accuracy: 53.02%, Validation Loss: 0.693, Validation Accuracy: 53.63%
checkpointing
Epoch [3/50], Training Loss: 0.704, Training Accuracy: 55.85%, Validation Loss: 0.659, Validation Accuracy: 60.08%
Min loss: 0.658935546875 vs 0.679443359375
Epoch [4/50], Training Loss: 0.692, Training Accuracy: 58.01%, Validation Loss: 0.679, Validation Accuracy: 58.67%
Min loss: 0.658935546875 vs 0.669677734375
Epoch [5/50], Training Loss: 0.673, Training Accuracy: 58.87%, Validation Loss: 0.670, Validation Accuracy: 57.86%
checkpointing
Epoch [6/50], Training Loss: 0.644, Training Accuracy: 62.10%, Validation Loss: 0.648, Validation Accuracy: 60.69%
checkpointing
Epoch [

---
## EEGNet on bands

In [21]:
def eeg_net_bands_loop(
		real_path,
		train_split,
		test_split,
		subject_dataset_type,
		channels,
		fake_paths=None,
		**dset_kwargs
):

	model = EEGNetHead(18,256)

	clf = DeepClassifier(
		model=model,
		save_paths=[real_path],
		train_split=train_split,
		test_split=test_split,
		dataset=None,
		subject_dataset_type=subject_dataset_type,
		channels=channels,
		fake_data=fake_paths,
		length=2.05,
		index_cutoff=512,
		**dset_kwargs,
	)

	x = clf.sample_batch()
	print(x.shape)

	return clf.fit(FABRIC,50,EEGNET_LR,EEGNET_WEIGHT_DECAY)

In [22]:
eegnet_bands_bci_comp_acc = eeg_net_bands_loop(
	real_path=CSP_DATA_BCI_COMP,
	train_split=TRAIN_SPLIT_BCI_COMP,
	test_split=TEST_SPLIT_BCI_COMP,
	subject_dataset_type=CSP_subject_dataset,
	channels=CSP_CHANNELS_BCI_COMP,
)

(3026, 18, 512)
(3026,)
final data shape: (3026, 18, 512)
(2241, 18, 512)
(2241,)
final data shape: (2241, 18, 512)
torch.Size([32, 18, 512])
checkpointing
Epoch [1/50], Training Loss: 0.747, Training Accuracy: 50.50%, Validation Loss: 0.734, Validation Accuracy: 51.25%
checkpointing
Epoch [2/50], Training Loss: 0.709, Training Accuracy: 53.44%, Validation Loss: 0.708, Validation Accuracy: 51.70%
checkpointing
Epoch [3/50], Training Loss: 0.703, Training Accuracy: 53.07%, Validation Loss: 0.687, Validation Accuracy: 55.45%
checkpointing
Epoch [4/50], Training Loss: 0.671, Training Accuracy: 58.43%, Validation Loss: 0.629, Validation Accuracy: 63.57%
checkpointing
Epoch [5/50], Training Loss: 0.648, Training Accuracy: 60.71%, Validation Loss: 0.612, Validation Accuracy: 60.80%
checkpointing
Epoch [6/50], Training Loss: 0.626, Training Accuracy: 64.14%, Validation Loss: 0.563, Validation Accuracy: 67.68%
Min loss: 0.5627232142857143 vs 0.5780133928571428
Epoch [7/50], Training Loss: 0.61

In [23]:
eeg_net_bands_openbci_acc = eeg_net_bands_loop(
	real_path=CSP_SAVE_PATH_OPENBCI,
	train_split=TRAIN_SPLIT_OPENBCI,
	test_split=TEST_SPLIT_OPENBCI,
	subject_dataset_type=CSPOpenBCISubject,
	channels=CSP_CHANNELS_OPENBCI,
	dataset_type=OpenBCIDataset,
	subject_channels=["ch2","ch5"],
	stride=25,
	epoch_length=512,
)

Loading saved data
(416, 18, 512)
(416,)
final data shape: (416, 18, 512)
Loading saved data
(208, 18, 512)
(208,)
final data shape: (208, 18, 512)
torch.Size([32, 18, 512])
checkpointing
Epoch [1/50], Training Loss: 0.765, Training Accuracy: 52.64%, Validation Loss: 0.697, Validation Accuracy: 50.96%
Min loss: 0.697265625 vs 0.8056640625
Epoch [2/50], Training Loss: 0.725, Training Accuracy: 53.85%, Validation Loss: 0.806, Validation Accuracy: 38.46%
Min loss: 0.697265625 vs 0.791015625
Epoch [3/50], Training Loss: 0.716, Training Accuracy: 54.09%, Validation Loss: 0.791, Validation Accuracy: 41.35%
Min loss: 0.697265625 vs 0.7890625
Epoch [4/50], Training Loss: 0.714, Training Accuracy: 57.69%, Validation Loss: 0.789, Validation Accuracy: 52.88%
Min loss: 0.697265625 vs 0.7646484375
Epoch [5/50], Training Loss: 0.730, Training Accuracy: 54.57%, Validation Loss: 0.765, Validation Accuracy: 49.04%
Min loss: 0.697265625 vs 0.7275390625
Epoch [6/50], Training Loss: 0.703, Training Accura

---
## SLC + EEGNet

#### Constants and setup

In [24]:
DIFFUSION_LR = 6E-4
DIFFUSION_N_EPOCHS = 250
DIFFUSION_BATCH_SIZE = 64

TIME_DIM = 12
HIDDEN_CHANNEL = 64
KERNEL_SIZE = 65
NUM_SCALES = 2
DECAY_MIN = 2
DECAY_MAX = 2
ACTIVATION_TYPE = "leaky_relu"

USE_FFT_CONV = KERNEL_SIZE * (2**(NUM_SCALES - 1)) >= 100
NUM_TIMESTEPS = 1024
SCHEDULE = "linear"
START_BETA = 1E-4
END_BETA = 8E-2

CLF_EPOCHS = 50

In [25]:
def diff_train(fabric,
			   diffusion_model,
			   train_loader):
	optimizer = optim.AdamW(
        diffusion_model.parameters(),
        lr=DIFFUSION_LR,
    )

	diffusion_model,optimizer = fabric.setup(diffusion_model,optimizer)

	train_loader = fabric.setup_dataloaders(train_loader)
	for i in range(DIFFUSION_N_EPOCHS):
			
			epoch_loss = []
			for batch in train_loader:
				
				with fabric.autocast():

					signal,cue = batch
					signal = signal.to(torch.bfloat16)
					cue = cue.to(torch.bfloat16)
					cond = cue.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, SIGNAL_LENGTH).to(DEVICE)
					loss = diffusion_model.train_batch(signal.to(DEVICE),
										 cond=cond)
				loss = torch.mean(loss)
				
				epoch_loss.append(loss.item())
				
				fabric.backward(loss)
				optimizer.step()
				optimizer.zero_grad()
				
			epoch_loss = np.mean(epoch_loss)

			print(f"Epoch {i} loss: {epoch_loss}")

fabric_train = lambda model,loader : diff_train(FABRIC,model,loader)

In [26]:
DEBUG = False

In [27]:
def generate_samples(fabric,
                     diffusion_model, 
					 condition,
                     batch_size=200,
                     n_iter=20,
                     w=0):
    # it's a bit hard to predict memory consumption so splitting in mini-batches to be safe
    num_samples = batch_size
    cond = 0
    if (condition == 0):
        cond = torch.zeros(num_samples, 1, SIGNAL_LENGTH).to(DEVICE)
    elif (condition == 1):
        cond = torch.ones(num_samples, 1, SIGNAL_LENGTH).to(DEVICE)
    
    diffusion_model.eval()

    print(f"Generating samples: cue {condition}")
    k = 1 if DEBUG else n_iter
    complete_samples = []
    with fabric.autocast():
        with torch.no_grad():
            for i in range(k):
                samples, _ = diffusion_model.sample(num_samples, cond=cond,w=w)
                samples = samples.cpu().numpy()
                print(samples.shape)
                complete_samples.append(samples)
    complete_samples = np.float32(np.concatenate(complete_samples))
    if DEBUG:
        complete_samples = repeat(complete_samples,"n ... -> (n k) ...",k=n_iter)
    print(complete_samples.shape)
    return complete_samples


In [28]:
def diffusion_loop(
		diffusion_dset,
		dset_path,
		train_split,
		test_split,
		save_folder,
		name,
		clf_head,
		train=True,
		pre_train=True,
		frozen=False,
		generate=False,
		gen_batch_size=200,
		gen_iters=20,
		use_gen=False,
		**head_kwargs):
	
	if not os.path.isdir(save_folder):
		os.makedirs(save_folder)

	network = LongConv(
			signal_length=SIGNAL_LENGTH,
			signal_channel=2, # The CSP classifier components
			time_dim=TIME_DIM,
			cond_channel=network_yaml["cond_channel"], # The cond channel will contain the cue (0 or 1)
			hidden_channel=HIDDEN_CHANNEL,
			in_kernel_size=KERNEL_SIZE,
			out_kernel_size=KERNEL_SIZE,
			slconv_kernel_size=KERNEL_SIZE,
			num_scales=NUM_SCALES,
			decay_min=DECAY_MIN,
			decay_max=DECAY_MAX,
			heads=network_yaml["heads"],
			activation_type=ACTIVATION_TYPE,
			use_fft_conv=USE_FFT_CONV,
		)
	
	noise_sampler = WhiteNoiseProcess(1.0, SIGNAL_LENGTH)

	diffusion_model = Diffusion(
		network=network,
		diffusion_time_steps=NUM_TIMESTEPS,
		noise_sampler=noise_sampler,
		mal_dist_computer=noise_sampler,
		schedule=SCHEDULE,
		start_beta=START_BETA,
		end_beta=END_BETA,
	)

	train_loader = DataLoader(
		diffusion_dset,
		DIFFUSION_BATCH_SIZE,
	)

	fabric_train = lambda fabric : diff_train(fabric,diffusion_model,train_loader)

	diff_path = os.path.join(save_folder,name)
	ones_path = os.path.join(save_folder,"ones.npy")
	zeros_path = os.path.join(save_folder,"zeros.npy")

	if train:
		FABRIC.launch(fabric_train)

		torch.save(diffusion_model.state_dict(),diff_path)

	if not train:
		diffusion_model = FABRIC.setup(diffusion_model)

	if pre_train:
		diffusion_model.load_state_dict(torch.load(diff_path))

	if generate and pre_train:
		ones = generate_samples(FABRIC,diffusion_model,1,
						  gen_batch_size,gen_iters)
		zeros = generate_samples(FABRIC,diffusion_model,0,
						  gen_batch_size,gen_iters)
		
		np.save(ones_path,ones)
		np.save(zeros_path,zeros)

	head = clf_head(**head_kwargs)

	clf = DiffusionClf(diffusion_model,head,freeze=frozen)

	clf.to(DEVICE)

	fake_paths = [ones_path,zeros_path] if use_gen else None

	slc_clf = DeepClassifier(
		model=clf.to(DEVICE),
		save_paths=[dset_path],
		fake_data=fake_paths,
		train_split=train_split,
		test_split=test_split,
		**head_kwargs
	)

	optimizer = optim.AdamW([
		{"params":slc_clf.model.model.parameters(),"lr":2E-5,"weight_decay":1E-4},
		{"params":slc_clf.model.clf.parameters(),"lr":1E-3,"weight_decay":1E-4}
	])

	test_acc = slc_clf.fit(fabric=FABRIC,
			 num_epochs=CLF_EPOCHS,
			 lr=1E-3,
			 weight_decay=1E-4,
			 verbose=True,
			 optimizer=optimizer)
	
	print(f"\n\nTest acc: {test_acc}")

	return test_acc

### BCI Competition

In [29]:
diff_bci_comp_acc = diffusion_loop(
	diffusion_dset=train_dataset_bci_comp,
	dset_path=REAL_DATA_BCI_COMP,
	train_split=TRAIN_SPLIT_BCI_COMP,
	test_split=TEST_SPLIT_BCI_COMP,
	save_folder="results/saved_models/bci_comp",
	name="slc_eegnet.pt",
	clf_head=EEGNetHead,
	c_in=128,
	d_out=256,
	dataset=None,
	subject_dataset_type=subject_dataset,
	length=2.05,
	index_cutoff=512,
	channels=CHANNELS_BCI_COMP,
	train=False,
	frozen=True,
	generate=False,
	use_gen=True,
	gen_batch_size=800,
	gen_iters=5,
)

(3026, 2, 512)
(3026,)
we have fake data
final data shape: (4538, 2, 512)
(2241, 2, 512)
(2241,)
final data shape: (2241, 2, 512)
using specified optimizer
checkpointing
Epoch [1/50], Training Loss: 0.734, Training Accuracy: 51.41%, Validation Loss: 0.669, Validation Accuracy: 60.00%
checkpointing
Epoch [2/50], Training Loss: 0.684, Training Accuracy: 57.25%, Validation Loss: 0.611, Validation Accuracy: 66.43%
Min loss: 0.6109933035714286 vs 0.6124441964285714
Epoch [3/50], Training Loss: 0.675, Training Accuracy: 58.48%, Validation Loss: 0.612, Validation Accuracy: 65.71%
checkpointing
Epoch [4/50], Training Loss: 0.654, Training Accuracy: 60.53%, Validation Loss: 0.596, Validation Accuracy: 69.82%
checkpointing
Epoch [5/50], Training Loss: 0.649, Training Accuracy: 62.38%, Validation Loss: 0.576, Validation Accuracy: 71.88%
checkpointing
Epoch [6/50], Training Loss: 0.629, Training Accuracy: 63.86%, Validation Loss: 0.552, Validation Accuracy: 69.55%
checkpointing
Epoch [7/50], Train

### OpenBCI

In [30]:
diff_openbci_acc = diffusion_loop(
	diffusion_dset=train_dataset_openbci,
	dset_path=SAVE_PATH_OPENBCI,
	train_split=TRAIN_SPLIT_OPENBCI,
	test_split=TEST_SPLIT_OPENBCI,
	save_folder="results/saved_models/openbci",
	name="slc_eegnet.pt",
	clf_head=EEGNetHead,
	c_in=128,
	d_out=256,
	dataset=None,
	dataset_type = OpenBCIDataset,
	subject_dataset_type=OpenBCISubject,
	subject_channels=["ch2","ch5"],
	length=2.0,
	epoch_length=512,
	index_cutoff=512,
	channels=CHANNELS_OPENBCI,
	train=False,
	frozen=True,
	generate=False,
	use_gen=True,
	gen_batch_size=800,
	gen_iters=5,
)

Loading saved data
(416, 2, 512)
(416,)
we have fake data
final data shape: (624, 2, 512)
Loading saved data
(208, 2, 512)
(208,)
final data shape: (208, 2, 512)
using specified optimizer
checkpointing
Epoch [1/50], Training Loss: 0.749, Training Accuracy: 50.32%, Validation Loss: 0.692, Validation Accuracy: 52.88%
checkpointing
Epoch [2/50], Training Loss: 0.722, Training Accuracy: 50.16%, Validation Loss: 0.691, Validation Accuracy: 51.92%
Min loss: 0.69140625 vs 0.7119140625
Epoch [3/50], Training Loss: 0.731, Training Accuracy: 47.28%, Validation Loss: 0.712, Validation Accuracy: 54.81%
checkpointing
Epoch [4/50], Training Loss: 0.713, Training Accuracy: 51.92%, Validation Loss: 0.679, Validation Accuracy: 54.81%
checkpointing
Epoch [5/50], Training Loss: 0.707, Training Accuracy: 46.15%, Validation Loss: 0.677, Validation Accuracy: 53.85%
Min loss: 0.6767578125 vs 0.6875
Epoch [6/50], Training Loss: 0.715, Training Accuracy: 48.88%, Validation Loss: 0.688, Validation Accuracy: 57.

---
## SLC from scratch

### BCI Competition

In [31]:
slc_bci_comp_acc = diffusion_loop(
	diffusion_dset=train_dataset_bci_comp,
	dset_path=REAL_DATA_BCI_COMP,
	train_split=TRAIN_SPLIT_BCI_COMP,
	test_split=TEST_SPLIT_BCI_COMP,
	save_folder="results/saved_models/bci_comp",
	name="slc_eegnet.pt",
	clf_head=EEGNetHead,
	c_in=128,
	d_out=256,
	dataset=None,
	subject_dataset_type=subject_dataset,
	length=2.05,
	index_cutoff=512,
	channels=CHANNELS_BCI_COMP,
	train=False,
	pre_train=False,
	frozen=False
)

(3026, 2, 512)
(3026,)
final data shape: (3026, 2, 512)
(2241, 2, 512)
(2241,)
final data shape: (2241, 2, 512)
using specified optimizer
checkpointing
Epoch [1/50], Training Loss: 0.743, Training Accuracy: 50.93%, Validation Loss: 0.720, Validation Accuracy: 48.30%
checkpointing
Epoch [2/50], Training Loss: 0.701, Training Accuracy: 52.51%, Validation Loss: 0.706, Validation Accuracy: 49.91%
checkpointing
Epoch [3/50], Training Loss: 0.687, Training Accuracy: 55.58%, Validation Loss: 0.694, Validation Accuracy: 53.04%
checkpointing
Epoch [4/50], Training Loss: 0.666, Training Accuracy: 58.72%, Validation Loss: 0.670, Validation Accuracy: 58.30%
checkpointing
Epoch [5/50], Training Loss: 0.631, Training Accuracy: 65.00%, Validation Loss: 0.597, Validation Accuracy: 65.89%
checkpointing
Epoch [6/50], Training Loss: 0.596, Training Accuracy: 67.42%, Validation Loss: 0.584, Validation Accuracy: 65.71%
checkpointing
Epoch [7/50], Training Loss: 0.552, Training Accuracy: 71.08%, Validation 

### OpenBCI

In [32]:
slc_openbci_acc = diffusion_loop(
	diffusion_dset=train_s25_dataset_openbci,
	dset_path=SAVE_PATH_OPENBCI,
	train_split=TRAIN_SPLIT_OPENBCI,
	test_split=TEST_SPLIT_OPENBCI,
	save_folder="results/saved_models/openbci",
	name="slc_eegnet.pt",
	clf_head=EEGNetHead,
	c_in=128,
	d_out=256,
	dataset=None,
	dataset_type = OpenBCIDataset,
	subject_dataset_type=OpenBCISubject,
	subject_channels=["ch2","ch5"],
	length=2.0,
	epoch_length=512,
	index_cutoff=512,
	channels=CHANNELS_OPENBCI,
	train=False,
	pre_train=False,
	frozen=False
)

Loading saved data
(416, 2, 512)
(416,)
final data shape: (416, 2, 512)
Loading saved data
(208, 2, 512)
(208,)
final data shape: (208, 2, 512)
using specified optimizer
checkpointing
Epoch [1/50], Training Loss: 0.801, Training Accuracy: 48.56%, Validation Loss: 0.692, Validation Accuracy: 50.00%
Min loss: 0.6923828125 vs 0.708984375
Epoch [2/50], Training Loss: 0.721, Training Accuracy: 55.29%, Validation Loss: 0.709, Validation Accuracy: 50.96%
Min loss: 0.6923828125 vs 0.734375
Epoch [3/50], Training Loss: 0.668, Training Accuracy: 61.06%, Validation Loss: 0.734, Validation Accuracy: 44.23%
Min loss: 0.6923828125 vs 0.724609375
Epoch [4/50], Training Loss: 0.661, Training Accuracy: 58.89%, Validation Loss: 0.725, Validation Accuracy: 45.19%
Min loss: 0.6923828125 vs 0.751953125
Epoch [5/50], Training Loss: 0.644, Training Accuracy: 62.50%, Validation Loss: 0.752, Validation Accuracy: 43.27%
Min loss: 0.6923828125 vs 0.7880859375
Epoch [6/50], Training Loss: 0.624, Training Accuracy

---
## EEGNet with synthetic

### BCI competition

In [33]:
bci_comp_fake = [os.path.join("results/saved_models/bci_comp",i) for i in ["ones.npy","zeros.npy"]]
eegnet_synth_bci_comp_acc = eeg_net_loop(
	real_path=REAL_DATA_BCI_COMP,
	train_split=TRAIN_SPLIT_BCI_COMP,
	test_split=TEST_SPLIT_BCI_COMP,
	subject_dataset_type=subject_dataset,
	channels=CHANNELS_BCI_COMP,
	fake_paths=bci_comp_fake
)

(3026, 2, 512)
(3026,)
we have fake data
final data shape: (4538, 2, 512)
(2241, 2, 512)
(2241,)
final data shape: (2241, 2, 512)
checkpointing
Epoch [1/50], Training Loss: 0.708, Training Accuracy: 56.21%, Validation Loss: 0.568, Validation Accuracy: 70.00%
checkpointing
Epoch [2/50], Training Loss: 0.622, Training Accuracy: 64.59%, Validation Loss: 0.562, Validation Accuracy: 72.23%
checkpointing
Epoch [3/50], Training Loss: 0.601, Training Accuracy: 66.90%, Validation Loss: 0.501, Validation Accuracy: 74.38%
checkpointing
Epoch [4/50], Training Loss: 0.590, Training Accuracy: 68.66%, Validation Loss: 0.494, Validation Accuracy: 75.09%
checkpointing
Epoch [5/50], Training Loss: 0.581, Training Accuracy: 68.47%, Validation Loss: 0.476, Validation Accuracy: 76.61%
Min loss: 0.476171875 vs 0.4763950892857143
Epoch [6/50], Training Loss: 0.587, Training Accuracy: 68.31%, Validation Loss: 0.476, Validation Accuracy: 75.80%
Min loss: 0.476171875 vs 0.48660714285714285
Epoch [7/50], Trainin

### OpenBCI

In [34]:
openbci_fake = [os.path.join("results/saved_models/openbci",i) for i in ["ones.npy","zeros.npy"]]
eegnet_synth_openbci_acc = eeg_net_loop(
	real_path=small_stride,
	train_split=TRAIN_SPLIT_OPENBCI,
	test_split=TEST_SPLIT_OPENBCI,
	subject_dataset_type=OpenBCISubject,
	channels=CHANNELS_OPENBCI,
	fake_paths=openbci_fake,
	dataset_type=OpenBCIDataset,
	subject_channels=["ch2","ch5"],
	stride=128,
	epoch_length=512,
)

Loading saved data
(1984, 2, 512)
(1984,)
we have fake data
final data shape: (2976, 2, 512)
Loading saved data
(992, 2, 512)
(992,)
final data shape: (992, 2, 512)
checkpointing
Epoch [1/50], Training Loss: 0.723, Training Accuracy: 48.32%, Validation Loss: 0.692, Validation Accuracy: 53.23%
Min loss: 0.691650390625 vs 0.69287109375
Epoch [2/50], Training Loss: 0.711, Training Accuracy: 48.49%, Validation Loss: 0.693, Validation Accuracy: 51.81%
Min loss: 0.691650390625 vs 0.692626953125
Epoch [3/50], Training Loss: 0.705, Training Accuracy: 49.29%, Validation Loss: 0.693, Validation Accuracy: 50.81%
Min loss: 0.691650390625 vs 0.696533203125
Epoch [4/50], Training Loss: 0.698, Training Accuracy: 51.31%, Validation Loss: 0.697, Validation Accuracy: 47.18%
Min loss: 0.691650390625 vs 0.693115234375
Epoch [5/50], Training Loss: 0.700, Training Accuracy: 50.47%, Validation Loss: 0.693, Validation Accuracy: 48.59%
Min loss: 0.691650390625 vs 0.69677734375
Epoch [6/50], Training Loss: 0.69

---
## Results

In [35]:
print(f"""

BCI Competition
--------
CSP: {csp_bci_comp_acc}
EEGNet: {eegnet_bci_comp_acc}
EEGNet bands: {eegnet_bands_bci_comp_acc}
Pre-trained: {diff_bci_comp_acc}
slc: {slc_bci_comp_acc}
EEGNet synth: {eegnet_synth_bci_comp_acc}

OpenBCI
--------
CSP: {csp_openbci_acc}
EEGNet: {eeg_net_openbci_acc}
EEGNet bands: {eeg_net_bands_openbci_acc}
Pre-trained: {diff_openbci_acc}
slc: {slc_openbci_acc}
EEGNet synth: {eegnet_synth_openbci_acc}
""")



BCI Competition
--------
CSP: 0.7175368139223561
EEGNet: 79.28571428571429
EEGNet bands: 75.53571428571429
Pre-trained: 77.76785714285714
slc: 72.14285714285714
EEGNet synth: 80.17857142857143

OpenBCI
--------
CSP: 0.6105769230769231
EEGNet: 65.7258064516129
EEGNet bands: 67.3076923076923
Pre-trained: 58.65384615384615
slc: 62.5
EEGNet synth: 54.03225806451613

