Skip to content

Commit

Permalink
Use the operator_wrapper.
Browse files Browse the repository at this point in the history
Comment and move old operator functions.
  • Loading branch information
jason-neal committed Jun 22, 2017
1 parent 1da3ee2 commit 22ac0fe
Showing 1 changed file with 121 additions and 106 deletions.
227 changes: 121 additions & 106 deletions spectrum_overload/Spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,65 +530,9 @@ def spline_interpolate_to(self, reference, w=None, bbox=[None, None], k=3,
# ######################################################

def _operation_wrapper(operation):
def __add__(self, other):
"""Overloaded addition method for Spectrum.
If there is addition between two Spectrum objects which have
difference xaxis values then the second Spectrum is interpolated
to the xaxis of the first Spectum
e.g. if len(a.xaxis) = 10 and len(b.xaxis = 15)
then if len(a + b) = 10 and len(b + a) = 15.
This makes a + b != b + a
"""
Perform an operation (addition, subtraction, mutiplication, division,
etc.) after checking for shape matching.
# Checks for type errors and size. It interpolates other if needed.
prepared_other = self._prepare_other(other)
newspec = self.copy()
newspec.flux = newspec.flux + prepared_other
return newspec

def __radd__(self, other):
"""Right addition."""
# E.g. for first Item in Sum 0 + Spectrum fails.
newspec = self.copy()
newspec.flux = newspec.flux + other
return newspec

def __sub__(self, other):
"""Overloaded subtraction method for Spectrum.
If there is subtraction between two Spectrum objects which have
difference xaxis values then the second Spectrum is interpolated
to the xaxis of the first Spectum.
e.g. if len(a.xaxis) = 10 and len(b.xaxis = 15)
then if len(a - b) = 10 and len(b - a) = 15.
# This makes a - b != -b + a
"""
# Checks for type errors and size. It interpolates other if needed.
prepared_other = self._prepare_other(other)
newspec = self.copy()
newspec.flux = newspec.flux - prepared_other
return newspec

def __mul__(self, other):
"""Overloaded multiplication method for Spectrum.
If there is multiplication between two Spectrum objects which have
difference xaxis values then the second Spectrum is interpolated
to the xaxis of the first Spectum.
e.g. if len(a.xaxis) = 10 and len(b.xaxis = 15)
then if len(a * b) = 10 and len(b * a) = 15.
This makes a * b != b * a
"""
def ofunc(self, other):
""" operation function """
Expand Down Expand Up @@ -620,37 +564,18 @@ def ofunc(self, other):
other_flux = other_copy.flux
else:
raise TypeError("Unexpected type {} for operation with Spectrum".format(type(other)))
# Checks for type errors and size. It interpolates other if needed.
prepared_other = self._prepare_other(other)
#new_flux = self.flux * prepared_other
newspec = self.copy()
newspec.flux = newspec.flux * prepared_other
return newspec

newspec.flux = operation(newspec.flux, other_flux) # Perform the operation
return newspec
def __truediv__(self, other):
"""Overloaded truedivision (/) method for Spectrum.

return ofunc
If there is truedivision between two Spectrum objects which have
difference xaxis values then the second Spectrum is interpolated
to the xaxis of the first Spectum.

e.g. if len(a.xaxis) = 10 and len(b.xaxis = 15)
then if len(a / b) = 10 and len(b / a) = 15.

This makes (a / b) != (1/b) / (1/a).
"""
# Checks for type errors and size. It interpolates other if needed.
prepared_other = self._prepare_other(other)
# Divide by zero only gives a runtime warning with numpy
newspec = self.copy()
newspec.flux = newspec.flux / prepared_other
# May want to change the inf to something else, nan, 0?...
# new_flux[new_flux == np.inf] = np.nan
return newspec
__add__ = _operation_wrapper(np.add)
__radd__ = _operation_wrapper(np.add)
__sub__ = _operation_wrapper(np.subtract)
__mul__ = _operation_wrapper(np.multiply)
__div__ = _operation_wrapper(np.divide)
__truediv__ = _operation_wrapper(np.divide)

def __pow__(self, other):
"""Exponetial magic method."""
Expand Down Expand Up @@ -696,32 +621,122 @@ def __abs__(self):
return Spectrum(flux=absflux, xaxis=self.xaxis, header=self.header,
calibrated=self.calibrated)

def _prepare_other(self, other):
if isinstance(other, Spectrum):
if self.calibrated != other.calibrated:
# Checking the Spectra are of same calibration state
raise SpectrumError("Spectra are not calibrated similarly.")
if np.all(self.xaxis == other.xaxis): # Only for equal xaxis
# Easiest condition in which xaxis of both are the same
return other.copy().flux
else: # Uneven length xaxis need to be interpolated
no_overlap_lower = (np.min(self.xaxis) > np.max(other.xaxis))
no_overlap_upper = (np.max(self.xaxis) < np.min(other.xaxis))
if no_overlap_lower | no_overlap_upper:
raise ValueError("The xaxis do not overlap so cannot"
" be interpolated")
else:
other_copy = other.copy()
# other_copy.interpolate_to(self)
other_copy.spline_interpolate_to(self)
return other_copy.flux
elif np.isscalar(other):
return other
return copy.copy(other)
else:
raise TypeError("Unexpected type {} given".format(type(other)))


class SpectrumError(Exception):
"""An errorclass for specturm errors."""
pass


# def __add__(self, other):
# """Overloaded addition method for Spectrum.
#
# If there is addition between two Spectrum objects which have
# difference xaxis values then the second Spectrum is interpolated
# to the xaxis of the first Spectum
#
# e.g. if len(a.xaxis) = 10 and len(b.xaxis = 15)
# then if len(a + b) = 10 and len(b + a) = 15.
#
# This makes a + b != b + a
#
# """
# # Checks for type errors and size. It interpolates other if needed.
# prepared_other = self._prepare_other(other)
# newspec = self.copy()
# newspec.flux = newspec.flux + prepared_other
# return newspec

# def __radd__(self, other):
# """Right addition."""
# # E.g. for first Item in Sum 0 + Spectrum fails.
# newspec = self.copy()
# newspec.flux = newspec.flux + other
# return newspec
#
# def __sub__(self, other):
# """Overloaded subtraction method for Spectrum.
#
# If there is subtraction between two Spectrum objects which have
# difference xaxis values then the second Spectrum is interpolated
# to the xaxis of the first Spectum.
#
# e.g. if len(a.xaxis) = 10 and len(b.xaxis = 15)
# then if len(a - b) = 10 and len(b - a) = 15.
#
# # This makes a - b != -b + a
#
# """
# # Checks for type errors and size. It interpolates other if needed.
# prepared_other = self._prepare_other(other)
# newspec = self.copy()
# newspec.flux = newspec.flux - prepared_other
# return newspec

# def __mul__(self, other):
# """Overloaded multiplication method for Spectrum.
#
# If there is multiplication between two Spectrum objects which have
# difference xaxis values then the second Spectrum is interpolated
# to the xaxis of the first Spectum.
#
# e.g. if len(a.xaxis) = 10 and len(b.xaxis = 15)
# then if len(a * b) = 10 and len(b * a) = 15.
#
# This makes a * b != b * a
#
# """
# # Checks for type errors and size. It interpolates other if needed.
# prepared_other = self._prepare_other(other)
# #new_flux = self.flux * prepared_other
# newspec = self.copy()
# newspec.flux = newspec.flux * prepared_other
# return newspec
#
# def __truediv__(self, other):
# """Overloaded truedivision (/) method for Spectrum.
#
# If there is truedivision between two Spectrum objects which have
# difference xaxis values then the second Spectrum is interpolated
# to the xaxis of the first Spectum.
#
# e.g. if len(a.xaxis) = 10 and len(b.xaxis = 15)
# then if len(a / b) = 10 and len(b / a) = 15.
#
# This makes (a / b) != (1/b) / (1/a).
#
# """
# # Checks for type errors and size. It interpolates other if needed.
# prepared_other = self._prepare_other(other)
# # Divide by zero only gives a runtime warning with numpy
# newspec = self.copy()
# newspec.flux = newspec.flux / prepared_other
# # May want to change the inf to something else, nan, 0?...
# # new_flux[new_flux == np.inf] = np.nan
# return newspec
#
# def _prepare_other(self, other):
# if isinstance(other, Spectrum):
# if self.calibrated != other.calibrated:
# # Checking the Spectra are of same calibration state
# raise SpectrumError("Spectra are not calibrated similarly.")
#
# if np.all(self.xaxis == other.xaxis): # Only for equal xaxis
# # Easiest condition in which xaxis of both are the same
# return other.copy().flux
# else: # Uneven length xaxis need to be interpolated
# no_overlap_lower = (np.min(self.xaxis) > np.max(other.xaxis))
# no_overlap_upper = (np.max(self.xaxis) < np.min(other.xaxis))
# if no_overlap_lower | no_overlap_upper:
# raise ValueError("The xaxis do not overlap so cannot"
# " be interpolated")
# else:
# other_copy = other.copy()
# # other_copy.interpolate_to(self)
# other_copy.spline_interpolate_to(self)
# return other_copy.flux
# elif np.isscalar(other):
# return other
# elif isinstance(other, np.ndarray):
# return copy.copy(other)
# else:
# raise TypeError("Unexpected type {} given".format(type(other)))

0 comments on commit 22ac0fe

Please sign in to comment.