From 11832e58a67918cd0af3e5569ad0a0ebfcde4762 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 2 Mar 2017 18:27:21 -0500 Subject: [PATCH] Refactor to use generators and support arbitrarily sized limits --- README.md | 15 ++++-- psraw/__init__.py | 2 +- psraw/base.py | 130 +++++++++++++++++++++------------------------ psraw/endpoints.py | 63 ++++++++++++++++++++++ psraw/version.py | 2 +- 5 files changed, 137 insertions(+), 75 deletions(-) create mode 100644 psraw/endpoints.py diff --git a/README.md b/README.md index 7b5f8f7..7aec82a 100644 --- a/README.md +++ b/README.md @@ -18,20 +18,29 @@ pip install psraw ### Usage +Each function in the library is a generator. This allows you to set an extremely +high limit parameter without the need to call the functions multiple times. To +get a normal Python list, just wrap the function with the `list()` constructor. + ```python import psraw import praw r = praw.Reddit(...) -results = psraw.comment_search(r, q='teaearlgraycold', limit=30) +# Get a list from the API +results = list(psraw.comment_search(r, q='teaearlgraycold', limit=30)) + +# Or use the function as a generator +for comment in psraw.comment_search(r, q='teaearlgraycold', limit=3000): + # Do something with the comment ``` ### Available functions [Official Documentation](https://docs.google.com/document/d/171VdjT-QKJi6ul9xYJ4kmiHeC7t_3G31Ce8eozKp3VQ/edit) -In the example function signatures below, `r` is a `praw.Reddit` session object. +In the example function signatures below, `r` is a `praw.Reddit` session object. **All other arguments must be passed as keyword arguments.** Each function will return either a list of `praw.models.Comment` objects or @@ -46,7 +55,7 @@ psraw.comment_fetch(r, author='', subreddit='', limit=0, sort='asc', after=0, be ``` ```python -psraw.submission_search(r, q='', subreddit='', limit=0, sort='asc', after=0) +psraw.submission_search(r, q='', subreddit='', limit=0, sort='asc', after=0, before=0) ``` ```python diff --git a/psraw/__init__.py b/psraw/__init__.py index ea45af7..fb8182a 100644 --- a/psraw/__init__.py +++ b/psraw/__init__.py @@ -1,2 +1,2 @@ from .version import __version__ -from .base import comment_search, comment_fetch, submission_search, submission_activity +from .base import comment_fetch, comment_search, submission_activity, submission_search diff --git a/psraw/base.py b/psraw/base.py index 533bbab..d953319 100644 --- a/psraw/base.py +++ b/psraw/base.py @@ -1,90 +1,80 @@ import requests import praw - try: from urllib import urlencode except ImportError: from urllib.parse import urlencode +from .endpoints import ENDPOINTS, BASE_ADDRESS, LIMIT_MAX, LIMIT_DEFAULT + + +def limit_chunk(limit, limit_max): + """Return a list of limits given a maximum that can be requested per API + request + + :param limit: The total number of items requested + :param limit_max: The maximum number of items that can be requested at once + """ + limits = [] + x = 0 + + while x < limit: + limits.append(min(limit_max, limit - x)) + x += limit_max + + return limits -def sort_type(value): - directions = ['asc', 'desc'] - - if value in directions: - return value - else: - raise ValueError('Value must be one of: {}'.format(directions)) - - -base_address = 'https://apiv2.pushshift.io/reddit' -endpoints = { - 'comment_search': { - 'params': { - 'q': str, - 'subreddit': str, - 'limit': int, - 'sort': sort_type, - 'after': int, - 'before': int - }, - 'return_type': praw.models.Comment, - 'url': '/search/comment/' - }, - 'comment_fetch': { - 'params': { - 'author': str, - 'after': int, - 'before': int, - 'limit': int, - 'subreddit': str, - 'sort': sort_type - }, - 'return_type': praw.models.Comment, - 'url': '/comment/fetch/' - }, - 'submission_search': { - 'params': { - 'q': str, - 'subreddit': str, - 'limit': int, - 'sort': sort_type, - 'after': int - }, - 'return_type': praw.models.Submission, - 'url': '/search/submission/' - }, - 'submission_activity': { - 'params': { - 'limit': int, - 'before': int, - 'after': int - }, - 'return_type': praw.models.Submission, - 'url': '/submission/activity/' - } -} +def coerce_kwarg_types(kwargs, param_types): + """Return a dict with its values converted to types specified in param_types + + :param kwargs: The dict of parameters passed to the endpoint function + :param param_types: The dict of all valid parameters and their types (taken + from the 'param' key of the endpoint config) + """ + try: + return {key: param_types[key](value) for key, value in list(kwargs.items())} + except KeyError as e: + raise ValueError('{} parameter is not accepted'.format(e.args[0])) def create_endpoint_function(name, config): + """Dynamically create a function that handles a single API endpoint + + :param name: The name of the endpoint, which will also become the name of + the function + :param config: The configuration of the API endpoint + """ def endpoint_func(r, **kwargs): - coerced_kwargs = {} + """Placeholder that becomes an endpoint handler through closure + + :param r: A reddit session object that is passed to instantiated + Comment or Submission objects + :param **kwargs: Query parameters passed to the API endpoint + """ + coerced_kwargs = coerce_kwarg_types(kwargs, config['params']) + direction = 'before' + + if 'limit' not in coerced_kwargs: + coerced_kwargs['limit'] = LIMIT_DEFAULT + if coerced_kwargs.get('sort', None) == 'asc': + direction = 'after' + + for limit in limit_chunk(coerced_kwargs['limit'], LIMIT_MAX): + coerced_kwargs['limit'] = limit + url = '{}{}?{}'.format(BASE_ADDRESS, config['url'], urlencode(coerced_kwargs)) + data = requests.get(url).json()['data'] + + for item in data: + yield config['return_type'](r, _data=item) - for key, value in list(kwargs.items()): - try: - coerced_kwargs[key] = config['params'][key](value) - except KeyError: - raise ValueError( - '{} parameter is not accepted by {} endpoint'. - format(key, name) - ) + if len(data) < limit: + raise StopIteration - query_params = '?{}'.format(urlencode(coerced_kwargs)) - resp = requests.get('{}{}{}'.format(base_address, config['url'], query_params)) - return [config['return_type'](r, _data=x) for x in resp.json()['data']] + coerced_kwargs[direction] = data[-1]['created_utc'] endpoint_func.__name__ = name return endpoint_func -for name, config in list(endpoints.items()): +for name, config in list(ENDPOINTS.items()): globals()[name] = create_endpoint_function(name, config) diff --git a/psraw/endpoints.py b/psraw/endpoints.py new file mode 100644 index 0000000..28e5d8f --- /dev/null +++ b/psraw/endpoints.py @@ -0,0 +1,63 @@ +import praw + + +def sort_type(value): + """Ensures values passed in are one of the valid set""" + directions = {'asc', 'desc'} + + if value in directions: + return value + else: + raise ValueError('Value must be one of: {}'.format(directions)) + + +LIMIT_MAX = 500 +LIMIT_DEFAULT = 50 +BASE_ADDRESS = 'https://apiv2.pushshift.io/reddit' +ENDPOINTS = { + 'comment_fetch': { + 'params': { + 'after': int, + 'author': str, + 'before': int, + 'limit': int, + 'sort': sort_type, + 'subreddit': str + }, + 'return_type': praw.models.Comment, + 'url': '/comment/fetch/' + }, + 'comment_search': { + 'params': { + 'after': int, + 'before': int, + 'limit': int, + 'q': str, + 'sort': sort_type, + 'subreddit': str + }, + 'return_type': praw.models.Comment, + 'url': '/search/comment/' + }, + 'submission_activity': { + 'params': { + 'after': int, + 'before': int, + 'limit': int + }, + 'return_type': praw.models.Submission, + 'url': '/submission/activity/' + }, + 'submission_search': { + 'params': { + 'after': int, + 'before': int, + 'limit': int, + 'q': str, + 'sort': sort_type, + 'subreddit': str + }, + 'return_type': praw.models.Submission, + 'url': '/search/submission/' + } +} diff --git a/psraw/version.py b/psraw/version.py index d18f409..b794fd4 100644 --- a/psraw/version.py +++ b/psraw/version.py @@ -1 +1 @@ -__version__ = '0.0.2' +__version__ = '0.1.0'