In [1]:
import torch
from tqdm.auto import tqdm
from PIL import Image

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers import LMSDiscreteScheduler


def latents_to_pil(latents, vae):
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images

class StableDiffusion:
    def __init__(
            self,
            sd_model="CompVis/stable-diffusion-v1-4",
            encoder_model="openai/clip-vit-large-patch14",
            max_embeddings_length=77,
        ):
        self.tokenizer = CLIPTokenizer.from_pretrained(encoder_model, torch_dtype=torch.float16)
        self.text_encoder = CLIPTextModel.from_pretrained(encoder_model, torch_dtype=torch.float16).to("cuda")
        self.vae = AutoencoderKL.from_pretrained(sd_model, subfolder="vae", torch_dtype=torch.float16).to("cuda")
        self.unet = UNet2DConditionModel.from_pretrained(sd_model, subfolder="unet", torch_dtype=torch.float16).to("cuda")

        self.max_embeddings_length = max_embeddings_length
        self.unconditional_embeddings =  self.encode_text([""])

    def encode_text(self, prompts):
        inp = self.tokenizer(
            prompts,
            padding="max_length",
            max_length=self.max_embeddings_length,
            truncation=True,
            return_tensors="pt"
        ) 
        return self.text_encoder(inp.input_ids.to("cuda"))[0].half()

    def generate(self, embeddings,  g=7.5, seed=100, steps=70, dim=512):
        bs = embeddings.shape[0]
        embeddings = embeddings.to("cuda").half()

        scheduler = LMSDiscreteScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            num_train_timesteps=1000
        )
        scheduler.set_timesteps(steps)

        latents = torch.randn((bs, self.unet.config.in_channels, dim//8, dim//8))
        latents = latents.to("cuda").half() * scheduler.init_noise_sigma

        emb = torch.cat([torch.cat([self.unconditional_embeddings] * bs), embeddings])

        for i,ts in enumerate(tqdm(scheduler.timesteps)):
            inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
            with torch.no_grad():
                u,t = self.unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
                
            pred = u + g*(t-u)
            latents = scheduler.step(pred, ts, latents).prev_sample

        return latents_to_pil(latents, self.vae)


In [2]:
sd = StableDiffusion()
text_embeddings = sd.encode_text([
    "Cute picture of a dog",
    "Cute picture of a cat"
])
text_embeddings = torch.cat([
    text_embeddings,
    ((text_embeddings[0] + text_embeddings[1]) / 2).unsqueeze(0),
])

imgs = sd.generate(text_embeddings)

imgs[0].save("img/img0.png")
imgs[1].save("img/img1.png")
imgs[2].save("img/img2.png")

  0%|          | 0/70 [00:00<?, ?it/s]