Skip to content

Commit

Permalink
🐛 LightCurveCollection.stitch() should ignore unmergeable columns (fixes
Browse files Browse the repository at this point in the history
 #954) (#996)

* 🐛 Fixes #954

* ✅ Add a regression test

* 📝 Update CHANGES.rst
  • Loading branch information
barentsen committed Mar 12, 2021
1 parent 08c3b19 commit a22ef3f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
- Modified ``LightCurve.bin()`` to partially restore the ``bins`` parameter which
was available in Lightkurve v1.x, to improve backwards compatibility. [#995]

- Modified ``LightCurveCollection.stitch()`` to ignore incompatible columns
instead of having them raise a ``ValueError``. [#996]



2.0.4 (2021-03-11)
==================
Expand Down
21 changes: 20 additions & 1 deletion src/lightkurve/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import MPLSTYLE
from .targetpixelfile import TargetPixelFile
from .utils import LightkurveDeprecationWarning
from .utils import LightkurveWarning, LightkurveDeprecationWarning


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -227,6 +227,25 @@ def stitch(self, corrector_func=lambda x: x.normalize()):
with warnings.catch_warnings(): # ignore "already normalized" message
warnings.filterwarnings("ignore", message=".*already.*")
lcs = [corrector_func(lc) for lc in self]

# Address issue #954: ignore incompatible columns with the same name
columns_to_remove = set()
for col in lcs[0].columns:
for lc in lcs[1:]:
if col in lc.columns:
if not (issubclass(lcs[0][col].__class__, lc[col].__class__) \
or lcs[0][col].__class__.info is lc[col].__class__.info):
columns_to_remove.add(col)
continue

if len(columns_to_remove) > 0:
warnings.warn(
f"The following columns will be excluded from stitching because the column types are incompatible: {columns_to_remove}",
LightkurveWarning,
)
lcs = [lc.copy() for lc in lcs]
[lc.remove_columns(columns_to_remove.intersection(lc.columns)) for lc in lcs]

# Need `join_type='inner'` until AstroPy supports masked Quantities
return vstack(lcs, join_type="inner", metadata_conflicts="silent")

Expand Down
13 changes: 11 additions & 2 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings

import pytest
from astropy import units as u
from astropy.utils.data import get_pkg_data_filename
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -107,7 +108,7 @@ def test_collection_getitem_by_boolean_array():

lcc_f = lcc[[True, False, True]]
assert lcc_f.data == [lc0, lc2]
assert (type(lcc_f), LightCurveCollection)
assert type(lcc_f) is LightCurveCollection

# boundary case: 1 element
lcc_f = lcc[[False, True, False]]
Expand Down Expand Up @@ -215,7 +216,7 @@ def test_tpfcollection():
# ensure index by boolean array also works for TPFs
tpfc_f = tpfc[[False, True, True]]
assert tpfc_f.data == [tpf2, tpf2]
assert (type(tpfc_f), TargetPixelFileCollection)
assert type(tpfc_f) is TargetPixelFileCollection
# Test __setitem__
tpf3 = KeplerTargetPixelFile(filename_tpf_one_center, targetid=55)
tpfc[1] = tpf3
Expand Down Expand Up @@ -353,3 +354,11 @@ def test_accessor_k2_campaign():
tpf1.hdu[0].header["CAMPAIGN"] = 1
tpfc = TargetPixelFileCollection([tpf0, tpf1])
assert (tpfc.campaign == [2, 1]).all()


def test_unmergeable_columns():
"""Regression test for #954."""
lc1 = LightCurve(data={'time': [1,2,3], 'x': [1,2,3]})
lc2 = LightCurve(data={'time': [1,2,3], 'x': [1,2,3]*u.electron/u.second})
with pytest.warns(LightkurveWarning, match="column types are incompatible"):
LightCurveCollection([lc1, lc2]).stitch()

0 comments on commit a22ef3f

Please sign in to comment.