-
Notifications
You must be signed in to change notification settings - Fork 250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pr into #785] Turn structured dataset into dataclass #802
Changes from 18 commits
99931c2
55f2e3e
2ab49a8
3254f37
9f36800
baf9e77
8373e5e
7fd01ff
6fb8602
7701660
8b71946
db66694
6b4f8b4
68daa48
92286bc
8b256b8
b80295d
705d2ce
6ef85d5
dad5e0e
f70e32d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,18 @@ | ||
from __future__ import annotations | ||
|
||
import collections | ||
import inspect | ||
import os | ||
import re | ||
import types | ||
import typing | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Generator, Optional, Type, Union | ||
|
||
from dataclasses_json import config, dataclass_json | ||
from marshmallow import fields | ||
|
||
try: | ||
from typing import Annotated, get_args, get_origin | ||
except ImportError: | ||
|
@@ -37,7 +43,11 @@ | |
PARQUET = "parquet" | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class StructuredDataset(object): | ||
uri: typing.Optional[os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) | ||
file_format: typing.Optional[str] = field(default=PARQUET, metadata=config(mm_field=fields.String())) | ||
""" | ||
This is the user facing StructuredDataset class. Please don't confuse it with the literals.StructuredDataset | ||
class (that is just a model, a Python class representation of the protobuf). | ||
|
@@ -93,9 +103,12 @@ def __init__( | |
dataframe: typing.Optional[typing.Any] = None, | ||
uri: Optional[str] = None, | ||
metadata: typing.Optional[literals.StructuredDatasetMetadata] = None, | ||
**kwargs, | ||
): | ||
self._dataframe = dataframe | ||
self._uri = uri | ||
# Make these fields public, so that the dataclass transformer can set a value for it | ||
# https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298 | ||
self.uri = uri | ||
# This is a special attribute that indicates if the data was either downloaded or uploaded | ||
self._metadata = metadata | ||
# This is not for users to set, the transformer will set this. | ||
|
@@ -107,18 +120,6 @@ def __init__( | |
def dataframe(self) -> Type[typing.Any]: | ||
return self._dataframe | ||
|
||
@property | ||
def uri(self) -> Optional[str]: | ||
return self._uri | ||
|
||
@uri.setter | ||
def uri(self, uri: str): | ||
self._uri = uri | ||
|
||
@property | ||
def file_format(self) -> str: | ||
return self.FILE_FORMAT | ||
|
||
@property | ||
def metadata(self) -> Optional[StructuredDatasetMetadata]: | ||
return self._metadata | ||
|
@@ -192,6 +193,11 @@ def encode( | |
|
||
:param ctx: | ||
:param structured_dataset: This is a StructuredDataset wrapper object. See more info above. | ||
:param structured_dataset_type: This the StructuredDatasetType, as found in the LiteralType of the interface | ||
of the task that invoked this encoding call. It is passed along to encoders so that authors of encoders | ||
can include it in the returned literals.StructuredDataset. See the IDL for more information on why this | ||
literal in particular carries the type information along with it. If the encoder doesn't supply it, it will | ||
also be filled in after the encoder runs by the transformer engine. | ||
:return: This function should return a StructuredDataset literal object. Do not confuse this with the | ||
StructuredDataset wrapper class used as input to this function - that is the user facing Python class. | ||
This function needs to return the IDL StructuredDataset. | ||
|
@@ -306,11 +312,9 @@ def _finder(self, handler_map, df_type: Type, protocol: str, format: str): | |
raise ValueError(f"Failed to find a handler for {df_type}, protocol {protocol}, fmt {format}") | ||
|
||
def get_encoder(self, df_type: Type, protocol: str, format: str): | ||
protocol.replace("://", "") | ||
return self._finder(self.ENCODERS, df_type, protocol, format) | ||
|
||
def get_decoder(self, df_type: Type, protocol: str, format: str): | ||
protocol.replace("://", "") | ||
return self._finder(self.DECODERS, df_type, protocol, format) | ||
|
||
def _handler_finder(self, h: Handlers) -> Dict[str, Handlers]: | ||
|
@@ -365,21 +369,21 @@ def to_literal( | |
) -> Literal: | ||
# Make a copy in case we need to hand off to encoders, since we can't be sure of mutations. | ||
# Check first to see if it's even an SD type. For backwards compatibility, we may be getting a | ||
if get_origin(python_type) is Annotated: | ||
python_type = get_args(python_type)[0] | ||
sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, None)) | ||
if expected.structured_dataset_type: | ||
|
||
if expected and expected.structured_dataset_type: | ||
sdt = StructuredDatasetType( | ||
columns=expected.structured_dataset_type.columns, | ||
format=expected.structured_dataset_type.format, | ||
external_schema_type=expected.structured_dataset_type.external_schema_type, | ||
external_schema_bytes=expected.structured_dataset_type.external_schema_bytes, | ||
) | ||
|
||
if get_origin(python_type) is Annotated: | ||
python_type = get_args(python_type)[0] | ||
|
||
# If the type signature has the StructuredDataset class, it will, or at least should, also be a | ||
# StructuredDataset instance. | ||
if issubclass(python_type, StructuredDataset): | ||
if inspect.isclass(python_type) and issubclass(python_type, StructuredDataset): | ||
assert isinstance(python_val, StructuredDataset) | ||
# There are three cases that we need to take care of here. | ||
|
||
|
@@ -431,7 +435,7 @@ def to_literal( | |
# Otherwise assume it's a dataframe instance. Wrap it with some defaults | ||
fmt = self.DEFAULT_FORMATS[python_type] | ||
protocol = self.DEFAULT_PROTOCOLS[python_type] | ||
meta = StructuredDatasetMetadata(structured_dataset_type=expected.structured_dataset_type) | ||
meta = StructuredDatasetMetadata(structured_dataset_type=expected.structured_dataset_type if expected else None) | ||
sd = StructuredDataset(dataframe=python_val, metadata=meta) | ||
return self.encode(ctx, sd, python_type, protocol, fmt, sdt) | ||
|
||
|
@@ -464,19 +468,22 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: | |
# The literal that we get in might be an old FlyteSchema. | ||
# We'll continue to support this for the time being. | ||
if lv.scalar.schema is not None: | ||
sd = StructuredDataset() | ||
sd_literal = literals.StructuredDataset( | ||
uri=lv.scalar.schema.uri, | ||
metadata=literals.StructuredDatasetMetadata( | ||
# Dataframe will always be serialized to parquet file by FlyteSchema transformer | ||
structured_dataset_type=StructuredDatasetType(format=PARQUET) | ||
), | ||
) | ||
sd._literal_sd = sd_literal | ||
if issubclass(expected_python_type, StructuredDataset): | ||
raise ValueError("We do not support FlyteSchema -> StructuredDataset transformations") | ||
return sd | ||
else: | ||
sd_literal = literals.StructuredDataset( | ||
uri=lv.scalar.schema.uri, | ||
metadata=literals.StructuredDatasetMetadata( | ||
structured_dataset_type=StructuredDatasetType(format=self.DEFAULT_FORMATS[expected_python_type]) | ||
), | ||
) | ||
return self.open_as(ctx, sd_literal, df_type=expected_python_type) | ||
|
||
# Either a StructuredDataset type or some dataframe type. | ||
if issubclass(expected_python_type, StructuredDataset): | ||
if inspect.isclass(expected_python_type) and issubclass(expected_python_type, StructuredDataset): | ||
# Just save the literal for now. If in the future we find that we need the StructuredDataset type hint | ||
# type also, we can add it. | ||
sd = expected_python_type( | ||
|
@@ -490,6 +497,8 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: | |
|
||
# If the requested type was not a StructuredDataset, then it means it was a plain dataframe type, which means | ||
# we should do the opening/downloading and whatever else it might entail right now. No iteration option here. | ||
if get_origin(expected_python_type) is Annotated: | ||
expected_python_type = get_args(expected_python_type)[0] | ||
return self.open_as(ctx, lv.scalar.structured_dataset, df_type=expected_python_type) | ||
|
||
def open_as(self, ctx: FlyteContext, sd: literals.StructuredDataset, df_type: Type[DF]) -> DF: | ||
|
@@ -542,7 +551,7 @@ def _get_dataset_type(self, t: typing.Union[Type[StructuredDataset], typing.Any] | |
raise ValueError(f"Unrecognized Annotated type for StructuredDataset {t}") | ||
|
||
# 2. Fill in columns by checking for StructuredDataset metadata. For example, StructuredDataset[my_cols, parquet] | ||
elif issubclass(t, StructuredDataset): | ||
elif inspect.isclass(t) and issubclass(t, StructuredDataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remind me again what this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's for |
||
for k, v in t.columns().items(): | ||
lt = self._get_dataset_column_literal_type(v) | ||
converted_cols.append(StructuredDatasetType.DatasetColumn(name=k, literal_type=lt)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
try: | ||
from typing import Annotated | ||
except ImportError: | ||
from typing_extensions import Annotated | ||
|
||
|
||
class AA(object): | ||
... | ||
|
||
|
||
my_aa = Annotated[AA, "some annotation"] | ||
|
||
|
||
def t1(in1: int) -> AA: | ||
return AA() | ||
|
||
|
||
def t2(in1: int) -> my_aa: | ||
return my_aa() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
works for python 3.7-3.10 right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I've tested it with python 3.7~3.10.