diff --git a/CHANGELOG.md b/CHANGELOG.md index 995516452..b099f8eb4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added support for cancelling requests #361 + ### Changed - Bumps Azure Core from 1.20.0 to 1.22.0 #359, #360, #341, #342 diff --git a/src/main/java/com/microsoft/graph/http/CoreHttpCallbackFutureWrapper.java b/src/main/java/com/microsoft/graph/http/CoreHttpCallbackFutureWrapper.java index 1bb930411..4cb926660 100644 --- a/src/main/java/com/microsoft/graph/http/CoreHttpCallbackFutureWrapper.java +++ b/src/main/java/com/microsoft/graph/http/CoreHttpCallbackFutureWrapper.java @@ -1,9 +1,12 @@ package com.microsoft.graph.http; import java.io.IOException; - +import java.util.Objects; +import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; +import javax.annotation.Nonnull; + import okhttp3.Call; import okhttp3.Callback; import okhttp3.Response; @@ -13,6 +16,14 @@ */ class CoreHttpCallbackFutureWrapper implements Callback { final CompletableFuture future = new CompletableFuture<>(); + public CoreHttpCallbackFutureWrapper(@Nonnull final Call call) { + Objects.requireNonNull(call); + future.whenComplete((r, ex) -> { + if (ex != null && (ex instanceof InterruptedException || ex instanceof CancellationException)) { + call.cancel(); + } + }); + } @Override public void onFailure(Call arg0, IOException arg1) { future.completeExceptionally(arg1); diff --git a/src/main/java/com/microsoft/graph/http/CoreHttpProvider.java b/src/main/java/com/microsoft/graph/http/CoreHttpProvider.java index 3b50842c5..6ffa45028 100644 --- a/src/main/java/com/microsoft/graph/http/CoreHttpProvider.java +++ b/src/main/java/com/microsoft/graph/http/CoreHttpProvider.java @@ -51,6 +51,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; +import okhttp3.Call; import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; @@ -375,8 +376,9 @@ private java.util.concurrent.CompletableFuture handler) throws ClientException { final Request coreHttpRequest = getHttpRequest(request, resultClass, serializable); - final CoreHttpCallbackFutureWrapper wrapper = new CoreHttpCallbackFutureWrapper(); - corehttpClient.newCall(coreHttpRequest).enqueue(wrapper); + final Call call = corehttpClient.newCall(coreHttpRequest); + final CoreHttpCallbackFutureWrapper wrapper = new CoreHttpCallbackFutureWrapper(call); + call.enqueue(wrapper); return wrapper.future.thenApply(r -> processResponse(r, request, resultClass, serializable, handler)); } /** diff --git a/src/test/java/com/microsoft/graph/http/CoreHttpCallbackFutureWrapperTests.java b/src/test/java/com/microsoft/graph/http/CoreHttpCallbackFutureWrapperTests.java new file mode 100644 index 000000000..83b1d1f54 --- /dev/null +++ b/src/test/java/com/microsoft/graph/http/CoreHttpCallbackFutureWrapperTests.java @@ -0,0 +1,46 @@ +package com.microsoft.graph.http; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +import org.junit.jupiter.api.Test; + +import okhttp3.Call; +import okhttp3.Response; + +class CoreHttpCallbackFutureWrapperTests { + + @Test + void throwsIfCallIsNull() { + assertThrows(NullPointerException.class, () -> new CoreHttpCallbackFutureWrapper(null)); + } + boolean isCanceled = false; + + @Test + void cancelsCall() { + var call = mock(Call.class); + doAnswer(i -> { + isCanceled = true; + return null; + }).when(call).cancel(); + var wrapper = new CoreHttpCallbackFutureWrapper(call); + wrapper.future.cancel(true); + assertTrue(isCanceled); + } + + @Test + void returnsResponseWhenCompleted() throws IOException, InterruptedException, ExecutionException { + var call = mock(Call.class); + var response = mock(Response.class); + var wrapper = new CoreHttpCallbackFutureWrapper(call); + wrapper.onResponse(call, response); + assertEquals(response, wrapper.future.get()); + } + +}