In [None]:
from share import *
import config

import cv2
import einops
import gradio as gr
import numpy as np
import torch
import random

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler

logging improved.


In [None]:
from .predictor import Predictor

predictor = Predictor(config_path, weight_path) #config path is the yaml file
out_gen = predictor.gen(input_image, prompts) # input_image:np.ndarray and prompts a list of str

# and then to get each result
for res in out_gen:
    ...

In [None]:
class Predictor:
    def __init__(self, config_path, weight_path):
        self.apply_canny = CannyDetector()
        self.model = create_model(str(config_path)).cpu()
        self.model.load_state_dict(load_state_dict(str(weight_path), location='cuda')).cuda()
        self.ddim_sampler = DDIMSampler(model)

    def gen(input_image, prompts, **kwargs):
        for prompt in prompts:
            yield self.process(input_image, prompt, **kwargs)
                
    def process(input_image,
                prompt,
                a_prompt='',
                n_prompt='',
                num_samples=2,
                image_resolution=512,
                ddim_steps=20,
                guess_mode=False,
                strength=1.0,
                low_threshold=100,
                high_threshold=200,
                scale=9.0,
                seed=42,
                eta=0.0):
        with torch.no_grad():
            img = resize_image(HWC3(input_image), image_resolution)
            H, W, C = img.shape

            detected_map = apply_canny(img, low_threshold, high_threshold)
            detected_map = HWC3(detected_map)

            control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
            control = torch.stack([control for _ in range(num_samples)], dim=0)
            control = einops.rearrange(control, 'b h w c -> b c h w').clone()

            if seed == -1:
                seed = random.randint(0, 65535)
            seed_everything(seed)

            if config.save_memory:
                self.model.low_vram_shift(is_diffusing=False)

            cond = {"c_concat": [control], "c_crossattn": [self.model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
            un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [self.model.get_learned_conditioning([n_prompt] * num_samples)]}
            shape = (4, H // 8, W // 8)

            if config.save_memory:
                self.model.low_vram_shift(is_diffusing=True)

            self.model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)  # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
            samples, intermediates = self.ddim_sampler.sample(ddim_steps, num_samples,
                                                         shape, cond, verbose=False, eta=eta,
                                                         unconditional_guidance_scale=scale,
                                                         unconditional_conditioning=un_cond)

            if config.save_memory:
                self.model.low_vram_shift(is_diffusing=False)

            x_samples = self.model.decode_first_stage(samples)
            x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

            results = [x_samples[i] for i in range(num_samples)]
        return [255 - detected_map] + results