Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FilteredStreamMessage.onCancellation() #3375

Merged
merged 15 commits into from
Mar 26, 2021
Expand Up @@ -21,6 +21,7 @@
import javax.annotation.Nullable;

import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import com.google.common.base.Ascii;

Expand Down Expand Up @@ -113,4 +114,11 @@ protected Throwable beforeError(Subscriber<? super HttpObject> subscriber, Throw
}
return cause;
}

@Override
protected void beforeCancel(Subscription subscription) {
if (responseDecoder != null) {
responseDecoder.finish();
}
}
}
Expand Up @@ -103,6 +103,13 @@ protected Throwable beforeError(Subscriber<? super U> subscriber, Throwable caus
return cause;
}

/**
* A callback executed just before calling the upstream {@link Subscription#cancel()}.
* Override this method to execute any cleanup logic that may be needed before completing or failing the
* subscription.
*/
protected void beforeCancel(Subscription subscription) {}

@Override
public final boolean isOpen() {
return upstream.isOpen();
Expand Down Expand Up @@ -184,7 +191,7 @@ private final class FilteringSubscriber implements Subscriber<T> {
@Override
public void onSubscribe(Subscription s) {
beforeSubscribe(delegate, s);
delegate.onSubscribe(s);
delegate.onSubscribe(new SubscriptionWrapper(s));
}

@Override
Expand Down Expand Up @@ -216,4 +223,24 @@ public void onComplete() {
delegate.onComplete();
}
}

private final class SubscriptionWrapper implements Subscription {

private final Subscription subscription;

SubscriptionWrapper(Subscription s) {
subscription = s;
}

@Override
public void request(long n) {
subscription.request(n);
}

@Override
public void cancel() {
beforeCancel(subscription);
subscription.cancel();
}
}
}
Expand Up @@ -21,6 +21,9 @@
import javax.annotation.Nullable;

import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.linecorp.armeria.common.FilteredHttpRequest;
import com.linecorp.armeria.common.FilteredHttpResponse;
Expand All @@ -39,6 +42,8 @@

public final class ContentPreviewingUtil {

private static final Logger logger = LoggerFactory.getLogger(ContentPreviewingUtil.class);

/**
* Sets up the request {@link ContentPreviewer} to set
* {@link RequestLogBuilder#requestContentPreview(String)} when the preview is available.
Expand Down Expand Up @@ -71,19 +76,40 @@ protected HttpObject filter(HttpObject obj) {

@Override
protected void beforeComplete(Subscriber<? super HttpObject> subscriber) {
logBuilder.requestContentPreview(requestContentPreviewer.produce());
produceRequestContentPreview(requestContentPreviewer);
}

@Override
protected Throwable beforeError(Subscriber<? super HttpObject> subscriber,
Throwable cause) {
// Call produce() to release the resources in the previewer. Consider adding close() method.
requestContentPreviewer.produce();
try {
// Call produce() to release the resources in the previewer. Consider adding close() method.
requestContentPreviewer.produce();
} catch (Exception e) {
logger.warn("Unexpected exception while producing the request content preview. " +
"previewer: {}", requestContentPreviewer, e);
}

// Set null to make it sure the log is complete.
logBuilder.requestContentPreview(null);
return cause;
}

@Override
protected void beforeCancel(Subscription subscription) {
produceRequestContentPreview(requestContentPreviewer);
}

private void produceRequestContentPreview(ContentPreviewer requestContentPreviewer) {
minwoox marked this conversation as resolved.
Show resolved Hide resolved
String produced = null;
try {
produced = requestContentPreviewer.produce();
} catch (Exception e) {
logger.warn("Unexpected exception while producing the request content preview. " +
"previewer: {}", requestContentPreviewer, e);
}
logBuilder.requestContentPreview(produced);
}
};
}

Expand Down Expand Up @@ -125,12 +151,7 @@ protected HttpObject filter(HttpObject obj) {

@Override
protected void beforeComplete(Subscriber<? super HttpObject> subscriber) {
if (responseContentPreviewer != null) {
ctx.logBuilder().responseContentPreview(responseContentPreviewer.produce());
} else {
// Call requestContentPreview(null) to make sure that the log is complete.
ctx.logBuilder().responseContentPreview(null);
}
produceResponseContentPreview(responseContentPreviewer);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any chance that subscriber.onComplete() and subscription.cancel() are called concurrently?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could have an atomic flag, then we won't even need the cause instanceof CancelledSubscriptionException checks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized that if we introduce an atomic flag, a user needs to do the extra work because beforeCancel() and beforeError() can be called twice in certain situation.
So let me just remove beforeCancel() and handle the cancelation event in beforeError.

}

@Override
Expand All @@ -143,6 +164,27 @@ protected Throwable beforeError(Subscriber<? super HttpObject> subscriber, Throw
ctx.logBuilder().responseContentPreview(null);
return cause;
}

@Override
protected void beforeCancel(Subscription subscription) {
produceResponseContentPreview(responseContentPreviewer);
}

private void produceResponseContentPreview(@Nullable ContentPreviewer responseContentPreviewer) {
minwoox marked this conversation as resolved.
Show resolved Hide resolved
if (responseContentPreviewer != null) {
String produced = null;
try {
produced = responseContentPreviewer.produce();
} catch (Exception e) {
logger.warn("Unexpected exception while producing the response content preview. " +
"previewer: {}", responseContentPreviewer, e);
}
ctx.logBuilder().responseContentPreview(produced);
} else {
// Call requestContentPreview(null) to make sure that the log is complete.
ctx.logBuilder().responseContentPreview(null);
}
}
};
}

Expand Down
Expand Up @@ -17,6 +17,7 @@
package com.linecorp.armeria.server.encoding;

import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import com.linecorp.armeria.common.FilteredHttpRequest;
import com.linecorp.armeria.common.HttpData;
Expand Down Expand Up @@ -62,4 +63,9 @@ protected Throwable beforeError(Subscriber<? super HttpObject> subscriber, Throw
responseDecoder.finish();
return cause;
}

@Override
protected void beforeCancel(Subscription subscription) {
responseDecoder.finish();
minwoox marked this conversation as resolved.
Show resolved Hide resolved
}
}
Expand Up @@ -26,6 +26,7 @@
import javax.annotation.Nullable;

import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -155,6 +156,11 @@ protected Throwable beforeError(Subscriber<? super HttpObject> subscriber, Throw
return cause;
}

@Override
protected void beforeCancel(Subscription subscription) {
closeEncoder();
minwoox marked this conversation as resolved.
Show resolved Hide resolved
}

private void closeEncoder() {
if (encodingStream == null) {
return;
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Expand Up @@ -17,6 +17,12 @@
package com.linecorp.armeria.client.encoding;

import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand All @@ -36,8 +42,10 @@
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpObject;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpResponseWriter;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.encoding.StreamDecoder;
import com.linecorp.armeria.common.stream.SubscriptionOption;

import io.netty.buffer.ByteBuf;
Expand Down Expand Up @@ -164,6 +172,26 @@ void pooledPayload_pooledDrain_withOldDecoder() {
assertThat(decodedPayloadBuf.refCnt()).isZero();
}

@Test
void streamDecoderFinishedIsCalledWhenRequestCanceled() throws InterruptedException {
final HttpResponseWriter response = HttpResponse.streaming();
response.write(ResponseHeaders.of(HttpStatus.OK, HttpHeaderNames.CONTENT_ENCODING, "foo"));
final HttpData data = HttpData.ofUtf8("bar");
response.write(data);

final com.linecorp.armeria.common.encoding.StreamDecoderFactory factory = mock(
com.linecorp.armeria.common.encoding.StreamDecoderFactory.class);
final com.linecorp.armeria.common.encoding.StreamDecoder streamDecoder = mock(StreamDecoder.class);
when(factory.newDecoder(any())).thenReturn(streamDecoder);
when(streamDecoder.decode(any())).thenReturn(data);

final HttpResponse decoded = new HttpDecodedResponse(response, ImmutableMap.of("foo", factory),
ByteBufAllocator.DEFAULT);
decoded.subscribe(new CancelSubscriber());

await().untilAsserted(() -> verify(streamDecoder, times(1)).finish());
}

private static HttpData responseData(HttpResponse decoded, boolean withPooledObjects) {
final CompletableFuture<HttpData> future = new CompletableFuture<>();
final Subscriber<HttpObject> subscriber = new Subscriber<HttpObject>() {
Expand Down Expand Up @@ -194,4 +222,28 @@ public void onComplete() {}

return future.join();
}

private static class CancelSubscriber implements Subscriber<HttpObject> {

private Subscription subscription;

@Override
public void onSubscribe(Subscription s) {
subscription = s;
s.request(Long.MAX_VALUE);
}

@Override
public void onNext(HttpObject httpObject) {
if (httpObject instanceof HttpData) {
subscription.cancel();
}
}

@Override
public void onError(Throwable t) {}

@Override
public void onComplete() {}
}
}