Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of gRPC trailers-only responses #101

Merged
merged 8 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,36 +49,38 @@ import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters
import org.testcontainers.containers.GenericContainer
import org.testcontainers.containers.wait.strategy.HostPortWaitStrategy
import java.time.Duration
import java.util.Base64
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit


@RunWith(Parameterized::class)
class Conformance(
private val protocol: NetworkProtocol
private val clientProtocol: NetworkProtocol,
private val serverProtocol: NetworkProtocol
pkwarren marked this conversation as resolved.
Show resolved Hide resolved
) {
private lateinit var connectClient: ProtocolClient
private lateinit var protocolClient: ProtocolClient
private lateinit var shortTimeoutConnectClient: ProtocolClient
private lateinit var unimplementedServiceClient: UnimplementedServiceClient
private lateinit var testServiceConnectClient: TestServiceClient
private lateinit var testServiceClient: TestServiceClient

companion object {
const val CONFORMANCE_VERSION = "0b07f579cb61ad89de24524d62f804a2b03b1acf"
private const val CONFORMANCE_VERSION = "0b07f579cb61ad89de24524d62f804a2b03b1acf"

@JvmStatic
@Parameters(name = "protocol")
fun data(): Iterable<NetworkProtocol> {
@Parameters(name = "client={0},server={1}")
fun data(): Iterable<Array<NetworkProtocol>> {
return arrayListOf(
NetworkProtocol.CONNECT,
NetworkProtocol.GRPC
arrayOf(NetworkProtocol.CONNECT, NetworkProtocol.CONNECT),
arrayOf(NetworkProtocol.GRPC, NetworkProtocol.CONNECT),
arrayOf(NetworkProtocol.GRPC, NetworkProtocol.GRPC)
pkwarren marked this conversation as resolved.
Show resolved Hide resolved
)
}

@JvmField
@ClassRule
val CONFORMANCE_CONTAINER = GenericContainer("connectrpc/conformance:$CONFORMANCE_VERSION")
val CONFORMANCE_CONTAINER_CONNECT = GenericContainer("connectrpc/conformance:$CONFORMANCE_VERSION")
.withExposedPorts(8080, 8081)
.withCommand(
"/usr/local/bin/serverconnect",
Expand All @@ -91,11 +93,28 @@ class Conformance(
"--key",
"cert/localhost.key"
)
.waitingFor(HostPortWaitStrategy().forPorts(8081))
pkwarren marked this conversation as resolved.
Show resolved Hide resolved

@JvmField
@ClassRule
val CONFORMANCE_CONTAINER_GRPC = GenericContainer("connectrpc/conformance:$CONFORMANCE_VERSION")
pkwarren marked this conversation as resolved.
Show resolved Hide resolved
.withExposedPorts(8081)
.withCommand(
"/usr/local/bin/servergrpc",
"--port",
"8081",
"--cert",
"cert/localhost.crt",
"--key",
"cert/localhost.key"
)
.waitingFor(HostPortWaitStrategy().forPorts(8081))
}

@Before
fun before() {
val host = "https://localhost:${CONFORMANCE_CONTAINER.getMappedPort(8081)}"
val serverPort = if (serverProtocol == NetworkProtocol.CONNECT) CONFORMANCE_CONTAINER_CONNECT.getMappedPort(8081) else CONFORMANCE_CONTAINER_GRPC.getMappedPort(8081)
val host = "https://localhost:$serverPort"
val (sslSocketFactory, trustManager) = sslContext()
val client = OkHttpClient.Builder()
.protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1))
Expand All @@ -117,23 +136,23 @@ class Conformance(
ProtocolClientConfig(
host = host,
serializationStrategy = GoogleJavaProtobufStrategy(),
networkProtocol = protocol,
networkProtocol = clientProtocol,
requestCompression = RequestCompression(10, GzipCompressionPool),
compressionPools = listOf(GzipCompressionPool)
)
)
connectClient = ProtocolClient(
protocolClient = ProtocolClient(
httpClient = ConnectOkHttpClient(client),
ProtocolClientConfig(
host = host,
serializationStrategy = GoogleJavaProtobufStrategy(),
networkProtocol = protocol,
networkProtocol = clientProtocol,
requestCompression = RequestCompression(10, GzipCompressionPool),
compressionPools = listOf(GzipCompressionPool)
)
)
testServiceConnectClient = TestServiceClient(connectClient)
unimplementedServiceClient = UnimplementedServiceClient(connectClient)
testServiceClient = TestServiceClient(protocolClient)
unimplementedServiceClient = UnimplementedServiceClient(protocolClient)
}

@Test
Expand All @@ -143,7 +162,7 @@ class Conformance(
reason = "soirée 🎉"
domain = "connect-conformance"
}
val stream = testServiceConnectClient.failStreamingOutputCall()
val stream = testServiceClient.failStreamingOutputCall()
val sizes = listOf(
31415,
9,
Expand Down Expand Up @@ -189,7 +208,7 @@ class Conformance(
@Test
fun emptyUnary(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.emptyCall(Empty.newBuilder().build()) { response ->
testServiceClient.emptyCall(Empty.newBuilder().build()) { response ->
response.failure {
fail<Unit>("expected error to be null")
}
Expand All @@ -212,7 +231,7 @@ class Conformance(
}
}
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.unaryCall(message) { response ->
testServiceClient.unaryCall(message) { response ->
response.failure {
fail<Unit>("expected error to be null")
}
Expand Down Expand Up @@ -242,7 +261,7 @@ class Conformance(
payload = payload { body = ByteString.copyFrom(ByteArray(size)) }
}
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.unaryCall(message, headers) { response ->
testServiceClient.unaryCall(message, headers) { response ->
assertThat(response.code).isEqualTo(Code.OK)
assertThat(response.headers[leadingKey]).containsExactly(leadingValue)
assertThat(response.trailers[trailingKey]).containsExactly(b64Encode(trailingValue))
Expand All @@ -267,7 +286,7 @@ class Conformance(
}
}
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.unaryCall(message) { response ->
testServiceClient.unaryCall(message) { response ->
assertThat(response.code).isEqualTo(Code.UNKNOWN)
response.failure { errorResponse ->
assertThat(errorResponse.error).isNotNull()
Expand Down Expand Up @@ -326,7 +345,7 @@ class Conformance(
val statusMessage =
"\\t\\ntest with whitespace\\r\\nand Unicode BMP ☺ and non-BMP \uD83D\uDE08\\t\\n"
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.unaryCall(
testServiceClient.unaryCall(
simpleRequest {
responseStatus = echoStatus {
code = 2
Expand All @@ -352,7 +371,7 @@ class Conformance(
@Test
fun unimplementedMethod(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.unimplementedCall(Empty.newBuilder().build()) { response ->
testServiceClient.unimplementedCall(Empty.newBuilder().build()) { response ->
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
countDownLatch.countDown()
}
Expand All @@ -378,7 +397,7 @@ class Conformance(
domain = "connect-conformance"
}
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.failUnaryCall(simpleRequest {}) { response ->
testServiceClient.failUnaryCall(simpleRequest {}) { response ->
assertThat(response.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
response.failure { errorResponse ->
val error = errorResponse.error
Expand All @@ -398,7 +417,7 @@ class Conformance(

@Test
fun emptyUnaryBlocking(): Unit = runBlocking {
val response = testServiceConnectClient.emptyCallBlocking(Empty.newBuilder().build()).execute()
val response = testServiceClient.emptyCallBlocking(Empty.newBuilder().build()).execute()
response.failure {
fail<Unit>("expected error to be null")
}
Expand All @@ -416,7 +435,7 @@ class Conformance(
body = ByteString.copyFrom(ByteArray(size))
}
}
val response = testServiceConnectClient.unaryCallBlocking(message).execute()
val response = testServiceClient.unaryCallBlocking(message).execute()
response.failure {
fail<Unit>("expected error to be null")
}
Expand All @@ -441,7 +460,7 @@ class Conformance(
responseSize = size
payload = payload { body = ByteString.copyFrom(ByteArray(size)) }
}
val response = testServiceConnectClient.unaryCallBlocking(message, headers).execute()
val response = testServiceClient.unaryCallBlocking(message, headers).execute()
assertThat(response.code).isEqualTo(Code.OK)
assertThat(response.headers[leadingKey]).containsExactly(leadingValue)
assertThat(response.trailers[trailingKey]).containsExactly(b64Encode(trailingValue))
Expand All @@ -461,7 +480,7 @@ class Conformance(
message = "test status message"
}
}
val response = testServiceConnectClient.unaryCallBlocking(message).execute()
val response = testServiceClient.unaryCallBlocking(message).execute()
assertThat(response.code).isEqualTo(Code.UNKNOWN)
response.failure { errorResponse ->
assertThat(errorResponse.error).isNotNull()
Expand All @@ -477,7 +496,7 @@ class Conformance(
fun specialStatusBlocking(): Unit = runBlocking {
val statusMessage =
"\\t\\ntest with whitespace\\r\\nand Unicode BMP ☺ and non-BMP \uD83D\uDE08\\t\\n"
val response = testServiceConnectClient.unaryCallBlocking(
val response = testServiceClient.unaryCallBlocking(
simpleRequest {
responseStatus = echoStatus {
code = 2
Expand All @@ -498,7 +517,7 @@ class Conformance(

@Test
fun unimplementedMethodBlocking(): Unit = runBlocking {
val response = testServiceConnectClient.unimplementedCallBlocking(Empty.newBuilder().build()).execute()
val response = testServiceClient.unimplementedCallBlocking(Empty.newBuilder().build()).execute()
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
}

Expand All @@ -514,7 +533,7 @@ class Conformance(
reason = "soirée 🎉"
domain = "connect-conformance"
}
val response = testServiceConnectClient.failUnaryCallBlocking(simpleRequest {}).execute()
val response = testServiceClient.failUnaryCallBlocking(simpleRequest {}).execute()
assertThat(response.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
response.failure { errorResponse ->
val error = errorResponse.error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal class ConnectInterceptor(
trailers.putAll(response.headers.toTrailers())
trailers.putAll(response.trailers)
val responseHeaders =
response.headers.filter { entry -> !entry.key.startsWith("trailer") }.toMutableMap()
response.headers.filter { entry -> !entry.key.startsWith("trailer") }
pkwarren marked this conversation as resolved.
Show resolved Hide resolved
val compressionPool = clientConfig.compressionPool(responseHeaders[CONTENT_ENCODING]?.first())
val (code, connectError) = if (response.code != Code.OK) {
val error = parseConnectUnaryError(code = response.code, response.headers, response.message.buffer)
Expand Down Expand Up @@ -150,7 +150,7 @@ internal class ConnectInterceptor(
val streamResult: StreamResult<Buffer> = res.fold(
onHeaders = { result ->
val responseHeaders =
result.headers.filter { entry -> !entry.key.startsWith("trailer") }.toMutableMap()
result.headers.filter { entry -> !entry.key.startsWith("trailer") }
responseCompressionPool =
clientConfig.compressionPool(responseHeaders[CONNECT_STREAMING_CONTENT_ENCODING]?.first())
StreamResult.Headers(responseHeaders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,31 @@ internal class GRPCInterceptor(
)
},
responseFunction = { response ->
val headers = response.headers
val trailers = response.trailers
val completion = completionParser.parse(trailers)
// Handle Trailers-Only response: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#responses
val completion = completionParser.parse(trailers.ifEmpty { headers })
pkwarren marked this conversation as resolved.
Show resolved Hide resolved
val code = completion?.code ?: Code.UNKNOWN
val responseHeaders = response.headers.toMutableMap()
if (response.code != Code.OK) {
return@UnaryFunction HTTPResponse(
code = response.code,
headers = response.headers.toMutableMap(),
headers = headers,
message = Buffer(),
trailers = trailers,
error = response.error,
tracingInfo = response.tracingInfo
)
}
val compressionPool =
clientConfig.compressionPool(responseHeaders[GRPC_ENCODING]?.first())
clientConfig.compressionPool(headers[GRPC_ENCODING]?.first())
if (code == Code.OK) {
val (_, message) = Envelope.unpackWithHeaderByte(
response.message.buffer,
compressionPool
)
HTTPResponse(
code = code,
headers = responseHeaders,
headers = headers,
message = message,
trailers = trailers,
error = response.error,
Expand All @@ -109,7 +110,7 @@ internal class GRPCInterceptor(
}
HTTPResponse(
code = code,
headers = responseHeaders,
headers = headers,
message = result,
trailers = trailers,
error = ConnectError(
Expand Down Expand Up @@ -142,7 +143,7 @@ internal class GRPCInterceptor(
streamResultFunction = { res ->
val streamResult = res.fold(
onHeaders = { result ->
val responseHeaders = result.headers.filter { entry -> !entry.key.startsWith("trailer") }.toMutableMap()
val responseHeaders = result.headers.filter { entry -> !entry.key.startsWith("trailer") }
pkwarren marked this conversation as resolved.
Show resolved Hide resolved
responseCompressionPool = clientConfig.compressionPool(responseHeaders[GRPC_ENCODING]?.first())
StreamResult.Headers(responseHeaders)
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,19 @@ internal class GRPCWebInterceptor(
)
},
responseFunction = { response ->
val headers = response.headers
if (response.code != Code.OK) {
return@UnaryFunction HTTPResponse(
code = response.code,
headers = response.headers.toMutableMap(),
headers = headers,
message = Buffer(),
trailers = emptyMap(),
error = response.error,
tracingInfo = response.tracingInfo
)
}
val responseHeaders = response.headers.toMutableMap()
val compressionPool =
clientConfig.compressionPool(responseHeaders[GRPC_ENCODING]?.first())
clientConfig.compressionPool(headers[GRPC_ENCODING]?.first())
// gRPC Web returns data in 2 chunks (either/both of which may be compressed):
// 1. OPTIONAL (when not trailers-only): The (headers and length prefixed)
// message data.
Expand All @@ -106,7 +106,7 @@ internal class GRPCWebInterceptor(
}
HTTPResponse(
code = code,
headers = responseHeaders,
headers = headers,
message = result,
trailers = trailers,
error = ConnectError(
Expand Down Expand Up @@ -160,7 +160,7 @@ internal class GRPCWebInterceptor(
}
HTTPResponse(
code = finalCode,
headers = responseHeaders,
headers = headers,
message = unpacked,
trailers = finalTrailers,
error = error,
Expand Down Expand Up @@ -188,7 +188,7 @@ internal class GRPCWebInterceptor(
streamResultFunction = { res ->
val streamResult = res.fold(
onHeaders = { result ->
val responseHeaders = result.headers.filter { entry -> !entry.key.startsWith("trailer") }.toMutableMap()
val responseHeaders = result.headers.filter { entry -> !entry.key.startsWith("trailer") }
responseCompressionPool = clientConfig.compressionPool(responseHeaders[GRPC_ENCODING]?.first())
// Trailers are passed in the headers for GRPC.
val streamTrailers: Trailers = responseHeaders
Expand Down