Skip to content

Commit

Permalink
First pass at subsetting by geometry
Browse files Browse the repository at this point in the history
  • Loading branch information
duncanwp committed Jan 11, 2017
1 parent febf98d commit 8c9f5cf
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion cis/subsetting/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,28 @@ def subset(data, constraint, **kwargs):
from datetime import datetime
from cis.time_util import PartialDateTime
from cis.exceptions import CoordinateNotFoundError
from shapely.wkt import loads
from shapely.geos import ReadingError

constraints = {}

for dim_name, limit in kwargs.items():
# Deal with shape argument
if dim_name == 'shape':
if isinstance(limit, six.string_types):
try:
shape = loads(limit)
except ReadingError:
raise ValueError("Invalid shape string: " + limit)
else:
shape = limit
constraints['shape'] = shape
bounding_box = shape.bounds
# Create the lat/lon box - this will be used to speed up the shape subset
constraints[data.coord(standard_name='longitude').name()] = slice(bounding_box[0], bounding_box[2])
constraints[data.coord(standard_name='latitude').name()] = slice(bounding_box[1], bounding_box[3])
break

c = data._get_coord(dim_name)
if c is None:
raise CoordinateNotFoundError("No coordinate found that matches '{}'. Please check the "
Expand Down Expand Up @@ -71,11 +89,12 @@ class SubsetConstraint(object):
Holds the limits for subsetting in each dimension.
"""

def __init__(self, limits):
def __init__(self, limits, shape=None):
"""
:param dict limits: A dictionary mapping coordinate name to slice objects
"""
self._limits = limits
self._shape = shape
logging.debug("Created SubsetConstraint of type %s", self.__class__.__name__)

def __str__(self):
Expand Down Expand Up @@ -220,3 +239,21 @@ def _create_data_for_subset(self, data):
range_start = -180
coord.set_longitude_range(range_start)
return data


def subset_region(ungridded_data, region):
from shapely.geometry import MultiPoint

cis_data = np.vstack([ungridded_data.lon.data, ungridded_data.lat.data, np.arange(len(ungridded_data.lat.data))])
points = MultiPoint(cis_data.T)

# Perform the actual calculation
selection = region.intersection(points)

# Pull out the indices
if selection.is_empty:
indices = []
else:
indices = np.asarray(selection).T[2].astype(np.int)

return ungridded_data[indices]

0 comments on commit 8c9f5cf

Please sign in to comment.