diff --git a/tests/integration/db/test_sql_data_splitting.py b/tests/integration/db/test_sql_data_splitting.py index a098ad66a6fc..4f7a67f5d396 100644 --- a/tests/integration/db/test_sql_data_splitting.py +++ b/tests/integration/db/test_sql_data_splitting.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import List, Tuple import pandas as pd @@ -11,10 +10,14 @@ from great_expectations.core.batch_spec import SqlAlchemyDatasourceBatchSpec from great_expectations.core.yaml_handler import YAMLHandler from great_expectations.datasource.data_connector import ConfiguredAssetSqlDataConnector -from great_expectations.execution_engine.split_and_sample.data_splitter import DatePart from great_expectations.execution_engine.sqlalchemy_batch_data import ( SqlAlchemyBatchData, ) +from tests.integration.fixtures.split_data.splitter_test_cases_and_fixtures import ( + TaxiSplittingTestCase, + TaxiSplittingTestCases, + TaxiTestData, +) from tests.test_utils import ( LoadedTable, clean_up_tables_with_prefix, @@ -77,114 +80,14 @@ def _load_data( test_df: pd.DataFrame = loaded_table.inserted_dataframe table_name: str = loaded_table.table_name - YEARS_IN_TAXI_DATA = ( - pd.date_range(start="2018-01-01", end="2020-12-31", freq="AS") - .to_pydatetime() - .tolist() - ) - YEAR_BATCH_IDENTIFIER_DATA: List[dict] = [ - {DatePart.YEAR.value: dt.year} for dt in YEARS_IN_TAXI_DATA - ] - - MONTHS_IN_TAXI_DATA = ( - pd.date_range(start="2018-01-01", end="2020-12-31", freq="MS") - .to_pydatetime() - .tolist() - ) - YEAR_MONTH_BATCH_IDENTIFIER_DATA: List[dict] = [ - {DatePart.YEAR.value: dt.year, DatePart.MONTH.value: dt.month} - for dt in MONTHS_IN_TAXI_DATA - ] - MONTH_BATCH_IDENTIFIER_DATA: List[dict] = [ - {DatePart.MONTH.value: dt.month} for dt in MONTHS_IN_TAXI_DATA - ] - - TEST_COLUMN: str = "pickup_datetime" - - # Since taxi data does not contain all days, we need to introspect the data to build the fixture: - YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA: List[dict] = list( - {val[0]: val[1], val[2]: val[3], val[4]: val[5]} - for val in { - ( - DatePart.YEAR.value, - dt.year, - DatePart.MONTH.value, - dt.month, - DatePart.DAY.value, - dt.day, - ) - for dt in test_df[TEST_COLUMN] - } + taxi_test_data: TaxiTestData = TaxiTestData( + test_df, test_column_name="pickup_datetime" ) - YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA: List[dict] = sorted( - YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA, - key=lambda x: ( - x[DatePart.YEAR.value], - x[DatePart.MONTH.value], - x[DatePart.DAY.value], - ), + taxi_splitting_test_cases: TaxiSplittingTestCases = TaxiSplittingTestCases( + taxi_test_data ) - @dataclass - class SqlSplittingTestCase: - splitter_method_name: str - splitter_kwargs: dict - num_expected_batch_definitions: int - num_expected_rows_in_first_batch_definition: int - expected_pickup_datetimes: List[dict] - - test_cases: List[SqlSplittingTestCase] = [ - SqlSplittingTestCase( - splitter_method_name="split_on_year", - splitter_kwargs={"column_name": TEST_COLUMN}, - num_expected_batch_definitions=3, - num_expected_rows_in_first_batch_definition=120, - expected_pickup_datetimes=YEAR_BATCH_IDENTIFIER_DATA, - ), - SqlSplittingTestCase( - splitter_method_name="split_on_year_and_month", - splitter_kwargs={"column_name": TEST_COLUMN}, - num_expected_batch_definitions=36, - num_expected_rows_in_first_batch_definition=10, - expected_pickup_datetimes=YEAR_MONTH_BATCH_IDENTIFIER_DATA, - ), - SqlSplittingTestCase( - splitter_method_name="split_on_year_and_month_and_day", - splitter_kwargs={"column_name": TEST_COLUMN}, - num_expected_batch_definitions=299, - num_expected_rows_in_first_batch_definition=2, - expected_pickup_datetimes=YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA, - ), - SqlSplittingTestCase( - splitter_method_name="split_on_date_parts", - splitter_kwargs={ - "column_name": TEST_COLUMN, - "date_parts": [DatePart.MONTH], - }, - num_expected_batch_definitions=12, - num_expected_rows_in_first_batch_definition=30, - expected_pickup_datetimes=MONTH_BATCH_IDENTIFIER_DATA, - ), - # date_parts as a string (with mixed case): - SqlSplittingTestCase( - splitter_method_name="split_on_date_parts", - splitter_kwargs={"column_name": TEST_COLUMN, "date_parts": ["mOnTh"]}, - num_expected_batch_definitions=12, - num_expected_rows_in_first_batch_definition=30, - expected_pickup_datetimes=MONTH_BATCH_IDENTIFIER_DATA, - ), - # Mix of types of date_parts: - SqlSplittingTestCase( - splitter_method_name="split_on_date_parts", - splitter_kwargs={ - "column_name": TEST_COLUMN, - "date_parts": [DatePart.YEAR, "month"], - }, - num_expected_batch_definitions=36, - num_expected_rows_in_first_batch_definition=10, - expected_pickup_datetimes=YEAR_MONTH_BATCH_IDENTIFIER_DATA, - ), - ] + test_cases: List[TaxiSplittingTestCase] = taxi_splitting_test_cases.test_cases() for test_case in test_cases: @@ -207,7 +110,7 @@ class SqlSplittingTestCase: # 2. Set splitter in data connector config data_connector_name: str = "test_data_connector" data_asset_name: str = table_name # Read from generated table name - column_name: str = TEST_COLUMN + column_name: str = taxi_splitting_test_cases.test_column_name data_connector: ConfiguredAssetSqlDataConnector = ( ConfiguredAssetSqlDataConnector( name=data_connector_name, diff --git a/tests/integration/fixtures/split_data/splitter_test_cases_and_fixtures.py b/tests/integration/fixtures/split_data/splitter_test_cases_and_fixtures.py new file mode 100644 index 000000000000..6a2abae9b267 --- /dev/null +++ b/tests/integration/fixtures/split_data/splitter_test_cases_and_fixtures.py @@ -0,0 +1,157 @@ +import datetime +from dataclasses import dataclass +from typing import List + +import pandas as pd + +from great_expectations.execution_engine.split_and_sample.data_splitter import DatePart + + +class TaxiTestData: + def __init__(self, test_df: pd.DataFrame, test_column_name: str): + self._test_df = test_df + self._test_column_name = test_column_name + + @property + def test_df(self): + return self._test_df + + @property + def test_column_name(self): + return self._test_column_name + + def years_in_taxi_data(self) -> List[datetime.datetime]: + return ( + pd.date_range(start="2018-01-01", end="2020-12-31", freq="AS") + .to_pydatetime() + .tolist() + ) + + def year_batch_identifier_data(self) -> List[dict]: + return [{DatePart.YEAR.value: dt.year} for dt in self.years_in_taxi_data()] + + def months_in_taxi_data(self) -> List[datetime.datetime]: + return ( + pd.date_range(start="2018-01-01", end="2020-12-31", freq="MS") + .to_pydatetime() + .tolist() + ) + + def year_month_batch_identifier_data(self) -> List[dict]: + return [ + {DatePart.YEAR.value: dt.year, DatePart.MONTH.value: dt.month} + for dt in self.months_in_taxi_data() + ] + + def month_batch_identifier_data(self) -> List[dict]: + return [{DatePart.MONTH.value: dt.month} for dt in self.months_in_taxi_data()] + + def year_month_day_batch_identifier_data(self) -> List[dict]: + # Since taxi data does not contain all days, + # we need to introspect the data to build the fixture: + year_month_day_batch_identifier_list_unsorted: List[dict] = list( + {val[0]: val[1], val[2]: val[3], val[4]: val[5]} + for val in { + ( + DatePart.YEAR.value, + dt.year, + DatePart.MONTH.value, + dt.month, + DatePart.DAY.value, + dt.day, + ) + for dt in self.test_df[self.test_column_name] + } + ) + + return sorted( + year_month_day_batch_identifier_list_unsorted, + key=lambda x: ( + x[DatePart.YEAR.value], + x[DatePart.MONTH.value], + x[DatePart.DAY.value], + ), + ) + + +@dataclass +class TaxiSplittingTestCase: + splitter_method_name: str + splitter_kwargs: dict + num_expected_batch_definitions: int + num_expected_rows_in_first_batch_definition: int + expected_pickup_datetimes: List[dict] + + +class TaxiSplittingTestCases: + def __init__(self, taxi_test_data: TaxiTestData): + self._taxi_test_data = taxi_test_data + + @property + def taxi_test_data(self) -> TaxiTestData: + return self._taxi_test_data + + @property + def test_df(self) -> pd.DataFrame: + return self._taxi_test_data.test_df + + @property + def test_column_name(self) -> str: + return self._taxi_test_data.test_column_name + + def test_cases(self) -> List[TaxiSplittingTestCase]: + return [ + TaxiSplittingTestCase( + splitter_method_name="split_on_year", + splitter_kwargs={"column_name": self.taxi_test_data.test_column_name}, + num_expected_batch_definitions=3, + num_expected_rows_in_first_batch_definition=120, + expected_pickup_datetimes=self.taxi_test_data.year_batch_identifier_data(), + ), + TaxiSplittingTestCase( + splitter_method_name="split_on_year_and_month", + splitter_kwargs={"column_name": self.taxi_test_data.test_column_name}, + num_expected_batch_definitions=36, + num_expected_rows_in_first_batch_definition=10, + expected_pickup_datetimes=self.taxi_test_data.year_month_batch_identifier_data(), + ), + TaxiSplittingTestCase( + splitter_method_name="split_on_year_and_month_and_day", + splitter_kwargs={"column_name": self.taxi_test_data.test_column_name}, + num_expected_batch_definitions=299, + num_expected_rows_in_first_batch_definition=2, + expected_pickup_datetimes=self.taxi_test_data.year_month_day_batch_identifier_data(), + ), + TaxiSplittingTestCase( + splitter_method_name="split_on_date_parts", + splitter_kwargs={ + "column_name": self.taxi_test_data.test_column_name, + "date_parts": [DatePart.MONTH], + }, + num_expected_batch_definitions=12, + num_expected_rows_in_first_batch_definition=30, + expected_pickup_datetimes=self.taxi_test_data.month_batch_identifier_data(), + ), + # date_parts as a string (with mixed case): + TaxiSplittingTestCase( + splitter_method_name="split_on_date_parts", + splitter_kwargs={ + "column_name": self.taxi_test_data.test_column_name, + "date_parts": ["mOnTh"], + }, + num_expected_batch_definitions=12, + num_expected_rows_in_first_batch_definition=30, + expected_pickup_datetimes=self.taxi_test_data.month_batch_identifier_data(), + ), + # Mix of types of date_parts: + TaxiSplittingTestCase( + splitter_method_name="split_on_date_parts", + splitter_kwargs={ + "column_name": self.taxi_test_data.test_column_name, + "date_parts": [DatePart.YEAR, "month"], + }, + num_expected_batch_definitions=36, + num_expected_rows_in_first_batch_definition=10, + expected_pickup_datetimes=self.taxi_test_data.year_month_batch_identifier_data(), + ), + ] diff --git a/tests/test_utils.py b/tests/test_utils.py index 7bf95b76e48d..09df5fb4683a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -504,6 +504,45 @@ class LoadedTable: inserted_dataframe: pd.DataFrame +def load_and_concatenate_csvs( + csv_paths: List[str], + load_full_dataset: bool = False, + convert_column_names_to_datetime: Optional[List[str]] = None, +) -> pd.DataFrame: + """Utility method that is used in loading test data into a pandas dataframe. + + It includes several parameters used to describe the output data. + + Args: + csv_paths: list of paths of csvs to write, can be a single path. These should all be the same shape. + load_full_dataset: if False, load only the first 10 rows. + convert_column_names_to_datetime: List of column names to convert to datetime before writing to db. + + Returns: + A pandas dataframe concatenating data loaded from all csvs. + """ + + if convert_column_names_to_datetime is None: + convert_column_names_to_datetime = [] + + import pandas as pd + + dfs: List[pd.DataFrame] = [] + for csv_path in csv_paths: + df = pd.read_csv(csv_path) + for column_name_to_convert in convert_column_names_to_datetime: + df[column_name_to_convert] = pd.to_datetime(df[column_name_to_convert]) + if not load_full_dataset: + # Improving test performance by only loading the first 10 rows of our test data into the db + df = df.head(10) + + dfs.append(df) + + all_dfs_concatenated: pd.DataFrame = pd.concat(dfs) + + return all_dfs_concatenated + + def load_data_into_test_database( table_name: str, connection_string: str, @@ -528,35 +567,20 @@ def load_data_into_test_database( same prefix. Returns: - For convenience, the pandas dataframe that was used to load the data. + LoadedTable which for convenience, contains the pandas dataframe that was used to load the data. """ if csv_path and csv_paths: csv_paths.append(csv_path) elif csv_path and not csv_paths: csv_paths = [csv_path] - if convert_colnames_to_datetime is None: - convert_colnames_to_datetime = [] + all_dfs_concatenated: pd.DataFrame = load_and_concatenate_csvs( + csv_paths, load_full_dataset, convert_colnames_to_datetime + ) if random_table_suffix: table_name: str = f"{table_name}_{str(uuid.uuid4())[:8]}" - import pandas as pd - - print("Generating dataframe of all csv data") - dfs: List[pd.DataFrame] = [] - for csv_path in csv_paths: - df = pd.read_csv(csv_path) - for colname_to_convert in convert_colnames_to_datetime: - df[colname_to_convert] = pd.to_datetime(df[colname_to_convert]) - if not load_full_dataset: - # Improving test performance by only loading the first 10 rows of our test data into the db - df = df.head(10) - - dfs.append(df) - - all_dfs_concatenated: pd.DataFrame = pd.concat(dfs) - return_value: LoadedTable = LoadedTable( table_name=table_name, inserted_dataframe=all_dfs_concatenated )