In [None]:
import numpy as np
import tritonclient.http
from diffusers.utils import load_image
from PIL import Image
import cv2


In [None]:
# model
model_name = "stable_diffusion_cnet_bls"
url = "0.0.0.0:8000"
model_version = "1"
batch_size = 1

In [None]:
image = load_image(
    "./input_image_vermeer.png"
)

image = np.array(image)

low_threshold = 100
high_threshold = 200

image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)

In [None]:
# model input params
prompt = "A Corgi is flying in the red sky"
negative_prompt = "bad quality"
samples = 1 
scheduler = "PNDMScheduler"
steps = 20
guidance_scale = 7.5
seed = 1024

In [None]:
triton_client = tritonclient.http.InferenceServerClient(url=url, verbose=False)
assert triton_client.is_model_ready(
    model_name=model_name, model_version=model_version
), f"model {model_name} not yet ready"

In [None]:
# Input placeholder
prompt_in = tritonclient.http.InferInput(name="PROMPT", shape=(batch_size,), datatype="BYTES")
negative_prompt_in = tritonclient.http.InferInput(name="NEGATIVE_PROMPT", shape=(batch_size,), datatype="BYTES")
pose_image = tritonclient.http.InferInput("POSE_IMAGE", canny_image.shape, "FP32")
scheduler_in = tritonclient.http.InferInput(name="SCHEDULER", shape=(batch_size,), datatype="BYTES")
steps_in = tritonclient.http.InferInput("STEPS", (batch_size, ), "INT32")
guidance_scale_in = tritonclient.http.InferInput("GUIDANCE_SCALE", (batch_size, ), "FP32")
seed_in = tritonclient.http.InferInput("SEED", (batch_size, ), "INT64")

images = tritonclient.http.InferRequestedOutput(name="IMAGES", binary_data=False)

In [None]:
prompt_in.set_data_from_numpy(np.asarray([prompt] * batch_size, dtype=object))
negative_prompt_in.set_data_from_numpy(np.asarray([negative_prompt] * batch_size, dtype=object))
pose_image.set_data_from_numpy(np.asarray([samples], dtype=np.int32))
scheduler_in.set_data_from_numpy(np.asarray([scheduler] * batch_size, dtype=object))
steps_in.set_data_from_numpy(np.asarray([steps], dtype=np.int32))
guidance_scale_in.set_data_from_numpy(np.asarray([guidance_scale], dtype=np.float32))
seed_in.set_data_from_numpy(np.asarray([seed], dtype=np.int64))

response = triton_client.infer(
    model_name=model_name, model_version=model_version, 
    inputs=[prompt_in,negative_prompt_in,pose_image,scheduler_in,steps_in,guidance_scale_in,seed_in], 
    outputs=[images]
)

In [None]:
images = response.as_numpy("IMAGES")


In [None]:
from PIL import Image
if images.ndim == 3:
    images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]

In [None]:
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
rows = 1 # change according to no.of samples 
cols = 1 # change according to no.of samples
# rows * cols == no.of samples
image_grid(pil_images, rows, cols)