-
Notifications
You must be signed in to change notification settings - Fork 185
/
utils.py
99 lines (84 loc) · 3.22 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import functools
import traceback
import json
import numpy as np
# The decorator is used to prints an error trhown inside process
def get_traceback(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except Exception as e:
print('Caught exception in worker thread:')
traceback.print_exc()
raise e
return wrapper
class IdGenerator():
'''
The class is designed to generate unique IDs that have meaningful RGB encoding.
Given semantic category unique ID will be generated and its RGB encoding will
have color close to the predefined semantic category color.
The RGB encoding used is ID = R * 256 * G + 256 * 256 + B.
Class constructor takes dictionary {id: category_info}, where all semantic
class ids are presented and category_info record is a dict with fields
'isthing' and 'color'
'''
def __init__(self, categories):
self.taken_colors = set([0, 0, 0])
self.categories = categories
for category in self.categories.values():
if category['isthing'] == 0:
self.taken_colors.add(tuple(category['color']))
def get_color(self, cat_id):
def random_color(base, max_dist=30):
new_color = base + np.random.randint(low=-max_dist,
high=max_dist+1,
size=3)
return tuple(np.maximum(0, np.minimum(255, new_color)))
category = self.categories[cat_id]
if category['isthing'] == 0:
return category['color']
base_color_array = category['color']
base_color = tuple(base_color_array)
if base_color not in self.taken_colors:
self.taken_colors.add(base_color)
return base_color
else:
while True:
color = random_color(base_color_array)
if color not in self.taken_colors:
self.taken_colors.add(color)
return color
def get_id(self, cat_id):
color = self.get_color(cat_id)
return rgb2id(color)
def get_id_and_color(self, cat_id):
color = self.get_color(cat_id)
return rgb2id(color), color
def rgb2id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
def id2rgb(id_map):
if isinstance(id_map, np.ndarray):
id_map_copy = id_map.copy()
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
for i in range(3):
rgb_map[..., i] = id_map_copy % 256
id_map_copy //= 256
return rgb_map
color = []
for _ in range(3):
color.append(id_map % 256)
id_map //= 256
return color
def save_json(d, file):
with open(file, 'w') as f:
json.dump(d, f)