In [None]:
import mne
import numpy as np
from scipy.signal import periodogram
import matplotlib.pyplot as plt
import os
plt.rcParams['font.size'] = 14
%matplotlib inline

In [None]:
datadir = "/mnt/d/data/signal-diffusion/parkinsons/sub-001/eeg/"
os.listdir(datadir)

In [None]:
data = mne.io.read_raw_eeglab(datadir + "sub-001_task-Rest_eeg.set")
# data = data.get_data()

In [None]:
fig = data.plot(show_scrollbars=False, show_scalebars=False);
fig.tight_layout()
fig.savefig("parkinsons-snippet.png");

In [None]:
ch0 = data.get_data()[0, 0:20000]
ch0.shape

In [None]:
f, psd = periodogram(ch0, fs=250, window='hann', nfft=1024, scaling="spectrum")
plt.figure()
plt.plot(f, 10 * np.log10(psd))
plt.ylabel("PSD (dB)")
plt.xlabel("Frequency (Hz)");
plt.tight_layout()
plt.savefig('eeg-spec.png');


In [None]:
fm = np.load("kpfa_12bit.npy")
fm.shape

In [None]:
from scipy import signal

In [None]:
fs = 960000
h = signal.firwin(128,100000.0,fs=fs,window='hann')
H = np.fft.fftshift(np.fft.fft(h,n=1024))
w = np.r_[-512.0:512.0]/1024.0*960
fig=plt.figure(figsize=(8,2))
plt.semilogy(w,np.abs(H))
plt.xlabel('Frequency [kHz]')
plt.title('Magnitude frequency response of the low-pass filter')
data_f = signal.oaconvolve(fm,h)[::4]

In [None]:
f, psd = periodogram(data_f, fs=fs/4, window='hann', return_onesided=True, nfft=1024, scaling='spectrum')
N = len(f)
f = f[:N//2]
psd = psd[:N//2]
plt.figure()
plt.plot(f/1000, 10 * np.log10(psd))
plt.ylabel("PSD (dB)")
plt.xlabel("Frequency (kHz)");
plt.tight_layout()
plt.savefig('fm-spec.png');

In [None]:
datadir = "/mnt/d/data/signal-diffusion/seed/EEG_raw/"

In [None]:
data = mne.io.read_raw_cnt(datadir + "10_1_20180507.cnt")

In [None]:
fig = data.plot(show_scrollbars=False, show_scalebars=False);
fig.tight_layout()
fig.savefig("seed-snippet.png");

# VAEs

In [None]:
import torch
import torchvision.transforms.v2 as v2
from diffusers import AutoencoderKL
from datasets import load_dataset, Features, Image
from data_processing.general_dataset import GeneralDataset
from data_processing.seed import SEEDDataset
from data_processing.parkinsons import ParkinsonsDataset
from PIL import Image as pImage
import huggingface_hub as hfh

In [None]:
seeddir = "./seed/stfts"
parkdir = "./parkinsons/stfts"
# model = "CompVis/stable-diffusion-v1-4"
model = "stabilityai/sdxl-vae"

In [None]:
# seed_ds = SEEDDataset("/mnt/d/data/signal-diffusion/seed/stfts")
# park_ds = ParkinsonsDataset("/mnt/d/data/signal-diffusion/parkinsons/stfts")
# gen_ds = GeneralDataset([seed_ds, park_ds])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
transform = v2.Compose([
    v2.Resize((256, 256)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.5], [0.5]),
])

def transform_fn(x):
    x['image'] = transform(x['image'])
    return x

def collate_fn(examples):
    return torch.stack([x['image'] for x in examples])

dataset = load_dataset("imagefolder", data_dir=seeddir,
                       features=Features({"image": Image(mode="RGB")})
                       ).with_format("torch")
dataset = dataset.with_transform(transform_fn)
dataloader = torch.utils.data.DataLoader(
    dataset['train'], batch_size=1, shuffle=False,
    collate_fn=collate_fn, pin_memory=True, num_workers=4)


In [None]:
if "vae" in model:
    vae = AutoencoderKL.from_pretrained(model).to(device)
else:
    vae = AutoencoderKL.from_pretrained(model, subfolder="vae").to(device)
vae.eval()


In [None]:

count = 0
for batch in dataloader:
    images = batch.to(device)
    with torch.no_grad():
        recon = vae(images)
    recon = recon.sample.permute(0, 2, 3, 1).cpu().float().numpy().clip(-1, 1)
    orig = images.permute(0, 2, 3, 1).cpu().float().numpy().clip(-1, 1)
    # Save the reconstructions
    for im in orig:
        im = pImage.fromarray(((im * 0.5 + 0.5) * 255).astype(np.uint8))
        im.save(f"sdxl_vae_orig_{count}.jpg")
    for im in recon:
        im = pImage.fromarray(((im * 0.5 + 0.5) * 255).astype(np.uint8))
        im.save(f"sdxl_vae_recon_{count}.jpg")
        count += 1
    break