Skip to content

Commit

Permalink
simplify headermanager; upgrade h2
Browse files Browse the repository at this point in the history
  • Loading branch information
mattbennett committed Jan 8, 2019
1 parent 268f367 commit 0fd0c79
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 56 deletions.
16 changes: 10 additions & 6 deletions nameko_grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from urllib.parse import urlparse

from grpc import StatusCode
from h2.errors import PROTOCOL_ERROR # changed under h2 from 2.6.4?
from h2.errors import ErrorCodes

from nameko_grpc.compression import SUPPORTED_ENCODINGS, UnsupportedEncoding
from nameko_grpc.connection import ConnectionManager
Expand Down Expand Up @@ -57,11 +57,13 @@ def send_request(self, request_headers):
"""
stream_id = next(self.counter)

request_stream = SendStream(stream_id, headers=request_headers)
request_stream = SendStream(stream_id)
response_stream = ReceiveStream(stream_id)
self.receive_streams[stream_id] = response_stream
self.send_streams[stream_id] = request_stream

request_stream.headers.set(*request_headers)

self.pending_requests.append(stream_id)

return request_stream, response_stream
Expand All @@ -73,9 +75,10 @@ def response_received(self, event):
"""
super().response_received(event)

response_stream = self.receive_streams.get(event.stream_id)
stream_id = event.stream_id
response_stream = self.receive_streams.get(stream_id)
if response_stream is None:
self.conn.reset_stream(event.stream_id, error_code=PROTOCOL_ERROR)
self.conn.reset_stream(stream_id, error_code=ErrorCodes.PROTOCOL_ERROR)
return

headers = response_stream.headers
Expand All @@ -91,9 +94,10 @@ def trailers_received(self, event):
"""
super().trailers_received(event)

response_stream = self.receive_streams.get(event.stream_id)
stream_id = event.stream_id
response_stream = self.receive_streams.get(stream_id)
if response_stream is None:
self.conn.reset_stream(event.stream_id, error_code=PROTOCOL_ERROR)
self.conn.reset_stream(stream_id, error_code=ErrorCodes.PROTOCOL_ERROR)
return

trailers = response_stream.trailers
Expand Down
8 changes: 4 additions & 4 deletions nameko_grpc/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from h2.config import H2Configuration
from h2.connection import H2Connection
from h2.errors import PROTOCOL_ERROR # changed under h2 from 2.6.4?
from h2.errors import ErrorCodes
from h2.events import (
DataReceived,
RemoteSettingsChanged,
Expand Down Expand Up @@ -121,7 +121,7 @@ def response_received(self, event):
if not receive_stream:
return

receive_stream.headers.set(*event.headers)
receive_stream.headers.set(*event.headers, from_wire=True)

def data_received(self, event):
""" Called when data is received on a stream.
Expand All @@ -134,7 +134,7 @@ def data_received(self, event):
receive_stream = self.receive_streams.get(stream_id)
if receive_stream is None:
try:
self.conn.reset_stream(stream_id, error_code=PROTOCOL_ERROR)
self.conn.reset_stream(stream_id, error_code=ErrorCodes.PROTOCOL_ERROR)
except StreamClosedError:
pass
return
Expand Down Expand Up @@ -172,7 +172,7 @@ def trailers_received(self, event):
if not receive_stream:
return

receive_stream.trailers.set(*event.headers)
receive_stream.trailers.set(*event.headers, from_wire=True)

def send_headers(self, stream_id, immediate=False):
""" Attempt to send any headers on a stream.
Expand Down
4 changes: 3 additions & 1 deletion nameko_grpc/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ def request_received(self, event):

stream_id = event.stream_id

request_stream = ReceiveStream(stream_id, headers=event.headers)
request_stream = ReceiveStream(stream_id)
response_stream = SendStream(stream_id)
self.receive_streams[stream_id] = request_stream
self.send_streams[stream_id] = response_stream

request_stream.headers.set(*event.headers, from_wire=True)

compression = select_algorithm(
request_stream.headers.get("grpc-accept-encoding")
)
Expand Down
110 changes: 70 additions & 40 deletions nameko_grpc/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,14 @@ def is_http2_header(name):
return name in ("te", "content-type", "user-agent", "accept-encoding")


def maybe_base64_decode(header):
name, value = header
if name.endswith("-bin"):
value = base64.b64decode(value)
return name, value


def maybe_base64_encode(header):
name, value = header
if name.endswith("-bin"):
value = base64.b64encode(value)
return name, value


def filter_headers_for_application(headers):
def include(header):
name, _ = header
if is_pseudo_header(name) or is_grpc_header(name) or is_http2_header(name):
return False
return True

return filter(include, headers)
return list(filter(include, headers))


def sort_headers_for_wire(headers):
Expand All @@ -53,64 +39,108 @@ def weight(header):
return sorted(headers, key=weight)


def check_encoded(headers):
for name, value in headers:
assert isinstance(name, bytes)
assert isinstance(value, bytes)


def check_decoded(headers):
for name, value in headers:
assert not isinstance(name, bytes)
if name.endswith("-bin"):
assert isinstance(value, bytes)
else:
assert not isinstance(value, bytes)


def decode_header(header):
name, value = header
name = name.decode("utf-8")
if name.endswith("-bin"):
value = base64.b64decode(value)
else:
value = value.decode("utf-8")
return name, value


def encode_header(header):
name, value = header
name = name.encode("utf-8")
if name.endswith(b"-bin"):
value = base64.b64encode(value)
else:
value = value.encode("utf-8")
return name, value


class HeaderManager:
def __init__(self, data=None):
def __init__(self):
"""
Object for managing headers (or trailers).
Accepts a list of tuples in the form `("<name>", "<value>")`
Object for managing headers (or trailers). Headers are encoded before
transmitting over the wire and stored in their non-encoded form.
"""
if data is None:
data = []
self.data = data
self.data = []

@staticmethod
def decode(headers):
return list(map(decode_header, headers))

@staticmethod
def encode(headers):
return list(map(encode_header, headers))

def get(self, name, default=None):
""" Get a header by `name`.
Performs base64 decoding of binary values, and joins duplicate headers into a
comma-separated string of values.
Joins duplicate headers into a comma-separated string of values.
"""
matches = []
for key, value in self.data:
if key == name:
_, value = maybe_base64_decode((key, value))
matches.append(value)
if not matches:
return default
return ",".join(matches)

def set(self, *headers):
def set(self, *headers, from_wire=False):
""" Set headers.
Overwrites any existing header with the same name.
Overwrites any existing header with the same name. Optionally decodes
the headers first.
"""
for name, _ in headers:
self.clear(name)
self.append(*headers)
if from_wire:
headers = self.decode(headers)
check_decoded(headers)

def clear(self, name):
""" Clear a header by `name`.
"""
self.data = list(filter(lambda header: header[0] != name, self.data))
# clear existing headers with these names
to_clear = dict(headers).keys()
self.data = [(key, value) for (key, value) in self.data if key not in to_clear]

self.data.extend(headers)

def append(self, *headers):
def append(self, *headers, from_wire=False):
""" Add new headers.
Preserves and appends to any existing header with the same name.
Preserves and appends to any existing header with the same name. Optionally
decodes the headers first.
"""
if from_wire:
headers = self.decode(headers)
check_decoded(headers)
self.data.extend(headers)

def __len__(self):
return len(self.data)

@property
def for_wire(self):
""" A sorted list of headers for transmitting over the wire.
""" A sorted list of encoded headers for transmitting over the wire.
"""
return list(map(maybe_base64_encode, sort_headers_for_wire(self.data)))
return self.encode(sort_headers_for_wire(self.data))

@property
def for_application(self):
""" A filtered and processed list of headers for use by the application.
""" A filtered list of headers for use by the application.
"""
return list(map(maybe_base64_decode, filter_headers_for_application(self.data)))
return filter_headers_for_application(self.data)
6 changes: 3 additions & 3 deletions nameko_grpc/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def __len__(self):


class StreamBase:
def __init__(self, stream_id, headers=None, trailers=None):
def __init__(self, stream_id):
self.stream_id = stream_id

self.headers = HeaderManager(headers)
self.trailers = HeaderManager(trailers)
self.headers = HeaderManager()
self.trailers = HeaderManager()

self.queue = Queue()
self.buffer = ByteBuffer()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
author="onefinestay",
url="http://github.com/nameko/nameko-grpc",
packages=find_packages(exclude=["test", "test.*"]),
install_requires=["nameko", "h2==2.6.2"],
install_requires=["nameko", "h2>=3"],
extras_require={
"dev": [
"pytest",
Expand Down
1 change: 0 additions & 1 deletion test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def extract_metadata(context):
def maybe_echo_metadata(context):
""" Copy metadata from request to response.
"""

initial_metadata = []
trailing_metadata = []

Expand Down

0 comments on commit 0fd0c79

Please sign in to comment.