In [None]:
import os
import random
import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path

import numpy as np
import PIL
import safetensors
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder

# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

import json
import SimpleITK as sitk
from skimage.transform import resize
import matplotlib.pyplot as plt

import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    DPMSolverMultistepScheduler,
    StableDiffusionPipeline,
    StableDiffusionImg2ImgPipeline,
    UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available

In [None]:
class ValDataset(Dataset):
    def __init__(
            self,
            json_file,
            data_root, 
            size=512,
    ):
        self.data_root = data_root
        self.size = size

        self.data = []
        with open(json_file, "r") as f:
            for line in f:
                self.data.append(json.loads(line))

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

    def __len__(self):
        return len(self.data)
    

    def __getitem__(self, i):
        item = self.data[i]
        image_name = item["preImg"]
        text = item['prompt']

        # baseline_name = image_name.split("_")[0]+"_M00"
        # if os.path.isfile(os.path.join(self.data_root, baseline_name+".nii.gz")):
        #     image_name = baseline_name
        
        image = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(self.data_root, image_name+".nii.gz")))
        image = np.repeat(image[..., np.newaxis], 3, axis=-1)
        image = self.transform(image)

        target_image_name = item["img"]
        target_image = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(self.data_root, target_image_name+".nii.gz")))
        
        return {
            "image": image,
            "text": text,
            "target_image": target_image,
            "image_name": image_name,
            "target_image_name": target_image_name,
        }

In [None]:
placeholder_tokens = [
    "<DX-CN>",
    "<DX-MCI>",
    "<DX-AD>",
    "<Age--64>",
    "<Age-65-74>",
    "<Age-75-84>",
    "<Age-85->",
    "<Gender-Male>",
    "<Gender-Female>",
    "<Edu-6-12>",
    "<Edu-13-16>",
    "<Edu-17-18>",
    "<Edu-19->",
    "<Race-White>",
    "<Race-Black>",
    "<Race-Asian>",
    "<Race-MoreThanOne>",
    "<Race-Unknown>",
    "<Race-Indian-Alaskan>",
    "<Race-Hawwaiian-otherPI>",
    "<Marry-Married>",
    "<Marry-Widowed>",
    "<Marry-Divorced>",
    "<Marry-NeverMarried>",
    "<Marry-Unknown>",
    "<APOE4-0>",
    "<APOE4-1>",
    "<APOE4-2>",
]

In [None]:
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1-base"
checkpoint_path = "/inye/results/TextInversion_ADNI3_3/checkpoint-last"

weight_dtype = torch.float32
device = torch.device("cuda")
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(checkpoint_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(
    pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(
    checkpoint_path, subfolder="unet")

In [None]:
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
print(num_added_tokens)

In [None]:
test_data_json_file = "/inye/dataset/ADNI3_test_metadata2.jsonl"
data_dir = "/inye/dataset/T1_2D_slice_512/"
resolution = 512

val_dataset = ValDataset(
    json_file=test_data_json_file,
    data_root=data_dir,
    size=resolution,
)

pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
    pretrained_model_name_or_path,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    vae=vae,
    unet=unet,
    safety_checker=None,
    torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(device)
pipeline.set_progress_bar_config(disable=True)


In [None]:
seed = 42
generator = torch.Generator(device=device).manual_seed(seed)

result_folder = '/inye/Inference/seed-42_strength-0.75_guidance-3.0_step-100'
os.makedirs(result_folder, exist_ok=True)

for i in tqdm(range(len(val_dataset))):
    val_example = val_dataset[i]
    img_name = val_example["image_name"]
    target_img_name = val_example["target_image_name"]
    input_image = val_example["image"].unsqueeze(0)
    target_image = val_example["target_image"]
    text = val_example["text"]

    with torch.no_grad():
        input_latents = vae.encode(input_image.to(device, dtype=weight_dtype)).latent_dist.sample()
        input_latents = input_latents * vae.config.scaling_factor

    with torch.autocast("cuda"):
        result_image = pipeline(
            prompt=text,
            image = input_latents,
            num_steps=100,
            strength=0.75,
            guidance_scale=3.0,
            output_type="np",
            generator=generator,
        ).images[0]

    result_image = result_image[:,:,0]
    result_image = sitk.GetImageFromArray(result_image)
    sitk.WriteImage(result_image, os.path.join(result_folder, target_img_name+".nii.gz"))

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(input_image[0,0], cmap="gray", vmin=-1, vmax=1)
ax[0].set_title("Input Image")
ax[0].axis("off")

ax[1].imshow(target_image, cmap="gray")
ax[1].set_title("Target Image")
ax[1].axis("off")

result_image = sitk.GetArrayFromImage(result_image)
ax[2].imshow(result_image, cmap="gray", vmin=0, vmax=1)
ax[2].set_title("Output Image")
ax[2].axis("off")

print("Image Name:", img_name)
print("Target Image Name:", target_img_name)
print("Text:", text)

In [None]:
result_folder = '/inye/Inference/strength_0.3_guidance_2.0'
os.makedirs(result_folder, exist_ok=True)

In [None]:
result_image[:,:,0].min(), result_image[:,:,0].max()