# Mask-Dino Analysis

In [None]:
import argparse
import multiprocessing as mp
import os
import warnings
warnings.filterwarnings("ignore")

# fmt: off
import sys
home_dir = os.path.abspath(os.getcwd()+"/../")
sys.path.insert(1, home_dir)
print(home_dir)
os.environ["DETECTRON2_DATASETS"] = "../datasets"

import tqdm
import torch
import numpy as np
import matplotlib as mpl
import matplotlib.colors as mplc
import matplotlib.figure as mplfigure
import colorsys
from pprint import pprint

import detectron2.data.transforms as T
from detectron2.modeling import build_model
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.projects.deeplab import add_deeplab_config
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.data import detection_utils as utils

from detectron2.utils.visualizer import VisImage, _create_text_labels, GenericMask
from detectron2.structures import ImageList, BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes
from detectron2.utils.colormap import random_color

from detectron2.data import (
    MetadataCatalog,
    build_detection_test_loader,
    build_detection_train_loader,
)

from maskdino.utils import box_ops
from maskdino import add_maskdino_config
from maskdino import COCOInstanceNewBaselineDatasetMapper

In [None]:
def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    add_deeplab_config(cfg)
    add_maskdino_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg

def get_parser():
    parser = argparse.ArgumentParser(description="maskdino demo for builtin configs")
    parser.add_argument(
        "--config-file",
        default="../configs/coco/instance-segmentation/maskdino_R50_bs16_50ep_3s.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--input",
        nargs="+",
        help="A list of space separated input images; "
        "or a single glob pattern such as 'directory/*.jpg'",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )
    parser.add_argument(
        "--output",
        help="A file or directory to save output visualizations. "
        "If not given, will show output in an OpenCV window.",
    )
    return parser

In [None]:
mp.set_start_method("spawn", force=True)
args = get_parser().parse_args('')

args.input = ["../images/fruit.jpg"]
args.opts = ['MODEL.WEIGHTS', '../ckpts/maskdino_r50_50ep_300q_hid1024_3sd1_instance_maskenhanced_mask46.1ap_box51.5ap.pth']
args.output = home_dir + "/outputs"
cfg = setup_cfg(args)

In [None]:
pprint(cfg)

In [None]:
home_dir

In [None]:
mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)

In [None]:
data_loader = build_detection_train_loader(cfg, mapper=mapper)

### Load Model

In [None]:
model = build_model(cfg)
model.train()

if len(cfg.DATASETS.TEST):
    metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0])

checkpointer = DetectionCheckpointer(model)
checkpointer.load(cfg.MODEL.WEIGHTS)

aug = T.ResizeShortestEdge([cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST)
input_format = cfg.INPUT.FORMAT

### Backbone (ResNet)
- input shape : [B, 3, H, W]
- output shape
    - level 1 shape : [B, 256, H/4, W/4]
    - level 2 shape : [B, 512, H/8, W/8]
    - level 3 shape : [B, 1024, H/16, W/16]
    - level 4 shape : [B, 2048, H/32, W/32]

### Pixel Decoder

In [None]:
def prepare_targets(targets, images):
    h_pad, w_pad = images.tensor.shape[-2:]
    new_targets = []
    for targets_per_image in targets:
        # pad gt
        h, w = targets_per_image.image_size
        image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=model.device)

        gt_masks = targets_per_image.gt_masks
        padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
        padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
        new_targets.append(
            {
                "labels": targets_per_image.gt_classes,
                "masks": padded_masks,
                "boxes":box_ops.box_xyxy_to_cxcywh(targets_per_image.gt_boxes.tensor)/image_size_xyxy
            }
        )
    return new_targets

## Training Loss

In [None]:
# for batched_inputs in data_loader:
#     loss_dict = model(batched_inputs)
#     print(loss_dict)

In [None]:
data_loader_iter_obj = iter(data_loader)
data = next(data_loader_iter_obj)

### Transformer Decoder

In [None]:
with torch.no_grad():
    images = [x["image"].to(model.device) for x in data]
    images = [(x - model.pixel_mean) / model.pixel_std for x in images]
    images = ImageList.from_tensors(images, model.size_divisibility)
    
    print("|Input Shape|")
    print(f"  {images.tensor.shape}")
    features = model.backbone(images.tensor)

    print("|Backbone Output|")
    for lvl, f in enumerate(features):
        print(f"  level:{lvl}, {features[f].shape}")

    gt_instances = [x["instances"].to(model.device) for x in data]
    targets = prepare_targets(gt_instances, images)

    print("|Pixel Decoder Output|")
    mask_features, transformer_encoder_features, multi_scale_features = model.sem_seg_head.pixel_decoder.forward_features(features, None)
    print(f"  mask_features: {mask_features.shape}")
    print(f"  transformer_encoder_features: {transformer_encoder_features.shape}")
    print(f"  multi_scale_features")
    for lvl, f in enumerate(multi_scale_features):
        print(f"    level:{lvl}, {f.shape}")

    print("|Transformer Decoder Output|")
    outputs, mask_dict = model.sem_seg_head.predictor(multi_scale_features, mask_features, None, targets=targets)
    print(outputs.keys(), mask_dict)

In [None]:
img = read_image(args.input[0], format="BGR")
with torch.no_grad():
    if input_format == "RGB":
        # whether the model expects BGR inputs or RGB
        original_image = img[:, :, ::-1]
    height, width = original_image.shape[:2]
    image = aug.get_transform(original_image).apply_image(original_image)
    image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
    image.to(cfg.MODEL.DEVICE)

    batched_inputs = {"image": image, "height": height, "width": width}
    batched_inputs = [batched_inputs]
    images = [x["image"].to(model.device) for x in batched_inputs]
    images = [(x - model.pixel_mean) / model.pixel_std for x in images]
    images = ImageList.from_tensors(images, model.size_divisibility)
    
    print("|Input Shape|")
    print(f"  {images.tensor.shape}")
    features = model.backbone(images.tensor)

    print("|Backbone Output|")
    for lvl, f in enumerate(features):
        print(f"  level:{lvl}, {features[f].shape}")

    print("|Pixel Decoder Output|")
    mask_features, transformer_encoder_features, multi_scale_features = model.sem_seg_head.pixel_decoder.forward_features(features, None)
    print(f"  mask_features: {mask_features.shape}")
    print(f"  transformer_encoder_features: {transformer_encoder_features.shape}")
    print(f"  multi_scale_features")
    for lvl, f in enumerate(multi_scale_features):
        print(f"    level:{lvl}, {f.shape}")

    print("|Transformer Decoder Output|")
    predictions = model.sem_seg_head.predictor(multi_scale_features, mask_features, None, targets=targets)

### Inference

In [None]:
model.eval()
for path in tqdm.tqdm(args.input):
    img = read_image(path, format="BGR")

    print(img.shape)
    with torch.no_grad():
        if input_format == "RGB":
            # whether the model expects BGR inputs or RGB
            original_image = img[:, :, ::-1]
        height, width = original_image.shape[:2]
        image = aug.get_transform(original_image).apply_image(original_image)
        image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
        image.to(cfg.MODEL.DEVICE)

        inputs = {"image": image, "height": height, "width": width}

        print(image.shape, height, width)
        predictions = model([inputs])[0]

        out_filename = os.path.join(args.output, os.path.basename(path))
        

### Visualization

In [None]:
img = img[:, :, ::-1]
img = np.asarray(img).clip(0, 255).astype(np.uint8)
output = VisImage(img, scale=1.0)

cpu_device = torch.device("cpu")
instances = predictions["instances"].to(cpu_device)

boxes = instances.pred_boxes if instances.has("pred_boxes") else None
scores = instances.scores if instances.has("scores") else None
classes = instances.pred_classes.tolist() if instances.has("pred_classes") else None
labels = _create_text_labels(classes, scores, metadata.get("thing_classes", None))

masks = np.asarray(instances.pred_masks)
masks = [GenericMask(x, output.height, output.width) for x in masks]

In [None]:
colors = None
alpha = 0.5
default_font_size = max(
    np.sqrt(output.height * output.width) // 90, 10 // 1.0
)
_SMALL_OBJECT_AREA_THRESH = 1000

def convert_boxes(boxes):
    """
    Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension.
    """
    if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes):
        return boxes.tensor.detach().numpy()
    else:
        return np.asarray(boxes)

def convert_masks(masks_or_polygons):
    """
    Convert different format of masks or polygons to a tuple of masks and polygons.

    Returns:
        list[GenericMask]:
    """

    m = masks_or_polygons
    if isinstance(m, PolygonMasks):
        m = m.polygons
    if isinstance(m, BitMasks):
        m = m.tensor.numpy()
    if isinstance(m, torch.Tensor):
        m = m.numpy()
    ret = []
    for x in m:
        if isinstance(x, GenericMask):
            ret.append(x)
        else:
            ret.append(GenericMask(x, output.height, output.width))
    return ret

def draw_box(output, box_coord, alpha=0.5, edge_color="g", line_style="-"):
    """
    Args:
        box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0
            are the coordinates of the image's top left corner. x1 and y1 are the
            coordinates of the image's bottom right corner.
        alpha (float): blending efficient. Smaller values lead to more transparent masks.
        edge_color: color of the outline of the box. Refer to `matplotlib.colors`
            for full list of formats that are accepted.
        line_style (string): the string to use to create the outline of the boxes.

    Returns:
        output (VisImage): image object with box drawn.
    """
    x0, y0, x1, y1 = box_coord
    width = x1 - x0
    height = y1 - y0

    linewidth = max(default_font_size / 4, 1)

    output.ax.add_patch(
        mpl.patches.Rectangle(
            (x0, y0),
            width,
            height,
            fill=False,
            edgecolor=edge_color,
            linewidth=linewidth * output.scale,
            alpha=alpha,
            linestyle=line_style,
        )
    )
    return output

def draw_polygon(output, segment, color, edge_color=None, alpha=0.5):
    """
    Args:
        segment: numpy array of shape Nx2, containing all the points in the polygon.
        color: color of the polygon. Refer to `matplotlib.colors` for a full list of
            formats that are accepted.
        edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a
            full list of formats that are accepted. If not provided, a darker shade
            of the polygon color will be used instead.
        alpha (float): blending efficient. Smaller values lead to more transparent masks.

    Returns:
        output (VisImage): image object with polygon drawn.
    """
    if edge_color is None:
        # make edge color darker than the polygon color
        if alpha > 0.8:
            edge_color = change_color_brightness(color, brightness_factor=-0.7)
        else:
            edge_color = color
    edge_color = mplc.to_rgb(edge_color) + (1,)

    polygon = mpl.patches.Polygon(
        segment,
        fill=True,
        facecolor=mplc.to_rgb(color) + (alpha,),
        edgecolor=edge_color,
        linewidth=max(default_font_size // 15 * output.scale, 1),
    )
    output.ax.add_patch(polygon)
    return output

def change_color_brightness(color, brightness_factor):
    """
    Depending on the brightness_factor, gives a lighter or darker color i.e. a color with
    less or more saturation than the original color.

    Args:
        color: color of the polygon. Refer to `matplotlib.colors` for a full list of
            formats that are accepted.
        brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of
            0 will correspond to no change, a factor in [-1.0, 0) range will result in
            a darker color and a factor in (0, 1.0] range will result in a lighter color.

    Returns:
        modified_color (tuple[double]): a tuple containing the RGB values of the
            modified color. Each value in the tuple is in the [0.0, 1.0] range.
    """
    assert brightness_factor >= -1.0 and brightness_factor <= 1.0
    color = mplc.to_rgb(color)
    polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color))
    modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1])
    modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness
    modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness
    modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2])
    return tuple(np.clip(modified_color, 0.0, 1.0))

def draw_text(
    output,
    text,
    position,
    *,
    font_size=None,
    color="g",
    horizontal_alignment="center",
    rotation=0,
):
    """
    Args:
        text (str): class label
        position (tuple): a tuple of the x and y coordinates to place text on image.
        font_size (int, optional): font of the text. If not provided, a font size
            proportional to the image width is calculated and used.
        color: color of the text. Refer to `matplotlib.colors` for full list
            of formats that are accepted.
        horizontal_alignment (str): see `matplotlib.text.Text`
        rotation: rotation angle in degrees CCW

    Returns:
        output (VisImage): image object with text drawn.
    """
    if not font_size:
        font_size = default_font_size

    # since the text background is dark, we don't want the text to be dark
    color = np.maximum(list(mplc.to_rgb(color)), 0.2)
    color[np.argmax(color)] = max(0.8, np.max(color))

    x, y = position
    output.ax.text(
        x,
        y,
        text,
        size=font_size * output.scale,
        family="sans-serif",
        bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
        verticalalignment="top",
        horizontalalignment=horizontal_alignment,
        color=color,
        zorder=10,
        rotation=rotation,
    )
    return output

num_instances = 0
if boxes is not None:
    boxes = convert_boxes(boxes)
    num_instances = len(boxes)

if masks is not None:
    masks = convert_masks(masks)
    if num_instances:
        assert len(masks) == num_instances
    else:
        num_instances = len(masks)

if labels is not None:
    assert len(labels) == num_instances
assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]

In [None]:
areas = None
if boxes is not None:
    areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1)
elif masks is not None:
    areas = np.asarray([x.area() for x in masks])

if areas is not None:
    sorted_idxs = np.argsort(-areas).tolist()
    # Re-order overlapped instances in descending order.
    boxes = boxes[sorted_idxs] if boxes is not None else None
    labels = [labels[k] for k in sorted_idxs] if labels is not None else None
    masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None
    assigned_colors = [assigned_colors[idx] for idx in sorted_idxs]

for i in range(num_instances):
    color = assigned_colors[i]
    if boxes is not None:
        output = draw_box(output, boxes[i], edge_color=color)

    if masks is not None:
        for segment in masks[i].polygons:
            output = draw_polygon(output, segment.reshape(-1, 2), color, alpha=alpha)

    if labels is not None:
        # first get a box
        if boxes is not None:
            x0, y0, x1, y1 = boxes[i]
            text_pos = (x0, y0)  # if drawing boxes, put text on the box corner.
            horiz_align = "left"
        elif masks is not None:
            # skip small mask without polygon
            if len(masks[i].polygons) == 0:
                continue

            x0, y0, x1, y1 = masks[i].bbox()

            # draw text in the center (defined by median) when box is not drawn
            # median is less sensitive to outliers.
            text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
            horiz_align = "center"
        else:
            continue  # drawing the box confidence for keypoints isn't very useful.
        # for small objects, draw text at the side to avoid occlusion
        instance_area = (y1 - y0) * (x1 - x0)
        if (
            instance_area < _SMALL_OBJECT_AREA_THRESH * output.scale
            or y1 - y0 < 40 * output.scale
        ):
            if y1 >= output.height - 5:
                text_pos = (x1, y0)
            else:
                text_pos = (x0, y1)

        height_ratio = (y1 - y0) / np.sqrt(output.height * output.width)
        lighter_color = change_color_brightness(color, brightness_factor=0.7)
        font_size = (
            np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
            * 0.5
            * default_font_size
        )
        vis_output = draw_text(
            output,
            labels[i],
            text_pos,
            color=lighter_color,
            horizontal_alignment=horiz_align,
            font_size=font_size,
        )
