Skip to content

Commit

Permalink
csv tests
Browse files Browse the repository at this point in the history
Signed-off-by: ChungYujoyce <joyce.bhps@gmail.com>
  • Loading branch information
ChungYujoyce committed Jun 2, 2023
1 parent daa67a4 commit 4dc0052
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/flytekit/unit/core/test_structured_dataset_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ def test_pandas():
sd = StructuredDataset(dataframe=df)
sd_type = StructuredDatasetType(format="parquet")
sd_lit = encoder.encode(ctx, sd, sd_type)
df2 = decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type))
assert df.equals(df2)


def test_csv():
df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
encoder = basic_dfs.PandasToCSVEncodingHandler()
decoder = basic_dfs.CSVToPandasDecodingHandler()

ctx = context_manager.FlyteContextManager.current_context()
sd = StructuredDataset(dataframe=df)
sd_type = StructuredDatasetType(format="csv")
sd_lit = encoder.encode(ctx, sd, sd_type)

df2 = decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type))
assert df.equals(df2)
Expand All @@ -51,4 +64,5 @@ def test_arrow():
assert decoder.protocol is None
assert encoder.python_type is decoder.python_type
d = StructuredDatasetTransformerEngine.DECODERS[encoder.python_type]["fsspec"]["parquet"]
print(d)
assert d is not None
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

from flytekit import FlyteContext, FlyteContextManager, kwtypes, task, workflow
from flytekit.models import literals
from flytekit.types.structured.basic_dfs import PandasToCSVEncodingHandler, CSVToPandasDecodingHandler
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
DF,
PARQUET,
CSV,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
Expand Down Expand Up @@ -198,6 +200,21 @@ def t10(dataset: Annotated[StructuredDataset, my_cols]) -> np.ndarray:
return np_array


StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler())
StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler())

@task
def t11(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, CSV]:
# pandas -> csv
return StructuredDataset(dataframe=dataframe, uri=PANDAS_PATH)

@task
def t12(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
# cav -> pandas
df = dataset.open(pd.DataFrame).all()
return df


@task
def generate_pandas() -> pd.DataFrame:
return pd_df
Expand Down Expand Up @@ -231,6 +248,8 @@ def wf():
t8a(dataframe=arrow_df)
t9(dataframe=np_array)
t10(dataset=StructuredDataset(uri=NUMPY_PATH))
t11(dataframe=df)
t12(dataset=StructuredDataset(uri=PANDAS_PATH))


def test_structured_dataset_wf():
Expand Down

0 comments on commit 4dc0052

Please sign in to comment.