diff --git a/global.json b/global.json index ee681bf..bb62055 100644 --- a/global.json +++ b/global.json @@ -1,6 +1,6 @@ { "sdk": { - "version": "10.0.200", + "version": "10.0.102", "rollForward": "latestMinor" } } diff --git a/src/SwaggerProvider.DesignTime/v3/OperationCompiler.fs b/src/SwaggerProvider.DesignTime/v3/OperationCompiler.fs index 2eb5fcc..33b8705 100644 --- a/src/SwaggerProvider.DesignTime/v3/OperationCompiler.fs +++ b/src/SwaggerProvider.DesignTime/v3/OperationCompiler.fs @@ -60,7 +60,7 @@ type PayloadType = /// Object for compiling operations. type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler, ignoreControllerPrefix, ignoreOperationId, asAsync: bool) = - let compileOperation (providedMethodName: string) (apiCall: ApiCall) = + let compileOperation (providedMethodName: string) (apiCall: ApiCall) (includeCancellationToken: bool) = let path, pathItem, opTy = apiCall let operation = pathItem.Operations[opTy] @@ -178,7 +178,16 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler, // reverse it again so that all required properties come first |> List.rev - payloadTy.ToMediaType(), providedParameters + let parameters = + if includeCancellationToken then + let ctParam = + ProvidedParameter("cancellationToken", typeof) + + providedParameters @ [ ctParam ] + else + providedParameters + + payloadTy.ToMediaType(), parameters // find the inner type value let retMimeAndTy = @@ -264,8 +273,21 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler, // Locates parameters matching the arguments let mutable payloadExp = None + // When the CancellationToken overload is generated, CancellationToken is always appended last. + // Extract it by position to avoid name-collision issues and invalid Expr.Coerce + // on a struct type (which generates an invalid castclass IL instruction). + let apiArgs, ct = + let allArgs = List.tail args // skip `this` + + if includeCancellationToken then + match List.rev allArgs with + | ctArg :: revApiArgs -> List.rev revApiArgs, Expr.Cast(ctArg) + | [] -> failwith "Expected CancellationToken argument but argument list was empty" + else + allArgs, <@ Threading.CancellationToken.None @> + let parameters = - List.tail args // skip `this` param + apiArgs |> List.choose (function | ShapeVar sVar as expr -> let param = @@ -392,7 +414,7 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler, @> let action = - <@ (%this).CallAsync(%httpRequestMessageWithPayload, errorCodes, errorDescriptions) @> + <@ (%this).CallAsync(%httpRequestMessageWithPayload, errorCodes, errorDescriptions, %ct) @> let responseObj = let innerReturnType = defaultArg retTy null @@ -591,7 +613,7 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler, let methodNameScope = UniqueNameGenerator() operations - |> List.map(fun op -> + |> List.collect(fun op -> let skipLength = if String.IsNullOrEmpty clientName then 0 @@ -599,5 +621,11 @@ type OperationCompiler(schema: OpenApiDocument, defCompiler: DefinitionCompiler, clientName.Length + 1 let name = OperationCompiler.GetMethodNameCandidate op skipLength ignoreOperationId - compileOperation (methodNameScope.MakeUnique name) op) + let uniqueName = methodNameScope.MakeUnique name + // Generate two overloads: one without CancellationToken (backward compatible) + // and one with an explicit CancellationToken parameter. + // We cannot use an optional struct parameter with a default value because + // struct values (e.g., CancellationToken.None) cannot be stored in DefaultParameterValue + // custom attributes. + [ compileOperation uniqueName op false; compileOperation uniqueName op true ]) |> ty.AddMembers) diff --git a/src/SwaggerProvider.Runtime/ProvidedApiClientBase.fs b/src/SwaggerProvider.Runtime/ProvidedApiClientBase.fs index 44f12cf..0dd4511 100644 --- a/src/SwaggerProvider.Runtime/ProvidedApiClientBase.fs +++ b/src/SwaggerProvider.Runtime/ProvidedApiClientBase.fs @@ -45,8 +45,13 @@ type ProvidedApiClientBase(httpClient: HttpClient, options: JsonSerializerOption JsonSerializer.Deserialize(value, retTy, options) member this.CallAsync(request: HttpRequestMessage, errorCodes: string[], errorDescriptions: string[]) : Task = + this.CallAsync(request, errorCodes, errorDescriptions, System.Threading.CancellationToken.None) + + member this.CallAsync + (request: HttpRequestMessage, errorCodes: string[], errorDescriptions: string[], cancellationToken: System.Threading.CancellationToken) + : Task = task { - let! response = this.HttpClient.SendAsync(request) + let! response = this.HttpClient.SendAsync(request, cancellationToken) if response.IsSuccessStatusCode then return response.Content diff --git a/tests/SwaggerProvider.ProviderTests/SwaggerProvider.ProviderTests.fsproj b/tests/SwaggerProvider.ProviderTests/SwaggerProvider.ProviderTests.fsproj index ecb8b27..14c5674 100644 --- a/tests/SwaggerProvider.ProviderTests/SwaggerProvider.ProviderTests.fsproj +++ b/tests/SwaggerProvider.ProviderTests/SwaggerProvider.ProviderTests.fsproj @@ -30,6 +30,7 @@ + diff --git a/tests/SwaggerProvider.ProviderTests/v3/Swashbuckle.CancellationToken.Tests.fs b/tests/SwaggerProvider.ProviderTests/v3/Swashbuckle.CancellationToken.Tests.fs new file mode 100644 index 0000000..d54ec99 --- /dev/null +++ b/tests/SwaggerProvider.ProviderTests/v3/Swashbuckle.CancellationToken.Tests.fs @@ -0,0 +1,43 @@ +module Swashbuckle.v3.CancellationTokenTests + +open Xunit +open FsUnitTyped +open System +open System.Threading +open Swashbuckle.v3.ReturnControllersTests + +[] +let ``Call generated method with explicit CancellationToken None``() = + task { + let! result = api.GetApiReturnBoolean(CancellationToken.None) + result |> shouldEqual true + } + +[] +let ``Call generated method with valid CancellationTokenSource token``() = + task { + use cts = new CancellationTokenSource() + let! result = api.GetApiReturnInt32(cts.Token) + result |> shouldEqual 42 + } + +[] +let ``Call generated method with already-cancelled token raises OperationCanceledException``() = + task { + use cts = new CancellationTokenSource() + cts.Cancel() + + try + let! _ = api.GetApiReturnString(cts.Token) + failwith "Expected OperationCanceledException" + with + | :? OperationCanceledException -> () + | :? System.AggregateException as aex when (aex.InnerException :? OperationCanceledException) -> () + } + +[] +let ``Call POST generated method with explicit CancellationToken None``() = + task { + let! result = api.PostApiReturnString(CancellationToken.None) + result |> shouldEqual "Hello world" + } diff --git a/tests/SwaggerProvider.Tests/RuntimeHelpersTests.fs b/tests/SwaggerProvider.Tests/RuntimeHelpersTests.fs index 1933a32..f534bd2 100644 --- a/tests/SwaggerProvider.Tests/RuntimeHelpersTests.fs +++ b/tests/SwaggerProvider.Tests/RuntimeHelpersTests.fs @@ -372,7 +372,8 @@ module ToContentTests = type private StubHttpMessageHandler(statusCode: HttpStatusCode, responseBody: string) = inherit HttpMessageHandler() - override _.SendAsync(_request: HttpRequestMessage, _cancellationToken: CancellationToken) = + override _.SendAsync(_request: HttpRequestMessage, cancellationToken: CancellationToken) = + cancellationToken.ThrowIfCancellationRequested() let response = new HttpResponseMessage(statusCode) response.Content <- new StringContent(responseBody) Task.FromResult(response) @@ -495,3 +496,34 @@ module OpenApiExceptionTests = () } + + [] + let ``CallAsync with CancellationToken returns content on success``() = + task { + use handler = new StubHttpMessageHandler(HttpStatusCode.OK, "result") + let client = makeClient handler + use request = new HttpRequestMessage(HttpMethod.Get, "http://stub/pets/1") + let! content = client.CallAsync(request, [||], [||], CancellationToken.None) + let! body = content.ReadAsStringAsync() + body |> shouldEqual "result" + } + + [] + let ``CallAsync with already-cancelled token raises OperationCanceledException``() = + task { + use cts = new CancellationTokenSource() + cts.Cancel() + + use handler = new StubHttpMessageHandler(HttpStatusCode.OK, "ok") + let client = makeClient handler + use request = new HttpRequestMessage(HttpMethod.Get, "http://stub/pets/1") + + let! _ = + Assert.ThrowsAnyAsync(fun () -> + task { + let! _ = client.CallAsync(request, [||], [||], cts.Token) + () + }) + + () + }