## Dependencies

In [None]:
import json
import requests
import numpy as np
import matplotlib.pyplot as plt
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer

## Auxiliary functions

In [None]:
def predict_rest(json_data, url):
    json_response = requests.post(url, data=json_data)
    response = json.loads(json_response.text)
    rest_outputs = np.array(response["predictions"])
    return rest_outputs


def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")

## Parameters

In [None]:
tokenizer = SimpleTokenizer()
tokenizer.add_tokens("<token>")
MAX_PROMPT_LENGTH = 77
PADDING_TOKEN = 49407
batch_size = 1
num_steps = 1

text_encoder_url = "http://localhost:8501/v1/models/text_encoder:predict"
diffusion_model_url = "http://localhost:8501/v1/models/diffusion_model:predict"
decoder_url = "http://localhost:8501/v1/models/decoder:predict"

## Inference

### Text encoder

In [None]:
text = "An image of a squirrel in Picasso style"
tokens = tokenizer.encode(text)
tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens))

data = [{
    "tokens": tokens, 
    "batch_size": batch_size
    }]
json_data = json.dumps({"signature_name": "serving_default", "instances": data})

encoded_text = predict_rest(json_data, text_encoder_url)
print(f"REST output shape: {encoded_text.shape}")
print(encoded_text[0].keys())
print(np.array(encoded_text[0]["context"]).shape)
print(np.array(encoded_text[0]["unconditional_context"]).shape)

### Diffusion model

In [None]:
data = [{
    "context": encoded_text[0]["context"], 
    "unconditional_context": encoded_text[0]["unconditional_context"],
    # "num_steps": num_steps, 
    # "batch_size": batch_size, 
    }]
json_data = json.dumps({"signature_name": "serving_default", "instances": data})

latents = predict_rest(json_data, diffusion_model_url)
print(f"REST output shape: {latents.shape}")

### Decoder

In [None]:
data = [{"latent": latents[0].tolist()}]
json_data = json.dumps({"signature_name": "serving_default", "instances": data})

decoded_images = predict_rest(json_data, decoder_url)
print(f"REST output shape: {decoded_images.shape}")

## Generated images

In [None]:
plot_images(generate_images_from_text("An image of a squirrel in Picasso style"))