# Setup

In [None]:
!pip install git+https://github.com/miguelCalado/keras-cv -q

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for keras-cv (setup.py) ... [?25l[?25hdone


In [None]:
from matplotlib import pylab as P
import tensorflow as tf
from PIL import Image
import numpy as np

from keras_cv.models.stable_diffusion import StableDiffusion

You do not have Waymo Open Dataset installed, so KerasCV Waymo metrics are not available.


# Utility fuctions

In [None]:
def ShowImage(im, ax=None, save_fig=None):
    if ax is None:
        P.figure()
    P.xticks([])
    P.yticks([])
    P.imshow(im)
    if save_fig:
        P.savefig(
            save_fig, dpi=200, bbox_inches="tight", pad_inches=0.0, transparent="True"
        )

In [None]:
def create_attention_weights(prompt, attn_weights):
        """Create an array of weights to scale the attention maps associated with each prompt token.
        This is used for manipulating the importance of the prompt tokens,
        increasing or decreasing the importance assigned to each word.

        Args:
            prompt: The prompt string to tokenize, must be 77 tokens or shorter.
            attn_weights: A list of tuples containing the
                pair of word and weight to be manipulated.

        Returns:
            weights: Array of weights to control the importance of each prompt token.

        Example:

        ```python
        from keras_cv.models import StableDiffusion

        model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)

        prompt = "a fluffy teddy bear"
        prompt_weights = [("fluffy", -4)]
        attn_weights = generator.create_attention_weights(prompt, prompt_weights)
        ```
        """

        # Initialize the weights to 1.
        weights = np.ones(77)

        # Get the prompt tokens
        tokens = generator.tokenize_prompt(prompt)

        # Extract the weights and words
        edit_words, edit_weights = zip(*attn_weights)

        # Tokenize the words to edit
        edit_tokens = [generator.tokenizer.encode(word)[1:-1] for word in edit_words]

        # Get the indexes of the tokens
        index_edit_tokens = tf.where(tf.equal(tokens, edit_tokens))[:, -1]

        # Replace the original weight values
        weights[index_edit_tokens] = tf.constant(edit_weights)
        return weights

# Prompt to prompt editing

## Generate an image

In [None]:
# Recommendation: if you have a low memory gpu drop the batch to 1
BATCH_SIZE = 2
NUM_STEPS = 50
UNCONDITIONAL_GUIDANCE_SCALE = 8

# Stable Diffusion 1.x
generator = StableDiffusion(
    img_height=512,
    img_width=512,
    jit_compile=False,
)

# Lets start by generating some chiwawas
print("Generating pictures of chiwawas")
prompt = "a photo of a chiwawa with sunglasses"
seed = 1235
img_org = generator.text_to_image(
    prompt=prompt,
    num_steps=NUM_STEPS,
    unconditional_guidance_scale=UNCONDITIONAL_GUIDANCE_SCALE,
    seed=seed,
    batch_size=BATCH_SIZE,
)
ShowImage(img_org)

By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
Generating pictures of chiwawas
Downloading data from https://github.com/openai/CLIP/blob/main/clip/bpe_simple_vocab_16e6.txt.gz?raw=true
Downloading data from https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5
Downloading data from https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5

## Word Swap

In [None]:
tf.keras.backend.clear_session()
self_attn_steps = 0.4
cross_attn_steps = 0.6
img_edit = generator.prompt_to_prompt(
        prompt="a photo of a chiwawa with sunglasses",
        prompt_edit="a photo of a chiwawa with googles",
        method="replace",
        self_attn_steps=self_attn_steps,
        cross_attn_steps=cross_attn_steps,
        num_steps=NUM_STEPS,
        unconditional_guidance_scale=UNCONDITIONAL_GUIDANCE_SCALE,
        seed=seed,
        batch_size=BATCH_SIZE,
    )
ShowImage(img_edit)

## Prompt refinement

In [None]:
tf.keras.backend.clear_session()
img_edit = generator.prompt_to_prompt(
        prompt="a photo of a chiwawa with sunglasses",
        prompt_edit="a photo of a chiwawa with heart shaped sunglasses",
        method="refine",
        self_attn_steps=self_attn_steps,
        cross_attn_steps=cross_attn_steps,
        num_steps=NUM_STEPS,
        unconditional_guidance_scale=UNCONDITIONAL_GUIDANCE_SCALE,
        seed=seed,
        batch_size=BATCH_SIZE,
    )
ShowImage(img_edit)

## Attention Re-weight

In [None]:
tf.keras.backend.clear_session()
prompt="a fluffy teddy bear"
img_org = generator.text_to_image(
    prompt=prompt,
    num_steps=NUM_STEPS,
    unconditional_guidance_scale=UNCONDITIONAL_GUIDANCE_SCALE,
    seed=seed,
    batch_size=1,
)
ShowImage(img_edit)

In [None]:
prompt_weights = [("fluffy", -5)]
attn_weights = create_attention_weights(prompt, prompt_weights)

self_attn_steps = 0.2
attn_edit_weights = 0.6

# Clean up the session to avoid clutter from old models and layers
tf.keras.backend.clear_session()
# Generate Prompt-to-Prompt
img_edit = generator.prompt_to_prompt(
    prompt=prompt,
    prompt_edit=prompt,
    method="reweight",
    self_attn_steps=self_attn_steps,
    cross_attn_steps=attn_edit_weights,
    attn_edit_weights=attn_weights,
    num_steps=NUM_STEPS,
    unconditional_guidance_scale=UNCONDITIONAL_GUIDANCE_SCALE,
    seed=seed,
    batch_size=1,
)

## try making the teddy bear more fluffy

In [None]:
# your code here