Skip to content

Commit

Permalink
Merge pull request #1518 from claudiodsf/arrivals_ops
Browse files Browse the repository at this point in the history
List operations on obspy.taup.tau.Arrivals should return an Arrivals object
  • Loading branch information
krischer committed Sep 9, 2016
2 parents a715b39 + db12a68 commit eba4c46
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 1 deletion.
74 changes: 73 additions & 1 deletion obspy/taup/tau.py
Expand Up @@ -15,6 +15,7 @@
import matplotlib.text
import numpy as np

from .helper_classes import Arrival
from .tau_model import TauModel
from .taup_create import TauPCreate
from .taup_path import TauPPath
Expand Down Expand Up @@ -68,7 +69,7 @@ def draw(self, renderer, *args, **kwargs):

class Arrivals(list):
"""
List of arrivals returned by :class:`TauPyModel` methods.
List like object of arrivals returned by :class:`TauPyModel` methods.
:param arrivals: Initial arrivals to store.
:type arrivals: :class:`list` of
Expand All @@ -83,6 +84,66 @@ def __init__(self, arrivals, model):
self.model = model
self.extend(arrivals)

def __add__(self, other):
if isinstance(other, Arrival):
other = Arrivals([other], model=self.model)
if not isinstance(other, Arrivals):
raise TypeError
return self.__class__(super(Arrivals, self).__add__(other),
model=self.model)

def __iadd__(self, other):
if isinstance(other, Arrival):
other = Arrivals([other], model=self.model)
if not isinstance(other, Arrivals):
raise TypeError
self.extend(other)
return self

def __mul__(self, num):
if not isinstance(num, int):
raise TypeError("Integer expected")
arr = self.copy()
for _i in range(num-1):
arr += self.copy()
return arr

def __imul__(self, num):
if not isinstance(num, int):
raise TypeError("Integer expected")
arr = self.copy()
for _i in range(num-1):
self += arr
return self

def __setitem__(self, index, arrival):
if (isinstance(index, slice) and
all(isinstance(x, Arrival) for x in arrival)):
super(Arrivals, self).__setitem__(index, arrival)
elif isinstance(arrival, Arrival):
super(Arrivals, self).__setitem__(index, arrival)
else:
msg = 'Only Arrival objects can be assigned.'
raise TypeError(msg)

def __setslice__(self, i, j, seq):
if all(isinstance(x, Arrival) for x in seq):
super(Arrivals, self).__setslice__(i, j, seq)
else:
msg = 'Only Arrival objects can be assigned.'
raise TypeError(msg)

def __getitem__(self, index):
if isinstance(index, slice):
return self.__class__(super(Arrivals, self).__getitem__(index),
model=self.model)
else:
return super(Arrivals, self).__getitem__(index)

def __getslice__(self, i, j):
return self.__class__(super(Arrivals, self).__getslice__(i, j),
model=self.model)

def __str__(self):
return (
"{count} arrivals\n\t{arrivals}"
Expand All @@ -93,6 +154,17 @@ def __str__(self):
def __repr__(self):
return "[%s]" % (", ".join([repr(_i) for _i in self]))

def append(self, arrival):
if isinstance(arrival, Arrival):
super(Arrivals, self).append(arrival)
else:
msg = 'Append only supports a single Arrival object as argument.'
raise TypeError(msg)

def copy(self):
return self.__class__(super(Arrivals, self).copy(),
model=self.model)

def plot(self, plot_type="spherical", plot_all=True, legend=True,
label_arrivals=False, ax=None, show=True):
"""
Expand Down
50 changes: 50 additions & 0 deletions obspy/taup/tests/test_tau.py
Expand Up @@ -18,6 +18,7 @@

from obspy.taup import TauPyModel
from obspy.taup.taup_geo import calc_dist
from obspy.taup.tau import Arrivals
import obspy.geodetics.base as geodetics


Expand Down Expand Up @@ -952,6 +953,55 @@ def test_paths_for_crustal_phases(self):
np.testing.assert_allclose([_i[1] for _i in pn_path],
paths[1].path["depth"])

def test_arrivals_class(self):
"""
Tests list operations on the Arrivals class.
See #1518.
"""
model = TauPyModel(model='iasp91')
arrivals = model.get_ray_paths(source_depth_in_km=0,
distance_in_degree=1,
phase_list=['Pn', 'PmP'])
self.assertEqual(len(arrivals), 2)
# test copy
self.assertTrue(isinstance(arrivals.copy(), Arrivals))
# test sum
self.assertTrue(isinstance(arrivals+arrivals, Arrivals))
self.assertTrue(isinstance(arrivals+arrivals[0], Arrivals))
# test multiplying
self.assertTrue(isinstance(arrivals*2, Arrivals))
arrivals *= 3
self.assertEqual(len(arrivals), 6)
self.assertTrue(isinstance(arrivals, Arrivals))
# test slicing
self.assertTrue(isinstance(arrivals[2:5], Arrivals))
# test appending
arrivals.append(arrivals[0])
self.assertEqual(len(arrivals), 7)
self.assertTrue(isinstance(arrivals, Arrivals))
# test assignment
arrivals[0] = arrivals[-1]
self.assertTrue(isinstance(arrivals, Arrivals))
arrivals[2:5] = arrivals[1:4]
self.assertTrue(isinstance(arrivals, Arrivals))
# test assignment with wrong type
with self.assertRaises(TypeError):
arrivals[0] = 10.
with self.assertRaises(TypeError):
arrivals[2:5] = [0, 1, 2]
with self.assertRaises(TypeError):
arrivals.append(arrivals)
# test add and mul with wrong type
with self.assertRaises(TypeError):
arrivals + [2, ]
with self.assertRaises(TypeError):
arrivals += [2, ]
with self.assertRaises(TypeError):
arrivals * [2, ]
with self.assertRaises(TypeError):
arrivals *= [2, ]


def suite():
return unittest.makeSuite(TauPyModelTestCase, 'test')
Expand Down

0 comments on commit eba4c46

Please sign in to comment.