-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
test_sql_data_splitting.py
175 lines (140 loc) · 6.3 KB
/
test_sql_data_splitting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from typing import List, Tuple
import pandas as pd
import sqlalchemy as sa
import great_expectations as ge
from great_expectations import DataContext
from great_expectations.core import IDDict
from great_expectations.core.batch import BatchDefinition, BatchRequest
from great_expectations.core.batch_spec import SqlAlchemyDatasourceBatchSpec
from great_expectations.core.yaml_handler import YAMLHandler
from great_expectations.datasource.data_connector import ConfiguredAssetSqlDataConnector
from great_expectations.execution_engine.sqlalchemy_batch_data import (
SqlAlchemyBatchData,
)
from tests.integration.fixtures.split_data.splitter_test_cases_and_fixtures import (
TaxiSplittingTestCase,
TaxiSplittingTestCases,
TaxiTestData,
)
from tests.test_utils import (
LoadedTable,
clean_up_tables_with_prefix,
get_bigquery_connection_url,
get_snowflake_connection_url,
load_data_into_test_database,
)
yaml_handler: YAMLHandler = YAMLHandler()
def _get_connection_string_and_dialect() -> Tuple[str, str]:
with open("./connection_string.yml") as f:
db_config: dict = yaml_handler.load(f)
dialect: str = db_config["dialect"]
if dialect == "snowflake":
connection_string: str = get_snowflake_connection_url()
elif dialect == "bigquery":
connection_string: str = get_bigquery_connection_url()
else:
connection_string: str = db_config["connection_string"]
return dialect, connection_string
TAXI_DATA_TABLE_NAME: str = "taxi_data_all_samples"
def _load_data(
connection_string: str, table_name: str = TAXI_DATA_TABLE_NAME
) -> LoadedTable:
# Load the first 10 rows of each month of taxi data
return load_data_into_test_database(
table_name=table_name,
csv_paths=[
f"./data/yellow_tripdata_sample_{year}-{month}.csv"
for year in ["2018", "2019", "2020"]
for month in [f"{mo:02d}" for mo in range(1, 12 + 1)]
],
connection_string=connection_string,
convert_colnames_to_datetime=["pickup_datetime", "dropoff_datetime"],
random_table_suffix=True,
)
if __name__ == "test_script_module":
dialect, connection_string = _get_connection_string_and_dialect()
print(f"Testing dialect: {dialect}")
print("Preemptively cleaning old tables")
clean_up_tables_with_prefix(
connection_string=connection_string, table_prefix=f"{TAXI_DATA_TABLE_NAME}_"
)
loaded_table: LoadedTable = _load_data(connection_string=connection_string)
test_df: pd.DataFrame = loaded_table.inserted_dataframe
table_name: str = loaded_table.table_name
taxi_test_data: TaxiTestData = TaxiTestData(
test_df, test_column_name="pickup_datetime"
)
taxi_splitting_test_cases: TaxiSplittingTestCases = TaxiSplittingTestCases(
taxi_test_data
)
test_cases: List[TaxiSplittingTestCase] = taxi_splitting_test_cases.test_cases()
for test_case in test_cases:
print("Testing splitter method:", test_case.splitter_method_name)
# 1. Setup
context: DataContext = ge.get_context()
datasource_name: str = "test_datasource"
context.add_datasource(
name=datasource_name,
class_name="Datasource",
execution_engine={
"class_name": "SqlAlchemyExecutionEngine",
"connection_string": connection_string,
},
)
# 2. Set splitter in data connector config
data_connector_name: str = "test_data_connector"
data_asset_name: str = table_name # Read from generated table name
column_name: str = taxi_splitting_test_cases.test_column_name
data_connector: ConfiguredAssetSqlDataConnector = (
ConfiguredAssetSqlDataConnector(
name=data_connector_name,
datasource_name=datasource_name,
execution_engine=context.datasources[datasource_name].execution_engine,
assets={
data_asset_name: {
"splitter_method": test_case.splitter_method_name,
"splitter_kwargs": test_case.splitter_kwargs,
}
},
)
)
# 3. Check if resulting batches are as expected
# using data_connector.get_batch_definition_list_from_batch_request()
batch_request: BatchRequest = BatchRequest(
datasource_name=datasource_name,
data_connector_name=data_connector_name,
data_asset_name=data_asset_name,
)
batch_definition_list: List[
BatchDefinition
] = data_connector.get_batch_definition_list_from_batch_request(batch_request)
assert len(batch_definition_list) == test_case.num_expected_batch_definitions
expected_batch_definition_list: List[BatchDefinition] = [
BatchDefinition(
datasource_name=datasource_name,
data_connector_name=data_connector_name,
data_asset_name=data_asset_name,
batch_identifiers=IDDict({column_name: pickup_datetime}),
)
for pickup_datetime in test_case.expected_pickup_datetimes
]
assert set(batch_definition_list) == set(
expected_batch_definition_list
), f"BatchDefinition lists don't match\n\nbatch_definition_list:\n{batch_definition_list}\n\nexpected_batch_definition_list:\n{expected_batch_definition_list}"
# 4. Check that loaded data is as expected
# Use expected_batch_definition_list since it is sorted, and we already
# asserted that it contains the same items as batch_definition_list
batch_spec: SqlAlchemyDatasourceBatchSpec = data_connector.build_batch_spec(
expected_batch_definition_list[0]
)
batch_data: SqlAlchemyBatchData = context.datasources[
datasource_name
].execution_engine.get_batch_data(batch_spec=batch_spec)
num_rows: int = batch_data.execution_engine.engine.execute(
sa.select([sa.func.count()]).select_from(batch_data.selectable)
).scalar()
assert num_rows == test_case.num_expected_rows_in_first_batch_definition
print("Clean up tables used in this test")
clean_up_tables_with_prefix(
connection_string=connection_string, table_prefix=f"{TAXI_DATA_TABLE_NAME}_"
)