Skip to content

Commit

Permalink
pickle dictionary when it isn't JSON serializable (#2390)
Browse files Browse the repository at this point in the history
* pickle dict when it isn't JSON serializable

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* lint

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix circular import and add dick pickling to boto agent

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* make dict optional

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* import pickle

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* incorporate suggestion by @pingsutw to fix metadata passage to to_python_value method

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* lint

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* lint

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* replace literal with literalmap

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix test

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* fix boto test

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* incorporate @pingsutw's suggestions

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* revert outputs and fix lint

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* lint

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* lint

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

* update boto agent test

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>

---------

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
  • Loading branch information
samhita-alla committed May 9, 2024
1 parent 94a48ae commit e6e08f9
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 98 deletions.
8 changes: 6 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
break

# If the current value is a dataclass, resolve the dataclass with the remaining path
if type(curr_val.value) is _literals_models.Scalar and type(curr_val.value.value) is _struct.Struct:
if (
len(p.attr_path) > 0
and type(curr_val.value) is _literals_models.Scalar
and type(curr_val.value.value) is _struct.Struct
):
st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
Expand Down Expand Up @@ -729,7 +733,7 @@ def binding_data_from_python_std(
lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type)
return _literals_models.BindingData(scalar=lit.scalar)
else:
_, v_type = DictTransformer.get_dict_types(t_value_type)
_, v_type = DictTransformer.extract_types_or_metadata(t_value_type)
m = _literals_models.BindingDataMap(
bindings={
k: binding_data_from_python_std(
Expand Down
89 changes: 72 additions & 17 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import textwrap
import typing
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast

Expand Down Expand Up @@ -713,7 +714,7 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any:
return list(map(lambda x: self._fix_val_int(ListTransformer.get_sub_type(t), x), val))

if isinstance(val, dict):
ktype, vtype = DictTransformer.get_dict_types(t)
ktype, vtype = DictTransformer.extract_types_or_metadata(t)
# Handle nested Dict. e.g. {1: {2: 3}, 4: {5: 6}})
return {
self._fix_val_int(cast(type, ktype), k): self._fix_val_int(cast(type, vtype), v) for k, v in val.items()
Expand Down Expand Up @@ -1660,13 +1661,10 @@ class DictTransformer(TypeTransformer[dict]):
"""

def __init__(self):
super().__init__("Typed Dict", dict)
super().__init__("Python Dictionary", dict)

@staticmethod
def get_dict_types(t: Optional[Type[dict]]) -> typing.Tuple[Optional[type], Optional[type]]:
"""
Return the generic Type T of the Dict
"""
def extract_types_or_metadata(t: Optional[Type[dict]]) -> typing.Tuple:
_origin = get_origin(t)
_args = get_args(t)
if _origin is not None:
Expand All @@ -1679,22 +1677,60 @@ def get_dict_types(t: Optional[Type[dict]]) -> typing.Tuple[Optional[type], Opti
raise ValueError(
f"Flytekit does not currently have support for FlyteAnnotations applied to dicts. {t} cannot be parsed."
)
if _origin is dict and _args is not None:
if _origin in [dict, Annotated] and _args is not None:
return _args # type: ignore
return None, None

@staticmethod
def dict_to_generic_literal(v: dict) -> Literal:
def dict_to_generic_literal(v: dict, allow_pickle: bool) -> Literal:
"""
Creates a flyte-specific ``Literal`` value from a native python dictionary.
"""
return Literal(scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())))
from flytekit.types.pickle import FlytePickle

try:
return Literal(
scalar=Scalar(generic=_json_format.Parse(json.dumps(v), _struct.Struct())),
metadata={"format": "json"},
)
except TypeError as e:
if allow_pickle:
remote_path = FlytePickle.to_pickle(v)
return Literal(
scalar=Scalar(
generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct())
),
metadata={"format": "pickle"},
)
raise e

@staticmethod
def is_pickle(python_type: Type[dict]) -> typing.Tuple[bool, Type]:
base_type, *metadata = DictTransformer.extract_types_or_metadata(python_type)

for each_metadata in metadata:
if isinstance(each_metadata, OrderedDict):
allow_pickle = each_metadata.get("allow_pickle", False)
return allow_pickle, base_type

return False, base_type

@staticmethod
def dict_types(python_type: Type) -> typing.Tuple[typing.Any, ...]:
if get_origin(python_type) is Annotated:
base_type, *_ = DictTransformer.extract_types_or_metadata(python_type)
tp = get_args(base_type)
else:
tp = DictTransformer.extract_types_or_metadata(python_type)

return tp

def get_literal_type(self, t: Type[dict]) -> LiteralType:
"""
Transforms a native python dictionary to a flyte-specific ``LiteralType``
"""
tp = self.get_dict_types(t)
tp = self.dict_types(t)

if tp:
if tp[0] == str:
try:
Expand All @@ -1710,21 +1746,33 @@ def to_literal(
if type(python_val) != dict:
raise TypeTransformerFailedError("Expected a dict")

allow_pickle = False
base_type = None

if get_origin(python_type) is Annotated:
allow_pickle, base_type = DictTransformer.is_pickle(python_type)

if expected and expected.simple and expected.simple == SimpleType.STRUCT:
return self.dict_to_generic_literal(python_val)
return self.dict_to_generic_literal(python_val, allow_pickle)

lit_map = {}
for k, v in python_val.items():
if type(k) != str:
raise ValueError("Flyte MapType expects all keys to be strings")
# TODO: log a warning for Annotated objects that contain HashMethod
k_type, v_type = self.get_dict_types(python_type)

if base_type:
_, v_type = get_args(base_type)
else:
_, v_type = self.extract_types_or_metadata(python_type)

lit_map[k] = TypeEngine.to_literal(ctx, v, cast(type, v_type), expected.map_value_type)
return Literal(map=LiteralMap(literals=lit_map))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict:
if lv and lv.map and lv.map.literals is not None:
tp = self.get_dict_types(expected_python_type)
tp = self.dict_types(expected_python_type)

if tp is None or tp[0] is None:
raise TypeError(
"TypeMismatch: Cannot convert to python dictionary from Flyte Literal Dictionary as the given "
Expand All @@ -1741,10 +1789,17 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
# for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict
# evaluates to false
if lv and lv.scalar and lv.scalar.generic is not None:
try:
return json.loads(_json_format.MessageToJson(lv.scalar.generic))
except TypeError:
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")
if lv.metadata["format"] == "json":
try:
return json.loads(_json_format.MessageToJson(lv.scalar.generic))
except TypeError:
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")
elif lv.metadata["format"] == "pickle":
from flytekit.types.pickle import FlytePickle

uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file")
return FlytePickle.from_pickle(uri)

raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

def guess_python_type(self, literal_type: LiteralType) -> Union[Type[dict], typing.Dict[Type, Type]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional

from flyteidl.core.execution_pb2 import TaskExecution
from typing_extensions import Annotated

from flytekit import FlyteContextManager, kwtypes
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import (
AgentRegistry,
Resource,
Expand Down Expand Up @@ -54,9 +57,19 @@ async def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = N
inputs=inputs,
)

outputs = None
outputs = {"result": {"result": None}}
if result:
outputs = {"result": result}
ctx = FlyteContextManager.current_context()
outputs = LiteralMap(
literals={
"result": TypeEngine.to_literal(
ctx,
result,
Annotated[dict, kwtypes(allow_pickle=True)],
TypeEngine.to_literal_type(dict),
)
}
)

return Resource(phase=TaskExecution.SUCCEEDED, outputs=outputs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def __init__(
name=name,
task_config=task_config,
task_type=self._TASK_TYPE,
interface=Interface(inputs=inputs, outputs=kwtypes(result=dict)),
interface=Interface(
inputs=inputs,
outputs=kwtypes(result=dict),
),
**kwargs,
)

Expand Down
77 changes: 56 additions & 21 deletions plugins/flytekit-aws-sagemaker/tests/test_boto3_agent.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,63 @@
from datetime import timedelta
from datetime import datetime, timedelta
from unittest import mock

import pytest
from flyteidl.core.execution_pb2 import TaskExecution

from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.interaction.string_literals import literal_map_string_repr
from flytekit.interfaces.cli_identifiers import Identifier
from flytekit.models import literals
from flytekit.models.core.identifier import ResourceType
from flytekit.models.task import RuntimeMetadata, TaskMetadata, TaskTemplate


@pytest.mark.asyncio
@pytest.mark.parametrize(
"mock_return_value",
[
(
{
"ResponseMetadata": {
"RequestId": "66f80391-348a-4ee0-9158-508914d16db2",
"HTTPStatusCode": 200.0,
"RetryAttempts": 0.0,
"HTTPHeaders": {
"content-type": "application/x-amz-json-1.1",
"date": "Wed, 31 Jan 2024 16:43:52 GMT",
"x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2",
"content-length": "114",
},
},
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
}
),
(
{
"ResponseMetadata": {
"RequestId": "66f80391-348a-4ee0-9158-508914d16db2",
"HTTPStatusCode": 200.0,
"RetryAttempts": 0.0,
"HTTPHeaders": {
"content-type": "application/x-amz-json-1.1",
"date": "Wed, 31 Jan 2024 16:43:52 GMT",
"x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2",
"content-length": "114",
},
},
"pickle_check": datetime(2024, 5, 5),
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
}
),
(None),
],
)
@mock.patch(
"flytekitplugins.awssagemaker_inference.boto3_agent.Boto3AgentMixin._call",
return_value={
"ResponseMetadata": {
"RequestId": "66f80391-348a-4ee0-9158-508914d16db2",
"HTTPStatusCode": 200.0,
"RetryAttempts": 0.0,
"HTTPHeaders": {
"content-type": "application/x-amz-json-1.1",
"date": "Wed, 31 Jan 2024 16:43:52 GMT",
"x-amzn-requestid": "66f80391-348a-4ee0-9158-508914d16db2",
"content-length": "114",
},
},
"EndpointConfigArn": "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config",
},
)
async def test_agent(mock_boto_call):
async def test_agent(mock_boto_call, mock_return_value):
mock_boto_call.return_value = mock_return_value

agent = AgentRegistry.get_agent("boto")
task_id = Identifier(
resource_type=ResourceType.TASK,
Expand Down Expand Up @@ -88,9 +116,16 @@ async def test_agent(mock_boto_call):
)

resource = await agent.do(task_template, task_inputs)

assert resource.phase == TaskExecution.SUCCEEDED
assert (
resource.outputs["result"]["EndpointConfigArn"]
== "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config"
)

if mock_return_value:
outputs = literal_map_string_repr(resource.outputs)
if "pickle_check" in mock_return_value:
assert "pickle_file" in outputs["result"]
else:
assert (
outputs["result"]["EndpointConfigArn"]
== "arn:aws:sagemaker:us-east-2:000000000:endpoint-config/sagemaker-xgboost-endpoint-config"
)
elif mock_return_value is None:
assert resource.outputs["result"] == {"result": None}
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def test_stable_cache_key():
}
)
key = _calculate_cache_key("task_name_1", "31415", lm)
assert key == "task_name_1-31415-404b45f8556276183621d4bf37f50049"
assert key == "task_name_1-31415-189e755a8f41c006889c291fcaedb4eb"


@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.")
Expand Down

0 comments on commit e6e08f9

Please sign in to comment.