Skip to content

Commit

Permalink
Merge pull request #9 from mehd-io/feat/explicit-schema
Browse files Browse the repository at this point in the history
Feat: Ingest pipeline using explicit schema
  • Loading branch information
mehd-io committed Jun 19, 2024
2 parents b0e7f74 + e83765e commit a6ef142
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 61 deletions.
15 changes: 10 additions & 5 deletions ingestion/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from google.auth.exceptions import DefaultCredentialsError
from loguru import logger
import time
from ingestion.models import PypiJobParameters
from ingestion.models import PypiJobParameters, FileDownloads
import pandas as pd
import pyarrow as pa

PYPI_PUBLIC_DATASET = "bigquery-public-data.pypi.file_downloads"

Expand Down Expand Up @@ -48,21 +49,25 @@ def get_bigquery_client(project_name: str) -> bigquery.Client:
raise creds_error


2


def get_bigquery_result(
query_str: str, bigquery_client: bigquery.Client
) -> pd.DataFrame:
query_str: str, bigquery_client: bigquery.Client, model: FileDownloads
) -> pa.Table:
"""Get query result from BigQuery and yield rows as dictionaries."""
try:
# Start measuring time
start_time = time.time()
# Run the query and directly load into a DataFrame
logger.info(f"Running query: {query_str}")
dataframe = bigquery_client.query(query_str).to_dataframe()
# dataframe = bigquery_client.query(query_str).to_dataframe(dtypes=FileDownloads().pandas_dtypes)
pa_tbl = bigquery_client.query(query_str).to_arrow()
# Log the time taken for query execution and data loading
elapsed_time = time.time() - start_time
logger.info(f"Query executed and data loaded in {elapsed_time:.2f} seconds")
# Iterate over DataFrame rows and yield as dictionaries
return dataframe
return pa_tbl

except Exception as e:
logger.error(f"Error running query: {e}")
Expand Down
20 changes: 15 additions & 5 deletions ingestion/duck.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
""" Helper functions for interacting with DuckDB """
from typing import List
from loguru import logger
import pandas as pd


def create_table_from_dataframe(duckdb_con, table_name: str, dataframe: str):
def create_table_from_dataframe(duckdb_con, table_name: str, table_ddl: str):
logger.info(f"Creating table {table_name} in local DuckDB")
duckdb_con.sql(table_ddl)
logger.info("inserting data into table")
duckdb_con.sql(
f"""
CREATE TABLE {table_name} AS
INSERT INTO {table_name}
SELECT *
FROM {dataframe}
FROM pa_tbl
"""
)

Expand Down Expand Up @@ -51,16 +54,23 @@ def write_to_md_from_duckdb(
start_date: str,
end_date: str,
):
logger.info(f"Writing data to motherduck {remote_database}.main.{table}")
logger.info(
f"Creating database {remote_database} from {local_database}.{table} if it doesn't exist"
)
duckdb_con.sql(f"CREATE DATABASE IF NOT EXISTS {remote_database}")
logger.info(
f"Creating table {remote_database}.main.{table} from {local_database}.{table} if it doesn't exist"
)
duckdb_con.sql(
f"CREATE TABLE IF NOT EXISTS {remote_database}.{table} AS SELECT * FROM {local_database}.{table} limit 0"
)
# Delete any existing data in the date range
logger.info(f"Deleting data from {start_date} to {end_date}")
duckdb_con.sql(
f"DELETE FROM {remote_database}.main.{table} WHERE {timestamp_column} BETWEEN '{start_date}' AND '{end_date}'"
)
# Insert new data
logger.info(f"Writing data to motherduck {remote_database}.main.{table}")
duckdb_con.sql(
f"""
INSERT INTO {remote_database}.main.{table}
Expand Down
58 changes: 39 additions & 19 deletions ingestion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import List, Union, Annotated, Type
from pydantic import BaseModel, ValidationError
from datetime import datetime
from typing import Optional
from typing import Optional, Dict
import pandas as pd
import pyarrow as pa

DUCKDB_EXTENSION = ["aws", "httpfs"]

Expand Down Expand Up @@ -52,17 +53,18 @@ class Details(BaseModel):
openssl_version: Optional[str]
setuptools_version: Optional[str]
rustc_version: Optional[str]
ci: Optional[bool]


class FileDownloads(BaseModel):
timestamp: Optional[datetime]
country_code: Optional[str]
url: Optional[str]
project: Optional[str]
file: Optional[File]
details: Optional[Details]
tls_protocol: Optional[str]
tls_cipher: Optional[str]
timestamp: Optional[datetime] = None
country_code: Optional[str] = None
url: Optional[str] = None
project: Optional[str] = None
file: Optional[File] = None
details: Optional[Details] = None
tls_protocol: Optional[str] = None
tls_cipher: Optional[str] = None


class PypiJobParameters(BaseModel):
Expand All @@ -79,29 +81,47 @@ class PypiJobParameters(BaseModel):
aws_profile: Optional[str]


class DataFrameValidationError(Exception):
"""Custom exception for DataFrame validation errors."""
class TableValidationError(Exception):
"""Custom exception for Table validation errors."""

pass


def validate_dataframe(df: pd.DataFrame, model: Type[BaseModel]):
def validate_table(table: pa.Table, model: Type[BaseModel]):
"""
Validates each row of a DataFrame against a Pydantic model.
Raises DataFrameValidationError if any row fails validation.
Validates each row of a PyArrow Table against a Pydantic model.
Raises TableValidationError if any row fails validation.
:param df: DataFrame to validate.
:param table: PyArrow Table to validate.
:param model: Pydantic model to validate against.
:raises: DataFrameValidationError
:raises: TableValidationError
"""
errors = []

for i, row in enumerate(df.to_dict(orient="records")):
for i in range(table.num_rows):
row = {column: table[column][i].as_py() for column in table.column_names}
try:
model(**row)
except ValidationError as e:
errors.append(f"Row {i} failed validation: {e}")

if errors:
error_message = "\n".join(errors)
raise DataFrameValidationError(
f"DataFrame validation failed with the following errors:\n{error_message}"
raise TableValidationError(
f"Table validation failed with the following errors:\n{error_message}"
)


def duckdb_ddl_file_downloads(table_name="pypi_file_downloads"):
return f"""
CREATE TABLE IF NOT EXISTS {table_name} (
timestamp TIMESTAMP WITH TIME ZONE,
country_code VARCHAR,
url VARCHAR,
project VARCHAR,
file STRUCT("filename" VARCHAR, "project" VARCHAR, "version" VARCHAR, "type" VARCHAR),
details STRUCT("installer" STRUCT("name" VARCHAR, "version" VARCHAR), "python" VARCHAR, "implementation" STRUCT("name" VARCHAR, "version" VARCHAR), "distro" STRUCT("name" VARCHAR, "version" VARCHAR, "id" VARCHAR, "libc" STRUCT("lib" VARCHAR, "version" VARCHAR)), "system" STRUCT("name" VARCHAR, "release" VARCHAR), "cpu" VARCHAR, "openssl_version" VARCHAR, "setuptools_version" VARCHAR, "rustc_version" VARCHAR, "ci" BOOLEAN),
tls_protocol VARCHAR,
tls_cipher VARCHAR
)
"""
25 changes: 21 additions & 4 deletions ingestion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
build_pypi_query,
)
import duckdb
from datetime import datetime
from loguru import logger
from ingestion.duck import (
create_table_from_dataframe,
Expand All @@ -13,20 +14,31 @@
connect_to_md,
)
import fire
from ingestion.models import validate_dataframe, FileDownloads, PypiJobParameters
from ingestion.models import (
validate_table,
FileDownloads,
PypiJobParameters,
duckdb_ddl_file_downloads,
)
import os


def main(params: PypiJobParameters):
start_time = datetime.now()
# Loading data from BigQuery
df = get_bigquery_result(
pa_tbl = get_bigquery_result(
query_str=build_pypi_query(params),
bigquery_client=get_bigquery_client(project_name=params.gcp_project),
model=FileDownloads,
)
validate_dataframe(df, FileDownloads)
validate_table(pa_tbl, FileDownloads)
# Loading to DuckDB
conn = duckdb.connect()
create_table_from_dataframe(conn, params.table_name, "df")
create_table_from_dataframe(
duckdb_con=conn,
table_name=params.table_name,
table_ddl=duckdb_ddl_file_downloads("pypi_file_downloads"),
)

logger.info(f"Sinking data to {params.destination}")
if "local" in params.destination:
Expand All @@ -49,6 +61,11 @@ def main(params: PypiJobParameters):
start_date=params.start_date,
end_date=params.end_date,
)
end_time = datetime.now()
elapsed = (end_time - start_time).total_seconds()
logger.info(
f"Total job completed in {elapsed // 60} minutes and {elapsed % 60:.2f} seconds."
)


if __name__ == "__main__":
Expand Down
126 changes: 98 additions & 28 deletions ingestion/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@
from pydantic import BaseModel
import duckdb
from ingestion.models import (
validate_dataframe,
validate_table,
PypiJobParameters,
DataFrameValidationError,
TableValidationError,
FileDownloads,
File,
Details,
Installer,
Implementation,
Distro,
System,
Libc,
)
from ingestion.bigquery import build_pypi_query
import pyarrow as pa


class MyModel(BaseModel):
Expand All @@ -17,19 +25,96 @@ class MyModel(BaseModel):
column3: float


def test_validate_dataframe_with_valid_data():
valid_data = {"column1": [1, 2], "column2": ["a", "b"], "column3": [1.1, 2.2]}
valid_df = pd.DataFrame(valid_data)
errors = validate_dataframe(valid_df, MyModel)
@pytest.fixture
def file_downloads_table():
df = pd.DataFrame(
{
"timestamp": [
pd.Timestamp("2023-01-01T12:00:00Z"),
pd.Timestamp("2023-01-02T12:00:00Z"),
],
"country_code": ["US", "CA"],
"url": ["http://example.com/file1", "http://example.com/file2"],
"project": ["project1", "project2"],
"file": [
File(
filename="file1.txt", project="project1", version="1.0", type="txt"
).dict(),
File(
filename="file2.txt", project="project2", version="1.0", type="txt"
).dict(),
],
"details": [
Details(
installer=Installer(name="pip", version="21.0"),
python="3.8.5",
implementation=Implementation(name="CPython", version="3.8.5"),
distro=Distro(
name="Ubuntu",
version="20.04",
id="ubuntu2004",
libc=Libc(lib="glibc", version="2.31"),
),
system=System(name="Linux", release="5.4.0-58-generic"),
cpu="x86_64",
openssl_version="1.1.1",
setuptools_version="50.3.0",
rustc_version="1.47.0",
ci=False,
).dict(),
Details(
installer=Installer(name="pip", version="21.0"),
python="3.8.5",
implementation=Implementation(name="CPython", version="3.8.5"),
distro=Distro(
name="Ubuntu",
version="20.04",
id="ubuntu2004",
libc=Libc(lib="glibc", version="2.31"),
),
system=System(name="Linux", release="5.4.0-58-generic"),
cpu="x86_64",
openssl_version="1.1.1",
setuptools_version="50.3.0",
rustc_version="1.47.0",
ci=False,
).dict(),
],
"tls_protocol": ["TLSv1.2", "TLSv1.3"],
"tls_cipher": ["AES128-GCM-SHA256", "AES256-GCM-SHA384"],
}
)

return pa.Table.from_pandas(df)


def test_validate_table_with_valid_data():
valid_data = {
"column1": pa.array([1, 2]),
"column2": pa.array(["a", "b"]),
"column3": pa.array([1.1, 2.2]),
}
valid_table = pa.table(valid_data)
errors = validate_table(valid_table, MyModel)
assert not errors, f"Validation errors were found in valid data: {errors}"


def test_validate_dataframe_with_invalid_data():
invalid_data = {"column1": ["x", 2], "column2": ["a", 3], "column3": ["y", 2.2]}
invalid_df = pd.DataFrame(invalid_data)
def test_validate_table_with_valid_data():
valid_data = {
"column1": pa.array([1, 2]),
"column2": pa.array(["a", "b"]),
"column3": pa.array([1.1, 2.2]),
}
valid_table = pa.table(valid_data)
errors = validate_table(valid_table, MyModel)
assert not errors, f"Validation errors were found in valid data: {errors}"


with pytest.raises(DataFrameValidationError) as excinfo:
validate_dataframe(invalid_df, MyModel)
def test_file_downloads_validation(file_downloads_table):
try:
validate_table(file_downloads_table, FileDownloads)
except TableValidationError as e:
pytest.fail(f"Table validation failed: {e}")


def test_build_pypi_query():
Expand Down Expand Up @@ -81,7 +166,8 @@ def file_downloads_df():
cpu VARCHAR,
openssl_version VARCHAR,
setuptools_version VARCHAR,
rustc_version VARCHAR
rustc_version VARCHAR,
ci BOOLEAN
),
tls_protocol VARCHAR,
tls_cipher VARCHAR
Expand All @@ -93,19 +179,3 @@ def file_downloads_df():
conn.execute("COPY tbl FROM 'ingestion/tests/sample_file_downloads.csv' (HEADER)")
# Create DataFrame
return conn.execute("SELECT * FROM tbl").df()


def test_file_downloads_validation(file_downloads_df):
try:
validate_dataframe(file_downloads_df, FileDownloads)
except DataFrameValidationError as e:
pytest.fail(f"DataFrame validation failed: {e}")


def test_file_downloads_invalid_data(file_downloads_df):
# Introduce an invalid data entry
file_downloads_df.at[0, "details"] = 123 # Replace with an invalid entry

# Expect DataFrameValidationError to be raised
with pytest.raises(DataFrameValidationError):
validate_dataframe(file_downloads_df, FileDownloads)

0 comments on commit a6ef142

Please sign in to comment.