Skip to content

TensorFlow implementation of the "Prompt-to-Prompt Image Editing with Cross Attention Control" for Stable Diffusion

License

Notifications You must be signed in to change notification settings

miguelCalado/prompt-to-prompt-tensorflow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

21 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Prompt-to-Prompt: Tensorflow Implementation

Open In Colab Hugging Face Demo

Unofficial Implementation of the paper Prompt-to-Prompt Image Editing with Cross Attention Control

teaser

Link to the paper | Official PyTorch implementation | Project page

This repository contains the Tensorflow/Keras code implementation for the paper "Prompt-to-Prompt Image Editing with Cross Attention Control".

πŸš€ Quickstart

Current state-of-the-art methods require the user to provide a spatial mask to localize the edit which ignores the original structure and content within the masked region. The paper proposes a novel technique to edit the generated content of large-scale language models such as DALLΒ·E 2, Imagen or Stable Diffusion, by only manipulating the text of the original parsed prompt.

To achieve this result, the authors present the Prompt-to-Prompt framework comprised of two functionalities:

  • Prompt Editing: where the key idea to edit the generated images is to inject cross-attention maps during the diffusion process, controlling which pixels attend to which tokens of the prompt text.

  • Attention Re-weighting: that amplifies or attenuates the effect of a word in the generated image. This is done by first attributing a weight to each token and later scaling the attention map assigned to the token. It's a nice alternative to negative prompting and multi-prompting.

βš™οΈ Installation

Install dependencies using the requirements.txt.

pip install -r requirements.txt

Essentially, you need to have installed TensorFlow and Keras-cv.

πŸ“š Notebooks

Try it yourself:

🎯 Prompt-to-Prompt Examples

To start using the Prompt-to-Prompt framework, you first need to set up a Tensorflow strategy for running computations across multiple devices (in case you have many).

For example, you can check the available hardware with:

gpus = tf.config.list_physical_devices("GPU")
tpus = tf.config.list_physical_devices("TPU")
print(f"Num GPUs Available: {len(gpus)} | Num TPUs Available: {len(tpus)}")

And adjust accordingly to your needs:

import tensorflow as tf

# For running on multiple GPUs
strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1", ...])
# To get the default strategy
strategy = tf.distribute.get_strategy()
...

Prompt Editing

Once the strategy is set, you can start generating images just like in Keras-cv:

# Imports
import tensorflow as tf
from stable_diffusion import StableDiffusion

generator = StableDiffusion(
    strategy=strategy,
    img_height=512,
    img_width=512,
    jit_compile=False,
)

# Generate text-to-image
img = generator.text_to_image(
    prompt="a photo of a chiwawa with sunglasses and a bandana",
    num_steps=50,
    unconditional_guidance_scale=8,
    seed=5681067,
    batch_size=1,
)
# Generate Prompt-to-Prompt
img_edit = generator.text_to_image_ptp(
    prompt="a photo of a chiwawa with sunglasses and a bandana",
    prompt_edit="a photo of a chiwawa with sunglasses and a pirate bandana",
    num_steps=50,
    unconditional_guidance_scale=8,
    cross_attn2_replace_steps_start=0.0,
    cross_attn2_replace_steps_end=1.0,
    cross_attn1_replace_steps_start=0.8,
    cross_attn1_replace_steps_end=1.0,
    seed=5681067,
    batch_size=1,
)

This generates the original and pirate bandana images shown below. You can play around and change the <bandana> and <sunglasses> attributes and many others!

teaser

Another example of prompt editing where one can control the content of the basket just by replacing a couple of words in the prompt:

img_edit = generator.text_to_image_ptp(
    prompt="a photo of basket with apples",
    prompt_edit="a photo of basket with oranges",
    num_steps=50,
    unconditional_guidance_scale=8,
    cross_attn2_replace_steps_start=0.0,
    cross_attn2_replace_steps_end=1.0,
    cross_attn1_replace_steps_start=0.0,
    cross_attn1_replace_steps_end=1.0,
    seed=1597337,
    batch_size=1,
)

The image below showcases examples where only the word <apples> was replaced with other fruits or animals. Try changing <basket> to other recipients (e.g. bowl or nest) and see what happens!

teaser

Attetion Re-weighting

To manipulate the relative importance of tokens, we've added an argument to pass in both the text_to_image and text_to_image_ptp methods. You can create an array of weights using our method create_prompt_weights.

For example, you generated a pizza that doesn't have enough pineapple on it, you can edit the weights of your prompt:

prompt = "a photo of a pizza with pineapple"
prompt_weights = generator.create_prompt_weights(prompt, [('pineapple', 2)])

This will create an array with 1's except on the pineapple word position where it will be a 2.

To generate a pizza with more pineapple (yak!), you just need to pass the variable prompt_weights to the text_to_image method:

img = generator.text_to_image(
    prompt="a photo of a pizza with pineapple",
    num_steps=50,
    unconditional_guidance_scale=8,
    prompt_weights=prompt_weights,
    seed=1234,
    batch_size=1,
)

teaser

Now you want to reduce the amount of blossom in a tree:

prompt = "A photo of a blossom tree"
prompt_weights = generator.create_prompt_weights(prompt, [('blossom', -1)])

img = generator.text_to_image(
    prompt="A photo of a blossom tree",
    num_steps=50,
    unconditional_guidance_scale=8,
    prompt_weights=prompt_weights,
    seed=1407923,
    batch_size=1,
)

Decreasing the weight associated to <blossom> will generate the following images.

teaser

Note about the cross-attention parameters

For the prompt editing method, implemented in the function text_to_image_ptp, varying the parameters that indicate in which phase of the diffusion process the edited cross-attention maps should get injected (e.g. cross_attn2_replace_steps_start, cross_attn1_replace_steps_start), may output different results (image below).

The cross-attention and prompt weights hyperparameters should be tuned according to the users' necessities and desired outputs.

teaser

More info in bloc97/CrossAttentionControl and the paper.

β˜‘οΈ TODO

  • Add tutorials and Google Colabs.
  • Add multi-batch support.
  • Add examples for Stable Diffusion 2.x.

πŸ‘¨β€πŸŽ“ References

πŸ”¬ Contributing

Feel free to open an issue or create a Pull Request.

For PRs, after implementing the changes please run the Makefile for formatting and linting the submitted code:

  • make init: to create a python environment with all the developer packages (Optional).
  • make format: to format the code.
  • make lint: to lint the code.
  • make type_check: to check for type hints.
  • make all: to run all the checks.

πŸ“œ License

Licensed under the Apache License 2.0. See LICENSE to read it in full.

About

TensorFlow implementation of the "Prompt-to-Prompt Image Editing with Cross Attention Control" for Stable Diffusion

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published