In [2]:
%load_ext autoreload
%autoreload 2

import os
from PIL import Image, ImageOps
import requests
import torch
import numpy as np
import json
from torchvision.utils import draw_segmentation_masks

multiple = ["two", "three", "four", "five", "six", "seven", "multiple", "group", "various", "several"] 
def judge_one_object(prompts):
    one_object = True
    for prompt in prompts:
        p = prompt.lower()
        for m in multiple:
            if m in p:
                one_object = False
        if one_object == False:
            return False
    return True

# If there are multiple classes, select the main body according to the clip score between the class and image
def find_main_class(classes, image):
    print(classes)
    class_candidate = []
    for c in classes:
        if c not in class_candidate:
            class_candidate.append(c)
    if len(class_candidate) == 1:
        class_select = class_candidate[0]
    else:
        class_score = []
        for c in classes:
            cls_score = metric(transform(image), c)
            class_score.append(cls_score.detach().numpy())
        class_score = np.array(class_score)
        class_score_idx = np.argsort(class_score)
        class_select = classes[class_score_idx[-1]]
    return class_select

def crop_image(mask, image, segments):
    L = image.size[0]
    T = image.size[1]
    R = 0
    B = 0
    for i in range(len(segments)):
        bbox = segments[i]['bbox']
        L = min(L, int(bbox[0]))
        T = min(T, int(bbox[1]))
        R = max(R, int(bbox[0]+bbox[2]))
        B = max(B, int(bbox[1]+bbox[3]))
    crop_h = int(max(B-T,R-L) * 1.5)
    crop_w = crop_h
    center_h = int((T+B)/2)
    center_w = int((L+R)/2)
    
    start_h = max(0,center_h-crop_h//2)
    start_w = max(0,center_w-crop_w//2)
    
    end_h = min(image.size[1],center_h+crop_h//2)
    end_w = min(image.size[0],center_w+crop_w//2)
    
    image_crop = Image.fromarray(np.array(image)[start_h:end_h, start_w:end_w])
    mask_crop = Image.fromarray(np.array(mask)[start_h:end_h, start_w:end_w])

    return mask_crop, image_crop


def generate_mask_image(segments_select, main_class, classes, init_image):
    if judge_one_object(prompts_dict[image_names[img_idx]]):
        for i in range(len(segments_select)):
            if classes[-i-1] == main_class:
                mask_gt = annToMask(segments_select[-i-1], init_image.size[0], init_image.size[1])
                mask_out = Image.fromarray(np.uint8(mask_gt*255)).convert('RGB')
                mask_out_crop, init_image_crop = crop_image(mask_out, init_image, [segments_select[-i-1]])
                return mask_out_crop, init_image_crop
        
    else:
        mask_gt_all = []
        segments = []
        for i in range(len(segments_select)):
            if classes[-i-1] == main_class:
                segments.append(segments_select[-i-1])
                mask_gt = annToMask(segments_select[-i-1], init_image.size[0], init_image.size[1])
                mask_gt_all.append(mask_gt)
        mask_gt_all = np.array(mask_gt_all)
        mask_gt = np.sum(mask_gt_all, axis=0, keepdims=False)
        mask_gt[mask_gt>1.0] = 1.0
        mask_out = Image.fromarray(np.uint8(mask_gt*255)).convert('RGB')
        mask_gout_crop, init_image_crop = crop_image(mask_out, init_image, segments)
        return mask_gout_crop, init_image_crop

import mask as maskUtils
def annToRLE(ann, w, h):
        """
        Convert annotation which can be polygons, uncompressed RLE to RLE.
        :return: binary mask (numpy 2D array)
        """
        segm = ann['segmentation']
        if type(segm) == list:
            # polygon -- a single object might consist of multiple parts
            # we merge all parts into one mask rle code
            rles = maskUtils.frPyObjects(segm, h, w)
            rle = maskUtils.merge(rles)
        elif type(segm['counts']) == list:
            # uncompressed RLE
            rle = maskUtils.frPyObjects(segm, h, w)
        else:
            # rle
            rle = ann['segmentation']
        return rle
def annToMask(ann, w, h):
        """
        Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
        :return: binary mask (numpy 2D array)
        """
        rle = annToRLE(ann, w, h)
        m = maskUtils.decode(rle)
        return m      



from torchmetrics.multimodal import CLIPScore
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.PILToTensor()
])
path_co_clip = "../../data/assets"
#path_co_clip = "./clip"
metric = CLIPScore(model_name_or_path=path_co_clip)

Using TensorFlow backend.


In [7]:
# obtain segments for an image
def obtain_segments(image_names, img_idx, instance):
    segments_all = []
    for i in range(len(instance['images'])):
        if instance['images'][i]['file_name'] == image_names[img_idx]:
            image_id = instance['images'][i]['id']
    for i in range(len(instance['annotations'])):
        if instance['annotations'][i]['image_id'] == image_id:
            segments_all.append(instance['annotations'][i])

    # obtain top-5 segments

    area = []
    for seg in segments_all:
        area.append(seg['area'])
    area = np.array(area)
    area_idx = np.argsort(area)
    segments = []
    if len(segments_all) > 5:
        for idx in area_idx[-5:]:
            segments.append(segments_all[idx])
    else:
        for idx in area_idx:
            segments.append(segments_all[idx])

    segments_select = []
    # filter out object less than 1/10 area of the total image if image has more than 2 objects
    if len(segments) > 2:
        for seg in segments:
            if (seg['bbox'][2]*seg['bbox'][3]) > 0.05 * (instance['images'][img_idx]['height'] * instance['images'][img_idx]['width']):
                segments_select.append(seg)
        if len(segments_select) == 0:
            segments_select.append(seg)
    else:
        segments_select = segments
    classes = []
    for seg in segments_select:
        cls = seg['category_id']
        for i in range(len(instance['categories'])):
            if instance['categories'][i]['id'] == cls:
                classes.append(instance['categories'][i]['name'])
    return segments_select, classes


In [18]:
with open("prompts.json", "r") as f:
    prompts_dict = json.load(f)
image_names = list(prompts_dict.keys())

with open(path_to_coco + "/annotations/instances_val2017.json", "r") as f:
    instance = json.load(f)

image_folder = "images_crop/"
mask_folder = "mask_crop/"
for img_idx in range(len(image_names)):
    # skip some error images
    if img_idx in [64, 141, 181, 219, 344, 359, 379, 389]:
        continue
    init_image = Image.open(path_to_coco + "/images/val2017/" + image_names[img_idx]).convert('RGB')
    save_path = image_folder + image_names[img_idx]
    mask_path = mask_folder + image_names[img_idx]
    segments_select, classes = obtain_segments(image_names, img_idx, instance)
    main_class = find_main_class(classes, init_image)
    mask_crop, init_image_crop = generate_mask_image(segments_select, main_class, classes, init_image)
    
    init_image_crop.save(save_path)
    mask_crop.save(mask_path)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
380
2
['bench', 'person']
381
1
['teddy bear']
382
4
['mouse', 'bottle', 'keyboard', 'tv']
383
2
['keyboard', 'teddy bear']
384
8
['person']
385
6
['fire hydrant']
386
2
['teddy bear', 'microwave']
387
1
['scissors']
388
6
['truck', 'car', 'truck', 'fire hydrant']
390
5
['potted plant', 'dog', 'car', 'car']
391
3
['remote', 'remote', 'person']
392
2
['dining table', 'cake']
393
5
['bird', 'potted plant', 'clock', 'apple']
394
2
['dog', 'fire hydrant']
395
3
['knife', 'cup', 'dining table']
396
1
['truck']
397
9
['fire hydrant']
398
15
['motorcycle']
399
28
['person']
