In [3]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as ipw
import cv2
from skimage import morphology

In [4]:
def _disc(n):
    if n % 2 == 0:
        n -= 1
    n = int((n - 1)/2)
    y,x = np.ogrid[-n: n+1, -n: n+1]
    mask = x**2+y**2 <= n**2
    return mask.astype(np.uint8)

def dtform_fill(depth_mask):
    dt_fn = cv2.distanceTransformWithLabels
    dtmask = (depth_mask == 0).astype(np.uint8)
    dtform, label = dt_fn(dtmask,
                           distanceType=cv2.DIST_L2,
                           maskSize=3, #cv2.DIST_MASK_PRECISE,
                           labelType=cv2.DIST_LABEL_PIXEL)

    depth_mask = depth_mask.reshape([-1])
    depth_mask = depth_mask[depth_mask != 0][label-1]
    
    return depth_mask

data = np.load("test.npy").item()

In [30]:
uncert = data["uncert"]
pred = data["pred"]
annot = data["annot"]
path = data["path"]
image = cv2.imread(path)
disp = cv2.imread(path.replace("leftImg8bit", "disparity"), cv2.IMREAD_UNCHANGED)
plt.figure()
plt.imshow(image)
plt.figure()
plt.imshow(uncert)
plt.figure()
plt.imshow(pred)
plt.figure()
plt.imshow(annot)
plt.figure()
plt.imshow(disp)

fx = 2256.47
baseline = 0.209313
sky = 10

mask = disp != 0
disp = mask * ((disp.astype(np.float32) - 1.) / 256.)
with np.errstate(divide='ignore'):
    depth = (baseline * fx) / disp
depth[disp == 0] = 0
depth[annot == sky] = depth.max()
depth = cv2.medianBlur(depth, 5)

depth = depth*morphology.remove_small_objects(depth != 0, min_size=1000, connectivity=2)
depth = cv2.inpaint(depth, (depth == 0).astype(np.uint8)*255, 3, cv2.INPAINT_NS)

depth_show = np.copy(depth)
depth_show[depth == 0] = 1
plt.figure()
plt.imshow((depth_show))

near = 18
far = 50

depth_mask = np.zeros_like(depth).astype(np.uint8)
depth_mask[np.logical_and(depth > 0, depth <= near)] = 1
depth_mask[np.logical_and(depth > near, depth <= far)] = 2
depth_mask[depth > far] = 3
depth_mask[annot == sky] = 3

plt.figure()
plt.imshow(depth_mask)

kern = _disc(7)
depth_mask = cv2.morphologyEx(depth_mask, cv2.MORPH_OPEN, kern)
depth_mask = cv2.morphologyEx(depth_mask, cv2.MORPH_CLOSE, kern)
#depth_mask = depth_mask*morphology.remove_small_objects(depth_mask != 0, min_size=1000, connectivity=2)


#set a value to 0 so image doesn't normalize
depth_mask[0,0] = 0
plt.figure()
plt.imshow(depth_mask)

depth_mask = dtform_fill(depth_mask)

#set a value to 0 so image doesn't normalize
depth_mask[0,0] = 0
plt.figure()
plt.imshow(depth_mask)

depth_mask = cv2.medianBlur(depth_mask, 11)
new_mask = np.zeros_like(depth_mask)

for v in [1,2,3]:
    tmp_mask = morphology.remove_small_objects(depth_mask == v, min_size=1000, connectivity=2)
    new_mask[tmp_mask] = v

depth_mask = new_mask
depth_mask = dtform_fill(depth_mask)

#set a value to 0 so image doesn't normalize
depth_mask[0,0] = 0
plt.figure()
plt.imshow(depth_mask)

plt.figure()
plt.imshow(image)
plt.imshow(depth_mask, alpha=0.35)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7ffa103ac358>

In [7]:
def get_img(thresh=128, dilate=5, erode=5):
    if dilate % 2 == 0:
        dilate -= 1
    if erode % 2 == 0:
        erode -= 1
    if dilate < 1:
        dilate = 1
    if erode < 1:
        erode = 1
    edges = cv2.Canny((uncert*255).astype(np.uint8), thresh, thresh)
    di = cv2.dilate(edges, np.ones((dilate, dilate)))
    er = cv2.erode(di, np.ones((erode, erode)))
    return er

def get_img_depth(thresh, dilate, erode, depth):
    img = get_img(thresh, dilate, erode)
    if depth != None:
        img = img * (depth_mask == depth)
    return img
    
def interactive_plot(depth):
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    init = get_img()
    im = ax.imshow(init)

    def update(thresh, dilate, erode):
        er = get_img_depth(thresh, dilate, erode, depth)
        im.set_data(er)
        fig.canvas.draw()

    thresh = ipw.IntSlider(min=0,max=256,step=1,value=128)
    dilate = ipw.IntSlider(min=3,max=31,step=2,value=5)
    erode = ipw.IntSlider(min=3,max=31,step=2,value=5)
    plt.show()
    ipw.interact(update, thresh=thresh, dilate=dilate, erode=erode)
    return thresh, dilate, erode

In [8]:
#164, 11, 11
near_thresh, near_dilate, near_erode = interactive_plot(1)

<IPython.core.display.Javascript object>

interactive(children=(IntSlider(value=128, description='thresh', max=256), IntSlider(value=5, description='dil…

In [9]:
#193,7,7
med_thresh, med_dilate, med_erode = interactive_plot(2)

<IPython.core.display.Javascript object>

interactive(children=(IntSlider(value=128, description='thresh', max=256), IntSlider(value=5, description='dil…

In [14]:
#104, 3, 3
far_thresh, far_dilate, far_erode = interactive_plot(3)

<IPython.core.display.Javascript object>

interactive(children=(IntSlider(value=128, description='thresh', max=256), IntSlider(value=5, description='dil…

In [15]:
all_thresh, all_dilate, all_erode = interactive_plot(None)

<IPython.core.display.Javascript object>

interactive(children=(IntSlider(value=128, description='thresh', max=256), IntSlider(value=5, description='dil…

In [16]:
print(near_thresh.value, near_dilate.value, near_erode.value)
print(med_thresh.value, med_dilate.value, med_erode.value)
print(far_thresh.value, far_dilate.value, far_erode.value)
final = get_img_depth(near_thresh.value, near_dilate.value, near_erode.value, 1) 
final += get_img_depth(med_thresh.value, med_dilate.value, med_erode.value, 2)
final += get_img_depth(far_thresh.value, far_dilate.value, far_erode.value,3)

final = get_img_depth(all_thresh.value, all_dilate.value, all_erode.value, None)

plt.figure()
plt.imshow(final)

162 11 11
191 7 7
128 5 5


<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7ff9fc67bb70>

In [28]:
dtform = cv2.distanceTransform(255 - final,
                       distanceType=cv2.DIST_L2,
                       maskSize=cv2.DIST_MASK_PRECISE)

#with np.errstate(divide='ignore'):
#    border_probs = 1/np.sqrt(dtform)
#border_probs[dtform == 0] = np.nanmax(border_probs)
#border_probs[dtform == 0] = 1
#border_probs *= 0.99

dtform[dtform > 10] = 10
dtform /= 10

border_probs = dtform
probs = cv2.dilate(border_probs*uncert, np.ones((3,3)))

fig = plt.figure()
plt.imshow(probs, cmap="jet")

plt.show()
plt.draw()
#fig.savefig('tessstttyyy.png', dpi=1000)

<IPython.core.display.Javascript object>