In [1]:
"""
    * Input image
    * Get DINOv2 embedding given the image (main part)
    * Plug it into IP-Adapter parameters --> Boom!
"""
import torch
from PIL import Image
import torchvision.transforms as T
import torch

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load model
dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
dinov2.to(device)

img = Image.open('../images/garments/Ao_thun_oversize_84RISING_-_Vit_Donald_Disneyshadow1.jpg')

# image preprocessing
transform = T.Compose([
    T.Resize(224),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])

img = transform(img)[:3].unsqueeze(0)

with torch.inference_mode():
    features = dinov2(img.to(device))[0]

Using cache found in /home/ubuntu/.cache/torch/hub/facebookresearch_dinov2_main


In [None]:
# https://github.com/Markus-Pobitzer/Inpainting-Tutorial/tree/main
class Masker:
    def __init__(self, model_id='mattmdjaga/segformer_b2_clothes'): # SegFormer
        self.processor = SegformerImageProcessor.from_pretrained(model_id)
        self.model = AutoModelForSemanticSegmentation.from_pretrained(model_id)

    def get_binary_mask(self, image, return_pil=False):
        inputs = self.processor(images=image, return_tensors="pt")        
        outputs = self.model(**inputs)
        logits = outputs.logits.cpu()
        upsampled_logits = nn.functional.interpolate(
            logits,
            size=image.size[::-1],
            mode="bilinear",
            align_corners=False
        )
        pred_seg = upsampled_logits.argmax(dim=1)[0] # Dunno how this line work
        np_image = np.array(image)
        np_image[pred_seg != 4] = 0
        np_image[pred_seg == 4] = 255
        binary_mask = ((pred_seg == 4) * 255).numpy().astype(np.uint8)

        return Image.fromarray(binary_mask) if return_pil else binary_mask

In [None]:
def try_on(original_image: PIL.Image.Image, 
           mask_image: PIL.Image.Image,
           ip_image: PIL.Image.Image,
           prompt: Text,
           control_image: Union[PIL.Image.Image] = None,
           **kwargs) -> List[PIL.Image.Image]:
    images = pipeline(
        prompt=prompt,
        negative_prompt=kwargs.get('negative_prompt'),
        image=original_image,
        mask_image=mask_image,
        control_image=control_image,
        guidance_scale=kwargs.get('guidance_scale'),
        width=kwargs.get('width'), # output width
        height=kwargs.get('height'), # output height
        ip_adapter_image=ip_image,
        generator=kwargs.get('seed'),
        num_images_per_prompt=kwargs.get('num_images'),
        strength=kwargs.get('strength'),
        num_inference_steps=kwargs.get('num_inference_steps'),
        padding_mask_crop=kwargs.get('padding_mask_crop')
    ).images

    return images

In [None]:
controlnet = ControlNetModel.from_pretrained(
    'lllyasviel/sd-controlnet-openpose',
    torch_dtype=torch.float16
).to('cuda')

inpainting_base_model = 'runwayml/stable-diffusion-inpainting'
pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
    inpainting_base_model, 
    variant='fp16',
    torch_dtype=torch.float16,
    controlnet=controlnet,
    safety_checker=None
).to('cuda')

ip_adapter_model = 'h94/IP-Adapter'
pipeline.load_ip_adapter(ip_adapter_model, subfolder='models', weight_name='ip-adapter_sd15.bin')
pipeline.set_ip_adapter_scale(1.0)

masker = Masker()

In [None]:
image = load_image('../images/male_model.jpg')
ip_image = load_image('../tmp/img4.jpg')
mask_image = masker.get_binary_mask(image, return_pil=True)

assert isinstance(image, PIL.Image.Image)
assert isinstance(ip_image, PIL.Image.Image)
assert isinstance(mask_image, PIL.Image.Image)

INPUT_IMAGE_SIZE = (512, 768)
SQUARE_IMAGE_SIZE = (224, 224)
image = image.resize(INPUT_IMAGE_SIZE)
mask_image = mask_image.resize(INPUT_IMAGE_SIZE)

make_image_grid([ip_image.resize(INPUT_IMAGE_SIZE), image, mask_image], rows=1, cols=3)

In [None]:
negative_prompt = '3d, cartoon, anime, sketches, (worst quality, bad quality, cropped:1.4) ((monochrome) ), \
((grayscale) ), (bad-hands-5:1) , (badhandv4) , (easynegative:0.8) , (bad-artist-anime:) , (bad-artist:) , (bad_prompt:) , \
(bad-picture-chill-75v:) , (bad_prompt_version2:) , (bad_quality:) , Asian-Less-Neg bad-hands-5 bad_pictures \
bad artist -neg CyberRealistic_Negative-neg negative_hand-neg ng_deepnegative_v1_75t, verybadimagenegative_v1. 3, lowres, \
low quality, jpeg, artifacts, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, \
poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, \
gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, \
drawing, painting, crayon, sketch, graphite, impressionist, noisy, soft'

pipe_config = {
    'negative_prompt': negative_prompt,
    'guidance_scale': 9.0,
    'num_images': 3,
    'strength': 1,
    'num_inference_steps': 50,
    'seed': torch.Generator(device="cuda").manual_seed(1337),
    'width': INPUT_IMAGE_SIZE[0],
    'height': INPUT_IMAGE_SIZE[1],
    'padding_mask_crop': None
}

In [None]:
openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
pose = openpose(image)
make_image_grid([image, pose], rows=1, cols=2)

In [None]:
images = try_on(image, mask_image, ip_image, prompt="a sks black polo neck t-shirt", control_image=pose, **pipe_config)