Skip to content

Commit

Permalink
[MAINTENANCE] Move splitter related taxi integration test fixtures (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonyburdi committed Apr 25, 2022
1 parent 81ec5cb commit 562874b
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 127 deletions.
119 changes: 11 additions & 108 deletions tests/integration/db/test_sql_data_splitting.py
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import List, Tuple

import pandas as pd
Expand All @@ -11,10 +10,14 @@
from great_expectations.core.batch_spec import SqlAlchemyDatasourceBatchSpec
from great_expectations.core.yaml_handler import YAMLHandler
from great_expectations.datasource.data_connector import ConfiguredAssetSqlDataConnector
from great_expectations.execution_engine.split_and_sample.data_splitter import DatePart
from great_expectations.execution_engine.sqlalchemy_batch_data import (
SqlAlchemyBatchData,
)
from tests.integration.fixtures.split_data.splitter_test_cases_and_fixtures import (
TaxiSplittingTestCase,
TaxiSplittingTestCases,
TaxiTestData,
)
from tests.test_utils import (
LoadedTable,
clean_up_tables_with_prefix,
Expand Down Expand Up @@ -77,114 +80,14 @@ def _load_data(
test_df: pd.DataFrame = loaded_table.inserted_dataframe
table_name: str = loaded_table.table_name

YEARS_IN_TAXI_DATA = (
pd.date_range(start="2018-01-01", end="2020-12-31", freq="AS")
.to_pydatetime()
.tolist()
)
YEAR_BATCH_IDENTIFIER_DATA: List[dict] = [
{DatePart.YEAR.value: dt.year} for dt in YEARS_IN_TAXI_DATA
]

MONTHS_IN_TAXI_DATA = (
pd.date_range(start="2018-01-01", end="2020-12-31", freq="MS")
.to_pydatetime()
.tolist()
)
YEAR_MONTH_BATCH_IDENTIFIER_DATA: List[dict] = [
{DatePart.YEAR.value: dt.year, DatePart.MONTH.value: dt.month}
for dt in MONTHS_IN_TAXI_DATA
]
MONTH_BATCH_IDENTIFIER_DATA: List[dict] = [
{DatePart.MONTH.value: dt.month} for dt in MONTHS_IN_TAXI_DATA
]

TEST_COLUMN: str = "pickup_datetime"

# Since taxi data does not contain all days, we need to introspect the data to build the fixture:
YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA: List[dict] = list(
{val[0]: val[1], val[2]: val[3], val[4]: val[5]}
for val in {
(
DatePart.YEAR.value,
dt.year,
DatePart.MONTH.value,
dt.month,
DatePart.DAY.value,
dt.day,
)
for dt in test_df[TEST_COLUMN]
}
taxi_test_data: TaxiTestData = TaxiTestData(
test_df, test_column_name="pickup_datetime"
)
YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA: List[dict] = sorted(
YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA,
key=lambda x: (
x[DatePart.YEAR.value],
x[DatePart.MONTH.value],
x[DatePart.DAY.value],
),
taxi_splitting_test_cases: TaxiSplittingTestCases = TaxiSplittingTestCases(
taxi_test_data
)

@dataclass
class SqlSplittingTestCase:
splitter_method_name: str
splitter_kwargs: dict
num_expected_batch_definitions: int
num_expected_rows_in_first_batch_definition: int
expected_pickup_datetimes: List[dict]

test_cases: List[SqlSplittingTestCase] = [
SqlSplittingTestCase(
splitter_method_name="split_on_year",
splitter_kwargs={"column_name": TEST_COLUMN},
num_expected_batch_definitions=3,
num_expected_rows_in_first_batch_definition=120,
expected_pickup_datetimes=YEAR_BATCH_IDENTIFIER_DATA,
),
SqlSplittingTestCase(
splitter_method_name="split_on_year_and_month",
splitter_kwargs={"column_name": TEST_COLUMN},
num_expected_batch_definitions=36,
num_expected_rows_in_first_batch_definition=10,
expected_pickup_datetimes=YEAR_MONTH_BATCH_IDENTIFIER_DATA,
),
SqlSplittingTestCase(
splitter_method_name="split_on_year_and_month_and_day",
splitter_kwargs={"column_name": TEST_COLUMN},
num_expected_batch_definitions=299,
num_expected_rows_in_first_batch_definition=2,
expected_pickup_datetimes=YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA,
),
SqlSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={
"column_name": TEST_COLUMN,
"date_parts": [DatePart.MONTH],
},
num_expected_batch_definitions=12,
num_expected_rows_in_first_batch_definition=30,
expected_pickup_datetimes=MONTH_BATCH_IDENTIFIER_DATA,
),
# date_parts as a string (with mixed case):
SqlSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={"column_name": TEST_COLUMN, "date_parts": ["mOnTh"]},
num_expected_batch_definitions=12,
num_expected_rows_in_first_batch_definition=30,
expected_pickup_datetimes=MONTH_BATCH_IDENTIFIER_DATA,
),
# Mix of types of date_parts:
SqlSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={
"column_name": TEST_COLUMN,
"date_parts": [DatePart.YEAR, "month"],
},
num_expected_batch_definitions=36,
num_expected_rows_in_first_batch_definition=10,
expected_pickup_datetimes=YEAR_MONTH_BATCH_IDENTIFIER_DATA,
),
]
test_cases: List[TaxiSplittingTestCase] = taxi_splitting_test_cases.test_cases()

for test_case in test_cases:

Expand All @@ -207,7 +110,7 @@ class SqlSplittingTestCase:
# 2. Set splitter in data connector config
data_connector_name: str = "test_data_connector"
data_asset_name: str = table_name # Read from generated table name
column_name: str = TEST_COLUMN
column_name: str = taxi_splitting_test_cases.test_column_name
data_connector: ConfiguredAssetSqlDataConnector = (
ConfiguredAssetSqlDataConnector(
name=data_connector_name,
Expand Down
@@ -0,0 +1,157 @@
import datetime
from dataclasses import dataclass
from typing import List

import pandas as pd

from great_expectations.execution_engine.split_and_sample.data_splitter import DatePart


class TaxiTestData:
def __init__(self, test_df: pd.DataFrame, test_column_name: str):
self._test_df = test_df
self._test_column_name = test_column_name

@property
def test_df(self):
return self._test_df

@property
def test_column_name(self):
return self._test_column_name

def years_in_taxi_data(self) -> List[datetime.datetime]:
return (
pd.date_range(start="2018-01-01", end="2020-12-31", freq="AS")
.to_pydatetime()
.tolist()
)

def year_batch_identifier_data(self) -> List[dict]:
return [{DatePart.YEAR.value: dt.year} for dt in self.years_in_taxi_data()]

def months_in_taxi_data(self) -> List[datetime.datetime]:
return (
pd.date_range(start="2018-01-01", end="2020-12-31", freq="MS")
.to_pydatetime()
.tolist()
)

def year_month_batch_identifier_data(self) -> List[dict]:
return [
{DatePart.YEAR.value: dt.year, DatePart.MONTH.value: dt.month}
for dt in self.months_in_taxi_data()
]

def month_batch_identifier_data(self) -> List[dict]:
return [{DatePart.MONTH.value: dt.month} for dt in self.months_in_taxi_data()]

def year_month_day_batch_identifier_data(self) -> List[dict]:
# Since taxi data does not contain all days,
# we need to introspect the data to build the fixture:
year_month_day_batch_identifier_list_unsorted: List[dict] = list(
{val[0]: val[1], val[2]: val[3], val[4]: val[5]}
for val in {
(
DatePart.YEAR.value,
dt.year,
DatePart.MONTH.value,
dt.month,
DatePart.DAY.value,
dt.day,
)
for dt in self.test_df[self.test_column_name]
}
)

return sorted(
year_month_day_batch_identifier_list_unsorted,
key=lambda x: (
x[DatePart.YEAR.value],
x[DatePart.MONTH.value],
x[DatePart.DAY.value],
),
)


@dataclass
class TaxiSplittingTestCase:
splitter_method_name: str
splitter_kwargs: dict
num_expected_batch_definitions: int
num_expected_rows_in_first_batch_definition: int
expected_pickup_datetimes: List[dict]


class TaxiSplittingTestCases:
def __init__(self, taxi_test_data: TaxiTestData):
self._taxi_test_data = taxi_test_data

@property
def taxi_test_data(self) -> TaxiTestData:
return self._taxi_test_data

@property
def test_df(self) -> pd.DataFrame:
return self._taxi_test_data.test_df

@property
def test_column_name(self) -> str:
return self._taxi_test_data.test_column_name

def test_cases(self) -> List[TaxiSplittingTestCase]:
return [
TaxiSplittingTestCase(
splitter_method_name="split_on_year",
splitter_kwargs={"column_name": self.taxi_test_data.test_column_name},
num_expected_batch_definitions=3,
num_expected_rows_in_first_batch_definition=120,
expected_pickup_datetimes=self.taxi_test_data.year_batch_identifier_data(),
),
TaxiSplittingTestCase(
splitter_method_name="split_on_year_and_month",
splitter_kwargs={"column_name": self.taxi_test_data.test_column_name},
num_expected_batch_definitions=36,
num_expected_rows_in_first_batch_definition=10,
expected_pickup_datetimes=self.taxi_test_data.year_month_batch_identifier_data(),
),
TaxiSplittingTestCase(
splitter_method_name="split_on_year_and_month_and_day",
splitter_kwargs={"column_name": self.taxi_test_data.test_column_name},
num_expected_batch_definitions=299,
num_expected_rows_in_first_batch_definition=2,
expected_pickup_datetimes=self.taxi_test_data.year_month_day_batch_identifier_data(),
),
TaxiSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={
"column_name": self.taxi_test_data.test_column_name,
"date_parts": [DatePart.MONTH],
},
num_expected_batch_definitions=12,
num_expected_rows_in_first_batch_definition=30,
expected_pickup_datetimes=self.taxi_test_data.month_batch_identifier_data(),
),
# date_parts as a string (with mixed case):
TaxiSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={
"column_name": self.taxi_test_data.test_column_name,
"date_parts": ["mOnTh"],
},
num_expected_batch_definitions=12,
num_expected_rows_in_first_batch_definition=30,
expected_pickup_datetimes=self.taxi_test_data.month_batch_identifier_data(),
),
# Mix of types of date_parts:
TaxiSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={
"column_name": self.taxi_test_data.test_column_name,
"date_parts": [DatePart.YEAR, "month"],
},
num_expected_batch_definitions=36,
num_expected_rows_in_first_batch_definition=10,
expected_pickup_datetimes=self.taxi_test_data.year_month_batch_identifier_data(),
),
]
62 changes: 43 additions & 19 deletions tests/test_utils.py
Expand Up @@ -504,6 +504,45 @@ class LoadedTable:
inserted_dataframe: pd.DataFrame


def load_and_concatenate_csvs(
csv_paths: List[str],
load_full_dataset: bool = False,
convert_column_names_to_datetime: Optional[List[str]] = None,
) -> pd.DataFrame:
"""Utility method that is used in loading test data into a pandas dataframe.
It includes several parameters used to describe the output data.
Args:
csv_paths: list of paths of csvs to write, can be a single path. These should all be the same shape.
load_full_dataset: if False, load only the first 10 rows.
convert_column_names_to_datetime: List of column names to convert to datetime before writing to db.
Returns:
A pandas dataframe concatenating data loaded from all csvs.
"""

if convert_column_names_to_datetime is None:
convert_column_names_to_datetime = []

import pandas as pd

dfs: List[pd.DataFrame] = []
for csv_path in csv_paths:
df = pd.read_csv(csv_path)
for column_name_to_convert in convert_column_names_to_datetime:
df[column_name_to_convert] = pd.to_datetime(df[column_name_to_convert])
if not load_full_dataset:
# Improving test performance by only loading the first 10 rows of our test data into the db
df = df.head(10)

dfs.append(df)

all_dfs_concatenated: pd.DataFrame = pd.concat(dfs)

return all_dfs_concatenated


def load_data_into_test_database(
table_name: str,
connection_string: str,
Expand All @@ -528,35 +567,20 @@ def load_data_into_test_database(
same prefix.
Returns:
For convenience, the pandas dataframe that was used to load the data.
LoadedTable which for convenience, contains the pandas dataframe that was used to load the data.
"""
if csv_path and csv_paths:
csv_paths.append(csv_path)
elif csv_path and not csv_paths:
csv_paths = [csv_path]

if convert_colnames_to_datetime is None:
convert_colnames_to_datetime = []
all_dfs_concatenated: pd.DataFrame = load_and_concatenate_csvs(
csv_paths, load_full_dataset, convert_colnames_to_datetime
)

if random_table_suffix:
table_name: str = f"{table_name}_{str(uuid.uuid4())[:8]}"

import pandas as pd

print("Generating dataframe of all csv data")
dfs: List[pd.DataFrame] = []
for csv_path in csv_paths:
df = pd.read_csv(csv_path)
for colname_to_convert in convert_colnames_to_datetime:
df[colname_to_convert] = pd.to_datetime(df[colname_to_convert])
if not load_full_dataset:
# Improving test performance by only loading the first 10 rows of our test data into the db
df = df.head(10)

dfs.append(df)

all_dfs_concatenated: pd.DataFrame = pd.concat(dfs)

return_value: LoadedTable = LoadedTable(
table_name=table_name, inserted_dataframe=all_dfs_concatenated
)
Expand Down

0 comments on commit 562874b

Please sign in to comment.