In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%pip install -e ~/coding/diffae

In [None]:
import argparse

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path

from training.data.glioma_public import PublicGliomaDataset
from training.data.mri import extract_slices_from_volume
from training.experiments.cls import ClsModel
from training.experiments.rep import LitModel
from training.templates.templates import gliomapublic_autoenc
from training.templates.templates_cls import gliomapublic_autoenc_cls

In [None]:
CWD = %pwd
CWD = Path(CWD).parent
CWD


In [None]:
SEEED = 0
np.random.seed(SEEED)
torch.manual_seed(SEEED)
print(f"seed = {SEEED}")

In [None]:
def plot_tensor(t, ax, cmap="gray", *args, **kwargs):
    return ax.imshow(t.permute(1, 2, 0).cpu(), cmap=cmap, *args, **kwargs)


In [None]:
args = argparse.Namespace()
args.clf_mode = "multi_class"
args.manipulate_znormalize = False
# args.manipulate_cls = "12"
args.model_name = "beatgans_autoenc"
args.version = "5"  # "2" or "5" for the other model
args.style_ch = "512"
args.use_healthy = True
args

In [None]:
device = 'cuda'
conf = gliomapublic_autoenc(args=args, is_debugging=False)

state = torch.load(CWD / f'{conf.logdir}/last.ckpt', map_location='cpu')
conf.sample_size = state["state_dict"]["x_T"].shape[0]
conf.manipulate_znormalize = False
print(conf.name)
model = LitModel(conf)
model.load_state_dict(state['state_dict'], strict=False)
model.ema_model.eval()
model.ema_model.to(device)
args.pretrain_path = CWD / f"checkpoints/gliomapublic_seq-all/version_{args.version}/last.ckpt"
print("version setup for healthy visualization")
args.version = {"2": "0", "5": "1"}[args.version]
cls_conf = gliomapublic_autoenc_cls(is_debugging=False, args=args)
print()

In [None]:
cls_model = ClsModel(cls_conf)
cls_state = torch.load(CWD / f'{cls_conf.logdir}/last.ckpt', map_location='cpu')
print('latent step:', cls_state['global_step'])
cls_model.load_state_dict(cls_state['state_dict'], strict=False)
cls_model.to(device)
print()


In [None]:
cls_state["state_dict"]["classifier.weight"].size()

In [None]:
# define dataset
split = "train"
ds = PublicGliomaDataset(
    data_dir=conf.data_path,
    img_size=conf.img_size,
    mri_sequences=conf.mri_sequences,
    mri_crop=conf.mri_crop,
    train_mode=conf.train_mode,
    filter_class_labels=True,
    split_ratio=conf.split_ratio,
    split=split,
    manipulate_cls=conf.manipulate_cls,
    use_healthy=conf.use_healthy,
)
n_classes = ds.num_classes
n_seq = ds.n_seq
print(f"{n_classes = }, {n_seq = }")

In [None]:
# get a single sample from the dataset
i_data = np.random.randint(0, len(ds))

print(f"index in dataset: {i_data}")
sample_dict = ds[i_data]
for k, v in sample_dict.items():
    if isinstance(v, torch.Tensor):
        sample_dict[k] = v.to(device).unsqueeze(0)

edit_img = sample_dict['img']
cls_label = sample_dict["cls_labels"]
# flip class label

cls_label = torch.tensor(0 if cls_label == 1 else 1)
print(f"flipped class label to {cls_label.item()}")
com = sample_dict["com"]
if conf.clf_mode != "multi_class" and not conf.use_healthy:
    og_class_label = ds.inv_cls_label_map[cls_label.item()]
else:
    og_class_label = cls_label.item()

og_class_label_name = ds.cls_to_name[
    og_class_label] if not conf.use_healthy else {
        1: "healthy",
        0: "tumor"
    }[og_class_label]

print(
    f"img has class {og_class_label} ({og_class_label_name}), binary cls label: {cls_label.item()}"
)
# show the sampled volume from the dataset
img_slices = extract_slices_from_volume(edit_img, com)
seg_slices = extract_slices_from_volume(
    sample_dict['seg'].repeat(1, ds.n_seq, 1, 1, 1), com)

with_seg_map = True

fig = plt.figure(figsize=(5, 8))
for i in range(3 * n_seq):
    ax = fig.add_subplot(n_seq, 3, i + 1)
    plot_tensor(img_slices[i], ax)
    if with_seg_map:
        plot_tensor(seg_slices[i],
                    ax,
                    cmap="jet",
                    alpha=0.2 * (seg_slices[i][0].detach().cpu().numpy() > 0))
    plt.axis("off")
plt.tight_layout(pad=0)

In [None]:
T_fast = 10
T_slow = 200
T = T_slow

In [None]:
cond = model.encode(edit_img)
print("cond size:", cond.size())
xT = model.encode_stochastic(edit_img, cond, T=T)

In [None]:
img_slice = extract_slices_from_volume(edit_img, com)
xT_slice = extract_slices_from_volume(xT, com)

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
show_dict = dict(cmap="gray")
for a, t in zip(ax, [img_slice[0], xT_slice[0]]):
    plot_tensor(t, a)


In [None]:
# create new conditional vector with a different class label
l_cond = 0.2
# because we need to transfer the label to the negative class.
l_cond = -1 * l_cond if cls_label.item() == 1 else l_cond
if cls_conf.manipulate_znormalize:
    cond2 = cls_model.normalize(cond)
else:
    cond2 = cond
clf_hyperplane = F.normalize(cls_model.classifier.weight, dim=1)

cond2 = cond2 + l_cond * np.sqrt(conf.style_ch) * clf_hyperplane
if cls_conf.manipulate_znormalize:
    cond2 = cls_model.denormalize(cond2)

In [None]:
# sanity check: classifier should give positive and negative values for the two classes
pred = cls_model.classifier(cond)
if (pred > 0) != (cls_label.item() == 1):
    print("WARNING: classifier gave wrong prediction!")
edit_pred = cls_model.classifier(cond2)
print(f"pred: {pred.item():.2f}, edit_pred: {edit_pred.item():.2f}")

In [None]:
edit_cls_label = (cls_model.classifier(cond2) > 0).int()
og_edit_cls_label = ds.inv_cls_label_map[
    edit_cls_label.item()] if not conf.use_healthy else edit_cls_label.item()

assert og_edit_cls_label != og_class_label, "class label should be different"

print(
    f"binary class label, og: {cls_label.item()}, new: {edit_cls_label.item()}")
og_edit_class_label_name = {
    0: "tumor",
    1: "healthy"
}[og_edit_cls_label] if conf.use_healthy else ds.cls_to_name[og_edit_cls_label]

print(
    f"original class label: {og_class_label} ({og_class_label_name}), new: {og_edit_cls_label} ({og_edit_class_label_name})"
)


In [None]:
# create image with different class
edit_img = model.render(xT, cond2, T=T)
edit_img_slice = extract_slices_from_volume(edit_img, com)

In [None]:
edit_img.size()

In [None]:
manip_img_dir = CWD / "imgs_manipulated/mri"
(manip_img_dir).mkdir(exist_ok=True, parents=True)

In [None]:
n_slice = 1

stride = 1 if n_slice == 3 else 3
# only show axial slices of the mri
img_slice_strided = img_slice[::stride]
edit_img_slice_strided = edit_img_slice[::stride]
diff_images = (edit_img_slice_strided - img_slice_strided).abs()

imgs = torch.stack([img_slice_strided, edit_img_slice_strided, diff_images],
                   dim=1).view(-1, *img_slice_strided[0].size())

n_row = n_seq * n_slice
fig, axs = plt.subplots(n_row,
                        imgs.size(0) // n_row,
                        figsize=(imgs.size(0) // n_row, n_row + 1))

seqs = ["T1", "T1CE", "T2", "FLAIR"]
img_mode = ["", "EDIT", "DIFF"]

og_class_label = ds.inv_cls_label_map[
    cls_label.item()] if not conf.use_healthy else cls_label.item()
edit_og_class_label = ds.inv_cls_label_map[
    edit_cls_label.item()] if not conf.use_healthy else edit_cls_label.item()
fig.suptitle(f"Class {og_class_label} -> {edit_og_class_label}")

for i, (img, ax) in enumerate(zip(imgs, axs.flatten())):
    plot_tensor(img, ax)
    ax.axis("off")
    cur_seq = seqs[i // 3]
    cur_mode = img_mode[i % 3]
    title = f"{cur_seq} {cur_mode}"
    ax.title.set_text(title)

plt.tight_layout(h_pad=0, w_pad=1)
manipulation_str = f"manipulate_{og_class_label}_to_{edit_og_class_label}"
fp = manip_img_dir / f'compare_mri_{manipulation_str}_{split}{"_healthy" if conf.use_healthy else ""}{args.version}.png'
plt.savefig(fp)
print(f"saved to {fp}")
plt.show()

## Save images to nifti


In [None]:
from monai.transforms import SaveImage, SaveImaged

In [None]:
edit_img.max(), edit_img.min(), edit_img.shape

In [None]:
# cast image to uint8 in range [0, 255] for saving
edit_img_byte = edit_img.clamp(-1, 1).add(1).div(2).mul(255).to(torch.uint8)
print(
    f"edit_img_byte: {edit_img_byte.max()}, {edit_img_byte.min()}, {edit_img_byte.size()}"
)

save_img_dict = dict(zip(conf.mri_sequences, edit_img_byte[0]))
# determine sequence name used for all filenames
cur_seq_name = [
    c for c in conf.mri_sequences if c in edit_img_byte.meta["filename_or_obj"]
][0]
# update meta dict to contain original sequence name in filename
for seq, img in save_img_dict.items():
    img.meta["filename_or_obj"] = img.meta["filename_or_obj"].replace(
        cur_seq_name, seq)


In [None]:
# using the Nibabel backend
subject_id = ds._make_patient_id(ds.subject_dirs[i_data])
editing_type = ""
if conf.use_healthy:
    editing_type = "healthy_to_tumor" if og_class_label == 1 else "tumor_to_healthy"
else:
    raise NotImplementedError("only healthy vs tumor implemented")

saver = SaveImaged(keys=conf.mri_sequences,
                   output_dir=manip_img_dir / subject_id / editing_type,
                   output_postfix='',
                   output_ext=".nii.gz",
                   output_dtype=np.uint8,
                   resample=False,
                   squeeze_end_dims=True,
                   writer="NibabelWriter",
                   separate_folder=False)

saver(save_img_dict)
print(f"saved to {manip_img_dir}")