Skip to content

Commit

Permalink
add sd test and fix import
Browse files Browse the repository at this point in the history
Signed-off-by: Samhita Alla <samhitaalla@Samhitas-MacBook-Pro.local>
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
  • Loading branch information
samhita-alla committed Feb 2, 2023
1 parent be60572 commit d95dce2
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions plugins/flytekit-duckdb/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Union, Annotated
from typing import Union

import pandas as pd
import pyarrow as pa
from flytekitplugins.duckdb import DuckDBQuery
from typing_extensions import Annotated

from flytekit import kwtypes, task, workflow
from flytekit.types.structured.structured_dataset import StructuredDataset
Expand Down Expand Up @@ -61,9 +62,11 @@ def arrow_wf(arrow_table: pa.Table) -> pa.Table:
assert isinstance(arrow_wf(arrow_table=get_arrow_table()), pa.Table)


def test_structured_dataset():
def test_structured_dataset_arrow_table():
duckdb_task = DuckDBQuery(
name="duckdb_sd", query="SELECT * FROM arrow_table WHERE i = 2", inputs=kwtypes(arrow_table=StructuredDataset)
name="duckdb_sd_table",
query="SELECT * FROM arrow_table WHERE i = 2",
inputs=kwtypes(arrow_table=StructuredDataset),
)

@task
Expand All @@ -79,6 +82,26 @@ def arrow_wf(arrow_table: StructuredDataset) -> pa.Table:
assert isinstance(arrow_wf(arrow_table=get_arrow_table()), pa.Table)


def test_structured_dataset_pandas_df():
duckdb_task = DuckDBQuery(
name="duckdb_sd_df",
query="SELECT * FROM pandas_df WHERE i = 2",
inputs=kwtypes(pandas_df=StructuredDataset),
)

@task
def get_pandas_df() -> StructuredDataset:
return StructuredDataset(
dataframe=pd.DataFrame.from_dict({"i": [1, 2, 3, 4], "j": ["one", "two", "three", "four"]})
)

@workflow
def pandas_wf(pandas_df: StructuredDataset) -> pd.DataFrame:
return duckdb_task(pandas_df=pandas_df)

assert isinstance(pandas_wf(pandas_df=get_pandas_df()), pd.DataFrame)


def test_distinct_params():
duckdb_params_query = DuckDBQuery(
name="params_query",
Expand Down

0 comments on commit d95dce2

Please sign in to comment.