Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of gRPC trailers-only responses #101

Merged
merged 8 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conformance/google-java/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies {
testImplementation(libs.mockito)
testImplementation(libs.kotlin.coroutines.core)
testImplementation(libs.testcontainers)
testImplementation(libs.slf4j.simple)
}

configure<SpotlessExtension> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import com.connectrpc.impl.ProtocolClient
import com.connectrpc.okhttp.ConnectOkHttpClient
import com.connectrpc.protocols.NetworkProtocol
import com.google.protobuf.ByteString
import com.google.protobuf.Empty
import com.google.protobuf.empty
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
Expand All @@ -50,36 +49,39 @@ import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters
import org.testcontainers.containers.GenericContainer
import org.testcontainers.containers.wait.strategy.HostPortWaitStrategy
import java.time.Duration
import java.util.Base64
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit


@RunWith(Parameterized::class)
class Conformance(
private val protocol: NetworkProtocol
private val protocol: NetworkProtocol,
private val serverType: ServerType
) {
private lateinit var connectClient: ProtocolClient
private lateinit var shortTimeoutConnectClient: ProtocolClient
private lateinit var unimplementedServiceClient: UnimplementedServiceClient
private lateinit var testServiceConnectClient: TestServiceClient

companion object {
const val CONFORMANCE_VERSION = "0b07f579cb61ad89de24524d62f804a2b03b1acf"
private const val CONFORMANCE_VERSION = "0b07f579cb61ad89de24524d62f804a2b03b1acf"

@JvmStatic
@Parameters(name = "protocol")
fun data(): Iterable<NetworkProtocol> {
@Parameters(name = "client={0},server={1}")
fun data(): Iterable<Array<Any>> {
return arrayListOf(
NetworkProtocol.CONNECT,
NetworkProtocol.GRPC
arrayOf(NetworkProtocol.CONNECT, ServerType.CONNECT_GO),
arrayOf(NetworkProtocol.GRPC, ServerType.CONNECT_GO),
arrayOf(NetworkProtocol.GRPC_WEB, ServerType.CONNECT_GO),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added gRPC web protocol tests as well (and found some issues as a result).

arrayOf(NetworkProtocol.GRPC, ServerType.GRPC_GO)
)
}

@JvmField
@ClassRule
val CONFORMANCE_CONTAINER = GenericContainer("connectrpc/conformance:$CONFORMANCE_VERSION")
val CONFORMANCE_CONTAINER_CONNECT = GenericContainer("connectrpc/conformance:$CONFORMANCE_VERSION")
.withExposedPorts(8080, 8081)
.withCommand(
"/usr/local/bin/serverconnect",
Expand All @@ -92,11 +94,28 @@ class Conformance(
"--key",
"cert/localhost.key"
)
.waitingFor(HostPortWaitStrategy().forPorts(8081))
pkwarren marked this conversation as resolved.
Show resolved Hide resolved

@JvmField
@ClassRule
val CONFORMANCE_CONTAINER_GRPC = GenericContainer("connectrpc/conformance:$CONFORMANCE_VERSION")
pkwarren marked this conversation as resolved.
Show resolved Hide resolved
.withExposedPorts(8081)
.withCommand(
"/usr/local/bin/servergrpc",
"--port",
"8081",
"--cert",
"cert/localhost.crt",
"--key",
"cert/localhost.key"
)
.waitingFor(HostPortWaitStrategy().forPorts(8081))
}

@Before
fun before() {
val host = "https://localhost:${CONFORMANCE_CONTAINER.getMappedPort(8081)}"
val serverPort = if (serverType == ServerType.CONNECT_GO) CONFORMANCE_CONTAINER_CONNECT.getMappedPort(8081) else CONFORMANCE_CONTAINER_GRPC.getMappedPort(8081)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal, but calling out that Swift uses 8083 for grpc-go because that's where it's hosted in the container: https://github.com/connectrpc/connect-swift/blob/f7aab0e53c38f15d5ab62e8e8afa2f984c7e34f3/Tests/ConnectLibraryTests/ConnectConformance/ConformanceConfiguration.swift#L58

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could update to match, however that is coming from https://github.com/connectrpc/connect-swift/blob/f7aab0e53c38f15d5ab62e8e8afa2f984c7e34f3/Makefile#L49-L54. We're no longer needing to spawn containers from the Makefile in this project - we're using testcontainers to do that, so port 8081 is actually ephemeral on the host.

val host = "https://localhost:$serverPort"
val (sslSocketFactory, trustManager) = sslContext()
val client = OkHttpClient.Builder()
.protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1))
Expand Down Expand Up @@ -168,14 +187,17 @@ class Conformance(
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
// For some reason we keep timing out on these calls and not actually getting a real response like with grpc?
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.message).isEqualTo("soirée 🎉")
assertThat(result.connectError()!!.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail
)
countDownLatch.countDown()
try {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we want this to explode if any of the asserts fail rather than allowing tests to continue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will still fail the test, but this will prevent it waiting 5 seconds on the countDownLatch for the failure to register. Verified by adding a failing assert.

// For some reason we keep timing out on these calls and not actually getting a real response like with grpc?
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.message).isEqualTo("soirée 🎉")
assertThat(result.connectError()!!.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail
)
} finally {
countDownLatch.countDown()
}
}
)
}
Expand All @@ -190,12 +212,12 @@ class Conformance(
@Test
fun emptyUnary(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.emptyCall(Empty.newBuilder().build()) { response ->
testServiceConnectClient.emptyCall(empty {}) { response ->
response.failure {
fail<Unit>("expected error to be null")
}
response.success { success ->
assertThat(success.message).isEqualTo(Empty.newBuilder().build())
assertThat(success.message).isEqualTo(empty {})
countDownLatch.countDown()
}
}
Expand Down Expand Up @@ -306,10 +328,13 @@ class Conformance(
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
assertThat(result.error).isNotNull()
assertThat(result.connectError()!!.code).isEqualTo(Code.DEADLINE_EXCEEDED)
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED)
countDownLatch.countDown()
try {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

assertThat(result.error).isNotNull()
assertThat(result.connectError()!!.code).isEqualTo(Code.DEADLINE_EXCEEDED)
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED)
} finally {
countDownLatch.countDown()
}
}
)
}
Expand Down Expand Up @@ -353,7 +378,7 @@ class Conformance(
@Test
fun unimplementedMethod(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
testServiceConnectClient.unimplementedCall(Empty.newBuilder().build()) { response ->
testServiceConnectClient.unimplementedCall(empty {}) { response ->
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
countDownLatch.countDown()
}
Expand All @@ -364,14 +389,41 @@ class Conformance(
@Test
fun unimplementedService(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
unimplementedServiceClient.unimplementedCall(Empty.newBuilder().build()) { response ->
unimplementedServiceClient.unimplementedCall(empty {}) { response ->
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
countDownLatch.countDown()
}
countDownLatch.await(500, TimeUnit.MILLISECONDS)
assertThat(countDownLatch.count).isZero()
}

@Test
fun unimplementedServerStreamingService(): Unit = runBlocking {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this test which caught the streaming behavior issue: #101 (comment)

val countDownLatch = CountDownLatch(1)
val stream = unimplementedServiceClient.unimplementedStreamingOutputCall()
stream.send(empty { })
withContext(Dispatchers.IO) {
val job = async {
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
try {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question about trying here

assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.UNIMPLEMENTED)
} finally {
countDownLatch.countDown()
}
}
)
}
}
countDownLatch.await(5, TimeUnit.SECONDS)
job.cancel()
assertThat(countDownLatch.count).isZero()
stream.close()
}
}

@Test
fun failUnary(): Unit = runBlocking {
val expectedErrorDetail = errorDetail {
Expand Down Expand Up @@ -399,12 +451,12 @@ class Conformance(

@Test
fun emptyUnaryBlocking(): Unit = runBlocking {
val response = testServiceConnectClient.emptyCallBlocking(Empty.newBuilder().build()).execute()
val response = testServiceConnectClient.emptyCallBlocking(empty {}).execute()
response.failure {
fail<Unit>("expected error to be null")
}
response.success { success ->
assertThat(success.message).isEqualTo(Empty.newBuilder().build())
assertThat(success.message).isEqualTo(empty {})
}
}

Expand Down Expand Up @@ -499,13 +551,13 @@ class Conformance(

@Test
fun unimplementedMethodBlocking(): Unit = runBlocking {
val response = testServiceConnectClient.unimplementedCallBlocking(Empty.newBuilder().build()).execute()
val response = testServiceConnectClient.unimplementedCallBlocking(empty {}).execute()
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
}

@Test
fun unimplementedServiceBlocking(): Unit = runBlocking {
val response = unimplementedServiceClient.unimplementedCallBlocking(Empty.newBuilder().build()).execute()
val response = unimplementedServiceClient.unimplementedCallBlocking(empty {}).execute()
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright 2022-2023 The Connect Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.connectrpc.conformance

enum class ServerType {
CONNECT_GO,
GRPC_GO,
}
2 changes: 2 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ moshi = "1.15.0"
okhttp = "4.10.0"
okio = "3.0.0"
protobuf = "3.24.3"
slf4j = "1.7.36"

[libraries]
android = { module = "com.google.android:android", version.ref = "android" }
Expand Down Expand Up @@ -49,6 +50,7 @@ protobuf-java-util = { module = "com.google.protobuf:protobuf-java-util", versio
protobuf-javalite = { module = "com.google.protobuf:protobuf-javalite", version.ref = "protobuf" }
protobuf-kotlin = { module = "com.google.protobuf:protobuf-kotlin", version.ref = "protobuf" }
protobuf-kotlinlite = { module = "com.google.protobuf:protobuf-kotlin-lite", version.ref = "protobuf" }
slf4j-simple = { module = "org.slf4j:slf4j-simple", version.ref = "slf4j" }
spotless = { module = "com.diffplug.spotless:spotless-plugin-gradle", version = "6.13.0" }
testcontainers = { module = "org.testcontainers:testcontainers", version = "1.19.0" }

Expand Down
2 changes: 1 addition & 1 deletion library/src/main/kotlin/com/connectrpc/StreamResult.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.connectrpc
*
* A typical stream receives [Headers] > [Message] > [Message] > [Message] ... > [Complete]
*/
sealed class StreamResult<Output> constructor(
sealed class StreamResult<Output>(
val error: Throwable? = null
) {
// Headers have been received over the stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ object GzipCompressionPool : CompressionPool {

override fun decompress(buffer: Buffer): Buffer {
val result = Buffer()
val source = GzipSource(buffer)
while (source.read(result, Int.MAX_VALUE.toLong()) != -1L) {
// continue reading.
GzipSource(buffer).use {
while (it.read(result, Int.MAX_VALUE.toLong()) != -1L) {
// continue reading.
}
}
return result
}
Expand Down
2 changes: 0 additions & 2 deletions library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import com.connectrpc.http.HTTPClientInterface
import com.connectrpc.http.HTTPRequest
import com.connectrpc.http.Stream
import com.connectrpc.protocols.GETConfiguration
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.suspendCancellableCoroutine
import java.net.URL
Expand Down Expand Up @@ -162,7 +161,6 @@ class ProtocolClient(
return ClientOnlyStream(stream)
}

@OptIn(ExperimentalCoroutinesApi::class)
private suspend fun <Input : Any, Output : Any> bidirectionalStream(
methodSpec: MethodSpec<Input, Output>,
headers: Headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal class ConnectInterceptor(
trailers.putAll(response.headers.toTrailers())
trailers.putAll(response.trailers)
val responseHeaders =
response.headers.filter { entry -> !entry.key.startsWith("trailer") }.toMutableMap()
response.headers.filter { entry -> !entry.key.startsWith("trailer-") }
val compressionPool = clientConfig.compressionPool(responseHeaders[CONTENT_ENCODING]?.first())
val (code, connectError) = if (response.code != Code.OK) {
val error = parseConnectUnaryError(code = response.code, response.headers, response.message.buffer)
Expand Down Expand Up @@ -150,7 +150,7 @@ internal class ConnectInterceptor(
val streamResult: StreamResult<Buffer> = res.fold(
onHeaders = { result ->
val responseHeaders =
result.headers.filter { entry -> !entry.key.startsWith("trailer") }.toMutableMap()
result.headers.filter { entry -> !entry.key.startsWith("trailer-") }
responseCompressionPool =
clientConfig.compressionPool(responseHeaders[CONNECT_STREAMING_CONTENT_ENCODING]?.first())
StreamResult.Headers(responseHeaders)
Expand All @@ -168,7 +168,7 @@ internal class ConnectInterceptor(
}
},
onCompletion = { result ->
val streamTrailers: Trailers = result.trailers
val streamTrailers = result.trailers
val error = result.connectError()
StreamResult.Complete(error?.code ?: Code.OK, error = error, streamTrailers)
}
Expand Down Expand Up @@ -290,7 +290,7 @@ internal class ConnectInterceptor(

private fun Headers.toTrailers(): Trailers {
val trailers = mutableMapOf<String, MutableList<String>>()
for (pair in this.filter { entry -> entry.key.startsWith("trailer") }) {
for (pair in this.filter { entry -> entry.key.startsWith("trailer-") }) {
val key = pair.key.substringAfter("trailer-")
if (trailers.containsKey(key)) {
trailers[key]?.add(pair.value.first())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
package com.connectrpc.protocols

import com.connectrpc.Code
import com.connectrpc.ConnectError
import com.connectrpc.ConnectErrorDetail
import com.connectrpc.Headers
import com.connectrpc.SerializationStrategy
import okio.ByteString

/**
Expand All @@ -29,5 +32,26 @@ internal data class GRPCCompletion(
// Message data.
val message: ByteString,
// List of error details.
val errorDetails: List<ConnectErrorDetail>
val errorDetails: List<ConnectErrorDetail>,
// Set to either message headers (or trailers) where the gRPC status was found.
val metadata: Headers
)

internal fun grpcCompletionToConnectError(completion: GRPCCompletion?, serializationStrategy: SerializationStrategy, error: Throwable?): ConnectError? {
if (error is ConnectError) {
return error
}
val code = completion?.code ?: Code.UNKNOWN
if (error != null || code != Code.OK) {
return ConnectError(
code = code,
errorDetailParser = serializationStrategy.errorDetailParser(),
message = completion?.message?.utf8(),
exception = error,
details = completion?.errorDetails ?: emptyList(),
metadata = completion?.metadata ?: emptyMap()
)
}
// Successful call.
return null
}
Loading