diff --git a/nbviewer/app.py b/nbviewer/app.py index d5b55f28..14cef5d5 100644 --- a/nbviewer/app.py +++ b/nbviewer/app.py @@ -18,7 +18,6 @@ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor from tornado import web, httpserver, ioloop, log -from tornado.httpclient import AsyncHTTPClient import tornado.options from tornado.options import define, options @@ -30,15 +29,10 @@ from .handlers import init_handlers from .cache import DummyAsyncCache, AsyncMultipartMemcache, MockCache, pylibmc -from .index import NoSearch, ElasticSearch +from .index import NoSearch from .formats import configure_formats - from .providers import default_providers, default_rewrites - -try: - from .providers.url.client import NBViewerCurlAsyncHTTPClient as HTTPClientClass -except ImportError: - from .providers.url.client import NBViewerSimpleAsyncHTTPClient as HTTPClientClass +from .providers.url.client import NBViewerAsyncHTTPClient as HTTPClientClass from .ratelimit import RateLimiter from .log import log_request @@ -93,6 +87,13 @@ class NBViewer(Application): gist_handler = Unicode(default_value="nbviewer.providers.gist.handlers.GistHandler", help="The Tornado handler to use for viewing notebooks stored as GitHub Gists").tag(config=True) user_gists_handler = Unicode(default_value="nbviewer.providers.gist.handlers.UserGistsHandler", help="The Tornado handler to use for viewing directory containing all of a user's Gists").tag(config=True) + client = Any().tag(config=True) + @default('client') + def _default_client(self): + client = HTTPClientClass() + client.cache = self.cache + return client + index = Any().tag(config=True) @default('index') def _load_index(self): @@ -160,15 +161,6 @@ def cache(self): return cache - # for some reason this needs to be a computed property, - # and not a traitlets Any(), otherwise nbviewer won't run - @cached_property - def client(self): - AsyncHTTPClient.configure(HTTPClientClass) - client = AsyncHTTPClient() - client.cache = self.cache - return client - @cached_property def env(self): env = Environment(loader=FileSystemLoader(self.template_paths), autoescape=True) diff --git a/nbviewer/cache.py b/nbviewer/cache.py index 42940f8b..c92b668f 100644 --- a/nbviewer/cache.py +++ b/nbviewer/cache.py @@ -9,9 +9,8 @@ from time import monotonic from concurrent.futures import ThreadPoolExecutor -from tornado.concurrent import Future +from asyncio import Future -from tornado import gen from tornado.log import app_log try: @@ -28,25 +27,25 @@ class MockCache(object): def __init__(self, *args, **kwargs): pass - def get(self, key): + async def get(self, key): f = Future() f.set_result(None) - return f + return await f - def set(self, key, value, *args, **kwargs): + async def set(self, key, value, *args, **kwargs): f = Future() f.set_result(None) - return f + return await f - def add(self, key, value, *args, **kwargs): + async def add(self, key, value, *args, **kwargs): f = Future() f.set_result(True) - return f + return await f - def incr(self, key): + async def incr(self, key): f = Future() f.set_result(None) - return f + return await f class DummyAsyncCache(object): """Dummy Async Cache. Just stores things in a dict of fixed size.""" @@ -55,10 +54,10 @@ def __init__(self, limit=10): self._cache_order = [] self.limit = limit - def get(self, key): + async def get(self, key): f = Future() f.set_result(self._get(key)) - return f + return await f def _get(self, key): value, deadline = self._cache.get(key, (None, None)) @@ -68,7 +67,7 @@ def _get(self, key): else: return value - def set(self, key, value, expires=0): + async def set(self, key, value, expires=0): if key in self._cache and self._cache_order[-1] != key: idx = self._cache_order.index(key) del self._cache_order[idx] @@ -87,18 +86,18 @@ def set(self, key, value, expires=0): self._cache[key] = (value, deadline) f = Future() f.set_result(True) - return f + return await f - def add(self, key, value, expires=0): + async def add(self, key, value, expires=0): f = Future() if self._get(key) is not None: f.set_result(False) else: - self.set(key, value, expires) + await self.set(key, value, expires) f.set_result(True) - return f + return await f - def incr(self, key): + async def incr(self, key): f = Future() if self._get(key) is not None: value, deadline = self._cache[key] @@ -107,7 +106,7 @@ def incr(self, key): else: value = None f.set_result(value) - return f + return await f class AsyncMemcache(object): """Wrap pylibmc.Client to run in a background thread @@ -119,8 +118,12 @@ def __init__(self, *args, **kwargs): self.mc = pylibmc.Client(*args, **kwargs) self.mc_pool = pylibmc.ThreadMappedPool(self.mc) - - def _call_in_thread(self, method_name, *args, **kwargs): + + self.loop = asyncio.get_event_loop() + + async def _call_in_thread(self, method_name, *args, **kwargs): + # https://stackoverflow.com/questions/34376814/await-future-from-executor-future-cant-be-used-in-await-expression + key = args[0] if 'multi' in method_name: key = sorted(key)[0].decode('ascii') + '[%i]' % len(key) @@ -130,19 +133,19 @@ def f(): with self.mc_pool.reserve() as mc: meth = getattr(mc, method_name) return meth(*args, **kwargs) - return self.pool.submit(f) + return await self.loop.run_in_executor(self.pool, f) - def get(self, *args, **kwargs): - return self._call_in_thread('get', *args, **kwargs) + async def get(self, *args, **kwargs): + return await self._call_in_thread('get', *args, **kwargs) - def set(self, *args, **kwargs): - return self._call_in_thread('set', *args, **kwargs) + async def set(self, *args, **kwargs): + return await self._call_in_thread('set', *args, **kwargs) - def add(self, *args, **kwargs): - return self._call_in_thread('add', *args, **kwargs) + async def add(self, *args, **kwargs): + return await self._call_in_thread('add', *args, **kwargs) - def incr(self, *args, **kwargs): - return self._call_in_thread('incr', *args, **kwargs) + async def incr(self, *args, **kwargs): + return await self._call_in_thread('incr', *args, **kwargs) class AsyncMultipartMemcache(AsyncMemcache): """subclass of AsyncMemcache that splits large files into multiple chunks @@ -154,11 +157,10 @@ def __init__(self, *args, **kwargs): self.max_chunks = kwargs.pop('max_chunks', 16) super(AsyncMultipartMemcache, self).__init__(*args, **kwargs) - @gen.coroutine - def get(self, key, *args, **kwargs): + async def get(self, key, *args, **kwargs): keys = [('%s.%i' % (key, idx)).encode() for idx in range(self.max_chunks)] - values = yield self._call_in_thread('get_multi', keys, *args, **kwargs) + values = await self._call_in_thread('get_multi', keys, *args, **kwargs) parts = [] for key in keys: if key not in values: @@ -171,10 +173,9 @@ def get(self, key, *args, **kwargs): except zlib.error as e: app_log.error("zlib decompression of %s failed: %s", key, e) else: - raise gen.Return(result) + return result - @gen.coroutine - def set(self, key, value, *args, **kwargs): + async def set(self, key, value, *args, **kwargs): chunk_size = self.chunk_size compressed = zlib.compress(value) offsets = range(0, len(compressed), chunk_size) @@ -186,5 +187,5 @@ def set(self, key, value, *args, **kwargs): values[('%s.%i' % (key, idx)).encode()] = compressed[ offset:offset + chunk_size ] - return self._call_in_thread('set_multi', values, *args, **kwargs) + return await self._call_in_thread('set_multi', values, *args, **kwargs) diff --git a/nbviewer/formats.py b/nbviewer/formats.py index e45c7e17..57967665 100644 --- a/nbviewer/formats.py +++ b/nbviewer/formats.py @@ -5,7 +5,6 @@ # the file COPYING, distributed as part of this software. #----------------------------------------------------------------------------- -import re import os from nbconvert.exporters.export import exporter_map diff --git a/nbviewer/providers/base.py b/nbviewer/providers/base.py index b480d72b..bb953fb8 100644 --- a/nbviewer/providers/base.py +++ b/nbviewer/providers/base.py @@ -10,6 +10,7 @@ import socket import time import statsd +import asyncio from cgi import escape from contextlib import contextmanager @@ -19,7 +20,6 @@ from urllib.parse import urlparse, urlunparse, quote, urlencode from tornado import ( - gen, httpclient, web, ) @@ -92,8 +92,7 @@ def redirect(self, url, *args, **kwargs): def set_default_headers(self): self.add_header('Content-Security-Policy', self.content_security_policy) - @gen.coroutine - def prepare(self): + async def prepare(self): """Check if the user is authenticated with JupyterHub if the hub API endpoint and token are configured. @@ -110,11 +109,11 @@ def redirect_to_login(): encrypted_cookie = self.get_cookie(self.hub_cookie_name) if not encrypted_cookie: # no cookie == not authenticated - raise gen.Return(redirect_to_login()) + return redirect_to_login() try: # if the hub returns a success code, the user is known - yield self.http_client.fetch( + await self.http_client.fetch( url_path_join(self.hub_api_url, 'authorizations/cookie', self.hub_cookie_name, @@ -126,7 +125,7 @@ def redirect_to_login(): except httpclient.HTTPError as ex: if ex.response.code == 404: # hub does not recognize the cookie == not authenticated - raise gen.Return(redirect_to_login()) + return redirect_to_login() # let all other errors surface: they're unexpected raise ex @@ -393,8 +392,7 @@ def catch_client_error(self): def fetch_kwargs(self): return self.settings.setdefault('fetch_kwargs', {}) - @gen.coroutine - def fetch(self, url, **overrides): + async def fetch(self, url, **overrides): """fetch a url with our async client handle default arguments and wrapping exceptions @@ -403,8 +401,8 @@ def fetch(self, url, **overrides): kw.update(self.fetch_kwargs) kw.update(overrides) with self.catch_client_error(): - response = yield self.client.fetch(url, **kw) - raise gen.Return(response) + response = await self.client.fetch(url, **kw) + return response def write_error(self, status_code, **kwargs): """render custom error pages""" @@ -471,14 +469,9 @@ def truncate(self, s, limit=256): s = "%s...%s" % (s[:limit//2], s[limit//2:]) return s - @gen.coroutine - def cache_and_finish(self, content=''): + async def cache_and_finish(self, content=''): """finish a request and cache the result - does not actually call finish - if used in @web.asynchronous, - finish must be called separately. But we never use @web.asynchronous, - because we are using gen.coroutine for async. - currently only works if: - result is not written in multiple chunks @@ -512,7 +505,7 @@ def cache_and_finish(self, content=''): log("caching (expiry=%is) %s", expiry, short_url) try: with time_block("cache set %s" % short_url): - yield self.cache.set( + await self.cache.set( self.cache_key, cache_data, int(time.time() + expiry), ) except Exception: @@ -527,16 +520,15 @@ def cached(method): This only handles getting from the cache, not writing to it. Writing to the cache must be handled in the decorated method. """ - @gen.coroutine - def cached_method(self, *args, **kwargs): + async def cached_method(self, *args, **kwargs): uri = self.request.path short_url = self.truncate(uri) if self.get_argument("flush_cache", False): - yield self.rate_limiter.check(self) + await self.rate_limiter.check(self) app_log.info("flushing cache %s", short_url) # call the wrapped method - yield method(self, *args, **kwargs) + await method(self, *args, **kwargs) return pending_future = self.pending.get(uri, None) @@ -544,7 +536,7 @@ def cached_method(self, *args, **kwargs): if pending_future: app_log.info("Waiting for concurrent request at %s", short_url) tic = loop.time() - yield pending_future + await pending_future toc = loop.time() app_log.info("Waited %.3fs for concurrent request at %s", toc-tic, short_url @@ -552,7 +544,7 @@ def cached_method(self, *args, **kwargs): try: with time_block("cache get %s" % short_url): - cached_pickle = yield self.cache.get(self.cache_key) + cached_pickle = await self.cache.get(self.cache_key) if cached_pickle is not None: cached = pickle.loads(cached_pickle) else: @@ -568,11 +560,11 @@ def cached_method(self, *args, **kwargs): self.write(cached['body']) else: app_log.debug("cache miss %s", short_url) - yield self.rate_limiter.check(self) + await self.rate_limiter.check(self) future = self.pending[uri] = Future() try: # call the wrapped method - yield method(self, *args, **kwargs) + await method(self, *args, **kwargs) finally: self.pending.pop(uri, None) # notify waiters @@ -673,8 +665,7 @@ def render_notebook_template(self, body, nb, download_url, json_notebook, **name date=datetime.utcnow().strftime(self.date_fmt), **namespace) - @gen.coroutine - def finish_notebook(self, json_notebook, download_url, msg=None, + async def finish_notebook(self, json_notebook, download_url, msg=None, public=False, **namespace): """Renders a notebook from its JSON body. @@ -707,9 +698,9 @@ def finish_notebook(self, json_notebook, download_url, msg=None, with time_block("Rendered %s" % download_url, debug_limit=0): app_log.info("Rendering %d B notebook from %s", len(json_notebook), download_url) render_time = self.statsd.timer('rendering.nbrender.time').start() - nbhtml, config = yield self.pool.submit(render_notebook, - self.formats[self.format], nb, download_url, - config=self.config, + loop = asyncio.get_event_loop() + nbhtml, config = await loop.run_in_executor(self.pool, render_notebook, + self.formats[self.format], nb, download_url, self.config, ) render_time.stop() except NbFormatError as e: @@ -735,7 +726,7 @@ def finish_notebook(self, json_notebook, download_url, msg=None, if 'content_type' in self.formats[self.format]: self.set_header('Content-Type', self.formats[self.format]['content_type']) - yield self.cache_and_finish(html) + await self.cache_and_finish(html) # Index notebook self.index.index_notebook(download_url, nb, public) diff --git a/nbviewer/providers/gist/handlers.py b/nbviewer/providers/gist/handlers.py index ee9a2ef4..14c3a377 100644 --- a/nbviewer/providers/gist/handlers.py +++ b/nbviewer/providers/gist/handlers.py @@ -4,7 +4,7 @@ import os import json -from tornado import web, gen +from tornado import web from tornado.log import app_log @@ -76,15 +76,14 @@ def render_usergists_template(self, entries, user, provider_url, prev_url, **namespace) @cached - @gen.coroutine - def get(self, user, **namespace): + async def get(self, user, **namespace): page = self.get_argument("page", None) params = {} if page: params['page'] = page with self.catch_client_error(): - response = yield self.github_client.get_gists(user, params=params) + response = await self.github_client.get_gists(user, params=params) prev_url, next_url = self.get_page_links(response) @@ -107,17 +106,16 @@ def get(self, user, **namespace): prev_url=prev_url, next_url=next_url, **namespace ) - yield self.cache_and_finish(html) + await self.cache_and_finish(html) class GistHandler(GistClientMixin, RenderingHandler): """render a gist notebook, or list files if a multifile gist""" - @gen.coroutine - def parse_gist(self, user, gist_id, filename=''): + async def parse_gist(self, user, gist_id, filename=''): with self.catch_client_error(): - response = yield self.github_client.get_gist(gist_id) + response = await self.github_client.get_gist(gist_id) gist = json.loads(response_text(response)) @@ -145,8 +143,7 @@ def parse_gist(self, user, gist_id, filename=''): return user, gist_id, gist, files, many_files_gist # Analogous to GitHubTreeHandler - @gen.coroutine - def tree_get(self, user, gist_id, gist, files): + async def tree_get(self, user, gist_id, gist, files): """ user, gist_id, gist, and files are (most) of the values returned by parse_gist """ @@ -199,32 +196,30 @@ def tree_get(self, user, gist_id, gist, files): executor_url=executor_url, **self.PROVIDER_CTX ) - yield self.cache_and_finish(html) + await self.cache_and_finish(html) # Analogous to GitHubBlobHandler - @gen.coroutine - def file_get(self, user, gist_id, filename, gist, many_files_gist, file): - content = yield self.get_notebook_data(gist_id, filename, many_files_gist, file) + async def file_get(self, user, gist_id, filename, gist, many_files_gist, file): + content = await self.get_notebook_data(gist_id, filename, many_files_gist, file) if not content: return - yield self.deliver_notebook(user, gist_id, filename, gist, file, content) + await self.deliver_notebook(user, gist_id, filename, gist, file, content) # Only called by file_get - @gen.coroutine - def get_notebook_data(self, gist_id, filename, many_files_gist, file): + async def get_notebook_data(self, gist_id, filename, many_files_gist, file): """ gist_id, filename, many_files_gist, file are all passed to file_get """ if (file['type'] or '').startswith('image/'): app_log.debug("Fetching raw image (%s) %s/%s: %s", file['type'], gist_id, filename, file['raw_url']) - response = yield self.fetch(file['raw_url']) + response = await self.fetch(file['raw_url']) # use raw bytes for images: content = response.body elif file['truncated']: app_log.debug("Gist %s/%s truncated, fetching %s", gist_id, filename, file['raw_url']) - response = yield self.fetch(file['raw_url']) + response = await self.fetch(file['raw_url']) content = response_text(response, encoding='utf-8') else: content = file['content'] @@ -239,8 +234,7 @@ def get_notebook_data(self, gist_id, filename, many_files_gist, file): return content # Only called by file_get - @gen.coroutine - def deliver_notebook(self, user, gist_id, filename, gist, file, content): + async def deliver_notebook(self, user, gist_id, filename, gist, file, content): """ user, gist_id, filename, gist, file, are the same values as those passed into file_get, whereas content is returned from @@ -256,7 +250,7 @@ def deliver_notebook(self, user, gist_id, filename, gist, file, content): # provider_url: str, optional # URL to the notebook document upstream at the provider (e.g., GitHub) - yield self.finish_notebook( + await self.finish_notebook( content, file['raw_url'], msg="gist: %s" % gist_id, @@ -266,16 +260,15 @@ def deliver_notebook(self, user, gist_id, filename, gist, file, content): **self.PROVIDER_CTX) @cached - @gen.coroutine - def get(self, user, gist_id, filename=''): + async def get(self, user, gist_id, filename=''): """ Encompasses both the case of a single file gist, handled by `file_get`, as well as a many-file gist, handled by `tree_get`. """ - user, gist_id, gist, files, many_files_gist = yield self.parse_gist(user, gist_id, filename) + user, gist_id, gist, files, many_files_gist = await self.parse_gist(user, gist_id, filename) if many_files_gist and not filename: - yield self.tree_get(user, gist_id, gist, files) + await self.tree_get(user, gist_id, gist, files) else: if not many_files_gist and not filename: @@ -286,7 +279,7 @@ def get(self, user, gist_id, filename=''): file = files[filename] - yield self.file_get(user, gist_id, filename, gist, many_files_gist, file) + await self.file_get(user, gist_id, filename, gist, many_files_gist, file) class GistRedirectHandler(BaseHandler): diff --git a/nbviewer/providers/github/client.py b/nbviewer/providers/github/client.py index 109c49bd..5aa4966b 100644 --- a/nbviewer/providers/github/client.py +++ b/nbviewer/providers/github/client.py @@ -10,7 +10,6 @@ from urllib.parse import urlparse -from tornado.concurrent import Future from tornado.httpclient import AsyncHTTPClient, HTTPError from tornado.httputil import url_concat from tornado.log import app_log @@ -128,13 +127,17 @@ def get_gists(self, user, callback=None, **kwargs): path = u"users/{user}/gists".format(user=user) return self.github_api_request(path, callback, **kwargs) - def get_tree(self, user, repo, ref='master', recursive=False, callback=None, **kwargs): + def get_tree(self, user, repo, path, ref='master', recursive=False, callback=None, **kwargs): """Get a git tree""" + # only need a recursive fetch if it's not in the top-level dir + if '/' in path: + recursive = True path = u"repos/{user}/{repo}/git/trees/{ref}".format(**locals()) if recursive: params = kwargs.setdefault('params', {}) params['recursive'] = True - return self.github_api_request(path, callback, **kwargs) + tree = self.github_api_request(path, callback, **kwargs) + return tree def get_branches(self, user, repo, callback=None, **kwargs): """List a repo's branches""" @@ -146,13 +149,16 @@ def get_tags(self, user, repo, callback=None, **kwargs): path = u"repos/{user}/{repo}/tags".format(user=user, repo=repo) return self.github_api_request(path, callback, **kwargs) - def _extract_tree_entry(self, path, tree_response): - """extract a single tree entry from a file list + def extract_tree_entry(self, path, tree_response): + """extract a single tree entry from + a tree response using for a path - For use as a callback in get_tree_entry raises 404 if not found + + Useful for finding the blob url for a given path. """ tree_response.rethrow() + app_log.info(tree_response) jsondata = response_text(tree_response) data = json.loads(jsondata) for entry in data['tree']: @@ -160,29 +166,3 @@ def _extract_tree_entry(self, path, tree_response): return entry raise HTTPError(404, "%s not found among %i files" % (path, len(data['tree']))) - - def get_tree_entry(self, user, repo, path, ref='master', callback=None, **kwargs): - """Get a single tree entry for a path - - Useful for finding the blob url for a given path. - """ - # only need a recursive fetch if it's not in the top-level dir - if '/' in path: - kwargs['recursive'] = True - - f = Future() - def cb(response): - try: - tree_entry = self._extract_tree_entry(path, response) - except Exception as e: - f.set_exception(e) - return - if callback: - result = callback(tree_entry) - else: - result = tree_entry - f.set_result(result) - - self.get_tree(user, repo, ref=ref, callback=cb, **kwargs) - return f - diff --git a/nbviewer/providers/github/handlers.py b/nbviewer/providers/github/handlers.py index cc04f71f..6dc1083b 100644 --- a/nbviewer/providers/github/handlers.py +++ b/nbviewer/providers/github/handlers.py @@ -9,10 +9,10 @@ import json import mimetypes import re +import asyncio from tornado import ( web, - gen, ) from tornado.log import app_log from tornado.escape import url_unescape @@ -110,14 +110,13 @@ def get(self, url): class GitHubUserHandler(GithubClientMixin, BaseHandler): """list a user's github repos""" @cached - @gen.coroutine - def get(self, user): + async def get(self, user): page = self.get_argument("page", None) params = {'sort' : 'updated'} if page: params['page'] = page with self.catch_client_error(): - response = yield self.github_client.get_repos(user, params=params) + response = await self.github_client.get_repos(user, params=params) prev_url, next_url = self.get_page_links(response) repos = json.loads(response_text(response)) @@ -135,7 +134,7 @@ def get(self, user): next_url=next_url, prev_url=prev_url, **self.PROVIDER_CTX ) - yield self.cache_and_finish(html) + await self.cache_and_finish(html) class GitHubRepoHandler(BaseHandler): @@ -165,18 +164,17 @@ def render_treelist_template(self, entries, breadcrumbs, provider_url, user, rep **self.PROVIDER_CTX, **namespace) @cached - @gen.coroutine - def get(self, user, repo, ref, path, **namespace): + async def get(self, user, repo, ref, path, **namespace): if not self.request.uri.endswith('/'): self.redirect(self.request.uri + '/') return path = path.rstrip('/') with self.catch_client_error(): - response = yield self.github_client.get_contents(user, repo, path, ref=ref) + response = await self.github_client.get_contents(user, repo, path, ref=ref) contents = json.loads(response_text(response)) - branches, tags = yield self.refs(user, repo) + branches, tags = await self.refs(user, repo) for nav_ref in branches + tags: nav_ref["url"] = (u"/github/{user}/{repo}/tree/{ref}/{path}" @@ -264,20 +262,19 @@ def get(self, user, repo, ref, path, **namespace): user=user, repo=repo, ref=ref, path=path, branches=branches, tags=tags, executor_url=executor_url, **namespace ) - yield self.cache_and_finish(html) + await self.cache_and_finish(html) - @gen.coroutine - def refs(self, user, repo): + async def refs(self, user, repo): """get branches and tags for this user/repo""" ref_types = ("branches", "tags") ref_data = [None, None] for i, ref_type in enumerate(ref_types): with self.catch_client_error(): - response = yield getattr(self.github_client, "get_%s" % ref_type)(user, repo) + response = await getattr(self.github_client, "get_%s" % ref_type)(user, repo) ref_data[i] = json.loads(response_text(response)) - raise gen.Return(ref_data) + return ref_data class GitHubBlobHandler(GithubClientMixin, RenderingHandler): @@ -289,8 +286,7 @@ class GitHubBlobHandler(GithubClientMixin, RenderingHandler): - non-notebook file, serve file unmodified - directory, redirect to tree """ - @gen.coroutine - def get_notebook_data(self, user, repo, ref, path): + async def get_notebook_data(self, user, repo, ref, path): if os.environ.get('GITHUB_API_URL', '') == '': raw_url = u"https://raw.githubusercontent.com/{user}/{repo}/{ref}/{path}".format( user=user, repo=repo, ref=ref, path=quote(path) @@ -303,9 +299,10 @@ def get_notebook_data(self, user, repo, ref, path): user=user, repo=repo, ref=ref, path=quote(path), github_url=self.github_url ) with self.catch_client_error(): - tree_entry = yield self.github_client.get_tree_entry( + tree = await self.github_client.get_tree( user, repo, path=url_unescape(path), ref=ref ) + tree_entry = self.github_client.extract_tree_entry(path=url_unescape(path), tree_response=tree) if tree_entry['type'] == 'tree': tree_url = "/github/{user}/{repo}/tree/{ref}/{path}/".format( @@ -317,11 +314,10 @@ def get_notebook_data(self, user, repo, ref, path): return raw_url, blob_url, tree_entry - @gen.coroutine - def deliver_notebook(self, user, repo, ref, path, raw_url, blob_url, tree_entry): + async def deliver_notebook(self, user, repo, ref, path, raw_url, blob_url, tree_entry): # fetch file data from the blobs API with self.catch_client_error(): - response = yield self.github_client.fetch(tree_entry['url']) + response = await self.github_client.fetch(tree_entry['url']) data = json.loads(response_text(response)) contents = data['content'] @@ -361,6 +357,7 @@ def deliver_notebook(self, user, repo, ref, path, raw_url, blob_url, tree_entry) except Exception as e: app_log.error("Failed to decode notebook: %s", raw_url, exc_info=True) raise web.HTTPError(400) + # Explanation of some kwargs passed into `finish_notebook`: # provider_url: # URL to the notebook document upstream at the provider (e.g., GitHub) @@ -368,7 +365,7 @@ def deliver_notebook(self, user, repo, ref, path, raw_url, blob_url, tree_entry) # Breadcrumb 'name' and 'url' to render as links at the top of the notebook page # executor_url: str, optional # URL to execute the notebook document (e.g., Binder) - yield self.finish_notebook(nbjson, raw_url, + await self.finish_notebook(nbjson, raw_url, provider_url=blob_url, executor_url=executor_url, breadcrumbs=breadcrumbs, @@ -379,14 +376,13 @@ def deliver_notebook(self, user, repo, ref, path, raw_url, blob_url, tree_entry) else: mime, enc = mimetypes.guess_type(path) self.set_header("Content-Type", mime or 'text/plain') - self.cache_and_finish(filedata) + await self.cache_and_finish(filedata) @cached - @gen.coroutine - def get(self, user, repo, ref, path): - raw_url, blob_url, tree_entry = yield self.get_notebook_data(user, repo, ref, path) + async def get(self, user, repo, ref, path): + raw_url, blob_url, tree_entry = await self.get_notebook_data(user, repo, ref, path) - yield self.deliver_notebook(user, repo, ref, path, raw_url, blob_url, tree_entry) + await self.deliver_notebook(user, repo, ref, path, raw_url, blob_url, tree_entry) def default_handlers(handlers=[], **handler_names): """Tornado handlers""" diff --git a/nbviewer/providers/github/tests/test_client.py b/nbviewer/providers/github/tests/test_client.py index 2f303b92..702494dd 100644 --- a/nbviewer/providers/github/tests/test_client.py +++ b/nbviewer/providers/github/tests/test_client.py @@ -1,9 +1,9 @@ # encoding: utf-8 -import mock +import unittest.mock as mock from tornado.httpclient import AsyncHTTPClient -from tornado.testing import AsyncTestCase +from tornado.testing import AsyncTestCase, gen_test from ..client import AsyncGitHubClient from ....utils import quote @@ -81,11 +81,11 @@ def test_get_tags(self): correct_url = 'https://api.github.com/repos/username/my_awesome_repo/tags' self.assertStartsWith(url, correct_url) - def test_get_tree_entry(self): + def test_get_tree(self): user = 'username' repo = 'my_awesome_repo' path = 'extra-path' - self.gh_client.get_tree_entry(user, repo, path) + self.gh_client.get_tree(user, repo, path) url = self._get_url() correct_url = 'https://api.github.com/repos/username/my_awesome_repo/git/trees/master' self.assertStartsWith(url, correct_url) diff --git a/nbviewer/providers/local/handlers.py b/nbviewer/providers/local/handlers.py index c68aec6b..6d05645b 100644 --- a/nbviewer/providers/local/handlers.py +++ b/nbviewer/providers/local/handlers.py @@ -12,7 +12,6 @@ import stat from tornado import ( - gen, web, iostream, ) @@ -65,8 +64,7 @@ def breadcrumbs(self, path): breadcrumbs.extend(super(LocalFileHandler, self).breadcrumbs(path, self._localfile_path)) return breadcrumbs - @gen.coroutine - def download(self, fullpath): + async def download(self, fullpath): """Download the file at the given absolute path. Parameters @@ -88,7 +86,7 @@ def download(self, fullpath): for chunk in content: try: self.write(chunk) - yield self.flush() + await self.flush() except iostream.StreamClosedError: return @@ -138,7 +136,7 @@ def can_show(self, path): return True - def get_notebook_data(self, path): + async def get_notebook_data(self, path): fullpath = os.path.join(self.localfile_path, path) if not self.can_show(fullpath): @@ -147,17 +145,17 @@ def get_notebook_data(self, path): if os.path.isdir(fullpath): html = self.show_dir(fullpath, path) - raise gen.Return(self.cache_and_finish(html)) + await self.cache_and_finish(html) + return is_download = self.get_query_arguments('download') if is_download: - self.download(fullpath) + await self.download(fullpath) return - + return fullpath - @gen.coroutine - def deliver_notebook(self, fullpath, path): + async def deliver_notebook(self, fullpath, path): try: with io.open(fullpath, encoding='utf-8') as f: nbdata = f.read() @@ -173,7 +171,7 @@ def deliver_notebook(self, fullpath, path): # Breadcrumb 'name' and 'url' to render as links at the top of the notebook page # title: str # Title to use as the HTML page title (i.e., text on the browser tab) - yield self.finish_notebook(nbdata, + await self.finish_notebook(nbdata, download_url='?download', msg="file from localfile: %s" % path, public=False, @@ -181,8 +179,7 @@ def deliver_notebook(self, fullpath, path): title=os.path.basename(path)) @cached - @gen.coroutine - def get(self, path): + async def get(self, path): """Get a directory listing, rendered notebook, or raw file at the given path based on the type and URL query parameters. @@ -196,9 +193,12 @@ def get(self, path): path: str Local filesystem path """ - fullpath = self.get_notebook_data(path) + fullpath = await self.get_notebook_data(path) - yield self.deliver_notebook(fullpath, path) + # get_notebook_data returns None if a directory is to be shown or a notebook is to be downloaded, + # i.e. if no notebook is supposed to be rendered, making deliver_notebook inappropriate + if fullpath: + await self.deliver_notebook(fullpath, path) # Make available to increase modularity for subclassing # E.g. so subclasses can implement templates with custom logic @@ -237,7 +237,7 @@ def show_dir(self, fullpath, path, **namespace): contents = os.listdir(fullpath) except IOError as ex: if ex.errno == errno.EACCES: - # py2/3: can't access the dir, so don't give away its presence + # can't access the dir, so don't give away its presence app_log.info("contents of path: '%s' cannot be listed from within nbviewer", fullpath) raise web.HTTPError(404) diff --git a/nbviewer/providers/url/client.py b/nbviewer/providers/url/client.py index 516c26cd..59f2e3ca 100644 --- a/nbviewer/providers/url/client.py +++ b/nbviewer/providers/url/client.py @@ -11,10 +11,11 @@ import pickle import time -from tornado.simple_httpclient import SimpleAsyncHTTPClient -from tornado.log import app_log +import asyncio -from tornado import gen +from tornado.httpclient import HTTPRequest, HTTPError +from tornado.curl_httpclient import CurlAsyncHTTPClient +from tornado.log import app_log from nbviewer.utils import time_block @@ -47,26 +48,42 @@ class NBViewerAsyncHTTPClient(object): Responses are cached as long as possible. """ - + cache = None - - def fetch_impl(self, request, callback): - self.io_loop.add_callback(lambda : self._fetch_impl(request, callback)) - - @gen.coroutine - def _fetch_impl(self, request, callback): - tic = time.time() + + def __init__(self, client=None): + self.client = client or CurlAsyncHTTPClient() + + def fetch(self, url, callback=None, params=None, **kwargs): + request = HTTPRequest(url, **kwargs) + if request.user_agent is None: request.user_agent = 'Tornado-Async-Client' - + + # The future which will become the response upon awaiting. + response_future = asyncio.ensure_future(self.smart_fetch(request, callback)) + + return response_future + + async def smart_fetch(self, request, callback): + """ + Before fetching request, first look to see whether it's already in cache. + If so load the response from cache. Only otherwise attempt to fetch the request. + When response code isn't 304 or 400, cache response before loading, else just load. + """ + tic = time.time() + # when logging, use the URL without params name = request.url.split('?')[0] - cached_response = None app_log.debug("Fetching %s", name) + + # look for a cached response + cached_response = None cache_key = hashlib.sha256(request.url.encode('utf8')).hexdigest() - with time_block("Upstream cache get %s" % name): - cached_response = yield self._get_cached_response(cache_key, name) - + cached_response = await self._get_cached_response(cache_key, name) + toc = time.time() + app_log.info("Upstream cache get %s %.2f ms", name, 1e3 * (toc-tic)) + if cached_response: app_log.info("Upstream cache hit %s", name) # add cache headers, if any @@ -74,58 +91,29 @@ def _fetch_impl(self, request, callback): value = cached_response.headers.get(resp_key) if value: request.headers[req_key] = value + return cached_response else: - app_log.debug("Upstream cache miss %s", name) - - response = yield gen.Task(super(NBViewerAsyncHTTPClient, self).fetch_impl, request) - dt = time.time() - tic - if cached_response and (response.code == 304 or response.code >= 400): - log = app_log.info if response.code == 304 else app_log.warning - log("Upstream %s on %s in %.2f ms, using cached response", - response.code, name, 1e3 * dt) - response = self._update_cached_response(response, cached_response) - callback(response) - else: - if not response.error: - app_log.info("Fetched %s in %.2f ms", name, 1e3 * dt) - callback(response) - if not response.error: - yield self._cache_response(cache_key, name, response) - - def _update_cached_response(self, three_o_four, cached_response): - """Apply any changes to the cached response from the 304 + app_log.info("Upstream cache miss %s", name) - Return the HTTPResponse to be used. + response = await self.client.fetch(request, callback) + dt = time.time() - tic + app_log.info("Fetched %s in %.2f ms", name, 1e3 * dt) + await self._cache_response(cache_key, name, response) + return response - Currently this hardcodes more recent GitHub rate limit headers, - and that's it. - Is there a better way for this to be in the right place? - - """ - # Copy GitHub rate-limit headers from 304 to the cached response - # So we don't log stale rate limits. - for key, value in three_o_four.headers.items(): - if key.lower().startswith('x-ratelimit-'): - cached_response.headers[key] = value - - return cached_response - - @gen.coroutine - def _get_cached_response(self, cache_key, name): + async def _get_cached_response(self, cache_key, name): """Get the cached response, if any""" if not self.cache: return try: - cached_pickle = yield self.cache.get(cache_key) + cached_pickle = await self.cache.get(cache_key) if cached_pickle: - raise gen.Return(pickle.loads(cached_pickle)) - except gen.Return: - raise # FIXME: remove gen.Return when we drop py2 support + app_log.info("Type of self.cache is: %s", type(self.cache)) + return pickle.loads(cached_pickle) except Exception: app_log.error("Upstream cache get failed %s", name, exc_info=True) - @gen.coroutine - def _cache_response(self, cache_key, name, response): + async def _cache_response(self, cache_key, name, response): """Cache the response, if any cache headers we understand are present.""" if not self.cache: return @@ -133,22 +121,10 @@ def _cache_response(self, cache_key, name, response): # cache the response try: pickle_response = pickle.dumps(response, pickle.HIGHEST_PROTOCOL) - yield self.cache.set( + await self.cache.set( cache_key, pickle_response, ) except Exception: app_log.error("Upstream cache failed %s" % name, exc_info=True) - -class NBViewerSimpleAsyncHTTPClient(NBViewerAsyncHTTPClient, SimpleAsyncHTTPClient): - pass - -try: - from tornado.curl_httpclient import CurlAsyncHTTPClient -except ImportError: - pass -else: - class NBViewerCurlAsyncHTTPClient(NBViewerAsyncHTTPClient, CurlAsyncHTTPClient): - pass - diff --git a/nbviewer/providers/url/handlers.py b/nbviewer/providers/url/handlers.py index d80f9cb6..b52a7eb7 100644 --- a/nbviewer/providers/url/handlers.py +++ b/nbviewer/providers/url/handlers.py @@ -9,7 +9,6 @@ from urllib import robotparser from tornado import ( - gen, httpclient, web, ) @@ -30,8 +29,7 @@ class URLHandler(RenderingHandler): """Renderer for /url or /urls""" - @gen.coroutine - def get_notebook_data(self, secure, netloc, url): + async def get_notebook_data(self, secure, netloc, url): proto = 'http' + secure netloc = url_unescape(netloc) @@ -60,7 +58,7 @@ def get_notebook_data(self, secure, netloc, url): public = False # Assume non-public try: - robots_response = yield self.fetch(robots_url) + robots_response = await self.fetch(robots_url) robotstxt = response_text(robots_response) rfp = robotparser.RobotFileParser() rfp.set_url(robots_url) @@ -75,9 +73,8 @@ def get_notebook_data(self, secure, netloc, url): return remote_url, public - @gen.coroutine - def deliver_notebook(self, remote_url, public): - response = yield self.fetch(remote_url) + async def deliver_notebook(self, remote_url, public): + response = await self.fetch(remote_url) try: nbjson = response_text(response, encoding='utf-8') @@ -85,17 +82,16 @@ def deliver_notebook(self, remote_url, public): app_log.error("Notebook is not utf8: %s", remote_url, exc_info=True) raise web.HTTPError(400) - yield self.finish_notebook(nbjson, download_url=remote_url, + await self.finish_notebook(nbjson, download_url=remote_url, msg="file from url: %s" % remote_url, public=public, request=self.request) @cached - @gen.coroutine - def get(self, secure, netloc, url): - remote_url, public = yield self.get_notebook_data(secure, netloc, url) + async def get(self, secure, netloc, url): + remote_url, public = await self.get_notebook_data(secure, netloc, url) - yield self.deliver_notebook(remote_url, public) + await self.deliver_notebook(remote_url, public) def default_handlers(handlers=[], **handler_names): """Tornado handlers""" diff --git a/nbviewer/ratelimit.py b/nbviewer/ratelimit.py index 7bda97a0..f6f9faf3 100644 --- a/nbviewer/ratelimit.py +++ b/nbviewer/ratelimit.py @@ -5,7 +5,6 @@ import hashlib -from tornado.gen import coroutine from tornado.log import app_log from tornado.web import HTTPError @@ -28,8 +27,7 @@ def key_for_handler(self, handler): hashlib.md5(agent.encode('utf8', 'replace')).hexdigest(), ) - @coroutine - def check(self, handler): + async def check(self, handler): """Check the rate limit for a handler. Identifies the source by ip and user-agent. @@ -39,11 +37,11 @@ def check(self, handler): if not self.limit: return key = self.key_for_handler(handler) - added = yield self.cache.add(key, 1, self.interval) + added = await self.cache.add(key, 1, self.interval) if not added: # it's been seen before, use incr try: - count = yield self.cache.incr(key) + count = await self.cache.incr(key) except Exception as e: app_log.warning("Failed to increment rate limit for %s", key) return diff --git a/nbviewer/render.py b/nbviewer/render.py index f1531af3..8c53d3dc 100644 --- a/nbviewer/render.py +++ b/nbviewer/render.py @@ -5,8 +5,6 @@ # the file COPYING, distributed as part of this software. #----------------------------------------------------------------------------- -import re - from tornado.log import app_log from nbconvert.exporters import Exporter diff --git a/requirements.txt b/requirements.txt index 4df7ce66..87ff476b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,5 @@ nbconvert>=5.4 ipython pycurl pylibmc -tornado<6.0 +tornado>=6.0 statsd