-
Notifications
You must be signed in to change notification settings - Fork 378
/
object_detection_label_source.py
128 lines (107 loc) · 4.91 KB
/
object_detection_label_source.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import TYPE_CHECKING, Any, Optional, Tuple
import numpy as np
from rastervision.core.box import Box
from rastervision.core.data.label import ObjectDetectionLabels
from rastervision.core.data.label_source import LabelSource
from rastervision.core.data.vector_source import VectorSource
from rastervision.core.data.utils import parse_array_slices
if TYPE_CHECKING:
from rastervision.core.data import CRSTransformer
class ObjectDetectionLabelSource(LabelSource):
"""A read-only label source for object detection."""
def __init__(self,
vector_source: VectorSource,
extent: Optional[Box] = None,
ioa_thresh: Optional[float] = None,
clip: bool = False):
"""Constructor.
Args:
vector_source (VectorSource): A VectorSource.
extent (Optional[Box]): User-specified extent. If None, the full
extent of the vector source is used.
ioa_thresh (Optional[float], optional): IOA threshold to apply when
retieving labels for a window. Defaults to None.
clip (bool, optional): Clip bounding boxes to window limits when
retrieving labels for a window. Defaults to False.
"""
self.vector_source = vector_source
geojson = self.vector_source.get_geojson()
self.validate_geojson(geojson)
self.labels = ObjectDetectionLabels.from_geojson(
geojson, extent=extent)
if extent is None:
extent = vector_source.extent
self._extent = extent
self.ioa_thresh = ioa_thresh if ioa_thresh is not None else 1e-6
self.clip = clip
def get_labels(self,
window: Box = None,
ioa_thresh: float = 1e-6,
clip: bool = False) -> ObjectDetectionLabels:
"""Get labels (in global coords) for a window.
Args:
window (Box): Window coords.
Returns:
ObjectDetectionLabels: Labels with sufficient overlap with the
window. The returned labels are in global coods
(i.e. coords wihtin the full extent).
"""
if window is None:
return self.labels
window = window.shift_origin(self.extent)
return ObjectDetectionLabels.get_overlapping(
self.labels, window, ioa_thresh=ioa_thresh, clip=clip)
def __getitem__(self, key: Any) -> Tuple[np.ndarray, np.ndarray, str]:
"""Get labels (in window coords) for a window.
Returns a 3-tuple: (npboxes, class_ids, box_format).
- npboxes is a float np.ndarray of shape (num_boxes, 4) representing
pixel coords of bounding boxes in the form [ymin, xmin, ymax, xmax].
- class_ids is a np.ndarray of shape (num_boxes,) representing the
class labels for each of the boxes.
- box_format is the format of npboxes which, in this case, is always
'yxyx'.
Args:
window (Box): Window coords.
Returns:
Tuple[np.ndarray, np.ndarray, str]: 3-tuple of
(npboxes, class_ids, box_format). The returned npboxes are in
window coords (i.e. coords within the window).
"""
if isinstance(key, Box):
window = key
labels = self.get_labels(
window, ioa_thresh=self.ioa_thresh, clip=self.clip)
class_ids = labels.get_class_ids()
npboxes = labels.get_npboxes()
npboxes = ObjectDetectionLabels.global_to_local(npboxes, window)
return npboxes, class_ids, 'yxyx'
window, (h, w) = parse_array_slices(key, extent=self.extent, dims=2)
npboxes, class_ids, fmt = self[window]
# rescale if steps specified
if h.step is not None:
# assume fmt='yxyx'
npboxes[:, [0, 2]] /= h.step
if w.step is not None:
# assume fmt='yxyx'
npboxes[:, [1, 3]] /= w.step
return npboxes, class_ids, fmt
def validate_geojson(self, geojson: dict) -> None:
for f in geojson['features']:
geom_type = f.get('geometry', {}).get('type', '')
if 'Point' in geom_type or 'LineString' in geom_type:
raise ValueError(
'LineStrings and Points are not supported '
'in ChipClassificationLabelSource. Use BufferTransformer '
'to buffer them into Polygons.')
for f in geojson['features']:
if f.get('properties', {}).get('class_id') is None:
raise ValueError('All GeoJSON features must have a class_id '
'field in their properties.')
@property
def extent(self) -> Box:
return self._extent
@property
def crs_transformer(self) -> 'CRSTransformer':
return self.vector_source.crs_transformer
def set_extent(self, extent: 'Box') -> None:
self._extent = extent