diff --git a/CHANGELOG.md b/CHANGELOG.md index cf35484..02907f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,10 @@ Handle all warnings as errors in testing. Fix test parameters to avoid invalid values. Drop support for python 2.7 due xaxis and flux keywords required with "*,". -Drop testing of python 3.4 also while at it. +Drop testing of python 3.4. + +- Add indexing/slicing spectrum with [], (Returns new spectrum) + 0.2.1 14/01/2018 diff --git a/spectrum_overload/spectrum.py b/spectrum_overload/spectrum.py index 460583f..75229fd 100644 --- a/spectrum_overload/spectrum.py +++ b/spectrum_overload/spectrum.py @@ -561,11 +561,9 @@ def spline_interpolate_to(self, reference: Union[ndarray, str, 'Spectrum', List[ raise TypeError("Cannot interpolate with the given object of type" " {}".format(type(reference))) - def remove_nans(self) -> 'Spectrum': - s = self.copy() - s.flux = s.flux[~np.isnan(self.flux)] - s.xaxis = s.xaxis[~np.isnan(self.flux)] - return s + def remove_nans(self) -> "Spectrum": + """Returns new spectrum. Uses slicing with isnan mask.""" + return self[~np.isnan(self.flux)] def continuum(self, method: str = "scalar", degree: Optional[int] = None, **kwargs) -> 'Spectrum': """Fit the continuum of the spectrum. @@ -771,6 +769,17 @@ def xmax(self): def xlimits(self): return [self.xmin(), self.xmax()] + def __getitem__(self, item): + """Be able slice the spectrum. Return new object.""" + if isinstance(item, (type(None), str, int, float, bool)): + raise ValueError( + "Cannot slice with types of type(None),str,int,float,bool." + ) + s = self.copy() + s.flux = self.flux[item] + s.xaxis = self.xaxis[item] + return s + class SpectrumError(Exception): """An error class for spectrum errors.""" diff --git a/spectrum_overload/test/test_Spectrum.py b/spectrum_overload/test/test_Spectrum.py index b96313c..d38f128 100644 --- a/spectrum_overload/test/test_Spectrum.py +++ b/spectrum_overload/test/test_Spectrum.py @@ -16,6 +16,7 @@ # Test using hypothesis from hypothesis import example, given from pkg_resources import resource_filename + from spectrum_overload import Spectrum, SpectrumError @@ -500,3 +501,49 @@ def test_instrument_broaden(phoenix_spectrum, R): assert not np.allclose(new_spec.flux, spec.flux) # Spectrum result equals correct pyasl value. assert np.allclose(new_spec.flux, new_flux) + + +@pytest.mark.parametrize( + "item", + [ + [5], + [-60], + [1, 2, 5, 6, 7], + [-1, 2, 5, 20, -6], + slice(10, 100), + slice(0, 8), + slice(None), + ], +) +def test_spectrum_slicing(phoenix_spectrum, item): + sliced_spectrum = phoenix_spectrum[item] + + assert np.all(sliced_spectrum.xaxis == phoenix_spectrum.xaxis[item]) + assert np.all(sliced_spectrum.flux == phoenix_spectrum.flux[item]) + assert sliced_spectrum.header == phoenix_spectrum.header + assert sliced_spectrum.calibrated == phoenix_spectrum.calibrated + + +def test_spectrum_slicing_with_colon(phoenix_spectrum): + sliced_spectrum = phoenix_spectrum[:] + assert np.all(sliced_spectrum.xaxis == phoenix_spectrum.xaxis[:]) + assert np.all(sliced_spectrum.flux == phoenix_spectrum.flux[:]) + + sliced_spectrum2 = phoenix_spectrum[:50] + assert np.all(sliced_spectrum2.xaxis == phoenix_spectrum.xaxis[:50]) + assert np.all(sliced_spectrum2.flux == phoenix_spectrum.flux[:50]) + + sliced_spectrum3 = phoenix_spectrum[-200:] + assert np.all(sliced_spectrum3.xaxis == phoenix_spectrum.xaxis[-200:]) + assert np.all(sliced_spectrum3.flux == phoenix_spectrum.flux[-200:]) + + sliced_spectrum4 = phoenix_spectrum[50:2:80] + assert np.all(sliced_spectrum4.xaxis == phoenix_spectrum.xaxis[50:2:80]) + assert np.all(sliced_spectrum4.flux == phoenix_spectrum.flux[50:2:80]) + + +@pytest.mark.parametrize("item", [None, "hello", "", 7, 3.14, True, False, 0, 1.0]) +def test_specturm_slicing_invalid_types(phoenix_spectrum, item): + """Invalid scalars and other types.""" + with pytest.raises(ValueError): + phoenix_spectrum[item]