-
Notifications
You must be signed in to change notification settings - Fork 8
/
self_taught_loc_test.py
126 lines (119 loc) · 6.28 KB
/
self_taught_loc_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
from skimage.data import imread
import unittest
from imgsegmentation import *
from configuration import *
from network import *
from self_taught_loc import *
class SelfTaughtLocTest(unittest.TestCase):
def setUp(self):
root = "/home/ironfs/scratch/vlg/Data/Images/ILSVRC2012/caffe_model_141118"
conf = Configuration(root=root)
netParams = NetworkCaffe1114Params(conf.caffe_model_spec, \
conf.caffe_model, \
conf.caffe_wnids_words, \
conf.caffe_avg_image, \
center_only = True,\
wnid_subset = [])
# instantiate network
self.net = Network.create_network(netParams)
# choose segmentation method (Matlab wrapper Felz through SS)
img_segmenter = ImgSegmMatWraper()
# instantiate STL object
self.stl_grayout = SelfTaughtLoc_Grayout(self.net, img_segmenter, \
min_sz_segm = 5, topC = 5,\
alpha = 1/3.0*np.ones((3,)), \
obfuscate_bbox = True, \
function_stl = 'similarity')
# instantiate STL object
self.stl_grayout_2 = SelfTaughtLoc_Grayout(self.net, img_segmenter, \
min_sz_segm = 5, topC = 5,\
alpha = 1/4.0*np.ones((4,)), \
obfuscate_bbox = True, \
function_stl = 'similarity+cnnfeature')
# instantiate STL object
self.stl_grayout_3 = SelfTaughtLoc_Grayout(self.net, img_segmenter, \
min_sz_segm = 5, topC = 5,
alpha = 1/4.0*np.ones((4,)), \
obfuscate_bbox = True, \
function_stl = 'similarity+cnnfeature', \
padding = 0.2)
def tearDown(self):
self.stl_grayout = None
self.stl_grayout_2 = None
self.stl_grayout_3 = None
def test_extract_stl_u(self):
# read image
img = imread('test_data/ILSVRC2012_val_00000001_n01751748.JPEG')
# resize image to fit the net input
image_resz = skimage.transform.resize(img,\
(self.net.get_input_dim(), self.net.get_input_dim()))
image_resz = skimage.img_as_ubyte(image_resz)
img_width, img_height = np.shape(image_resz)[0:2]
# perform stl unsupervised
segment_lists = self.stl_grayout.extract_greedy(image_resz)
# Control some elements
self.assertEqual(np.shape(segment_lists)[0], 4)
self.assertEqual(np.shape(segment_lists[0])[0], 217)
self.assertEqual(np.shape(segment_lists[1])[0], 123)
self.assertEqual(np.shape(segment_lists[2])[0], 93)
self.assertEqual(np.shape(segment_lists[3])[0], 47)
self.assertEqual(segment_lists[0][100]['mask'][0,0], False)
self.assertEqual(segment_lists[1][50]['mask'][10,10], True)
self.assertEqual(segment_lists[2][65]['mask'][0,0], False)
self.assertEqual(segment_lists[3][33]['mask'][30,10], True)
self.assertEqual(segment_lists[0][100]['bbox'].xmin, 87)
self.assertEqual(segment_lists[1][50]['bbox'].xmin, 41)
self.assertEqual(segment_lists[2][65]['bbox'].xmin, 9)
self.assertEqual(segment_lists[3][33]['bbox'].xmin, 73)
def test_extract_stl_u_cnnfeature(self):
# read image
img = imread('test_data/ILSVRC2012_val_00000001_n01751748.JPEG')
# resize image to fit the net input
image_resz = skimage.transform.resize(img,\
(self.net.get_input_dim(), self.net.get_input_dim()))
image_resz = skimage.img_as_ubyte(image_resz)
img_width, img_height = np.shape(image_resz)[0:2]
# perform stl unsupervised
segment_lists = self.stl_grayout_2.extract_greedy(image_resz)
# Control some elements
self.assertEqual(np.shape(segment_lists)[0], 4)
self.assertEqual(np.shape(segment_lists[0])[0], 217)
self.assertEqual(np.shape(segment_lists[1])[0], 123)
self.assertEqual(np.shape(segment_lists[2])[0], 93)
self.assertEqual(np.shape(segment_lists[3])[0], 47)
self.assertEqual(segment_lists[0][100]['mask'][0,0], False)
self.assertEqual(segment_lists[1][50]['mask'][10,10], True)
self.assertEqual(segment_lists[2][65]['mask'][0,0], False)
self.assertEqual(segment_lists[3][33]['mask'][30,10], False)
self.assertEqual(segment_lists[0][100]['bbox'].xmin, 87)
self.assertEqual(segment_lists[1][50]['bbox'].xmin, 41)
self.assertEqual(segment_lists[2][65]['bbox'].xmin, 170)
self.assertEqual(segment_lists[3][33]['bbox'].xmin, 0)
def test_extract_stl_u_cnnfeature_pad(self):
# read image
img = imread('test_data/ILSVRC2012_val_00000001_n01751748.JPEG')
# resize image to fit the net input
image_resz = skimage.transform.resize(img,\
(self.net.get_input_dim(), self.net.get_input_dim()))
image_resz = skimage.img_as_ubyte(image_resz)
img_width, img_height = np.shape(image_resz)[0:2]
# perform stl unsupervised
segment_lists = self.stl_grayout_3.extract_greedy(image_resz)
# Control some elements
self.assertEqual(np.shape(segment_lists)[0], 4)
self.assertEqual(np.shape(segment_lists[0])[0], 217)
self.assertEqual(np.shape(segment_lists[1])[0], 123)
self.assertEqual(np.shape(segment_lists[2])[0], 93)
self.assertEqual(np.shape(segment_lists[3])[0], 47)
self.assertEqual(segment_lists[0][100]['mask'][0,0], False)
self.assertEqual(segment_lists[1][50]['mask'][10,10], True)
self.assertEqual(segment_lists[2][65]['mask'][0,0], False)
self.assertEqual(segment_lists[3][33]['mask'][30,10], False)
self.assertEqual(segment_lists[0][100]['bbox'].xmin, 87)
self.assertEqual(segment_lists[1][50]['bbox'].xmin, 41)
self.assertEqual(segment_lists[2][65]['bbox'].xmin, 82)
self.assertEqual(segment_lists[3][33]['bbox'].xmin, 47)
# ------------ perform test ------------ #
if __name__ == '__main__':
unittest.main()