In [None]:
import matplotlib.pyplot as plt
from dalle2 import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenClip
import pickle as pkl
import numpy as np
from dalle2 import DALLE2


clip = OpenClip(name='hf-hub:wisdomik/QuiltNet-B-32', pretrained=None)

In [None]:
from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline.to("cuda")

In [None]:
diffusion_prior_state_dict = torch.load('/data/ekvall/wandb/clip_prior_epoch_16_step_9000_state_dict.pt')
decoder_state_dict = torch.load('/data/ekvall/wandb/clip_decoder_epoch_1_step_48000_state_dict.pt')

In [None]:
unet1 = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8),
    text_embed_dim = 512,
    #cond_on_text_encodings = True  # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
).cuda()

decoder = Decoder(
    unet = unet1,
    image_sizes = [224],
    clip = clip,
    timesteps = 1000,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

In [None]:
prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

In [None]:
decoder.load_state_dict(decoder_state_dict)


In [None]:
diffusion_prior.load_state_dict(diffusion_prior_state_dict)

In [None]:
dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

# send the text as a string if you want to use the simple tokenizer from DALLE v1
# or you can do it as token ids, if you have your own tokenizer


In [None]:
tissues = ['Adrenal gland',
 'Bile duct',
 'Bladder',
 'Breast',
 'Cervix',
 'Colon',
 'Esophagus',
 'Head and Neck',
 'Kidney',
 'Liver',
 'Lung',
 'Ovarian',
 'Pancreatic',
 'Prostate',
 'Skin',
 'Stomach',
 'Testis',
 'Thyroid',
 'Uterus']

tissues_small = tissues[:3]

In [None]:
pan_nuke_images = np.load(open('/data/ekvall/kaggle/Part_1/Images/images.npy', 'rb'))
types = np.load(open('/data/ekvall/kaggle/Part_1/Images/types.npy', 'rb'))

In [None]:
for t, tissue in zip(set(types), tissues_small):
    selected_images = pan_nuke_images[types == t][:5]
    
    
    texts = [f'{tissue} H&E stained tissue sample'] * 5
    images = dalle2(texts) # (1, 3, 256, 256)
    images = images.permute(0, 2, 3, 1).cpu().numpy()
    stable_difusion_images = pipeline([f'{tissue} H&E stained tissue sample'] * 5).images
    
    
    fig, axs = plt.subplots(3, 5, figsize=(20, 15))
    
    overall_title = f"{tissue}"
    fig.suptitle(overall_title, fontsize=24, y=1.05)  # Adjust y-coordinate for the overall title

    
    # Titles for each row
    row_titles = ["DALL-E Generated Images", "Stable Diffusion Generated Images", "PanNuke Images"]
    
    #plot 5 images in the first row, and plot 5 stable diffusion images in the second row
    
    for i in range(5):
        axs[0, i].imshow(images[i])
        axs[1, i].imshow(stable_difusion_images[i])
        axs[2, i].imshow(selected_images[i] / 255)

        
        
    # Set titles for each row with annotations
    fig.text(0.5, 0.95, row_titles[0], ha='center', va='center', fontsize=20)  # First row
    fig.text(0.5, 0.63, row_titles[1], ha='center', va='center', fontsize=20)  # Second row
    fig.text(0.5, 0.30, row_titles[2], ha='center', va='center', fontsize=20)  # Third row
    
    plt.tight_layout()
    plt.show()
    