-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix several streaming issues #106
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,27 +15,35 @@ | |
package com.connectrpc.conformance | ||
|
||
import com.connectrpc.Code | ||
import com.connectrpc.ConnectError | ||
import com.connectrpc.Headers | ||
import com.connectrpc.ProtocolClientConfig | ||
import com.connectrpc.RequestCompression | ||
import com.connectrpc.StreamResult | ||
import com.connectrpc.Trailers | ||
import com.connectrpc.compression.GzipCompressionPool | ||
import com.connectrpc.conformance.ssl.sslContext | ||
import com.connectrpc.conformance.v1.ErrorDetail | ||
import com.connectrpc.conformance.v1.PayloadType | ||
import com.connectrpc.conformance.v1.TestServiceClient | ||
import com.connectrpc.conformance.v1.UnimplementedServiceClient | ||
import com.connectrpc.conformance.v1.echoStatus | ||
import com.connectrpc.conformance.v1.errorDetail | ||
import com.connectrpc.conformance.v1.payload | ||
import com.connectrpc.conformance.v1.responseParameters | ||
import com.connectrpc.conformance.v1.simpleRequest | ||
import com.connectrpc.conformance.v1.streamingInputCallRequest | ||
import com.connectrpc.conformance.v1.streamingOutputCallRequest | ||
import com.connectrpc.extensions.GoogleJavaProtobufStrategy | ||
import com.connectrpc.getOrThrow | ||
import com.connectrpc.impl.ProtocolClient | ||
import com.connectrpc.okhttp.ConnectOkHttpClient | ||
import com.connectrpc.protocols.NetworkProtocol | ||
import com.google.protobuf.ByteString | ||
import com.google.protobuf.empty | ||
import kotlinx.coroutines.Dispatchers | ||
import kotlinx.coroutines.async | ||
import kotlinx.coroutines.channels.ReceiveChannel | ||
import kotlinx.coroutines.runBlocking | ||
import kotlinx.coroutines.withContext | ||
import okhttp3.OkHttpClient | ||
|
@@ -54,6 +62,7 @@ import java.time.Duration | |
import java.util.Base64 | ||
import java.util.concurrent.CountDownLatch | ||
import java.util.concurrent.TimeUnit | ||
import java.util.concurrent.atomic.AtomicBoolean | ||
|
||
@RunWith(Parameterized::class) | ||
class Conformance( | ||
|
@@ -157,8 +166,7 @@ class Conformance( | |
} | ||
|
||
@Test | ||
fun failServerStreaming() = runBlocking { | ||
val countDownLatch = CountDownLatch(1) | ||
fun failServerStreaming(): Unit = runBlocking { | ||
val expectedErrorDetail = errorDetail { | ||
reason = "soirée 🎉" | ||
domain = "connect-conformance" | ||
|
@@ -177,35 +185,32 @@ class Conformance( | |
} | ||
} | ||
|
||
stream.send( | ||
stream.sendAndClose( | ||
streamingOutputCallRequest { | ||
responseParameters.addAll(parameters) | ||
}, | ||
) | ||
val countDownLatch = CountDownLatch(1) | ||
withContext(Dispatchers.IO) { | ||
val job = async { | ||
for (res in stream.resultChannel()) { | ||
res.maybeFold( | ||
onCompletion = { result -> | ||
try { | ||
// For some reason we keep timing out on these calls and not actually getting a real response like with grpc? | ||
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED) | ||
assertThat(result.connectError()!!.code).isEqualTo(Code.RESOURCE_EXHAUSTED) | ||
assertThat(result.connectError()!!.message).isEqualTo("soirée 🎉") | ||
assertThat(result.connectError()!!.unpackedDetails(ErrorDetail::class)).containsExactly( | ||
expectedErrorDetail, | ||
) | ||
} finally { | ||
countDownLatch.countDown() | ||
} | ||
}, | ||
try { | ||
val result = streamResults(stream.resultChannel()) | ||
assertThat(result.messages.map { it.payload.body.size() }).isEqualTo(sizes) | ||
assertThat(result.code).isEqualTo(Code.RESOURCE_EXHAUSTED) | ||
assertThat(result.error).isInstanceOf(ConnectError::class.java) | ||
val connectError = result.error as ConnectError | ||
assertThat(connectError.code).isEqualTo(Code.RESOURCE_EXHAUSTED) | ||
assertThat(connectError.message).isEqualTo("soirée 🎉") | ||
assertThat(connectError.unpackedDetails(ErrorDetail::class)).containsExactly( | ||
expectedErrorDetail, | ||
) | ||
} finally { | ||
countDownLatch.countDown() | ||
} | ||
} | ||
countDownLatch.await(5, TimeUnit.SECONDS) | ||
job.cancel() | ||
assertThat(countDownLatch.count).isZero() | ||
stream.close() | ||
} | ||
} | ||
|
||
|
@@ -308,7 +313,7 @@ class Conformance( | |
} | ||
|
||
@Test | ||
fun timeoutOnSleepingServer() = runBlocking { | ||
fun timeoutOnSleepingServer(): Unit = runBlocking { | ||
val countDownLatch = CountDownLatch(1) | ||
val client = TestServiceClient(shortTimeoutConnectClient) | ||
val request = streamingOutputCallRequest { | ||
|
@@ -325,25 +330,20 @@ class Conformance( | |
val stream = client.streamingOutputCall() | ||
withContext(Dispatchers.IO) { | ||
val job = async { | ||
for (res in stream.resultChannel()) { | ||
res.maybeFold( | ||
onCompletion = { result -> | ||
try { | ||
assertThat(result.error).isNotNull() | ||
assertThat(result.connectError()!!.code).isEqualTo(Code.DEADLINE_EXCEEDED) | ||
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED) | ||
} finally { | ||
countDownLatch.countDown() | ||
} | ||
}, | ||
) | ||
try { | ||
val result = streamResults(stream.resultChannel()) | ||
assertThat(result.error).isInstanceOf(ConnectError::class.java) | ||
val connectErr = result.error as ConnectError | ||
assertThat(connectErr.code).isEqualTo(Code.DEADLINE_EXCEEDED) | ||
assertThat(result.code).isEqualTo(Code.DEADLINE_EXCEEDED) | ||
} finally { | ||
countDownLatch.countDown() | ||
} | ||
} | ||
stream.send(request) | ||
stream.sendAndClose(request) | ||
countDownLatch.await(5, TimeUnit.SECONDS) | ||
job.cancel() | ||
assertThat(countDownLatch.count).isZero() | ||
stream.close() | ||
} | ||
} | ||
|
||
|
@@ -401,26 +401,22 @@ class Conformance( | |
fun unimplementedServerStreamingService(): Unit = runBlocking { | ||
val countDownLatch = CountDownLatch(1) | ||
val stream = unimplementedServiceClient.unimplementedStreamingOutputCall() | ||
stream.send(empty { }) | ||
stream.sendAndClose(empty { }) | ||
withContext(Dispatchers.IO) { | ||
val job = async { | ||
for (res in stream.resultChannel()) { | ||
res.maybeFold( | ||
onCompletion = { result -> | ||
try { | ||
assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED) | ||
assertThat(result.connectError()!!.code).isEqualTo(Code.UNIMPLEMENTED) | ||
} finally { | ||
countDownLatch.countDown() | ||
} | ||
}, | ||
) | ||
try { | ||
val result = streamResults(stream.resultChannel()) | ||
assertThat(result.code).isEqualTo(Code.UNIMPLEMENTED) | ||
assertThat(result.error).isInstanceOf(ConnectError::class.java) | ||
val connectErr = result.error as ConnectError | ||
assertThat(connectErr.code).isEqualTo(Code.UNIMPLEMENTED) | ||
} finally { | ||
countDownLatch.countDown() | ||
} | ||
} | ||
countDownLatch.await(5, TimeUnit.SECONDS) | ||
job.cancel() | ||
assertThat(countDownLatch.count).isZero() | ||
stream.close() | ||
} | ||
} | ||
|
||
|
@@ -754,6 +750,87 @@ class Conformance( | |
assertThat(countDownLatch.count).isZero() | ||
} | ||
|
||
@Test | ||
fun clientStreaming(): Unit = runBlocking { | ||
val stream = testServiceConnectClient.streamingInputCall(emptyMap()) | ||
var sum = 0 | ||
listOf(256000, 8, 1024, 32768).forEach { size -> | ||
stream.send( | ||
streamingInputCallRequest { | ||
payload = payload { | ||
type = PayloadType.COMPRESSABLE | ||
body = ByteString.copyFrom(ByteArray(size)) | ||
} | ||
}, | ||
).getOrThrow() | ||
sum += size | ||
} | ||
val countDownLatch = CountDownLatch(1) | ||
withContext(Dispatchers.IO) { | ||
val job = async { | ||
try { | ||
val result = stream.receiveAndClose().getOrThrow() | ||
assertThat(result.aggregatedPayloadSize).isEqualTo(sum) | ||
} finally { | ||
countDownLatch.countDown() | ||
} | ||
} | ||
countDownLatch.await(5, TimeUnit.MINUTES) | ||
job.cancel() | ||
assertThat(countDownLatch.count).isZero() | ||
} | ||
} | ||
|
||
private data class ServerStreamingResult<Output>( | ||
val headers: Headers, | ||
val messages: List<Output>, | ||
val code: Code, | ||
val trailers: Trailers, | ||
val error: Throwable?, | ||
) | ||
|
||
/* | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should really revisit the APIs around streaming calls to make this sort of code unnecessary. For now however, this will at least allow us to easily consume all of the data from a server/bidi streaming response and perform assertions on headers, messages, trailers, and errors. |
||
* Convenience method to return all results (with sanity checking) for calls which stream results from the server | ||
* (bidi and server streaming). | ||
* | ||
* This allows us to easily verify headers, messages, trailers, and errors without having to use fold/maybeFold | ||
* manually in each location. | ||
*/ | ||
private suspend fun <Output> streamResults(channel: ReceiveChannel<StreamResult<Output>>): ServerStreamingResult<Output> { | ||
val seenHeaders = AtomicBoolean(false) | ||
var headers: Headers = emptyMap() | ||
val messages: MutableList<Output> = mutableListOf() | ||
val seenCompletion = AtomicBoolean(false) | ||
var code: Code = Code.UNKNOWN | ||
var trailers: Headers = emptyMap() | ||
var error: Throwable? | ||
for (response in channel) { | ||
response.maybeFold( | ||
onHeaders = { | ||
if (!seenHeaders.compareAndSet(false, true)) { | ||
throw IllegalStateException("multiple onHeaders callbacks") | ||
} | ||
headers = it.headers | ||
}, | ||
onMessage = { | ||
messages.add(it.message) | ||
}, | ||
onCompletion = { | ||
if (!seenCompletion.compareAndSet(false, true)) { | ||
throw IllegalStateException("multiple onCompletion callbacks") | ||
} | ||
code = it.code | ||
trailers = it.trailers | ||
error = it.error | ||
}, | ||
) | ||
} | ||
if (!seenCompletion.get()) { | ||
throw IllegalStateException("didn't get completion message") | ||
} | ||
return ServerStreamingResult(headers, messages, code, trailers, error) | ||
} | ||
|
||
private fun b64Encode(trailingValue: ByteArray): String { | ||
return String(Base64.getEncoder().encode(trailingValue)) | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,10 +9,17 @@ plugins: | |
- plugin: java | ||
out: generated-google-java/build/generated/sources/bufgen | ||
protoc_path: .tmp/bin/protoc | ||
- plugin: kotlin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generate kotlin code so our examples don't look so verbose. |
||
out: generated-google-java/build/generated/sources/bufgen | ||
protoc_path: .tmp/bin/protoc | ||
- plugin: connect-kotlin | ||
out: generated-google-javalite/build/generated/sources/bufgen | ||
path: ./protoc-gen-connect-kotlin/build/install/protoc-gen-connect-kotlin/bin/protoc-gen-connect-kotlin | ||
- plugin: java | ||
out: generated-google-javalite/build/generated/sources/bufgen | ||
protoc_path: .tmp/bin/protoc | ||
opt: lite | ||
- plugin: kotlin | ||
out: generated-google-javalite/build/generated/sources/bufgen | ||
protoc_path: .tmp/bin/protoc | ||
opt: lite |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,9 +14,11 @@ | |
|
||
package com.connectrpc.examples.kotlin | ||
|
||
import com.connectrpc.Code | ||
import com.connectrpc.ConnectError | ||
import com.connectrpc.ProtocolClientConfig | ||
import com.connectrpc.eliza.v1.ConverseRequest | ||
import com.connectrpc.eliza.v1.ElizaServiceClient | ||
import com.connectrpc.eliza.v1.converseRequest | ||
import com.connectrpc.extensions.GoogleJavaProtobufStrategy | ||
import com.connectrpc.impl.ProtocolClient | ||
import com.connectrpc.okhttp.ConnectOkHttpClient | ||
|
@@ -33,37 +35,50 @@ class Main { | |
fun main(args: Array<String>) { | ||
runBlocking { | ||
val host = "https://demo.connectrpc.com" | ||
val okHttpClient = OkHttpClient() | ||
.newBuilder() | ||
.readTimeout(Duration.ofMinutes(10)) | ||
.writeTimeout(Duration.ofMinutes(10)) | ||
.callTimeout(Duration.ofMinutes(10)) | ||
.build() | ||
val client = ProtocolClient( | ||
httpClient = ConnectOkHttpClient( | ||
OkHttpClient() | ||
.newBuilder() | ||
.readTimeout(Duration.ofMinutes(10)) | ||
.writeTimeout(Duration.ofMinutes(10)) | ||
.callTimeout(Duration.ofMinutes(10)) | ||
.build(), | ||
), | ||
httpClient = ConnectOkHttpClient(okHttpClient), | ||
ProtocolClientConfig( | ||
host = host, | ||
serializationStrategy = GoogleJavaProtobufStrategy(), | ||
), | ||
) | ||
val elizaServiceClient = ElizaServiceClient(client) | ||
connectStreaming(elizaServiceClient) | ||
try { | ||
connectStreaming(elizaServiceClient) | ||
} finally { | ||
okHttpClient.dispatcher.executorService.shutdown() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is necessary for the client to cleanly shut down (the dispatcher thread pool uses non-daemon threads, so we need to shut it down manually for the app to shut down cleanly. |
||
} | ||
} | ||
} | ||
|
||
private suspend fun connectStreaming(elizaServiceClient: ElizaServiceClient) { | ||
val stream = elizaServiceClient.converse() | ||
withContext(Dispatchers.IO) { | ||
// Add the message the user is sending to the views. | ||
stream.send(ConverseRequest.newBuilder().setSentence("hello").build()) | ||
stream.send(converseRequest { sentence = "hello" }) | ||
stream.sendClose() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we don't close the bidi side of the stream, we won't ever complete. |
||
for (streamResult in stream.resultChannel()) { | ||
streamResult.maybeFold( | ||
onMessage = { result -> | ||
// Update the view with the response. | ||
val elizaResponse = result.message | ||
println(elizaResponse.sentence) | ||
}, | ||
onCompletion = { result -> | ||
if (result.code != Code.OK) { | ||
val connectErr = result.connectError() | ||
if (connectErr != null) { | ||
throw connectErr | ||
} | ||
throw ConnectError(code = result.code, metadata = result.trailers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this to at least show how errors need to be handled today in the current API. In a follow up, I'll see if this can be improved. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe just print the error details as a response? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wanted to illustrate how you might want to handle errors in real code (instead of printing code/cause/trailers). I think it also points out opportunities to make it even easier for consumers - if the Code is ever not OK, we should always have a ConnectError (even if it wraps another cause). |
||
} | ||
}, | ||
) | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New test for client streaming (used to find several issues with streaming calls and API definitions).