Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory usage of MapEvaluator #4989

Merged
merged 5 commits into from Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 31 additions & 25 deletions gammapy/datasets/evaluator.py
Expand Up @@ -69,7 +69,6 @@
self.mask = mask
self.gti = gti
self.use_cache = use_cache
self._init_position = None
self.contributes = True
self.psf_containment = None

Expand All @@ -93,9 +92,8 @@
self._cached_position = (0, 0)
self._computation_cache = None
self._spatial_oversampling_factor = 1
if self.exposure is not None:
if not self.geom.is_region or self.geom.region is not None:
self.update_spatial_oversampling_factor(self.geom)
if exposure is not None:
self.update_spatial_oversampling_factor(self.geom)

def _repr_html_(self):
try:
Expand All @@ -111,7 +109,10 @@
@property
def geom(self):
"""True energy map geometry (`~gammapy.maps.Geom`)."""
return self.exposure.geom
if self.exposure is not None:
return self.exposure.geom
else:
return None

Check warning on line 115 in gammapy/datasets/evaluator.py

View check run for this annotation

Codecov / codecov/patch

gammapy/datasets/evaluator.py#L115

Added line #L115 was not covered by tests

@property
def _geom_reco(self):
Expand Down Expand Up @@ -159,7 +160,12 @@
is_circle_region = isinstance(geom.region, CircleSkyRegion)
return is_point_model & is_circle_region

@property
@lazyproperty
def position(self):
"""Latest evaluation position."""
return self.model.position

@lazyproperty
def cutout_width(self):
"""Cutout width for the model component."""
return self.psf_width + 2 * (self.model.evaluation_radius + CUTOUT_MARGIN)
Expand All @@ -183,11 +189,14 @@
# TODO: simplify and clean up
log.debug("Updating model evaluator")

del self.position
del self.cutout_width

# lookup edisp
if edisp:
energy_axis = geom.axes["energy"]
self.edisp = edisp.get_edisp_kernel(
position=self.model.position, energy_axis=energy_axis
position=self.position, energy_axis=energy_axis
)
del self._edisp_diagonal

Expand All @@ -208,25 +217,20 @@
geom_psf = geom_psf.to_wcs_geom()

self.psf = psf.get_psf_kernel(
position=self.model.position,
position=self.position,
geom=geom_psf,
containment=PSF_CONTAINMENT,
max_radius=PSF_MAX_RADIUS,
)

self.exposure = exposure
if self.evaluation_mode == "local":
self.contributes = self.model.contributes(mask=mask, margin=self.psf_width)

if self.contributes:
self.exposure = exposure.cutout(
position=self.model.position, width=self.cutout_width, odd_npix=True
if self.contributes and not self.geom.is_region:
self.exposure = exposure._cutout_view(
position=self.position, width=self.cutout_width, odd_npix=True
)
else:
self.exposure = exposure

if self.contributes:
if not self.geom.is_region or self.geom.region is not None:
self.update_spatial_oversampling_factor(self.geom)
self.update_spatial_oversampling_factor(self.geom)

self.reset_cache_properties()
self._computation_cache = None
Expand All @@ -241,15 +245,17 @@

def update_spatial_oversampling_factor(self, geom):
"""Update spatial oversampling_factor for model evaluation."""
res_scale = self.model.evaluation_bin_size_min

res_scale = res_scale.to_value("deg") if res_scale is not None else 0
if self.contributes and (not geom.is_region or geom.region is not None):
res_scale = self.model.evaluation_bin_size_min

res_scale = res_scale.to_value("deg") if res_scale is not None else 0

if res_scale != 0:
if geom.is_region or geom.is_hpx:
geom = geom.to_wcs_geom()
factor = int(np.ceil(np.max(geom.pixel_scales.deg) / res_scale))
self._spatial_oversampling_factor = factor
if res_scale != 0:
if geom.is_region or geom.is_hpx:
geom = geom.to_wcs_geom()
factor = int(np.ceil(np.max(geom.pixel_scales.deg) / res_scale))
self._spatial_oversampling_factor = factor

def compute_dnde(self):
"""Compute model differential flux at map pixel centers.
Expand Down
32 changes: 32 additions & 0 deletions gammapy/maps/wcs/ndmap.py
Expand Up @@ -994,6 +994,38 @@ def cutout(self, position, width, mode="trim", odd_npix=False):

return self._init_copy(geom=geom_cutout, data=data)

def _cutout_view(self, position, width, odd_npix=False):
"""
Create a cutout around a given position without copy of the data.

Parameters
----------
position : `~astropy.coordinates.SkyCoord`
Center position of the cutout region.
width : tuple of `~astropy.coordinates.Angle`
Angular sizes of the region in (lon, lat) in that specific order.
If only one value is passed, a square region is extracted.
odd_npix : bool, optional
Force width to odd number of pixels.
Default is False.

Returns
-------
cutout : `~gammapy.maps.WcsNDMap`
Cutout map.
"""
geom_cutout = self.geom.cutout(
position=position, width=width, mode="trim", odd_npix=odd_npix
)
cutout_info = geom_cutout.cutout_slices(self.geom, mode="trim")

slices = cutout_info["parent-slices"]
parent_slices = Ellipsis, slices[0], slices[1]

return self.__class__.from_geom(
geom=geom_cutout, data=self.quantity[parent_slices]
)

def stack(self, other, weights=None, nan_to_num=True):
"""Stack cutout into map.

Expand Down