Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Csvtransform #1671

Merged
merged 5 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flytekit/types/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
)


def register_csv_handlers():

from .basic_dfs import CSVToPandasDecodingHandler, PandasToCSVEncodingHandler

StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler(), default_format_for_type=True)
StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler(), default_format_for_type=True)


def register_pandas_handlers():
import pandas as pd

Expand Down
49 changes: 49 additions & 0 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
CSV,
PARQUET,
StructuredDataset,
StructuredDatasetDecoder,
Expand All @@ -35,6 +36,54 @@ def get_storage_options(cfg: DataConfig, uri: str, anon: bool = False) -> typing
return None


class PandasToCSVEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
Comment on lines +39 to +40
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to register the handler in this file

super().__init__(pd.DataFrame, None, CSV)

def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
uri = typing.cast(str, structured_dataset.uri) or ctx.file_access.get_random_remote_directory()
if not ctx.file_access.is_remote(uri):
Path(uri).mkdir(parents=True, exist_ok=True)
path = os.path.join(uri, ".csv")
df = typing.cast(pd.DataFrame, structured_dataset.dataframe)
df.to_csv(
path,
index=False,
storage_options=get_storage_options(ctx.file_access.data_config, path),
)
structured_dataset_type.format = CSV
return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type))


class CSVToPandasDecodingHandler(StructuredDatasetDecoder):
def __init__(self):
super().__init__(pd.DataFrame, None, CSV)

def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> pd.DataFrame:
uri = flyte_value.uri
columns = None
kwargs = get_storage_options(ctx.file_access.data_config, uri)
path = os.path.join(uri, ".csv")
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
try:
return pd.read_csv(path, usecols=columns, storage_options=kwargs)
except NoCredentialsError:
logger.debug("S3 source detected, attempting anonymous S3 access")
kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True)
return pd.read_csv(path, usecols=columns, storage_options=kwargs)


class PandasToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(pd.DataFrame, None, PARQUET)
Expand Down
1 change: 1 addition & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

# Storage formats
PARQUET: StructuredDatasetFormat = "parquet"
CSV: StructuredDatasetFormat = "csv"
GENERIC_FORMAT: StructuredDatasetFormat = ""
GENERIC_PROTOCOL: str = "generic protocol"

Expand Down
14 changes: 14 additions & 0 deletions tests/flytekit/unit/core/test_structured_dataset_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def test_pandas():
assert df.equals(df2)


def test_csv():
df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
encoder = basic_dfs.PandasToCSVEncodingHandler()
decoder = basic_dfs.CSVToPandasDecodingHandler()

ctx = context_manager.FlyteContextManager.current_context()
sd = StructuredDataset(dataframe=df)
sd_type = StructuredDatasetType(format="csv")
sd_lit = encoder.encode(ctx, sd, sd_type)

df2 = decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type))
assert df.equals(df2)


def test_base_isnt_instantiable():
with pytest.raises(TypeError):
StructuredDatasetEncoder(pd.DataFrame, "", "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import StructuredDatasetType
from flytekit.types.structured.basic_dfs import CSVToPandasDecodingHandler, PandasToCSVEncodingHandler
from flytekit.types.structured.structured_dataset import (
CSV,
DF,
PARQUET,
StructuredDataset,
Expand Down Expand Up @@ -198,6 +200,23 @@ def t10(dataset: Annotated[StructuredDataset, my_cols]) -> np.ndarray:
return np_array


StructuredDatasetTransformerEngine.register(PandasToCSVEncodingHandler())
StructuredDatasetTransformerEngine.register(CSVToPandasDecodingHandler())


@task
def t11(dataframe: pd.DataFrame) -> Annotated[StructuredDataset, CSV]:
# pandas -> csv
return StructuredDataset(dataframe=dataframe, uri=PANDAS_PATH)


@task
def t12(dataset: Annotated[StructuredDataset, my_cols]) -> pd.DataFrame:
# csv -> pandas
df = dataset.open(pd.DataFrame).all()
return df


@task
def generate_pandas() -> pd.DataFrame:
return pd_df
Expand Down Expand Up @@ -231,6 +250,8 @@ def wf():
t8a(dataframe=arrow_df)
t9(dataframe=np_array)
t10(dataset=StructuredDataset(uri=NUMPY_PATH))
t11(dataframe=df)
t12(dataset=StructuredDataset(uri=PANDAS_PATH))


def test_structured_dataset_wf():
Expand Down