Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ private void ValidateSecondaryConstructor(
// auth should be the only parameter if endpoint is optional when there is auth
if (_hasSupportedAuth)
{
var expectedName = ctorParams?[0].Type?.Equals(ClientPipelineProvider.Instance.TokenCredentialType) == true
var expectedName = ctorParams?[0].Type?.Equals(ClientPipelineProvider.Instance.TokenCredentialType!) == true
? "tokenProvider"
: "credential";
Assert.AreEqual(expectedName, ctorParams?[0].Name);
Expand All @@ -1179,7 +1179,7 @@ private void ValidateSecondaryConstructor(
Assert.AreEqual(KnownParameters.Endpoint.Name, ctorParams?[0].Name);
if (_hasSupportedAuth)
{
var expectedName = ctorParams?[1].Type?.Equals(ClientPipelineProvider.Instance.TokenCredentialType) == true
var expectedName = ctorParams?[1].Type?.Equals(ClientPipelineProvider.Instance.TokenCredentialType!) == true
? "tokenProvider"
: "credential";
Assert.AreEqual(expectedName, ctorParams?[1].Name);
Expand Down Expand Up @@ -2787,9 +2787,7 @@ public async Task BackCompatibility_ProtocolMethodParamOrderChanged()
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

// Use reflection to invoke internal ProcessTypeForBackCompatibility method
var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var methods = clientProvider!.Methods;
var protocolMethods = methods
Expand Down Expand Up @@ -2874,9 +2872,7 @@ public async Task BackCompatibility_ConvenienceMethodParamOrderChanged()
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

// Use reflection to invoke internal ProcessTypeForBackCompatibility method
var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var methods = clientProvider!.Methods;
var convenienceMethods = methods
Expand Down Expand Up @@ -2956,9 +2952,7 @@ public async Task BackCompatibility_BothMethodsParamOrderChanged()
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

// Use reflection to invoke internal ProcessTypeForBackCompatibility method
var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var methods = clientProvider.Methods;
var protocolMethods = methods
Expand Down Expand Up @@ -3055,9 +3049,7 @@ public async Task BackCompatibility_ExactMatchWithCompatibleOverload()
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

// Use reflection to invoke internal ProcessTypeForBackCompatibility method
var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var methods = clientProvider.Methods;
var processDataMethods = methods
Expand Down Expand Up @@ -3143,10 +3135,8 @@ public async Task BackCompatibility_DuplicateMethodSignatureDoesNotThrow()
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

// Use reflection to invoke internal ProcessTypeForBackCompatibility method
// This should not throw even when there are duplicate method signatures
var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
Assert.DoesNotThrow(() => processMethod?.Invoke(clientProvider, null));
Assert.DoesNotThrow(() => clientProvider!.ProcessTypeForBackCompatibility());
}

// Last contract has GetData(int param1, string param2, CancellationToken) (and async).
Expand Down Expand Up @@ -3183,8 +3173,7 @@ public async Task BackCompatibility_NewOptionalNonBodyParameterAdded()
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var writer = new TypeProviderWriter(new FilteredMethodsTypeProvider(clientProvider!, name => name == "GetData" || name == "GetDataAsync"));
var file = writer.Write();
Expand Down Expand Up @@ -3226,8 +3215,7 @@ public async Task BackCompatibility_MultipleNewOptionalNonBodyParametersAdded()
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var writer = new TypeProviderWriter(new FilteredMethodsTypeProvider(clientProvider!, name => name == "GetData" || name == "GetDataAsync"));
var file = writer.Write();
Expand Down Expand Up @@ -3272,8 +3260,7 @@ public async Task BackCompatibility_NewOptionalNonBodyParameterAddedWithModelBod
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var writer = new TypeProviderWriter(new FilteredMethodsTypeProvider(clientProvider!, name => name == "GetData" || name == "GetDataAsync"));
var file = writer.Write();
Expand Down Expand Up @@ -3312,8 +3299,7 @@ public async Task BackCompatibility_NewOptionalBodyParameterDoesNotAddBackCompat
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var writer = new TypeProviderWriter(new FilteredMethodsTypeProvider(clientProvider!, name => name == "GetData" || name == "GetDataAsync"));
var file = writer.Write();
Expand Down Expand Up @@ -3353,8 +3339,7 @@ public async Task BackCompatibility_NewRequiredParameterDoesNotAddBackCompatOver
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var writer = new TypeProviderWriter(new FilteredMethodsTypeProvider(clientProvider!, name => name == "GetData" || name == "GetDataAsync"));
var file = writer.Write();
Expand Down Expand Up @@ -3399,14 +3384,71 @@ public async Task BackCompatibility_NewOptionalNonBodyParameterAddedWithPathAndH
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var writer = new TypeProviderWriter(new FilteredMethodsTypeProvider(clientProvider!, name => name == "GetData" || name == "GetDataAsync"));
var file = writer.Write();
Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

// Verifies that a back-compat overload added by ProcessTypeForBackCompatibility is removed
// when the user has explicitly suppressed it via [CodeGenSuppress] in custom code. Without
// filtering in ProcessTypeForBackCompatibility, the suppressed overload would be re-introduced
// after the FilterAllCustomizedMembers pass, leaking into the generated client.
[Test]
public async Task BackCompatibility_BackCompatOverloadSuppressedByCustomCode()
{
// Last contract had a GetData(int param1, string param2, CancellationToken) method.
// Current method adds an optional non-body parameter `param3`, which would normally
// cause a back-compat overload matching the previous signature to be added.
var param1 = InputFactory.QueryParameter("param1", InputPrimitiveType.Int32, isRequired: true);
var param2 = InputFactory.BodyParameter("param2", InputPrimitiveType.String, isRequired: true);
var param3 = InputFactory.HeaderParameter("param3", InputPrimitiveType.Boolean, isRequired: false);

var operation = InputFactory.Operation(
"GetData",
parameters: [param1, param2, param3],
responses: [InputFactory.OperationResponse([200], bodytype: InputPrimitiveType.String)]);

List<InputMethodParameter> methodParameters =
[
InputFactory.MethodParameter("param1", InputPrimitiveType.Int32, location: InputRequestLocation.Query, isRequired: true),
InputFactory.MethodParameter("param2", InputPrimitiveType.String, location: InputRequestLocation.Body, isRequired: true),
InputFactory.MethodParameter("param3", InputPrimitiveType.Boolean, location: InputRequestLocation.Header, isRequired: false),
];

var method = InputFactory.BasicServiceMethod("GetData", operation, parameters: [.. methodParameters]);
var client = InputFactory.Client(TestClientName, methods: [method]);

var generator = await MockHelpers.LoadMockGeneratorAsync(
clients: () => [client],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync(parameters: "Custom"),
lastContractCompilation: async () => await Helpers.GetCompilationFromDirectoryAsync(parameters: "Last"));

var clientProvider = generator.Object.OutputLibrary.TypeProviders.OfType<ClientProvider>().FirstOrDefault();
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);
Assert.IsNotNull(clientProvider.CustomCodeView);

clientProvider.ProcessTypeForBackCompatibility();

// The current spec method has param3, the back-compat overload would NOT have param3.
// With the fix, the suppressed back-compat overload should not appear in the methods.
var getDataOverloads = clientProvider.Methods
.Where(m => m.Signature.Name == "GetData" && !m.Signature.Parameters.Any(p => p.Name == "param3"))
.ToList();
var getDataAsyncOverloads = clientProvider.Methods
.Where(m => m.Signature.Name == "GetDataAsync" && !m.Signature.Parameters.Any(p => p.Name == "param3"))
.ToList();

Assert.AreEqual(0, getDataOverloads.Count, "Back-compat overload of GetData should be suppressed by [CodeGenSuppress] in custom code.");
Assert.AreEqual(0, getDataAsyncOverloads.Count, "Back-compat overload of GetDataAsync should be suppressed by [CodeGenSuppress] in custom code.");

// The current methods (with param3) must still be present.
Assert.IsTrue(clientProvider.Methods.Any(m => m.Signature.Name == "GetData" && m.Signature.Parameters.Any(p => p.Name == "param3")));
Assert.IsTrue(clientProvider.Methods.Any(m => m.Signature.Name == "GetDataAsync" && m.Signature.Parameters.Any(p => p.Name == "param3")));
}

// Last contract has only protocol methods: GetData(int param1, BinaryContent content, RequestOptions options = null).
// The current TypeSpec adds a new optional non-body query parameter "$select" whose raw input name
// starts with a reserved character.
Expand Down Expand Up @@ -3443,8 +3485,7 @@ public async Task BackCompatibility_NewOptionalParameterWithReservedName()
Assert.IsNotNull(clientProvider);
Assert.IsNotNull(clientProvider!.LastContractView);

var processMethod = typeof(ClientProvider).GetMethod("ProcessTypeForBackCompatibility", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
processMethod?.Invoke(clientProvider, null);
clientProvider!.ProcessTypeForBackCompatibility();

var writer = new TypeProviderWriter(new FilteredMethodsTypeProvider(clientProvider!, name => name == "GetData" || name == "GetDataAsync"));
var file = writer.Write();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System.Threading;
using Microsoft.TypeSpec.Generator.Customizations;

namespace Sample
{
[CodeGenSuppress("GetData", typeof(int), typeof(string), typeof(CancellationToken))]
[CodeGenSuppress("GetDataAsync", typeof(int), typeof(string), typeof(CancellationToken))]
public partial class TestClient
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Threading;
using System.Threading.Tasks;

namespace Sample
{
public partial class TestClient
{
public virtual ClientResult<string> GetData(int param1, string param2, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}

public virtual Task<ClientResult<string>> GetDataAsync(int param1, string param2, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -962,14 +962,14 @@ public TestMrwSerialization(bool implementsPersistableModel, bool includeDepMode

protected override string BuildName() => "TestMrwSerialization";

protected override CSharpType[] BuildImplements()
protected internal override CSharpType[] BuildImplements()
{
return _implementsPersistableModel
? [new CSharpType(typeof(IPersistableModel<object>))]
: [new CSharpType(typeof(IJsonModel<object>))];
}

protected override PropertyProvider[] BuildProperties()
protected internal override PropertyProvider[] BuildProperties()
{
if (!_includeTypeWithDepModelProperty)
{
Expand Down Expand Up @@ -1021,7 +1021,7 @@ public NonMRWModelProvider(InputModelType inputModel) : base(inputModel)
{
}

protected override CSharpType[] BuildImplements()
protected internal override CSharpType[] BuildImplements()
{
// Don't implement MRW interfaces
return [];
Expand All @@ -1040,7 +1040,7 @@ public MRWTypeProvider() : base()
{
}

protected override CSharpType[] BuildImplements()
protected internal override CSharpType[] BuildImplements()
{
// Implement a framework type that does not implement MRW interfaces
return
Expand Down Expand Up @@ -1492,7 +1492,7 @@ public CustomSerializationProvider(bool usePersistableModel = false)
protected override string BuildName() => "CustomSerializationProvider";
protected override string BuildRelativeFilePath() => "CustomSerializationProvider.cs";

protected override CSharpType[] BuildImplements()
protected internal override CSharpType[] BuildImplements()
{
return [new CSharpType(_usePersistableModel ? typeof(IPersistableModel<object>) : typeof(IJsonModel<object>))];
}
Expand Down Expand Up @@ -1529,7 +1529,7 @@ public TestClientProvider(InputModelType returnTypeModel) : base()

protected override string BuildRelativeFilePath() => "TestClient.cs";

protected override MethodProvider[] BuildMethods()
protected internal override MethodProvider[] BuildMethods()
{
var modelProvider = ScmCodeModelGenerator.Instance.TypeFactory.CreateCSharpType(_returnTypeModel);
var returnType = new CSharpType(typeof(System.Threading.Tasks.Task<>), modelProvider!);
Expand Down Expand Up @@ -1560,7 +1560,7 @@ public TestClientProviderWithFrameworkReturnType() : base()

protected override string BuildRelativeFilePath() => "TestClient.cs";

protected override MethodProvider[] BuildMethods()
protected internal override MethodProvider[] BuildMethods()
{
var returnType = new CSharpType(typeof(FrameworkModelWithMRW));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ protected override MethodProvider[] BuildMethods()
return [.. base.BuildMethods().Where(m => m.Signature.Name.StartsWith("Deserialize"))];
}

protected override FieldProvider[] BuildFields() => [];
protected internal override FieldProvider[] BuildFields() => [];
protected override ConstructorProvider[] BuildConstructors() => [];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ protected override MethodProvider[] BuildMethods()
return [.. base.BuildMethods().Where(m => m.Signature.Name.Equals("JsonModelWriteCore"))];
}

protected override FieldProvider[] BuildFields() => [];
protected internal override FieldProvider[] BuildFields() => [];
protected override ConstructorProvider[] BuildConstructors() => [];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ public PropertyAccessTrackingModelProvider(InputModelType inputModel) : base(inp

public bool PropertiesAccessed => _propertiesAccessed;

protected override PropertyProvider[] BuildProperties()
protected internal override PropertyProvider[] BuildProperties()
{
_propertiesAccessed = true;
return base.BuildProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ protected override MethodProvider[] BuildMethods()
return [.. base.BuildMethods().Where(m => m.Signature.Name.Equals("PersistableModelWriteCore"))];
}

protected override FieldProvider[] BuildFields() => [];
protected internal override FieldProvider[] BuildFields() => [];
protected override ConstructorProvider[] BuildConstructors() => [];
}

Expand Down
Loading
Loading