# Matrix multiplications, homework

> Implementing negative prompts

In [1]:
#| default_exp matmul_hw

In [15]:
#| export
from typing import List

import torch
from diffusers import LMSDiscreteScheduler, StableDiffusionPipeline
from slowai.overview import TORCH_DEVICE, StableDiffusion
from tqdm import tqdm

Negative prompts are an extension of the Classifier Free Guidance Module. Recall this is part of the `pred_noise` method of `StableDiffusion`

In [3]:
StableDiffusion.pred_noise?

[0;31mSignature:[0m [0mStableDiffusion[0m[0;34m.[0m[0mpred_noise[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mprompt_embedding[0m[0;34m,[0m [0ml[0m[0;34m,[0m [0mt[0m[0;34m,[0m [0mguidance_scale[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      ~/Code/SlowAI/slowai/overview.py
[0;31mType:[0m      function

Let's define a helper method to load StableDiffusion, as in the "Overview" notebook

In [4]:
#| export
def get_stable_diffusion(cls=StableDiffusion) -> StableDiffusion:
    pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
    # Use a simple noising scheduler for the initial draft
    pipe.scheduler = LMSDiscreteScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        num_train_timesteps=1000,
    )
    pipe = pipe.to(TORCH_DEVICE)
    pipe.enable_attention_slicing()
    return cls(
        tokenizer=pipe.tokenizer,
        text_encoder=pipe.text_encoder,
        scheduler=pipe.scheduler,
        unet=pipe.unet,
        vae=pipe.vae,
    )

In [5]:
sd = get_stable_diffusion()

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


`prompt_embedding` is a tensor four-rank tensor of `batch_size x seq_len x channels`, where the batch size is `2` because its the concatenated unconditional prompt and the conditional prompt.

In [6]:
sd.embed_prompt("a photo of a giraffe in paris").shape

torch.Size([2, 77, 768])

We want to add the negative prompt and run this through the denoising unet at the same time. This should make the batch size into `3`.

In [7]:
#| export
def embed_prompt(sd, prompt, max_length):
    prompt_tokens = sd.tokenizer(
        prompt,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    with torch.no_grad():
        return sd.text_encoder(
            prompt_tokens.input_ids.to(TORCH_DEVICE)
        ).last_hidden_state

In [8]:
class StableDiffusionWithNegativePrompt(StableDiffusion):
    def embed_prompt(self, prompt, negative_prompt):
        orig_embedding = super().embed_prompt(prompt)
        _, max_length, _ = orig_embedding.shape
        neg_text_embeddings = embed_prompt(self, negative_prompt, max_length)
        return torch.cat([orig_embedding, neg_text_embeddings])


sd = get_stable_diffusion(StableDiffusionWithNegativePrompt)
embedding = sd.embed_prompt("a photo of a giraffe in paris", "blurry")
embedding.shape

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


Now, we need to pretty much totally rewrite the denoising method to incorporate this negative guidance.

In [10]:
#| export
def pred_noise(sd, prompt_embedding, l, t, guidance_scale_pos, guidance_scale_neg):
    latent_model_input = torch.cat([l] * 3)  # note all 3 latents injected with prompt
    # Scale the initial noise by the variance required by the scheduler
    latent_model_input = sd.scheduler.scale_model_input(latent_model_input, t)
    with torch.no_grad():
        noise_pred = sd.unet(
            latent_model_input, t, encoder_hidden_states=prompt_embedding
        ).sample
    noise_pred_uncond, noise_pred_text_pos, noise_pred_text_neg = noise_pred.chunk(3)
    noise_pred = noise_pred_uncond
    noise_pred += guidance_scale_pos * (noise_pred_text_pos - noise_pred_uncond)
    noise_pred -= guidance_scale_neg * (noise_pred_text_neg - noise_pred_uncond)
    return noise_pred

In [11]:
pred_noise(sd, embedding, sd.init_latents(), 0, 7.5, 2).shape

torch.Size([1, 4, 64, 64])

Finally, we incorporate the negative prompt into the class API.

In [16]:
#| export
class StableDiffusionWithNegativePrompts(StableDiffusion):
    def embed_prompt(self, prompt, negative_prompt):
        orig_embedding = super().embed_prompt(prompt)
        _, max_length, _ = orig_embedding.shape
        neg_text_embeddings = embed_prompt(self, negative_prompt, max_length)
        return torch.cat([orig_embedding, neg_text_embeddings])

        def denoise(
            self,
            prompt_embedding,
            guidance_scale_pos,
            guidance_scale_neg,
            l,  # latents
            t,  # timestep
            i,  # global progress
        ):
            noise_pred = self.pred_noise(
                self, prompt_embedding, l, t, guidance_scale_pos, guidance_scale_neg
            )
            return self.scheduler.step(noise_pred, t, l).prev_sample

    def __call__(
        self,
        prompt,
        negative_prompt,
        guidance_scale=7.5,
        neg_guidance_scale=2,
        n_inference_steps=30,
        as_pil=False,
    ):
        prompt_embedding = self.embed_prompt(prompt, negative_prompt)
        l = self.init_latents()
        self.init_schedule(n_inference_steps)
        # Note that the time steps aren't neccesarily 1, 2, 3, etc
        for i, t in tqdm(enumerate(self.scheduler.timesteps), total=n_inference_steps):
            # workaround for ARM Macs where float64's are not supported
            t = t.to(torch.float32).to(TORCH_DEVICE)
            l = self.denoise(
                prompt_embedding, guidance_scale, neg_guidance_scale, l, t, i
            )
        return decompress(l, vae, as_pil=as_pil)


StableDiffusionWithNegativePrompts.pred_noise = pred_noise

In [17]:
sd = get_stable_diffusion(StableDiffusionWithNegativePrompts)

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [18]:
sd(
    "leonardo da vinci painting of barack obama, renaissance masterpiece",
    "amateur, ugly, disfigured",
    as_pil=True,
)

  0%|                                                  | 0/30 [00:00<?, ?it/s]


TypeError: pred_noise() missing 1 required positional argument: 'guidance_scale_neg'

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()