#Setup

In [None]:
!pip install -q git+https://github.com/keras-team/keras-cv.git

In [None]:
import time
import keras_cv
from tensorflow import keras
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
from IPython.display import Image as IImage

#Utility Functions

In [None]:
def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")

def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):
    if rubber_band:
        images += images[2:-1][::-1]
    images[0].save(
        filename,
        save_all=True,
        append_images=images[1:],
        duration=1000 // frames_per_second,
        loop=0,
    )

#Stable Diffusion model

In [None]:
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)

Lets try a simple prompt

In [None]:
images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)

plot_images(images)

Try out a more complex prompt

In [None]:
images = model.text_to_image(
    "your detailed prompt here "
    "your detailed prompt here "
    "your detailed prompt here ",
    batch_size=3,
)
plot_images(images)

#Improve performance of the model

##Standard model

In [None]:
start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
plot_images(images)

print(f"Standard model: {(end - start):.2f} seconds")
keras.backend.clear_session()  # Clear session to preserve memory.

## With Mixed precision

In [None]:
keras.mixed_precision.set_global_policy("mixed_float16")

In [None]:
# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)

Try and check out the performance for yourself after enabeling mixed precision

In [None]:
# your code here

## With XLA compilation
TensorFlow comes with the XLA: Accelerated Linear Algebra compiler built-in. keras_cv.models.StableDiffusion supports a jit_compile argument out of the box. Setting this argument to True enables XLA compilation, resulting in a significant speed-up.

Let's use this below:

In [None]:
model = keras_cv.models.StableDiffusion(jit_compile=True)
# Before we benchmark the model, we run inference once to make sure the TensorFlow
# graph has already been traced.
images = model.text_to_image("An avocado armchair", batch_size=3)
plot_images(images)

Try and check out the performance for yourself after enabeling XLA compilation

In [None]:
# your code here

## Try them together
test out your model with both mixed precision and XLA compilation

In [None]:
# your code here

# A walk through latent space with Stable Diffusion

## Interpolating between text prompts

In [None]:
prompt_1 = "A watercolor painting of a Golden Retriever at the beach"
prompt_2 = "A still life DSLR photo of a bowl of fruit"
interpolation_steps = 5

encoding_1 = tf.squeeze(model.encode_text(prompt_1))
encoding_2 = tf.squeeze(model.encode_text(prompt_2))

interpolated_encodings = tf.linspace(encoding_1, encoding_2, interpolation_steps)

# Show the size of the latent manifold
print(f"Encoding shape: {encoding_1.shape}")

In [None]:
seed = 12345
noise = tf.random.normal((512 // 8, 512 // 8, 4), seed=seed)

images = model.generate_image(
    interpolated_encodings,
    batch_size=interpolation_steps,
    diffusion_noise=noise,
)

In [None]:
export_as_gif(
    "doggo-and-fruit-5.gif",
    [Image.fromarray(img) for img in images],
    frames_per_second=2,
    rubber_band=True,
)
IImage("doggo-and-fruit-5.gif")

## Try out more fine-grained interpolation, using hundreds of steps

In [None]:
# your code here

## Break Task [optional] : Implement Interpolation for 4 prompts

In [None]:
# your code here