diff --git a/tests/fixtures/monthly_subquota_limit_classifier.csv b/tests/fixtures/monthly_subquota_limit_classifier.csv new file mode 100644 index 0000000..4d4b917 --- /dev/null +++ b/tests/fixtures/monthly_subquota_limit_classifier.csv @@ -0,0 +1,10 @@ +applicant_id,document_id,issue_date,year,month,subquota_number,total_net_value +1,100,2015-03-01,2015,3,120,10500 +1,100,2015-03-01,2015,3,120,401 +2,100,2015-04-01,2015,4,120,10500 +2,100,2015-04-01,2015,4,120,399 +3,100,2015-05-01,2015,5,120,10500 +3,100,2015-05-01,2015,5,120,400 +4,100,2015-06-01,2015,6,120,10500 +4,100,2015-06-01,2015,6,120,401 +5,100,2015-06-01,2015,6,120,10500 diff --git a/tests/test_monthly_subquota_limit_classifier.py b/tests/test_monthly_subquota_limit_classifier.py new file mode 100644 index 0000000..124281d --- /dev/null +++ b/tests/test_monthly_subquota_limit_classifier.py @@ -0,0 +1,19 @@ +from unittest import TestCase + +import numpy as np +import pandas as pd +from rosie.monthly_subquota_limit_classifier import MonthlySubquotaLimitClassifier + + +class TestMonthlySubquotaLimitClassifier(TestCase): + + def setUp(self): + self.dataset = pd.read_csv('tests/fixtures/monthly_subquota_limit_classifier.csv', + dtype={'subquota_number': np.str}) + self.subject = MonthlySubquotaLimitClassifier() + self.subject.fit_transform(self.dataset) + self.prediction = self.subject.predict(self.dataset) + + def test_predict_false_when_not_in_date_range(self): + self.assertEqual(False, self.prediction[0]) + self.assertEqual(False, self.prediction[1])