## Import Modules

In [None]:
import cv2
import json
import numpy as np
import os
import shutil
import socket
import time
from IPython.display import display
from PIL import Image
from matplotlib import pyplot as plt
from pathlib import Path
from rich import inspect

# from torch.utils.tensorboard import SummaryWriter
import torch
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# rcg
from config import RCGConfiguration
from engine_mage import gen_img
from pixel_generator.mage import models_mage
from rdm.util import instantiate_from_config
import util.misc as misc

%matplotlib inline


In [None]:
# rtk
from rtk._datasets import create_transforms
from rtk.config import *
from rtk.datasets import instantiate_image_dataset
from rtk.mlflow import prepare_mlflow
from rtk.repl import prepare_console
from rtk.utils import get_logger, hydra_instantiate, _strip_target

ws, console = prepare_console(show_locals=False, _traceback=False)

In [None]:
init_method_kwargs = {"config_dir": "/home/nicoleg/workspaces/rcg/config/"}
overrides = [
    "device=3",
    "datasets.preprocessing.positive_class=null",
    "datasets.target=class_conditioned_labels",
    "datasets/encoding=class-conditioned-encoding",
    "datasets/transforms=rdm-transforms",
]
config_name = "rdm"

In [None]:
args: RCGConfiguration = set_hydra_configuration(
    config_name,
    ConfigurationInstance=RCGConfiguration,
    init_method_kwargs=init_method_kwargs,
    overrides=overrides,
)
console.print(args)

## Load pre-trained encoder, RDM and MAGE

In [None]:
from omegaconf import OmegaConf
from rdm.models.diffusion.ddpm import RDM

# Initialize RCG-L
class_cond = False
# rdm_ckpt_path = "outputs/rdm/2024-02-01/16-00-35/outputs/checkpoint-last.pth"
rdm_ckpt_path = "/home/nicoleg/workspaces/rcg/outputs/rdm/2024-02-01/16-00-35/outputs/checkpoint-0.pth"
rdm_cfg = "config/rdm/mocov3vitb_simplemlp_l12_w1536_classcond.yaml"
model = models_mage.mage_vit_large_patch16(
    mask_ratio_mu=0.75,
    mask_ratio_std=0.25,
    mask_ratio_min=0.5,
    mask_ratio_max=1.0,
    vqgan_ckpt_path="pretrained_enc_ckpts/vqgan_jax_strongaug.ckpt",
    use_rep=True,
    rep_dim=256,
    rep_drop_prob=0.1,
    use_class_label=False,
    pretrained_enc_arch="mocov3_vit_base",
    pretrained_enc_path="pretrained_enc_ckpts/mocov3/vitb.pth.tar",
    pretrained_enc_proj_dim=256,
    pretrained_enc_withproj=True,
    pretrained_rdm_ckpt=rdm_ckpt_path,
    pretrained_rdm_cfg=rdm_cfg,
)
# config = OmegaConf.load(args.config)
# model: RDM = instantiate_from_config(config.model)
model.cuda();

In [None]:
# checkpoint = torch.load(args.pretrained_rdm_ckpt, map_location='cpu')
# model.load_state_dict(checkpoint['model'], strict=True)
# model.eval();

## Image Generation

In [None]:
from rtk.datasets import set_labels_from_encoding

torch.manual_seed(args.seed)
np.random.seed(args.seed)

n_image_to_gen = 1
rdm_steps = 250
rdm_eta = 1.0
mage_temp = 11.0
mage_steps = 20
cfg = 6.0

In [None]:
dataset_cfg = args.datasets
labels = set_labels_from_encoding(args)
class_encoding = {v: k for k, v in dataset_cfg.encoding.items()}
class_encoding

In [None]:
def viz_torchimage(image: torch.Tensor):
    image = torch.clamp(image, 0, 1)
    image_np = image.detach().cpu().numpy().transpose([1, 2, 0])
    image_np = Image.fromarray(np.uint8(image_np * 255))
    display(image_np)

In [None]:
images = []
labels = dataset_cfg.labels

for class_label in class_encoding.keys():
    label = class_encoding[class_label]
    console.print("Generating: '{}'...".format(label))
    class_label = class_label * torch.ones(1).cuda().long()
    for i in range(n_image_to_gen):
        gen_images, lab = model.gen_image(
            1,
            num_iter=mage_steps,
            choice_temperature=mage_temp,
            sampled_rep=None,
            rdm_steps=rdm_steps,
            eta=rdm_eta,
            cfg=cfg,
            class_label=class_label,
        )
        visualize_scan(scan=gen_images[0], title=label)
        # images.append(img)

# fig, axs = plt.subplots(2, 2, figsize=(10, 10))

# for i, ax in enumerate(axs.flat):
#     ax.imshow(images[i])
#     ax.axis("off")  # to hide the axes

In [None]:
datasets = instantiate_image_dataset(args)
test_dataset = datasets[-1]

In [None]:
from collections import Counter

Counter(test_dataset.labels)

In [None]:
from monai.data import ThreadDataLoader

loader = ThreadDataLoader(test_dataset, batch_size=1, num_workers=12, shuffle=True)
iter_loader = iter(loader)

In [None]:
scans, labels = next(iter_loader)

In [None]:
scans.shape, labels

## GT Representation Reconstruction

### Generate Image from GT Representation

In [None]:
n_image_to_gen = 1
rdm_steps = 250
rdm_eta = 1.0
mage_temp = 11.0
mage_steps = 20
cfg = 0.0

images, _ = next(iter_loader)

images = images.cuda()
print("Ground Truth Image:")
plt.imshow(images[0], cmap="bone")

with torch.no_grad():
    mean = (
        torch.Tensor([0.485, 0.456, 0.406])
        .cuda()
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
    )
    std = (
        torch.Tensor([0.229, 0.224, 0.225])
        .cuda()
        .unsqueeze(0)
        .unsqueeze(-1)
        .unsqueeze(-1)
    )
    x_normalized = (images - mean) / std
    x_normalized = torch.nn.functional.interpolate(x_normalized, 224, mode="bicubic")
    rep = model.pretrained_encoder.forward_features(x_normalized)
    if model.pretrained_enc_withproj:
        rep = model.pretrained_encoder.head(rep)
    rep_std = torch.std(rep, dim=1, keepdim=True)
    rep_mean = torch.mean(rep, dim=1, keepdim=True)
    rep = (rep - rep_mean) / rep_std

print("Reconstructed Images:")
recon_image_list = []
for _ in range(n_image_to_gen):
    recon_images, _ = model.gen_image(
        12,
        num_iter=mage_steps,
        choice_temperature=mage_temp,
        sampled_rep=rep,
        rdm_steps=rdm_steps,
        eta=rdm_eta,
        cfg=cfg,
        class_label=None,
    )
    visualize_scan(scan=recon_images[0])