Skip to content

Commit

Permalink
Fix memory leak in reading LC/TPF (lightkurve#1390)
Browse files Browse the repository at this point in the history
* Fixed memleak for lc in lightkurve#1388

* Fixed memleak for tpf in lightkurve#1388

* add test for read HDUList

* Explicit tests for read memory leaks (LC & TPF)
- Run in memtest workflow in CI (pytest -m memtest --remote-data)

* Test tpf.from_fits_images() to ensure no unclosed file handles

* Revert lc.hdu change in PR lightkurve#1299

* Revert raising ResourceWarning as error during tests in PR lightkurve#1299
- For it to actually work (to ensure no unclosed files), "error::pytest.PytestUnraisableExceptionWarning" wil also be needed
- but it'll create many false alarms.
- Explicit tests on unclosed file handles is done in specific tests instead.

* add changelog  [skip ci]
  • Loading branch information
orionlee committed Dec 5, 2023
1 parent f8e8c16 commit 68fdf03
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 146 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
2.5.0 (unreleased)
=====================

- Fixed memory leak in reading Lightcurve / TargetPixel FITS files in v2.4.2 [#1390]


2.4.2 (2023-11-03)
=====================
Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,3 @@ testpaths = [
"tests",
"src",
]
filterwarnings = [
"error::ResourceWarning",
]
207 changes: 105 additions & 102 deletions src/lightkurve/io/generic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Read a generic FITS table containing a light curve."""
import logging
import warnings
from copy import deepcopy

from astropy.io import fits
from astropy.table import Table
Expand Down Expand Up @@ -33,105 +32,109 @@ def read_generic_lightcurve(
into a generic `LightCurve` object.
"""
if isinstance(filename, fits.HDUList):
hdulist = filename # Allow HDUList to be passed
hdulist, to_close_hdul = filename, False # Allow HDUList to be passed
else:
with fits.open(filename) as hdulist:
hdulist = deepcopy(hdulist)

# Raise an exception if the requested extension is invalid
if isinstance(ext, str):
validate_method(ext, supported_methods=[hdu.name.lower() for hdu in hdulist])
with warnings.catch_warnings():
# By default, AstroPy emits noisy warnings about units commonly used
# in archived TESS data products (e.g., "e-/s" and "pixels").
# We ignore them here because they don't affect Lightkurve's features.
# Inconsistencies between TESS data products and the FITS standard
# out to be addressed at the archive level. (See issue #1216.)
warnings.simplefilter("ignore", category=UnitsWarning)
tab = Table.read(hdulist[ext], format="fits")

# Make sure the meta data also includes header fields from extension #0
tab.meta.update(hdulist[0].header)

tab.meta = {k: v for k, v in tab.meta.items()}

for colname in tab.colnames:
# Ensure units have the correct astropy format
# Speed-up: comparing units by their string representation is 1000x
# faster than performing full-blown unit comparison
unitstr = str(tab[colname].unit)
if unitstr == "e-/s":
tab[colname].unit = "electron/s"
elif unitstr == "pixels":
tab[colname].unit = "pixel"
elif unitstr == "ppm" and repr(tab[colname].unit).startswith("Unrecognized"):
# Workaround for issue #956
tab[colname].unit = ppm
elif unitstr == "ADU":
tab[colname].unit = "adu"
elif unitstr.lower() == "unitless":
tab[colname].unit = ""
elif unitstr.lower() == "degcelcius":
# CDIPS has non-astropy units
tab[colname].unit = "deg_C"
# Rename columns to lowercase
tab.rename_column(colname, colname.lower())

# Some KEPLER files used to have a T column instead of TIME.
if time_column == "time" and "time" not in tab.columns and "t" in tab.colnames:
tab.rename_column("t", "time")
if time_column != "time":
tab.rename_column(time_column, "time")

# We *have* to remove rows with TIME=NaN because the Astropy Time
# object does not support the presence of NaNs.
# Fortunately, such rows are always bad data.
nans = np.isnan(tab["time"].data)
if np.any(nans):
log.debug("Ignoring {} rows with NaN times".format(np.sum(nans)))
tab = tab[~nans]

# Prepare a special time column
if not time_format:
if hdulist[ext].header.get("BJDREFI") == 2454833:
time_format = "bkjd"
elif hdulist[ext].header.get("BJDREFI") == 2457000:
time_format = "btjd"
else:
raise ValueError(f"Input file has unclear time format: {filename}")
time = Time(
tab["time"].data,
scale=hdulist[ext].header.get("TIMESYS", "tdb").lower(),
format=time_format,
)
tab.remove_column("time")

# For backwards compatibility with Lightkurve v1.x,
# we make sure standard columns and attributes exist.
if "flux" not in tab.columns:
tab.add_column(tab[flux_column], name="flux", index=0)
if "flux_err" not in tab.columns:
# Try falling back to `{flux_column}_err` if possible
if flux_err_column not in tab.columns:
flux_err_column = flux_column + "_err"
if flux_err_column in tab.columns:
tab.add_column(tab[flux_err_column], name="flux_err", index=1)
if "quality" not in tab.columns and quality_column in tab.columns:
tab.add_column(tab[quality_column], name="quality", index=2)
if "cadenceno" not in tab.columns and cadenceno_column in tab.columns:
tab.add_column(tab[cadenceno_column], name="cadenceno", index=3)
if "centroid_col" not in tab.columns and centroid_col_column in tab.columns:
tab.add_column(tab[centroid_col_column], name="centroid_col", index=4)
if "centroid_row" not in tab.columns and centroid_row_column in tab.columns:
tab.add_column(tab[centroid_row_column], name="centroid_row", index=5)

tab.meta["LABEL"] = hdulist[0].header.get("OBJECT")
tab.meta["MISSION"] = hdulist[0].header.get(
"MISSION", hdulist[0].header.get("TELESCOP")
)
tab.meta["RA"] = hdulist[0].header.get("RA_OBJ")
tab.meta["DEC"] = hdulist[0].header.get("DEC_OBJ")
tab.meta["FILENAME"] = filename
tab.meta["FLUX_ORIGIN"] = flux_column

return LightCurve(time=time, data=tab)
hdulist, to_close_hdul = fits.open(filename), True

try:
# Raise an exception if the requested extension is invalid
if isinstance(ext, str):
validate_method(ext, supported_methods=[hdu.name.lower() for hdu in hdulist])
with warnings.catch_warnings():
# By default, AstroPy emits noisy warnings about units commonly used
# in archived TESS data products (e.g., "e-/s" and "pixels").
# We ignore them here because they don't affect Lightkurve's features.
# Inconsistencies between TESS data products and the FITS standard
# out to be addressed at the archive level. (See issue #1216.)
warnings.simplefilter("ignore", category=UnitsWarning)
tab = Table.read(hdulist[ext], format="fits")

# Make sure the meta data also includes header fields from extension #0
tab.meta.update(hdulist[0].header)

tab.meta = {k: v for k, v in tab.meta.items()}

for colname in tab.colnames:
# Ensure units have the correct astropy format
# Speed-up: comparing units by their string representation is 1000x
# faster than performing full-blown unit comparison
unitstr = str(tab[colname].unit)
if unitstr == "e-/s":
tab[colname].unit = "electron/s"
elif unitstr == "pixels":
tab[colname].unit = "pixel"
elif unitstr == "ppm" and repr(tab[colname].unit).startswith("Unrecognized"):
# Workaround for issue #956
tab[colname].unit = ppm
elif unitstr == "ADU":
tab[colname].unit = "adu"
elif unitstr.lower() == "unitless":
tab[colname].unit = ""
elif unitstr.lower() == "degcelcius":
# CDIPS has non-astropy units
tab[colname].unit = "deg_C"
# Rename columns to lowercase
tab.rename_column(colname, colname.lower())

# Some KEPLER files used to have a T column instead of TIME.
if time_column == "time" and "time" not in tab.columns and "t" in tab.colnames:
tab.rename_column("t", "time")
if time_column != "time":
tab.rename_column(time_column, "time")

# We *have* to remove rows with TIME=NaN because the Astropy Time
# object does not support the presence of NaNs.
# Fortunately, such rows are always bad data.
nans = np.isnan(tab["time"].data)
if np.any(nans):
log.debug("Ignoring {} rows with NaN times".format(np.sum(nans)))
tab = tab[~nans]

# Prepare a special time column
if not time_format:
if hdulist[ext].header.get("BJDREFI") == 2454833:
time_format = "bkjd"
elif hdulist[ext].header.get("BJDREFI") == 2457000:
time_format = "btjd"
else:
raise ValueError(f"Input file has unclear time format: {filename}")
time = Time(
tab["time"].data,
scale=hdulist[ext].header.get("TIMESYS", "tdb").lower(),
format=time_format,
)
tab.remove_column("time")

# For backwards compatibility with Lightkurve v1.x,
# we make sure standard columns and attributes exist.
if "flux" not in tab.columns:
tab.add_column(tab[flux_column], name="flux", index=0)
if "flux_err" not in tab.columns:
# Try falling back to `{flux_column}_err` if possible
if flux_err_column not in tab.columns:
flux_err_column = flux_column + "_err"
if flux_err_column in tab.columns:
tab.add_column(tab[flux_err_column], name="flux_err", index=1)
if "quality" not in tab.columns and quality_column in tab.columns:
tab.add_column(tab[quality_column], name="quality", index=2)
if "cadenceno" not in tab.columns and cadenceno_column in tab.columns:
tab.add_column(tab[cadenceno_column], name="cadenceno", index=3)
if "centroid_col" not in tab.columns and centroid_col_column in tab.columns:
tab.add_column(tab[centroid_col_column], name="centroid_col", index=4)
if "centroid_row" not in tab.columns and centroid_row_column in tab.columns:
tab.add_column(tab[centroid_row_column], name="centroid_row", index=5)

tab.meta["LABEL"] = hdulist[0].header.get("OBJECT")
tab.meta["MISSION"] = hdulist[0].header.get(
"MISSION", hdulist[0].header.get("TELESCOP")
)
tab.meta["RA"] = hdulist[0].header.get("RA_OBJ")
tab.meta["DEC"] = hdulist[0].header.get("DEC_OBJ")
tab.meta["FILENAME"] = filename
tab.meta["FLUX_ORIGIN"] = flux_column

return LightCurve(time=time, data=tab)
finally:
if to_close_hdul:
# avoid hdulist closing from emitting exceptions
hdulist.close(output_verify="warn")
4 changes: 1 addition & 3 deletions src/lightkurve/lightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,7 @@ def flux_quantity(self):
warning_type=LightkurveDeprecationWarning,
)
def hdu(self):
with fits.open(self.filename) as hdulist:
hdulist = hdulist.copy()
return hdulist
return fits.open(self.filename)

@property
@deprecated("2.0", warning_type=LightkurveDeprecationWarning)
Expand Down
9 changes: 4 additions & 5 deletions src/lightkurve/targetpixelfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def __init__(self, path, quality_bitmask="default", targetid=None, **kwargs):
if isinstance(path, fits.HDUList):
self.hdu = path
else:
with fits.open(self.path, **kwargs) as hdulist:
self.hdu = deepcopy(hdulist)
self.hdu = fits.open(self.path, **kwargs)
self.quality_bitmask = quality_bitmask
self.targetid = targetid

Expand Down Expand Up @@ -660,10 +659,10 @@ def _parse_aperture_mask(self, aperture_mask):
# Kepler and TESS pipeline style integer flags
aperture_mask = (aperture_mask & 2) == 2
else:
aperture_mask = aperture_mask.astype(bool)
aperture_mask = aperture_mask.astype(bool)
elif np.issubdtype(aperture_mask.dtype, float):
aperture_mask = aperture_mask.astype(bool)
self._last_aperture_mask = aperture_mask
aperture_mask = aperture_mask.astype(bool)
self._last_aperture_mask = aperture_mask
return aperture_mask

def create_threshold_mask(self, threshold=3, reference_pixel="center"):
Expand Down
19 changes: 19 additions & 0 deletions tests/data/test-lc-tess-pimen-100-cadences.fits

Large diffs are not rendered by default.

96 changes: 95 additions & 1 deletion tests/io/test_read.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import gc
import os
import warnings
import tempfile
import tracemalloc

import pytest

from astropy.io import fits

from lightkurve.utils import LightkurveDeprecationWarning, LightkurveError
from lightkurve import (
PACKAGEDIR,
Expand All @@ -12,11 +16,41 @@
LightCurve,
)
from lightkurve.io import read
from lightkurve.io.generic import read_generic_lightcurve

from .. import TESTDATA
from ..test_lightcurve import TABBY_Q8
from ..test_targetpixelfile import TABBY_TPF

#
# For tests with pytest error::ResourceWarning
# and error::pytest.PytestUnraisableExceptionWarning
# they are to ensure all internal file handles are closed in read operations
# (ResourceWarning in case of unclosed file handles,
# is wrapped by as PytestUnraisableExceptionWarning by pytest)
#

@pytest.mark.filterwarnings("error::ResourceWarning")
@pytest.mark.filterwarnings("error::pytest.PytestUnraisableExceptionWarning")
def test_read_lc():
filename_lc = os.path.join(TESTDATA, "test-lc-tess-pimen-100-cadences.fits")
lc = read(filename_lc)
assert isinstance(lc, LightCurve)


@pytest.mark.filterwarnings("error::ResourceWarning")
@pytest.mark.filterwarnings("error::pytest.PytestUnraisableExceptionWarning")
def test_read_lc_in_hdu():
filename_lc = os.path.join(TESTDATA, "test-lc-tess-pimen-100-cadences.fits")
hdul = fits.open(filename_lc)
# lk.read() does not support hdul as input
lc = read_generic_lightcurve(hdul, flux_column="pdcsap_flux", time_format="btjd")
hdul.close()
assert len(lc.flux) > 0, "LC should be functional even the hdul is closed."

def test_read():

# tpf.hdu has open file handle, so they are not tested for unclosed file handles
def test_read_tpf():
# define paths to k2 and tess data
k2_path = os.path.join(TESTDATA, "test-tpf-star.fits")
tess_path = os.path.join(TESTDATA, "tess25155310-s01-first-cadences.fits.gz")
Expand Down Expand Up @@ -72,6 +106,8 @@ def test_filenotfound():
assert filename in str(excinfo.value)


@pytest.mark.filterwarnings("error::ResourceWarning")
@pytest.mark.filterwarnings("error::pytest.PytestUnraisableExceptionWarning")
@pytest.mark.filterwarnings("ignore:.*been truncated.*") # ignore AstropyUserWarning: File may have been truncated
def test_file_corrupted():
"""Regression test for #1184; ensure lk.read() yields an error that includes the filename."""
Expand Down Expand Up @@ -110,3 +146,61 @@ def test_basic_ascii_io():
finally:
tabfile.close()
os.remove(tabfile.name)


@pytest.mark.memtest
@pytest.mark.remote_data
@pytest.mark.parametrize("fits_path, iterations_warmup, run_iterations", [
(TABBY_Q8, 40, 60),
(TABBY_TPF, 40, 60),
])
def test_read_memory_usage(fits_path, iterations_warmup, run_iterations):
"""Ensure reading LC/TPF has no memory leak. Regression test for #1388.
The test uses real data rather than trimmed-down test data
to better simulate real life scenarios.
"""
def do_read():
# do the actual read in a function,
# to ensure object read is out-of-scope and to be freed up after it's done,
# simulating the typical scenario
obj_read = read(fits_path)
return len(obj_read)

tracemalloc.start()
try:
h_current, h_peak = [], [] # history of tracemalloc for error reporting
for _ in range(iterations_warmup):
do_read()
current, peak = tracemalloc.get_traced_memory()
h_current.append(current)
h_peak.append(peak)
gc.collect() # run GC so that the number would be more consistent across runs
current, peak = tracemalloc.get_traced_memory()
post_warmup_mem, post_warmup_peak = current, peak
h_current.append(current)
h_peak.append(f"{peak} (post-warmup, after GC)")

for _ in range(run_iterations):
do_read()
current, peak = tracemalloc.get_traced_memory()
h_current.append(current)
h_peak.append(peak)
gc.collect() # run GC so that the number would be more consistent across runs
current, peak = tracemalloc.get_traced_memory()
post_run_mem, post_run_peak = current, peak
h_current.append(current)
h_peak.append(f"{peak} (post-run, after GC)")

# if the test fails, print out detailed history of the memory usage for diagnosis
assert_err_msg = ("Memory usage should not keep increasing. "
f"After warmup ({iterations_warmup}): mem: {post_warmup_mem}, peak: {post_warmup_peak} . "
f"After run ({run_iterations} more): mem: {post_run_mem}, peak: {post_run_peak} . "
f"History of (current, peak):\n "
)
assert_err_msg += "\n".join([str((c, p)) for c, p in zip(h_current, h_peak)])

# leave plenty buffer (2X of post warm-up memory) for the memory leak detection
# if there is a leak, the actual memory usage at the end is likely to be significantly higher.
assert post_run_mem < 2 * post_warmup_mem, assert_err_msg
finally:
tracemalloc.stop()

0 comments on commit 68fdf03

Please sign in to comment.