diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs index 17c09257962..f24c7c279f5 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/RestClientProvider.cs @@ -431,22 +431,16 @@ private IEnumerable AppendHeaderParameters(HttpRequestApi r statement = request.SetHeaders([Literal(inputHeaderParameter.SerializedName), toStringExpression.As()]); } - if (!TryGetSpecialHeaderParam(inputHeaderParameter, out _) && (!inputHeaderParameter.IsRequired || type?.IsNullable == true || - (type is { IsValueType: false, IsFrameworkType: true } && type.FrameworkType != typeof(string)))) + // If this is a Content-Type header and there's an optional content parameter, wrap in content null check + if (inputHeaderParameter.IsContentType && contentParam != null && + operation.Parameters.Any(p => p is InputBodyParameter bodyParam && !bodyParam.IsRequired)) { - statement = BuildQueryOrHeaderOrPathParameterNullCheck(type, valueExpression, statement); + statement = new IfStatement(contentParam.NotEqual(Null)) { statement }; } - // If this is a Content-Type header and there's an optional content parameter, wrap in content null check - else if (inputHeaderParameter.IsContentType && contentParam != null) + else if (!TryGetSpecialHeaderParam(inputHeaderParameter, out _) && (!inputHeaderParameter.IsRequired || type?.IsNullable == true || + (type is { IsValueType: false, IsFrameworkType: true } && type.FrameworkType != typeof(string)))) { - // Check if any body parameter in the operation is optional - var hasOptionalBody = operation.Parameters.Any(p => - p is InputBodyParameter bodyParam && !bodyParam.IsRequired); - - if (hasOptionalBody) - { - statement = new IfStatement(contentParam.NotEqual(Null)) { statement }; - } + statement = BuildQueryOrHeaderOrPathParameterNullCheck(type, valueExpression, statement); } statements.Add(statement); diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs index 0111de0dc0e..2a9638dab4b 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/RestClientProviderTests.cs @@ -1655,8 +1655,6 @@ public void TestApiVersionParameterReinjectedInCreateNextRequestMethod() [Test] public void ContentTypeHeaderWrappedInNullCheckWhenContentIsOptional() { - // Test that when there's an optional body parameter with a Content-Type header, - // the Content-Type header setting is wrapped in a null check for the content parameter var contentTypeParam = InputFactory.HeaderParameter( "Content-Type", InputFactory.Literal.String("application/json"), @@ -1671,37 +1669,21 @@ public void ContentTypeHeaderWrappedInNullCheckWhenContentIsOptional() "TestOperation", requestMediaTypes: ["application/json"], parameters: [contentTypeParam, bodyParam]); - var inputServiceMethod = InputFactory.BasicServiceMethod("Test", operation); - var inputClient = InputFactory.Client("TestClient", methods: [inputServiceMethod]); - MockHelpers.LoadMockGenerator(clients: () => [inputClient]); - - var client = ScmCodeModelGenerator.Instance.TypeFactory.CreateClient(inputClient); - Assert.IsNotNull(client); - - var restClient = client!.RestClient; - Assert.IsNotNull(restClient); - - var createMethod = restClient.Methods.FirstOrDefault(m => m.Signature.Name == "CreateTestOperationRequest"); - Assert.IsNotNull(createMethod, "CreateTestOperationRequest method not found"); + var inputClient = InputFactory.Client( + "TestClient", + methods: [InputFactory.BasicServiceMethod("Test", operation)]); - var statements = createMethod!.BodyStatements as MethodBodyStatements; - Assert.IsNotNull(statements); + var clientProvider = new ClientProvider(inputClient); + var restClientProvider = new MockClientProvider(inputClient, clientProvider); - var expectedStatement = @"if ((content != null)) -{ - request.Headers.Set(""Content-Type"", ""application/json""); -} -"; - var statementsString = string.Join("\n", statements!.Select(s => s.ToDisplayString())); - Assert.IsTrue(statements!.Any(s => s.ToDisplayString() == expectedStatement), - $"Expected to find statement:\n{expectedStatement}\nBut got statements:\n{statementsString}"); + var writer = new TypeProviderWriter(restClientProvider); + var file = writer.Write(); + Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content); } [Test] public void ContentTypeHeaderNotWrappedInNullCheckWhenContentIsRequired() { - // Test that when there's a required body parameter with a Content-Type header, - // the Content-Type header setting is NOT wrapped in a null check var contentTypeParam = InputFactory.HeaderParameter( "Content-Type", InputFactory.Literal.String("application/json"), @@ -1716,32 +1698,45 @@ public void ContentTypeHeaderNotWrappedInNullCheckWhenContentIsRequired() "TestOperation", requestMediaTypes: ["application/json"], parameters: [contentTypeParam, bodyParam]); - var inputServiceMethod = InputFactory.BasicServiceMethod("Test", operation); - var inputClient = InputFactory.Client("TestClient", methods: [inputServiceMethod]); - MockHelpers.LoadMockGenerator(clients: () => [inputClient]); + var inputClient = InputFactory.Client( + "TestClient", + methods: [InputFactory.BasicServiceMethod("Test", operation)]); - var client = ScmCodeModelGenerator.Instance.TypeFactory.CreateClient(inputClient); - Assert.IsNotNull(client); + var clientProvider = new ClientProvider(inputClient); + var restClientProvider = new MockClientProvider(inputClient, clientProvider); - var restClient = client!.RestClient; - Assert.IsNotNull(restClient); + var writer = new TypeProviderWriter(restClientProvider); + var file = writer.Write(); + Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content); + } - var createMethod = restClient.Methods.FirstOrDefault(m => m.Signature.Name == "CreateTestOperationRequest"); - Assert.IsNotNull(createMethod, "CreateTestOperationRequest method not found"); + [Test] + public void ContentTypeHeaderWrappedInNullCheckWhenContentTypeIsOptional() + { + var contentTypeParam = InputFactory.HeaderParameter( + "Content-Type", + InputFactory.Literal.String("application/xml"), + isRequired: false, + isContentType: true, + scope: InputParameterScope.Constant); + var bodyParam = InputFactory.BodyParameter( + "body", + InputPrimitiveType.String, + isRequired: false); + var operation = InputFactory.Operation( + "TestOperation", + requestMediaTypes: ["application/xml"], + parameters: [contentTypeParam, bodyParam]); + var inputClient = InputFactory.Client( + "TestClient", + methods: [InputFactory.BasicServiceMethod("Test", operation)]); - var statements = createMethod!.BodyStatements as MethodBodyStatements; - Assert.IsNotNull(statements); + var clientProvider = new ClientProvider(inputClient); + var restClientProvider = new MockClientProvider(inputClient, clientProvider); - // Verify there's no if statement wrapping the Content-Type header - var wrappedStatement = @"if ((content != null)) -{ - request.Headers.Set(""Content-Type"", ""application/json""); -} -"; - var statementsString = string.Join("\n", statements!.Select(s => s.ToDisplayString())); - var hasIfWrappedContentType = statements!.Any(s => s.ToDisplayString().Contains(wrappedStatement)); - Assert.IsFalse(hasIfWrappedContentType, - $"Content-Type should NOT be wrapped in an if statement for required content, but found:\n{statementsString}"); + var writer = new TypeProviderWriter(restClientProvider); + var file = writer.Write(); + Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content); } [Test] diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeHeaderNotWrappedInNullCheckWhenContentIsRequired.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeHeaderNotWrappedInNullCheckWhenContentIsRequired.cs new file mode 100644 index 00000000000..7444c570591 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeHeaderNotWrappedInNullCheckWhenContentIsRequired.cs @@ -0,0 +1,24 @@ +// + +#nullable disable + +using System.ClientModel; +using System.ClientModel.Primitives; + +namespace Sample +{ + public partial class TestClient + { + internal global::System.ClientModel.Primitives.PipelineMessage CreateTestOperationRequest(global::System.ClientModel.BinaryContent content, global::System.ClientModel.Primitives.RequestOptions options) + { + global::Sample.ClientUriBuilder uri = new global::Sample.ClientUriBuilder(); + uri.Reset(_endpoint); + global::System.ClientModel.Primitives.PipelineMessage message = Pipeline.CreateMessage(uri.ToUri(), "GET", PipelineMessageClassifier200); + global::System.ClientModel.Primitives.PipelineRequest request = message.Request; + request.Headers.Set("Content-Type", "application/json"); + request.Content = content; + message.Apply(options); + return message; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeHeaderWrappedInNullCheckWhenContentIsOptional.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeHeaderWrappedInNullCheckWhenContentIsOptional.cs new file mode 100644 index 00000000000..08028ebda4d --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeHeaderWrappedInNullCheckWhenContentIsOptional.cs @@ -0,0 +1,27 @@ +// + +#nullable disable + +using System.ClientModel; +using System.ClientModel.Primitives; + +namespace Sample +{ + public partial class TestClient + { + internal global::System.ClientModel.Primitives.PipelineMessage CreateTestOperationRequest(global::System.ClientModel.BinaryContent content, global::System.ClientModel.Primitives.RequestOptions options) + { + global::Sample.ClientUriBuilder uri = new global::Sample.ClientUriBuilder(); + uri.Reset(_endpoint); + global::System.ClientModel.Primitives.PipelineMessage message = Pipeline.CreateMessage(uri.ToUri(), "GET", PipelineMessageClassifier200); + global::System.ClientModel.Primitives.PipelineRequest request = message.Request; + if ((content != null)) + { + request.Headers.Set("Content-Type", "application/json"); + } + request.Content = content; + message.Apply(options); + return message; + } + } +} diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeHeaderWrappedInNullCheckWhenContentTypeIsOptional.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeHeaderWrappedInNullCheckWhenContentTypeIsOptional.cs new file mode 100644 index 00000000000..47ab29490c4 --- /dev/null +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/RestClientProviders/TestData/RestClientProviderTests/ContentTypeHeaderWrappedInNullCheckWhenContentTypeIsOptional.cs @@ -0,0 +1,27 @@ +// + +#nullable disable + +using System.ClientModel; +using System.ClientModel.Primitives; + +namespace Sample +{ + public partial class TestClient + { + internal global::System.ClientModel.Primitives.PipelineMessage CreateTestOperationRequest(global::System.ClientModel.BinaryContent content, global::System.ClientModel.Primitives.RequestOptions options) + { + global::Sample.ClientUriBuilder uri = new global::Sample.ClientUriBuilder(); + uri.Reset(_endpoint); + global::System.ClientModel.Primitives.PipelineMessage message = Pipeline.CreateMessage(uri.ToUri(), "GET", PipelineMessageClassifier200); + global::System.ClientModel.Primitives.PipelineRequest request = message.Request; + if ((content != null)) + { + request.Headers.Set("Content-Type", "application/xml"); + } + request.Content = content; + message.Apply(options); + return message; + } + } +}