-
Notifications
You must be signed in to change notification settings - Fork 3
/
selection.py
84 lines (63 loc) · 2.49 KB
/
selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import tensorflow as tf
import skimage
import skimage.io
import skimage.transform
import numpy as np
def select(masks, left_image, left_shift=16, name="select"):
'''
assumes inputs:
masks, shape N, H, W, S
left_image, shape N, H, W, C
returns
right_image, shape N, H, W, C
'''
_, H, W, S = masks.get_shape().as_list()
with tf.variable_scope(name):
padded = tf.pad(left_image, [[0,0],[0,0],[left_shift, left_shift],[0,0]], mode='REFLECT')
# padded is the image padded whatever the left_shift variable is on either side
layers = []
for s in np.arange(S):
layers.append(tf.slice(padded, [0,0,s,0], [-1,H,W,-1]))
slices = tf.stack(layers, axis=4)
disparity_image = tf.multiply(slices, tf.expand_dims(masks, axis=3))
return tf.reduce_sum(disparity_image, axis=4)
#layers = tf.zeros_like(left_image)
#for s in np.arange(S):
# mask_slice = tf.slice(masks, [0,0,0,s], [-1, H, W, 1])
# pad_slice = tf.slice(padded, [0,0,s,0], [-1,H,W,-1])
# layers = tf.add(layers, tf.multiply(mask_slice, pad_slice))
#return layers
# YOU CAN IGNORE THE FOLLOWING:
def compose(masks, left_image, left_shift=16):
'''
THIS FUNCTION IS ASSUMING WIDTH IS ON AXIS 1, MASK AXIS IS 3
Takes disparity masks and applies pixel selection to generate a right frame
Inputs:
masks: narray of shape (N, W, H, S)
left_images: narray of shape (N, W, H, C)
left_shift: maximum pixel left shift
Outputs:
right_images: narray of shape (N, W, H)
'''
N, H, W, C = masks.get_shape().as_list()
C = left_image.shape[3]
shift_images = np.zeros([N,W,H,C,S])
for s in np.arange(S):
shift_images[:,:,:,:,s] = np.roll(left_image, s-left_shift, axis=1) #width is axis 1
right_image = np.sum(shift_images * masks, axis=4) #mask axis is 4
return right_image
if __name__ == '__main__':
sess = tf.InteractiveSession()
image_contents = tf.read_file('./test_data/tiger.jpeg')
image = tf.image.decode_jpeg(image_contents,channels=3)
print image.eval().shape
def load_image(path):
# load image
img = skimage.io.imread(path)
img = img / 255.0
assert (0 <= img).all() and (img <= 1.0).all()
# print "Original Image Shape: ", img.shape
# resize to 224, 224
resized_img = skimage.transform.resize(img, (180, 320))
return resized_img