In [None]:
import numpy as np
import cv2
import torch
from diffusers import ControlNetModel, DDIMScheduler
from PIL import Image
from pipelines.controlnet_inpainting_pipeline import StableDiffusionControlNetInpaintPipeline
from settings.prompt_enhancer import PromptEnhancer
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation

from .ade import ade_palette, ADE_CLASSES, NEG_CLASSES, INIT_CLASSES
from .segformer_model import SegFormer
import utils.utils as utils

In [None]:
def rounded_rectangle(src, top_left, bottom_right, radius=1, color=255, thickness=1, line_type=cv2.LINE_AA):
    """ Function to create a rouned rectangle
        Stackoverflow: https://stackoverflow.com/a/60210706
    """
    #  corners:
    #  p1 - p2
    #  |     |
    #  p4 - p3

    p1 = top_left
    p2 = (bottom_right[1], top_left[1])
    p3 = (bottom_right[1], bottom_right[0])
    p4 = (top_left[0], bottom_right[0])

    height = abs(bottom_right[0] - top_left[1])

    if radius > 1:
        radius = 1

    corner_radius = int(radius * (height/2))

    if thickness < 0:

        #big rect
        top_left_main_rect = (int(p1[0] + corner_radius), int(p1[1]))
        bottom_right_main_rect = (int(p3[0] - corner_radius), int(p3[1]))

        top_left_rect_left = (p1[0], p1[1] + corner_radius)
        bottom_right_rect_left = (p4[0] + corner_radius, p4[1] - corner_radius)

        top_left_rect_right = (p2[0] - corner_radius, p2[1] + corner_radius)
        bottom_right_rect_right = (p3[0], p3[1] - corner_radius)

        all_rects = [
        [top_left_main_rect, bottom_right_main_rect], 
        [top_left_rect_left, bottom_right_rect_left], 
        [top_left_rect_right, bottom_right_rect_right]]

        [cv2.rectangle(src, rect[0], rect[1], color, thickness) for rect in all_rects]

    # draw straight lines
    cv2.line(src, (p1[0] + corner_radius, p1[1]), (p2[0] - corner_radius, p2[1]), color, abs(thickness), line_type)
    cv2.line(src, (p2[0], p2[1] + corner_radius), (p3[0], p3[1] - corner_radius), color, abs(thickness), line_type)
    cv2.line(src, (p3[0] - corner_radius, p4[1]), (p4[0] + corner_radius, p3[1]), color, abs(thickness), line_type)
    cv2.line(src, (p4[0], p4[1] - corner_radius), (p1[0], p1[1] + corner_radius), color, abs(thickness), line_type)

    # draw arcs
    cv2.ellipse(src, (p1[0] + corner_radius, p1[1] + corner_radius), (corner_radius, corner_radius), 180.0, 0, 90, color ,thickness, line_type)
    cv2.ellipse(src, (p2[0] - corner_radius, p2[1] + corner_radius), (corner_radius, corner_radius), 270.0, 0, 90, color , thickness, line_type)
    cv2.ellipse(src, (p3[0] - corner_radius, p3[1] - corner_radius), (corner_radius, corner_radius), 0.0, 0, 90,   color , thickness, line_type)
    cv2.ellipse(src, (p4[0] + corner_radius, p4[1] - corner_radius), (corner_radius, corner_radius), 90.0, 0, 90,  color , thickness, line_type)

    return src

def overlay(image, mask, color, alpha, resize=None):
    """Combines image and its segmentation mask into a single image.
    https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay

    Params:
        image: Training image. np.ndarray,
        mask: Segmentation mask. np.ndarray,
        color: Color for segmentation mask rendering.  tuple[int, int, int] = (255, 0, 0)
        alpha: Segmentation mask's transparency. float = 0.5,
        resize: If provided, both image and its mask are resized before blending them together.
        tuple[int, int] = (1024, 1024))

    Returns:
        image_combined: The combined image. np.ndarray

    """
    color = color[::-1]
    colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
    colored_mask = np.moveaxis(colored_mask, 0, -1)
    masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
    image_overlay = masked.filled()

    if resize is not None:
        image = cv2.resize(image.transpose(1, 2, 0), resize)
        image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)

    image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)

    return image_combined

def rescale_image(image, size=(512, 512), pad=True):
    W, H = size
    if pad:
        image.thumbnail((512, 512))
        rW, rH = image.size
        temp_input = Image.new(image.mode, size)
        temp_input.paste(image, (W//2 - rW//2,
                                        H//2 - rH//2))
        size = (rW, rH)
    else:
        temp_input = image.resize(size)

    return temp_input, size

def unpad_image(padded_img, size):
    padded_img = padded_img.copy()
    
    W, H = padded_img.size
    rW, rH = size
    cx, cy = W//2, H//2
    icx, icy = rW//2, rH//2
    x1, y1 = cx - icx, cy - icy
    x2, y2 = cx + icx, cy + icy

    return padded_img.crop((x1, y1, x2, y2))


In [None]:
class ControlNetInpaint:
    def __init__(
            self,
            cn_model='lllyasviel/sd-controlnet-seg',
            # cn_model='BertChristiaens/controlnet-seg-room',
            sd_model='runwayml/stable-diffusion-inpainting',
            use_cuda=True
        ) -> None:

        self.use_cuda = use_cuda
        # Device to use
        if use_cuda and torch.cuda.is_available():
            self.device = 'cuda'
        else:
            self.device = 'cpu'

        # Load all models
        self.load_models(sd_model, cn_model)

    def load_models(self, sd_model, cn_model):
        # Load controlnet
        self.cn_model = ControlNetModel.from_pretrained(
                cn_model,
                torch_dtype=torch.float16
        )
        # Load stable diffusion
        self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
                sd_model,
                controlnet=self.cn_model,
                torch_dtype=torch.float16
        )
        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)

        # Memory optimisation
        if self.use_cuda:
            self.pipe.enable_xformers_memory_efficient_attention()
        self.pipe.to(self.device)

        # Load the prompt enhancer
        # self.prompt_enhancer = PromptEnhancer()

        # Load segmentation models
        self.image_processor = AutoImageProcessor.from_pretrained(
            'openmmlab/upernet-convnext-small'
        )
        self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
            'openmmlab/upernet-convnext-small'
        )

    def get_cn_seg_control(self, image, return_mask=False,
                           dilation_kernel=(5, 5),
                           iterations=10,
                           mask_option='erode',
                           classes=INIT_CLASSES,
                           neg_classes=NEG_CLASSES):
        # Pre-process images
        pixel_values = self.image_processor(
                np.array(image),
                return_tensors='pt'
        ).pixel_values

        # Run image segmentation
        with torch.no_grad():
            seg = self.image_segmentor(pixel_values)

        # Refine segmentation
        seg = self.image_processor.post_process_semantic_segmentation(
            seg, target_sizes=[image.size[::-1]]
        )[0]
        class_labels = []
        class_idxs = []
        # If True it will generate a mask for inpainting otherwise segmantation image for controll net will be generated
        if return_mask:            
            if not len(classes):
                classes = ADE_CLASSES
                
            mask = seg.cpu().numpy().copy()
            # Extracting detected labels and ids
            for i, class_label in enumerate(classes):
                if class_label in neg_classes:
                    continue
                class_idx = ADE_CLASSES.index(class_label)
                if np.any(mask==class_idx):
                    class_labels.append(class_label)
                    class_idxs.append(class_idx)
            print(class_labels)

            # creating a mask of detected objects
            for idx in class_idxs:
                mask[mask==idx] = 255
            mask[mask!=255] = 0

            # Expanding or shinking mask based on `mask_option`
            mask = mask.astype(np.float32)
            kernel = np.ones(dilation_kernel, np.uint8)
            if mask_option == 'erode':
                mask = cv2.erode(mask, kernel, iterations=iterations)
            else:
                mask = cv2.dilate(mask, kernel, iterations=iterations)

            image = Image.fromarray(mask.astype(np.uint8))

        else:  
            # Color code using ADE palette
            color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
            for label, color in enumerate(np.array(ade_palette())):
                color_seg[seg==label] = color
            color_seg = color_seg.astype(np.uint8)
            image = Image.fromarray(color_seg)
    
        return image

    def get_prompts(
            self,
            image,
            room_type=None,
            architecture_style=None,
            override_prompt=None
        ):
        # Generate prompts
        prompt, add_prompt, negative_prompt = vs.create_prompts(
                room_type=room_type,
                architecture_style=architecture_style
        )
        prompt = prompt + ', ' + add_prompt

        # Override prompt if user has specified
        if override_prompt:
            prompt = override_prompt

        # Enhance prompt
        # prompt = random.choice(self.prompt_enhancer(prompt))

        return prompt, negative_prompt

    def run_model(
            self,
            prompt,
            negative_prompt,
            image,
            mask,
            controlnet_image,
            num_inference_step,
            guidance_scale,
            control_strength,
            generator
        ):
        # Text prompt
        input_ids = self.pipe.tokenizer(
                prompt,
                return_tensors='pt'
        ).input_ids.to(self.device)
        negative_ids = self.pipe.tokenizer(
                negative_prompt,
                truncation=False,
                padding='max_length',
                max_length=input_ids.shape[-1],
                return_tensors='pt'
        ).input_ids.to(self.device)

        # Encode prompts in chunks because of max_length limit
        concat_embeds, neg_embeds = [], []
        max_length = self.pipe.tokenizer.model_max_length
        for i in range(0, input_ids.shape[-1], max_length):
            concat_embeds.append(
                    self.pipe.text_encoder(input_ids[:,i:i+max_length])[0]
            )
            neg_embeds.append(
                    self.pipe.text_encoder(negative_ids[:,i:i+max_length])[0]
            )

        # Concat chunks
        prompt_embeds = torch.cat(concat_embeds, dim=1)
        negative_prompt_embeds = torch.cat(neg_embeds, dim=1)

        negative_prompt_embeds = negative_prompt_embeds[:, :prompt_embeds.shape[1], :] 
        # Run through pipe
        images = self.pipe(
                image=image,
                mask_image=mask,
                control_image=controlnet_image,
                num_images_per_prompt=1,
                num_inference_steps=num_inference_step,
                guidance_scale=guidance_scale,
                generator=generator,
                controlnet_conditioning_scale=control_strength,
                prompt_embeds=prompt_embeds,
                negative_prompt_embeds=negative_prompt_embeds
        ).images
        return images

    def generate_mask(self, image, mask_dilation,
                       mask_option, use_rounded=False):
        """Function to only run and display SegFormer for mask debugging.
            integrated with `Generate Mask` button in gradio app.
        """
        if isinstance(image, dict):
            image = image['image'].convert('RGB')
        
        mask = self.get_mask(image, mask_dilation, mask_option=mask_option,
                             use_rounded=use_rounded)
        mask = Image.fromarray(overlay(np.array(image),
                                       np.array(mask),
                                       (255, 0, 0), 0.5))
        return [mask]

    def get_mask(self, image, mask_dilation,
                  mask_option, use_rounded=False,
                  classes=INIT_CLASSES,
                  padding=10, 
                  neg_classes=NEG_CLASSES):
        W, H = image.size
        if image.mode != 'RGB':
            image = image.convert('RGB')
            
        smask = self.get_cn_seg_control(image, return_mask=True,
                                iterations=mask_dilation,
                                mask_option=mask_option,
                                classes=classes,
                                neg_classes=neg_classes)

        if use_rounded:
            smask = np.array(smask)
            mask = rounded_rectangle(np.zeros_like(smask), (0 + padding, 0 + padding),
                                     (H - padding, W - padding), radius=0.5, color=(255, 255, 255), thickness=-1)
            mask = mask//255
            smask = smask//255
            smask = mask*smask
            smask *= 255    
            smask = Image.fromarray(smask)
        smask = cv2.resize(
                np.array(smask), (W, H), interpolation=cv2.INTER_NEAREST)
        smask = Image.fromarray(smask)

        return smask

    def run_single_iteratiation(self,
                                input_image,
                                mask_image,
                                prompt,
                                negative_prompt,
                                control_strength,
                                guidance_scale,
                                num_inference_step,
                                seed
                                ):
        rW, rH = input_image.size
        # Get control image
        control = self.get_cn_seg_control(image=input_image.copy())
        control = cv2.resize(
                    np.array(control), (rW, rH), interpolation=cv2.INTER_NEAREST)
        control = Image.fromarray(control)

        # Set seed
        if seed == 0:
            itr_seed = torch.randint(0, 1000000, (1,))
        else:
            itr_seed = seed
        itr_seed_gen = torch.manual_seed(itr_seed)

        print(f'\tUsing model with seed {itr_seed} '
                f'strength {control_strength} '
                f'and guidance {guidance_scale} '
                f'with prompt: \n\t{prompt}')

        # Run model
        output = self.run_model(
            prompt,
            negative_prompt,
            input_image,
            mask_image,
            control,
            num_inference_step,
            guidance_scale,
            control_strength,
            itr_seed_gen
        )
        return output[0]



    def __call__(
            self,
            image_dict,
            room_type,
            architecture_style=None,
            negative_prompt="",
            num_images_per_prompt=5,
            guidance_scale=12,
            num_inference_step=20,
            strength_min=0.1,
            strength_max=0.5,
            seed=0,
            override_prompt=None,
            upscale=False,
            mask_dilation=1,
            mask_option='dilate',
            use_fixed_strength=False,
            use_rounded=True,
            padding=10
        ):
        W, H = image_dict['image'].size
        org_image = image_dict['image'].copy()
        input_image = image_dict['image'].convert('RGB')
        input_image = Image.fromarray(utils.resize_image(np.array(input_image), 512))
        rW, rH = input_image.size
        mask_image = image_dict['mask'].convert('RGB')
        mask_image = cv2.resize(
                        np.array(mask_image), (rW, rH), interpolation=cv2.INTER_NEAREST)
        mask_image = Image.fromarray(mask_image)

        if not np.any(np.array(mask_image)):
            print('Extracting Mask...')
            mask_image = self.get_mask(input_image, mask_dilation,
                                        mask_option=mask_option,
                                        use_rounded=use_rounded,
                                        padding=padding)


        strength_factor = (strength_max - strength_min)/num_images_per_prompt
        control_strength = strength_max
        output_images = []

        # Get prompts
        prompt, negative_prompt = self.get_prompts(
                input_image,
                room_type,
                architecture_style=architecture_style,
                override_prompt=override_prompt
        )

        min_seed = 1800000000
        max_seed = 4200000001
        # Why set seed in this range? 
        # Don't know Magic number maybe

        # Set seed
        if seed == 0:
            seed_value = np.random.randint(min_seed, max_seed)

        for i in range(num_images_per_prompt):
            # increase control strength iteratively
            if not use_fixed_strength:
                control_strength = strength_min + (i+1)*strength_factor

            seed = torch.randint(max(0, seed_value - 10000), seed_value + 10000, (1,))


            output = self.run_single_iteratiation(input_image,
                                         mask_image,
                                         prompt,
                                         negative_prompt,
                                         control_strength,
                                         guidance_scale,
                                         num_inference_step,
                                         seed)
            output_images.append(output.copy())
            input_image = output.copy()

        ##################################################################
        # Renovate generated image by extracting the objects masks
        ##################################################################

        seed = np.random.randint(min_seed, max_seed)
        
        input_image = output_images[-1].copy()        
        final_mask = self.get_mask(input_image,
                            20,
                            mask_option='dilate',
                            use_rounded=False,
                            padding=padding,
                            classes=[],
                            neg_classes=NEG_CLASSES+INIT_CLASSES+['light'])

        # final_mask.save('stage-2-mask.jpg')
        # Combine stage-1 mask and new mask
        final_mask = np.array(final_mask.convert('L'))
        mask_image = np.array(mask_image.convert('L'))
        final_mask = final_mask//255
        mask_image = mask_image//255
        final_mask = mask_image*final_mask
        final_mask *= 255    
        final_mask = Image.fromarray(final_mask)
        mask_image *= 255
        mask_image = Image.fromarray(mask_image)
        # final_mask.save('final_mask.jpg')
        output = self.run_single_iteratiation(input_image,
                                        final_mask,
                                        prompt,
                                        negative_prompt,
                                        0.4,
                                        15,
                                        num_inference_step,
                                        seed)
        output_images.append(output.copy())
        ##################################################################

        # Resize input image to match output image
        W, H = output_images[0].size[:2]
        
        input_image = np.array(rescale_image(org_image.convert('RGB'),
                                             size=(W,H), pad=False)[0])
        # Resize mask to match input image
        mask_image = np.array(rescale_image(mask_image.convert('L'),
                                            size=(W, H), pad=False)[0])

        # # merge input image and generated image to remove oily effect due to model
        output_images[-1] = utils.post_process_image(output_images[-1],
                                                     input_image, mask_image)


        # Overlay mask on input image
        mask = [Image.fromarray(overlay(input_image, mask_image,
                                        (255, 0, 0), 0.5))]
        output_images = [img for img in output_images]

        # Not actually upscalling it's just resizing into original size
        if upscale:
            output_images = [rescale_image(img, size=org_image.size,
                                           pad=False)[0] for img in output_images]
            mask = [rescale_image(m, size=org_image.size,
                                           pad=False)[0] for m in mask]

        return output_images, mask
