In [None]:
import os
import logging
import torch
import torch.optim as optim
import torch.nn.functional as F

from utils.dataset_loader import CustomDatasetFromSource
from torch.utils.data import DataLoader, Dataset
from diffusers import StableDiffusionPipeline
from torchvision import transforms
from datasets import load_dataset

In [None]:
if not os.path.exists('logs'):
    os.makedirs('logs')

In [None]:
device = "cuda"
save_name = "histopathology-diffusion-t2i"
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    filename=f'./logs/{save_name}.log',)

In [None]:
pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
unet = pipeline.unet.to(device)
vae = pipeline.vae.to(device)
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
scheduler = pipeline.scheduler

vae.eval()
unet.train()

In [None]:
optimizer = optim.AdamW(unet.parameters(), lr=1e-5)
batch_size = 1
num_epochs = 5


logging.info("Optimizer: {}".format(optimizer))
logging.info("Batch size: {}".format(batch_size))
logging.info("Number of epochs: {}".format(num_epochs))

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

logging.info("Transform: {}".format(transform))

In [None]:
dataset = load_dataset("Cilem/histopathology")
split_datasets = dataset["train"].train_test_split(test_size=0.15, seed=42)
train_dset = split_datasets["train"]
train_dataset = CustomDatasetFromSource(train_dset, transform=transform)
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

logging.info("Dataset: {}".format(train_dset))
logging.info("Train dataset size: {}".format(len(train_dataset)))

In [None]:
prompt = "a histopathology image"

logging.info("Prompt: {}".format(prompt))

In [None]:
for epoch in range(num_epochs):
    for batch in dataloader:
        images = batch["image"].to(device)
        text = [prompt] * images.shape[0]
        text = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        encoder_hidden_states = text_encoder(**text).last_hidden_state
        encoder_hidden_states = encoder_hidden_states.to(device)

        latents = vae.encode(images).latent_dist.sample() 
        latents = latents * 0.18215 

        noise = torch.randn_like(latents).to(device)
        timesteps = torch.randint(0, scheduler.num_train_timesteps, (latents.shape[0],), device=device).long()

        noisy_latents = scheduler.add_noise(latents, noise, timesteps)  
        
        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample

        loss = F.mse_loss(noise_pred, noise)
        loss.backward(retain_graph=True)
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")
    logging.info(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5",
    unet=unet,
    vae=vae,
    scheduler=scheduler,
    tokenizer=pipeline.tokenizer,
    text_encoder=pipeline.text_encoder
)
pipe.save_pretrained(save_name)

logging.info("Model saved as {}".format(save_name))