Skip to content
Permalink
Browse files

Duplex gRPC via Jetty (#995)

  • Loading branch information...
oldergod committed May 19, 2019
1 parent d9b9745 commit 070a1001ff3157d79fadfada3316c99e508c1d11
@@ -0,0 +1,29 @@
package misk.grpc.miskserver

import misk.grpc.BlockingGrpcChannel
import misk.grpc.GrpcReceiveChannel
import misk.grpc.GrpcSendChannel
import misk.grpc.consumeEach
import misk.web.Grpc
import misk.web.RequestBody
import misk.web.actions.WebAction
import routeguide.RouteNote
import javax.inject.Inject

// TODO: Misk should pass in the channel rather than returning it.
class RouteChatGrpcAction @Inject constructor() : WebAction {
@Grpc("/routeguide.RouteGuide/RouteChat")
fun chat(@RequestBody request: GrpcReceiveChannel<RouteNote>): GrpcSendChannel<RouteNote> {
val response = BlockingGrpcChannel<RouteNote>()

Thread {
response.use { response ->
request.consumeEach { routeNote ->
response.send(RouteNote(message = "ACK: ${routeNote.message}"))
}
}
}.start()

return response
}
}
@@ -12,6 +12,7 @@ class RouteGuideMiskServiceModule : KAbstractModule() {
override fun configure() {
install(WebTestingModule())
install(WebActionModule.create<GetFeatureGrpcAction>())
install(WebActionModule.create<RouteChatGrpcAction>())
}

@Provides
@@ -11,6 +11,7 @@ import org.junit.jupiter.api.Test
import routeguide.Feature
import routeguide.Point
import routeguide.RouteGuide
import routeguide.RouteNote
import javax.inject.Inject
import javax.inject.Provider

@@ -37,4 +38,18 @@ class MiskClientMiskServerTest {
))
}
}

@Test
fun duplexStreaming() {
runBlocking {
val routeGuide = routeGuideProvider.get()

val (sendChannel, receiveChannel) = routeGuide.RouteChat()
sendChannel.send(RouteNote(message = "a"))
assertThat(receiveChannel.receive()).isEqualTo(RouteNote(message = "ACK: a"))
sendChannel.send(RouteNote(message = "b"))
assertThat(receiveChannel.receive()).isEqualTo(RouteNote(message = "ACK: b"))
sendChannel.close()
}
}
}
@@ -0,0 +1,120 @@
package misk.grpc

import com.squareup.wire.ProtoAdapter
import misk.web.ResponseBody
import misk.web.marshal.GenericMarshallers
import misk.web.marshal.GenericUnmarshallers
import misk.web.marshal.Marshaller
import misk.web.marshal.Unmarshaller
import misk.web.mediatype.MediaTypes
import okhttp3.MediaType
import okio.BufferedSink
import okio.BufferedSource
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import javax.inject.Inject
import javax.inject.Singleton
import kotlin.reflect.KType
import kotlin.reflect.jvm.javaType

@Singleton
class GrpcMarshallerFactory @Inject constructor() : Marshaller.Factory {
override fun create(mediaType: MediaType, type: KType): Marshaller<Any>? {
if (mediaType.type() != MediaTypes.APPLICATION_GRPC_MEDIA_TYPE.type() ||
mediaType.subtype() != MediaTypes.APPLICATION_GRPC_MEDIA_TYPE.subtype()) {
return null
}

val responseType = Marshaller.actualResponseType(type)
if (GenericMarshallers.canHandle(responseType)) return null

val elementType = type.streamElementType()

@Suppress("UNCHECKED_CAST") // Guarded by reflection.
return if (elementType != null) {
GrpcStreamMarshaller(ProtoAdapter.get(elementType as Class<Any>)) as Marshaller<Any>
} else {
GrpcSingleMarshaller<Any>(ProtoAdapter.get(responseType as Class<Any>))
}
}
}

internal class GrpcSingleMarshaller<T>(val adapter: ProtoAdapter<T>) : Marshaller<T> {
override fun contentType() = MediaTypes.APPLICATION_GRPC_MEDIA_TYPE

override fun responseBody(o: T): ResponseBody {
return object : ResponseBody {
override fun writeTo(sink: BufferedSink) {
val writer = GrpcWriter.get(sink, adapter)
writer.writeMessage(o)
}
}
}
}

internal class GrpcStreamMarshaller<T>(val adapter: ProtoAdapter<T>) : Marshaller<GrpcReceiveChannel<T>> {
override fun contentType() = MediaTypes.APPLICATION_GRPC_MEDIA_TYPE

override fun responseBody(o: GrpcReceiveChannel<T>): ResponseBody {
return object : ResponseBody {
override fun writeTo(sink: BufferedSink) {
GrpcWriter.get(sink, adapter).use { writer ->
o.consumeEach { message ->
writer.writeMessage(message)
writer.flush()
}
}
}
}
}
}

@Singleton
class GrpcUnmarshallerFactory @Inject constructor() : Unmarshaller.Factory {
override fun create(mediaType: MediaType, type: KType): Unmarshaller? {
if (mediaType.type() != MediaTypes.APPLICATION_GRPC_MEDIA_TYPE.type() ||
mediaType.subtype() != MediaTypes.APPLICATION_GRPC_MEDIA_TYPE.subtype()) {
return null
}

if (GenericUnmarshallers.canHandle(type)) return null

val elementType = type.streamElementType()

@Suppress("UNCHECKED_CAST") // Guarded by reflection.
return if (elementType != null) {
GrpcStreamUnmarshaller(ProtoAdapter.get(elementType as Class<Any>))
} else {
GrpcSingleUnmarshaller(ProtoAdapter.get(type.javaType as Class<Any>))
}
}
}

internal class GrpcSingleUnmarshaller<T>(val adapter: ProtoAdapter<T>) : Unmarshaller {
override fun unmarshal(source: BufferedSource): Any? {
return GrpcReader.get(source, adapter).readMessage()
}
}

internal class GrpcStreamUnmarshaller<T>(val adapter: ProtoAdapter<T>) : Unmarshaller {
override fun unmarshal(source: BufferedSource): GrpcReceiveChannel<T> {
return object : GrpcReceiveChannel<T> {
val grpcReader = GrpcReader.get(source, adapter)

override fun receiveOrNull(): T? {
return grpcReader.readMessage()
}
}
}
}

/**
* Returns the channel element type, like `MyRequest` if this is `Channel<MyRequest>`. Returns null
* if this is not a channel.
*/
private fun KType.streamElementType(): Type? {
val parameterizedType = javaType as? ParameterizedType ?: return null
if (parameterizedType.rawType != GrpcReceiveChannel::class.java &&
parameterizedType.rawType != GrpcSendChannel::class.java) return null
return parameterizedType.actualTypeArguments[0]
}
@@ -0,0 +1,43 @@
package misk.grpc

import java.io.Closeable
import java.util.concurrent.LinkedBlockingDeque

interface GrpcSendChannel<T> : Closeable {
fun send(message: T)
}

interface GrpcReceiveChannel<T> {
fun receiveOrNull(): T?
}

fun <T> GrpcReceiveChannel<T>.consumeEach(block: (T) -> Unit) {
while (true) {
val message = receiveOrNull() ?: return
block(message)
}
}

class BlockingGrpcChannel<T> : GrpcSendChannel<T>, GrpcReceiveChannel<T> {
private val queue = LinkedBlockingDeque<T>(1)
private object Eof

override fun send(message: T) {
queue.put(message)
}

@Suppress("UNCHECKED_CAST")
override fun receiveOrNull(): T? {
val message = queue.take()
if (message == Eof) {
queue.put(Eof as T)
return null
}
return message
}

@Suppress("UNCHECKED_CAST")
override fun close() {
queue.put(Eof as T)
}
}
@@ -37,8 +37,8 @@ import misk.web.interceptors.TracingInterceptor
import misk.web.jetty.JettyConnectionMetricsCollector
import misk.web.jetty.JettyService
import misk.web.jetty.JettyThreadPoolMetricsCollector
import misk.web.marshal.GrpcMarshaller
import misk.web.marshal.GrpcUnmarshaller
import misk.grpc.GrpcMarshallerFactory
import misk.grpc.GrpcUnmarshallerFactory
import misk.web.marshal.JsonMarshaller
import misk.web.marshal.JsonUnmarshaller
import misk.web.marshal.Marshaller
@@ -77,10 +77,10 @@ class MiskWebModule(private val config: WebConfig) : KAbstractModule() {
multibind<Marshaller.Factory>().to<PlainTextMarshaller.Factory>()
multibind<Marshaller.Factory>().to<JsonMarshaller.Factory>()
multibind<Marshaller.Factory>().to<ProtobufMarshaller.Factory>()
multibind<Marshaller.Factory>().to<GrpcMarshaller.Factory>()
multibind<Marshaller.Factory>().to<GrpcMarshallerFactory>()
multibind<Unmarshaller.Factory>().to<JsonUnmarshaller.Factory>()
multibind<Unmarshaller.Factory>().to<ProtobufUnmarshaller.Factory>()
multibind<Unmarshaller.Factory>().to<GrpcUnmarshaller.Factory>()
multibind<Unmarshaller.Factory>().to<GrpcUnmarshallerFactory>()

// Initialize empty sets for our multibindings.
newMultibinder<NetworkInterceptor.Factory>()

This file was deleted.

0 comments on commit 070a100

Please sign in to comment.
You can’t perform that action at this time.