In [None]:
import os
import argparse
from tqdm import tqdm
import numpy as np
import skimage.io as io
from copy import copy
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session

from autolab_core import YamlConfig

from sd_maskrcnn import utils
from sd_maskrcnn.config import MaskConfig
from sd_maskrcnn.dataset import ImageDataset
from sd_maskrcnn.coco_benchmark import coco_benchmark
from sd_maskrcnn.supplement_benchmark import s_benchmark

from mrcnn import model as modellib, utils as utilslib, visualize
from mrcnn.config import Config

In [None]:
class InferenceConfig(Config):
    NAME = "my_sdmaskrcnn"
    NUM_CLASSES = 1 + 1 # background + object
    USE_MINI_MASK = False
    BACKBONE = "resnet35"
    IMAGE_MIN_DIM = 512
    IMAGE_MAX_DIM = 512
    RPN_NMS_THRESHOLD = 1.0
    DETECTION_NMS_THRESHOLD = 0.35
    POST_NMS_ROIS_INFERENCE = 2000
    MEAN_PIXEL = [128, 128, 128]
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    
config = InferenceConfig()
config.display()

In [None]:
model = modellib.MaskRCNN(mode="inference", model_dir="./models", config=config)

In [None]:
model.load_weights("models/sd_maskrcnn_20200129.h5", by_name=True)

In [None]:
class_names = ['BG', 'obj']

In [None]:
import os
import cv2

TEST_DIR = os.path.abspath("./test")
for file in os.listdir(TEST_DIR):
    depth = cv2.imread(TEST_DIR + "/" + file , cv2.IMREAD_GRAYSCALE)
    depth = np.repeat(depth[:, :, np.newaxis], 3, axis=2)
    
    results = model.detect([depth], verbose=1)
    
    r = results[0]
    visualize.display_instances(depth, r['rois'], r['masks'], r['class_ids'], class_names, r['scores'])

In [None]:
import os
import cv2

ROOT_DIR = os.path.abspath("./")
IMAGE_DIR = os.path.join(ROOT_DIR, "datasets/wisdom/wisdom-real/high-res/color_ims/")
DEPTH_DIR = os.path.join(ROOT_DIR, "datasets/wisdom/wisdom-real/high-res/depth_ims/")
image_files = os.listdir(IMAGE_DIR)
depth_files = os.listdir(DEPTH_DIR)

image_files.sort()
depth_files.sort()

for i in range(5):
    image = cv2.imread(IMAGE_DIR + image_files[i])
    depth = cv2.imread(DEPTH_DIR + depth_files[i])
    
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    plt.show()
    plt.imshow(depth)
    plt.show()
    results = model.detect([depth], verbose=1)
    r = results[0]
    visualize.display_instances(depth, r['rois'], r['masks'], r['class_ids'], class_names, r['scores'])