In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from training.data.glioma_public import PublicGliomaDataset

from experiment import LitModel
from experiment_classifier import ClsModel
from mri_utils import *
from templates import gliomapublic_autoenc
from templates_cls import *

ImportError: cannot import name 'd2c_crop' from partially initialized module 'training.data.celeb' (most likely due to a circular import) (/home/daniel/coding/diffae/notebooks/training/data/celeb.py)

In [None]:
SEEED = 0
np.random.seed(SEEED)
torch.manual_seed(SEEED)
print(f"seed = {SEEED}")

In [None]:
def plot_tensor(t, ax, cmap="gray",*args,**kwargs):
    return ax.imshow(t.permute(1,2,0).cpu(),cmap=cmap, *args,**kwargs)

In [None]:
device = 'cuda:0'
# device = "cpu"
conf = gliomapublic_autoenc()
print(conf.name)


In [None]:
model = LitModel(conf)
state = torch.load(f'{conf.logdir}/last.ckpt', map_location="cpu")
model.load_state_dict(state['state_dict'], strict=False)
model.ema_model.eval()
model.ema_model.to(device)
#model.model.eval()
#model.model.to(device)
print(f"global_step: {state['global_step']}")

In [None]:
conf.num_workers = 2
conf.batch_size = 1

data = PublicGliomaDataset(conf.data_path,
                           img_size=conf.img_size,
                           mri_sequences=conf.mri_sequences,
                           mri_crop=conf.mri_crop,
                           train_mode=conf.train_mode,
                           filter_class_labels=True)
loader = torch.utils.data.DataLoader(data, batch_size=conf.batch_size, shuffle=False, num_workers=conf.num_workers)                        


In [None]:
it = iter(loader)

In [None]:
do_reconstruction = True
n_steps = 100
remaining_classes = list(range(3))
while len(remaining_classes) > 0:
    b = next(it)
    b_cls = b["cls_labels"][0].item()
    if b_cls not in remaining_classes:
        continue
    remaining_classes.remove(b_cls)
    print(f"Class: {b_cls}")
    b_slice = extract_slices_from_volume(b["img"], b["com"])
    if do_reconstruction:
        print("Reconstructing...")
        img = b["img"].to(device)
        com = b["com"].to(device)
        cond = model.encode(img)
        print("encoded...")
        xT = model.encode_stochastic(img, cond, T=n_steps)
        print("encoded stochastic...")
        rec_img = model.render(xT, cond, T=n_steps).detach().cpu()
        print("rendered...")
        rec_img_slice = extract_slices_from_volume(rec_img, com)

        b_slice = torch.cat([b_slice, rec_img_slice], dim=0)


    only_axial = True
    n_cols = 4 if only_axial else 12 
    n_rows= 1+do_reconstruction
    print(n_rows, n_cols)
    fig, axs = plt.subplots(n_rows, n_cols,squeeze=True, gridspec_kw = {'wspace':0, 'hspace':0})

    stride= 3 if only_axial else 1
    for img, ax in zip(b_slice[::stride ], axs.flatten()):
        # normalize img to [0,1]
        img = (img - img.min()) / (img.max() - img.min())
        plot_tensor(img, ax)
        ax.axis("off")

    rec_suffix = "_rec" if do_reconstruction else ""
    img_dir = f"./imgs/axial_slices{rec_suffix}/"
    os.makedirs(img_dir, exist_ok=True)

    plt.savefig(f"{img_dir}/{b_cls}.png", bbox_inches="tight", pad_inches=0)
    plt.show()
    plt.close()
    # break

In [None]:
from sklearn.manifold import TSNE
from umap import UMAP
from itertools import islice    

In [None]:
def pca_reduction(x:torch.Tensor, d_low:int=2):
    # create covariance matrix of the feature vector in the latent space
    cov_x = x.T @ x
    # compute the first d_low principal components 
    u, s, v = torch.pca_lowrank(cov_x, q=d_low)
    # project latents onto lower dimensional space
    x_low = x @ v
    return x_low

def tsne_reduction(x:torch.Tensor, d_low: int=2):
    n_samples = x.size(0)
    tsne = TSNE(n_components=d_low, random_state=SEEED,
                       perplexity=min(n_samples-1,50.0))
    x_np = x.detach().cpu().numpy()
    x_low = tsne.fit_transform(x_np,)
    return x_low

def umap_reduction(x:torch.Tensor, d_low: int=2):
    x_np = x.detach().cpu().numpy()
    umap = UMAP(n_components=d_low, init='spectral', random_state=SEEED, low_memory=False,)
    x_low = umap.fit_transform(x)
    return  x_low

In [None]:
# init all lists to store the latents in
try:
    latents_list = [torch.load(f"{conf.logdir}/latents.pt")]
    cls_labels_list = [torch.load( f"{conf.logdir}/cls_labels.pt")]
except OSError as e:
    print(e)
    latents_list = []
    cls_labels_list = []
i_batch =0 

In [None]:

conf.num_workers = 4
conf.batch_size =16

data = PublicGliomaDataset(conf.data_path,
                           img_size=conf.img_size,
                           mri_sequences=conf.mri_sequences,
                           mri_crop=conf.mri_crop,
                           train_mode=conf.train_mode,
                           filter_class_labels=True)
loader = torch.utils.data.DataLoader(data, batch_size=conf.batch_size, shuffle=False, num_workers=conf.num_workers)                        


In [None]:
# initialize dataloader
use_all_data = True
if use_all_data:
    n = len(loader)
else:
    n_sample_pts = 10
    n = max(1,n_sample_pts // conf.batch_size)
print(f"{n = }")
loader_n = islice(loader, n)
batch_iter = tqdm.notebook.tqdm(loader_n,total=n)


In [None]:
# consume data iterator
for batch in loader_n:
    print(f"{i_batch = }/{n = }")
    
    imgs = batch["img"]
    cls_labels_list.append(batch["cls_labels"])
    with torch.no_grad():
        latent = model.encode(imgs.to(device))
    latents_list.append(latent.detach().cpu())
    i_batch += 1    


In [None]:
latents = torch.cat(latents_list, dim=0)   
cls_labels = torch.cat(cls_labels_list)

In [None]:
torch.save(latents,f"{conf.logdir}/latents.pt")
torch.save(cls_labels,  f"{conf.logdir}/cls_labels.pt")

In [None]:
cls_labels.size()

In [None]:
dim_red_fns = {
    "pca": pca_reduction,
    "tsne": tsne_reduction,
    "umap": umap_reduction
}
dim_red_fn_name = "pca"
for dim_red_fn_name in dim_red_fns.keys():
    sample_mask = cls_labels != 1
    sample_mask = torch.ones_like(sample_mask, dtype=torch.bool)
    latents_low = dim_red_fns[dim_red_fn_name](latents[sample_mask], 2)
    cls_labels_low = cls_labels[sample_mask]

    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111)
    ax.set_title(f"{dim_red_fn_name}")
    one_vs_all_colors = cls_labels_low == 2
    cls_names = ["Astrocytoma", "Glioblastoma", "Oligodendroglioma"]

    # set color map for figure
    cmap = plt.cm.get_cmap('tab10')
    cmap.set_under('white')
    cmap.set_over('black')

    for i in range(3):
        ax.scatter(latents_low[:, 1][cls_labels_low == i],
                latents_low[:, 0][cls_labels_low == i],
                label=cls_names[i],
                alpha=0.5)
    ax.legend()
    img_dir = "imgs/plots/"
    os.makedirs(img_dir, exist_ok=True)
    plt.savefig(f"{img_dir}/{dim_red_fn_name}_reduction.png",
                bbox_inches="tight",
                pad_inches=0)
    plt.show()


In [None]:
from sklearn.cluster import KMeans, DBSCAN

In [None]:
clustering =  DBSCAN(metric='euclidean', min_samples=10, eps=3,n_jobs=-1).fit(latents_low)
#clustering =  KMeans(n_clusters=3, n_init=1).fit(latents_low)
c_labels = clustering.labels_
c_i = 0
c_mask = c_labels==c_i
fig, (ax1, ax2)= plt.subplots(1,2)
ax1.scatter(latents_low[:,1], latents_low[:,0], c=c_labels)
ax2.scatter(latents_low[:,1], latents_low[:,0], c=c_mask)
plt.show()
# stats across a cluster
cluster_to_cls_labels = {}
for i_cluster in range(c_labels.max()+1):
    cluster_to_cls_labels[i_cluster] = dict(zip(*map(lambda x:x.tolist(),cls_labels_low[c_labels == i_cluster].unique(return_counts=True))))
cluster_to_cls_labels

In [None]:
def split_given_size(a, size):
    return np.split(a, np.arange(size,len(a),size))

In [None]:
loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False, num_workers=conf.num_workers)                        

In [None]:
img_iter = enumerate(zip(loader, c_mask))

In [None]:
i_b, (b,b_m) = next(filter(lambda x:x[1][1], img_iter))

In [None]:
print(i_b +1,"/", len(loader))
if b_m:
    b_img = b["img"]
    b_com = b["com"]
    # b_img[(b["seg_labels"]>0).repeat(1,4,1,1,1)] = 1
    b_img = torch.cat([b_img, b["seg_labels"]],dim=1)

    img_slice = extract_slices_from_volume(b_img, b_com)

    fig, axs =  plt.subplots(img_slice.size(0)//3, 3)
    fig.suptitle(f'Label: {b["cls_labels"][0].item()}')
    for img,ax in zip(img_slice, axs.flatten()):  
        ax.axis("off")
        plot_tensor(img,ax)
    plt.tight_layout()
    plt.show()
else:
    print("not in cluster")