diff --git a/cpp/test/computed_fields_test.cc b/cpp/test/computed_fields_test.cc index eb39a2c0..20713691 100644 --- a/cpp/test/computed_fields_test.cc +++ b/cpp/test/computed_fields_test.cc @@ -119,8 +119,8 @@ TEST(ComputedFieldsTest, SwitchExpression) { r.optional_named_array = {}; ASSERT_EQ(r.OptionalNamedArrayLength(), 0); ASSERT_EQ(r.OptionalNamedArrayLengthWithDiscard(), 0); - static_assert(std::is_same_v); - static_assert(std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v); r.int_float_union = 42; ASSERT_EQ(r.IntFloatUnionAsFloat(), 42.0f); diff --git a/cpp/test/generated/hdf5/protocols.cc b/cpp/test/generated/hdf5/protocols.cc index dc496ce2..e8a06dc6 100644 --- a/cpp/test/generated/hdf5/protocols.cc +++ b/cpp/test/generated/hdf5/protocols.cc @@ -132,7 +132,7 @@ namespace { [[maybe_unused]] H5::EnumType GetSizeBasedEnumHdf5Ddl() { H5::EnumType t(yardl::hdf5::SizeTypeDdl()); - size_t i = 0ULL; + yardl::Size i = 0ULL; t.insert("a", &i); i = 1ULL; t.insert("b", &i); diff --git a/cpp/test/generated/types.h b/cpp/test/generated/types.h index 5711637a..e3b992c9 100644 --- a/cpp/test/generated/types.h +++ b/cpp/test/generated/types.h @@ -88,7 +88,7 @@ struct RecordWithPrimitives { uint32_t uint32_field{}; int64_t int64_field{}; uint64_t uint64_field{}; - size_t size_field{}; + yardl::Size size_field{}; float float32_field{}; double float64_field{}; std::complex complexfloat32_field{}; @@ -362,7 +362,7 @@ enum class Int64Enum : int64_t { kB = -4611686018427387904LL, }; -enum class SizeBasedEnum : size_t { +enum class SizeBasedEnum : yardl::Size { kA = 0ULL, kB = 1ULL, kC = 2ULL, @@ -590,31 +590,31 @@ struct RecordWithComputedFields { return const_cast(std::as_const(*this).AccessVectorOfVectorsField()); } - size_t ArraySize() const { + yardl::Size ArraySize() const { return array_field.size(); } - size_t ArrayXSize() const { + yardl::Size ArrayXSize() const { return array_field.shape(0); } - size_t ArrayYSize() const { + yardl::Size ArrayYSize() const { return array_field.shape(1); } - size_t Array0Size() const { + yardl::Size Array0Size() const { return array_field.shape(0); } - size_t Array1Size() const { + yardl::Size Array1Size() const { return array_field.shape(1); } - size_t ArraySizeFromIntField() const { + yardl::Size ArraySizeFromIntField() const { return array_field.shape(int_field); } - size_t ArraySizeFromStringField() const { + yardl::Size ArraySizeFromStringField() const { return array_field.shape(([](std::string dim_name) { if (dim_name == "x") return 0; if (dim_name == "y") return 1; @@ -622,43 +622,43 @@ struct RecordWithComputedFields { })(string_field)); } - size_t ArraySizeFromNestedIntField() const { + yardl::Size ArraySizeFromNestedIntField() const { return array_field.shape(tuple_field.v1); } - size_t ArrayFieldMapDimensionsXSize() const { + yardl::Size ArrayFieldMapDimensionsXSize() const { return array_field_map_dimensions.shape(0); } - size_t FixedArraySize() const { + yardl::Size FixedArraySize() const { return 12ULL; } - size_t FixedArrayXSize() const { + yardl::Size FixedArrayXSize() const { return 3ULL; } - size_t FixedArray0Size() const { + yardl::Size FixedArray0Size() const { return 3ULL; } - size_t VectorSize() const { + yardl::Size VectorSize() const { return vector_field.size(); } - size_t FixedVectorSize() const { + yardl::Size FixedVectorSize() const { return 3ULL; } - size_t ArrayDimensionXIndex() const { + yardl::Size ArrayDimensionXIndex() const { return 0ULL; } - size_t ArrayDimensionYIndex() const { + yardl::Size ArrayDimensionYIndex() const { return 1ULL; } - size_t ArrayDimensionIndexFromStringField() const { + yardl::Size ArrayDimensionIndexFromStringField() const { return ([](std::string dim_name) { if (dim_name == "x") return 0; if (dim_name == "y") return 1; @@ -666,16 +666,16 @@ struct RecordWithComputedFields { })(string_field); } - size_t ArrayDimensionCount() const { + yardl::Size ArrayDimensionCount() const { return 2ULL; } - size_t DynamicArrayDimensionCount() const { + yardl::Size DynamicArrayDimensionCount() const { return dynamic_array_field.dimension(); } - size_t OptionalNamedArrayLength() const { - return [](auto&& __case_arg__) -> size_t { + yardl::Size OptionalNamedArrayLength() const { + return [](auto&& __case_arg__) -> yardl::Size { if (__case_arg__.has_value()) { test_model::NamedNDArray const& arr = __case_arg__.value(); return arr.size(); @@ -684,8 +684,8 @@ struct RecordWithComputedFields { }(optional_named_array); } - size_t OptionalNamedArrayLengthWithDiscard() const { - return [](auto&& __case_arg__) -> size_t { + yardl::Size OptionalNamedArrayLengthWithDiscard() const { + return [](auto&& __case_arg__) -> yardl::Size { if (__case_arg__.has_value()) { test_model::NamedNDArray const& arr = __case_arg__.value(); return arr.size(); diff --git a/cpp/test/roundtrip_test.cc b/cpp/test/roundtrip_test.cc index c62af120..81729662 100644 --- a/cpp/test/roundtrip_test.cc +++ b/cpp/test/roundtrip_test.cc @@ -48,7 +48,7 @@ TEST_P(RoundTripTests, Scalars) { rec.uint32_field = 55; rec.int64_field = -66; rec.uint64_field = 66; - rec.size_field = sizeof(size_t); + rec.size_field = UINT64_MAX; rec.float32_field = 4290.39; rec.float64_field = 2234290.39; rec.complexfloat32_field = {1.3, 2.2}; @@ -560,7 +560,7 @@ INSTANTIATE_TEST_SUITE_P(, ::testing::Values( Format::kBinary, Format::kHdf5), - [](const ::testing::TestParamInfo& info) { + [](::testing::TestParamInfo const& info) { switch (info.param) { case Format::kBinary: return "Binary"; diff --git a/tooling/internal/cpp/common/common.go b/tooling/internal/cpp/common/common.go index 8966c6ce..2f0ca4c6 100644 --- a/tooling/internal/cpp/common/common.go +++ b/tooling/internal/cpp/common/common.go @@ -311,8 +311,10 @@ func TypeDefinitionSyntax(t dsl.TypeDefinition) string { func PrimitiveSyntax(p dsl.PrimitiveDefinition) string { switch p { - case dsl.Int8, dsl.Uint8, dsl.Int16, dsl.Uint16, dsl.Int32, dsl.Uint32, dsl.Int64, dsl.Uint64, dsl.Size: + case dsl.Int8, dsl.Uint8, dsl.Int16, dsl.Uint16, dsl.Int32, dsl.Uint32, dsl.Int64, dsl.Uint64: return string(p) + "_t" + case dsl.Size: + return "yardl::Size" case dsl.Float32: return "float" case dsl.Float64: diff --git a/tooling/internal/cpp/include/detail/hdf5/ddl.h b/tooling/internal/cpp/include/detail/hdf5/ddl.h index 8ce32c83..8139e946 100644 --- a/tooling/internal/cpp/include/detail/hdf5/ddl.h +++ b/tooling/internal/cpp/include/detail/hdf5/ddl.h @@ -24,7 +24,7 @@ namespace yardl::hdf5 { /** - * @brief Returns the HDF5 type for size_t. + * @brief Returns the HDF5 type for yardl::Size. */ static inline H5::PredType const& SizeTypeDdl() { static_assert(sizeof(hsize_t) == sizeof(size_t)); diff --git a/tooling/internal/cpp/include/yardl.h b/tooling/internal/cpp/include/yardl.h index 15814bef..98b62e7f 100644 --- a/tooling/internal/cpp/include/yardl.h +++ b/tooling/internal/cpp/include/yardl.h @@ -69,4 +69,9 @@ using Time = std::chrono::duration; using DateTime = std::chrono::time_point>; +/** + * @brief The same as size_t when it is 64 bits, otherwise uint64_t. + */ +using Size = std::conditional_t; + } // namespace yardl diff --git a/tooling/pkg/dsl/rewriter.go b/tooling/pkg/dsl/rewriter.go index cb110f6e..a947e832 100644 --- a/tooling/pkg/dsl/rewriter.go +++ b/tooling/pkg/dsl/rewriter.go @@ -90,7 +90,7 @@ func (rewriter RewriterWithContext[T]) DefaultRewrite(node Node, context T) Node rewrittenEnv.Namespaces = rewrittenNamespaces return &rewrittenEnv case *Namespace: - rewrittenTypes := rewriteIntefaceSlice(t.TypeDefinitions, context, rewriter) + rewrittenTypes := rewriteInterfaceSlice(t.TypeDefinitions, context, rewriter) rewrittenProtocols := rewriteSlice(t.Protocols, context, rewriter) if rewrittenTypes == nil && rewrittenProtocols == nil { @@ -209,7 +209,7 @@ func (rewriter RewriterWithContext[T]) DefaultRewrite(node Node, context T) Node case *GenericTypeParameter: return t case *SimpleType: - rewrittenTypeArguments := rewriteIntefaceSlice(t.TypeArguments, context, rewriter) + rewrittenTypeArguments := rewriteInterfaceSlice(t.TypeArguments, context, rewriter) if rewrittenTypeArguments == nil { return t } @@ -218,7 +218,7 @@ func (rewriter RewriterWithContext[T]) DefaultRewrite(node Node, context T) Node rewrittenSimpleType.TypeArguments = rewrittenTypeArguments return &rewrittenSimpleType case *GeneralizedType: - rewrittenTypeCases := rewriteIntefaceSlice(t.Cases, context, rewriter) + rewrittenTypeCases := rewriteInterfaceSlice(t.Cases, context, rewriter) var rewrittenDimensionality Dimensionality if t.Dimensionality != nil { @@ -283,7 +283,7 @@ func (rewriter RewriterWithContext[T]) DefaultRewrite(node Node, context T) Node rewrittenTarget = rewriter.Rewrite(t.Target, context).(Expression) } - rewrittenArguments := rewriteIntefaceSlice(t.Arguments, context, rewriter) + rewrittenArguments := rewriteInterfaceSlice(t.Arguments, context, rewriter) if rewrittenTarget == t.Target && rewrittenArguments == nil { return t @@ -306,7 +306,7 @@ func (rewriter RewriterWithContext[T]) DefaultRewrite(node Node, context T) Node return &rewrittenArgument case *FunctionCallExpression: - rewrittenArguments := rewriteIntefaceSlice(t.Arguments, context, rewriter) + rewrittenArguments := rewriteInterfaceSlice(t.Arguments, context, rewriter) if rewrittenArguments == nil { return t @@ -379,7 +379,7 @@ func (rewriter RewriterWithContext[T]) DefaultRewrite(node Node, context T) Node } } -// Rewites a slice of pointers to types that implement the Node interface, e.g, []*Field +// Rewrites a slice of pointers to types that implement the Node interface, e.g, []*Field // Returns nil if no changes were made and the original slice should be used. func rewriteSlice[TContext any, TElement any, T interface { *TElement @@ -402,9 +402,9 @@ func rewriteSlice[TContext any, TElement any, T interface { return rewrittenSlice } -// Rewites a slice of an interface that implements the Node interface, e.g, []Expression +// Rewrites a slice of an interface that implements the Node interface, e.g, []Expression // Returns nil if no changes were made and the original slice should be used. -func rewriteIntefaceSlice[TContext any, T Node](slice []T, context TContext, rewriter RewriterWithContext[TContext]) []T { +func rewriteInterfaceSlice[TContext any, T Node](slice []T, context TContext, rewriter RewriterWithContext[TContext]) []T { var rewrittenSlice []T for i, element := range slice { visited := rewriter.Rewrite(T(element), context) diff --git a/tooling/pkg/dsl/typefunctions.go b/tooling/pkg/dsl/typefunctions.go index f99483cf..4150bd50 100644 --- a/tooling/pkg/dsl/typefunctions.go +++ b/tooling/pkg/dsl/typefunctions.go @@ -73,6 +73,11 @@ func TypeDefinitionsEqual(a, b TypeDefinition) bool { aMeta := a.GetDefinitionMeta() bMeta := b.GetDefinitionMeta() if aMeta.Namespace != bMeta.Namespace || aMeta.Name != bMeta.Name { + if a == PrimitiveSize && b == PrimitiveUint64 || a == PrimitiveUint64 && b == PrimitiveSize { + // Special case: `size` and `uint64` are equivalent though not aliases + return true + } + return false } @@ -185,6 +190,10 @@ func TypesEqual(a, b Type) bool { if b == nil { return a == nil } + + a = GetUnderlyingType(a) + b = GetUnderlyingType(b) + switch ta := a.(type) { case *SimpleType: tb, ok := b.(*SimpleType) diff --git a/tooling/pkg/dsl/typeparser.go b/tooling/pkg/dsl/typeparser.go index e0521b8a..bdc463ff 100644 --- a/tooling/pkg/dsl/typeparser.go +++ b/tooling/pkg/dsl/typeparser.go @@ -10,9 +10,10 @@ import ( ) type simpleTypeTree struct { - Name string `json:"name"` - TypeArguments []simpleTypeTree `json:"args,omitempty"` - Optional bool `json:"optional,omitempty"` + Name string `json:"name"` + TypeArguments []simpleTypeTree `json:"args,omitempty"` + Optional bool `json:"optional,omitempty"` + PositionOffset int `json:"positionOffset,omitempty"` } func (pt *simpleTypeTree) String() string { @@ -67,6 +68,7 @@ func (tp *typeParser) consumeIdentifier() string { func (tp *typeParser) parseTypeString() (simpleTypeTree, error) { parsed := simpleTypeTree{} tp.skipWhitespace() + parsed.PositionOffset = tp.position parsed.Name = tp.consumeIdentifier() tp.skipWhitespace() @@ -154,7 +156,10 @@ func parseSimpleTypeStringAllowingRemaining(typeString string) (typeTree simpleT } func (tree simpleTypeTree) ToType(node NodeMeta) Type { - simpleType := SimpleType{NodeMeta: node, Name: tree.Name} + nodeWithPositionUpdated := node + nodeWithPositionUpdated.Column += tree.PositionOffset + + simpleType := SimpleType{NodeMeta: nodeWithPositionUpdated, Name: tree.Name} for _, typeArg := range tree.TypeArguments { simpleType.TypeArguments = append(simpleType.TypeArguments, typeArg.ToType(node)) } diff --git a/tooling/pkg/dsl/typeparser_test.go b/tooling/pkg/dsl/typeparser_test.go index 6d14dcdc..e0d00558 100644 --- a/tooling/pkg/dsl/typeparser_test.go +++ b/tooling/pkg/dsl/typeparser_test.go @@ -17,17 +17,17 @@ func TestTypeParsing_Valid(t *testing.T) { }{ {input: "Foo", expected: `{"name":"Foo"}`}, {input: "Foo ", expected: `{"name":"Foo"}`}, - {input: " Foo ", expected: `{"name":"Foo"}`}, + {input: " Foo ", expected: `{"name":"Foo","positionOffset":1}`}, {input: "Foo?", expected: `{"name":"Foo","optional":true}`}, {input: "Foo ? ", expected: `{"name":"Foo","optional":true}`}, - {input: "Foo", expected: `{"name":"Foo","args":[{"name":"int"}]}`}, - {input: "Foo< int >", expected: `{"name":"Foo","args":[{"name":"int"}]}`}, - {input: "Foo", expected: `{"name":"Foo","args":[{"name":"int","optional":true}]}`}, - {input: "Foo?", expected: `{"name":"Foo","args":[{"name":"int"}],"optional":true}`}, - {input: "Foo", expected: `{"name":"Foo","args":[{"name":"int"},{"name":"float"}]}`}, - {input: " Foo < int , float > ", expected: `{"name":"Foo","args":[{"name":"int"},{"name":"float"}]}`}, - {input: "Foo>", expected: `{"name":"Foo","args":[{"name":"Bar","args":[{"name":"int"}]}]}`}, - {input: "Foo,Baz>", expected: `{"name":"Foo","args":[{"name":"Bar","args":[{"name":"int"}]},{"name":"Baz","args":[{"name":"long"}]}]}`}, + {input: "Foo", expected: `{"name":"Foo","args":[{"name":"int","positionOffset":4}]}`}, + {input: "Foo< int >", expected: `{"name":"Foo","args":[{"name":"int","positionOffset":5}]}`}, + {input: "Foo", expected: `{"name":"Foo","args":[{"name":"int","optional":true,"positionOffset":4}]}`}, + {input: "Foo?", expected: `{"name":"Foo","args":[{"name":"int","positionOffset":4}],"optional":true}`}, + {input: "Foo", expected: `{"name":"Foo","args":[{"name":"int","positionOffset":4},{"name":"float","positionOffset":8}]}`}, + {input: " Foo < int , float > ", expected: `{"name":"Foo","args":[{"name":"int","positionOffset":7},{"name":"float","positionOffset":13}],"positionOffset":1}`}, + {input: "Foo>", expected: `{"name":"Foo","args":[{"name":"Bar","args":[{"name":"int","positionOffset":8}],"positionOffset":4}]}`}, + {input: "Foo,Baz>", expected: `{"name":"Foo","args":[{"name":"Bar","args":[{"name":"int","positionOffset":8}],"positionOffset":4},{"name":"Baz","args":[{"name":"long","positionOffset":17}],"positionOffset":13}]}`}, } for _, tc := range testCases { t.Run(tc.input, func(t *testing.T) { diff --git a/tooling/pkg/dsl/types.go b/tooling/pkg/dsl/types.go index 502692c5..c360c6c9 100644 --- a/tooling/pkg/dsl/types.go +++ b/tooling/pkg/dsl/types.go @@ -21,6 +21,17 @@ type NodeMeta struct { Column int `json:"-"` } +func (n *NodeMeta) String() string { + return fmt.Sprintf("%s:%d:%d", n.File, n.Line, n.Column) +} + +func (n *NodeMeta) Equals(other *NodeMeta) bool { + return n == other || (n != nil && other != nil && + n.File == other.File && + n.Line == other.Line && + n.Column == other.Column) +} + func (n *NodeMeta) GetNodeMeta() *NodeMeta { return n } @@ -605,7 +616,7 @@ var ( _ TypeDefinition = (*RecordDefinition)(nil) _ TypeDefinition = (*EnumDefinition)(nil) - _ TypeDefinition = (*PrimitiveDefinition)(nil) + _ TypeDefinition = (PrimitiveDefinition)("") _ TypeDefinition = (*NamedType)(nil) _ TypeDefinition = (*ProtocolDefinition)(nil) _ TypeDefinition = (*GenericTypeParameter)(nil) diff --git a/tooling/pkg/dsl/types_test.go b/tooling/pkg/dsl/types_test.go index 77de2a1d..2263ac3f 100644 --- a/tooling/pkg/dsl/types_test.go +++ b/tooling/pkg/dsl/types_test.go @@ -41,14 +41,14 @@ Z: T`, true, t.Run(tC.spec, func(t *testing.T) { env, err := parseAndValidate(t, tC.spec) require.Nil(t, err) - x := typeDefnitionByName(env, "X").(*NamedType) - y := typeDefnitionByName(env, "Y").(*NamedType) + x := typeDefinitionByName(env, "X").(*NamedType) + y := typeDefinitionByName(env, "Y").(*NamedType) require.Equal(t, tC.expectedResult, TypesEqual(GetUnderlyingType(y.Type), x.Type)) }) } } -func typeDefnitionByName(env *Environment, name string) TypeDefinition { +func typeDefinitionByName(env *Environment, name string) TypeDefinition { for _, ns := range env.Namespaces { for _, td := range ns.TypeDefinitions { if td.GetDefinitionMeta().Name == name { diff --git a/tooling/pkg/dsl/validation.go b/tooling/pkg/dsl/validation.go index 29603a5c..20149995 100644 --- a/tooling/pkg/dsl/validation.go +++ b/tooling/pkg/dsl/validation.go @@ -34,9 +34,10 @@ func Validate(namespaces []*Namespace) (*Environment, error) { validateStreams, buildSymbolTable, resolveTypes, - validateUnionCases, + assignUnionCaseLabels, topologicalSortTypes, convertGenericReferences, + validateUnionCases, validateEnums, resolveComputedFields, } diff --git a/tooling/pkg/dsl/validation_topological_sort.go b/tooling/pkg/dsl/validation_topological_sort.go index 52a22980..df6d5e8c 100644 --- a/tooling/pkg/dsl/validation_topological_sort.go +++ b/tooling/pkg/dsl/validation_topological_sort.go @@ -31,7 +31,7 @@ func topologicalSortTypes(env *Environment, errorSink *validation.ErrorSink) *En case *EnumDefinition: return fmt.Sprintf("Enum '%s'", nt.Name) case *NamedType: - return fmt.Sprintf("Array '%s'", nt.Name) + return fmt.Sprintf("Alias '%s'", nt.Name) case *Field: return fmt.Sprintf("Field '%s'", nt.Name) default: diff --git a/tooling/pkg/dsl/validation_topological_sort_test.go b/tooling/pkg/dsl/validation_topological_sort_test.go index 3bb638d0..ca5edcb1 100644 --- a/tooling/pkg/dsl/validation_topological_sort_test.go +++ b/tooling/pkg/dsl/validation_topological_sort_test.go @@ -64,5 +64,5 @@ Image: !array --- ` _, err := parseAndValidate(t, src) - require.ErrorContains(t, err, "there is a reference cycle, which is not supported, within namespace 'test': Array 'Image' -> Array 'Image'") + require.ErrorContains(t, err, "there is a reference cycle, which is not supported, within namespace 'test': Alias 'Image' -> Alias 'Image'") } diff --git a/tooling/pkg/dsl/validation_unions.go b/tooling/pkg/dsl/validation_unions.go index 252955a4..62a2ba6a 100644 --- a/tooling/pkg/dsl/validation_unions.go +++ b/tooling/pkg/dsl/validation_unions.go @@ -11,9 +11,56 @@ import ( "github.com/microsoft/yardl/tooling/internal/validation" ) -func validateUnionCases(env *Environment, errorSink *validation.ErrorSink) *Environment { +func assignUnionCaseLabels(env *Environment, errorSink *validation.ErrorSink) *Environment { Visit(env, func(self Visitor, node Node) { - if t, ok := node.(*GeneralizedType); ok { + if t, ok := node.(*GeneralizedType); ok && t.Cases.IsUnion() { + // assign labels to union cases + for _, typeCase := range t.Cases { + if !typeCase.IsNullType() { + typeCase.Label = typeLabel(typeCase.Type, true) + } + } + + duplicates := make(map[string][]int) + for i, typeCase := range t.Cases { + if !typeCase.IsNullType() { + duplicates[typeCase.Label] = append(duplicates[typeCase.Label], i) + } + } + + for _, v := range duplicates { + if len(v) > 1 { + for _, i := range v { + t.Cases[i].Label = typeLabel(t.Cases[i].Type, false) + } + } + } + + labels := make(map[string]any) + for _, item := range t.Cases { + if item != nil { + if _, found := labels[item.Label]; found { + errorSink.Add(validationError(node, "internal error: union cases must have distinct labels within the union")) + } + } + } + } + + self.VisitChildren(node) + }) + + return env +} + +func validateUnionCases(env *Environment, errorSink *validation.ErrorSink) *Environment { + if len(errorSink.Errors) > 0 { + // Only perform this if all types are resolved + return env + } + + VisitWithContext(env, false, func(self VisitorWithContext[bool], node Node, visitingReference bool) { + switch t := node.(type) { + case *GeneralizedType: cases := t.Cases if len(cases) == 0 { errorSink.Add(validationError(node, "a field cannot be a union type with no options")) @@ -35,51 +82,80 @@ func validateUnionCases(env *Environment, errorSink *validation.ErrorSink) *Envi errorSink.Add(validationError(typeCase, "unions may not immediately contain other unions")) } } - } - - if cases.IsUnion() { - // assign labels to union cases - for _, typeCase := range cases { - if !typeCase.IsNullType() { - typeCase.Label = typeLabel(typeCase.Type, true) - } - } - - duplicates := make(map[string][]int) - for i, typeCase := range cases { - if !typeCase.IsNullType() { - duplicates[typeCase.Label] = append(duplicates[typeCase.Label], i) - } - } - for _, v := range duplicates { - if len(v) > 1 { - for _, i := range v { - cases[i].Label = typeLabel(cases[i].Type, false) + for i, item := range t.Cases { + for j := i + 1; j < len(t.Cases); j++ { + otherItem := t.Cases[j] + + if TypesEqual(item.Type, otherItem.Type) { + additionalExplanation := "" + if !item.IsNullType() { + // determine if this is because size and uint64 were used, which are equivalent but not aliases + if itemPrimitive, ok := GetPrimitiveType(item.Type); ok { + if otherItemPrimitive, ok := GetPrimitiveType(otherItem.Type); ok { + if itemPrimitive != otherItemPrimitive && + (itemPrimitive == PrimitiveUint64 && otherItemPrimitive == PrimitiveSize || + itemPrimitive == PrimitiveSize && otherItemPrimitive == PrimitiveUint64) { + additionalExplanation = " (uint64 and size are equivalent)" + } + } + } + + // Determine if the types are defined at a different location than the cases + // This indicates that the cause of the duplicate is a type argument. + + itemNodeMeta := item.GetNodeMeta() + itemTypeNodeMeta := item.Type.GetNodeMeta() + + otherItemNodeMeta := otherItem.GetNodeMeta() + otherItemTypeNodeMeta := otherItem.Type.GetNodeMeta() + + itemDefinedElsewhere := !itemNodeMeta.Equals(itemTypeNodeMeta) + otherItemDefinedElsewhere := !otherItemNodeMeta.Equals(otherItemTypeNodeMeta) + + if itemDefinedElsewhere || otherItemDefinedElsewhere { + if itemDefinedElsewhere && otherItemDefinedElsewhere { + // both are type arguments + errorSink.Add(validationError(item, "redundant union type cases resulting from the type arguments given at %s and %s%s", itemTypeNodeMeta, otherItemTypeNodeMeta, additionalExplanation)) + continue + } + + // only one is a type argument + var typeParameterNode Node + var redundantNode Node + if itemDefinedElsewhere { + typeParameterNode = itemTypeNodeMeta + redundantNode = otherItemNodeMeta + } else { + typeParameterNode = otherItemTypeNodeMeta + redundantNode = itemNodeMeta + } + + errorSink.Add(validationError(redundantNode, "redundant union type cases resulting from the type argument given at %s%s", typeParameterNode, additionalExplanation)) + continue + } + } + // No contributions from type arguments. + // To avoid reporting the same error multiple times, we only report the error + // if we we visiting the type directly, i.e. not through a reference. + if !visitingReference { + errorSink.Add(validationError(item, "redundant union type cases%s", additionalExplanation)) + } } } } + } - for i, item := range cases { - for j := i + 1; j < len(cases); j++ { - if TypesEqual(item.Type, cases[j].Type) { - errorSink.Add(validationError(item, "all type cases in a union must be distinct")) - } - } - } + self.VisitChildren(node, visitingReference) - labels := make(map[string]any) - for _, item := range cases { - if item != nil { - if _, found := labels[item.Label]; found { - errorSink.Add(validationError(node, "union cases must have distict labels within the union")) - } - } - } + case *SimpleType: + if len(t.ResolvedDefinition.GetDefinitionMeta().TypeArguments) > 0 { + // Check the referenced type with the type arguments provided + self.Visit(t.ResolvedDefinition, true) } + default: + self.VisitChildren(node, visitingReference) } - - self.VisitChildren(node) }) return env diff --git a/tooling/pkg/dsl/validation_unions_test.go b/tooling/pkg/dsl/validation_unions_test.go index f48da6c9..c822000e 100644 --- a/tooling/pkg/dsl/validation_unions_test.go +++ b/tooling/pkg/dsl/validation_unions_test.go @@ -6,6 +6,7 @@ package dsl import ( "os" "path" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -48,16 +49,16 @@ X: !record assert.ErrorContains(t, err, "if null is specified in a union type, it must be the first option") } -func TestUnionElementsMustBeDistict(t *testing.T) { +func TestUnionElementsMustBeDistinct(t *testing.T) { src := ` X: !record fields: f: [int, int]` _, err := parseAndValidate(t, src) - assert.ErrorContains(t, err, "all type cases in a union must be distinct") + assert.ErrorContains(t, err, "redundant union type cases") } -func TestUnionElementsMustBeDistictWithGenerics(t *testing.T) { +func TestUnionElementsMustBeDistinctWithGenerics(t *testing.T) { src := ` X: !record fields: @@ -65,10 +66,10 @@ X: !record GenericRecord: !record` _, err := parseAndValidate(t, src) - assert.ErrorContains(t, err, "all type cases in a union must be distinct") + assert.ErrorContains(t, err, "redundant union type cases") } -func TestUnionElementsAreDistictWithGenerics(t *testing.T) { +func TestUnionElementsAreDistinctWithGenerics(t *testing.T) { src := ` X: !record fields: @@ -79,35 +80,36 @@ GenericRecord: !record` assert.Nil(t, err) } -func TestUnionElementsMustBeDistict_SameUnrecognizedType(t *testing.T) { +func TestUnionElementsMustBeDistinct_SameUnrecognizedType(t *testing.T) { src := ` +Bar: int X: !record fields: f: [Bar, Bar]` _, err := parseAndValidate(t, src) - assert.ErrorContains(t, err, "all type cases in a union must be distinct") + assert.ErrorContains(t, err, "redundant union type cases") } -func TestUnionElementsMustBeDistict_DifferentUnrecognizedType(t *testing.T) { +func TestUnionElementsMustBeDistinct_DifferentUnrecognizedType(t *testing.T) { src := ` X: !record fields: f: [Foo, Bar]` _, err := parseAndValidate(t, src) assert.NotNil(t, err) - assert.NotContains(t, err.Error(), "all type cases in a union must be distinct") + assert.NotContains(t, err.Error(), "redundant union type cases") } -func TestUnionElementsMustBeDistict_MultipleNulls(t *testing.T) { +func TestUnionElementsMustBeDistinct_MultipleNulls(t *testing.T) { src := ` X: !record fields: f: [null, null, null]` _, err := parseAndValidate(t, src) - assert.ErrorContains(t, err, "all type cases in a union must be distinct") + assert.ErrorContains(t, err, "redundant union type cases") } -func TestUnionElementsMustBeDistict_Complex(t *testing.T) { +func TestUnionElementsMustBeDistinct_Complex(t *testing.T) { src := ` X: !record fields: @@ -119,10 +121,10 @@ X: !record items: [int, float] length: 10` _, err := parseAndValidate(t, src) - assert.ErrorContains(t, err, "all type cases in a union must be distinct") + assert.ErrorContains(t, err, "redundant union type cases") } -func TestUnionElementsMustBeDistict_Nested(t *testing.T) { +func TestUnionElementsMustBeDistinct_Nested(t *testing.T) { src := ` X: !record fields: @@ -130,7 +132,118 @@ X: !record - int - [ float, float ]` _, err := parseAndValidate(t, src) - assert.ErrorContains(t, err, "all type cases in a union must be distinct") + assert.ErrorContains(t, err, "redundant union type cases") +} + +func TestUnionElementsMustBeDistinct_AliasedType(t *testing.T) { + src := ` +MyIntType: uint64 +MyRecord: !record + fields: + one: [uint64, MyIntType]` + _, err := parseAndValidate(t, src) + assert.ErrorContains(t, err, "redundant union type cases") +} + +func TestUnionElementsMustBeDistinct_AliasedWithinVector(t *testing.T) { + src := ` +MyIntType: uint64 +MyRecord: !record + fields: + f: + - !vector + items: MyIntType + - !vector + items: uint64` + _, err := parseAndValidate(t, src) + assert.ErrorContains(t, err, "redundant union type cases") +} + +func TestUnionElementsMustBeDistinct_DifferentGenericArgs(t *testing.T) { + src := ` +MyIntType: uint64 +Image: !array + items: T +MyRecord: !record + fields: + f: [Image, Image]` + _, err := parseAndValidate(t, src) + assert.Nil(t, err) +} + +func TestUnionElementsMustBeDistinct_GenericUnionAlias_AllUnique(t *testing.T) { + src := ` +MyUnionType: [T, U] +MyRecord: !record + fields: + f: MyUnionType` + _, err := parseAndValidate(t, src) + assert.Nil(t, err) +} + +func TestUnionElementsMustBeDistinct_GenericUnionAlias_NotUnique_MultipleTypeArgs(t *testing.T) { + src := ` +MyUnionType: [T, U] +MyRecord: !record + fields: + f: MyUnionType` + _, err := parseAndValidate(t, src) + require.NotNil(t, err) + assert.Regexp(t, ".yaml:2:21: redundant union type cases resulting from the type arguments given at .*.yaml:5:20 and .*.yaml:5:25", err.Error()) +} + +func TestUnionElementsMustBeDistinct_GenericUnionAlias_NotUnique_SingleTypeArg(t *testing.T) { + src := ` +MyUnionType: [T, int] +MyRecord: !record + fields: + f: MyUnionType` + _, err := parseAndValidate(t, src) + require.NotNil(t, err) + assert.Regexp(t, ".yaml:2:24: redundant union type cases resulting from the type argument given at .*.yaml:5:20$", err.Error()) +} + +func TestUnionElementsMustBeDistinct_GenericUnionAliasChain_SingleTypeArg(t *testing.T) { + src := ` +Rec: !record + fields: + f: [T, int] +Alias1: Rec +Alias2: Alias1` + _, err := parseAndValidate(t, src) + require.NotNil(t, err) + assert.Regexp(t, ".yaml:4:12: redundant union type cases resulting from the type argument given at .*.yaml:6:16$", err.Error()) +} + +func TestUnionElementsMustBeDistinct_GenericUnionAliasChain_ErrorsNotDuplicated(t *testing.T) { + src := ` +Rec: !record + fields: + f: [int, int] +Alias1: Rec +Alias2: Alias1` + _, err := parseAndValidate(t, src) + require.NotNil(t, err) + assert.Regexp(t, ".yaml:4:9: redundant union type cases$", err.Error()) + assert.Equal(t, 1, len(strings.Split(err.Error(), "\n"))) +} + +func TestUnionElementsMustBeDistinct_SizeAndUnit64Direct(t *testing.T) { + src := ` +T: [size, uint64]` + _, err := parseAndValidate(t, src) + require.NotNil(t, err) + assert.ErrorContains(t, err, "redundant union type cases (uint64 and size are equivalent)") +} + +func TestUnionElementsMustBeDistinct_SizeAndUnit64Aliased(t *testing.T) { + src := ` +MySize: size +MyUint64: uint64 +T: [MySize, MyUint64]` + _, err := parseAndValidate(t, src) + require.NotNil(t, err) + assert.ErrorContains(t, err, "redundant union type cases (uint64 and size are equivalent)") } func TestVectorLengthCannotBeNegative(t *testing.T) { diff --git a/tooling/pkg/dsl/visitor.go b/tooling/pkg/dsl/visitor.go index e5ee1838..1abf9bcd 100644 --- a/tooling/pkg/dsl/visitor.go +++ b/tooling/pkg/dsl/visitor.go @@ -23,8 +23,8 @@ type VisitorFunc func(self Visitor, node Node) // Visits a Node tree from the given root, threading a context parameter throughout. func VisitWithContext[T any](root Node, context T, visitor VisitorWithContextFunc[T]) { - vistorWithContext := VisitorWithContext[T]{visitorWithContext: visitor} - visitor(vistorWithContext, root, context) + visitorWithContext := VisitorWithContext[T]{visitorWithContext: visitor} + visitor(visitorWithContext, root, context) } type VisitorWithContextFunc[T any] func(self VisitorWithContext[T], node Node, context T)