In [None]:
import torch
import matplotlib.pyplot as plt

from torchvision import transforms
from torch.utils.data import DataLoader
from diffusers import StableDiffusionPipeline
from utils.dataset_loader import CustomDatasetFromSlide
from utils.latent_extractor import ImageEmbeddingExtractor, TextEmbeddingExtractor
from datasets import load_dataset
from PIL import Image

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

In [None]:
MODE = "T2I"
device = "cuda:2"

if MODE == "T2I":
    text_encoder_base_name = "openai/clip-vit-large-patch14"
    model_name = "../trained_models/histopathology-diffusion-t2i-256"
    text_embedding_extractor = TextEmbeddingExtractor(text_encoder_name=text_encoder_base_name, device=device)

elif MODE == "E2I":
    model_name = "../trained_models/histopathology-diffusion-e2i-256"

else:
    image_encoder_base_name = "openai/clip-vit-large-patch14"
    model_name = "../trained_models/histopathology-diffusion-i2i-256"
    image_embedding_extractor = ImageEmbeddingExtractor(img_encoder_name=image_encoder_base_name, device=device)
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

In [None]:
dataset_name = "Cilem/mixed-histopathology-512"
dataset = load_dataset(dataset_name)
slide_dir = "/home/cilem/Lfstorage/wsis"


dataset = CustomDatasetFromSlide(dataset, slide_dir=slide_dir, transform=transform)
data_loader = DataLoader(dataset, batch_size=64, shuffle=False)

In [None]:
pipeline = StableDiffusionPipeline.from_pretrained(model_name, safety_checker=None)
pipeline = pipeline.to(device)

In [None]:
outs = []
for data in data_loader:

    image = data["image"]
    google_embedding_vector = data["embedding"]
    organ = data["organ"]

    if MODE == "T2I":
        text = [f"histopathology image of {organ[j]}" for j in range(len(organ))]
        embedding = text_embedding_extractor.extract_text_embedding(text=text)
        embedding = torch.from_numpy(embedding).to(device)
    
    elif MODE == "E2I":
        embedding = google_embedding_vector
        
    else:
        embedding = image_embedding_extractor.extract_image_embedding(image=image)
        embedding = torch.from_numpy(embedding).to(device)
        embedding = embedding.unsqueeze(1)
        
    output = pipeline(
        prompt_embeds=embedding,
        guidance_scale=0.0,
        num_inference_steps=40,
        output_type="pil"
    ).images

    for i in range(len(output)):
        outs.append({
            'original': transforms.ToPILImage()(image[i]),
            'generated': output[i]
        })
    
    if len(outs) >= 40:
        break

In [None]:
# visualize original and generated side by side
import numpy as np
if MODE != "T2I":
    fig, axs = plt.subplots(2, 10, figsize=(50, 10), dpi=300)
    fig.suptitle("Original and Generated Images, Mode: {}".format(MODE), fontsize=20)

    for i in range(10):
        axs[0, i].imshow(outs[i+30]['original']) 
        axs[0, i].axis('off')
        axs[0, i].set_title("Original")

        axs[1, i].imshow(outs[i+30]['generated'])
        axs[1, i].axis('off')
        axs[1, i].set_title("Generated")

    plt.show()


else:
    fig, axs = plt.subplots(2, 10, figsize=(50, 10), dpi=300)
    fig.suptitle("Original and Generated Images, Mode: {}\n Prompt: {}".format(MODE, text[0]), fontsize=20)

    for i in range(10):  
        axs[0, i].imshow(outs[i]['generated'])
        axs[0, i].axis('off')
        axs[0, i].set_title("Generated")

        axs[1, i].imshow(outs[i + 10]['generated'])
        axs[1, i].axis('off')
        axs[1, i].set_title("Generated")

    plt.show()



In [None]:
fig, axs = plt.subplots(5, 5, figsize=(10, 10), dpi=100)
fig.suptitle("Training Images", fontsize=20)

for i in range(25):
    axs[i // 5, i % 5].imshow(outs[i]['original'])
    axs[i // 5, i % 5].axis('off')
    axs[i // 5, i % 5].set_title("image {}".format(i))

plt.show()