Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Union Types RFCs #1926

Merged
merged 3 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions rfc/core language/sum-types-2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Sum Types (Unions) - Follow Up

**This is a follow-up to:** [the Sum Types RFC](https://github.com/maximsmol/flyte/blob/master/rfc/core%20language/sum-types.md)

# Executive Summary

Some questions on the previously proposed implementations made it clear that a deeper investigation into possible alternatives was required. I consider 3 programming languages with fundamentally different union type implementations in the core language or the standard library. I develop a new version of the sum type IDL representation that accomodates all three languages.

# Examples in Programming Languages

## Python

The type-erased case.

- No runtime representation
- The union is a set of types
- Duplicates get collapsed
- Single-type unions collapse to the underlying type
- Order of types does not matter
- Cannot effect runtime behavior as there is no tag
- Does not effect type equality

```py
>>> # No runtime representation
>>> a : t.Union[str, int] = 10
>>> a
10

>>> # Single-type union collapse
>>> t.Union[int, int]
<class 'int'>

>>> # Trivial duplicate collapse
>>> t.Union[int, int, str]
typing.Union[int, str]

>>> # Non-trivial duplicate collapse:
>>> a = t.Union[t.List[t.Union[str, int]], t.List[t.Union[str, int]]]
>>> b = t.Union[t.List[t.Union[str, int]], t.List[t.Union[int, str]]]
>>> a
typing.List[typing.Union[str, int]]
>>> b
typing.List[typing.Union[str, int]]
>>> a == b
True

>>> # Order does not matter:
>>> t.Union[str, int] == t.Union[int, str]
True
```

## Haskell

Algebraic types case.

- Runtime representation carries symbolic tag
- The union is a set of (tag, type) tuples
- Duplicate tags are a compile-time error
- Duplicate types with different tags are allowed
- Order of variants does not matter as the symbolic tag does not depend on its order in the union
- Defining the same union type but with different order of variants in different compilation unions (when compiling separately from linking) works as expected as the symbols referring to the type contain the symbolic tags

```haskell
data Test = Hello Int | World String

-- Runtime value carries tag, can be introspected
case x of
Hello a -> "Found int: " ++ show a
World b -> "Found string: " ++ show b

-- Duplicates are a compile-time error
data Test1 = Hello | Hello
{-|
test.hs:1:21: error:
Multiple declarations of ‘Hello’
Declared at: test.hs:1:13
test.hs:1:21
-}

-- Duplicate types are allowed
data Test2 = Left Int | Right Int
case x of
Left a -> "Found int: " ++ show a
Right b -> "Found a different int: " ++ show b
```

## C++ (std::variant)

In-between case (indexed union case).

- Runtime representation carries a positional tag
- Two available APIs
- One uses positional indexes as the tag
- One uses the types themselves as tags
- The union is a list of types
- Duplicate types cannot be used with the type-indexed API (unless unambiguous) but can be used with the position-indexed API
- Order of variants matters
- Can influence runtime behavior
- Distinct order means distinct types

```cpp
// Runtime representation carries a position tag, showing both APIs
std::variant<int, bool> a = 10;
assert(std::get<int>(a) == 10);
assert(std::get<0>(a) == 10);
a = false;
assert(std::get<bool>(a) == false);
assert(std::get<1>(a) == false);

// Failure cases
std::get<2>(a); // no matching function for call to 'get'
std::get<double>(a); // no matching function for call to 'get'

std::get<0>(a);
/*
terminate called after throwing an instance of 'std::bad_variant_access'
what(): Unexpected index
*/

// Duplicate types are allowed but must use an unambiguous API
std::variant<int, int> b = 10;
/*
no viable conversion from 'int' to
'std::variant<int, int>'
*/

// Unambiguous uses allow both APIs
std::variant<int, int, bool> c = false;

assert(std::get<bool>(c) == false);
assert(std::get<2>(c) == false);

// Ambiguous uses of the API do not work, the index-based API is the never ambiguous
c.emplace<0>(10); // Assignment using the index-based API
std::get<int>(c);
/*
error:
static_assert failed due to requirement
'__detail::__variant::__exactly_once<int, int, int, bool>' "T should occur
for exactly once in alternatives"
static_assert(__detail::__variant::__exactly_once<_Tp, _Types...>,
*/

std::get<0>(c) == 10;

// Order of types matters
if (c.index() == 0)
std::cout << "First integer" << std::endl;
else if (c.index() == 1)
std::cout << "Second integer" << std::endl;
else if (c.index() == 2)
std::cout << "Boolean" << std::endl;

std::variant<int, bool> x = false;
std::variant<bool, int> y = true;

x = y;
/*
error: no viable overloaded '='
x = y;
~ ^ ~
*/
```

# Design considerations

## Tagged vs Untagged

- To properly support languages like Haskell and C++ the backend representation of union types should use tags of some kind.

## Type of Tag

- First-class Haskell support needs string tags
- Integer tags open the possibility of incorrectly (according to language semantics) decoding a union-typed literal from IDL to a Haskell type if the Haskell source code changes to rearrange the variants
- Such a change is a no-op according to language semantics so this is an issue unless it is guaranteed that IDL `LiteralRepr`s are never serialized (or are never reused across task/workflow versions)
- This is also an issue when linking Haskell object files since they remain compatible even when produced from these two different (but equivalent) versions of the source code
- Both of these error cases are unlikely to occur naturally but not supporting them would cause obscure issues for users
- First-class C++ support needs integer tags

Since integers can be stringified, string tags offer first-class support for both languages

## Tags in Python

Since Python's `typing.Union` is untagged, it could be implemented without a tag and even without a `LiteralRepr` for the union values. Here is why this is not a good choice:

- An untagged implementation requires being able to determine whether a given `LiteralRepr` is castable to a given Python type. The current type transformer implementation is not designed as a type validator and may not throw on incompatible types during the transformation at all, or will throw an error indistinguishable from a normal Python programmer-error exception
- Another issue is that in case a custom class was defined (e.g. `MyInt` which is simply a proxy for the default `int`), then the `LiteralRepr` for an integer value would convert into both `MyInt` and `int` without an error so the choice between the two would be ambiguous and the runtime behavior could be influenced since the custom type transformer for `MyInt` can run arbitrary code

We could use an index-based tag, but there are multiple reasons why that is not a good idea either:

- `typing.Union` semantics are that of a set (the order does not matter for equality comparison, duplicates are eliminated), though in practice it the source ordering of the variants can be recovered in cases without duplicates. Note that the CPython implementation actually uses a tuple behind the scenes, so this is more about intent and less about factual behavior of the class. The [PEP](https://www.python.org/dev/peps/pep-0484/#union-types) also specifies that Union is expected to accept a "set" of types, though it is unclear whether this is referring to a specific data structure or just a figure of speech
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new insight for me. The whole section around colliding types within a Union and the equivalence of Unions of different variant ordering... Much appreciated. I think as long as we do not store the tag in the Literal itself and rather in the BindingInfo (that's computed for a particular task version), we will be fine, right? because as you said, a Literal produced by task1 v1 may be used to call task2 v1 in one case or task2 v2 in another... and it has to be portable... The binding, however, is computed when compiling the workflow closure (where all versions are pinned)

- If we ignore the apparent intent to implement unions as sets, the issue of code refactoring arises again as changing the order of variants in a union should not effect behavior. This goes away if we can guarantee that IDL is never serialized or that the serialized messages are never reused across different task/workflow versions

These problems compound with the requirements necessary for properly supporting Haskell and C++.

# Proposed Implementation

Use a string tag. In Haskell use the symbolic tag. In C++ use the index (serialized to a string) as the tag. In Python use the name of the type transformer (already present on all transformers).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with most of this and I appreciate the deep dive into various implementations in other languages… one goal that I would like to emphasize is that a task written in Java (or C++.. etc.) that accepts a Union should be usable from a workflow written in Python… and vis versa… I could be wrong but I got the sense that there is an assumption of homogeneity between tasks/wfs when it comes to the language used…
This will manifest in how the tag will be filled and subsequently used. We need to be able to consistently populate it across languages as well as consistently consume it in any of the languages…


In Python the correspondence between a tag and the choice of type transformer must be 1-to-1 i.e. type transformer names must be made unique which is already the case but not formalized.

The matching procedure for Haskell and C++ is trivial. In Python, we must deal specially with duplicates. For example:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we do "compile" the workflows on the server side, we will need to be able to validate something like this:

@task
def task2(typing.Union[int, str]) -> int:
  return 5
 
@workflow
def my_wf() -> int:
   a = task1()
   return task2(in=a)

The system will need to know that task2's input accept task1's output (even if task1 is declared in a completely different language)... so the way we represent the type will need to be language agnostic.
Moreover, since the system is able to make this connection, it should also be able to, at compile time, to define the target type binding/choice (i.e. will it be the int or the str)...

The target container's SDK/language will also need to understand that of course and be able to map between flyteidl's representation of the target type choice (e.g. int) to its needed Union's representation (index based tag or otherwise)...


```py
from typing import Union

# MyList is a proxy type similar to MyInt - a no-op wrapper around the native
# Python type but with a custom type transformer (which might have side effects)
def f(x: Union[MyList[Union[MyInt, int]], MyList[int]]):
...
```

In this case, having the `MyList` tag is not enough to disambiguate the choice of variant. We iterate each candidate variant and try to match it recursively with the literal. It is not required that all type transformers can fail gracefully when given a value of any incorrect type, only that they fail on union `LiteralRepr`s and that the union type transformer can recognize its own literals. The only possible difference between how the types are resolved is in the choice of type transformers, which is made completely unambiguous by traversing the type tree with its tags (since the tags resolve the only ambiguity). This procedure thus guarantees a value will be recovered appropriately from a union IDL representation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposal is to make this matching a compile-time matching. Based on the LiteralType of the connection (task1 returns foo of type blah, try to match blah against one of the union variants of x)...
At runtime, the SDK, if needed, can use the compile-time matching information to instruct the language of which variant is desired... I say 'if needed' because in Python, as you said, we don't actually need to pass the typing information of the literal...

172 changes: 172 additions & 0 deletions rfc/core language/sum-types.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Sum Types (Unions)

**Authors:**

- @maximsmol

## 1 Executive Summary

Goals:
- Implement support for sum types (also known as union types) in FlyteIDL
- Example: `a : int | str` can assume values `10` and `"hello world"`
- Implement support for new Python types in Flytekit
- `typing.Union`
- `typing.Optional` (which is a special case of `typing.Union`)

Two implementation are considered.
- A tagged literal representation using a new `Literal` message (primary alternative)
```proto
message LiteralSum {
Literal value = 1;
SumType type = 2;
uint64 summand_idx = 3;
}
```
- A type-erased literal representation where existing literals are made castable to acceptable sum types (secondary alternative)
- Example: `10` is made castable to `int | str`, `bool | int | list str`, etc.

## 2 Motivation

Currently any type can take none values ([see this comment in Propeller's sources](https://github.com/flyteorg/flytepropeller/blob/master/pkg/compiler/validators/typing.go#L32)). This creates a few unwanted outcomes:
- The type system does not enforce required workflow parameters as always having a valid value
- Collections are not in practice homogeneous since they may contain none values
- Example: `[1, 2, 3, null]` is a valid `list int` value since `int` is implicitly nullable
- In `Python` the use of `typing.Optional` is a no-op and furthermore not supported by default as there is no type transformer for `typing.Optional` since it would be useless
- Examining the types of workflow parameters is not enough to determine whether the parameter is intended to be optional
- This particular point affects Latch as we generate workflow interfaces based on type information
- Collections are the most troublesome here as it is unclear whether `list int` is intended to take none-values (and thus whether the interface should allow them)

## 3 Proposed Implementation

- Add the following to [`flyteidl/protos/flyteidl/core/types.proto`](https://github.com/flyteorg/flyteidl/blob/master/protos/flyteidl/core/types.proto):
```proto
message SumType {
repeated LiteralType summands = 1;
}
// ...
message LiteralType {
oneof type {
// ...
SumType sum = 8;
}
// ...
}
```
- Add the following to [`flyteidl/protos/flyteidl/core/literals.proto`](https://github.com/flyteorg/flyteidl/blob/master/protos/flyteidl/core/literals.proto):
```proto
message LiteralSum {
Literal value = 1;
SumType type = 2;
uint64 summand_idx = 3;
}
// ...
message Scalar {
oneof value {
// ...
LiteralSum sum = 8;
}
}
```
- Implement a new type checker in [`flytepropeller/pkg/compiler/validators/typing.go`](https://github.com/flyteorg/flytepropeller/blob/master/pkg/compiler/validators/typing.go):
```go
func (t sumChecker) CastsFrom(upstreamType *flyte.LiteralType) bool {
for _, x := range t.literalType.GetSum().GetSummands() {
if getTypeChecker(x).CastsFrom(upstreamType) {
return true;
}
}
return false;
}
```
- Do not implicitly accept none values for other types (potentially breaking change):
- Do not accept other types as Void downstream
- https://github.com/flyteorg/flytepropeller/blob/master/pkg/compiler/validators/typing.go#L59
- Do not accept Void as other types downstream
- In "trivial" types https://github.com/flyteorg/flytepropeller/blob/master/pkg/compiler/validators/typing.go#L33
- In maps https://github.com/flyteorg/flytepropeller/blob/master/pkg/compiler/validators/typing.go#L66
- In collections https://github.com/flyteorg/flytepropeller/blob/master/pkg/compiler/validators/typing.go#L80
- In schemas https://github.com/flyteorg/flytepropeller/blob/master/pkg/compiler/validators/typing.go#L101
- Update the bindings type validation code in [`flytepropeller/pkg/compiler/validators/bindings.go`](https://github.com/flyteorg/flytepropeller/blob/master/pkg/compiler/validators/bindings.go#L14):
```go
func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, binding *flyte.BindingData,
expectedType *flyte.LiteralType, errs errors.CompileErrors) (
resolvedType *flyte.LiteralType, upstreamNodes []c.NodeID, ok bool) {

switch binding.GetValue().(type) {
case *flyte.BindingData_Scalar:
// Goes through SumType-aware AreTypesCastable
break
case *flyte.BindingData_Promise:
// Goes through SumType-aware AreTypesCastable
break
default:
if expectedType.GetSum() != nil {
for _, t := range expectedType.GetSum().GetSummands() {
if resolvedType, nodeIds, ok := validateBinding(w, nodeID, nodeParam, binding, t, errors.NewCompileErrors()); ok {
// there can be no errors otherwise ok = false
return resolvedType, nodeIds, ok
}
}
errs.Collect(errors.NewMismatchingBindingsErr(nodeID, nodeParam, expectedType.String(), binding.GetCollection().String()))
return nil, nil, !errs.HasErrors()
}
}
// ...
}
```
- TODO: It might be necessary to accumulate the errors for each of the summands' failed binding validations to ease debugging. If that is the case, it would be preferrable to ignore errors by default and re-run the verification if no candidate was found to avoid slowing down the non-exceptional case
- The verbosity of the resulting messages would make it very hard to read so only a broad error is collected right now. It is unclear whether the extra complexity in the code and in the output is justified
- Implement a `typing.Union` type transformer in Python FlyteKit:
- `get_literal_type`:
```py
return LiteralType(sum=_type_models.SumType(summands=[TypeEngine.to_literal_type(x) for x in t.__args__]))
```
- `to_literal`
- Iterate through the types in `python_type.__args__` and try `TypeEngine.to_literal` for each. The first succeeding type is accepted
- TODO: this might mean that order of summands matters e.g. `X | Y` is different from `Y | X`
- `to_python_value`
- Use the `TypeTransformer` for the `lv.sum.type` to transform `lv.sum.value`
- `guess_python_type`
- Return `TypeEngine.guess_python_type(lv.sum.type)`
- All `TypeTransformer`s' `to_literal` must be updated to fail with a specific error class so the `typing.Union` transformer can distinguish between user or programmer error and actual failure to convert type
- Update [`flytekit/core/interface.py`](https://github.com/flyteorg/flytekit/blob/master/flytekit/core/interface.py) to support `None` values as parameter defaults
- Check whether the default is present by comparing with `inspect.Parameter.empty` in [`transform_inputs_to_parameters`](https://github.com/flyteorg/flytekit/blob/master/flytekit/core/interface.py#L186)
- Pass `inspect.Parameter.empty` to the interface as is in [`transform_signature_to_interface`](https://github.com/flyteorg/flytekit/blob/master/flytekit/core/interface.py#L283)

## 4 Metrics & Dashboards

None

## 5 Drawbacks

- Projects relying on types being implicitly nullable will be broken by this update since parameter types and return types might need to be changed to optionals
- A feature flag can be used to ease the transition
- It would be nice to estimate the number of projects affected by this

## 6 Alternatives

TODO: discuss the type-erased version

## 7 Potential Impact and Dependencies

See drawbacks

## 8 Unresolved questions

TODO: discuss the type-erased version

## 9 Conclusion

TODO

## 10 RFC Process Guide, remove this section when done

**Checklist:**

- [x] Copy template
- [x] Draft RFC (think of it as a wireframe)
- [ ] Share as WIP with folks you trust to gut-check
- [ ] Send pull request when comfortable
- [ ] Label accordingly
- [ ] Assign reviewers
- [ ] Merge PR