Skip to content

Commit

Permalink
remove flow_collect, apply flow control behaviour to collect
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Apr 12, 2019
1 parent 0a2724a commit 167bc80
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 57 deletions.
105 changes: 48 additions & 57 deletions src/treq/content.py
Expand Up @@ -4,7 +4,7 @@
import json

from twisted.internet.defer import (
Deferred, succeed, inlineCallbacks, returnValue
Deferred, succeed, inlineCallbacks, maybeDeferred,
)

from twisted.internet.protocol import Protocol
Expand All @@ -28,52 +28,30 @@ def _encoding_from_headers(headers):
return 'UTF-8'


class _BodyCollector(Protocol):
def __init__(self, finished, collector):
self.finished = finished
self.collector = collector

def dataReceived(self, data):
self.collector(data)

def connectionLost(self, reason):
if reason.check(ResponseDone):
self.finished.callback(None)
elif reason.check(PotentialDataLoss):
# http://twistedmatrix.com/trac/ticket/4840
self.finished.callback(None)
else:
self.finished.errback(reason)
class _NullTransport(object):
@staticmethod
def pauseProducing():
pass

@staticmethod
def resumeProducing():
pass

def collect(response, collector):
"""
Incrementally collect the body of the response.
@staticmethod
def stopProducing():
pass

This function may only be called **once** for a given response.

:param IResponse response: The HTTP response to collect the body from.
:param collector: A callable to be called each time data is available
from the response body.
:type collector: single argument callable
:rtype: Deferred that fires with None when the entire body has been read.
"""
if response.length == 0:
return succeed(None)

d = Deferred()
response.deliverBody(_BodyCollector(d, collector))
return d


class _FlowBodyCollector(Protocol):
class _BodyCollector(Protocol):
def __init__(self, finished, collector):
self.buffer = b''
self.writing = False
self.finished = finished
self.collector = collector

def getTransport(self):
return self.transport or _NullTransport

@inlineCallbacks
def dataReceived(self, data):
try:
Expand All @@ -83,42 +61,41 @@ def dataReceived(self, data):

w = Deferred()
self.writing = w
self.transport.pauseProducing()
self.getTransport().pauseProducing()
while self.buffer:
bufferred = self.buffer
self.buffer = b''
yield self.collector(bufferred)
self.writing = False
self.transport.resumeProducing()
self.getTransport().resumeProducing()
w.callback(None)
except Exception as e:
self.finished.errback(e)
self.transport.stopProducing()
self.getTransport().stopProducing()

@inlineCallbacks
def connectionLost(self, reason):
if self.finished.called:
return
yield self.writing

if reason.check(ResponseDone):
self.finished.callback(None)
elif reason.check(PotentialDataLoss):
# http://twistedmatrix.com/trac/ticket/4840
# PotentialDataLoss due to http://twistedmatrix.com/trac/ticket/4840
if reason.check(ResponseDone, PotentialDataLoss):
self.finished.callback(None)
else:
self.finished.errback(reason)


def flow_collect(response, collector):
def collect(response, collector):
"""
Incrementally collect the body of the response. Respecting flow control.
This function may only be called **once** for a given response.
:param IResponse response: The HTTP response to collect the body from.
:param collector: A callable that returns a deferred to be called each time
data is available from the response body.
:param collector: A callable to be called each time data is available
from the response body. If callable returns a Deferred, it will be
callable will not be called again until that Deferred completes.
:type collector: single argument callable
:rtype: Deferred that fires with None when the entire body has been read.
Expand All @@ -127,11 +104,32 @@ def flow_collect(response, collector):
return succeed(None)

d = Deferred()
response.deliverBody(_FlowBodyCollector(d, collector))
response.deliverBody(_BodyCollector(d, collector))
return d


@inlineCallbacks
class _Reduce(object):
def __init__(self, reducer, response, initializer):
self.reducer = reducer
self.response = response
self.initializer = initializer

def apply(self):
return (
collect(self.response, self.collector)
.addCallback(lambda _: self.initializer)
)

def collector(self, data):
return (
maybeDeferred(self.reducer, self.initializer, data)
.addCallback(self.update)
)

def update(self, v):
self.initializer = v


def reduce(reducer, response, initializer):
"""
Incrementally collect the body of the response. Respecting flow control.
Expand All @@ -145,14 +143,7 @@ def reduce(reducer, response, initializer):
:rtype: Deferred that fires with the accumulator when the entire body has
been read.
"""
ref = [initializer, ] # py2 nonlocal

@inlineCallbacks
def collector(data):
ref[0] = yield reducer(ref[0], data)

yield flow_collect(response, collector)
returnValue(ref[0])
return _Reduce(reducer, response, initializer).apply()


def content(response):
Expand Down
51 changes: 51 additions & 0 deletions src/treq/test/test_content.py
Expand Up @@ -7,18 +7,22 @@
from twisted.web.http_headers import Headers
from twisted.web.client import ResponseDone, ResponseFailed
from twisted.web.http import PotentialDataLoss
from twisted.internet.defer import Deferred

from treq import collect, content, json_content, text_content
from treq.content import reduce
from treq.client import _BufferedResponse


class ContentTests(TestCase):
def setUp(self):
self.response = mock.Mock()
self.transport = mock.Mock()
self.protocol = None

def deliverBody(protocol):
self.protocol = protocol
self.protocol.makeConnection(self.transport)

self.response.deliverBody.side_effect = deliverBody
self.response = _BufferedResponse(self.response)
Expand All @@ -38,6 +42,53 @@ def test_collect(self):

self.assertEqual(data, [b'{', b'"msg": "hell', b'o"}'])

def test_collect_flow_control(self):
data = []
collector_d = Deferred()

def collector(b):
data.append(b)
return collector_d

d = collect(self.response, collector)

self.protocol.dataReceived(b'{')
self.protocol.dataReceived(b'"msg": "hell')
self.protocol.dataReceived(b'o"}')
self.assertEqual(data, [b'{'])

self.transport.pauseProducing.assert_called_once_with()
collector_d.callback(None)
self.transport.resumeProducing.assert_called_once_with()

self.assertEqual(data, [b'{', b'"msg": "hello"}'])

collector_d = Deferred()
self.protocol.dataReceived(b'more_data')
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertFalse(d.called)
collector_d.callback(object())
self.assertEqual(self.successResultOf(d), None)

self.assertEqual(data, [b'{', b'"msg": "hello"}', b'more_data'])

def test_reduce(self):
d = reduce(lambda acc, b: acc+b, self.response, b'')
self.protocol.dataReceived(b'{')
self.protocol.dataReceived(b'"msg": "hell')
self.protocol.dataReceived(b'o"}')
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.successResultOf(d), b'{"msg": "hello"}')

def test_collect_fn_failure(self):
e = Exception()
def collector(b):
raise e

d = collect(self.response, collector)
self.protocol.dataReceived(b'{')
self.assertIs(self.failureResultOf(d).value, e)

def test_collect_failure(self):
data = []

Expand Down

0 comments on commit 167bc80

Please sign in to comment.