In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import cv2
import torch
import gradio as gr
import time
import random
import numpy as np
from PIL import Image

from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
from controlnet_aux import OpenposeDetector

class ControleNet_generator:
    def __init__(   self, 
                    controlnet_model_dir, 
                    pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base",
                    progress_bar_disable = True,
                    ) -> None:
        
        self.controlnet = ControlNetModel.from_pretrained(str(controlnet_model_dir), torch_dtype=torch.float32)

        self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
            pretrained_model_name_or_path=pretrained_model_name_or_path, 
            controlnet=self.controlnet, 
            safety_checker=None,
            revision=None,
            torch_dtype=torch.float32
        )

        self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
        self.pipe = self.pipe.to("cuda")
        # self.pipe.enable_model_cpu_offload()

        self.pipe.set_progress_bar_config(disable=progress_bar_disable)
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

        return None

    def generate(
                self,
                source_image:np.ndarray, 
                positive_prompt:str     = "", 
                negative_prompt:str     = "low quality, worst quality, bad fingers, out of focus, bad face, extra arms, extra legs, blurry, bokeh, ugly", 
                seed:int                = -1, 
                num_inference_steps:int = 20, 
                device:str              = "cuda", 
            ):
        
        if seed < 1:
            generator = torch.Generator(device=device).manual_seed(random.randint(1, 10000000))
        else:
            generator = torch.Generator(device=device).manual_seed(seed)

        with torch.autocast("cuda"):
            output = self.pipe(
                                    prompt              = positive_prompt,
                                    image               = Image.fromarray(source_image),
                                    negative_prompt     = negative_prompt,
                                    generator           = generator,
                                    num_inference_steps = num_inference_steps,
                                )
        
        return np.array(output.images[0])

openpose_model = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
controlnet_model_dir = "fusing/stable-diffusion-v1-5-controlnet-openpose"
controlnet_model = ControleNet_generator(str(controlnet_model_dir), pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", progress_bar_disable=False)

img_size = (512, 512)

def get_pose_img(img):
    img = cv2.resize(img, img_size)
    pose_img = openpose_model(img)
    pose_img = np.array(pose_img)
    pose_img = cv2.resize(pose_img, img_size)
    return pose_img

def get_generated_img(pose_img, gen_prompt):
    pose_img = cv2.resize(pose_img, img_size)
    gen_img = controlnet_model.generate(pose_img, positive_prompt=gen_prompt, num_inference_steps=10)
    return gen_img

def process(img, gen_prompt="best quality"):
    img = cv2.resize(img, img_size)
    pose_img = get_pose_img(img)
    gen_img = get_generated_img(pose_img, gen_prompt)
    return cv2.hconcat([pose_img, gen_img]), img


examples = [
                [None, "super-hero character, powerful, masterpiece, best quality, extremely detailed"], 
                [None, "an astronaut on the moon, digital art"], 
                [None, "Dancing Darth Vader, best quality, extremely detailed"]
            ]

demo = gr.Interface(
    process, 
    # [   gr.Image(source="webcam", streaming=True), 
    [   gr.Image(), 
        gr.inputs.Textbox(lines=2, label="Prompt", default="super-hero character, powerful, masterpiece, best quality, extremely detailed")
    ], 
    ["image", "image"],
    label="Prompt Example",
    examples=examples,
)
demo.launch(server_name="172.16.15.127", server_port=9999, debug=True)
# demo.launch(share=True)
    