Skip to content

Commit

Permalink
add a XarraySource.from_stac() constructor
Browse files Browse the repository at this point in the history
This moves most of the XarraySource creation code from XarraySourceConfig to XarraySource.from_stac(), which allows an XarraySource to be created directly from a STAC Item w/o having to define an XarraySourceConfig.
  • Loading branch information
AdeelH committed Feb 14, 2024
1 parent 83fa1a1 commit 1157b44
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rastervision.core.data.utils import parse_array_slices_Nd, fill_overflow

if TYPE_CHECKING:
from pystac import Item, ItemCollection
from rastervision.core.data import RasterTransformer, CRSTransformer

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -91,6 +92,78 @@ def __init__(self,
raster_transformers=raster_transformers,
bbox=bbox)

@classmethod
def from_stac(
cls,
item_or_item_collection: Union['Item', 'ItemCollection'],
raster_transformers: List['RasterTransformer'] = [],
channel_order: Optional[Sequence[int]] = None,
bbox: Optional[Box] = None,
bbox_map_coords: Optional[Box] = None,
temporal: bool = False,
allow_streaming: bool = False,
stackstac_args: dict = dict(rescale=False)) -> 'XarraySource':
"""Construct an ``XarraySource`` from a STAC Item or ItemCollection.
Args:
item_or_item_collection: STAC Item or ItemCollection.
raster_transformers: RasterTransformers to use to transform chips
after they are read.
channel_order: List of indices of channels to extract from raw
imagery. Can be a subset of the available channels. If None,
all channels available in the image will be read.
Defaults to None.
bbox: User-specified crop of the extent. If None, the full extent
available in the source file is used. Mutually exclusive with
``bbox_map_coords``. Defaults to ``None``.
bbox_map_coords: User-specified bbox in EPSG:4326 coords of the
form (ymin, xmin, ymax, xmax). Useful for cropping the raster
source so that only part of the raster is read from. Mutually
exclusive with ``bbox``. Defaults to ``None``.
temporal: If True, data_array is expected to have a "time"
dimension and the chips returned will be of shape (T, H, W, C).
allow_streaming: If False, load the entire DataArray into memory.
Defaults to True.
stackstac_args: Optional arguments to pass to stackstac.stack().
"""
import stackstac

data_array = stackstac.stack(item_or_item_collection, **stackstac_args)

if not temporal and 'time' in data_array.dims:
if len(data_array.time) > 1:
raise ValueError('temporal=False but len(data_array.time) > 1')
data_array = data_array.isel(time=0)

if not allow_streaming:
from humanize import naturalsize
log.info('Loading the full DataArray into memory '
f'({naturalsize(data_array.nbytes)}).')
data_array.load()

crs_transformer = RasterioCRSTransformer(
transform=data_array.transform, image_crs=data_array.crs)

if bbox is not None:
if bbox_map_coords is not None:
raise ValueError('Specify either bbox or bbox_map_coords, '
'but not both.')
bbox = Box(*bbox)
elif bbox_map_coords is not None:
bbox_map_coords = Box(*bbox_map_coords)
bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize()
else:
bbox = None

raster_source = XarraySource(
data_array,
crs_transformer=crs_transformer,
raster_transformers=raster_transformers,
channel_order=channel_order,
bbox=bbox,
temporal=temporal)
return raster_source

@property
def shape(self) -> Tuple[int, int, int]:
"""Shape of the raster as a (height, width, num_channels) tuple."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import logging

from rastervision.pipeline.config import Field, register_config
from rastervision.core.box import Box
from rastervision.core.data.raster_source.raster_source_config import (
RasterSourceConfig)
from rastervision.core.data.crs_transformer import RasterioCRSTransformer
from rastervision.core.data.raster_source.stac_config import (
STACItemConfig, STACItemCollectionConfig)
from rastervision.core.data.raster_source.xarray_source import (XarraySource)
from rastervision.core.data.raster_source.xarray_source import XarraySource

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,43 +36,17 @@ class XarraySourceConfig(RasterSourceConfig):
def build(self,
tmp_dir: Optional[str] = None,
use_transformers: bool = True) -> XarraySource:
import stackstac

item_or_item_collection = self.stac.build()
data_array = stackstac.stack(item_or_item_collection,
**self.stackstac_args)

if not self.temporal and 'time' in data_array.dims:
if len(data_array.time) > 1:
raise ValueError('temporal=False but len(data_array.time) > 1')
data_array = data_array.isel(time=0)

if not self.allow_streaming:
from humanize import naturalsize
log.info('Loading the full DataArray into memory '
f'({naturalsize(data_array.nbytes)}).')
data_array.load()

crs_transformer = RasterioCRSTransformer(
transform=data_array.transform, image_crs=data_array.crs)
raster_transformers = ([rt.build() for rt in self.transformers]
if use_transformers else [])

if self.bbox is not None:
if self.bbox_map_coords is not None:
log.info('Using bbox and ignoring bbox_map_coords.')
bbox = Box(*self.bbox)
elif self.bbox_map_coords is not None:
bbox_map_coords = Box(*self.bbox_map_coords)
bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize()
else:
bbox = None

raster_source = XarraySource(
data_array,
crs_transformer=crs_transformer,
raster_source = XarraySource.from_stac(
item_or_item_collection,
raster_transformers=raster_transformers,
channel_order=self.channel_order,
bbox=bbox,
temporal=self.temporal)
bbox=self.bbox,
bbox_map_coords=self.bbox_map_coords,
temporal=self.temporal,
allow_streaming=self.allow_streaming,
stackstac_args=self.stackstac_args,
)
return raster_source

0 comments on commit 1157b44

Please sign in to comment.