Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote client failed to fetch FlytePickle object #764

Merged
merged 3 commits into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,17 @@ def __init__(
# Not exposing this as a property for now.
self._entrypoint_settings = entrypoint_settings

raw_output_data_prefix = auth_config.RAW_OUTPUT_DATA_PREFIX.get() or os.path.join(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this above the line above? We should be able to do something like

if not file_access:
    ...  # all the code you have (with the defaults from the config file)
# ...build()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, sorry. my bad

sdk_config.LOCAL_SANDBOX.get(), "control_plane_raw"
)
self._file_access = file_access or FileAccessProvider(
local_sandbox_dir=os.path.join(sdk_config.LOCAL_SANDBOX.get(), "control_plane_metadata"),
raw_output_prefix=raw_output_data_prefix,
)
# Save the file access object locally, but also make it available for use from the context.
FlyteContextManager.with_context(FlyteContextManager.current_context().with_file_access(file_access).build())
self._file_access = file_access
FlyteContextManager.with_context(
FlyteContextManager.current_context().with_file_access(self._file_access).build()
)

# TODO: Reconsider whether we want this. Probably best to not cache.
self._serialized_entity_cache = OrderedDict()
Expand Down
7 changes: 6 additions & 1 deletion flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flytekit.models.core.types import BlobType
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types.pickle.pickle import FlytePickleTransformer


def noop():
Expand Down Expand Up @@ -348,7 +349,11 @@ def _downloader():
return ff

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
if literal_type.blob is not None and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format != FlytePickleTransformer.PYTHON_PICKLE_FORMAT
):
return FlyteFile.__class_getitem__(literal_type.blob.format)

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")
Expand Down
10 changes: 10 additions & 0 deletions flytekit/types/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
ctx.file_access.put_data(uri, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlytePickle[typing.Any]]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT
):
return FlytePickle

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")

def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
Expand Down
27 changes: 27 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,14 @@ def test_guessing_basic():
pt = TypeEngine.guess_python_type(lt)
assert pt is None

lt = model_types.LiteralType(
blob=BlobType(
format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=BlobType.BlobDimensionality.SINGLE
)
)
pt = TypeEngine.guess_python_type(lt)
assert pt is FlytePickle


def test_guessing_containers():
b = model_types.LiteralType(simple=model_types.SimpleType.BOOLEAN)
Expand Down Expand Up @@ -552,6 +560,25 @@ def test_enum_type():
TypeEngine.to_literal_type(UnsupportedEnumValues)


def test_pickle_type():
class Foo(object):
def __init__(self, number: int):
self.number = number

lt = TypeEngine.to_literal_type(FlytePickle)
assert lt.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT
assert lt.blob.dimensionality == BlobType.BlobDimensionality.SINGLE

ctx = FlyteContextManager.current_context()
lv = TypeEngine.to_literal(ctx, Foo(1), FlytePickle, lt)
assert "/tmp/flyte/" in lv.scalar.blob.uri

transformer = FlytePickleTransformer()
gt = transformer.guess_python_type(lt)
pv = transformer.to_python_value(ctx, lv, expected_python_type=gt)
assert Foo(1).number == pv.number


def test_enum_in_dataclass():
@dataclass_json
@dataclass
Expand Down