Skip to content

Commit

Permalink
Add rate-limit support
Browse files Browse the repository at this point in the history
To be able to get the HTTP headers from Twisted, I had to:

1) Implement my own connect function that returns the HTTPDownloader
   object
2) Work around a bug on Twisted 8.2.0

Signed-off-by: Eduardo Habkost <ehabkost@raisama.net>
  • Loading branch information
ehabkost committed Dec 19, 2009
1 parent 7102090 commit e9c3033
Showing 1 changed file with 90 additions and 4 deletions.
94 changes: 90 additions & 4 deletions twittytwister/twitter.py
Expand Up @@ -10,10 +10,11 @@
import urllib
import mimetypes
import mimetools
import logging

from oauth import oauth

from twisted.internet import defer
from twisted.internet import defer, reactor
from twisted.web import client

import txml
Expand All @@ -23,6 +24,31 @@
BASE_URL="http://twitter.com"
SEARCH_URL="http://search.twitter.com/search.atom"


logger = logging.getLogger('twittytwister.twitter')


##### ugly hack to work around a bug on HTTPDownloader on Twisted 8.2.0 (fixed on 9.0.0)
def install_twisted_fix():
orig_method = client.HTTPDownloader.gotHeaders
def gotHeaders(self, headers):
client.HTTPClientFactory.gotHeaders(self, headers)
orig_method(self, headers)
client.HTTPDownloader.gotHeaders = gotHeaders

def buggy_twisted():
o = client.HTTPDownloader('http://dummy-url/foo', None)
client.HTTPDownloader.gotHeaders(o, {})
if o.response_headers is None:
return True
return False

if buggy_twisted():
install_twisted_fix()

##### end of hack


class TwitterClientInfo:
def __init__ (self, name, version = None, url = None):
self.name = name
Expand All @@ -40,6 +66,29 @@ def get_headers (self):
def get_source (self):
return self.name


def downloadPage(url, file, **kwargs):
"""Start a HTTP download, returning a HTTPDownloader object"""

# The Twisted API is weird:
# 1) web.client.downloadPage() doesn't give us the HTTP headers
# 2) there is no method that simply accepts a URL and gives you back
# a HTTPDownloader object

#TODO: convert getPage() usage to something similar, too

downloader = client.HTTPDownloader(url, file, **kwargs)
if downloader.scheme == 'https':
from twisted.internet import ssl
contextFactory = ssl.ClientContextFactory()
reactor.connectSSL(downloader.host, downloader.port,
downloader, contextFactory)
else:
reactor.connectTCP(downloader.host, downloader.port,
downloader)
return downloader


class Twitter(object):

agent="twitty twister"
Expand All @@ -55,6 +104,11 @@ def __init__(self, user=None, passwd=None,
self.use_oauth = False
self.client_info = None

# rate-limit info:
self.rate_limit_limit = None
self.rate_limit_remaining = None
self.rate_limit_reset = None

if user and passwd:
self.use_auth = True
self.username = user
Expand Down Expand Up @@ -132,6 +186,23 @@ def __encodeMultipart(self, fields, files):

return boundary, body

def gotHeaders(self, headers):
logger.debug("hdrs: %r", headers)

def ih(hdr):
r = headers.get(hdr)
if r is not None and len(r) > 0 and r[0]:
return int(r[0])
else:
return None

self.rate_limit_limit = ih('x-ratelimit-limit')
self.rate_limit_remaining = ih('x-ratelimit-remaining')
self.rate_limit_reset = ih('x-ratelimit-reset')

logger.debug('hdrs end')


def __getContentType(self, filename):
return mimetypes.guess_type(filename)[0] or 'application/octet-stream'

Expand Down Expand Up @@ -165,6 +236,21 @@ def __post(self, path, args={}):
agent=self.agent,
postdata=self._urlencode(args), headers=headers)

def __doDownloadPage(self, *args, **kwargs):
"""Works like client.downloadPage(), but handle incoming headers
"""
logger.debug("download page: %r, %r" % (args, kwargs))

d = defer.Deferred()
c = downloadPage(*args, **kwargs)

def done(*args, **kwargs):
self.gotHeaders(c.response_headers)
return d.callback(*args, **kwargs)

c.deferred.addCallbacks(done, d.errback)
return d

def __postPage(self, path, parser, args={}, params=None):
url = self.base_url + path
if params:
Expand All @@ -176,7 +262,7 @@ def __postPage(self, path, parser, args={}, params=None):
headers.update(self.client_info.get_headers())
args['source'] = self.client_info.get_source()

return client.downloadPage(url, parser, method='POST',
return self.__doDownloadPage(url, parser, method='POST',
agent=self.agent,
postdata=self._urlencode(args), headers=headers)

Expand All @@ -187,7 +273,7 @@ def __downloadPage(self, path, parser, params=None):

headers = self.makeAuthHeader('GET', url)

return client.downloadPage(url, parser,
return self.__doDownloadPage(url, parser,
agent=self.agent, headers=headers)

def __get(self, path, delegate, params, parser_factory=txml.Feed, extra_args=None):
Expand Down Expand Up @@ -368,7 +454,7 @@ def exampleDelegate(entry):
if args is None:
args = {}
args['q'] = query
return client.downloadPage(self.search_url + '?' + self._urlencode(args),
return self.__doDownloadPage(self.search_url + '?' + self._urlencode(args),
txml.Feed(delegate, extra_args), agent=self.agent)

def block(self, user):
Expand Down

0 comments on commit e9c3033

Please sign in to comment.