Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
Merge 4649cf9 into bb5d4cd
Browse files Browse the repository at this point in the history
  • Loading branch information
hamersaw committed Mar 31, 2023
2 parents bb5d4cd + 4649cf9 commit 121275b
Show file tree
Hide file tree
Showing 10 changed files with 1,728 additions and 314 deletions.
367 changes: 367 additions & 0 deletions clients/go/coreutils/casting.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,367 @@
package coreutils

import (
"strings"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core"
structpb "github.com/golang/protobuf/ptypes/struct"
)

type typeChecker interface {
CastsFrom(*core.LiteralType) bool
}

type trivialChecker struct {
literalType *core.LiteralType
}

// CastsFrom is a trivial type checker merely checks if types match exactly.
func (t trivialChecker) CastsFrom(upstreamType *core.LiteralType) bool {
// If upstream is an enum, it can be consumed as a string downstream
if upstreamType.GetEnumType() != nil {
if t.literalType.GetSimple() == core.SimpleType_STRING {
return true
}
}
// If t is an enum, it can be created from a string as Enums as just constrained String aliases
if t.literalType.GetEnumType() != nil {
if upstreamType.GetSimple() == core.SimpleType_STRING {
return true
}
}

if GetTagForType(upstreamType) != "" && GetTagForType(t.literalType) != GetTagForType(upstreamType) {
return false
}

// Ignore metadata when comparing types.
upstreamTypeCopy := *upstreamType
downstreamTypeCopy := *t.literalType
upstreamTypeCopy.Structure = &core.TypeStructure{}
downstreamTypeCopy.Structure = &core.TypeStructure{}
upstreamTypeCopy.Metadata = &structpb.Struct{}
downstreamTypeCopy.Metadata = &structpb.Struct{}
upstreamTypeCopy.Annotation = &core.TypeAnnotation{}
downstreamTypeCopy.Annotation = &core.TypeAnnotation{}
return upstreamTypeCopy.String() == downstreamTypeCopy.String()
}

type noneTypeChecker struct{}

// CastsFrom matches only void
func (t noneTypeChecker) CastsFrom(upstreamType *core.LiteralType) bool {
return isNoneType(upstreamType)

Check warning on line 53 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L52-L53

Added lines #L52 - L53 were not covered by tests
}

type mapTypeChecker struct {
literalType *core.LiteralType
}

// CastsFrom checks that the target map type can be cast to the current map type. We need to ensure both the key types
// and value types match.
func (t mapTypeChecker) CastsFrom(upstreamType *core.LiteralType) bool {
// Empty maps should match any collection.
mapLiteralType := upstreamType.GetMapValueType()
if isNoneType(mapLiteralType) {
return true
} else if mapLiteralType != nil {
return getTypeChecker(t.literalType.GetMapValueType()).CastsFrom(mapLiteralType)
}

return false
}

type collectionTypeChecker struct {
literalType *core.LiteralType
}

// CastsFrom checks whether two collection types match. We need to ensure that the nesting is correct and the final
// subtypes match.
func (t collectionTypeChecker) CastsFrom(upstreamType *core.LiteralType) bool {
// Empty collections should match any collection.
collectionType := upstreamType.GetCollectionType()
if isNoneType(upstreamType.GetCollectionType()) {
return true
} else if collectionType != nil {
return getTypeChecker(t.literalType.GetCollectionType()).CastsFrom(collectionType)
}

return false
}

type schemaTypeChecker struct {
literalType *core.LiteralType
}

// CastsFrom handles type casting to the underlying schema type.
// Schemas are more complex types in the Flyte ecosystem. A schema is considered castable in the following
// cases.
//
// 1. The downstream schema has no column types specified. In such a case, it accepts all schema input since it is
// generic.
//
// 2. The downstream schema has a subset of the upstream columns and they match perfectly.
//
// 3. The upstream type can be Schema type or structured dataset type
func (t schemaTypeChecker) CastsFrom(upstreamType *core.LiteralType) bool {
schemaType := upstreamType.GetSchema()
structuredDatasetType := upstreamType.GetStructuredDatasetType()
if structuredDatasetType == nil && schemaType == nil {
return false
}

if schemaType != nil {
return schemaCastFromSchema(schemaType, t.literalType.GetSchema())
}

// Flyte Schema can only be serialized to parquet
if len(structuredDatasetType.Format) != 0 && !strings.EqualFold(structuredDatasetType.Format, "parquet") {
return false

Check warning on line 119 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L119

Added line #L119 was not covered by tests
}

return schemaCastFromStructuredDataset(structuredDatasetType, t.literalType.GetSchema())
}

type structuredDatasetChecker struct {
literalType *core.LiteralType
}

// CastsFrom for Structured dataset are more complex types in the Flyte ecosystem. A structured dataset is considered
// castable in the following cases:
//
// 1. The downstream structured dataset has no column types specified. In such a case, it accepts all structured dataset input since it is
// generic.
//
// 2. The downstream structured dataset has a subset of the upstream structured dataset columns and they match perfectly.
//
// 3. The upstream type can be Schema type or structured dataset type
func (t structuredDatasetChecker) CastsFrom(upstreamType *core.LiteralType) bool {
// structured datasets are nullable
if isNoneType(upstreamType) {
return true
}
structuredDatasetType := upstreamType.GetStructuredDatasetType()
schemaType := upstreamType.GetSchema()
if structuredDatasetType == nil && schemaType == nil {
return false

Check warning on line 146 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L146

Added line #L146 was not covered by tests
}
if schemaType != nil {
// Flyte Schema can only be serialized to parquet
format := t.literalType.GetStructuredDatasetType().Format
if len(format) != 0 && !strings.EqualFold(format, "parquet") {
return false

Check warning on line 152 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L152

Added line #L152 was not covered by tests
}
return structuredDatasetCastFromSchema(schemaType, t.literalType.GetStructuredDatasetType())
}
return structuredDatasetCastFromStructuredDataset(structuredDatasetType, t.literalType.GetStructuredDatasetType())
}

// Upstream (schema) -> downstream (schema)
func schemaCastFromSchema(upstream *core.SchemaType, downstream *core.SchemaType) bool {
if len(upstream.Columns) == 0 || len(downstream.Columns) == 0 {
return true
}

nameToTypeMap := make(map[string]core.SchemaType_SchemaColumn_SchemaColumnType)
for _, column := range upstream.Columns {
nameToTypeMap[column.Name] = column.Type
}

// Check that the downstream schema is a strict sub-set of the upstream schema.
for _, column := range downstream.Columns {
upstreamType, ok := nameToTypeMap[column.Name]
if !ok {
return false
}
if upstreamType != column.Type {
return false
}
}
return true
}

type unionTypeChecker struct {
literalType *core.LiteralType
}

func (t unionTypeChecker) CastsFrom(upstreamType *core.LiteralType) bool {
unionType := t.literalType.GetUnionType()

upstreamUnionType := upstreamType.GetUnionType()
if upstreamUnionType != nil {
// For each upstream variant we must find a compatible downstream variant
for _, u := range upstreamUnionType.GetVariants() {
found := false
for _, d := range unionType.GetVariants() {
if AreTypesCastable(u, d) {
found = true
break
}
}
if !found {
return false
}
}

return true
}

// Matches iff we can unambiguously select a variant
foundOne := false
for _, x := range unionType.GetVariants() {
if AreTypesCastable(upstreamType, x) {
if foundOne {
return false
}
foundOne = true
}
}

return foundOne
}

// Upstream (structuredDatasetType) -> downstream (structuredDatasetType)
func structuredDatasetCastFromStructuredDataset(upstream *core.StructuredDatasetType, downstream *core.StructuredDatasetType) bool {
// Skip the format check here when format is empty. https://github.com/flyteorg/flyte/issues/2864
if len(upstream.Format) != 0 && len(downstream.Format) != 0 && !strings.EqualFold(upstream.Format, downstream.Format) {
return false

Check warning on line 227 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L227

Added line #L227 was not covered by tests
}

if len(upstream.Columns) == 0 || len(downstream.Columns) == 0 {
return true
}

nameToTypeMap := make(map[string]*core.LiteralType)
for _, column := range upstream.Columns {
nameToTypeMap[column.Name] = column.LiteralType
}

// Check that the downstream structured dataset is a strict sub-set of the upstream structured dataset.
for _, column := range downstream.Columns {
upstreamType, ok := nameToTypeMap[column.Name]
if !ok {
return false
}
if !getTypeChecker(column.LiteralType).CastsFrom(upstreamType) {
return false
}
}
return true
}

// Upstream (schemaType) -> downstream (structuredDatasetType)
func structuredDatasetCastFromSchema(upstream *core.SchemaType, downstream *core.StructuredDatasetType) bool {
if len(upstream.Columns) == 0 || len(downstream.Columns) == 0 {
return true
}
nameToTypeMap := make(map[string]core.SchemaType_SchemaColumn_SchemaColumnType)
for _, column := range upstream.Columns {
nameToTypeMap[column.Name] = column.GetType()
}

// Check that the downstream structuredDataset is a strict sub-set of the upstream schema.
for _, column := range downstream.Columns {
upstreamType, ok := nameToTypeMap[column.Name]
if !ok {
return false

Check warning on line 266 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L266

Added line #L266 was not covered by tests
}
if !schemaTypeIsMatchStructuredDatasetType(upstreamType, column.LiteralType.GetSimple()) {
return false
}
}
return true
}

// Upstream (structuredDatasetType) -> downstream (schemaType)
func schemaCastFromStructuredDataset(upstream *core.StructuredDatasetType, downstream *core.SchemaType) bool {
if len(upstream.Columns) == 0 || len(downstream.Columns) == 0 {
return true
}
nameToTypeMap := make(map[string]core.SimpleType)
for _, column := range upstream.Columns {
nameToTypeMap[column.Name] = column.LiteralType.GetSimple()
}

// Check that the downstream schema is a strict sub-set of the upstream structuredDataset.
for _, column := range downstream.Columns {
upstreamType, ok := nameToTypeMap[column.Name]
if !ok {
return false

Check warning on line 289 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L289

Added line #L289 was not covered by tests
}
if !schemaTypeIsMatchStructuredDatasetType(column.GetType(), upstreamType) {
return false

Check warning on line 292 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L292

Added line #L292 was not covered by tests
}
}
return true
}

func schemaTypeIsMatchStructuredDatasetType(schemaType core.SchemaType_SchemaColumn_SchemaColumnType, structuredDatasetType core.SimpleType) bool {
switch schemaType {
case core.SchemaType_SchemaColumn_INTEGER:
return structuredDatasetType == core.SimpleType_INTEGER
case core.SchemaType_SchemaColumn_FLOAT:
return structuredDatasetType == core.SimpleType_FLOAT
case core.SchemaType_SchemaColumn_STRING:
return structuredDatasetType == core.SimpleType_STRING
case core.SchemaType_SchemaColumn_BOOLEAN:
return structuredDatasetType == core.SimpleType_BOOLEAN
case core.SchemaType_SchemaColumn_DATETIME:
return structuredDatasetType == core.SimpleType_DATETIME
case core.SchemaType_SchemaColumn_DURATION:
return structuredDatasetType == core.SimpleType_DURATION

Check warning on line 311 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L302-L311

Added lines #L302 - L311 were not covered by tests
}
return false

Check warning on line 313 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L313

Added line #L313 was not covered by tests
}

func isNoneType(t *core.LiteralType) bool {
switch t.GetType().(type) {
case *core.LiteralType_Simple:
return t.GetSimple() == core.SimpleType_NONE
default:
return false
}
}

func getTypeChecker(t *core.LiteralType) typeChecker {
switch t.GetType().(type) {
case *core.LiteralType_CollectionType:
return collectionTypeChecker{
literalType: t,
}
case *core.LiteralType_MapValueType:
return mapTypeChecker{
literalType: t,
}
case *core.LiteralType_Schema:
return schemaTypeChecker{
literalType: t,
}
case *core.LiteralType_UnionType:
return unionTypeChecker{
literalType: t,
}
case *core.LiteralType_StructuredDatasetType:
return structuredDatasetChecker{
literalType: t,
}
default:
if isNoneType(t) {
return noneTypeChecker{}

Check warning on line 349 in clients/go/coreutils/casting.go

View check run for this annotation

Codecov / codecov/patch

clients/go/coreutils/casting.go#L349

Added line #L349 was not covered by tests
}

return trivialChecker{
literalType: t,
}
}
}

func AreTypesCastable(upstreamType, downstreamType *core.LiteralType) bool {
return getTypeChecker(downstreamType).CastsFrom(upstreamType)
}

func GetTagForType(x *core.LiteralType) string {
if x.GetStructure() == nil {
return ""
}
return x.GetStructure().GetTag()
}
Loading

0 comments on commit 121275b

Please sign in to comment.