diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/ScmModelProvider/ScmModelProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/ScmModelProvider/ScmModelProviderTests.cs index bbbccdb3af3..e0d773dd550 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/ScmModelProvider/ScmModelProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/ScmModelProvider/ScmModelProviderTests.cs @@ -105,7 +105,6 @@ public void TestNestedDiscriminatorDynamicModel(bool discriminatedTypeIsDynamicM ]); InputModelType catModel = InputFactory.Model( "cat", - discriminatedKind: "cat", properties: [ InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs index ada1ec67c73..9ed1df69d43 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/src/Providers/ModelProvider.cs @@ -52,6 +52,7 @@ protected override FormattableString BuildDescription() } private readonly bool _isAbstract; + private readonly bool _isMultiLevelDiscriminator; private readonly CSharpType _additionalBinaryDataPropsFieldType = typeof(IDictionary); private readonly Type _additionalPropsUnknownType = typeof(BinaryData); @@ -66,6 +67,7 @@ public ModelProvider(InputModelType inputModel) : base(inputModel) { _inputModel = inputModel; _isAbstract = _inputModel.DiscriminatorProperty is not null && _inputModel.DiscriminatorValue is null; + _isMultiLevelDiscriminator = ComputeIsMultiLevelDiscriminator(); if (_inputModel.BaseModel is not null) { @@ -522,7 +524,7 @@ protected override ConstructorProvider[] BuildConstructors() return [FullConstructor]; } - // Build the initialization constructor + // Build the standard single initialization constructor var accessibility = DeclarationModifiers.HasFlag(TypeSignatureModifiers.Abstract) ? MethodSignatureModifiers.Private | MethodSignatureModifiers.Protected : _inputModel.Usage.HasFlag(InputModelTypeUsage.Input) @@ -543,12 +545,80 @@ protected override ConstructorProvider[] BuildConstructors() }, this); + var constructors = new List { constructor }; + + // Add FullConstructor if parameters are different if (!constructorParameters.SequenceEqual(FullConstructor.Signature.Parameters)) { - return [constructor, FullConstructor]; + constructors.Add(FullConstructor); + } + + // For multi-level discriminators, add one additional private protected constructor + if (_isMultiLevelDiscriminator) + { + var protectedConstructor = BuildProtectedInheritanceConstructor(); + constructors.Add(protectedConstructor); } - return [constructor]; + return [.. constructors]; + } + + /// + /// Determines if this model should have a dual constructor pattern. + /// This is needed when the model shares the same discriminator property name as its base model + /// AND has derived models, indicating it's an intermediate type in a discriminated union hierarchy. + /// + private bool ComputeIsMultiLevelDiscriminator() + { + // Only applies to non-abstract models with a base model + if (_isAbstract || _inputModel.BaseModel == null) + { + return false; + } + // Must have derived models to be considered an intermediate type + if (_inputModel.DerivedModels.Count == 0) + { + return false; + } + + // Check if this model has a discriminator property in the input + if (_inputModel.DiscriminatorProperty == null) + { + return false; + } + + // Check if base model has a discriminator property with the same name + if (_inputModel.BaseModel.DiscriminatorProperty == null) + { + return false; + } + + // If both models have discriminator properties with the same name, + // and this model has derived models, it needs the dual constructor pattern + return _inputModel.DiscriminatorProperty.Name == + _inputModel.BaseModel.DiscriminatorProperty.Name; + } + + /// + /// Builds a private protected constructor for multi-level discriminator inheritance. + /// This allows derived models to call this constructor with their discriminator value. + /// + private ConstructorProvider BuildProtectedInheritanceConstructor() + { + var (parameters, initializer) = BuildConstructorParameters(true, includeDiscriminatorParameter: true); + + return new ConstructorProvider( + signature: new ConstructorSignature( + Type, + $"Initializes a new instance of {Type:C}", + MethodSignatureModifiers.Private | MethodSignatureModifiers.Protected, + parameters, + initializer: initializer), + bodyStatements: new MethodBodyStatement[] + { + GetPropertyInitializers(true, parameters: parameters) + }, + this); } /// @@ -558,7 +628,6 @@ protected override ConstructorProvider[] BuildConstructors() private ConstructorProvider BuildFullConstructor() { var (ctorParameters, ctorInitializer) = BuildConstructorParameters(false); - return new ConstructorProvider( signature: new ConstructorSignature( Type, @@ -573,7 +642,7 @@ private ConstructorProvider BuildFullConstructor() this); } - private IEnumerable GetAllBasePropertiesForConstructorInitialization() + private IEnumerable GetAllBasePropertiesForConstructorInitialization(bool includeAllHierarchyDiscriminator = false) { var properties = new Stack>(); var modelProvider = BaseModelProvider; @@ -585,9 +654,8 @@ private IEnumerable GetAllBasePropertiesForConstructorInitiali { if (property.IsDiscriminator) { - // In the case of nested discriminators, we only need to include the direct base discriminator property, - // as this is the only one that will be initialized in this model's constructor. - if (isDirectBase) + // In the case of nested discriminators, include discriminator property based on the parameter + if (isDirectBase || includeAllHierarchyDiscriminator) { properties.Peek().Add(property); } @@ -624,16 +692,16 @@ private IEnumerable GetAllBaseFieldsForConstructorInitialization( } private (IReadOnlyList Parameters, ConstructorInitializer? Initializer) BuildConstructorParameters( - bool isPrimaryConstructor) + bool isInitializationConstructor, bool includeDiscriminatorParameter = false) { var baseParameters = new List(); var constructorParameters = new List(); IEnumerable baseProperties = []; IEnumerable baseFields = []; - if (isPrimaryConstructor) + if (isInitializationConstructor) { - baseProperties = GetAllBasePropertiesForConstructorInitialization(); + baseProperties = GetAllBasePropertiesForConstructorInitialization(includeDiscriminatorParameter); baseFields = GetAllBaseFieldsForConstructorInitialization(); } else if (BaseModelProvider?.FullConstructor.Signature != null) @@ -646,36 +714,61 @@ private IEnumerable GetAllBaseFieldsForConstructorInitialization( // add the base parameters, if any foreach (var property in baseProperties) { - AddInitializationParameterForCtor(baseParameters, Type.IsStruct, isPrimaryConstructor, property); + AddInitializationParameterForCtor(baseParameters, Type.IsStruct, isInitializationConstructor, property); } // add the base fields, if any foreach (var field in baseFields) { - AddInitializationParameterForCtor(baseParameters, Type.IsStruct, isPrimaryConstructor, field: field); + AddInitializationParameterForCtor(baseParameters, Type.IsStruct, isInitializationConstructor, field: field); } // construct the initializer using the parameters from base signature - var constructorInitializer = new ConstructorInitializer(true, [.. baseParameters.Select(p => GetExpressionForCtor(p, overriddenProperties, isPrimaryConstructor))]); + ConstructorInitializer? constructorInitializer = null; + if (BaseModelProvider != null) + { + if (baseParameters.Count > 0) + { + // Check if base model has dual constructor pattern and we should call private protected constructor + if (isInitializationConstructor && BaseModelProvider._isMultiLevelDiscriminator) + { + // Call base model's private protected constructor with discriminator value + var args = new List(); + args.Add(Literal(_inputModel.DiscriminatorValue ?? "")); + var filteredParams = baseParameters.Where(p => p.Property is null || !p.Property.IsDiscriminator).ToList(); + args.AddRange(filteredParams.Select(p => GetExpressionForCtor(p, overriddenProperties, isInitializationConstructor))); + constructorInitializer = new ConstructorInitializer(true, args); + } + else + { + // Standard base constructor call + constructorInitializer = new ConstructorInitializer(true, [.. baseParameters.Select(p => GetExpressionForCtor(p, overriddenProperties, isInitializationConstructor))]); + } + } + else + { + // Even when no base parameters, we still need a base constructor call if there's a base model + constructorInitializer = new ConstructorInitializer(true, Array.Empty()); + } + } foreach (var property in CanonicalView.Properties) { - AddInitializationParameterForCtor(constructorParameters, Type.IsStruct, isPrimaryConstructor, property); + AddInitializationParameterForCtor(constructorParameters, Type.IsStruct, isInitializationConstructor, property); } foreach (var field in CanonicalView.Fields) { - AddInitializationParameterForCtor(constructorParameters, Type.IsStruct, isPrimaryConstructor, field: field); + AddInitializationParameterForCtor(constructorParameters, Type.IsStruct, isInitializationConstructor, field: field); } constructorParameters.InsertRange(0, _inputModel.IsUnknownDiscriminatorModel ? baseParameters : baseParameters.Where(p => p.Property is null - || (p.Property.IsDiscriminator && !overriddenProperties.Contains(p.Property) && !isPrimaryConstructor) - || (!p.Property.IsDiscriminator && !overriddenProperties.Contains(p.Property)))); + || (!overriddenProperties.Contains(p.Property!) && (!p.Property.IsDiscriminator || !isInitializationConstructor || includeDiscriminatorParameter)))); - if (!isPrimaryConstructor) + if (!isInitializationConstructor) { foreach (var property in AdditionalPropertyProperties) { diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs index d3b32196e21..234654be7eb 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator/test/Providers/ModelProviders/ModelProviderTests.cs @@ -1289,5 +1289,258 @@ public void ModelWithOptionalDiscriminatorProperty() var file = writer.Write(); Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content); } + + [Test] + public void TestMultiLayerDiscriminator_IntermediateWithoutDiscriminator() + { + // Test hierarchy: Pet (base, no discriminator) → Cat (intermediate, no discriminator) → Tiger (leaf, discriminator: "tiger") + + InputModelType tigerModel = InputFactory.Model( + "tiger", + discriminatedKind: "tiger", + properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("stripes", InputPrimitiveType.Int32, isRequired: true) + ]); + + InputModelType catModel = InputFactory.Model( + "cat", + properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("meows", InputPrimitiveType.Boolean, isRequired: true) + ], + discriminatedModels: new Dictionary() { { "tiger", tigerModel } }); + + var baseModel = InputFactory.Model( + "pet", + properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("name", InputPrimitiveType.String, isRequired: true) + ], + discriminatedModels: new Dictionary() { { "cat", catModel } }); + + MockHelpers.LoadMockGenerator(inputModelTypes: [baseModel, catModel, tigerModel]); + + var tigerProvider = new ModelProvider(tigerModel); + + Assert.AreEqual(2, tigerProvider.Constructors.Count); + + var publicConstructor = tigerProvider.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public)); + Assert.IsNotNull(publicConstructor); + Assert.AreEqual(MethodSignatureModifiers.Public, publicConstructor!.Signature.Modifiers); + + // Tiger's public constructor should have parameters: name (from Pet), meows (from Cat), stripes (from Tiger) + var publicParams = publicConstructor.Signature.Parameters; + Assert.AreEqual(3, publicParams.Count); + Assert.AreEqual("name", publicParams[0].Name); + Assert.AreEqual(typeof(string), publicParams[0].Type.FrameworkType); + Assert.AreEqual("meows", publicParams[1].Name); + Assert.AreEqual(typeof(bool), publicParams[1].Type.FrameworkType); + Assert.AreEqual("stripes", publicParams[2].Name); + Assert.AreEqual(typeof(int), publicParams[2].Type.FrameworkType); + + // Tiger should call base constructor with only base parameters (no discriminator from Cat since Cat doesn't have one) + var initializer = publicConstructor.Signature.Initializer; + Assert.IsNotNull(initializer); + Assert.IsTrue(initializer!.IsBase); + + // Should have name and meows parameters from base chain (no discriminator since Cat has no discriminatedKind) + Assert.AreEqual(2, initializer.Arguments.Count); + Assert.AreEqual("name", initializer.Arguments[0].ToDisplayString()); + Assert.AreEqual("meows", initializer.Arguments[1].ToDisplayString()); + + // Verify internal (serialization) constructor signature and parameters + var internalConstructor = tigerProvider.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Internal)); + Assert.IsNotNull(internalConstructor); + Assert.AreEqual(MethodSignatureModifiers.Internal, internalConstructor!.Signature.Modifiers); + + // Internal constructor should have all parameters including serialization params + var internalParams = internalConstructor.Signature.Parameters; + Assert.IsTrue(internalParams.Count >= 4); + + var internalInitializer = internalConstructor.Signature.Initializer; + Assert.IsNotNull(internalInitializer); + Assert.IsTrue(internalInitializer!.IsBase); + } + + [Test] + public void TestMultiLayerDiscriminator_IntermediateWithDiscriminator() + { + // Test hierarchy: Pet (base, no discriminator) → Cat (intermediate, discriminator: "cat") → DomesticCat (leaf, discriminator: "domestic") + + InputModelType domesticCatModel = InputFactory.Model( + "domesticCat", + discriminatedKind: "domestic", + properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("breed", InputPrimitiveType.String, isRequired: true) + ]); + + InputModelType catModel = InputFactory.Model( + "cat", + discriminatedKind: "cat", + properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("meows", InputPrimitiveType.Boolean, isRequired: true) + ], + discriminatedModels: new Dictionary() { { "domestic", domesticCatModel } }); + + var baseModel = InputFactory.Model( + "pet", + properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("name", InputPrimitiveType.String, isRequired: true) + ], + discriminatedModels: new Dictionary() { { "cat", catModel } }); + + MockHelpers.LoadMockGenerator(inputModelTypes: [baseModel, catModel, domesticCatModel]); + + var domesticCatProvider = new ModelProvider(domesticCatModel); + var catProvider = new ModelProvider(catModel); + + Assert.AreEqual(2, domesticCatProvider.Constructors.Count); + + var publicConstructor = domesticCatProvider.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public)); + Assert.IsNotNull(publicConstructor); + Assert.AreEqual(MethodSignatureModifiers.Public, publicConstructor!.Signature.Modifiers); + + // DomesticCat's public constructor should have parameters: name (from Pet), meows (from Cat), breed (from DomesticCat) + var publicParams = publicConstructor.Signature.Parameters; + Assert.AreEqual(3, publicParams.Count); + Assert.AreEqual("name", publicParams[0].Name); + Assert.AreEqual(typeof(string), publicParams[0].Type.FrameworkType); + Assert.AreEqual("meows", publicParams[1].Name); + Assert.AreEqual(typeof(bool), publicParams[1].Type.FrameworkType); + Assert.AreEqual("breed", publicParams[2].Name); + Assert.AreEqual(typeof(string), publicParams[2].Type.FrameworkType); + + // DomesticCat should call Cat's dual constructor with discriminator value since Cat has discriminatedKind + var initializer = publicConstructor.Signature.Initializer; + Assert.IsNotNull(initializer); + Assert.IsTrue(initializer!.IsBase); + + // Should have discriminator + parameters from Cat's dual constructor pattern + Assert.AreEqual(3, initializer.Arguments.Count); // discriminator "domestic" + name + meows + Assert.AreEqual("\"domestic\"", initializer.Arguments[0].ToDisplayString()); + Assert.AreEqual("name", initializer.Arguments[1].ToDisplayString()); + Assert.AreEqual("meows", initializer.Arguments[2].ToDisplayString()); + + // Verify Cat also has dual constructor pattern (public and protected) + Assert.AreEqual(3, catProvider.Constructors.Count); + + var internalConstructor = domesticCatProvider.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Internal)); + Assert.IsNotNull(internalConstructor); + Assert.AreEqual(MethodSignatureModifiers.Internal, internalConstructor!.Signature.Modifiers); + + var internalInitializer = internalConstructor.Signature.Initializer; + Assert.IsNotNull(internalInitializer); + Assert.IsTrue(internalInitializer!.IsBase); + } + + [Test] + public void TestMultiLayerDiscriminator_ThreeLayers() + { + // Test hierarchy: Animal (base) → Pet (discriminator: "pet") → Cat (discriminator: "cat") → DomesticCat (discriminator: "domestic") + + InputModelType domesticCatModel = InputFactory.Model( + "domesticCat", + discriminatedKind: "domestic", + properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("breed", InputPrimitiveType.String, isRequired: true) + ]); + + InputModelType catModel = InputFactory.Model( + "cat", + discriminatedKind: "cat", + properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("meows", InputPrimitiveType.Boolean, isRequired: true) + ], + discriminatedModels: new Dictionary() { { "domestic", domesticCatModel } }); + + InputModelType petModel = InputFactory.Model( + "pet", + discriminatedKind: "pet", + properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("name", InputPrimitiveType.String, isRequired: true) + ], + discriminatedModels: new Dictionary() { { "cat", catModel } }); + + var animalModel = InputFactory.Model( + "animal", + properties: + [ + InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true), + InputFactory.Property("species", InputPrimitiveType.String, isRequired: true) + ], + discriminatedModels: new Dictionary() { { "pet", petModel } }); + + MockHelpers.LoadMockGenerator(inputModelTypes: [animalModel, petModel, catModel, domesticCatModel]); + + var domesticCatProvider = new ModelProvider(domesticCatModel); + var catProvider = new ModelProvider(catModel); + var petProvider = new ModelProvider(petModel); + var animalProvider = new ModelProvider(animalModel); + + Assert.IsNotNull(domesticCatProvider.BaseModelProvider); + Assert.IsNotNull(domesticCatProvider.BaseModelProvider!.BaseModelProvider); + Assert.IsNotNull(domesticCatProvider.BaseModelProvider!.BaseModelProvider!.BaseModelProvider); + + Assert.AreEqual(2, domesticCatProvider.Constructors.Count); // public, internal (leaf type) + Assert.AreEqual(3, catProvider.Constructors.Count); // public, protected with discriminator, internal + Assert.AreEqual(3, petProvider.Constructors.Count); // public, protected with discriminator, internal + Assert.AreEqual(2, animalProvider.Constructors.Count); // public, internal (base type) + + // Verify DomesticCat's public constructor parameters: species (from Animal), name (from Pet), meows (from Cat), breed (from DomesticCat) + var publicConstructor = domesticCatProvider.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Public)); + Assert.IsNotNull(publicConstructor); + Assert.AreEqual(MethodSignatureModifiers.Public, publicConstructor!.Signature.Modifiers); + + var publicParams = publicConstructor.Signature.Parameters; + Assert.AreEqual(4, publicParams.Count); + Assert.AreEqual("species", publicParams[0].Name); + Assert.AreEqual(typeof(string), publicParams[0].Type.FrameworkType); + Assert.AreEqual("name", publicParams[1].Name); + Assert.AreEqual(typeof(string), publicParams[1].Type.FrameworkType); + Assert.AreEqual("meows", publicParams[2].Name); + Assert.AreEqual(typeof(bool), publicParams[2].Type.FrameworkType); + Assert.AreEqual("breed", publicParams[3].Name); + Assert.AreEqual(typeof(string), publicParams[3].Type.FrameworkType); + + var initializer = publicConstructor.Signature.Initializer; + Assert.IsNotNull(initializer); + Assert.IsTrue(initializer!.IsBase); + + Assert.AreEqual(4, initializer.Arguments.Count); // discriminator "domestic" + species + name + meows + Assert.AreEqual("\"domestic\"", initializer.Arguments[0].ToDisplayString()); + Assert.AreEqual("species", initializer.Arguments[1].ToDisplayString()); + Assert.AreEqual("name", initializer.Arguments[2].ToDisplayString()); + Assert.AreEqual("meows", initializer.Arguments[3].ToDisplayString()); + + // Verify Cat's protected constructor exists and has correct signature + var catProtectedConstructor = catProvider.Constructors.FirstOrDefault(c => c.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Protected)); + Assert.IsNotNull(catProtectedConstructor); + Assert.IsTrue(catProtectedConstructor!.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Protected)); + + // Cat's protected constructor should have: discriminator + species + name + meows + var catProtectedParams = catProtectedConstructor.Signature.Parameters; + Assert.AreEqual(4, catProtectedParams.Count); + Assert.AreEqual("kind", catProtectedParams[0].Name); + Assert.AreEqual(typeof(string), catProtectedParams[0].Type.FrameworkType); + Assert.AreEqual("species", catProtectedParams[1].Name); + Assert.AreEqual("name", catProtectedParams[2].Name); + Assert.AreEqual("meows", catProtectedParams[3].Name); + } } }