# Imports and Setup

In [1]:
# Imports for Tensor
import csv
import itertools
import math
import numpy as np
import os
import pandas as pd
import shutil
import sys
from collections import OrderedDict
from datetime import datetime
from tempfile import TemporaryDirectory
from typing import Tuple

from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.utils.tensorboard import SummaryWriter
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

from torchvision import transforms

from diffusers import StableDiffusionImg2ImgPipeline
from datasets import load_dataset

sys.path.append("../")

%load_ext autoreload
%autoreload 2

In [2]:
from common.dog import DoG, LDoG, PDoG
from models import CNNClassifier
from common.eeg_datasets import GeneratedSpectrumDataset, ParkinsonsPreprocessor, ParkinsonsDataset
from common.eeg_datasets import parkinsons_class_labels
from training import train_class, evaluate_class, TrainingConfig
from visualization import *

In [3]:
data_path = "/data/shared/signal-diffusion/parkinsons/"
stft_path = os.path.join(data_path, "stfts")

In [4]:
import matplotlib.pyplot as plt
from PIL import Image

def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        if images.shape[0] < images.shape[1]:
            images = np.moveaxis(images, 0, 2)
        images = images[None, ...]
    images = ((0.5 + images / 2) * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]

    return pil_images

# Data Preprocessing (run once)

In [5]:
nsamps = 2000
preprocessor = ParkinsonsPreprocessor(data_path, nsamps)
# %time preprocessor.preprocess()
# preprocessor.make_tvt_splits()  # Can run separately to regen splits efficiently

In [6]:
df = pd.read_csv(stft_path + "/train-metadata.csv")

In [7]:
df.age.hist()

In [8]:
m = df.gender[df["gender"] == "M"].count()
f = df.gender[df["gender"] == "F"].count()
print(m, f)
print(m / (m + f))


In [9]:
pdf = df[df["health"] == "PD"]
m = pdf.gender[pdf["gender"] == "M"].count()
f = pdf.gender[pdf["gender"] == "F"].count()
print(m, f)
print(m / (m + f))

# Models and DataLoaders

In [42]:
d = ParkinsonsDataset(stft_path, split="train", transform=lambda x: x.convert("RGB"))
print(d[0][0])
# print(d[0][0].shape)
print(d[0][1])
print(len(d))

In [41]:
d[0][0]

In [43]:
d.caption(0)

In [10]:
dl = torch.utils.data.DataLoader(d, batch_size=16, shuffle=True,)
S, y = next(iter(dl))
print(S.shape, y.shape)
print(S.max(), S.min(), S.mean(), S.std())

# Test img2img

In [16]:
model_path = "../data/stft-full.parkinsons.0/"
# model_path = "../data/stft-post-ft.parkinsons.0/"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    model_path, 
    safety_checker=lambda images, **kwargs: (images, False),  # Disable safety checker - spectrograms won't be NSFW
    torch_dtype=torch.float16,
)
pipe.set_progress_bar_config(disable=True)
pipe.to("cuda")

model = CNNClassifier(1, 4)
model.load_state_dict(torch.load("models/best_model-cnnclass_AdamW_decay0.0001.pt"))
model.to("cuda")
model.eval()
model.requires_grad_(False)

## Flip gender

In [90]:
nout = 10
idx = 0
img, y = d[idx]
caption = d.caption(idx)
new_caption = "an EEG spectrogram of a 72 year old, healthy, male subject"
print(caption)
print("NEW:", new_caption)
strengths = np.arange(0.1, 1.0, 0.05)
gen_imgs = [[pipe(prompt=[new_caption], 
                  image=[img], strength=strength, guidance_scale=7.5, 
                  num_inference_steps=50, num_images_per_prompt=1).images[0] 
             for _ in range(nout)] for strength in strengths]
gen_imgs[0][0]

In [96]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
yhats = [[model(transform(im.convert("L")).to("cuda")) for im in gen_str] for gen_str in gen_imgs]
yhats = [[np.argmax(y.cpu().numpy()) for y in yhats_str] for yhats_str in yhats]
classes = [[parkinsons_class_labels[y] for y in yhats_str] for yhats_str in yhats]
print(caption)
classes

In [98]:
ayhats = np.array(yhats)
accs = np.sum(ayhats == 0, axis=1) / 10.
accs

In [99]:
plt.plot(strengths, accs)
plt.xlabel("Guidance Strength")
plt.ylabel("Requested class accuracy")
plt.grid(True)

## Flip health

In [90]:
nout = 10
idx = 0
img, y = d[idx]
caption = d.caption(idx)
new_caption = "an EEG spectrogram of a 72 year old, parkinsons disease diagnosed, female subject"
print(caption)
print("NEW:", new_caption)
strengths = np.arange(0.1, 1.0, 0.05)
gen_imgs = [[pipe(prompt=[new_caption], 
                  image=[img], strength=strength, guidance_scale=7.5, 
                  num_inference_steps=50, num_images_per_prompt=1).images[0] 
             for _ in range(nout)] for strength in strengths]
gen_imgs[0][0]

In [96]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
yhats = [[model(transform(im.convert("L")).to("cuda")) for im in gen_str] for gen_str in gen_imgs]
yhats = [[np.argmax(y.cpu().numpy()) for y in yhats_str] for yhats_str in yhats]
classes = [[parkinsons_class_labels[y] for y in yhats_str] for yhats_str in yhats]
print(caption)
classes

In [98]:
ayhats = np.array(yhats)
accs = np.sum(ayhats == 0, axis=1) / 10.
accs

In [99]:
plt.plot(strengths, accs)
plt.xlabel("Guidance Strength")
plt.ylabel("Requested class accuracy")
plt.grid(True)

# Train on Real Data

In [10]:
# Parameters
BATCH_SIZE = 64
SHUFFLE = True
NUM_WORKERS = 16
N_TOKENS = 128
RESOLUTION = 512
HOP_LENGTH = 80
persistent = NUM_WORKERS > 0

# Datasets
# train_set = GeneratedSpectrumDataset(gendir)
val_set = ParkinsonsDataset(stft_path, split="val")
test_set = ParkinsonsDataset(stft_path, split="test")
real_train_set = ParkinsonsDataset(stft_path, split="train")
# train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, 
#                                            shuffle=SHUFFLE, num_workers=NUM_WORKERS,
#                                            pin_memory=True, persistent_workers=persistent)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, 
                                         shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                         pin_memory=True, persistent_workers=persistent)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, 
                                          shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                          pin_memory=True, persistent_workers=persistent)
real_train_loader = torch.utils.data.DataLoader(real_train_set, batch_size=BATCH_SIZE, 
                                                shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                                pin_memory=True, persistent_workers=persistent)

In [11]:
# define hyperparameters
OUTPUT_DIM = 4
DROPOUT = 0.25
BATCH_FIRST = True # True: (batch, seq, feature). False: (seq, batch, feature)

# CUDA for PyTorch
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
torch.backends.cudnn.benchmark = device.type == "cuda"

# Loss function
criterion = nn.CrossEntropyLoss()

In [20]:
# Runtime training parameters
opt, decay, restart, max_eta, decouple = (torch.optim.AdamW, 0.0001, 0, None, True)


# Create model instance
model = CNNClassifier(1, OUTPUT_DIM,)
model = model.to(device)    
if opt == torch.optim.AdamW:
    optimizer = opt(model.parameters(), lr=1e-3, weight_decay=decay)
else:
    optimizer = opt(model.parameters(), weight_decay=decay, max_eta=max_eta, 
                    decouple_weight_decay=decouple)

# Create training configuration
ARGS = TrainingConfig(epochs=15, val_every_epochs=10, opt_restart_every=restart)

# Log statistics
postfix = ""
if isinstance(optimizer, DoG):
    postfix = f"_restart{restart}_etamax{max_eta}_decouple{str(int(decouple))}"
comment = f"cnnclass_{str(type(optimizer)).split('.')[-1][:-2]}_decay{decay}{postfix}"
tbsw = SummaryWriter(log_dir="./tensorboard_logs/cnn/" + comment + "-" + 
                     datetime.now().isoformat(sep='_'), 
                     comment=comment)
print("#" * 80)
print("Training", comment)

# Training loop
losses, accs, val_accs = train_class(
    ARGS, model, 
    real_train_loader, val_loader,
    optimizer, criterion,
    device, tbsw
)

# load best model and evaluate on test set
model.load_state_dict(torch.load(f'best_model.pt'))
test_loss, test_acc = evaluate_class(model, test_loader, criterion, device, 
                                     tbsw, ARGS.epochs * len(real_train_loader) + 1)
print(f'Test loss={test_loss:.3f}; test accuracy={test_acc:.3f}')

# Copy model to unique filename
os.makedirs("models", exist_ok=True)
shutil.copyfile("best_model.pt", f"models/best_model-{comment}.pt")
shutil.copyfile("last_model.pt", f"models/last_model-{comment}.pt")
print(f"Copied best model to models/best_model-{comment}.pt")


In [16]:
model = CNNClassifier(1, 4)
model.load_state_dict(torch.load("models/best_model-cnnclass_AdamW_decay0.01.pt"))
model.to("cuda")

In [21]:
cf, p = class_confusion(model, real_train_loader, parkinsons_class_labels, device)
plt.title("Training")

In [22]:
cf, p = class_confusion(model, val_loader, parkinsons_class_labels, device)
plt.title("Validation")

In [23]:
cf, p = class_confusion(model, test_loader, parkinsons_class_labels, device)
plt.title("Test")

In [20]:
class_prevalence(real_train_loader, parkinsons_class_labels);

# Generate a fake dataset

In [5]:
genders = ["female", "male"]
ages = np.arange(60, 80)
healths = ["healthy", "parkinsons disease diagnosed"]
def make_prompt():
    gender = np.random.choice(genders, p=[.33, .67])
    age = np.random.choice(ages)
    health = np.random.choice(healths, p=[.33, .67])
    prompt = f"an EEG spectrogram of a {age} year old, {health}, {gender} subject"
    y = (health != "healthy") * 2 + (gender == "female")
    return prompt, y

def make_n_prompts(n):
    for _ in range(n):
        yield make_prompt()

print(make_prompt())

def im2tensor(image):
    txfm = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5])
            ])
    return txfm(image.convert("L"))


In [26]:
model_path = "../data/stft-full.parkinsons.0/"
# model_path = "../data/stft-post-ft.parkinsons.0/"
pipe = StableDiffusionPipeline.from_pretrained(
    model_path, 
    safety_checker=lambda images, **kwargs: (images, False),  # Disable safety checker - spectrograms won't be NSFW
    torch_dtype=torch.float16,
)
pipe.set_progress_bar_config(disable=True)
pipe.to("cuda")

model = CNNClassifier(1, 4)
model.load_state_dict(torch.load(f'best_model.pt'))
model.to("cuda")
model.eval()
model.requires_grad_(False)

In [27]:
N = 6000
batch = 8
offset = 0

gendir = "/data/shared/signal-diffusion/parkinsons/gen-stft/"
os.makedirs(gendir, exist_ok=True)

files = []
ys = []

nbatch = (N + batch - 1) // batch
for i in tqdm(range(nbatch)):
    prompts, yy = zip(*list(make_n_prompts(batch)))
    ys.extend(yy)
    images = pipe(list(prompts), num_inference_steps=50, guidance_scale=7.5,).images
    for j in range(batch):
        imname = f"gen-{i * batch + j + offset}.png"
        files.append(imname)
        fname = os.path.join(gendir, imname)
        images[j].convert("L").save(fname)

mode = "w"
if offset > 0:
    mode = "a"
with open(os.path.join(gendir, "metadata.csv"), mode) as f:
    writer = csv.writer(f)
    if offset == 0:
        writer.writerow(["file", "y"])
    for row in zip(files, ys):
        writer.writerow(row)

# Generated Spectrograms

In [6]:
gendir = "/data/shared/signal-diffusion/parkinsons/gen-stft/"


In [7]:
gen_set = GeneratedSpectrumDataset(gendir)
print(len(gen_set))
print(gen_set[0])

## Generative-Trained Classifier

In [11]:
# Parameters

# Datasets
train_set = GeneratedSpectrumDataset(gendir)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, 
                                           shuffle=SHUFFLE, num_workers=NUM_WORKERS,
                                           pin_memory=True, persistent_workers=persistent)

In [12]:
# define hyperparameters
OUTPUT_DIM = 4
DROPOUT = 0.25
BATCH_FIRST = True # True: (batch, seq, feature). False: (seq, batch, feature)
WEIGHT_DECAY = 0.0001

# CUDA for PyTorch
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
torch.backends.cudnn.benchmark = device.type == "cuda"

# Loss function
criterion = nn.CrossEntropyLoss()

In [52]:
# Runtime training parameters
opt, decay, restart, max_eta, decouple = (torch.optim.AdamW, 0.001, 0, None, True)


# Create model instance
model = CNNClassifier(1, OUTPUT_DIM,)
model = model.to(device)    
if opt == torch.optim.AdamW:
    optimizer = opt(model.parameters(), lr=1e-3, weight_decay=decay)
else:
    optimizer = opt(model.parameters(), weight_decay=decay, max_eta=max_eta, 
                    decouple_weight_decay=decouple)

# Create training configuration
ARGS = TrainingConfig(epochs=15, val_every_epochs=10, opt_restart_every=restart)

# Log statistics
postfix = ""
if isinstance(optimizer, DoG):
    postfix = f"_restart{restart}_etamax{max_eta}_decouple{str(int(decouple))}"
comment = f"cnnclass_gen_{str(type(optimizer)).split('.')[-1][:-2]}_decay{decay}{postfix}"
tbsw = SummaryWriter(log_dir="./tensorboard_logs/cnn/" + comment + "-" + 
                     datetime.now().isoformat(sep='_'), 
                     comment=comment)
print("#" * 80)
print("Training", comment)

# Training loop
losses, accs, val_accs = train_class(
    ARGS, model, 
    train_loader, val_loader,
    optimizer, criterion,
    device, tbsw
)

# load best model and evaluate on test set
model.load_state_dict(torch.load(f'best_model.pt'))
test_loss, test_acc = evaluate_class(model, test_loader, criterion, device, 
                                     tbsw, ARGS.epochs * len(train_loader) + 1)
print(f'Test loss={test_loss:.3f}; test accuracy={test_acc:.3f}')

# Copy model to unique filename
os.makedirs("models", exist_ok=True)
shutil.copyfile("best_model.pt", f"models/best_model-{comment}.pt")
shutil.copyfile("last_model.pt", f"models/last_model-{comment}.pt")
print(f"Copied best model to models/best_model-{comment}.pt")


In [53]:
gmodel = model

### Evaluate on generative data

In [37]:
cf, p = class_confusion(gmodel, train_loader, parkinsons_class_labels, device)
plt.title("Gen-train, Gen-test")

### Evaluate on real data

In [36]:
cf, p = class_confusion(gmodel, real_train_loader, parkinsons_class_labels, device)
plt.title("Gen-train, Real-test")

In [16]:
class_prevalence(train_loader, parkinsons_class_labels);

## Real-trained classifier on generative data

In [33]:
rmodel = CNNClassifier(1, OUTPUT_DIM,)
rmodel.load_state_dict(torch.load("models/best_model-cnnclass_AdamW_decay0.0001.pt"))
rmodel.to("cuda")

### Evaluate on generative data

In [38]:
cf, p = class_confusion(rmodel, train_loader, parkinsons_class_labels, device)
plt.title("Real-train, Gen-test")

### Evaluate on real data

In [39]:
cf, p = class_confusion(rmodel, test_loader, parkinsons_class_labels, device)
plt.title("Real-train, Real-test")

In [51]:
fig, ax = plt.subplots(figsize=(10, 10))
plt.subplot(2, 2, 1)
cf, _ = class_confusion(rmodel, test_loader, parkinsons_class_labels, device, fig=fig)
plt.title("Real-train, Real-test")
plt.yticks(rotation=60)
plt.xticks([])
plt.xlabel("Accuracy: {:.2%}".format(np.trace(cf) / np.sum(cf)))
plt.ylabel("True Class")

plt.subplot(2, 2, 2)
cf, _ = class_confusion(rmodel, train_loader, parkinsons_class_labels, device, fig=fig)
plt.xticks([])
plt.xlabel("Accuracy: {:.2%}".format(np.trace(cf) / np.sum(cf)))
plt.yticks([])
plt.title("Real-train, Gen-test")

plt.subplot(2, 2, 3)
cf, _ = class_confusion(gmodel, test_loader, parkinsons_class_labels, device, fig=fig)
plt.title("Gen-train, Real-test")
plt.xlabel("Predicted Class\nAccuracy: {:.2%}".format(np.trace(cf) / np.sum(cf)))
plt.yticks(rotation=60)
plt.xticks(rotation=60)

plt.subplot(2, 2, 4)
cf, _ = class_confusion(gmodel, train_loader, parkinsons_class_labels, device, fig=fig)
plt.yticks([])
plt.xticks(rotation=60)
plt.xlabel("Predicted Class\nAccuracy: {:.2%}".format(np.trace(cf) / np.sum(cf)))
plt.title("Gen-train, Gen-test")

plt.subplots_adjust(bottom=0.2)