In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from datasets import load_dataset

from diffusers import StableDiffusionPipeline
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPTokenizer, CLIPTextModel, CLIPFeatureExtractor

In [None]:
vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="safety_checker")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

unet = UNet2DConditionModel(
    act_fn="silu",
    attention_head_dim=8,
    center_input_sample=False,
    downsample_padding=1,
    flip_sin_to_cos=True,
    freq_shift=0,
    mid_block_scale_factor=1,
    norm_eps=1e-05,
    norm_num_groups=32,
    sample_size=32, # generated samples are 256x256 
    in_channels=4, 
    out_channels=4, 
    layers_per_block=2, 
    block_out_channels=(320, 640, 1280, 1280), 
    down_block_types=(
    "CrossAttnDownBlock2D",
    "CrossAttnDownBlock2D",
    "CrossAttnDownBlock2D",
    "DownBlock2D"), 
    up_block_types=("UpBlock2D",
    "CrossAttnUpBlock2D",
    "CrossAttnUpBlock2D",
    "CrossAttnUpBlock2D"),
    cross_attention_dim=384
)

In [None]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
vae.to(device)
unet.to(device)

In [4]:
class CustomDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset["train"]
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]["image"]
        if self.transform:
            image = self.transform(image)
        embedding = self.dataset[idx]["embedding_vector"]
        embedding = torch.tensor(embedding, dtype=torch.float32)
        return {"image": image, "embedding": embedding}

In [5]:
optimizer = optim.AdamW(unet.parameters(), lr=1e-5)
batch_size = 4
num_epochs = 10

In [6]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
dataset = load_dataset("Cilem/histopathology")
custom_dataset = CustomDataset(dataset, transform=transform)
dataloader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)

In [None]:
for epoch in range(num_epochs):
    for batch in dataloader:
        images = batch["image"].to(device)
        embeddings = batch["embedding"].to(device)
    
        latents = vae.encode(images).latent_dist.sample() 
        latents = latents * 0.18215 

        noise = torch.randn_like(latents).to(device)
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=device).long()

        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)  
        
        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=embeddings).sample

        loss = F.mse_loss(noise_pred, noise)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

In [None]:
pipe = StableDiffusionPipeline(
    vae=vae,
    unet=unet,
    scheduler=noise_scheduler,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    safety_checker=safety_checker,
    feature_extractor=feature_extractor
)

pipe.save_pretrained("histopathology-diffusion")