From 3fa6f2a80b93e58c1b5449d360349987e01882f6 Mon Sep 17 00:00:00 2001 From: Anthony Burdi Date: Mon, 25 Apr 2022 15:34:07 -0400 Subject: [PATCH] Lint --- .../splitter_test_cases_and_fixtures.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/tests/integration/fixtures/split_data/splitter_test_cases_and_fixtures.py b/tests/integration/fixtures/split_data/splitter_test_cases_and_fixtures.py index a900d633973e..6a2abae9b267 100644 --- a/tests/integration/fixtures/split_data/splitter_test_cases_and_fixtures.py +++ b/tests/integration/fixtures/split_data/splitter_test_cases_and_fixtures.py @@ -8,7 +8,6 @@ class TaxiTestData: - def __init__(self, test_df: pd.DataFrame, test_column_name: str): self._test_df = test_df self._test_column_name = test_column_name @@ -21,20 +20,21 @@ def test_df(self): def test_column_name(self): return self._test_column_name - def years_in_taxi_data(self) -> List[datetime.datetime]: - return pd.date_range(start="2018-01-01", end="2020-12-31", freq="AS").to_pydatetime().tolist() + return ( + pd.date_range(start="2018-01-01", end="2020-12-31", freq="AS") + .to_pydatetime() + .tolist() + ) def year_batch_identifier_data(self) -> List[dict]: - return [ - {DatePart.YEAR.value: dt.year} for dt in self.years_in_taxi_data() - ] + return [{DatePart.YEAR.value: dt.year} for dt in self.years_in_taxi_data()] def months_in_taxi_data(self) -> List[datetime.datetime]: return ( pd.date_range(start="2018-01-01", end="2020-12-31", freq="MS") - .to_pydatetime() - .tolist() + .to_pydatetime() + .tolist() ) def year_month_batch_identifier_data(self) -> List[dict]: @@ -44,10 +44,7 @@ def year_month_batch_identifier_data(self) -> List[dict]: ] def month_batch_identifier_data(self) -> List[dict]: - return [ - {DatePart.MONTH.value: dt.month} for dt in self.months_in_taxi_data() - ] - + return [{DatePart.MONTH.value: dt.month} for dt in self.months_in_taxi_data()] def year_month_day_batch_identifier_data(self) -> List[dict]: # Since taxi data does not contain all days, @@ -77,8 +74,6 @@ def year_month_day_batch_identifier_data(self) -> List[dict]: ) - - @dataclass class TaxiSplittingTestCase: splitter_method_name: str @@ -89,7 +84,6 @@ class TaxiSplittingTestCase: class TaxiSplittingTestCases: - def __init__(self, taxi_test_data: TaxiTestData): self._taxi_test_data = taxi_test_data @@ -141,7 +135,10 @@ def test_cases(self) -> List[TaxiSplittingTestCase]: # date_parts as a string (with mixed case): TaxiSplittingTestCase( splitter_method_name="split_on_date_parts", - splitter_kwargs={"column_name": self.taxi_test_data.test_column_name, "date_parts": ["mOnTh"]}, + splitter_kwargs={ + "column_name": self.taxi_test_data.test_column_name, + "date_parts": ["mOnTh"], + }, num_expected_batch_definitions=12, num_expected_rows_in_first_batch_definition=30, expected_pickup_datetimes=self.taxi_test_data.month_batch_identifier_data(),