In [None]:
from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import PIL.Image
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm.auto import tqdm

In [None]:
# https://github.com/devforfu/diffusion-nbs

In [None]:
!pip install --upgrade pip
!pip install -Uq diffusers transformers fastcore fastdownload

In [None]:
def login():
    from huggingface_hub import notebook_login
    if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

In [None]:
login()

In [None]:
logging.disable(logging.WARNING)

In [None]:
@dataclass(frozen=True)
class DiffusionConfig:
    vae: str = "stabilityai/sd-vae-ft-ema"
    unet: str = "CompVis/stable-diffusion-v1-4"
    clip_tok: str = "openai/clip-vit-large-patch14"
    clip_enc: str = "openai/clip-vit-large-patch14"

In [None]:
def to(models, where): 
    return [m.to(where) for m in models]

In [None]:
def build(cfg: DiffusionConfig, device, half=True):
    dtype = torch.float16 if half else torch.float32
    vae, unet, tok, enc = [
        AutoencoderKL.from_pretrained(cfg.vae, torch_dtype=dtype),
        UNet2DConditionModel.from_pretrained(cfg.unet, subfolder="unet", torch_dtype=dtype),
        CLIPTokenizer.from_pretrained(cfg.clip_tok, torch_dtype=dtype),
        CLIPTextModel.from_pretrained(cfg.clip_enc, torch_dtype=dtype),
    ]
    vae, unet, enc = to([vae, unet, enc], device)
    return (vae, unet, tok, enc), device, half

In [None]:
@dataclass
class Prompt:
    positive: str
    negative: str = ""

In [None]:
class Diffusion:
    
    def __init__(self, parts, device, half=True):
        self.vae, self.unet, self.tok, self.enc = parts 
        self.device = device
        self.half = half
        
    @property
    def dtype(self): return torch.float16 if self.half else torch.float32
    
    @staticmethod
    def from_cfg(cfg, device, half=True):
        return Diffusion(*build(cfg, device, half))
    
    def embed(self, prompts):
        txt_inp = self.tok(prompts, padding="max_length", max_length=self.tok.model_max_length, truncation=True, return_tensors="pt")
        txt_emb = self.enc(txt_inp.input_ids.to(self.device))[0].to(self.dtype)
        max_len = txt_inp.input_ids.shape[-1]
        unc_inp = self.tok([""] * len(prompts), padding="max_length", max_length=max_len, return_tensors="pt")
        unc_emb = self.enc(unc_inp.input_ids.to(self.device))[0].to(self.dtype)
        return torch.cat([unc_emb, txt_emb])
    
    def latents(self, prompts, h, w):
        latents = torch.randn((len(prompts), self.unet.in_channels, h//8, w//8))
        latents = latents.to(self.device).to(self.dtype)
        return latents
    
    def denoise(self, latents, embedded, scheduler, n_steps, g=1.0):
        scheduler.set_timesteps(n_steps)
        latents *= scheduler.init_noise_sigma
        for i, ts in enumerate(tqdm(scheduler.timesteps)):
            inp = scheduler.scale_model_input(torch.cat([latents]*2), ts)
            with torch.no_grad():
                u, t = self.unet(inp, ts, encoder_hidden_states=embedded).sample.chunk(2)
            pred = u + g*(t - u)
            latents = scheduler.step(pred, ts, latents).prev_sample
        with torch.no_grad():
            decoded = self.vae.decode(1/0.18215 * latents).sample
        return (decoded/2 + 0.5).clamp(0, 1)
        
    def generate_images(self, prompts: list[Prompt], w, h, scheduler=None, n=70, gs=1.0, seeds=1):
        scheduler = scheduler or default_scheduler()
        
        if not isinstance(gs, list):
            gs = [gs]
            
        if not isinstance(seeds, list):
            seeds = [seeds]
            
        pil_images = []
        print("processing prompts:")
        print(prompts)
        
        for seed in seeds:
            torch.manual_seed(seed)
            print(f"manual seed: {seed} | g=", end="")
            
            for guidance in gs:
                print(f"{guidance}..", end="")
                  
                latents = self.latents(prompts, w, h)    
                embedded = self.embed(prompts)
                denoised = diff.denoise(latents, embedded, scheduler, n_steps=n, g=guidance)
                arrays = torch.einsum("nchw->nhwc", denoised).detach().cpu().numpy()
                images = (arrays * 255).round().astype(np.uint8)
                pil_images += [PIL.Image.fromarray(img) for img in images]
                  
            print("done!")
            
        return pil_images

In [None]:
def default_scheduler():
    beta_start,beta_end = 0.00085,0.012
    num_inference_steps = 70
    num_train_timesteps = 1000
    return LMSDiscreteScheduler(
        beta_start=beta_start, beta_end=beta_end, 
        beta_schedule="scaled_linear", 
        num_train_timesteps=num_train_timesteps)

In [None]:
def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): 
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
parts, device, half = build(DiffusionConfig(), torch.device("cuda"), half=True)

In [None]:
diff = Diffusion(parts, device, half)

In [None]:
prompts = [
    "Labrador in the style of Vermeer",
]

In [None]:
images = diff.generate_images(prompts, 512, 512, seeds=[1,2,3,4], gs=7.5)

In [None]:
len(images)

In [None]:
image_grid(images, rows=1, cols=4)