# LoRACLR Inference Notebook

## Imports & Constants

In [None]:
import os
import json
import torch

from PIL import Image

from regionally_controlable_sampling import build_model, prepare_text, sample_image

In [None]:
DEVICE = "cuda"
CONFIG_FILE = "elsa+moana"
PRETRAINED_MODEL = f"experiments/multi-concepts/{CONFIG_FILE}/combined_model_base"


SUBJECTS = CONFIG_FILE.split("+")
TOKs = [f'<{subject}1> <{subject}2>' for subject in SUBJECTS] 

POSES = {}
with open("multi-concept/pose_data/pose.json") as f:
    POSES = json.load(f)

In [None]:
SEED = 6262

POSE = POSES[0]
REGION1 = POSE["region1"]
REGION2 = POSE["region2"]
KEYPOSE_CONDITION = POSE["img_dir"]
KEYPOSE_ADAPTOR_WEIGHT = 1.0

CONTEXT_PROMPT = "in a forest, standing"
CONTEXT_PROMPT += ", 4K, high quality, high resolution, best quality"

NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'

REGION1_PROMPT = f"{TOKs[0]}, {CONTEXT_PROMPT}"
REGION2_PROMPT = f"{TOKs[1]}, {CONTEXT_PROMPT}"

CONTEXT_PROMPT = "Two people, " + CONTEXT_PROMPT
PROMPT_REWRITE = f"{REGION1_PROMPT}-*-{NEGATIVE_PROMPT}-*-{REGION1}|{REGION2_PROMPT}-*-{NEGATIVE_PROMPT}-*-{REGION2}"

In [None]:
def infer(
    pipe,
    prompt,
    prompt_rewrite,
    negative_prompt="",
    seed=16141,
    keypose_condition=None,
    keypose_adaptor_weight=1.0,
    sketch_condition=None,
    sketch_adaptor_weight=0.0,
    region_sketch_adaptor_weight="",
    region_keypose_adaptor_weight="",
):
    if sketch_condition is not None and os.path.exists(sketch_condition):
        sketch_condition = Image.open(sketch_condition).convert("L")
        width_sketch, height_sketch = sketch_condition.size
        print("use sketch condition")
    else:
        sketch_condition, width_sketch, height_sketch = None, 0, 0
        print("skip sketch condition")

    if keypose_condition is not None and os.path.exists(keypose_condition):
        keypose_condition = Image.open(keypose_condition).convert("RGB")
        width_pose, height_pose = keypose_condition.size
        print("use pose condition")
    else:
        keypose_condition, width_pose, height_pose = None, 0, 0
        print("skip pose condition")

    if width_sketch != 0 and width_pose != 0:
        assert (
            width_sketch == width_pose and height_sketch == height_pose
        ), "conditions should be same size"
    width, height = max(width_pose, width_sketch), max(height_pose, height_sketch)
    kwargs = {
        "sketch_condition": sketch_condition,
        "keypose_condition": keypose_condition,
        "height": height,
        "width": width,
    }

    prompts = [prompt]
    prompts_rewrite = [prompt_rewrite]
    input_prompt = [
        prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)
    ]
    save_prompt = input_prompt[0][0]
    print(save_prompt)

    image = sample_image(
        pipe,
        input_prompt=input_prompt,
        input_neg_prompt=[negative_prompt] * len(input_prompt),
        generator=torch.Generator(DEVICE).manual_seed(seed),
        guidance_scale=8.5,
        sketch_adaptor_weight=sketch_adaptor_weight,
        region_sketch_adaptor_weight=region_sketch_adaptor_weight,
        keypose_adaptor_weight=keypose_adaptor_weight,
        region_keypose_adaptor_weight=region_keypose_adaptor_weight,
        **kwargs,
    )

    return image[0]

## Model Build

In [None]:
pipe = build_model(PRETRAINED_MODEL, DEVICE)

In [None]:
image = infer(
    pipe,
    CONTEXT_PROMPT,
    PROMPT_REWRITE,
    NEGATIVE_PROMPT,
    SEED,
    KEYPOSE_CONDITION,
    KEYPOSE_ADAPTOR_WEIGHT,
)

image