Skip to content

Commit

Permalink
Clean up the implementation, as this seems the simplest solution
Browse files Browse the repository at this point in the history
  • Loading branch information
niloc132 committed Aug 29, 2022
1 parent c9677b6 commit 7063333
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 166 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package io.grpc.servlet.jakarta.web;

import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;

/**
* Util class to allow the complete() call to get some work done (writing trailers as a payload) before
* calling the actual container implementation. The container will finish closing the stream before
* invoking the async listener and formally informing the filter that the stream has closed, making
* this our last chance to intercept the closing of the stream before it happens.
*/
public class DelegatingAsyncContext implements AsyncContext {
private final AsyncContext delegate;

public DelegatingAsyncContext(AsyncContext delegate) {
this.delegate = delegate;
}

@Override
public ServletRequest getRequest() {
return delegate.getRequest();
}

@Override
public ServletResponse getResponse() {
return delegate.getResponse();
}

@Override
public boolean hasOriginalRequestAndResponse() {
return delegate.hasOriginalRequestAndResponse();
}

@Override
public void dispatch() {
delegate.dispatch();
}

@Override
public void dispatch(String path) {
delegate.dispatch(path);
}

@Override
public void dispatch(ServletContext context, String path) {
delegate.dispatch(context, path);
}

@Override
public void complete() {
delegate.complete();
}

@Override
public void start(Runnable run) {
delegate.start(run);
}

@Override
public void addListener(AsyncListener listener) {
delegate.addListener(listener);
}

@Override
public void addListener(AsyncListener listener, ServletRequest servletRequest,
ServletResponse servletResponse) {
delegate.addListener(listener, servletRequest, servletResponse);
}

@Override
public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
return delegate.createListener(clazz);
}

@Override
public void setTimeout(long timeout) {
delegate.setTimeout(timeout);
}

@Override
public long getTimeout() {
return delegate.getTimeout();
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import io.grpc.internal.GrpcUtil;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
Expand All @@ -16,10 +14,19 @@

import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.function.Supplier;
import java.util.regex.Pattern;

/**
* Servlet filter that translates grpc-web on the fly to match what is expected by GrpcServlet.
* This work is done in-process with no addition copies to the request or response data - only
* the content type header and the trailer content is specially treated at this time.
*
* Note that grpc-web-text is not yet supported.
*/
public class GrpcWebFilter extends HttpFilter {
public static final String CONTENT_TYPE_GRPC_WEB = GrpcUtil.CONTENT_TYPE_GRPC + "-web";

Expand All @@ -39,7 +46,7 @@ public String getContentType() {

@Override
public AsyncContext startAsync() throws IllegalStateException {
return super.startAsync(this, wrappedResponse);
return startAsync(this, wrappedResponse);
}

@Override
Expand All @@ -49,25 +56,40 @@ public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse se
return new DelegatingAsyncContext(delegate) {
@Override
public void complete() {
// Write any trailers out to the output stream as a payload, since grpc-web doesn't
// use proper trailers.
try {
wrappedResponse.finish();
if (wrappedResponse.trailers != null) {
Map<String, String> map = wrappedResponse.trailers.get();
if (map != null) {
// write a payload, even for an empty set of trailers, but not for
// the absence of trailers.
int trailerLength = map.entrySet().stream().mapToInt(e -> e.getKey().length() + e.getValue().length() + 4).sum();
ByteBuffer payload = ByteBuffer.allocate(5 + trailerLength);
payload.put((byte) 0x80);
payload.putInt(trailerLength);
for (Map.Entry<String, String> entry : map.entrySet()) {
payload.put(entry.getKey().getBytes(StandardCharsets.US_ASCII));
payload.put(": ".getBytes(StandardCharsets.US_ASCII));
payload.put(entry.getValue().getBytes(StandardCharsets.US_ASCII));
payload.put("\r\n".getBytes(StandardCharsets.US_ASCII));
}
wrappedResponse.getOutputStream().write(payload.array());
}
}
} catch (IOException e) {
// TODO reconsider this, find a better way to report
throw new UncheckedIOException(e);
}

// Let the superclass complete the stream so we formally close it
super.complete();
}
};
}
};

try {
chain.doFilter(wrappedRequest, wrappedResponse);
} finally {
// if (request.isAsyncStarted()) {
// request.getAsyncContext().addListener(new GrpcWebAsyncListener(wrappedResponse));
// }
}
chain.doFilter(wrappedRequest, wrappedResponse);
} else {
chain.doFilter(request, response);
}
Expand All @@ -80,7 +102,6 @@ private static boolean isGrpcWeb(ServletRequest request) {
// Technically we should throw away content-length too, but the impl won't care
public static class GrpcWebHttpResponse extends HttpServletResponseWrapper {
private Supplier<Map<String, String>> trailers;
private GrpcWebServletOutputStream outputStream;

public GrpcWebHttpResponse(HttpServletResponse response) {
super(response);
Expand All @@ -93,19 +114,6 @@ public void setContentType(String type) {
type.replaceFirst(Pattern.quote(GrpcUtil.CONTENT_TYPE_GRPC), CONTENT_TYPE_GRPC_WEB));
}

@Override
public GrpcWebServletOutputStream getOutputStream() throws IOException {
if (outputStream == null) {
outputStream = new GrpcWebServletOutputStream(super.getOutputStream());
}
return outputStream;
}

@Override
public void flushBuffer() throws IOException {
super.getOutputStream().flush();
}

// intercept trailers and write them out as a message just before we complete
@Override
public void setTrailerFields(Supplier<Map<String, String>> supplier) {
Expand All @@ -116,86 +124,5 @@ public void setTrailerFields(Supplier<Map<String, String>> supplier) {
public Supplier<Map<String, String>> getTrailerFields() {
return trailers;
}

public void finish() throws IOException {
// write any trailers out to the output stream
getOutputStream().writeTrailers(trailers);
}


}

private static class DelegatingAsyncContext implements AsyncContext {
private final AsyncContext delegate;

private DelegatingAsyncContext(AsyncContext delegate) {
this.delegate = delegate;
}

@Override
public ServletRequest getRequest() {
return delegate.getRequest();
}

@Override
public ServletResponse getResponse() {
return delegate.getResponse();
}

@Override
public boolean hasOriginalRequestAndResponse() {
return delegate.hasOriginalRequestAndResponse();
}

@Override
public void dispatch() {
delegate.dispatch();
}

@Override
public void dispatch(String path) {
delegate.dispatch(path);
}

@Override
public void dispatch(ServletContext context, String path) {
delegate.dispatch(context, path);
}

@Override
public void complete() {
delegate.complete();
}

@Override
public void start(Runnable run) {
delegate.start(run);
}

@Override
public void addListener(AsyncListener listener) {
delegate.addListener(listener);
}

@Override
public void addListener(AsyncListener listener, ServletRequest servletRequest,
ServletResponse servletResponse) {
delegate.addListener(listener, servletRequest, servletResponse);
}

@Override
public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
return delegate.createListener(clazz);
}

@Override
public void setTimeout(long timeout) {
delegate.setTimeout(timeout);
}

@Override
public long getTimeout() {
return delegate.getTimeout();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,4 @@ public void flush() throws IOException {
public void close() throws IOException {
super.close();
}

public void writeTrailers(Supplier<Map<String, String>> trailers) throws IOException {
// probably could inline this and drop the class if we don't have to frame other messages
if (trailers == null) {
return;
}
Map<String, String> map = trailers.get();
if (map == null) {
return;
}
// write a payload, even for an empty set of trailers
int trailerLength =
map.entrySet().stream().mapToInt(e -> e.getKey().length() + e.getValue().length() + 4).sum();
ByteBuffer payload = ByteBuffer.allocate(5 + trailerLength);
payload.put((byte) 0x80);
payload.putInt(trailerLength);
for (Map.Entry<String, String> entry : map.entrySet()) {
payload.put(entry.getKey().getBytes(StandardCharsets.US_ASCII));
payload.put(": ".getBytes(StandardCharsets.US_ASCII));
payload.put(entry.getValue().getBytes(StandardCharsets.US_ASCII));
payload.put("\r\n".getBytes(StandardCharsets.US_ASCII));
}
wrapped.write(payload.array());
}
}

0 comments on commit 7063333

Please sign in to comment.