Skip to content

Commit

Permalink
Merge pull request #5 from dan-hipschman-od/support-streaming
Browse files Browse the repository at this point in the history
Support streaming RPCs
  • Loading branch information
d5h committed Oct 7, 2020
2 parents ad5857d + 926bc68 commit b01d7ab
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 22 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/.idea/
/.nox/
/.venv/
/.vscode/
/coverage.xml
/dist/
/docs/_build/
Expand Down
1 change: 0 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,5 +165,4 @@ Limitations
These are the current limitations, although supporting these is possible. Contributions
or requests are welcome.

* ``ServerInterceptor`` currently only supports unary-unary RPCs.
* The package only provides service interceptors.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "grpc-interceptor"
version = "0.11.0"
version = "0.12.0"
description = "Simplifies gRPC interceptors"
license = "MIT"
readme = "README.md"
Expand Down
30 changes: 20 additions & 10 deletions src/grpc_interceptor/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base class for server-side interceptors."""

import abc
from typing import Any, Callable, NamedTuple
from typing import Any, Callable, NamedTuple, Tuple

import grpc

Expand Down Expand Up @@ -37,7 +37,7 @@ def intercept(
is typically the RPC method response, as a protobuf message. The
interceptor is free to modify this in some way, however.
"""
return method(request, context)
return method(request, context) # pragma: no cover

# Implementation of grpc.ServerInterceptor, do not override.
def intercept_service(self, continuation, handler_call_details):
Expand All @@ -47,24 +47,34 @@ def intercept_service(self, continuation, handler_call_details):
a public name. Do not override it, unless you know what you're doing.
"""
next_handler = continuation(handler_call_details)
# Make sure it's unary_unary:
if next_handler.request_streaming or next_handler.response_streaming:
raise ValueError("ServerInterceptor only handles unary_unary")
handler_factory, next_handler_method = _get_factory_and_method(next_handler)

def invoke_intercept_method(request, context):
next_interceptor_or_implementation = next_handler.unary_unary
method_name = handler_call_details.method
return self.intercept(
next_interceptor_or_implementation, request, context, method_name,
)
return self.intercept(next_handler_method, request, context, method_name,)

return grpc.unary_unary_rpc_method_handler(
return handler_factory(
invoke_intercept_method,
request_deserializer=next_handler.request_deserializer,
response_serializer=next_handler.response_serializer,
)


def _get_factory_and_method(
rpc_handler: grpc.RpcMethodHandler,
) -> Tuple[Callable, Callable]:
if rpc_handler.unary_unary:
return grpc.unary_unary_rpc_method_handler, rpc_handler.unary_unary
elif rpc_handler.unary_stream:
return grpc.unary_stream_rpc_method_handler, rpc_handler.unary_stream
elif rpc_handler.stream_unary:
return grpc.stream_unary_rpc_method_handler, rpc_handler.stream_unary
elif rpc_handler.stream_stream:
return grpc.stream_stream_rpc_method_handler, rpc_handler.stream_stream
else: # pragma: no cover
raise RuntimeError("RPC handler implementation does not exist")


class MethodName(NamedTuple):
"""Represents a gRPC method name.
Expand Down
38 changes: 31 additions & 7 deletions src/grpc_interceptor/testing/dummy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import contextmanager
import os
from tempfile import gettempdir
from typing import Callable, Dict, List
from typing import Callable, Dict, Iterable, List
from uuid import uuid4

import grpc
Expand Down Expand Up @@ -34,13 +34,37 @@ def Execute(
self, request: DummyRequest, context: grpc.ServicerContext
) -> DummyResponse:
"""Echo the input, or take on of the special cases actions."""
inp = request.input
if inp in self._special_cases:
output = self._special_cases[inp](inp)
else:
output = inp
return DummyResponse(output=self._get_output(request))

def ExecuteClientStream(
self, request_iter: Iterable[DummyRequest], context: grpc.ServicerContext
) -> DummyResponse:
"""Iterate over the input and concatenates the strings into the output."""
output = "".join(self._get_output(request) for request in request_iter)
return DummyResponse(output=output)

def ExecuteServerStream(
self, request: DummyRequest, context: grpc.ServicerContext
) -> Iterable[DummyResponse]:
"""Stream one character at a time from the input."""
for c in self._get_output(request):
yield DummyResponse(output=c)

def ExecuteClientServerStream(
self, request_iter: Iterable[DummyRequest], context: grpc.ServicerContext
) -> Iterable[DummyResponse]:
"""Stream input to output."""
for request in request_iter:
yield DummyResponse(output=self._get_output(request))

def _get_output(self, request: DummyRequest) -> str:
input = request.input
if input in self._special_cases:
output = self._special_cases[input](input)
else:
output = input
return output


@contextmanager
def dummy_client(
Expand All @@ -54,7 +78,7 @@ def dummy_client(
dummy_service = DummyService(special_cases)
dummy_pb2_grpc.add_DummyServiceServicer_to_server(dummy_service, server)

if os.name == "nt":
if os.name == "nt": # pragma: no cover
# We use Unix domain sockets when they're supported, to avoid port conflicts.
# However, on Windows, just pick a port.
channel_descriptor = "localhost:50051"
Expand Down
3 changes: 3 additions & 0 deletions src/grpc_interceptor/testing/protos/dummy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ message DummyResponse {

service DummyService {
rpc Execute (DummyRequest) returns (DummyResponse);
rpc ExecuteClientStream (stream DummyRequest) returns (DummyResponse);
rpc ExecuteServerStream (DummyRequest) returns (stream DummyResponse);
rpc ExecuteClientServerStream (stream DummyRequest) returns (stream DummyResponse);
}
33 changes: 30 additions & 3 deletions src/grpc_interceptor/testing/protos/dummy_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions src/grpc_interceptor/testing/protos/dummy_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@ def __init__(self, channel):
request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,
response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,
)
self.ExecuteClientStream = channel.stream_unary(
"/DummyService/ExecuteClientStream",
request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,
response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,
)
self.ExecuteServerStream = channel.unary_stream(
"/DummyService/ExecuteServerStream",
request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,
response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,
)
self.ExecuteClientServerStream = channel.stream_stream(
"/DummyService/ExecuteClientServerStream",
request_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.SerializeToString,
response_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.FromString,
)


class DummyServiceServicer(object):
Expand All @@ -34,6 +49,27 @@ def Execute(self, request, context):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")

def ExecuteClientStream(self, request_iterator, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")

def ExecuteServerStream(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")

def ExecuteClientServerStream(self, request_iterator, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")


def add_DummyServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -42,6 +78,21 @@ def add_DummyServiceServicer_to_server(servicer, server):
request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString,
response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString,
),
"ExecuteClientStream": grpc.stream_unary_rpc_method_handler(
servicer.ExecuteClientStream,
request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString,
response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString,
),
"ExecuteServerStream": grpc.unary_stream_rpc_method_handler(
servicer.ExecuteServerStream,
request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString,
response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString,
),
"ExecuteClientServerStream": grpc.stream_stream_rpc_method_handler(
servicer.ExecuteClientServerStream,
request_deserializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyRequest.FromString,
response_serializer=grpc__interceptor_dot_testing_dot_protos_dot_dummy__pb2.DummyResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
"DummyService", rpc_method_handlers
Expand Down
54 changes: 54 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Test cases for streaming RPCs."""

import grpc
import pytest

from grpc_interceptor import ServerInterceptor
from grpc_interceptor.testing import dummy_client, DummyRequest


class StreamingInterceptor(ServerInterceptor):
"""A test interceptor that streams."""

def intercept(self, method, request, context, method_name):
"""Doesn't do anything; just make sure we handle streaming RPCs."""
return method(request, context)


@pytest.fixture
def interceptors():
"""The interceptor chain for this test suite."""
intr = StreamingInterceptor()
return [intr]


def test_client_streaming(interceptors):
"""Client streaming should work."""
special_cases = {"error": lambda r, c: 1 / 0}
with dummy_client(special_cases=special_cases, interceptors=interceptors) as client:
inputs = ["foo", "bar"]
input_iter = (DummyRequest(input=input) for input in inputs)
assert client.ExecuteClientStream(input_iter).output == "foobar"

inputs = ["foo", "error"]
input_iter = (DummyRequest(input=input) for input in inputs)
with pytest.raises(grpc.RpcError):
client.ExecuteClientStream(input_iter)


def test_server_streaming(interceptors):
"""Server streaming should work."""
with dummy_client(special_cases={}, interceptors=interceptors) as client:
output = [
r.output for r in client.ExecuteServerStream(DummyRequest(input="foo"))
]
assert output == ["f", "o", "o"]


def test_client_server_streaming(interceptors):
"""Bidirectional streaming should work."""
with dummy_client(special_cases={}, interceptors=interceptors) as client:
inputs = ["foo", "bar"]
input_iter = (DummyRequest(input=input) for input in inputs)
response = client.ExecuteClientServerStream(input_iter)
assert [r.output for r in response] == inputs

0 comments on commit b01d7ab

Please sign in to comment.