Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of gRPC trailers-only responses #101

Merged
merged 8 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conformance/google-java/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies {
testImplementation(libs.mockito)
testImplementation(libs.kotlin.coroutines.core)
testImplementation(libs.testcontainers)
testImplementation(libs.slf4j.simple)
}

configure<SpotlessExtension> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import com.connectrpc.impl.ProtocolClient
import com.connectrpc.okhttp.ConnectOkHttpClient
import com.connectrpc.protocols.NetworkProtocol
import com.google.protobuf.ByteString
import com.google.protobuf.Empty
import com.google.protobuf.empty
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
Expand All @@ -58,24 +57,25 @@ import java.util.concurrent.TimeUnit

@RunWith(Parameterized::class)
class Conformance(
private val clientProtocol: NetworkProtocol,
private val serverProtocol: NetworkProtocol
private val protocol: NetworkProtocol,
private val serverType: ServerType
) {
private lateinit var protocolClient: ProtocolClient
private lateinit var connectClient: ProtocolClient
private lateinit var shortTimeoutConnectClient: ProtocolClient
private lateinit var unimplementedServiceClient: UnimplementedServiceClient
private lateinit var testServiceClient: TestServiceClient
private lateinit var testServiceConnectClient: TestServiceClient

companion object {
private const val CONFORMANCE_VERSION = "0b07f579cb61ad89de24524d62f804a2b03b1acf"

@JvmStatic
@Parameters(name = "client={0},server={1}")
fun data(): Iterable<Array<NetworkProtocol>> {
fun data(): Iterable<Array<Any>> {
return arrayListOf(
arrayOf(NetworkProtocol.CONNECT, NetworkProtocol.CONNECT),
arrayOf(NetworkProtocol.GRPC, NetworkProtocol.CONNECT),
arrayOf(NetworkProtocol.GRPC, NetworkProtocol.GRPC)
arrayOf(NetworkProtocol.CONNECT, ServerType.CONNECT_GO),
arrayOf(NetworkProtocol.GRPC, ServerType.CONNECT_GO),
arrayOf(NetworkProtocol.GRPC_WEB, ServerType.CONNECT_GO),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added gRPC web protocol tests as well (and found some issues as a result).

arrayOf(NetworkProtocol.GRPC, ServerType.GRPC_GO)
)
}

Expand Down Expand Up @@ -114,7 +114,7 @@ class Conformance(

@Before
fun before() {
val serverPort = if (serverProtocol == NetworkProtocol.CONNECT) CONFORMANCE_CONTAINER_CONNECT.getMappedPort(8081) else CONFORMANCE_CONTAINER_GRPC.getMappedPort(8081)
val serverPort = if (serverType == ServerType.CONNECT_GO) CONFORMANCE_CONTAINER_CONNECT.getMappedPort(8081) else CONFORMANCE_CONTAINER_GRPC.getMappedPort(8081)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal, but calling out that Swift uses 8083 for grpc-go because that's where it's hosted in the container: https://github.com/connectrpc/connect-swift/blob/f7aab0e53c38f15d5ab62e8e8afa2f984c7e34f3/Tests/ConnectLibraryTests/ConnectConformance/ConformanceConfiguration.swift#L58

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could update to match, however that is coming from https://github.com/connectrpc/connect-swift/blob/f7aab0e53c38f15d5ab62e8e8afa2f984c7e34f3/Makefile#L49-L54. We're no longer needing to spawn containers from the Makefile in this project - we're using testcontainers to do that, so port 8081 is actually ephemeral on the host.

val host = "https://localhost:$serverPort"
val (sslSocketFactory, trustManager) = sslContext()
val client = OkHttpClient.Builder()
Expand All @@ -137,23 +137,23 @@ class Conformance(
ProtocolClientConfig(
host = host,
serializationStrategy = GoogleJavaProtobufStrategy(),
networkProtocol = clientProtocol,
networkProtocol = protocol,
requestCompression = RequestCompression(10, GzipCompressionPool),
compressionPools = listOf(GzipCompressionPool)
)
)
protocolClient = ProtocolClient(
connectClient = ProtocolClient(
httpClient = ConnectOkHttpClient(client),
ProtocolClientConfig(
host = host,
serializationStrategy = GoogleJavaProtobufStrategy(),
networkProtocol = clientProtocol,
networkProtocol = protocol,
requestCompression = RequestCompression(10, GzipCompressionPool),
compressionPools = listOf(GzipCompressionPool)
)
)
testServiceClient = TestServiceClient(protocolClient)
unimplementedServiceClient = UnimplementedServiceClient(protocolClient)
testServiceConnectClient = TestServiceClient(connectClient)
unimplementedServiceClient = UnimplementedServiceClient(connectClient)
}

@Test
Expand All @@ -163,7 +163,7 @@ class Conformance(
reason = "soirée 🎉"
domain = "connect-conformance"
}
val stream = testServiceClient.failStreamingOutputCall()
val stream = testServiceConnectClient.failStreamingOutputCall()
val sizes = listOf(
31415,
9,
Expand All @@ -187,14 +187,17 @@ class Conformance(
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
// For some reason we keep timing out on these calls and not actually getting a real response like with grpc?
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.message).isEqualTo("soirée 🎉")
assertThat(result.connectError()!!.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail
)
countDownLatch.countDown()
try {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we want this to explode if any of the asserts fail rather than allowing tests to continue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will still fail the test, but this will prevent it waiting 5 seconds on the countDownLatch for the failure to register. Verified by adding a failing assert.

// For some reason we keep timing out on these calls and not actually getting a real response like with grpc?
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.message).isEqualTo("soirée 🎉")
assertThat(result.connectError()!!.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail
)
} finally {
countDownLatch.countDown()
}
}
)
}
Expand All @@ -209,12 +212,12 @@ class Conformance(
@Test
fun emptyUnary(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
testServiceClient.emptyCall(Empty.newBuilder().build()) { response ->
testServiceConnectClient.emptyCall(empty {}) { response ->
response.failure {
fail<Unit>("expected error to be null")
}
response.success { success ->
assertThat(success.message).isEqualTo(Empty.newBuilder().build())
assertThat(success.message).isEqualTo(empty {})
countDownLatch.countDown()
}
}
Expand All @@ -232,7 +235,7 @@ class Conformance(
}
}
val countDownLatch = CountDownLatch(1)
testServiceClient.unaryCall(message) { response ->
testServiceConnectClient.unaryCall(message) { response ->
response.failure {
fail<Unit>("expected error to be null")
}
Expand Down Expand Up @@ -262,7 +265,7 @@ class Conformance(
payload = payload { body = ByteString.copyFrom(ByteArray(size)) }
}
val countDownLatch = CountDownLatch(1)
testServiceClient.unaryCall(message, headers) { response ->
testServiceConnectClient.unaryCall(message, headers) { response ->
assertThat(response.code).isEqualTo(Code.OK)
assertThat(response.headers[leadingKey]).containsExactly(leadingValue)
assertThat(response.trailers[trailingKey]).containsExactly(b64Encode(trailingValue))
Expand All @@ -287,7 +290,7 @@ class Conformance(
}
}
val countDownLatch = CountDownLatch(1)
testServiceClient.unaryCall(message) { response ->
testServiceConnectClient.unaryCall(message) { response ->
assertThat(response.code).isEqualTo(Code.UNKNOWN)
response.failure { errorResponse ->
assertThat(errorResponse.error).isNotNull()
Expand Down Expand Up @@ -325,10 +328,13 @@ class Conformance(
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
assertThat(result.error).isNotNull()
assertThat(result.connectError()!!.code).isEqualTo(Code.DEADLINE_EXCEEDED)
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED)
countDownLatch.countDown()
try {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

assertThat(result.error).isNotNull()
assertThat(result.connectError()!!.code).isEqualTo(Code.DEADLINE_EXCEEDED)
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED)
} finally {
countDownLatch.countDown()
}
}
)
}
Expand All @@ -346,7 +352,7 @@ class Conformance(
val statusMessage =
"\\t\\ntest with whitespace\\r\\nand Unicode BMP ☺ and non-BMP \uD83D\uDE08\\t\\n"
val countDownLatch = CountDownLatch(1)
testServiceClient.unaryCall(
testServiceConnectClient.unaryCall(
simpleRequest {
responseStatus = echoStatus {
code = 2
Expand All @@ -372,7 +378,7 @@ class Conformance(
@Test
fun unimplementedMethod(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
testServiceClient.unimplementedCall(Empty.newBuilder().build()) { response ->
testServiceConnectClient.unimplementedCall(empty {}) { response ->
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
countDownLatch.countDown()
}
Expand All @@ -383,22 +389,49 @@ class Conformance(
@Test
fun unimplementedService(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
unimplementedServiceClient.unimplementedCall(Empty.newBuilder().build()) { response ->
unimplementedServiceClient.unimplementedCall(empty {}) { response ->
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
countDownLatch.countDown()
}
countDownLatch.await(500, TimeUnit.MILLISECONDS)
assertThat(countDownLatch.count).isZero()
}

@Test
fun unimplementedServerStreamingService(): Unit = runBlocking {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this test which caught the streaming behavior issue: #101 (comment)

val countDownLatch = CountDownLatch(1)
val stream = unimplementedServiceClient.unimplementedStreamingOutputCall()
stream.send(empty { })
withContext(Dispatchers.IO) {
val job = async {
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
try {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question about trying here

assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.UNIMPLEMENTED)
} finally {
countDownLatch.countDown()
}
}
)
}
}
countDownLatch.await(5, TimeUnit.SECONDS)
job.cancel()
assertThat(countDownLatch.count).isZero()
stream.close()
}
}

@Test
fun failUnary(): Unit = runBlocking {
val expectedErrorDetail = errorDetail {
reason = "soirée 🎉"
domain = "connect-conformance"
}
val countDownLatch = CountDownLatch(1)
testServiceClient.failUnaryCall(simpleRequest {}) { response ->
testServiceConnectClient.failUnaryCall(simpleRequest {}) { response ->
assertThat(response.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
response.failure { errorResponse ->
val error = errorResponse.error
Expand All @@ -418,12 +451,12 @@ class Conformance(

@Test
fun emptyUnaryBlocking(): Unit = runBlocking {
val response = testServiceClient.emptyCallBlocking(Empty.newBuilder().build()).execute()
val response = testServiceConnectClient.emptyCallBlocking(empty {}).execute()
response.failure {
fail<Unit>("expected error to be null")
}
response.success { success ->
assertThat(success.message).isEqualTo(Empty.newBuilder().build())
assertThat(success.message).isEqualTo(empty {})
}
}

Expand All @@ -436,7 +469,7 @@ class Conformance(
body = ByteString.copyFrom(ByteArray(size))
}
}
val response = testServiceClient.unaryCallBlocking(message).execute()
val response = testServiceConnectClient.unaryCallBlocking(message).execute()
response.failure {
fail<Unit>("expected error to be null")
}
Expand All @@ -461,7 +494,7 @@ class Conformance(
responseSize = size
payload = payload { body = ByteString.copyFrom(ByteArray(size)) }
}
val response = testServiceClient.unaryCallBlocking(message, headers).execute()
val response = testServiceConnectClient.unaryCallBlocking(message, headers).execute()
assertThat(response.code).isEqualTo(Code.OK)
assertThat(response.headers[leadingKey]).containsExactly(leadingValue)
assertThat(response.trailers[trailingKey]).containsExactly(b64Encode(trailingValue))
Expand All @@ -481,7 +514,7 @@ class Conformance(
message = "test status message"
}
}
val response = testServiceClient.unaryCallBlocking(message).execute()
val response = testServiceConnectClient.unaryCallBlocking(message).execute()
assertThat(response.code).isEqualTo(Code.UNKNOWN)
response.failure { errorResponse ->
assertThat(errorResponse.error).isNotNull()
Expand All @@ -497,7 +530,7 @@ class Conformance(
fun specialStatusBlocking(): Unit = runBlocking {
val statusMessage =
"\\t\\ntest with whitespace\\r\\nand Unicode BMP ☺ and non-BMP \uD83D\uDE08\\t\\n"
val response = testServiceClient.unaryCallBlocking(
val response = testServiceConnectClient.unaryCallBlocking(
simpleRequest {
responseStatus = echoStatus {
code = 2
Expand All @@ -518,13 +551,13 @@ class Conformance(

@Test
fun unimplementedMethodBlocking(): Unit = runBlocking {
val response = testServiceClient.unimplementedCallBlocking(Empty.newBuilder().build()).execute()
val response = testServiceConnectClient.unimplementedCallBlocking(empty {}).execute()
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
}

@Test
fun unimplementedServiceBlocking(): Unit = runBlocking {
val response = unimplementedServiceClient.unimplementedCallBlocking(Empty.newBuilder().build()).execute()
val response = unimplementedServiceClient.unimplementedCallBlocking(empty {}).execute()
assertThat(response.code).isEqualTo(Code.UNIMPLEMENTED)
}

Expand All @@ -534,7 +567,7 @@ class Conformance(
reason = "soirée 🎉"
domain = "connect-conformance"
}
val response = testServiceClient.failUnaryCallBlocking(simpleRequest {}).execute()
val response = testServiceConnectClient.failUnaryCallBlocking(simpleRequest {}).execute()
assertThat(response.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
response.failure { errorResponse ->
val error = errorResponse.error
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright 2022-2023 The Connect Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.connectrpc.conformance

enum class ServerType {
CONNECT_GO,
GRPC_GO,
}
2 changes: 2 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ moshi = "1.15.0"
okhttp = "4.10.0"
okio = "3.0.0"
protobuf = "3.24.3"
slf4j = "1.7.36"

[libraries]
android = { module = "com.google.android:android", version.ref = "android" }
Expand Down Expand Up @@ -49,6 +50,7 @@ protobuf-java-util = { module = "com.google.protobuf:protobuf-java-util", versio
protobuf-javalite = { module = "com.google.protobuf:protobuf-javalite", version.ref = "protobuf" }
protobuf-kotlin = { module = "com.google.protobuf:protobuf-kotlin", version.ref = "protobuf" }
protobuf-kotlinlite = { module = "com.google.protobuf:protobuf-kotlin-lite", version.ref = "protobuf" }
slf4j-simple = { module = "org.slf4j:slf4j-simple", version.ref = "slf4j" }
spotless = { module = "com.diffplug.spotless:spotless-plugin-gradle", version = "6.13.0" }
testcontainers = { module = "org.testcontainers:testcontainers", version = "1.19.0" }

Expand Down
2 changes: 1 addition & 1 deletion library/src/main/kotlin/com/connectrpc/StreamResult.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.connectrpc
*
* A typical stream receives [Headers] > [Message] > [Message] > [Message] ... > [Complete]
*/
sealed class StreamResult<Output> constructor(
sealed class StreamResult<Output>(
val error: Throwable? = null
) {
// Headers have been received over the stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ object GzipCompressionPool : CompressionPool {

override fun decompress(buffer: Buffer): Buffer {
val result = Buffer()
val source = GzipSource(buffer)
while (source.read(result, Int.MAX_VALUE.toLong()) != -1L) {
// continue reading.
GzipSource(buffer).use {
while (it.read(result, Int.MAX_VALUE.toLong()) != -1L) {
// continue reading.
}
}
return result
}
Expand Down
Loading