Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ public void TestNestedDiscriminatorDynamicModel(bool discriminatedTypeIsDynamicM
]);
InputModelType catModel = InputFactory.Model(
"cat",
discriminatedKind: "cat",
properties:
[
InputFactory.Property("kind", InputPrimitiveType.String, isRequired: true, isDiscriminator: true),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ protected override FormattableString BuildDescription()
}

private readonly bool _isAbstract;
private readonly bool _isMultiLevelDiscriminator;

private readonly CSharpType _additionalBinaryDataPropsFieldType = typeof(IDictionary<string, BinaryData>);
private readonly Type _additionalPropsUnknownType = typeof(BinaryData);
Expand All @@ -66,6 +67,7 @@ public ModelProvider(InputModelType inputModel) : base(inputModel)
{
_inputModel = inputModel;
_isAbstract = _inputModel.DiscriminatorProperty is not null && _inputModel.DiscriminatorValue is null;
_isMultiLevelDiscriminator = ComputeIsMultiLevelDiscriminator();

if (_inputModel.BaseModel is not null)
{
Expand Down Expand Up @@ -522,7 +524,7 @@ protected override ConstructorProvider[] BuildConstructors()
return [FullConstructor];
}

// Build the initialization constructor
// Build the standard single initialization constructor
var accessibility = DeclarationModifiers.HasFlag(TypeSignatureModifiers.Abstract)
? MethodSignatureModifiers.Private | MethodSignatureModifiers.Protected
: _inputModel.Usage.HasFlag(InputModelTypeUsage.Input)
Expand All @@ -543,12 +545,80 @@ protected override ConstructorProvider[] BuildConstructors()
},
this);

var constructors = new List<ConstructorProvider> { constructor };

// Add FullConstructor if parameters are different
if (!constructorParameters.SequenceEqual(FullConstructor.Signature.Parameters))
{
return [constructor, FullConstructor];
constructors.Add(FullConstructor);
}

// For multi-level discriminators, add one additional private protected constructor
if (_isMultiLevelDiscriminator)
{
var protectedConstructor = BuildProtectedInheritanceConstructor();
constructors.Add(protectedConstructor);
}

return [constructor];
return [.. constructors];
}

/// <summary>
/// Determines if this model should have a dual constructor pattern.
/// This is needed when the model shares the same discriminator property name as its base model
/// AND has derived models, indicating it's an intermediate type in a discriminated union hierarchy.
/// </summary>
private bool ComputeIsMultiLevelDiscriminator()
{
// Only applies to non-abstract models with a base model
if (_isAbstract || _inputModel.BaseModel == null)
{
return false;
}
// Must have derived models to be considered an intermediate type
if (_inputModel.DerivedModels.Count == 0)
{
return false;
}

// Check if this model has a discriminator property in the input
if (_inputModel.DiscriminatorProperty == null)
{
return false;
}

// Check if base model has a discriminator property with the same name
if (_inputModel.BaseModel.DiscriminatorProperty == null)
{
return false;
}

// If both models have discriminator properties with the same name,
// and this model has derived models, it needs the dual constructor pattern
return _inputModel.DiscriminatorProperty.Name ==
_inputModel.BaseModel.DiscriminatorProperty.Name;
}

/// <summary>
/// Builds a private protected constructor for multi-level discriminator inheritance.
/// This allows derived models to call this constructor with their discriminator value.
/// </summary>
private ConstructorProvider BuildProtectedInheritanceConstructor()
{
var (parameters, initializer) = BuildConstructorParameters(true, includeDiscriminatorParameter: true);

return new ConstructorProvider(
signature: new ConstructorSignature(
Type,
$"Initializes a new instance of {Type:C}",
MethodSignatureModifiers.Private | MethodSignatureModifiers.Protected,
parameters,
initializer: initializer),
bodyStatements: new MethodBodyStatement[]
{
GetPropertyInitializers(true, parameters: parameters)
},
this);
}

/// <summary>
Expand All @@ -558,7 +628,6 @@ protected override ConstructorProvider[] BuildConstructors()
private ConstructorProvider BuildFullConstructor()
{
var (ctorParameters, ctorInitializer) = BuildConstructorParameters(false);

return new ConstructorProvider(
signature: new ConstructorSignature(
Type,
Expand All @@ -573,7 +642,7 @@ private ConstructorProvider BuildFullConstructor()
this);
}

private IEnumerable<PropertyProvider> GetAllBasePropertiesForConstructorInitialization()
private IEnumerable<PropertyProvider> GetAllBasePropertiesForConstructorInitialization(bool includeAllHierarchyDiscriminator = false)
{
var properties = new Stack<List<PropertyProvider>>();
var modelProvider = BaseModelProvider;
Expand All @@ -585,9 +654,8 @@ private IEnumerable<PropertyProvider> GetAllBasePropertiesForConstructorInitiali
{
if (property.IsDiscriminator)
{
// In the case of nested discriminators, we only need to include the direct base discriminator property,
// as this is the only one that will be initialized in this model's constructor.
if (isDirectBase)
// In the case of nested discriminators, include discriminator property based on the parameter
if (isDirectBase || includeAllHierarchyDiscriminator)
{
properties.Peek().Add(property);
}
Expand Down Expand Up @@ -624,16 +692,16 @@ private IEnumerable<FieldProvider> GetAllBaseFieldsForConstructorInitialization(
}

private (IReadOnlyList<ParameterProvider> Parameters, ConstructorInitializer? Initializer) BuildConstructorParameters(
bool isPrimaryConstructor)
bool isInitializationConstructor, bool includeDiscriminatorParameter = false)
{
var baseParameters = new List<ParameterProvider>();
var constructorParameters = new List<ParameterProvider>();
IEnumerable<PropertyProvider> baseProperties = [];
IEnumerable<FieldProvider> baseFields = [];

if (isPrimaryConstructor)
if (isInitializationConstructor)
{
baseProperties = GetAllBasePropertiesForConstructorInitialization();
baseProperties = GetAllBasePropertiesForConstructorInitialization(includeDiscriminatorParameter);
baseFields = GetAllBaseFieldsForConstructorInitialization();
}
else if (BaseModelProvider?.FullConstructor.Signature != null)
Expand All @@ -646,36 +714,61 @@ private IEnumerable<FieldProvider> GetAllBaseFieldsForConstructorInitialization(
// add the base parameters, if any
foreach (var property in baseProperties)
{
AddInitializationParameterForCtor(baseParameters, Type.IsStruct, isPrimaryConstructor, property);
AddInitializationParameterForCtor(baseParameters, Type.IsStruct, isInitializationConstructor, property);
}

// add the base fields, if any
foreach (var field in baseFields)
{
AddInitializationParameterForCtor(baseParameters, Type.IsStruct, isPrimaryConstructor, field: field);
AddInitializationParameterForCtor(baseParameters, Type.IsStruct, isInitializationConstructor, field: field);
}

// construct the initializer using the parameters from base signature
var constructorInitializer = new ConstructorInitializer(true, [.. baseParameters.Select(p => GetExpressionForCtor(p, overriddenProperties, isPrimaryConstructor))]);
ConstructorInitializer? constructorInitializer = null;
if (BaseModelProvider != null)
{
if (baseParameters.Count > 0)
{
// Check if base model has dual constructor pattern and we should call private protected constructor
if (isInitializationConstructor && BaseModelProvider._isMultiLevelDiscriminator)
{
// Call base model's private protected constructor with discriminator value
var args = new List<ValueExpression>();
args.Add(Literal(_inputModel.DiscriminatorValue ?? ""));
var filteredParams = baseParameters.Where(p => p.Property is null || !p.Property.IsDiscriminator).ToList();
args.AddRange(filteredParams.Select(p => GetExpressionForCtor(p, overriddenProperties, isInitializationConstructor)));
constructorInitializer = new ConstructorInitializer(true, args);
}
else
{
// Standard base constructor call
constructorInitializer = new ConstructorInitializer(true, [.. baseParameters.Select(p => GetExpressionForCtor(p, overriddenProperties, isInitializationConstructor))]);
}
}
else
{
// Even when no base parameters, we still need a base constructor call if there's a base model
constructorInitializer = new ConstructorInitializer(true, Array.Empty<ValueExpression>());
}
}

foreach (var property in CanonicalView.Properties)
{
AddInitializationParameterForCtor(constructorParameters, Type.IsStruct, isPrimaryConstructor, property);
AddInitializationParameterForCtor(constructorParameters, Type.IsStruct, isInitializationConstructor, property);
}

foreach (var field in CanonicalView.Fields)
{
AddInitializationParameterForCtor(constructorParameters, Type.IsStruct, isPrimaryConstructor, field: field);
AddInitializationParameterForCtor(constructorParameters, Type.IsStruct, isInitializationConstructor, field: field);
}

constructorParameters.InsertRange(0, _inputModel.IsUnknownDiscriminatorModel
? baseParameters
: baseParameters.Where(p =>
p.Property is null
|| (p.Property.IsDiscriminator && !overriddenProperties.Contains(p.Property) && !isPrimaryConstructor)
|| (!p.Property.IsDiscriminator && !overriddenProperties.Contains(p.Property))));
|| (!overriddenProperties.Contains(p.Property!) && (!p.Property.IsDiscriminator || !isInitializationConstructor || includeDiscriminatorParameter))));

if (!isPrimaryConstructor)
if (!isInitializationConstructor)
{
foreach (var property in AdditionalPropertyProperties)
{
Expand Down
Loading
Loading