Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FilteredStreamMessage.onCancellation() #3375

Merged
merged 15 commits into from Mar 26, 2021
Expand Up @@ -49,6 +49,7 @@ final class HttpDecodedResponse extends FilteredHttpResponse {
@Nullable
private StreamDecoder responseDecoder;
private boolean headersReceived;
private boolean decoderClosed;

HttpDecodedResponse(HttpResponse delegate, Map<String, StreamDecoderFactory> availableDecoders,
ByteBufAllocator alloc, boolean strictContentEncoding) {
Expand Down Expand Up @@ -110,20 +111,43 @@ protected HttpObject filter(HttpObject obj) {

@Override
protected void beforeComplete(Subscriber<? super HttpObject> subscriber) {
if (responseDecoder == null) {
final HttpData lastData = closeResponseDecoder();
if (lastData == null) {
return;
}
final HttpData lastData = responseDecoder.finish();
if (!lastData.isEmpty()) {
subscriber.onNext(lastData);
} else {
lastData.close();
}
}

@Override
protected Throwable beforeError(Subscriber<? super HttpObject> subscriber, Throwable cause) {
if (responseDecoder != null) {
responseDecoder.finish();
final HttpData lastData = closeResponseDecoder();
if (lastData != null) {
lastData.close();
}
return cause;
}

@Override
protected void onCancellation(Subscriber<? super HttpObject> subscriber) {
final HttpData lastData = closeResponseDecoder();
if (lastData != null) {
lastData.close();
}
}

@Nullable
private HttpData closeResponseDecoder() {
if (decoderClosed) {
return null;
}
decoderClosed = true;
if (responseDecoder == null) {
return null;
}
return responseDecoder.finish();
}
}
Expand Up @@ -20,7 +20,6 @@
import static com.linecorp.armeria.common.stream.StreamMessageUtil.containsWithPooledObjects;
import static java.util.Objects.requireNonNull;

import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;

import javax.annotation.Nullable;
Expand All @@ -45,13 +44,11 @@ public abstract class FilteredStreamMessage<T, U> implements StreamMessage<U> {

private static final Logger logger = LoggerFactory.getLogger(FilteredStreamMessage.class);

private static final SubscriptionOption[] EMPTY_OPTIONS = new SubscriptionOption[0];

private final StreamMessage<T> upstream;
private final boolean filterSupportsPooledObjects;

/**
* Creates a new {@link FilteredStreamMessage} that filters objects published by {@code delegate}
* Creates a new {@link FilteredStreamMessage} that filters objects published by {@code upstream}
* before passing to a subscriber.
*/
protected FilteredStreamMessage(StreamMessage<T> upstream) {
Expand All @@ -60,7 +57,7 @@ protected FilteredStreamMessage(StreamMessage<T> upstream) {

/**
* (Advanced users only) Creates a new {@link FilteredStreamMessage} that filters objects published by
* {@code delegate} before passing to a subscriber.
* {@code upstream} before passing to a subscriber.
*
* @param withPooledObjects if {@code true}, {@link #filter(Object)} receives the pooled {@link HttpData}
* as is, without making a copy. If you don't know what this means,
Expand All @@ -69,7 +66,7 @@ protected FilteredStreamMessage(StreamMessage<T> upstream) {
*/
@UnstableApi
protected FilteredStreamMessage(StreamMessage<T> upstream, boolean withPooledObjects) {
this.upstream = requireNonNull(upstream, "delegate");
this.upstream = requireNonNull(upstream, "upstream");
filterSupportsPooledObjects = withPooledObjects;
}

Expand Down Expand Up @@ -103,6 +100,11 @@ protected Throwable beforeError(Subscriber<? super U> subscriber, Throwable caus
return cause;
}

/**
* A callback executed when this {@link StreamMessage} is canceled by the {@link Subscriber}.
*/
protected void onCancellation(Subscriber<? super U> subscriber) {}

@Override
public final boolean isOpen() {
return upstream.isOpen();
Expand Down Expand Up @@ -141,19 +143,14 @@ public final void subscribe(Subscriber<? super U> subscriber, EventExecutor exec

private void subscribe(Subscriber<? super U> subscriber, EventExecutor executor, boolean withPooledObjects,
boolean notifyCancellation) {
upstream.subscribe(new FilteringSubscriber(subscriber, withPooledObjects),
executor, filteringSubscriptionOptions(notifyCancellation));
}

private SubscriptionOption[] filteringSubscriptionOptions(boolean notifyCancellation) {
final ArrayList<SubscriptionOption> list = new ArrayList<>(2);
final FilteringSubscriber filteringSubscriber = new FilteringSubscriber(
subscriber, withPooledObjects, notifyCancellation);
if (filterSupportsPooledObjects) {
list.add(SubscriptionOption.WITH_POOLED_OBJECTS);
}
if (notifyCancellation) {
list.add(SubscriptionOption.NOTIFY_CANCELLATION);
upstream.subscribe(filteringSubscriber, executor,
SubscriptionOption.NOTIFY_CANCELLATION, SubscriptionOption.WITH_POOLED_OBJECTS);
} else {
upstream.subscribe(filteringSubscriber, executor, SubscriptionOption.NOTIFY_CANCELLATION);
}
return list.toArray(EMPTY_OPTIONS);
}

@Override
Expand All @@ -175,13 +172,16 @@ private final class FilteringSubscriber implements Subscriber<T> {

private final Subscriber<? super U> delegate;
private final boolean subscribedWithPooledObjects;
private final boolean notifyCancellation;

@Nullable
private Subscription upstream;

FilteringSubscriber(Subscriber<? super U> delegate, boolean subscribedWithPooledObjects) {
FilteringSubscriber(Subscriber<? super U> delegate, boolean subscribedWithPooledObjects,
boolean notifyCancellation) {
this.delegate = requireNonNull(delegate, "delegate");
this.subscribedWithPooledObjects = subscribedWithPooledObjects;
this.notifyCancellation = notifyCancellation;
}

@Override
Expand All @@ -206,6 +206,7 @@ public void onNext(T o) {
filtered = filter(o);
} catch (Throwable ex) {
StreamMessageUtil.closeOrAbort(o);
assert upstream != null;
upstream.cancel();
onError(ex);
return;
Expand All @@ -219,6 +220,13 @@ public void onNext(T o) {

@Override
public void onError(Throwable t) {
if (t instanceof CancelledSubscriptionException) {
onCancellation(delegate);
if (!notifyCancellation) {
return;
}
}

final Throwable filteredCause = beforeError(delegate, t);
if (filteredCause != null) {
delegate.onError(filteredCause);
Expand Down
Expand Up @@ -20,7 +20,8 @@

import javax.annotation.Nullable;

import org.reactivestreams.Subscriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.linecorp.armeria.common.FilteredHttpRequest;
import com.linecorp.armeria.common.FilteredHttpResponse;
Expand All @@ -39,6 +40,8 @@

public final class ContentPreviewingUtil {

private static final Logger logger = LoggerFactory.getLogger(ContentPreviewingUtil.class);

/**
* Sets up the request {@link ContentPreviewer} to set
* {@link RequestLogBuilder#requestContentPreview(String)} when the preview is available.
Expand All @@ -60,31 +63,27 @@ public static HttpRequest setUpRequestContentPreviewer(RequestContext ctx, HttpR
logBuilder.requestContentPreview(null);
return null;
});
return new FilteredHttpRequest(req) {
final FilteredHttpRequest filteredHttpRequest = new FilteredHttpRequest(req) {
@Override
protected HttpObject filter(HttpObject obj) {
if (obj instanceof HttpData) {
requestContentPreviewer.onData((HttpData) obj);
}
return obj;
}

@Override
protected void beforeComplete(Subscriber<? super HttpObject> subscriber) {
logBuilder.requestContentPreview(requestContentPreviewer.produce());
}

@Override
protected Throwable beforeError(Subscriber<? super HttpObject> subscriber,
Throwable cause) {
// Call produce() to release the resources in the previewer. Consider adding close() method.
requestContentPreviewer.produce();

// Set null to make it sure the log is complete.
logBuilder.requestContentPreview(null);
return cause;
}
};
filteredHttpRequest.whenComplete().handle((unused, cause) -> {
String produced = null;
try {
produced = requestContentPreviewer.produce();
} catch (Exception e) {
logger.warn("Unexpected exception while producing the request content preview. " +
"previewer: {}", requestContentPreviewer, e);
}
logBuilder.requestContentPreview(produced);
return null;
});
return filteredHttpRequest;
}

/**
Expand All @@ -96,54 +95,60 @@ public static HttpResponse setUpResponseContentPreviewer(
requireNonNull(factory, "factory");
requireNonNull(ctx, "ctx");
requireNonNull(res, "res");
return new ContentPreviewerHttpResponse(res, factory, ctx);
}

return new FilteredHttpResponse(res) {
@Nullable
ContentPreviewer responseContentPreviewer;
private static class ContentPreviewerHttpResponse extends FilteredHttpResponse {

@Override
protected HttpObject filter(HttpObject obj) {
if (obj instanceof ResponseHeaders) {
final ResponseHeaders resHeaders = (ResponseHeaders) obj;

// Skip informational headers.
final String status = resHeaders.get(HttpHeaderNames.STATUS);
if (ArmeriaHttpUtil.isInformational(status)) {
return obj;
}
final ContentPreviewer contentPreviewer = factory.responseContentPreviewer(ctx, resHeaders);
if (!contentPreviewer.isDisabled()) {
responseContentPreviewer = contentPreviewer;
}
} else if (obj instanceof HttpData) {
if (responseContentPreviewer != null) {
responseContentPreviewer.onData((HttpData) obj);
}
}
return obj;
}
private final ContentPreviewerFactory factory;
private final RequestContext ctx;
@Nullable
ContentPreviewer responseContentPreviewer;

@Override
protected void beforeComplete(Subscriber<? super HttpObject> subscriber) {
protected ContentPreviewerHttpResponse(HttpResponse delegate, ContentPreviewerFactory factory,
RequestContext ctx) {
super(delegate);
this.factory = factory;
this.ctx = ctx;
whenComplete().handle((unused, cause) -> {
if (responseContentPreviewer != null) {
ctx.logBuilder().responseContentPreview(responseContentPreviewer.produce());
String produced = null;
try {
produced = responseContentPreviewer.produce();
} catch (Exception e) {
logger.warn("Unexpected exception while producing the response content preview. " +
"previewer: {}", responseContentPreviewer, e);
}
ctx.logBuilder().responseContentPreview(produced);
} else {
// Call requestContentPreview(null) to make sure that the log is complete.
ctx.logBuilder().responseContentPreview(null);
}
}
return null;
});
}

@Override
protected Throwable beforeError(Subscriber<? super HttpObject> subscriber, Throwable cause) {
@Override
protected HttpObject filter(HttpObject obj) {
if (obj instanceof ResponseHeaders) {
final ResponseHeaders resHeaders = (ResponseHeaders) obj;

// Skip informational headers.
final String status = resHeaders.get(HttpHeaderNames.STATUS);
if (ArmeriaHttpUtil.isInformational(status)) {
return obj;
}
final ContentPreviewer contentPreviewer = factory.responseContentPreviewer(ctx, resHeaders);
if (!contentPreviewer.isDisabled()) {
responseContentPreviewer = contentPreviewer;
}
} else if (obj instanceof HttpData) {
if (responseContentPreviewer != null) {
// Call produce() to release the resources in the previewer. Consider adding close() method.
responseContentPreviewer.produce();
responseContentPreviewer.onData((HttpData) obj);
}
// Set null to make it sure the log is complete.
ctx.logBuilder().responseContentPreview(null);
return cause;
}
};
return obj;
}
}

private ContentPreviewingUtil() {}
Expand Down
Expand Up @@ -16,6 +16,8 @@

package com.linecorp.armeria.server.encoding;

import javax.annotation.Nullable;

import org.reactivestreams.Subscriber;

import com.linecorp.armeria.common.FilteredHttpRequest;
Expand All @@ -34,6 +36,8 @@ final class HttpDecodedRequest extends FilteredHttpRequest {

private final StreamDecoder responseDecoder;

private boolean decoderFinished;

HttpDecodedRequest(HttpRequest delegate, StreamDecoderFactory decoderFactory,
ByteBufAllocator alloc) {
super(delegate);
Expand All @@ -51,15 +55,40 @@ protected HttpObject filter(HttpObject obj) {

@Override
protected void beforeComplete(Subscriber<? super HttpObject> subscriber) {
final HttpData lastData = responseDecoder.finish();
final HttpData lastData = closeResponseDecoder();
if (lastData == null) {
return;
}
if (!lastData.isEmpty()) {
subscriber.onNext(lastData);
} else {
lastData.close();
}
}

@Override
protected Throwable beforeError(Subscriber<? super HttpObject> subscriber, Throwable cause) {
responseDecoder.finish();
final HttpData lastData = closeResponseDecoder();
if (lastData != null) {
lastData.close();
}
return cause;
}

@Override
protected void onCancellation(Subscriber<? super HttpObject> subscriber) {
final HttpData lastData = closeResponseDecoder();
if (lastData != null) {
lastData.close();
}
}

@Nullable
private HttpData closeResponseDecoder() {
if (decoderFinished) {
return null;
}
decoderFinished = true;
return responseDecoder.finish();
}
}