Skip to content

Commit

Permalink
Fix handling of gRPC trailers-only responses (#101)
Browse files Browse the repository at this point in the history
connect-kotlin wasn't properly handling gRPC trailers-only responses,
leading to the inability to read grpc-status headers properly. Update
the GRPCCompletionParser to first look for `grpc-status` in headers then
trailers to handle these cases.

Update the connect-kotlin conformance tests to run combinations of both
Connect/gRPC/gRPC-Web protocols and Connect/gRPC servers instead of just
Connect/gRPC against a Connect server (which always sends trailers). By
enabling this earlier, we would've detected the trailers-only issue
earlier with the conformance test.

Fix gRPC protocol handlers to not filter out headers with `trailer-`
prefix - this should only happen for the Connect protocol. Stop sending
`TE: trailers` on gRPC-web requests (they don't use trailers). Preserve
header value case in gRPC-web (leading to errors parsing base64-encoded
error details).
  • Loading branch information
pkwarren authored Sep 19, 2023
1 parent 68c8d2b commit 2fd1600
Show file tree
Hide file tree
Showing 15 changed files with 211 additions and 140 deletions.
1 change: 1 addition & 0 deletions conformance/google-java/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies {
testImplementation(libs.mockito)
testImplementation(libs.kotlin.coroutines.core)
testImplementation(libs.testcontainers)
testImplementation(libs.slf4j.simple)
}

configure<SpotlessExtension> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import com.connectrpc.impl.ProtocolClient
import com.connectrpc.okhttp.ConnectOkHttpClient
import com.connectrpc.protocols.NetworkProtocol
import com.google.protobuf.ByteString
import com.google.protobuf.Empty
import com.google.protobuf.empty
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
Expand All @@ -50,36 +49,39 @@ import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters
import org.testcontainers.containers.GenericContainer
import org.testcontainers.containers.wait.strategy.HostPortWaitStrategy
import java.time.Duration
import java.util.Base64
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit


@RunWith(Parameterized::class)
class Conformance(
private val protocol: NetworkProtocol
private val protocol: NetworkProtocol,
private val serverType: ServerType
) {
private lateinit var connectClient: ProtocolClient
private lateinit var shortTimeoutConnectClient: ProtocolClient
private lateinit var unimplementedServiceClient: UnimplementedServiceClient
private lateinit var testServiceConnectClient: TestServiceClient

companion object {
const val CONFORMANCE_VERSION = "0b07f579cb61ad89de24524d62f804a2b03b1acf"
private const val CONFORMANCE_VERSION = "0b07f579cb61ad89de24524d62f804a2b03b1acf"

@JvmStatic
@Parameters(name = "protocol")
fun data(): Iterable<NetworkProtocol> {
@Parameters(name = "client={0},server={1}")
fun data(): Iterable<Array<Any>> {
return arrayListOf(
NetworkProtocol.CONNECT,
NetworkProtocol.GRPC
arrayOf(NetworkProtocol.CONNECT, ServerType.CONNECT_GO),
arrayOf(NetworkProtocol.GRPC, ServerType.CONNECT_GO),
arrayOf(NetworkProtocol.GRPC_WEB, ServerType.CONNECT_GO),
arrayOf(NetworkProtocol.GRPC, ServerType.GRPC_GO)
)
}

@JvmField
@ClassRule
val CONFORMANCE_CONTAINER = GenericContainer("connectrpc/conformance:$CONFORMANCE_VERSION")
val CONFORMANCE_CONTAINER_CONNECT = GenericContainer("connectrpc/conformance:$CONFORMANCE_VERSION")
.withExposedPorts(8080, 8081)
.withCommand(
"/usr/local/bin/serverconnect",
Expand All @@ -92,11 +94,28 @@ class Conformance(
"--key",
"cert/localhost.key"
)
.waitingFor(HostPortWaitStrategy().forPorts(8081))

@JvmField
@ClassRule
val CONFORMANCE_CONTAINER_GRPC = GenericContainer("connectrpc/conformance:$CONFORMANCE_VERSION")
.withExposedPorts(8081)
.withCommand(
"/usr/local/bin/servergrpc",
"--port",
"8081",
"--cert",
"cert/localhost.crt",
"--key",
"cert/localhost.key"
)
.waitingFor(HostPortWaitStrategy().forPorts(8081))
}

@Before
fun before() {
val host = "https://localhost:${CONFORMANCE_CONTAINER.getMappedPort(8081)}"
val serverPort = if (serverType == ServerType.CONNECT_GO) CONFORMANCE_CONTAINER_CONNECT.getMappedPort(8081) else CONFORMANCE_CONTAINER_GRPC.getMappedPort(8081)
val host = "https://localhost:$serverPort"
val (sslSocketFactory, trustManager) = sslContext()
val client = OkHttpClient.Builder()
.protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1))
Expand Down Expand Up @@ -168,14 +187,17 @@ class Conformance(
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
// For some reason we keep timing out on these calls and not actually getting a real response like with grpc?
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.message).isEqualTo("soirée 🎉")
assertThat(result.connectError()!!.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail
)
countDownLatch.countDown()
try {
// For some reason we keep timing out on these calls and not actually getting a real response like with grpc?
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.message).isEqualTo("soirée 🎉")
assertThat(result.connectError()!!.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail
)
} finally {
countDownLatch.countDown()
}
}
)
}
Expand All @@ -190,12 +212,12 @@ class Conformance(
@Test
fun emptyUnary(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.emptyCall(Empty.newBuilder().build()) { response ->
testServiceConnectClient.emptyCall(empty {}) { response ->
response.failure {
fail<Unit>("expected error to be null")
}
response.success { success ->
assertThat(success.message).isEqualTo(Empty.newBuilder().build())
assertThat(success.message).isEqualTo(empty {})
countDownLatch.countDown()
}
}
Expand Down Expand Up @@ -306,10 +328,13 @@ class Conformance(
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
assertThat(result.error).isNotNull()
assertThat(result.connectError()!!.code).isEqualTo(Code.DEADLINE_EXCEEDED)
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED)
countDownLatch.countDown()
try {
assertThat(result.error).isNotNull()
assertThat(result.connectError()!!.code).isEqualTo(Code.DEADLINE_EXCEEDED)
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED)
} finally {
countDownLatch.countDown()
}
}
)
}
Expand Down Expand Up @@ -353,7 +378,7 @@ class Conformance(
@Test
fun unimplementedMethod(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.unimplementedCall(Empty.newBuilder().build()) { response ->
testServiceConnectClient.unimplementedCall(empty {}) { response ->
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
countDownLatch.countDown()
}
Expand All @@ -364,14 +389,41 @@ class Conformance(
@Test
fun unimplementedService(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
unimplementedServiceClient.unimplementedCall(Empty.newBuilder().build()) { response ->
unimplementedServiceClient.unimplementedCall(empty {}) { response ->
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
countDownLatch.countDown()
}
countDownLatch.await(500, TimeUnit.MILLISECONDS)
assertThat(countDownLatch.count).isZero()
}

@Test
fun unimplementedServerStreamingService(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
val stream = unimplementedServiceClient.unimplementedStreamingOutputCall()
stream.send(empty { })
withContext(Dispatchers.IO) {
val job = async {
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
try {
assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.UNIMPLEMENTED)
} finally {
countDownLatch.countDown()
}
}
)
}
}
countDownLatch.await(5, TimeUnit.SECONDS)
job.cancel()
assertThat(countDownLatch.count).isZero()
stream.close()
}
}

@Test
fun failUnary(): Unit = runBlocking {
val expectedErrorDetail = errorDetail {
Expand Down Expand Up @@ -399,12 +451,12 @@ class Conformance(

@Test
fun emptyUnaryBlocking(): Unit = runBlocking {
val response = testServiceConnectClient.emptyCallBlocking(Empty.newBuilder().build()).execute()
val response = testServiceConnectClient.emptyCallBlocking(empty {}).execute()
response.failure {
fail<Unit>("expected error to be null")
}
response.success { success ->
assertThat(success.message).isEqualTo(Empty.newBuilder().build())
assertThat(success.message).isEqualTo(empty {})
}
}

Expand Down Expand Up @@ -499,13 +551,13 @@ class Conformance(

@Test
fun unimplementedMethodBlocking(): Unit = runBlocking {
val response = testServiceConnectClient.unimplementedCallBlocking(Empty.newBuilder().build()).execute()
val response = testServiceConnectClient.unimplementedCallBlocking(empty {}).execute()
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
}

@Test
fun unimplementedServiceBlocking(): Unit = runBlocking {
val response = unimplementedServiceClient.unimplementedCallBlocking(Empty.newBuilder().build()).execute()
val response = unimplementedServiceClient.unimplementedCallBlocking(empty {}).execute()
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright 2022-2023 The Connect Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.connectrpc.conformance

enum class ServerType {
CONNECT_GO,
GRPC_GO,
}
2 changes: 2 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ moshi = "1.15.0"
okhttp = "4.10.0"
okio = "3.0.0"
protobuf = "3.24.3"
slf4j = "1.7.36"

[libraries]
android = { module = "com.google.android:android", version.ref = "android" }
Expand Down Expand Up @@ -49,6 +50,7 @@ protobuf-java-util = { module = "com.google.protobuf:protobuf-java-util", versio
protobuf-javalite = { module = "com.google.protobuf:protobuf-javalite", version.ref = "protobuf" }
protobuf-kotlin = { module = "com.google.protobuf:protobuf-kotlin", version.ref = "protobuf" }
protobuf-kotlinlite = { module = "com.google.protobuf:protobuf-kotlin-lite", version.ref = "protobuf" }
slf4j-simple = { module = "org.slf4j:slf4j-simple", version.ref = "slf4j" }
spotless = { module = "com.diffplug.spotless:spotless-plugin-gradle", version = "6.13.0" }
testcontainers = { module = "org.testcontainers:testcontainers", version = "1.19.0" }

Expand Down
2 changes: 1 addition & 1 deletion library/src/main/kotlin/com/connectrpc/StreamResult.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.connectrpc
*
* A typical stream receives [Headers] > [Message] > [Message] > [Message] ... > [Complete]
*/
sealed class StreamResult<Output> constructor(
sealed class StreamResult<Output>(
val error: Throwable? = null
) {
// Headers have been received over the stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ object GzipCompressionPool : CompressionPool {

override fun decompress(buffer: Buffer): Buffer {
val result = Buffer()
val source = GzipSource(buffer)
while (source.read(result, Int.MAX_VALUE.toLong()) != -1L) {
// continue reading.
GzipSource(buffer).use {
while (it.read(result, Int.MAX_VALUE.toLong()) != -1L) {
// continue reading.
}
}
return result
}
Expand Down
2 changes: 0 additions & 2 deletions library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import com.connectrpc.http.HTTPClientInterface
import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.Stream
import com.connectrpc.protocols.GETConfiguration
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.suspendCancellableCoroutine
import java.net.URL
Expand Down Expand Up @@ -162,7 +161,6 @@ class ProtocolClient(
return ClientOnlyStream(stream)
}

@OptIn(ExperimentalCoroutinesApi::class)
private suspend fun <Input : Any, Output : Any> bidirectionalStream(
methodSpec: MethodSpec<Input, Output>,
headers: Headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal class ConnectInterceptor(
trailers.putAll(response.headers.toTrailers())
trailers.putAll(response.trailers)
val responseHeaders =
response.headers.filter { entry -> !entry.key.startsWith("trailer") }.toMutableMap()
response.headers.filter { entry -> !entry.key.startsWith("trailer-") }
val compressionPool = clientConfig.compressionPool(responseHeaders[CONTENT_ENCODING]?.first())
val (code, connectError) = if (response.code != Code.OK) {
val error = parseConnectUnaryError(code = response.code, response.headers, response.message.buffer)
Expand Down Expand Up @@ -150,7 +150,7 @@ internal class ConnectInterceptor(
val streamResult: StreamResult<Buffer> = res.fold(
onHeaders = { result ->
val responseHeaders =
result.headers.filter { entry -> !entry.key.startsWith("trailer") }.toMutableMap()
result.headers.filter { entry -> !entry.key.startsWith("trailer-") }
responseCompressionPool =
clientConfig.compressionPool(responseHeaders[CONNECT_STREAMING_CONTENT_ENCODING]?.first())
StreamResult.Headers(responseHeaders)
Expand All @@ -168,7 +168,7 @@ internal class ConnectInterceptor(
}
},
onCompletion = { result ->
val streamTrailers: Trailers = result.trailers
val streamTrailers = result.trailers
val error = result.connectError()
StreamResult.Complete(error?.code ?: Code.OK, error = error, streamTrailers)
}
Expand Down Expand Up @@ -290,7 +290,7 @@ internal class ConnectInterceptor(

private fun Headers.toTrailers(): Trailers {
val trailers = mutableMapOf<String, MutableList<String>>()
for (pair in this.filter { entry -> entry.key.startsWith("trailer") }) {
for (pair in this.filter { entry -> entry.key.startsWith("trailer-") }) {
val key = pair.key.substringAfter("trailer-")
if (trailers.containsKey(key)) {
trailers[key]?.add(pair.value.first())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
package com.connectrpc.protocols

import com.connectrpc.Code
import com.connectrpc.ConnectError
import com.connectrpc.ConnectErrorDetail
import com.connectrpc.Headers
import com.connectrpc.SerializationStrategy
import okio.ByteString

/**
Expand All @@ -29,5 +32,26 @@ internal data class GRPCCompletion(
// Message data.
val message: ByteString,
// List of error details.
val errorDetails: List<ConnectErrorDetail>
val errorDetails: List<ConnectErrorDetail>,
// Set to either message headers (or trailers) where the gRPC status was found.
val metadata: Headers
)

internal fun grpcCompletionToConnectError(completion: GRPCCompletion?, serializationStrategy: SerializationStrategy, error: Throwable?): ConnectError? {
if (error is ConnectError) {
return error
}
val code = completion?.code ?: Code.UNKNOWN
if (error != null || code != Code.OK) {
return ConnectError(
code = code,
errorDetailParser = serializationStrategy.errorDetailParser(),
message = completion?.message?.utf8(),
exception = error,
details = completion?.errorDetails ?: emptyList(),
metadata = completion?.metadata ?: emptyMap()
)
}
// Successful call.
return null
}
Loading

0 comments on commit 2fd1600

Please sign in to comment.