In [None]:
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from transformers import CLIPTextModel, CLIPTokenizer
import os
import argparse
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Generating images on device: {device}")

In [None]:
class DiffusionConfigs:
    seed = None
    output_dir = './outputs'
    path_to_model_dir = '/home/haicu/sophia.wagner/projects/dinov2/diffusion/model_0'
    mixed_precision = 'bf_16'
    batch_size = 1
    num_inference_steps = 100
    prompts = "Histopathological image of a single cell of type lymphocyte. The image is from site LDWBC."
    num_images_per_prompt = 1
args = DiffusionConfigs()

In [None]:
# If passed along, set the training seed now.
if args.seed is None:
    args.seed = 42

if args.output_dir is not None:
    os.makedirs(args.output_dir, exist_ok=True)

# Load scheduler, tokenizer and models.
unet = UNet2DConditionModel.from_pretrained(
    args.path_to_model_dir, subfolder="unet"
)

# noise_scheduler = DDPMScheduler.from_pretrained(args.path_to_model_dir, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
    args.path_to_model_dir, subfolder="tokenizer"
)
text_encoder = CLIPTextModel.from_pretrained(
        args.path_to_model_dir, subfolder="text_encoder"
    )
vae = AutoencoderKL.from_pretrained(
        args.path_to_model_dir, subfolder="vae"
    )


pipeline = StableDiffusionPipeline.from_pretrained(
    args.path_to_model_dir,
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    safety_checker=None,
    torch_dtype=args.mixed_precision,
)

pipeline = pipeline.to(device)
pipeline.set_progress_bar_config(disable=True)

generator = torch.Generator(device=device).manual_seed(args.seed)

In [None]:
domain_class_dict = {
    "beluga": [
        "metamyelocyte",
        "miscellaneous",
        "neutrophil segmented",
        "lymphocyte reactive",
        "monocyte",
        "neutrophil band",
        "eosinophil",
        "myeloblast",
        "basophil",
        "promyelocyte atypical",
        "lymphocyte",
        "mononucleose",
        "lymphocyte neoplastic",
        "lymphocyte large",
        "smudge cell",
        "plasma cell",
        "normoblast",
        "promyelocyte",
        "myelocyte",
        "hairy cell",
    ],
    "Marr": [
        "eosinophil",
        "monoblast",
        "erythroblast",
        "neutrophil segmented",
        "lymphocyte",
        "promyelocyte",
        "metamyelocyte",
        "myeloblast",
        "myelocyte",
        "smudge",
        "neutrophil band",
        "basophil",
        "promyelocyte bilobed",
        "monocyte",
        "lymphocyte atypical",
    ],
    "Matek": [
        "plasma cell",
        "basophil",
        "neutrophil band",
        "monocyte",
        "hairy cell",
        "promyelocyte",
        "neutrophil segmented",
        "normoblast",
        "lymphocyte large",
        "myeloblast",
        "smudge cell",
        "lymphocyte",
        "myelocyte",
        "promyelocyte atypical",
        "metamyelocyte",
        "lymphocyte neoplastic",
        "lymphocyte reactive",
        "eosinophil",
    ],
    "LDWBC": [
        "lymphocyte",
        "eosinophil",
        "monocyte",
        "neutrophil",
        "basophil",
    ],
    "Bodzas": [
        "neutrophil segmented",
        "lymphocyte",
        "eosinophil",
        "monocyte",
        "basophil",
        "lymphoblast",
        "neutrophil",
        "myeloblast",
        "normoblast",
    ]
}

In [None]:
# for prompt in args.prompts:

domains = [
    "beluga",
    "Marr",
    "Matek",
    "LDWBC",
    "Bodzas",
]
labels = [        
        "metamyelocyte",
        "miscellaneous",
        "neutrophil segmented",
        "lymphocyte reactive",
        "monocyte",
        "neutrophil band",
        "eosinophil",
        "myeloblast",
        "basophil",
        "promyelocyte atypical",
        "lymphocyte",
        "mononucleose",
        "lymphocyte neoplastic",
        "lymphocyte large",
        "smudge cell",
        "plasma cell",
        "normoblast",
        "promyelocyte",
        "myelocyte",
        "hairy cell",
        ]


In [None]:

prompts = [f"Histopathological image of a single cell of type {label}. The image is from site {domain}." for domain in domains for label in labels]
images = []
for prompt in tqdm(prompts):
    print(f"Generating images for:\n {prompt}")
    for i in range(0, args.num_images_per_prompt, args.batch_size):
        with torch.autocast(device):
            imgs = pipeline(prompt, num_images_per_prompt=args.batch_size,
                num_inference_steps=args.num_inference_steps,
                generator=generator).images
        images.extend(imgs)
# str_prompt = prompt.replace(' ','_')
# out_dir_temp = os.join(args.output_dir,str_prompt) 
# os.makedirs(out_dir_temp,exist_ok=True)
# print(f"Saving generated images in {out_dir_temp}")
# for num , image in enumerate(images):
#     image.save(os.join(out_dir_temp,f'{num}.jpg'))

In [None]:
len(images)

In [None]:
# save images
for img, prompt in tqdm(zip(images, prompts)):
    label = prompt.split('Histopathological image of a single cell of type ')[-1].split('.')[0]
    domain = prompt.split('The image is from site ')[-1].split('.')[0]
    # print(label, domain)
    # plt.imshow(img)
    img.save(f'./outputs/generated_seed42_{label.replace(" ", "-")}_{domain}.png')

In [None]:
# load images
images = []
for label in labels:
    for domain in domains:
        img = Image.open(f'./outputs/generated_seed42_{label.replace(" ", "-")}_{domain}.png')
        images.append(img) 

In [None]:
# plot the list of images with domains in columns and labels in rows
fig, axs = plt.subplots(len(labels), len(domains), figsize=(20, 100))
for i, label in enumerate(labels):
    for j, domain in enumerate(domains):
        img = images[i*len(domains)+j]
        axs[i, j].imshow(img)
        axs[i, j].axis('off')
        axs[i, j].set_title(f'{label} - {domain}')

plt.savefig('./outputs/generated_images_overview.png')