# DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance

In this notebook, we will implement [DiffEdit](https://arxiv.org/abs/2210.11427), an interesting paper discussed in the fastai course Lesson 11. 

**One-sentence summary:** Modify an input image in accordance with a textual transformation query, while otherwise leaving the image as close as possible to the original.

![Figure 1 of the DiffEdit paper](./images/diffedit_task.png)

## The problem
Let us properly define the problem and introduce the notation.
Given:
* $x_0$, an image, $e.g.$ !["A bowl of fruits"](./images/bowl_fruits.png)
* $R$, a Text Reference (aka caption), $e.g.$ "A bowl of fruits", and
* $Q$, a Textual Transformation Query, $e.g.$ $\text{fruits} \to \text{pears}$

Generate:
* $y_0$, a minimally modified version of $x_0$ in accordance to the transformation query $q$ !["A bowl of pears"](./images/bowl_pears.png)


## Method overview
The method consist of 3 main steps:
1. $M_{\eta}(x_0,R,Q)\to M$, compute mask (with noise ratio $\eta$);
2. $E_{r}(x_0,Q = \emptyset) \to x_r$, encode input image with edit strength $r$ and no-conditioning ($Q=\emptyset$);
3. $D_{r,M}(x_t, Q) \to y_0$, decode with edit strength $r$, mask guidance and prompt conditioning. 


#|hide

## Setup

We will be using Huggingface Diffusers in this implementation.

### Importing Libraries

In [21]:
#|hide

import torch
from torch import autocast
from transformers import CLIPModel, CLIPVisionModel, CLIPProcessor
from transformers import logging
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler, DDIMScheduler
from tqdm.auto import tqdm
from PIL import Image
from matplotlib import pyplot as plt
import numpy
from torchvision import transforms as tfms
from fastdownload import FastDownload

# Supress some unnecessary warnings when loading the CLIPTextModel
logging.set_verbosity_error()

from wwf.utils import *
state_versions(['fastai', 'torch', 'diffusers'])


---
This article is also a Jupyter Notebook available to be run from the top down. There
will be code snippets that you can then run in any environment.

Below are the versions of `fastai`, `torch`, and `diffusers` currently running at the time of writing this:
* `fastai` : 2.7.10 
* `torch` : 1.13.0 
* `diffusers` : 0.7.2 
---

In [9]:
#|hide

# Set device
cast_device = "cuda" if torch.cuda.is_available() else "cpu"
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps else "cpu"
{'cast':cast_device, 'torch':torch_device}

{'cast': 'cpu', 'torch': 'mps'}

#|hide

### Loading Models

In [19]:
#|hide

# The CLIP Model for generating the embeddings
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
# Load the autoencoder model which will be used to decode the latents into image space. 
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# Attention slicing - improves performance on macOS
if torch.has_mps:
    slice_size = unet.config.attention_head_dim // 2
    unet.set_attention_slice(slice_size)

# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
#DDIM scheduler
ddim = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
tokenizer = processor.tokenizer

# To the GPU we go!
vae = vae.to(torch_device)
text_encoder = model.text_model.to(torch_device)
image_encoder = model.vision_model.to(torch_device)
unet = unet.to(torch_device)

### Utility functions

In [43]:
def pil_to_latent(input_im):
    # Single image -> single latent in a batch (so size 1, 4, 64, 64)
    with torch.no_grad():
        latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
    return 0.18215 * latent.latent_dist.sample()

def latents_to_pil(latents):
    # bath of latents -> list of images
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images


def generate(prompt="", width=512, height=512, steps=30, guidance=7.5, seed=42):
    prompts = ["", prompt]
    inputs = processor(text=prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        t_e = text_encoder(input_ids=inputs.input_ids.to(torch_device))[0].half()
    generator = torch.manual_seed(seed)
    latents = generate_seed_latent(generator, width, height)
    latents = torch.randn((1, unet.in_channels, height // 8, width // 8), generator=generator)
    latents = latents.to(torch_device)
    latents = latents * scheduler.init_noise_sigma
    scheduler.set_timesteps(steps)
    frames = []
    # denoising loop
    with autocast(cast_device):
        for i, t in tqdm(enumerate(scheduler.timesteps)):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            sigma = scheduler.sigmas[i]
            # Scale the latents (preconditioning):
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)
            # predict the noise residual
            with torch.no_grad():
                ts = t.type(torch.float32) if torch.has_mps else t
                noise_pred = unet(latent_model_input, ts, encoder_hidden_states=t_e).sample
            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, t, latents).prev_sample
            frames.append(latents)
    return latents_to_pil(latents), seed, frames

In [45]:
image, seed, frames = generate("an astronaut riding a horse")

0it [00:00, ?it/s]

In [48]:
frames


  nonzero_finite_vals = torch.masked_select(


[tensor([[[[-1.5453e+01,  1.7413e+01,  3.0450e+00,  ..., -1.5863e+00,
            -1.6960e+01, -2.0946e+00],
           [ 1.0661e+01, -3.4488e+00, -5.8246e+00,  ..., -2.4999e-02,
             9.8572e+00,  2.4667e+00],
           [ 1.1475e+01, -7.4601e+00,  6.6181e+00,  ...,  4.6589e+00,
            -5.9077e+00,  8.0520e+00],
           ...,
           [-2.1826e+01, -4.3789e+00,  6.0778e+00,  ...,  1.2621e+01,
            -1.2625e+01, -8.4941e+00],
           [ 7.3615e+00,  5.1121e+00, -1.1762e+01,  ...,  1.3410e+00,
            -9.3642e+00,  3.2393e+00],
           [-6.9359e+00, -3.9992e+00,  5.9301e+00,  ...,  3.0132e+01,
             1.6475e+01,  5.6991e+00]],
 
          [[-3.4725e+00, -4.3543e+00, -1.9941e+01,  ..., -9.1618e+00,
            -2.1897e+01,  7.1817e+00],
           [-8.0178e+00, -4.2485e+00, -1.8010e+01,  ...,  8.7064e+00,
             2.8139e+00, -2.7796e+00],
           [-7.9254e+00,  5.1075e+00,  1.6894e+00,  ..., -2.4413e+00,
             1.2449e+01, -1.5403e+01],
