Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix several streaming issues #106

Merged
merged 5 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,35 @@
package com.connectrpc.conformance

import com.connectrpc.Code
import com.connectrpc.ConnectError
import com.connectrpc.Headers
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.RequestCompression
import com.connectrpc.StreamResult
import com.connectrpc.Trailers
import com.connectrpc.compression.GzipCompressionPool
import com.connectrpc.conformance.ssl.sslContext
import com.connectrpc.conformance.v1.ErrorDetail
import com.connectrpc.conformance.v1.PayloadType
import com.connectrpc.conformance.v1.TestServiceClient
import com.connectrpc.conformance.v1.UnimplementedServiceClient
import com.connectrpc.conformance.v1.echoStatus
import com.connectrpc.conformance.v1.errorDetail
import com.connectrpc.conformance.v1.payload
import com.connectrpc.conformance.v1.responseParameters
import com.connectrpc.conformance.v1.simpleRequest
import com.connectrpc.conformance.v1.streamingInputCallRequest
import com.connectrpc.conformance.v1.streamingOutputCallRequest
import com.connectrpc.extensions.GoogleJavaProtobufStrategy
import com.connectrpc.getOrThrow
import com.connectrpc.impl.ProtocolClient
import com.connectrpc.okhttp.ConnectOkHttpClient
import com.connectrpc.protocols.NetworkProtocol
import com.google.protobuf.ByteString
import com.google.protobuf.empty
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import okhttp3.OkHttpClient
Expand All @@ -54,6 +62,7 @@ import java.time.Duration
import java.util.Base64
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean

@RunWith(Parameterized::class)
class Conformance(
Expand Down Expand Up @@ -157,8 +166,7 @@ class Conformance(
}

@Test
fun failServerStreaming() = runBlocking {
val countDownLatch = CountDownLatch(1)
fun failServerStreaming(): Unit = runBlocking {
val expectedErrorDetail = errorDetail {
reason = "soirée 🎉"
domain = "connect-conformance"
Expand All @@ -177,35 +185,32 @@ class Conformance(
}
}

stream.send(
stream.sendAndClose(
streamingOutputCallRequest {
responseParameters.addAll(parameters)
},
)
val countDownLatch = CountDownLatch(1)
withContext(Dispatchers.IO) {
val job = async {
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
try {
// For some reason we keep timing out on these calls and not actually getting a real response like with grpc?
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.connectError()!!.message).isEqualTo("soirée 🎉")
assertThat(result.connectError()!!.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail,
)
} finally {
countDownLatch.countDown()
}
},
try {
val result = streamResults(stream.resultChannel())
assertThat(result.messages.map { it.payload.body.size() }).isEqualTo(sizes)
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(result.error).isInstanceOf(ConnectError::class.java)
val connectError = result.error as ConnectError
assertThat(connectError.code).isEqualTo(Code.RESOURCE_EXHAUSTED)
assertThat(connectError.message).isEqualTo("soirée 🎉")
assertThat(connectError.unpackedDetails(ErrorDetail::class)).containsExactly(
expectedErrorDetail,
)
} finally {
countDownLatch.countDown()
}
}
countDownLatch.await(5, TimeUnit.SECONDS)
job.cancel()
assertThat(countDownLatch.count).isZero()
stream.close()
}
}

Expand Down Expand Up @@ -308,7 +313,7 @@ class Conformance(
}

@Test
fun timeoutOnSleepingServer() = runBlocking {
fun timeoutOnSleepingServer(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
val client = TestServiceClient(shortTimeoutConnectClient)
val request = streamingOutputCallRequest {
Expand All @@ -325,25 +330,20 @@ class Conformance(
val stream = client.streamingOutputCall()
withContext(Dispatchers.IO) {
val job = async {
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
try {
assertThat(result.error).isNotNull()
assertThat(result.connectError()!!.code).isEqualTo(Code.DEADLINE_EXCEEDED)
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED)
} finally {
countDownLatch.countDown()
}
},
)
try {
val result = streamResults(stream.resultChannel())
assertThat(result.error).isInstanceOf(ConnectError::class.java)
val connectErr = result.error as ConnectError
assertThat(connectErr.code).isEqualTo(Code.DEADLINE_EXCEEDED)
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED)
} finally {
countDownLatch.countDown()
}
}
stream.send(request)
stream.sendAndClose(request)
countDownLatch.await(5, TimeUnit.SECONDS)
job.cancel()
assertThat(countDownLatch.count).isZero()
stream.close()
}
}

Expand Down Expand Up @@ -401,26 +401,22 @@ class Conformance(
fun unimplementedServerStreamingService(): Unit = runBlocking {
val countDownLatch = CountDownLatch(1)
val stream = unimplementedServiceClient.unimplementedStreamingOutputCall()
stream.send(empty { })
stream.sendAndClose(empty { })
withContext(Dispatchers.IO) {
val job = async {
for (res in stream.resultChannel()) {
res.maybeFold(
onCompletion = { result ->
try {
assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED)
assertThat(result.connectError()!!.code).isEqualTo(Code.UNIMPLEMENTED)
} finally {
countDownLatch.countDown()
}
},
)
try {
val result = streamResults(stream.resultChannel())
assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED)
assertThat(result.error).isInstanceOf(ConnectError::class.java)
val connectErr = result.error as ConnectError
assertThat(connectErr.code).isEqualTo(Code.UNIMPLEMENTED)
} finally {
countDownLatch.countDown()
}
}
countDownLatch.await(5, TimeUnit.SECONDS)
job.cancel()
assertThat(countDownLatch.count).isZero()
stream.close()
}
}

Expand Down Expand Up @@ -754,6 +750,87 @@ class Conformance(
assertThat(countDownLatch.count).isZero()
}

@Test
fun clientStreaming(): Unit = runBlocking {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New test for client streaming (used to find several issues with streaming calls and API definitions).

val stream = testServiceConnectClient.streamingInputCall(emptyMap())
var sum = 0
listOf(256000, 8, 1024, 32768).forEach { size ->
stream.send(
streamingInputCallRequest {
payload = payload {
type = PayloadType.COMPRESSABLE
body = ByteString.copyFrom(ByteArray(size))
}
},
).getOrThrow()
sum += size
}
val countDownLatch = CountDownLatch(1)
withContext(Dispatchers.IO) {
val job = async {
try {
val result = stream.receiveAndClose().getOrThrow()
assertThat(result.aggregatedPayloadSize).isEqualTo(sum)
} finally {
countDownLatch.countDown()
}
}
countDownLatch.await(5, TimeUnit.MINUTES)
job.cancel()
assertThat(countDownLatch.count).isZero()
}
}

private data class ServerStreamingResult<Output>(
val headers: Headers,
val messages: List<Output>,
val code: Code,
val trailers: Trailers,
val error: Throwable?,
)

/*
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should really revisit the APIs around streaming calls to make this sort of code unnecessary. For now however, this will at least allow us to easily consume all of the data from a server/bidi streaming response and perform assertions on headers, messages, trailers, and errors.

* Convenience method to return all results (with sanity checking) for calls which stream results from the server
* (bidi and server streaming).
*
* This allows us to easily verify headers, messages, trailers, and errors without having to use fold/maybeFold
* manually in each location.
*/
private suspend fun <Output> streamResults(channel: ReceiveChannel<StreamResult<Output>>): ServerStreamingResult<Output> {
val seenHeaders = AtomicBoolean(false)
var headers: Headers = emptyMap()
val messages: MutableList<Output> = mutableListOf()
val seenCompletion = AtomicBoolean(false)
var code: Code = Code.UNKNOWN
var trailers: Headers = emptyMap()
var error: Throwable?
for (response in channel) {
response.maybeFold(
onHeaders = {
if (!seenHeaders.compareAndSet(false, true)) {
throw IllegalStateException("multiple onHeaders callbacks")
}
headers = it.headers
},
onMessage = {
messages.add(it.message)
},
onCompletion = {
if (!seenCompletion.compareAndSet(false, true)) {
throw IllegalStateException("multiple onCompletion callbacks")
}
code = it.code
trailers = it.trailers
error = it.error
},
)
}
if (!seenCompletion.get()) {
throw IllegalStateException("didn't get completion message")
}
return ServerStreamingResult(headers, messages, code, trailers, error)
}

private fun b64Encode(trailingValue: ByteArray): String {
return String(Base64.getEncoder().encode(trailingValue))
}
Expand Down
7 changes: 7 additions & 0 deletions examples/buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@ plugins:
- plugin: java
out: generated-google-java/build/generated/sources/bufgen
protoc_path: .tmp/bin/protoc
- plugin: kotlin
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generate kotlin code so our examples don't look so verbose.

out: generated-google-java/build/generated/sources/bufgen
protoc_path: .tmp/bin/protoc
- plugin: connect-kotlin
out: generated-google-javalite/build/generated/sources/bufgen
path: ./protoc-gen-connect-kotlin/build/install/protoc-gen-connect-kotlin/bin/protoc-gen-connect-kotlin
- plugin: java
out: generated-google-javalite/build/generated/sources/bufgen
protoc_path: .tmp/bin/protoc
opt: lite
- plugin: kotlin
out: generated-google-javalite/build/generated/sources/bufgen
protoc_path: .tmp/bin/protoc
opt: lite
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

package com.connectrpc.examples.kotlin

import com.connectrpc.Code
import com.connectrpc.ConnectError
import com.connectrpc.ProtocolClientConfig
import com.connectrpc.eliza.v1.ConverseRequest
import com.connectrpc.eliza.v1.ElizaServiceClient
import com.connectrpc.eliza.v1.converseRequest
import com.connectrpc.extensions.GoogleJavaProtobufStrategy
import com.connectrpc.impl.ProtocolClient
import com.connectrpc.okhttp.ConnectOkHttpClient
Expand All @@ -33,37 +35,50 @@ class Main {
fun main(args: Array<String>) {
runBlocking {
val host = "https://demo.connectrpc.com"
val okHttpClient = OkHttpClient()
.newBuilder()
.readTimeout(Duration.ofMinutes(10))
.writeTimeout(Duration.ofMinutes(10))
.callTimeout(Duration.ofMinutes(10))
.build()
val client = ProtocolClient(
httpClient = ConnectOkHttpClient(
OkHttpClient()
.newBuilder()
.readTimeout(Duration.ofMinutes(10))
.writeTimeout(Duration.ofMinutes(10))
.callTimeout(Duration.ofMinutes(10))
.build(),
),
httpClient = ConnectOkHttpClient(okHttpClient),
ProtocolClientConfig(
host = host,
serializationStrategy = GoogleJavaProtobufStrategy(),
),
)
val elizaServiceClient = ElizaServiceClient(client)
connectStreaming(elizaServiceClient)
try {
connectStreaming(elizaServiceClient)
} finally {
okHttpClient.dispatcher.executorService.shutdown()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary for the client to cleanly shut down (the dispatcher thread pool uses non-daemon threads, so we need to shut it down manually for the app to shut down cleanly.

}
}
}

private suspend fun connectStreaming(elizaServiceClient: ElizaServiceClient) {
val stream = elizaServiceClient.converse()
withContext(Dispatchers.IO) {
// Add the message the user is sending to the views.
stream.send(ConverseRequest.newBuilder().setSentence("hello").build())
stream.send(converseRequest { sentence = "hello" })
stream.sendClose()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't close the bidi side of the stream, we won't ever complete.

for (streamResult in stream.resultChannel()) {
streamResult.maybeFold(
onMessage = { result ->
// Update the view with the response.
val elizaResponse = result.message
println(elizaResponse.sentence)
},
onCompletion = { result ->
if (result.code != Code.OK) {
val connectErr = result.connectError()
if (connectErr != null) {
throw connectErr
}
throw ConnectError(code = result.code, metadata = result.trailers)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this to at least show how errors need to be handled today in the current API. In a follow up, I'll see if this can be improved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just print the error details as a response?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to illustrate how you might want to handle errors in real code (instead of printing code/cause/trailers).

I think it also points out opportunities to make it even easier for consumers - if the Code is ever not OK, we should always have a ConnectError (even if it wraps another cause).

}
},
)
}
}
Expand Down
Loading