Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pr into #785] Turn structured dataset into dataclass #802

Merged
merged 21 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 5 additions & 31 deletions dataset_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pyspark.sql.dataframe
from pyspark.sql import SparkSession

from flytekit import FlyteContext, kwtypes, task, workflow
from flytekit.models import literals
Expand All @@ -27,8 +25,8 @@
)
from flytekit.types.structured.utils import get_filesystem

PANDAS_PATH = "/tmp/pandas"
NUMPY_PATH = "/tmp/numpy"
PANDAS_PATH = "s3://flyte-batch/my-s3-bucket/test-data/pandas"
NUMPY_PATH = "s3://flyte-batch/my-s3-bucket/test-data/numpy"
BQ_PATH = "bq://photo-313016:flyte.new_table3"

# https://github.com/flyteorg/flyte/issues/523
Expand Down Expand Up @@ -60,7 +58,7 @@ def t2(dataframe: pd.DataFrame) -> Annotated[pd.DataFrame, arrow_schema]:
@task
def t3(dataset: StructuredDataset[my_cols]) -> StructuredDataset[my_cols]:
# s3 (parquet) -> pandas -> s3 (parquet)
print("Pandas dataframe")
print("Pandas dataframe:")
print(dataset.open(pd.DataFrame).all())
# In the example, we download dataset when we open it.
# Here we won't upload anything, since we're returning just the input object.
Expand Down Expand Up @@ -148,6 +146,8 @@ def decode(

FLYTE_DATASET_TRANSFORMER.register_handler(NumpyEncodingHandlers(np.ndarray, "/", "parquet"))
FLYTE_DATASET_TRANSFORMER.register_handler(NumpyDecodingHandlers(np.ndarray, "/", "parquet"))
FLYTE_DATASET_TRANSFORMER.register_handler(NumpyEncodingHandlers(np.ndarray, "s3://", "parquet"))
FLYTE_DATASET_TRANSFORMER.register_handler(NumpyDecodingHandlers(np.ndarray, "s3://", "parquet"))


@task
Expand All @@ -163,17 +163,6 @@ def t10(dataset: StructuredDataset[my_cols]) -> np.ndarray:
return np_array


@task
def t11(dataframe: pyspark.sql.dataframe.DataFrame) -> StructuredDataset[my_cols]:
return StructuredDataset(dataframe)


@task
def t12(dataset: StructuredDataset[my_cols]) -> pyspark.sql.dataframe.DataFrame:
spark_df = dataset.open(pyspark.sql.dataframe.DataFrame).all()
return spark_df


@task
def generate_pandas() -> pd.DataFrame:
return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
Expand All @@ -189,24 +178,11 @@ def generate_arrow() -> pa.Table:
return pa.Table.from_pandas(pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}))


@task
def generate_spark_dataframe() -> pyspark.sql.dataframe.DataFrame:
data = [
{"Category": "A", "ID": 1, "Value": 121.44, "Truth": True},
{"Category": "B", "ID": 2, "Value": 300.01, "Truth": False},
{"Category": "C", "ID": 3, "Value": 10.99, "Truth": None},
{"Category": "E", "ID": 4, "Value": 33.87, "Truth": True},
]
spark = SparkSession.builder.getOrCreate()
return spark.createDataFrame(data)


@workflow()
def wf():
df = generate_pandas()
np_array = generate_numpy()
arrow_df = generate_arrow()
spark_df = generate_spark_dataframe()
t1(dataframe=df)
t1a(dataframe=df)
t2(dataframe=df)
Expand All @@ -220,8 +196,6 @@ def wf():
t8a(dataframe=arrow_df)
t9(dataframe=np_array)
t10(dataset=StructuredDataset(uri=NUMPY_PATH))
dataset = t11(dataframe=spark_df)
t12(dataset=dataset)


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,11 @@ def transform_function_to_interface(fn: Callable, docstring: Optional[Docstring]
For now the fancy object, maybe in the future a dumb object.

"""
type_hints = typing.get_type_hints(fn)
try:
# include_extras can only be used in python >= 3.9
type_hints = typing.get_type_hints(fn, include_extras=True)
except TypeError:
type_hints = typing.get_type_hints(fn)
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)

Expand Down
28 changes: 27 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@
Primitive,
Scalar,
Schema,
StructuredDatasetMetadata,
)
from flytekit.models.types import LiteralType, SimpleType
from flytekit.models.types import LiteralType, SimpleType, StructuredDatasetType

T = typing.TypeVar("T")
DEFINITIONS = "definitions"
Expand Down Expand Up @@ -280,6 +281,7 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
from flytekit.types.directory.types import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.schema.types import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset

for f in dataclasses.fields(python_type):
v = python_val.__getattribute__(f.name)
Expand All @@ -288,6 +290,7 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
issubclass(field_type, FlyteSchema)
or issubclass(field_type, FlyteFile)
or issubclass(field_type, FlyteDirectory)
or issubclass(field_type, StructuredDataset)
):
lv = TypeEngine.to_literal(FlyteContext.current_context(), v, field_type, None)
# dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a
Expand All @@ -300,6 +303,13 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]):
# as determined by the transformer.
if issubclass(field_type, FlyteFile) or issubclass(field_type, FlyteDirectory):
python_val.__setattr__(f.name, field_type(path=lv.scalar.blob.uri))
elif issubclass(field_type, StructuredDataset):
python_val.__setattr__(
f.name,
field_type(
uri=lv.scalar.structured_dataset.uri,
),
)

elif dataclasses.is_dataclass(field_type):
self._serialize_flyte_type(v, field_type)
Expand All @@ -308,6 +318,7 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) ->
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from flytekit.types.schema.types import FlyteSchema, FlyteSchemaTransformer
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine

if not dataclasses.is_dataclass(expected_python_type):
return python_val
Expand Down Expand Up @@ -353,6 +364,21 @@ def _deserialize_flyte_type(self, python_val: T, expected_python_type: Type) ->
),
expected_python_type,
)
elif issubclass(expected_python_type, StructuredDataset):
return StructuredDatasetTransformerEngine().to_python_value(
FlyteContext.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=python_val.file_format)
),
uri=python_val.uri,
)
)
),
expected_python_type,
)
else:
for f in dataclasses.fields(expected_python_type):
value = python_val.__getattribute__(f.name)
Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def external_schema_bytes(self) -> bytes:
def to_flyte_idl(self) -> _types_pb2.StructuredDatasetType:
return _types_pb2.StructuredDatasetType(
columns=[c.to_flyte_idl() for c in self.columns] if self.columns else None,
format=self._format,
format=self.format,
external_schema_type=self.external_schema_type if self.external_schema_type else None,
external_schema_bytes=self.external_schema_bytes if self.external_schema_bytes else None,
)
Expand Down
6 changes: 1 addition & 5 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ def decode(
path = flyte_value.uri
local_dir = ctx.file_access.get_random_local_directory()
ctx.file_access.get_data(path, local_dir, is_multipart=True)
frames = [pandas.read_parquet(os.path.join(local_dir, f)) for f in os.listdir(local_dir)]
if len(frames) == 1:
return frames[0]
elif len(frames) > 1:
return pandas.concat(frames, copy=True)
return pd.read_parquet(local_dir)


class ArrowToParquetEncodingHandler(StructuredDatasetEncoder):
Expand Down
69 changes: 39 additions & 30 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import collections
import inspect
import os
import re
import types
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, Generator, Optional, Type, Union

from dataclasses_json import config, dataclass_json
from marshmallow import fields

try:
from typing import Annotated, get_args, get_origin
except ImportError:
Expand Down Expand Up @@ -37,7 +43,11 @@
PARQUET = "parquet"


@dataclass_json
@dataclass
class StructuredDataset(object):
uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String()))
file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String()))
"""
This is the user facing StructuredDataset class. Please don't confuse it with the literals.StructuredDataset
class (that is just a model, a Python class representation of the protobuf).
Expand Down Expand Up @@ -93,9 +103,12 @@ def __init__(
dataframe: typing.Optional[typing.Any] = None,
uri: Optional[str] = None,
metadata: typing.Optional[literals.StructuredDatasetMetadata] = None,
**kwargs,
):
self._dataframe = dataframe
self._uri = uri
# Make these fields public, so that the dataclass transformer can set a value for it
# https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
self.uri = uri
# This is a special attribute that indicates if the data was either downloaded or uploaded
self._metadata = metadata
# This is not for users to set, the transformer will set this.
Expand All @@ -107,18 +120,6 @@ def __init__(
def dataframe(self) -> Type[typing.Any]:
return self._dataframe

@property
def uri(self) -> Optional[str]:
return self._uri

@uri.setter
def uri(self, uri: str):
self._uri = uri

@property
def file_format(self) -> str:
return self.FILE_FORMAT

@property
def metadata(self) -> Optional[StructuredDatasetMetadata]:
return self._metadata
Expand Down Expand Up @@ -192,6 +193,11 @@ def encode(

:param ctx:
:param structured_dataset: This is a StructuredDataset wrapper object. See more info above.
:param structured_dataset_type: This the StructuredDatasetType, as found in the LiteralType of the interface
of the task that invoked this encoding call. It is passed along to encoders so that authors of encoders
can include it in the returned literals.StructuredDataset. See the IDL for more information on why this
literal in particular carries the type information along with it. If the encoder doesn't supply it, it will
also be filled in after the encoder runs by the transformer engine.
:return: This function should return a StructuredDataset literal object. Do not confuse this with the
StructuredDataset wrapper class used as input to this function - that is the user facing Python class.
This function needs to return the IDL StructuredDataset.
Expand Down Expand Up @@ -306,11 +312,9 @@ def _finder(self, handler_map, df_type: Type, protocol: str, format: str):
raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt {format}")

def get_encoder(self, df_type: Type, protocol: str, format: str):
protocol.replace("://", "")
return self._finder(self.ENCODERS, df_type, protocol, format)

def get_decoder(self, df_type: Type, protocol: str, format: str):
protocol.replace("://", "")
return self._finder(self.DECODERS, df_type, protocol, format)

def _handler_finder(self, h: Handlers) -> Dict[str, Handlers]:
Expand Down Expand Up @@ -365,21 +369,21 @@ def to_literal(
) -> Literal:
# Make a copy in case we need to hand off to encoders, since we can't be sure of mutations.
# Check first to see if it's even an SD type. For backwards compatibility, we may be getting a
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, None))
if expected.structured_dataset_type:

if expected and expected.structured_dataset_type:
sdt = StructuredDatasetType(
columns=expected.structured_dataset_type.columns,
format=expected.structured_dataset_type.format,
external_schema_type=expected.structured_dataset_type.external_schema_type,
external_schema_bytes=expected.structured_dataset_type.external_schema_bytes,
)

if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]

# If the type signature has the StructuredDataset class, it will, or at least should, also be a
# StructuredDataset instance.
if issubclass(python_type, StructuredDataset):
if inspect.isclass(python_type) and issubclass(python_type, StructuredDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works for python 3.7-3.10 right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I've tested it with python 3.7~3.10.

assert isinstance(python_val, StructuredDataset)
# There are three cases that we need to take care of here.

Expand Down Expand Up @@ -431,7 +435,7 @@ def to_literal(
# Otherwise assume it's a dataframe instance. Wrap it with some defaults
fmt = self.DEFAULT_FORMATS[python_type]
protocol = self.DEFAULT_PROTOCOLS[python_type]
meta = StructuredDatasetMetadata(structured_dataset_type=expected.structured_dataset_type)
meta = StructuredDatasetMetadata(structured_dataset_type=expected.structured_dataset_type if expected else None)
sd = StructuredDataset(dataframe=python_val, metadata=meta)
return self.encode(ctx, sd, python_type, protocol, fmt, sdt)

Expand Down Expand Up @@ -464,19 +468,22 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
# The literal that we get in might be an old FlyteSchema.
# We'll continue to support this for the time being.
if lv.scalar.schema is not None:
sd = StructuredDataset()
sd_literal = literals.StructuredDataset(
uri=lv.scalar.schema.uri,
metadata=literals.StructuredDatasetMetadata(
# Dataframe will always be serialized to parquet file by FlyteSchema transformer
structured_dataset_type=StructuredDatasetType(format=PARQUET)
),
)
sd._literal_sd = sd_literal
if issubclass(expected_python_type, StructuredDataset):
raise ValueError("We do not support FlyteSchema -> StructuredDataset transformations")
return sd
else:
sd_literal = literals.StructuredDataset(
uri=lv.scalar.schema.uri,
metadata=literals.StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=self.DEFAULT_FORMATS[expected_python_type])
),
)
return self.open_as(ctx, sd_literal, df_type=expected_python_type)

# Either a StructuredDataset type or some dataframe type.
if issubclass(expected_python_type, StructuredDataset):
if inspect.isclass(expected_python_type) and issubclass(expected_python_type, StructuredDataset):
# Just save the literal for now. If in the future we find that we need the StructuredDataset type hint
# type also, we can add it.
sd = expected_python_type(
Expand All @@ -490,6 +497,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:

# If the requested type was not a StructuredDataset, then it means it was a plain dataframe type, which means
# we should do the opening/downloading and whatever else it might entail right now. No iteration option here.
if get_origin(expected_python_type) is Annotated:
expected_python_type = get_args(expected_python_type)[0]
return self.open_as(ctx, lv.scalar.structured_dataset, df_type=expected_python_type)

def open_as(self, ctx: FlyteContext, sd: literals.StructuredDataset, df_type: Type[DF]) -> DF:
Expand Down Expand Up @@ -542,7 +551,7 @@ def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]
raise ValueError(f"Unrecognized Annotated type for StructuredDataset {t}")

# 2. Fill in columns by checking for StructuredDataset metadata. For example, StructuredDataset[my_cols, parquet]
elif issubclass(t, StructuredDataset):
elif inspect.isclass(t) and issubclass(t, StructuredDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remind me again what this inspect.isclass is supposed to catch? can you add a comment? i keep forgetting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's for Annotated[pd.Dataframe, my_col]. I just moved expected_python_type = get_args(expected_python_type)[0] to the beginning of the to_python and to_literal. Therefore, we don't need inspect.isclass(t) any more, so I removed it.

for k, v in t.columns().items():
lt = self._get_dataset_column_literal_type(v)
converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt))
Expand Down
Empty file.
19 changes: 19 additions & 0 deletions tests/flytekit/unit/core/hint_handling/a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated


class AA(object):
...


my_aa = Annotated[AA, "some annotation"]


def t1(in1: int) -> AA:
return AA()


def t2(in1: int) -> my_aa:
return my_aa()
Loading