Skip to content

Commit

Permalink
add utils.record_filter edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jGaboardi committed Dec 5, 2020
1 parent 016b775 commit ed038e7
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion tigernet/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

from .. import utils

import numpy
import operator
import pandas
import unittest
from shapely.geometry import MultiLineString


class TestUtils(unittest.TestCase):
class TestUtilsWeldingFuncs(unittest.TestCase):
def setUp(self):
pass

Expand All @@ -32,5 +35,36 @@ def test__weld_MultiLineString_3(self):
self.assertEqual(observed_weld_wkt, known_weld_wkt)


class TestUtilsFilterFuncs(unittest.TestCase):
def setUp(self):
self.df = pandas.DataFrame().from_dict(
{"v1": ["x"] * 2 + ["y"] * 2 + ["z"] * 2, "v2": ["a"] * 3 + ["b"] * 3}
)

def test_record_filter_mval_in(self):
known_values = numpy.array([["x", "a"], ["x", "a"], ["y", "a"], ["y", "b"]])
kws = {"column": "v1", "mval": ["x", "y"], "oper": "in"}
observed_values = utils.record_filter(self.df.copy(), **kws).values
numpy.testing.assert_array_equal(observed_values, known_values)

def test_record_filter_mval_out(self):
known_values = numpy.array([["z", "b"], ["z", "b"]])
kws = {"column": "v1", "mval": ["x", "y"], "oper": "out"}
observed_values = utils.record_filter(self.df.copy(), **kws).values
numpy.testing.assert_array_equal(observed_values, known_values)

def test_record_filter_mval_index(self):
known_values = numpy.array([["y", "b"], ["z", "b"], ["z", "b"]])
kws = {"column": "index", "mval": [0, 1, 2], "oper": "out"}
observed_values = utils.record_filter(self.df.copy(), **kws).values
numpy.testing.assert_array_equal(observed_values, known_values)

def test_record_filter_sval(self):
known_values = numpy.array([["x", "a"], ["x", "a"], ["y", "a"]])
kws = {"column": "v2", "sval": "a", "oper": operator.eq}
observed_values = utils.record_filter(self.df.copy(), **kws).values
numpy.testing.assert_array_equal(observed_values, known_values)


if __name__ == "__main__":
unittest.main()

0 comments on commit ed038e7

Please sign in to comment.