Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MAINTENANCE] Move splitter related taxi integration test fixtures #4947

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
119 changes: 11 additions & 108 deletions tests/integration/db/test_sql_data_splitting.py
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import List, Tuple

import pandas as pd
Expand All @@ -11,10 +10,14 @@
from great_expectations.core.batch_spec import SqlAlchemyDatasourceBatchSpec
from great_expectations.core.yaml_handler import YAMLHandler
from great_expectations.datasource.data_connector import ConfiguredAssetSqlDataConnector
from great_expectations.execution_engine.split_and_sample.data_splitter import DatePart
from great_expectations.execution_engine.sqlalchemy_batch_data import (
SqlAlchemyBatchData,
)
from tests.integration.fixtures.split_data.splitter_test_cases_and_fixtures import (
TaxiSplittingTestCase,
TaxiSplittingTestCases,
TaxiTestData,
)
from tests.test_utils import (
LoadedTable,
clean_up_tables_with_prefix,
Expand Down Expand Up @@ -77,114 +80,14 @@ def _load_data(
test_df: pd.DataFrame = loaded_table.inserted_dataframe
table_name: str = loaded_table.table_name

YEARS_IN_TAXI_DATA = (
pd.date_range(start="2018-01-01", end="2020-12-31", freq="AS")
.to_pydatetime()
.tolist()
)
YEAR_BATCH_IDENTIFIER_DATA: List[dict] = [
{DatePart.YEAR.value: dt.year} for dt in YEARS_IN_TAXI_DATA
]

MONTHS_IN_TAXI_DATA = (
pd.date_range(start="2018-01-01", end="2020-12-31", freq="MS")
.to_pydatetime()
.tolist()
)
YEAR_MONTH_BATCH_IDENTIFIER_DATA: List[dict] = [
{DatePart.YEAR.value: dt.year, DatePart.MONTH.value: dt.month}
for dt in MONTHS_IN_TAXI_DATA
]
MONTH_BATCH_IDENTIFIER_DATA: List[dict] = [
{DatePart.MONTH.value: dt.month} for dt in MONTHS_IN_TAXI_DATA
]

TEST_COLUMN: str = "pickup_datetime"

# Since taxi data does not contain all days, we need to introspect the data to build the fixture:
YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA: List[dict] = list(
{val[0]: val[1], val[2]: val[3], val[4]: val[5]}
for val in {
(
DatePart.YEAR.value,
dt.year,
DatePart.MONTH.value,
dt.month,
DatePart.DAY.value,
dt.day,
)
for dt in test_df[TEST_COLUMN]
}
taxi_test_data: TaxiTestData = TaxiTestData(
test_df, test_column_name="pickup_datetime"
)
YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA: List[dict] = sorted(
YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA,
key=lambda x: (
x[DatePart.YEAR.value],
x[DatePart.MONTH.value],
x[DatePart.DAY.value],
),
taxi_splitting_test_cases: TaxiSplittingTestCases = TaxiSplittingTestCases(
taxi_test_data
)

@dataclass
class SqlSplittingTestCase:
splitter_method_name: str
splitter_kwargs: dict
num_expected_batch_definitions: int
num_expected_rows_in_first_batch_definition: int
expected_pickup_datetimes: List[dict]

test_cases: List[SqlSplittingTestCase] = [
SqlSplittingTestCase(
splitter_method_name="split_on_year",
splitter_kwargs={"column_name": TEST_COLUMN},
num_expected_batch_definitions=3,
num_expected_rows_in_first_batch_definition=120,
expected_pickup_datetimes=YEAR_BATCH_IDENTIFIER_DATA,
),
SqlSplittingTestCase(
splitter_method_name="split_on_year_and_month",
splitter_kwargs={"column_name": TEST_COLUMN},
num_expected_batch_definitions=36,
num_expected_rows_in_first_batch_definition=10,
expected_pickup_datetimes=YEAR_MONTH_BATCH_IDENTIFIER_DATA,
),
SqlSplittingTestCase(
splitter_method_name="split_on_year_and_month_and_day",
splitter_kwargs={"column_name": TEST_COLUMN},
num_expected_batch_definitions=299,
num_expected_rows_in_first_batch_definition=2,
expected_pickup_datetimes=YEAR_MONTH_DAY_BATCH_IDENTIFIER_DATA,
),
SqlSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={
"column_name": TEST_COLUMN,
"date_parts": [DatePart.MONTH],
},
num_expected_batch_definitions=12,
num_expected_rows_in_first_batch_definition=30,
expected_pickup_datetimes=MONTH_BATCH_IDENTIFIER_DATA,
),
# date_parts as a string (with mixed case):
SqlSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={"column_name": TEST_COLUMN, "date_parts": ["mOnTh"]},
num_expected_batch_definitions=12,
num_expected_rows_in_first_batch_definition=30,
expected_pickup_datetimes=MONTH_BATCH_IDENTIFIER_DATA,
),
# Mix of types of date_parts:
SqlSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={
"column_name": TEST_COLUMN,
"date_parts": [DatePart.YEAR, "month"],
},
num_expected_batch_definitions=36,
num_expected_rows_in_first_batch_definition=10,
expected_pickup_datetimes=YEAR_MONTH_BATCH_IDENTIFIER_DATA,
),
]
test_cases: List[TaxiSplittingTestCase] = taxi_splitting_test_cases.test_cases()

for test_case in test_cases:

Expand All @@ -207,7 +110,7 @@ class SqlSplittingTestCase:
# 2. Set splitter in data connector config
data_connector_name: str = "test_data_connector"
data_asset_name: str = table_name # Read from generated table name
column_name: str = TEST_COLUMN
column_name: str = taxi_splitting_test_cases.test_column_name
data_connector: ConfiguredAssetSqlDataConnector = (
ConfiguredAssetSqlDataConnector(
name=data_connector_name,
Expand Down
@@ -0,0 +1,157 @@
import datetime
from dataclasses import dataclass
from typing import List

import pandas as pd

from great_expectations.execution_engine.split_and_sample.data_splitter import DatePart


class TaxiTestData:
def __init__(self, test_df: pd.DataFrame, test_column_name: str):
self._test_df = test_df
self._test_column_name = test_column_name

@property
def test_df(self):
return self._test_df

@property
def test_column_name(self):
return self._test_column_name

def years_in_taxi_data(self) -> List[datetime.datetime]:
return (
pd.date_range(start="2018-01-01", end="2020-12-31", freq="AS")
.to_pydatetime()
.tolist()
)

def year_batch_identifier_data(self) -> List[dict]:
return [{DatePart.YEAR.value: dt.year} for dt in self.years_in_taxi_data()]

def months_in_taxi_data(self) -> List[datetime.datetime]:
return (
pd.date_range(start="2018-01-01", end="2020-12-31", freq="MS")
.to_pydatetime()
.tolist()
)

def year_month_batch_identifier_data(self) -> List[dict]:
return [
{DatePart.YEAR.value: dt.year, DatePart.MONTH.value: dt.month}
for dt in self.months_in_taxi_data()
]

def month_batch_identifier_data(self) -> List[dict]:
return [{DatePart.MONTH.value: dt.month} for dt in self.months_in_taxi_data()]

def year_month_day_batch_identifier_data(self) -> List[dict]:
# Since taxi data does not contain all days,
# we need to introspect the data to build the fixture:
year_month_day_batch_identifier_list_unsorted: List[dict] = list(
{val[0]: val[1], val[2]: val[3], val[4]: val[5]}
for val in {
(
DatePart.YEAR.value,
dt.year,
DatePart.MONTH.value,
dt.month,
DatePart.DAY.value,
dt.day,
)
for dt in self.test_df[self.test_column_name]
}
)

return sorted(
year_month_day_batch_identifier_list_unsorted,
key=lambda x: (
x[DatePart.YEAR.value],
x[DatePart.MONTH.value],
x[DatePart.DAY.value],
),
)


@dataclass
class TaxiSplittingTestCase:
splitter_method_name: str
splitter_kwargs: dict
num_expected_batch_definitions: int
num_expected_rows_in_first_batch_definition: int
expected_pickup_datetimes: List[dict]


class TaxiSplittingTestCases:
def __init__(self, taxi_test_data: TaxiTestData):
self._taxi_test_data = taxi_test_data

@property
def taxi_test_data(self) -> TaxiTestData:
return self._taxi_test_data

@property
def test_df(self) -> pd.DataFrame:
return self._taxi_test_data.test_df

@property
def test_column_name(self) -> str:
return self._taxi_test_data.test_column_name

def test_cases(self) -> List[TaxiSplittingTestCase]:
return [
TaxiSplittingTestCase(
splitter_method_name="split_on_year",
splitter_kwargs={"column_name": self.taxi_test_data.test_column_name},
num_expected_batch_definitions=3,
num_expected_rows_in_first_batch_definition=120,
expected_pickup_datetimes=self.taxi_test_data.year_batch_identifier_data(),
),
TaxiSplittingTestCase(
splitter_method_name="split_on_year_and_month",
splitter_kwargs={"column_name": self.taxi_test_data.test_column_name},
num_expected_batch_definitions=36,
num_expected_rows_in_first_batch_definition=10,
expected_pickup_datetimes=self.taxi_test_data.year_month_batch_identifier_data(),
),
TaxiSplittingTestCase(
splitter_method_name="split_on_year_and_month_and_day",
splitter_kwargs={"column_name": self.taxi_test_data.test_column_name},
num_expected_batch_definitions=299,
num_expected_rows_in_first_batch_definition=2,
expected_pickup_datetimes=self.taxi_test_data.year_month_day_batch_identifier_data(),
),
TaxiSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={
"column_name": self.taxi_test_data.test_column_name,
"date_parts": [DatePart.MONTH],
},
num_expected_batch_definitions=12,
num_expected_rows_in_first_batch_definition=30,
expected_pickup_datetimes=self.taxi_test_data.month_batch_identifier_data(),
),
# date_parts as a string (with mixed case):
TaxiSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={
"column_name": self.taxi_test_data.test_column_name,
"date_parts": ["mOnTh"],
},
num_expected_batch_definitions=12,
num_expected_rows_in_first_batch_definition=30,
expected_pickup_datetimes=self.taxi_test_data.month_batch_identifier_data(),
),
# Mix of types of date_parts:
TaxiSplittingTestCase(
splitter_method_name="split_on_date_parts",
splitter_kwargs={
"column_name": self.taxi_test_data.test_column_name,
"date_parts": [DatePart.YEAR, "month"],
},
num_expected_batch_definitions=36,
num_expected_rows_in_first_batch_definition=10,
expected_pickup_datetimes=self.taxi_test_data.year_month_batch_identifier_data(),
),
]
62 changes: 43 additions & 19 deletions tests/test_utils.py
Expand Up @@ -504,6 +504,45 @@ class LoadedTable:
inserted_dataframe: pd.DataFrame


def load_and_concatenate_csvs(
csv_paths: List[str],
load_full_dataset: bool = False,
convert_column_names_to_datetime: Optional[List[str]] = None,
) -> pd.DataFrame:
"""Utility method that is used in loading test data into a pandas dataframe.

It includes several parameters used to describe the output data.

Args:
csv_paths: list of paths of csvs to write, can be a single path. These should all be the same shape.
load_full_dataset: if False, load only the first 10 rows.
convert_column_names_to_datetime: List of column names to convert to datetime before writing to db.

Returns:
A pandas dataframe concatenating data loaded from all csvs.
"""

if convert_column_names_to_datetime is None:
convert_column_names_to_datetime = []

import pandas as pd

dfs: List[pd.DataFrame] = []
for csv_path in csv_paths:
df = pd.read_csv(csv_path)
for column_name_to_convert in convert_column_names_to_datetime:
df[column_name_to_convert] = pd.to_datetime(df[column_name_to_convert])
if not load_full_dataset:
# Improving test performance by only loading the first 10 rows of our test data into the db
df = df.head(10)

dfs.append(df)

all_dfs_concatenated: pd.DataFrame = pd.concat(dfs)

return all_dfs_concatenated


def load_data_into_test_database(
table_name: str,
connection_string: str,
Expand All @@ -528,35 +567,20 @@ def load_data_into_test_database(
same prefix.

Returns:
For convenience, the pandas dataframe that was used to load the data.
LoadedTable which for convenience, contains the pandas dataframe that was used to load the data.
"""
if csv_path and csv_paths:
csv_paths.append(csv_path)
elif csv_path and not csv_paths:
csv_paths = [csv_path]

if convert_colnames_to_datetime is None:
convert_colnames_to_datetime = []
all_dfs_concatenated: pd.DataFrame = load_and_concatenate_csvs(
csv_paths, load_full_dataset, convert_colnames_to_datetime
)

if random_table_suffix:
table_name: str = f"{table_name}_{str(uuid.uuid4())[:8]}"

import pandas as pd

print("Generating dataframe of all csv data")
dfs: List[pd.DataFrame] = []
for csv_path in csv_paths:
df = pd.read_csv(csv_path)
for colname_to_convert in convert_colnames_to_datetime:
df[colname_to_convert] = pd.to_datetime(df[colname_to_convert])
if not load_full_dataset:
# Improving test performance by only loading the first 10 rows of our test data into the db
df = df.head(10)

dfs.append(df)

all_dfs_concatenated: pd.DataFrame = pd.concat(dfs)

return_value: LoadedTable = LoadedTable(
table_name=table_name, inserted_dataframe=all_dfs_concatenated
)
Expand Down