Skip to content

Commit

Permalink
improving test coverage + fixing a never raising exception in
Browse files Browse the repository at this point in the history
Trace.__setattr__
  • Loading branch information
barsch committed Aug 17, 2013
1 parent d4260a0 commit e0639cc
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 1 deletion.
94 changes: 94 additions & 0 deletions obspy/core/tests/test_trace.py
Expand Up @@ -12,6 +12,40 @@ class TraceTestCase(unittest.TestCase):
"""
Test suite for obspy.core.trace.Trace.
"""
def test_init(self):
"""
Tests the __init__ method of the Trace class.
"""
# NumPy ndarray
tr = Trace(data=np.arange(4))
self.assertEqual(len(tr), 4)
# NumPy masked array
data = np.ma.array([0, 1, 2, 3], mask=[True, True, False, False])
tr = Trace(data=data)
self.assertEqual(len(tr), 4)
# other data types will raise
self.assertRaises(ValueError, Trace, data=[0, 1, 2, 3])
self.assertRaises(ValueError, Trace, data=(0, 1, 2, 3))
self.assertRaises(ValueError, Trace, data='1234')

def test_setattr(self):
"""
Tests the __setattr__ method of the Trace class.
"""
# NumPy ndarray
tr = Trace()
tr.data = np.arange(4)
self.assertEqual(len(tr), 4)
# NumPy masked array
tr = Trace()
tr.data = np.ma.array([0, 1, 2, 3], mask=[True, True, False, False])
self.assertEqual(len(tr), 4)
# other data types will raise
tr = Trace()
self.assertRaises(ValueError, tr.__setattr__, 'data', [0, 1, 2, 3])
self.assertRaises(ValueError, tr.__setattr__, 'data', (0, 1, 2, 3))
self.assertRaises(ValueError, tr.__setattr__, 'data', '1234')

def test_len(self):
"""
Tests the __len__ and count methods of the Trace class.
Expand All @@ -20,6 +54,29 @@ def test_len(self):
self.assertEqual(len(trace), 1000)
self.assertEqual(trace.count(), 1000)

def test_mul(self):
"""
Tests the __mul__ method of the Trace class.
"""
tr = Trace(data=np.arange(10))
st = tr * 5
self.assertEqual(len(st), 5)
# you may only multiply using an integer
self.assertRaises(TypeError, tr.__mul__, 2.5)
self.assertRaises(TypeError, tr.__mul__, '1234')

def test_div(self):
"""
Tests the __div__ method of the Trace class.
"""
tr = Trace(data=np.arange(1000))
st = tr / 5
self.assertEqual(len(st), 5)
self.assertEqual(len(st[0]), 200)
# you may only multiply using an integer
self.assertRaises(TypeError, tr.__div__, 2.5)
self.assertRaises(TypeError, tr.__div__, '1234')

def test_ltrim(self):
"""
Tests the ltrim method of the Trace class.
Expand All @@ -32,6 +89,9 @@ def test_ltrim(self):
end = UTCDateTime(2000, 1, 1, 0, 0, 4, 995000)
# verify
trace.verify()
# UTCDateTime/int/float required
self.assertRaises(TypeError, trace._ltrim, '1234')
self.assertRaises(TypeError, trace._ltrim, [1, 2, 3, 4])
# ltrim 100 samples
tr = deepcopy(trace)
tr._ltrim(0.5)
Expand Down Expand Up @@ -117,6 +177,9 @@ def test_rtrim(self):
trace.stats.sampling_rate = 200.0
end = UTCDateTime(2000, 1, 1, 0, 0, 4, 995000)
trace.verify()
# UTCDateTime/int/float required
self.assertRaises(TypeError, trace._rtrim, '1234')
self.assertRaises(TypeError, trace._rtrim, [1, 2, 3, 4])
# rtrim 100 samples
tr = deepcopy(trace)
tr._rtrim(0.5)
Expand Down Expand Up @@ -252,6 +315,8 @@ def test_trim(self):
self.assertEqual(trace.stats.sampling_rate, 200.0)
self.assertEqual(trace.stats.starttime, start + 0.5)
self.assertEqual(trace.stats.endtime, end - 0.5)
# starttime should be before endtime
self.assertRaises(ValueError, trace.trim, end, start)

def test_trimAllDoesNotChangeDtype(self):
"""
Expand Down Expand Up @@ -637,6 +702,32 @@ def test_trimFloatingPointWithPadding2(self):
self.assertEqual(tr.data.ctypes.data, mem_pos)
self.assertEqual(tr.stats, org_stats)

def test_add_sanity(self):
"""
Test sanity checks in __add__ method of the Trace object.
"""
tr = Trace(data=np.arange(10))
# you may only add a Trace object
self.assertRaises(TypeError, tr.__add__, 1234)
self.assertRaises(TypeError, tr.__add__, '1234')
self.assertRaises(TypeError, tr.__add__, [1, 2, 3, 4])
# trace id
tr2 = Trace()
tr2.stats.station = 'TEST'
self.assertRaises(TypeError, tr.__add__, tr2)
# sample rate
tr2 = Trace()
tr2.stats.sampling_rate = 20
self.assertRaises(TypeError, tr.__add__, tr2)
# calibration factor
tr2 = Trace()
tr2.stats.calib = 20
self.assertRaises(TypeError, tr.__add__, tr2)
# data type
tr2 = Trace()
tr2.data = np.arange(10, dtype=np.float32)
self.assertRaises(TypeError, tr.__add__, tr2)

def test_addOverlapsDefaultMethod(self):
"""
Test __add__ method of the Trace object.
Expand Down Expand Up @@ -801,6 +892,7 @@ def test_comparisons(self):
tr5 = Trace(np.arange(5), {'station': 'X'})
tr6 = Trace(np.arange(5), {'processing':
["filter:lowpass:{'freq': 10}"]})
tr7 = Trace(np.array([1, 1, 1]))
# tests that should raise a NotImplementedError (i.e. <=, <, >=, >)
self.assertRaises(NotImplementedError, tr1.__lt__, tr1)
self.assertRaises(NotImplementedError, tr1.__le__, tr1)
Expand All @@ -818,6 +910,7 @@ def test_comparisons(self):
self.assertEqual(tr0 == tr4, False)
self.assertEqual(tr0 == tr5, False)
self.assertEqual(tr0 == tr6, False)
self.assertEqual(tr0 == tr7, False)
self.assertEqual(tr5 == tr0, False)
self.assertEqual(tr5 == tr1, False)
self.assertEqual(tr5 == tr2, False)
Expand All @@ -833,6 +926,7 @@ def test_comparisons(self):
self.assertEqual(tr0 != tr4, True)
self.assertEqual(tr0 != tr5, True)
self.assertEqual(tr0 != tr6, True)
self.assertEqual(tr0 != tr7, True)
self.assertEqual(tr5 != tr0, True)
self.assertEqual(tr5 != tr1, True)
self.assertEqual(tr5 != tr2, True)
Expand Down
2 changes: 1 addition & 1 deletion obspy/core/trace.py
Expand Up @@ -365,7 +365,7 @@ def __setattr__(self, key, value):
if key == 'data':
if not isinstance(value, np.ndarray):
msg = "Trace.data must be a NumPy array."
ValueError(msg)
raise ValueError(msg)
self.stats.npts = len(value)
return super(Trace, self).__setattr__(key, value)

Expand Down

0 comments on commit e0639cc

Please sign in to comment.