In [None]:
!pip install -q --upgrade accelerate transformers ftfy
!pip install -q git+https://github.com/huggingface/diffusers
# !pip install -q diffusers


In [None]:
!pip install grpcio==1.56.0 grpcio-tools==1.33.2

In [None]:
!pip list | grep -e grpcio -e protobuf

In [None]:
grpc_host = 'modelmesh-serving'
grpc_port = 8033

textencoder_model_name = 'textencoder'
unet_model_name = 'unet'
vaeencoder_model_name = 'vaeencoder'
vaedecoder_model_name = 'vaedecoder'

In [None]:
import sys
sys.path.append('./serving')

import grpc
import serving.grpc_predict_v2_pb2 as grpc_predict_v2_pb2
import serving.grpc_predict_v2_pb2_grpc as grpc_predict_v2_pb2_grpc

channel = grpc.insecure_channel(f"{grpc_host}:{grpc_port}")
stub = grpc_predict_v2_pb2_grpc.GRPCInferenceServiceStub(channel)


In [None]:
from PIL import Image
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler

# replace text decoder with grpc requests
# vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae", use_safetensors=True)
def vae_decoder_grpc_request(latent_sample):
    inputs = []
    inputs.append(grpc_predict_v2_pb2.ModelInferRequest().InferInputTensor())
    inputs[0].name = "latent_sample"
    inputs[0].datatype = "FP32"
    inputs[0].shape.extend([1, 4, 64, 64])
    arr = latent_sample.flatten()
    inputs[0].contents.fp32_contents.extend(arr)

    request = grpc_predict_v2_pb2.ModelInferRequest()
    request.model_name = vaedecoder_model_name
    request.inputs.extend(inputs)

    response = stub.ModelInfer(request)
    out_sample = np.frombuffer(response.raw_output_contents[0], dtype=np.float32)

    return torch.tensor(out_sample.reshape([1, 3, 512, 512]))


tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")


# replace text encoder with grpc requests
# text_encoder = CLIPTextModel.from_pretrained(
#     "runwayml/stable-diffusion-v1-5", subfolder="text_encoder", use_safetensors=False
# )

def textencoder_grpc_request(input_arr):
    inputs = []
    inputs.append(grpc_predict_v2_pb2.ModelInferRequest().InferInputTensor())
    inputs[0].name = "input_ids"
    inputs[0].datatype = "INT32"
    inputs[0].shape.extend([1, 77])
    arr = input_arr.flatten()
    inputs[0].contents.int_contents.extend(arr)

    request = grpc_predict_v2_pb2.ModelInferRequest()
    request.model_name = textencoder_model_name
    request.inputs.extend(inputs)

    response = stub.ModelInfer(request)
    text_embeddings = np.frombuffer(response.raw_output_contents[0], dtype=np.float32)

    return torch.tensor(text_embeddings.reshape([-1, 77, 768]), dtype=torch.float32)


# unet = UNet2DConditionModel.from_pretrained(
#     "runwayml/stable-diffusion-v1-5", subfolder="unet", use_safetensors=False
# )

def unet_grpc_request(encoder_hidden_states, timestep, sample):
    inputs = []
    inputs.append(grpc_predict_v2_pb2.ModelInferRequest().InferInputTensor())
    inputs[0].name = "encoder_hidden_states"
    inputs[0].datatype = "FP32"
    inputs[0].shape.extend([2, 77, 768])
    arr = encoder_hidden_states.flatten()
    inputs[0].contents.fp32_contents.extend(arr)

    inputs.append(grpc_predict_v2_pb2.ModelInferRequest().InferInputTensor())
    inputs[1].name = "timestep"
    inputs[1].datatype = "INT64"
    inputs[1].shape.extend([2, 1])
    arr = timestep.flatten()
    inputs[1].contents.int64_contents.extend(arr)

    inputs.append(grpc_predict_v2_pb2.ModelInferRequest().InferInputTensor())
    inputs[2].name = "sample"
    inputs[2].datatype = "FP32"
    inputs[2].shape.extend([2, 4, 64, 64])
    arr = sample.flatten()
    inputs[2].contents.fp32_contents.extend(arr)

    request = grpc_predict_v2_pb2.ModelInferRequest()
    request.model_name = unet_model_name
    request.inputs.extend(inputs)

    response = stub.ModelInfer(request)
    out_sample = np.frombuffer(response.raw_output_contents[0], dtype=np.float32)

    return torch.tensor(out_sample.reshape([-1, 4, 64, 64]), dtype=torch.float32)

In [None]:
from diffusers import DDIMScheduler

scheduler = DDIMScheduler.from_pretrained("cfchase/stable-diffusion-rhteddy", subfolder="scheduler")

In [None]:
# just use cpu and offload gpu requests to grpc server
torch_device = "cpu"

#replace inference with gRPC
# vae.to(torch_device)
# text_encoder.to(torch_device)
# unet.to(torch_device)

In [None]:
prompt = ["a photo of a rhteddy dog on the beach"]
height = 512  # default height of Stable Diffusion
width = 512  # default width of Stable Diffusion
num_inference_steps = 50  # Number of denoising steps
guidance_scale = 7.5  # Scale for classifier-free guidance

generator = torch.manual_seed(10)  # Seed generator to create the inital latent noise

batch_size = len(prompt)

In [None]:
import numpy as np

text_input = tokenizer(
    prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
)
with torch.no_grad():
    text_encoder_args = text_input.input_ids.to(torch_device)
    
    # replace text encoder with grpc requests
    # text_embeddings = text_encoder(text_encoder_args)[0]
    text_embeddings = textencoder_grpc_request(text_input.input_ids.numpy())


In [None]:
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")

# replace text encoder with grpc requests
# uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
uncond_embeddings = textencoder_grpc_request(uncond_input.input_ids.numpy())

In [None]:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

In [None]:
latents = torch.randn(
    #(batch_size, unet.in_channels, height // 8, width // 8),
    (batch_size, 4, height // 8, width // 8),
    generator=generator,
)
latents = latents.to(torch_device)

In [None]:
from tqdm.auto import tqdm
import torch

scheduler.set_timesteps(num_inference_steps)

for t in tqdm(scheduler.timesteps):
    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
    latent_model_input = torch.cat([latents] * 2)

    latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
    
    # replace unet encoder with grpc requests

    # predict the noise residual
    # with torch.no_grad():
        # noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

    # torch.tensor([t, t]) instead of t to workaround batch error on triton grpc
    noise_pred = unet_grpc_request(text_embeddings, torch.tensor([t, t]), latent_model_input)

    # perform guidance
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    # compute the previous noisy sample x_t -> x_t-1
    latents = scheduler.step(noise_pred, t, latents).prev_sample

In [None]:
import numpy as np

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents

# replace vae decoder with grpc requests
# with torch.no_grad():
    # image = vae.decode(latents).sample   
image = vae_decoder_grpc_request(latents)

In [None]:
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
images = (image * 255).round().astype("uint8")
image = Image.fromarray(image)
image