In [1]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=False)

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
# ! pip install ripser

In [4]:
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [5]:
import torch
from matplotlib import pyplot as plt
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import functional as F

from cvae_train import (PETsMRIDataset, BimodalCVAE, bimodal_cvae_loss_fn, invo_cvae_loss_fn, train_bimodal_cvae, 
                        predict_bimodal_latent, load_model)
from TopoKMeans import topo_kmeans

In [6]:
DEVICE = ('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {DEVICE}')

Using device: cuda:0


In [7]:
AD_COHORT_PET_NORM_DIR = r'C:\Users\keena\School\Georgia Tech\FA25 Classes\CS 8903\cVAE\ad_cohort_pet_norm'
AD_COHORT_MRI_NORM_DIR = r'C:\Users\keena\School\Georgia Tech\FA25 Classes\CS 8903\cVAE\ad_cohort_mri_norm'

In [8]:
# pet_hdf5_path = r'C:\Users\keena\School\Georgia Tech\FA25 Classes\CS 8903\cVAE\processed_data\ad_cohort_pet.hdf5'
# mri_hdf5_path = r'C:\Users\keena\School\Georgia Tech\FA25 Classes\CS 8903\cVAE\processed_data\ad_cohort_mri.hdf5'
# all_cohorts_path = r'C:\Users\keena\School\Georgia Tech\FA25 Classes\CS 8903\cVAE\adni-tables\all_cohorts_pet_mri.csv'
pet_hdf5_path = '/home/hice1/khom9/scratch/ad_cohort_pet.hdf5'
mri_hdf5_path = '/home/hice1/khom9/scratch/ad_cohort_mri.hdf5'
all_cohorts_path = '/home/hice1/khom9/CS8903/TopoDL_Hypometabolism/adni-tables/all_cohorts_pet_mri.csv'
latent_dim = 20

In [9]:
# Only run once to create the HDF5 files

# create_hdf5_dataset(AD_COHORT_PET_NORM_DIR, pet_hdf5_path, r'^norm_wwc_(\d+)-')
# create_hdf5_dataset(AD_COHORT_MRI_NORM_DIR, mri_hdf5_path, r'_I(\d+)')

In [10]:
dataset = PETsMRIDataset(pet_hdf5_path, mri_hdf5_path, all_cohorts_path, in_memory=True)

In [11]:
model = BimodalCVAE(latent_dim, 1).to(DEVICE)
optimizer = Adam(model.parameters(), lr=1e-3)
epochs = 300
batch_size = 32
save_path = '/home/hice1/khom9/CS8903/TopoDL_Hypometabolism/cvae-models/cvae_bimodal3.pth'
# save_path = None
# save_path = r'C:\Users\keena\School\Georgia Tech\FA25 Classes\CS 8903\cVAE\saved_models\cvae_bimodal.pth'

In [None]:
train_bimodal_cvae(model, optimizer, dataset, bimodal_cvae_loss_fn, epochs, batch_size, verbose=2, device=DEVICE, save_path=save_path, save_freq=5)
# model = load_model(model, save_path)

Learning rate: 0.001
No learning rate scheduling!
Training for 300 epochs, with batch size=32
Using device: cuda:0
Saving model every 5 epochs to /home/hice1/khom9/CS8903/TopoDL_Hypometabolism/cvae-models/cvae_bimodal3.pth

-----Epoch 1/300-----
Batch 8/24 | loss: 34.17454752326012 (4.441s) | recon: 1.4243674874305725 | KL: 32.750181160867214
Batch 16/24 | loss: 2.148728162050247 (3.683s) | recon: 1.371691808104515 | KL: 0.7770363390445709
Batch 24/24 | loss: 2.119994208216667 (3.491s) | recon: 1.3332230299711227 | KL: 0.7867712080478668

-----Epoch 2/300-----
Batch 8/24 | loss: 2.0262793600559235 (3.679s) | recon: 1.3043141812086105 | KL: 0.7219651937484741
Batch 16/24 | loss: 1.9614633619785309 (3.678s) | recon: 1.2704184502363205 | KL: 0.6910449266433716
Batch 24/24 | loss: 2.0306107699871063 (3.467s) | recon: 1.2415660917758942 | KL: 0.7890447229146957

-----Epoch 3/300-----
Batch 8/24 | loss: 1.9338074773550034 (3.681s) | recon: 1.2164357155561447 | KL: 0.7173717245459557
Batch 16

In [None]:
loader = DataLoader(dataset, batch_size=1, shuffle=True)
pet, mri, cdr = next(iter(loader))
pet = pet.to(DEVICE)
mri = mri.to(DEVICE)
cdr = cdr.to(DEVICE)

In [None]:
pet_hat, mri_hat, mu, logvar  = model(pet, mri, cdr)
loss, bce_loss, kld_loss = bimodal_cvae_loss_fn(pet, mri, pet_hat, mri_hat, mu, logvar, invo_cvae_loss_fn)
pet_bce_loss = F.binary_cross_entropy(pet_hat, pet, reduction='mean')
mri_bce_loss = F.binary_cross_entropy(mri_hat, mri, reduction='mean')

# kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
print(f'PET loss: {pet_bce_loss}; MRI loss: {mri_bce_loss}, KLD loss: {kld_loss}')

In [None]:
mu, logvar

In [None]:
def plot_scan(img):
    r, c = 3, 5
    fig, ax = plt.subplots(r, c, figsize=(10, 5))

    for i in range(0, r*c):
        im = ax[i//c, i%c].imshow(img[:,:,i*(img.shape[2]//(r*c))].detach().cpu(), vmin=0, vmax=1, cmap='viridis')
        ax[i//c, i%c].axis('off')
    # plt.colorbar(im, ax=ax)
#     plt.tight_layout()
    fig.subplots_adjust(hspace=0.1, wspace=0.1)
    

In [None]:
plot_scan(pet.squeeze())

In [None]:
plot_scan(pet_hat.squeeze())

In [None]:
plot_scan(mri.squeeze())

In [None]:
plot_scan(mri_hat.squeeze())

In [None]:
# generate an example
z = torch.randn(1, latent_dim).to(DEVICE)
cdr = torch.tensor([[0.95]]).to(DEVICE)

pet_gen, mri_gen = model.decode(z, cdr)

In [None]:
# generate another example
z = torch.randn(1, latent_dim).to(DEVICE)
cdr = torch.tensor([[0.04]]).to(DEVICE)

pet_gen2, mri_gen2 = model.decode(z, cdr)
# plot_scan(pet_gen.squeeze())

In [None]:
(mri_gen2 - mri_gen).abs().mean(), (pet_gen2 - pet_gen).abs().mean()

In [None]:
plot_scan(pet_gen.squeeze())

In [None]:
plot_scan(pet_gen2.squeeze())

In [None]:
plot_scan(mri_gen.squeeze())

In [None]:
plot_scan(mri_gen2.squeeze())

In [None]:
# Cluster using k-means and plot with t-SNE
def plot_tsne(X, clusters):
    '''Cluster the latent representations of the bimodal images'''
    tsne = TSNE()
    points = tsne.fit_transform(X)
    plt.scatter(points[:,0], points[:,1], c=clusters, s=2, alpha=0.4)

In [None]:
all_mu, all_logvar = predict_bimodal_latent(model, dataset, device=DEVICE)
all_mu = all_mu.cpu()
all_logvar = all_logvar.cpu()

In [None]:
kmeans = KMeans(n_clusters=4)
clusters = kmeans.fit_predict(all_mu)

In [None]:
plot_tsne(all_mu, clusters)

In [None]:
# Topo k-means plotting with t-SNE
t_kmeans = topo_kmeans(
    data=all_mu,
    n_knn=40,          # neighborhood size - start with 20-50 imo. Controls how local each persistence diagram is. Too small: diagrams are sparse/uninformative. Too large: neighborhoods blur global structure.
    n_clust=4,         # number of clusters
    sigma=20,          # kernel scale - depends on scaling decision based on data. Larger sigma makes the RBF fall off faster with differences in persistence values.
    power=15,          # power exponent - exagerates distance differences
    null_dim=True,     # sets H0 meaning uses persistence diagram approach
    first_dim=False,   # keep false, if set true it's basically only H1 (loops) but we'd need to compute that, otherwise it's just euclidean
    preserve_ordering=False, # keep false (this is true if we only cared about each point/sample's position in the input order we feed it.)
    dist_matrix=False, # assumes we don't have a dist matrix
    random_state=0,
    random_restarts=200 # rec 20-200 based on R/FCPS implementation version. Start small (say 20) for exploration and then go big say 200, maybe even 1k-5k if compute avail for max fidelity. Would only play around this if sillhoute score/cluster labels etc changing a lot each time
)

In [None]:
plot_tsne(all_mu, t_kmeans.labels)