In [None]:
%reload_ext dotenv
%dotenv

In [None]:
import replicate
import requests
from PIL import Image

In [None]:
garments = {
    "tshirt": {
        "description": "purple t-shirt",
        "category": "upper_body",
        "image_path": "./garments/tops/tshirt.webp",
    },
    "sweater": {
        "description": "oversized pink sweater",
        "category": "upper_body",
        "image_path": "./garments/tops/sweater.jpg",
    },
    "jeans": {
        "description": "slim fit washed jeans",
        "category": "lower_body",
        "image_path": "./garments/bottoms/jeans.webp",
    },
    "joggers": {
        "description": "pink joggers",
        "category": "lower_body",
        "image_path": "./garments/bottoms/joggers.jpg",
    },
}

### Figure out where the garment might be

In [None]:
def generate_mask(category):
    input = {
        # Garment doesn't matter if mask_only is True
        "garm_img": open(next(iter(garments.values()))["image_path"], "rb"),
        "human_img": open("./humans/model.jpg", "rb"),
        "category": category,
        "mask_only": True,
    }

    result_url = replicate.run(
        "cuuupid/idm-vton:c871bb9b046607b680449ecbae55fd8c6d945e0a1948644bf2361b3d021d3ff4",
        input=input,
    )
    
    return Image.open(requests.get(result_url, stream=True).raw)

In [None]:
im = generate_mask("upper_body")
im.save("./masks/top.jpg")
im

### Put single garment on human

In [None]:
def put_garment(garment):
    input = {
        "garm_img": open(garment["image_path"], "rb"),
        "human_img": open("./humans/model.jpg", "rb"),
        "mask_img": open("./masks/top.jpg" if garment["category"] == "upper_body" else "./masks/bottom.jpg", "rb"),
        "garment_des": garment["description"],
        "category": garment["category"],
    }

    result_url = replicate.run(
        "cuuupid/idm-vton:c871bb9b046607b680449ecbae55fd8c6d945e0a1948644bf2361b3d021d3ff4",
        input=input
    )
    
    return Image.open(requests.get(result_url, stream=True).raw)

In [None]:
im = put_garment(garments["joggers"])
im.save("./results/single/joggers.jpg")
im

### Combine garments

In [None]:
human = Image.open('./humans/model.jpg')
print(human.format, human.size, human.mode)

mask_bottom = Image.open('./masks/bottom.jpg')
print(mask_bottom.format, mask_bottom.size, mask_bottom.mode)

mask_top = Image.open('./masks/top.jpg')
print(mask_top.format, mask_top.size, mask_top.mode)

result_joggers = Image.open('./results/single/joggers.jpg')
print(result_joggers.format, result_joggers.size, result_joggers.mode)

result_jeans = Image.open('./results/single/jeans.jpg')
print(result_jeans.format, result_jeans.size, result_jeans.mode)

result_sweater = Image.open('./results/single/sweater.jpg')
print(result_sweater.format, result_sweater.size, result_sweater.mode)

result_tshirt = Image.open('./results/single/tshirt.jpg')
print(result_tshirt.format, result_tshirt.size, result_tshirt.mode)

In [None]:
def refine_mask(human, garment_on_human, vton_mask, threshold=1000):
    # compute the difference between the two images
    mask = Image.new("L", human.size)
    for x in range(human.size[0]):
        for y in range(human.size[1]):
            r1, g1, b1 = human.getpixel((x, y))
            r2, g2, b2 = garment_on_human.getpixel((x, y))
            if ((r2 - r1) ** 2 + (g2 - g1) ** 2 + (b2 - b1) ** 2) > threshold:
                mask.putpixel((x, y), 255)

    # intersect mask with vton mask
    for x in range(human.size[0]):
        for y in range(human.size[1]):
            if vton_mask.getpixel((x, y)) == 0:
                mask.putpixel((x, y), 0)

    return mask

In [None]:
def combine_garments(human, result_top, result_bottom, mask_top, mask_bottom):
    result = Image.new("RGB", human.size)
    refined_mask = refine_mask(human, result_top, mask_top)
    result.paste(result_bottom, (0, 0))
    result.paste(result_top, (0, 0), refined_mask)
    return result

In [None]:
im = combine_garments(human, result_sweater, result_jeans, mask_top, mask_bottom)
im.save("./results/multi/sweater_jeans.jpg")
im