In [41]:
import numpy as np
import pandas as pd
import os
import scipy.io
from einops import rearrange
import matplotlib.pyplot as plt
import seaborn as sns
# import pywt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from lightning import Fabric
from pytorch_lightning.utilities.model_summary import ModelSummary
from lightning_fabric.utilities.seed import seed_everything
import lightning as L
from einops import repeat,reduce,rearrange
import argparse
import pickle
import wandb

import sys

sys.path.append("../../../motor-imagery-classification-2024/")

from classification.loaders import EEGDataset,load_data
from models.unet.eeg_unets import Unet,UnetConfig, BottleNeckClassifier, Unet1D
from classification.classifiers import DeepClassifier , SimpleCSP, k_fold_splits
from classification.loaders import subject_dataset
from ntd.networks import SinusoidalPosEmb
from ntd.diffusion_model import Diffusion
from ntd.utils.kernels_and_diffusion_utils import WhiteNoiseProcess

from u_net_diffusion import DiffusionUnet, DiffusionUnetConfig

torch.set_float32_matmul_precision('medium')
seed_everything(0)

FS = 250
DEVICE = "cuda"
sns.set_style("darkgrid")


DEBUG = False

if DEBUG:
	print("---\n---\nCurrently in debug mode\n---\n---")

NUM_TIMESTEPS = 1000
DIFFUSION_LR = 6E-4
CNN_LR = 1E-4
SCHEDULE = "linear"
START_BETA = 1E-4
END_BETA = 8E-2
DIFFUSION_NUM_EPOCHS = 100 if not DEBUG else 1
DIFFUSION_BATCH_SIZE = 64
CLASSIFICATION_MAX_EPOCHS = 150 if not DEBUG else 1
CHANNELS = [0,1,2]

dataset = {}
for i in range(1,10):
    mat_train,mat_test = load_data("../../data/2b_iv",i)
    dataset[f"subject_{i}"] = {"train":mat_train,"test":mat_test}

REAL_DATA = "../../data/2b_iv/raw"

SAVE_PATH = "../../saved_models"



Seed set to 0


In [37]:
UnetDiff1D = DiffusionUnetConfig(
		time_dim=12,
		class_dim=12,
		num_classes=2,
		input_shape=(512),
		input_channels=3,
		conv_op=nn.Conv1d,
		norm_op=nn.InstanceNorm1d,
		non_lin=nn.ReLU,
		pool_op=nn.AvgPool1d,
		up_op=nn.ConvTranspose1d,
		starting_channels=32,
		max_channels=256,
		conv_group=1,
		conv_padding=(1),
		conv_kernel=(3),
		pool_fact=2,
		deconv_group=1,
		deconv_padding=(0),
		deconv_kernel=(2),
		deconv_stride=(2),
		residual=True
	)

In [38]:
train_split = 9*[["train"]]
test_split = 9*[["test"]]

In [39]:
print(f"train split: {train_split}")
print(f"test split: {test_split}")

train_set = EEGDataset(subject_splits=train_split,
				dataset=None,
				save_paths=[REAL_DATA],
				dataset_type=subject_dataset,
				channels=CHANNELS,
				sanity_check=False,
				length=2.05)

test_set = EEGDataset(subject_splits=test_split,
					dataset=None,
					save_paths=[REAL_DATA],
					channels=CHANNELS,
					sanity_check=False,
					length=2.05)

classifier = BottleNeckClassifier((2048,1024),)
unet = DiffusionUnet(UnetDiff1D,classifier)

noise_sampler = WhiteNoiseProcess(1.0, 512)

diffusion_model = Diffusion(
	network=unet,
	diffusion_time_steps=NUM_TIMESTEPS,
	noise_sampler=noise_sampler,
	mal_dist_computer=noise_sampler,
	schedule=SCHEDULE,
	start_beta=START_BETA,
	end_beta=END_BETA,
)


train split: [['train'], ['train'], ['train'], ['train'], ['train'], ['train'], ['train'], ['train'], ['train']]
test split: [['test'], ['test'], ['test'], ['test'], ['test'], ['test'], ['test'], ['test'], ['test']]
(3026, 3, 512)
(3026,)
final data shape: (3026, 3, 512)
(2241, 3, 512)
(2241,)
final data shape: (2241, 3, 512)


In [66]:
def test_batch(self, 
                    batch,
                    conditions,
                    mask=None,
                    p_uncond=0):
        self.eval()

        losses = []

        for cond in conditions:
            batch_size = batch.shape[0]
            time_index = self.fixed_steps
            n = len(time_index)
            time_index = repeat(time_index,"n -> (n b)",b=batch_size)
            batch = repeat(batch,"b ... -> (n b) ...",n=n)
            train_alpha_bars = self.alpha_bars[time_index].unsqueeze(-1).unsqueeze(-1)
            noise = self.noise_sampler.sample(
                sample_shape=(
                    batch_size,
                    self.network.signal_channel,
                    self.network.signal_length,
                )
            )
            noise = repeat(noise,"b ... -> (n b) ...",n=n)
            if noise.shape != batch.shape:
                raise ValueError(f"shape mismatch between noise ({noise.shape}) and batch ({batch.shape})")
            
            noisy_sig = (
                torch.sqrt(train_alpha_bars) * batch
                + torch.sqrt(1.0 - train_alpha_bars) * noise
            )
            cond = repeat(cond,"t -> (n b) 1 t",n=n,b=batch_size)
            res = self.network.forward(noisy_sig, time_index, cond=cond)
            diff = rearrange(noise - res,"(n b) ... -> b n ...",n=n)
            norms = reduce((diff)**2,"b ... -> b",reduction="mean")
            losses.append(norms)
        return torch.stack(losses,-1)

In [67]:
x = torch.rand((4,3,512))

In [68]:
x.shape

torch.Size([4, 3, 512])

In [69]:
cond_1 = (torch.zeros(512)+1).to(torch.int)
cond_2 = (torch.ones(512)+1).to(torch.int)

In [70]:
conds = [cond_1,cond_2]

In [71]:
cond_1.shape

torch.Size([512])