In [5]:
import os
import numpy as np
import cv2
import random
import matplotlib.pyplot as plt 
from PIL import Image
import glob
import json
from pycococreatortools import pycococreatortools

labels = ['chair', 'cushion', 'door', 'indoor-plant', 'sofa', 'table']
semantic_json_root = '/checkpoint/apratik/ActiveVision/active_vision/info_semantic'

def load_semantic_json(scene):
    replica_root = '/datasets01/replica/061819/18_scenes'
    habitat_semantic_json = os.path.join(replica_root, scene, 'habitat', 'info_semantic.json')
#         habitat_semantic_json = os.path.join(self.sjr, scene + '_info_semantic.json')
#         print(f"Using habitat semantic json {habitat_semantic_json}")
    with open(habitat_semantic_json, "r") as f:
        hsd = json.load(f)
    if hsd is None:
        print("Semantic json not found!")
    return hsd

hsd = load_semantic_json('apartment_0')

label_id_dict = {}
new_old_id = {}
idc = 1
for obj_cls in hsd["classes"]:
    if obj_cls["name"] in labels:
        label_id_dict[obj_cls["id"]] = obj_cls["name"]
        new_old_id[obj_cls['id']] = idc
        idc += 1

In [38]:
class PickGoodCandidates:
    def __init__(self, img_dir, depth_dir, seg_dir, instance_ids):
        self.imgdir = img_dir
        self.depthdir = depth_dir
        self.segdir = seg_dir
        self.iids = instance_ids
        self.filtered = False
        self.chosen = set()
        
    def is_open_contour(self, c):
        # check for a bunch of edge points
        # c is of the format num_points * 1 * 2
        edge_points = []
        for x in c:
            if x[0][0] == 0 or x[0][1] == 0 or x[0][0] == 511 or x[0][1] == 511:
                edge_points.append(x)
#         print(len(edge_points))
        if len(edge_points) > 0:
            return True
        return False

    def find_nearest2(self, x):
        dist = 10000
        res = -1
        for y, _ in self.good_candidates:
            if abs(x-y) < dist and y not in self.chosen:
                dist = abs(x-y)
                res = y
        # now look in vicinity of res for frame with max size 
        for x in range(4):
            self.chosen.add(res+x)
            self.chosen.add(res-x)
        return res
    
    def sample_uniform_nn2(self, n):
        frames = set
        for x in self.iids:
            if not self.filtered:
                self.filter_candidates(x)
                
            # now pick n best ids
            print(f'{len(self.good_candidates)} good candidates for instance id {x}')
            
            if len(self.good_candidates) > 0:
                # sort a list of tuples by the second element 
                sorted(self.good_candidates, key= lambda x: x[1], reverse=True)

                print(f'sorted candidates {self.good_candidates[:5]}')
                frames.add([x[0] for x, v in self.good_candidates[:n]])
        return list(frames)
        
    def filter_candidates(self, iid):
        self.good_candidates = []
        self.bad_candidates = []
        for x in range(len(os.listdir(self.imgdir))):
            res, size = self.is_good_candidate(x, iid)
            if res:
                self.good_candidates.append((x, size))
#                 self.vis(x)
            elif res == False:
                self.bad_candidates.append(x)
            elif not res:
                print(f'None for {x}')
                
        assert len(os.listdir(self.imgdir)) == len(self.good_candidates) + len(self.bad_candidates)
#         print(f'good candidates {self.good_candidates}')
        
    def is_good_candidate(self, fname, iid, vis=False):
        dpath = os.path.join(self.depthdir, "{:05d}.npy".format(fname))
        imgpath = os.path.join(self.imgdir, "{:05d}.jpg".format(fname))
        segpath = os.path.join(self.segdir, "{:05d}.npy".format(fname))
                
        # Load Image
        if not os.path.isfile(imgpath):
            print(f'looking for {imgpath}')
            return None, None
        img = cv2.cvtColor(cv2.imread(imgpath), cv2.COLOR_BGR2RGB)
#         print(img.shape)
        
        # Load Annotations 
        annot = np.load(segpath).astype(np.uint32)
        
        if iid in np.unique(annot):
            print(f'{iid} in {np.unique(annot)}')
            binary_mask = (annot == iid).astype(np.uint32)
            print(f'mask area {binary_mask.sum()}')

        if vis:
            plt.imshow(binary_mask)
            plt.show()
#             print(np.unique(all_binary_mask))
            
        if not binary_mask.any():
            return False, None
        
        # Check that all masks are within a certain distance from the boundary
        # all pixels [:10,:], [:,:10], [-10:], [:-10] must be 0:
        if binary_mask[:10,:].any() or all_binary_mask[:,:10].any() or all_binary_mask[:,-10:].any() or all_binary_mask[-10:,:].any():
            return False, None
        
        
        
        return True, (binary_mask == 1).sum()
        
    def visualize_good_bad(self, num):
        # TODO: sample num numbers from all, then look at he 
        # sample num from good bad
        good = random.sample(self.good_candidates, num)
        bad = random.sample(self.bad_candidates, num)
        
        for x in range(num):
            gim = os.path.join(self.imgdir, "{:05d}.jpg".format(good[x][0]))
            gim = cv2.cvtColor(cv2.imread(gim), cv2.COLOR_BGR2RGB)
            
            bim = os.path.join(self.imgdir, "{:05d}.jpg".format(bad[x]))
            bim = cv2.cvtColor(cv2.imread(bim), cv2.COLOR_BGR2RGB)
            
            arr = [gim, bim]
            titles = ['good', 'bad']
            plt.figure(figsize=(5,4))
            for i, data in enumerate(arr):
                ax = plt.subplot(1, 2, i+1)
                ax.axis('off')
                ax.set_title(titles[i])
                plt.imshow(data)
            plt.show()
        
    def vis(self, x, contours=None):
        
        imgpath = os.path.join(self.imgdir, "{:05d}.jpg".format(x))
        image = cv2.cvtColor(cv2.imread(imgpath), cv2.COLOR_BGR2RGB)
        
        if contours:
            image = cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
      
        arr = [image]
        titles = ["{:05d}.jpg".format(x)]
        plt.figure(figsize=(5,4))
        for i, data in enumerate(arr):
    #         print(f'data.shape {data.shape}')
            ax = plt.subplot(1, 1, i+1)
            ax.axis('off')
            ax.set_title(titles[i])
            plt.imshow(data)
        plt.show()

In [39]:
data_dir = '/checkpoint/apratik/data_devfair0187/apartment_0/straightline/no_noise/instance_detection_ids_allinone_auto'

img_dir = os.path.join(data_dir, 'rgb')
depth_dir = os.path.join(data_dir, 'depth')
seg_dir = os.path.join(data_dir, 'seg')
instance_ids = [404,196,172,243,133,129,170]


s = PickGoodCandidates(img_dir=img_dir, depth_dir=depth_dir, seg_dir=seg_dir, instance_ids=instance_ids)

for gt in [1,5]:
#     xs = s.sample_uniform_nn(gt)
    # pick gt frames for each instance 
    x2 = s.sample_uniform_nn2(gt)
    print(x2)
    s.chosen.clear()

404 in [ 23  30  31  32  54  58 133 182 248 275 288 303 312 318 319 348 391 404
 432]
mask area 128
404 in [ 23  30  31  32  54  58 124 133 182 199 244 248 275 288 303 312 318 319
 348 391 404 432]
mask area 120
404 in [ 23  30  31  32  54  58 124 133 182 199 244 248 275 288 303 312 318 348
 391 404 432 433]
mask area 1036
404 in [ 23  30  32  54  58 124 133 182 199 248 275 288 303 312 318 348 391 404
 432 433]
mask area 703
404 in [ 23  30  32  54  58 124 133 182 199 248 275 288 303 312 318 348 391 404
 433]
mask area 370
404 in [ 23  30  32  58 133 182 244 248 275 288 303 312 391 404]
mask area 2
404 in [  0   3  18  23  38  45  47  49  50  52  53  54  58 121 124 152 165 173
 188 193 196 199 224 234 241 251 254 270 297 318 348 361 364 400 404 428
 430 431]
mask area 553
404 in [  0   3  18  23  45  47  49  50  52  53  54  58 121 124 152 165 173 193
 196 199 234 241 251 254 270 318 348 361 364 400 404 428 430 431]
mask area 721
0 good candidates for instance id 404
196 in [  0   3   6

TypeError: 'int' object is not subscriptable