Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migration updates and bug fixes from 3.x #77

Merged
merged 1 commit into from
Dec 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 32 additions & 0 deletions jupyter_archive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import os

from notebook.utils import url_path_join
from traitlets.config import Configurable
from traitlets import Int, default

from .handlers import DownloadArchiveHandler
from .handlers import ExtractArchiveHandler

Expand All @@ -8,7 +13,34 @@ def _jupyter_server_extension_paths():
return [{"module": "jupyter_archive"}]


class JupyterArchive(Configurable):
stream_max_buffer_size = Int(help="The max size of tornado IOStream buffer",
config=True)

@default("stream_max_buffer_size")
def _default_stream_max_buffer_size(self):
# 100 * 1024 * 1024 equals to 100M
return int(os.environ.get("JA_IOSTREAM_MAX_BUFFER_SIZE", 100 * 1024 * 1024))

handler_max_buffer_length = Int(help="The max length of chunks in tornado RequestHandler",
config=True)

@default("handler_max_buffer_length")
def _default_handler_max_buffer_length(self):
# if 8K for one chunk, 10240 * 8K equals to 80M
return int(os.environ.get("JA_HANDLER_MAX_BUFFER_LENGTH", 10240))

archive_download_flush_delay = Int(help="The delay in ms at which we send the chunk of data to the client.",
config=True)

@default("archive_download_flush_delay")
def _default_archive_download_flush_delay(self):
return int(os.environ.get("JA_ARCHIVE_DOWNLOAD_FLUSH_DELAY", 100))


def load_jupyter_server_extension(nbapp):
config = JupyterArchive(config=nbapp.config)
nbapp.web_app.settings["jupyter_archive"] = config

# Add download handler.
base_url = url_path_join(nbapp.web_app.settings["base_url"], r"/directories/(.*)")
Expand Down
58 changes: 42 additions & 16 deletions jupyter_archive/handlers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import asyncio
import time
import zipfile
import tarfile
import pathlib
from urllib.parse import quote

from tornado import gen, web, iostream, ioloop
from notebook.base.handlers import IPythonHandler
from notebook.utils import url2path


# The delay in ms at which we send the chunk of data
# to the client.
ARCHIVE_DOWNLOAD_FLUSH_DELAY = 100
Expand All @@ -31,6 +32,18 @@ def __init__(self, handler):
self.position = 0

def write(self, data):
if self.handler.canceled:
raise ValueError("File download canceled")
# timeout 600s for this while loop
time_out_cnt = 600 * 1000 / self.handler.archive_download_flush_delay
while len(self.handler._write_buffer) > self.handler.handler_max_buffer_length:
# write_buffer or handler is too large, wait for an flush cycle
time.sleep(self.handler.archive_download_flush_delay / 1000)
if self.handler.canceled:
raise ValueError("File download canceled")
time_out_cnt -= 1
if time_out_cnt <= 0:
raise ValueError("Time out for writing into tornado buffer")
self.position += len(data)
self.handler.write(data)
del data
Expand Down Expand Up @@ -63,23 +76,41 @@ def make_writer(handler, archive_format="zip"):


def make_reader(archive_path):
archive_format = "".join(archive_path.suffixes)

archive_format = "".join(archive_path.suffixes)[1:]

if archive_format == "zip":
if archive_format.endswith(".zip"):
archive_file = zipfile.ZipFile(archive_path, mode="r")
elif archive_format in ["tgz", "tar.gz"]:
elif any([archive_format.endswith(ext) for ext in [".tgz", ".tar.gz"]]):
archive_file = tarfile.open(archive_path, mode="r|gz")
elif archive_format in ["tbz", "tbz2", "tar.bz", "tar.bz2"]:
elif any([archive_format.endswith(ext) for ext in [".tbz", ".tbz2", ".tar.bz", ".tar.bz2"]]):
archive_file = tarfile.open(archive_path, mode="r|bz2")
elif archive_format in ["txz", "tar.xz"]:
elif any([archive_format.endswith(ext) for ext in [".txz", ".tar.xz"]]):
archive_file = tarfile.open(archive_path, mode="r|xz")
else:
raise ValueError("'{}' is not a valid archive format.".format(archive_format))
return archive_file


class DownloadArchiveHandler(IPythonHandler):
@property
def stream_max_buffer_size(self):
return self.settings["jupyter_archive"].stream_max_buffer_size

@property
def handler_max_buffer_length(self):
return self.settings["jupyter_archive"].handler_max_buffer_length

@property
def archive_download_flush_delay(self):
return self.settings["jupyter_archive"].archive_download_flush_delay

def flush(self, include_footers=False):
# skip flush when stream_buffer is larger than stream_max_buffer_size
stream_buffer = self.request.connection.stream._write_buffer
if stream_buffer and len(stream_buffer) > self.stream_max_buffer_size:
return
return super(DownloadArchiveHandler, self).flush(include_footers)

@web.authenticated
@gen.coroutine
def get(self, archive_path, include_body=False):
Expand Down Expand Up @@ -113,11 +144,9 @@ def get(self, archive_path, include_body=False):
else:
raise web.HTTPError(400)

archive_path = os.path.join(cm.root_dir, url2path(archive_path))

archive_path = pathlib.Path(archive_path)
archive_name = archive_path.name
archive_filename = archive_path.with_suffix(".{}".format(archive_format)).name
archive_path = pathlib.Path(cm.root_dir) / url2path(archive_path)
archive_filename = f"{archive_path.name}.{archive_format}"
archive_filename = quote(archive_filename)

self.log.info("Prepare {} for archiving and downloading.".format(archive_filename))
self.set_header("content-type", "application/octet-stream")
Expand Down Expand Up @@ -169,7 +198,6 @@ class ExtractArchiveHandler(IPythonHandler):
@web.authenticated
@gen.coroutine
def get(self, archive_path, include_body=False):

# /extract-archive/ requests must originate from the same site
self.check_xsrf_cookie()
cm = self.contents_manager
Expand All @@ -178,15 +206,13 @@ def get(self, archive_path, include_body=False):
self.log.info("Refusing to serve hidden file, via 404 Error")
raise web.HTTPError(404)

archive_path = os.path.join(cm.root_dir, url2path(archive_path))
archive_path = pathlib.Path(archive_path)
archive_path = pathlib.Path(cm.root_dir) / url2path(archive_path)

yield ioloop.IOLoop.current().run_in_executor(None, self.extract_archive, archive_path)

self.finish()

def extract_archive(self, archive_path):

archive_destination = archive_path.parent
self.log.info("Begin extraction of {} to {}.".format(archive_path, archive_destination))

Expand Down