Skip to content

Commit

Permalink
Fix several streaming issues (#106)
Browse files Browse the repository at this point in the history
Update conformance tests to add a new test exercising client side
streaming, which exposed several issues in streaming call
implementations.

The first issue only affected client streaming (it stopped attempting to
read a response from the server once the send side closed - it should
have stopped only if the receive side closed).

The second issue resulted from not calling close on the channel after
the completion message was received, which lead to hangs consuming from
`resultChannel()` (it would never complete). After this fix, both
examples (for java and javalite) were updated and fixed to correctly
exit when finished.

Additionally, several cleanups were made to the API (since the current
API for client streaming was non-functional - it would only return the
initial Headers result and not the message or completion result).

This should help resolve reported streaming issues like #100.

# API Updates

## `com.connectrpc.BidirectionalStreamInterface`

### Removed
* `close()`
* Use `sendClose()` instead. This may have confused callers that the
close() method would close both send and receive sides of the connection
when it was only closing the send side.

## `com.connectrpc.ClientOnlyStreamInterface`

### Added
* `sendClose()`
* This shouldn't typically need to be called as receiveAndClose()
already closes the send side of the stream.
* `isSendClosed()`

### Changed
* `receiveAndClose()`
* Changed to return a ResponseMessage instead of a StreamResult. This
allows callers to easily get access to the response as if they were
calling a unary method. Previously, the StreamResult would only return
the first result retrieved by the call, which typically was a Headers
result (leaving callers unable to access the Message or Completion
contents).

### Removed
* `close()`
  * Replaced with `sendClose()`.

## `com.connectrpc.ServerOnlyStreamInterface`

### Added
* `receiveClose()`
* `isReceiveClosed()`

### Removed
* `close()`
* This closed both the send and receive side of the stream (unlike in
other interfaces which just closed the send side). If needed, callers
should invoke `receiveClose()` instead (although this isn't necessary in
normal use).
* `send()`
* Callers should invoke `sendAndClose()` instead. Otherwise, reading
results from `resultChannel()` will hang since the send side of the
stream should be closed before reading responses.

## `com.connectrpc.StreamResult`

### Removed
* Removed the `error` field from the base `StreamResult` class. It was
never used by the `Headers` or `Message` subclasses and only used on the
`Complete` type. This should make it easier for callers to use `Headers`
and `Message` types since they don't need to worry about handling
`error`.
  • Loading branch information
pkwarren authored Sep 22, 2023
1 parent 5fa6786 commit bb102a8
Show file tree
Hide file tree
Showing 25 changed files with 335 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,35 @@
package com.connectrpc.conformance

import com.connectrpc.Code
import com.connectrpc.ConnectError
import com.connectrpc.Headers
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.RequestCompression
import com.connectrpc.StreamResult
import com.connectrpc.Trailers
import com.connectrpc.compression.GzipCompressionPool
import com.connectrpc.conformance.ssl.sslContext
import com.connectrpc.conformance.v1.ErrorDetail
import com.connectrpc.conformance.v1.PayloadType
import com.connectrpc.conformance.v1.TestServiceClient
import com.connectrpc.conformance.v1.UnimplementedServiceClient
import com.connectrpc.conformance.v1.echoStatus
import com.connectrpc.conformance.v1.errorDetail
import com.connectrpc.conformance.v1.payload
import com.connectrpc.conformance.v1.responseParameters
import com.connectrpc.conformance.v1.simpleRequest
import com.connectrpc.conformance.v1.streamingInputCallRequest
import com.connectrpc.conformance.v1.streamingOutputCallRequest
import com.connectrpc.extensions.GoogleJavaProtobufStrategy
import com.connectrpc.getOrThrow
import com.connectrpc.impl.ProtocolClient
import com.connectrpc.okhttp.ConnectOkHttpClient
import com.connectrpc.protocols.NetworkProtocol
import com.google.protobuf.ByteString
import com.google.protobuf.empty
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import okhttp3.OkHttpClient
Expand All @@ -54,6 +62,7 @@ import java.time.Duration
import java.util.Base64
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean

@RunWith(Parameterized::class)
class Conformance(
Expand Down Expand Up @@ -157,8 +166,7 @@ class Conformance(
}

@Test
fun failServerStreaming() = runBlocking {
val countDownLatch = CountDownLatch(1)
fun failServerStreaming(): Unit = runBlocking {
val expectedErrorDetail = errorDetail {
reason = "soirée 🎉"
domain = "connect-conformance"
Expand All @@ -177,35 +185,32 @@ class Conformance(
}
}

stream.send(
stream.sendAndClose(
streamingOutputCallRequest {
responseParameters.addAll(parameters)
},
)
val countDownLatch = CountDownLatch(1)
withContext(Dispatchers.IO) {
val job = async {
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
try {
// For some reason we keep timing out on these calls and not actually getting a real response like with grpc?
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.message).isEqualTo("soirée 🎉")
assertThat(result.connectError()!!.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail,
)
} finally {
countDownLatch.countDown()
}
},
try {
val result = streamResults(stream.resultChannel())
assertThat(result.messages.map { it.payload.body.size() }).isEqualTo(sizes)
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.error).isInstanceOf(ConnectError::class.java)
val connectError = result.error as ConnectError
assertThat(connectError.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(connectError.message).isEqualTo("soirée 🎉")
assertThat(connectError.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail,
)
} finally {
countDownLatch.countDown()
}
}
countDownLatch.await(5, TimeUnit.SECONDS)
job.cancel()
assertThat(countDownLatch.count).isZero()
stream.close()
}
}

Expand Down Expand Up @@ -308,7 +313,7 @@ class Conformance(
}

@Test
fun timeoutOnSleepingServer() = runBlocking {
fun timeoutOnSleepingServer(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
val client = TestServiceClient(shortTimeoutConnectClient)
val request = streamingOutputCallRequest {
Expand All @@ -325,25 +330,20 @@ class Conformance(
val stream = client.streamingOutputCall()
withContext(Dispatchers.IO) {
val job = async {
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
try {
assertThat(result.error).isNotNull()
assertThat(result.connectError()!!.code).isEqualTo(Code.DEADLINE_EXCEEDED)
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED)
} finally {
countDownLatch.countDown()
}
},
)
try {
val result = streamResults(stream.resultChannel())
assertThat(result.error).isInstanceOf(ConnectError::class.java)
val connectErr = result.error as ConnectError
assertThat(connectErr.code).isEqualTo(Code.DEADLINE_EXCEEDED)
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED)
} finally {
countDownLatch.countDown()
}
}
stream.send(request)
stream.sendAndClose(request)
countDownLatch.await(5, TimeUnit.SECONDS)
job.cancel()
assertThat(countDownLatch.count).isZero()
stream.close()
}
}

Expand Down Expand Up @@ -401,26 +401,22 @@ class Conformance(
fun unimplementedServerStreamingService(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
val stream = unimplementedServiceClient.unimplementedStreamingOutputCall()
stream.send(empty { })
stream.sendAndClose(empty { })
withContext(Dispatchers.IO) {
val job = async {
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
try {
assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.UNIMPLEMENTED)
} finally {
countDownLatch.countDown()
}
},
)
try {
val result = streamResults(stream.resultChannel())
assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED)
assertThat(result.error).isInstanceOf(ConnectError::class.java)
val connectErr = result.error as ConnectError
assertThat(connectErr.code).isEqualTo(Code.UNIMPLEMENTED)
} finally {
countDownLatch.countDown()
}
}
countDownLatch.await(5, TimeUnit.SECONDS)
job.cancel()
assertThat(countDownLatch.count).isZero()
stream.close()
}
}

Expand Down Expand Up @@ -754,6 +750,87 @@ class Conformance(
assertThat(countDownLatch.count).isZero()
}

@Test
fun clientStreaming(): Unit = runBlocking {
val stream = testServiceConnectClient.streamingInputCall(emptyMap())
var sum = 0
listOf(256000, 8, 1024, 32768).forEach { size ->
stream.send(
streamingInputCallRequest {
payload = payload {
type = PayloadType.COMPRESSABLE
body = ByteString.copyFrom(ByteArray(size))
}
},
).getOrThrow()
sum += size
}
val countDownLatch = CountDownLatch(1)
withContext(Dispatchers.IO) {
val job = async {
try {
val result = stream.receiveAndClose().getOrThrow()
assertThat(result.aggregatedPayloadSize).isEqualTo(sum)
} finally {
countDownLatch.countDown()
}
}
countDownLatch.await(5, TimeUnit.MINUTES)
job.cancel()
assertThat(countDownLatch.count).isZero()
}
}

private data class ServerStreamingResult<Output>(
val headers: Headers,
val messages: List<Output>,
val code: Code,
val trailers: Trailers,
val error: Throwable?,
)

/*
* Convenience method to return all results (with sanity checking) for calls which stream results from the server
* (bidi and server streaming).
*
* This allows us to easily verify headers, messages, trailers, and errors without having to use fold/maybeFold
* manually in each location.
*/
private suspend fun <Output> streamResults(channel: ReceiveChannel<StreamResult<Output>>): ServerStreamingResult<Output> {
val seenHeaders = AtomicBoolean(false)
var headers: Headers = emptyMap()
val messages: MutableList<Output> = mutableListOf()
val seenCompletion = AtomicBoolean(false)
var code: Code = Code.UNKNOWN
var trailers: Headers = emptyMap()
var error: Throwable?
for (response in channel) {
response.maybeFold(
onHeaders = {
if (!seenHeaders.compareAndSet(false, true)) {
throw IllegalStateException("multiple onHeaders callbacks")
}
headers = it.headers
},
onMessage = {
messages.add(it.message)
},
onCompletion = {
if (!seenCompletion.compareAndSet(false, true)) {
throw IllegalStateException("multiple onCompletion callbacks")
}
code = it.code
trailers = it.trailers
error = it.error
},
)
}
if (!seenCompletion.get()) {
throw IllegalStateException("didn't get completion message")
}
return ServerStreamingResult(headers, messages, code, trailers, error)
}

private fun b64Encode(trailingValue: ByteArray): String {
return String(Base64.getEncoder().encode(trailingValue))
}
Expand Down
7 changes: 7 additions & 0 deletions examples/buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@ plugins:
- plugin: java
out: generated-google-java/build/generated/sources/bufgen
protoc_path: .tmp/bin/protoc
- plugin: kotlin
out: generated-google-java/build/generated/sources/bufgen
protoc_path: .tmp/bin/protoc
- plugin: connect-kotlin
out: generated-google-javalite/build/generated/sources/bufgen
path: ./protoc-gen-connect-kotlin/build/install/protoc-gen-connect-kotlin/bin/protoc-gen-connect-kotlin
- plugin: java
out: generated-google-javalite/build/generated/sources/bufgen
protoc_path: .tmp/bin/protoc
opt: lite
- plugin: kotlin
out: generated-google-javalite/build/generated/sources/bufgen
protoc_path: .tmp/bin/protoc
opt: lite
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

package com.connectrpc.examples.kotlin

import com.connectrpc.Code
import com.connectrpc.ConnectError
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.eliza.v1.ConverseRequest
import com.connectrpc.eliza.v1.ElizaServiceClient
import com.connectrpc.eliza.v1.converseRequest
import com.connectrpc.extensions.GoogleJavaProtobufStrategy
import com.connectrpc.impl.ProtocolClient
import com.connectrpc.okhttp.ConnectOkHttpClient
Expand All @@ -33,37 +35,50 @@ class Main {
fun main(args: Array<String>) {
runBlocking {
val host = "https://demo.connectrpc.com"
val okHttpClient = OkHttpClient()
.newBuilder()
.readTimeout(Duration.ofMinutes(10))
.writeTimeout(Duration.ofMinutes(10))
.callTimeout(Duration.ofMinutes(10))
.build()
val client = ProtocolClient(
httpClient = ConnectOkHttpClient(
OkHttpClient()
.newBuilder()
.readTimeout(Duration.ofMinutes(10))
.writeTimeout(Duration.ofMinutes(10))
.callTimeout(Duration.ofMinutes(10))
.build(),
),
httpClient = ConnectOkHttpClient(okHttpClient),
ProtocolClientConfig(
host = host,
serializationStrategy = GoogleJavaProtobufStrategy(),
),
)
val elizaServiceClient = ElizaServiceClient(client)
connectStreaming(elizaServiceClient)
try {
connectStreaming(elizaServiceClient)
} finally {
okHttpClient.dispatcher.executorService.shutdown()
}
}
}

private suspend fun connectStreaming(elizaServiceClient: ElizaServiceClient) {
val stream = elizaServiceClient.converse()
withContext(Dispatchers.IO) {
// Add the message the user is sending to the views.
stream.send(ConverseRequest.newBuilder().setSentence("hello").build())
stream.send(converseRequest { sentence = "hello" })
stream.sendClose()
for (streamResult in stream.resultChannel()) {
streamResult.maybeFold(
onMessage = { result ->
// Update the view with the response.
val elizaResponse = result.message
println(elizaResponse.sentence)
},
onCompletion = { result ->
if (result.code != Code.OK) {
val connectErr = result.connectError()
if (connectErr != null) {
throw connectErr
}
throw ConnectError(code = result.code, metadata = result.trailers)
}
},
)
}
}
Expand Down
Loading

0 comments on commit bb102a8

Please sign in to comment.