In [None]:
import torch
from diffusers import StableDiffusionPipeline
from datasets import load_dataset
import matplotlib.pyplot as plt

In [None]:
idxs = [31500, 32500, 36500]
dataset = load_dataset("Cilem/histopathology")
image_list = [dataset['train'][i]['image'] for i in idxs]
embeddings = [torch.tensor(dataset['train'][i]['embedding_vector']).unsqueeze(0) for i in idxs]

display(image_list[0])
display(image_list[1])
display(image_list[2])

In [None]:
pipeline = StableDiffusionPipeline.from_pretrained("trained_models/256x256/histopathology-diffusion-e2i-256", safety_checker=None)

In [None]:
outs = []
for embed, img in zip(embeddings, image_list):
    out = pipeline(
        prompt_embeds=embed,
        num_images_per_prompt=4,
        num_inference_steps=40,
        guidance_scale=0
    ).images

    outs.append({
        'original': img,
        'generated': out
    })

In [None]:
fig, axs = plt.subplots(3, 5, figsize=(15, 9), dpi=300)
for i, out in enumerate(outs):
    axs[i, 0].imshow(out['original'])
    axs[i, 0].set_title("Original")
    axs[i, 0].axis('off')
    for j, img in enumerate(out['generated']):
        axs[i, j+1].imshow(img)
        axs[i, j+1].set_title(f"Generated {j+1}")
        axs[i, j+1].axis('off')
plt.show()
