-
Notifications
You must be signed in to change notification settings - Fork 378
/
chip_classification_geojson_store.py
76 lines (64 loc) · 2.73 KB
/
chip_classification_geojson_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import TYPE_CHECKING, Optional
from rastervision.pipeline.file_system import json_to_file
from rastervision.core.data.label import ChipClassificationLabels
from rastervision.core.data.label_store import LabelStore
from rastervision.core.data.label_store.utils import boxes_to_geojson
from rastervision.core.data.label_source import (
ChipClassificationLabelSourceConfig)
from rastervision.core.data.vector_source import (GeoJSONVectorSourceConfig)
if TYPE_CHECKING:
from rastervision.core.box import Box
from rastervision.core.data import ClassConfig, CRSTransformer
class ChipClassificationGeoJSONStore(LabelStore):
"""Storage for chip classification predictions."""
def __init__(self,
uri: str,
class_config: 'ClassConfig',
crs_transformer: 'CRSTransformer',
bbox: Optional['Box'] = None):
"""Constructor.
Args:
uri: uri of GeoJSON file containing labels
class_config: ClassConfig
crs_transformer: CRSTransformer to convert from map coords in label
in GeoJSON file to pixel coords.
bbox (Optional[Box], optional): User-specified crop of the extent.
If provided, only labels falling inside it are returned by
:meth:`.ChipClassificationGeoJSONStore.get_labels`. Must be
provided if the corresponding RasterSource has bbox != extent.
"""
self.uri = uri
self.class_config = class_config
self._crs_transformer = crs_transformer
self._bbox = bbox
def save(self, labels: ChipClassificationLabels) -> None:
"""Save labels to URI if writable.
Note that if the grid is inferred from polygons, only the grid will be
written, not the original polygons.
"""
boxes = labels.get_cells()
class_ids = labels.get_class_ids()
scores = list(labels.get_scores())
geojson = boxes_to_geojson(
boxes,
class_ids,
self.crs_transformer,
self.class_config,
scores=scores,
bbox=self.bbox)
json_to_file(geojson, self.uri)
def get_labels(self) -> ChipClassificationLabels:
vs = GeoJSONVectorSourceConfig(uris=self.uri)
ls = ChipClassificationLabelSourceConfig(vector_source=vs).build(
class_config=self.class_config,
crs_transformer=self.crs_transformer,
bbox=self.bbox)
return ls.get_labels()
@property
def bbox(self) -> 'Box':
return self._bbox
@property
def crs_transformer(self) -> 'CRSTransformer':
return self._crs_transformer
def set_bbox(self, bbox: 'Box') -> None:
self._bbox = bbox