In [None]:
import os
import pickle
import logging
import torch
import torch.optim as optim
import torch.nn.functional as F

from accelerate import Accelerator
from utils.dataset_loader import CustomDatasetWithLatent
from torch.utils.data import DataLoader

from diffusers import StableDiffusionPipeline
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor

In [None]:
accelerator = Accelerator(
    mixed_precision="fp16",
    log_with="tensorboard",
    project_dir="./logs/training")

accelerator.state.device = torch.device('cuda:2')

if not os.path.exists('logs'):
    os.makedirs('logs')

In [None]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
save_name ="histopathology-diffusion-e2i-256"
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 
                    filename=f'logs/{save_name}.log')

In [None]:
vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="safety_checker")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

unet = UNet2DConditionModel(
    act_fn="silu",
    attention_head_dim=8,
    center_input_sample=False,
    downsample_padding=1,
    flip_sin_to_cos=True,
    freq_shift=0,
    mid_block_scale_factor=1,
    norm_eps=1e-05,
    norm_num_groups=32,
    sample_size=32, # generated samples are 512x512
    in_channels=4, 
    out_channels=4, 
    layers_per_block=2, 
    block_out_channels=(320, 640, 1280, 1280), 
    down_block_types=(
    "CrossAttnDownBlock2D",
    "CrossAttnDownBlock2D",
    "CrossAttnDownBlock2D",
    "DownBlock2D"), 
    up_block_types=("UpBlock2D",
    "CrossAttnUpBlock2D",
    "CrossAttnUpBlock2D",
    "CrossAttnUpBlock2D"),
    cross_attention_dim=384
)

In [None]:
optimizer = optim.AdamW(unet.parameters(), lr=1e-5)
batch_size = 32
num_epochs = 10

logging.info("Optimizer: {}".format(optimizer))
logging.info("Batch size: {}".format(batch_size))
logging.info("Number of epochs: {}".format(num_epochs))

In [None]:
dataset_with_latents = pickle.load(open("./latent_files/dataset_with_latents_e2i.pkl", "rb"))
dataset_with_latents = CustomDatasetWithLatent(dataset_with_latents)
unet_dataloader = DataLoader(dataset_with_latents, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=24)

logging.info("Dataset with latents loaded from {}".format("./latent_files/dataset_with_latents_e2i.pkl.pkl"))
logging.info("Dataset size: {}".format(len(dataset_with_latents)))

unet.to(device)

In [None]:
unet, optimizer, unet_dataloader = accelerator.prepare(unet, optimizer, unet_dataloader, device_placement=[False, False, True])

In [None]:
for epoch in range(num_epochs):
    for batch in unet_dataloader:
        latents = batch["latent"].to(device)
        embeddings = batch["embedding"].to(device)

        noise = torch.randn_like(latents).to(device)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()

        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)  
        
        with accelerator.accumulate(unet):
            noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=embeddings).sample

            loss = F.mse_loss(noise_pred, noise)
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")
    logging.info(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

In [None]:
pipe = StableDiffusionPipeline(
    vae=vae,
    unet=unet,
    scheduler=noise_scheduler,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    safety_checker=safety_checker,
    feature_extractor=feature_extractor
)

pipe.save_pretrained(save_name)

logging.info("Model saved as {}".format(save_name))