In [None]:
from pathlib import Path
import cv2
import pytest
import torch
import matplotlib.pyplot as plt
import numpy as np
from lama_cleaner.helper import load_img

from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler

#current_dir = Path(__file__).parent.absolute().resolve()
#save_dir = current_dir / "result"
save_dir = Path("result")
save_dir.mkdir(exist_ok=True, parents=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

print(f"Saving results to: {save_dir}")
print(f"Using device: {device}")


In [None]:
def get_data(
        fx: float = 1,
        fy: float = 1.0,
        img_p = "",
        mask_p = ""
):
    img = cv2.imread(str(img_p))
    img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
    mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
    mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
    return img, mask


def get_data_2(
        fx: float = 1,
        fy: float = 1.0,
        img_p = "",
        mask_p = "",
#        img_p=current_dir / "image.png",
#        mask_p=current_dir / "mask.png",
):
    print(f"Reading image from: {img_p}")
    img = cv2.imread(str(img_p))
    #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
    
    #mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
    #mask, _ = load_img(mask.read(), gray=True)

    f = open(mask_p, "rb")
    mask, alpha_channel = load_img(f.read())
    
    
    #mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
    
    
    img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
    mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
    return img, mask


In [None]:
'''
p2pSteps: 50
p2pImageGuidanceScale: 1.5
p2pGuidanceScale: 7.5
controlnet_conditioning_scale: 0.4
controlnet_method: control_v11p_sd15_canny
'''


def get_config(strategy, **kwargs):
    '''
        data = dict(
            ldm_steps=1,
            ldm_sampler=LDMSampler.plms,
            hd_strategy=strategy,
            hd_strategy_crop_margin=32,
            hd_strategy_crop_trigger_size=200,
            hd_strategy_resize_limit=200,
        )
    '''

    data = dict(
        ldm_steps=25,
        ldm_sampler=LDMSampler.plms,
        hd_strategy=strategy,
        zits_wireframe=True,
        hd_strategy_crop_margin=196,
        hd_strategy_crop_trigger_size=800,
        hd_strategy_resize_limit=2048,
        #prompt=,
        #negative_prompt=,
        use_croper=False,
        croper_x=144,
        croper_y=11,
        croper_height=512,
        croper_width=512,
        sd_scale=1,
        sd_mask_blur=5,
        sd_strength=0.75,
        sd_steps=50,
        sd_guidance_scale=7.5,
        sd_sampler="uni_pc",
        sd_seed=-1,
        sd_match_histograms=False,
        cv2_flag="INPAINT_NS",
        cv2_radius=5,
        paint_by_example_steps=50,
        paint_by_example_guidance_scale=7.5,
        paint_by_example_mask_blur=5,
        paint_by_example_seed=-1,
        paint_by_example_match_histograms=False,
        #paint_by_example_example_image=,
        p2p_steps=50,
        p2p_image_guidance_scale=1.5,
        p2p_guidance_scale=7.5,
        controlnet_conditioning_scale=0.4,
        controlnet_method="control_v11p_sd15_canny",
    )
    
    data.update(**kwargs)
    return Config(**data)


In [None]:
def display_mask_image(image_path):
    image = cv2.imread(image_path)
    #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    
    plt.figure(figsize=(20,20))
    plt.imshow(image)
    plt.axis('off')
    plt.show()

def print_grayscale_image_stats(image_path):
    #image = cv2.imread(image_path)

    image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE)
    image = cv2.resize(image, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)

    height, width = image.shape
    print(f"w: {width}, h: {height}")
    
    print(f"Image path: {image_path}")
    print(f"Image shape: {image.shape}")
    print(f"Image dtype: {image.dtype}")
    print(f"Image min: {image.min()}")

def print_color_image_stats(image_path):
    image = cv2.imread(str(image_path))
    image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
    image = cv2.resize(image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)

    height, width, channels = image.shape
    print(f"w: {width}, h: {height}, channels: {channels}")

    print(f"Image path: {image_path}")
    print(f"Image shape: {image.shape}")
    print(f"Image dtype: {image.dtype}")
    print(f"Image min: {image.min()}")
    
mask_path = "../output/dog/0.png"
image_path = "../images/dog.jpg"
#print_grayscale_image_stats(mask_path)
print_color_image_stats(image_path)
print("")

mask_path_2 = "../images/women_fence_mask.png"
image_path_2 = "../images/women_fence.png"
#print_grayscale_image_stats(mask_path_2)
print_color_image_stats(image_path_2)


In [None]:

fx = 1
fy = 1

#img, mask = get_data(fx=fx, fy=fy, img_p="../images/dog.jpg", mask_p="../output/dog/0.png")
img, mask = get_data(fx=fx, fy=fy, img_p="../images/women_fence.png", mask_p="../output/women_fence/0.png")


print(f"Input image shape: {img.shape}")

model = ModelManager(name="lama", device=device)
strategy = HDStrategy.CROP
config = get_config(strategy)
gt_name = f"lama_{strategy[0].upper() + strategy[1:]}_fx_{fx}_result.png"

res = model(img, mask, config)
output_path = str(save_dir / gt_name)
cv2.imwrite(
    output_path,
    res,
    [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
)
print(f"Saved result to: {output_path}")

In [None]:

image = cv2.imread(output_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()
