Skip to content

Commit

Permalink
build in s3:// downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jun 30, 2020
1 parent 835ff0e commit 6abaed4
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 37 deletions.
7 changes: 7 additions & 0 deletions WDL/runtime/config_templates/default.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ enable_patterns = ["*"]
disable_patterns = ["*.php", "*.aspx"]


[download_awscli]
# When a s3:// URI is supplied for a File input, attempt to load AWS credentials using boto3 on the
# miniwdl host. If disabled, the downloader task might still get credentials from metadata service
# if running in EC2. Failing that, public S3 objects can be accessed.
host_credentials = true


[plugins]
# Control which plugins are used. Plugins are installed using the Python entry points convention,
# https://packaging.python.org/specifications/entry-points/
Expand Down
165 changes: 131 additions & 34 deletions WDL/runtime/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,29 @@
import os
import logging
import traceback
import tempfile
import hashlib
import importlib_metadata
from contextlib import ExitStack
from typing import Optional, List, Generator, Dict, Any, Tuple, Callable
from . import config
from .cache import CallCache
from .._util import compose_coroutines
from .._util import StructuredLogMessage as _

# WDL tasks for downloading a file based on its URI scheme


def aria2c_downloader(
cfg: config.Loader, logger: logging.Logger, uri: str, **kwargs
) -> Generator[Dict[str, Any], Dict[str, Any], None]:
wdl = r"""
task aria2c {
input {
String uri
Int connections = 10
}
command <<<
set -euxo pipefail
mkdir __out
cd __out
aria2c -x ~{connections} -s ~{connections} \
--file-allocation=none --retry-wait=2 --stderr=true --enable-color=false \
"~{uri}"
>>>
output {
File file = glob("__out/*")[0]
}
runtime {
cpu: 4
memory: "1G"
docker: "hobbsau/aria2"
}
}
"""
recv = yield {"task_wdl": wdl, "inputs": {"uri": uri}}
yield recv # pyre-ignore


def _load(cfg: config.Loader):
table = getattr(cfg, "_downloaders", None)
if table:
return table

# default public URI downloaders
table = {"https": aria2c_downloader, "http": aria2c_downloader, "ftp": aria2c_downloader}
table = {
"https": aria2c_downloader,
"http": aria2c_downloader,
"ftp": aria2c_downloader,
"s3": awscli_downloader,
}

# plugins
for plugin_name, plugin_fn in config.load_plugins(cfg, "file_download"):
Expand Down Expand Up @@ -156,6 +131,128 @@ def run_cached(
return False, cache.put_download(uri, os.path.realpath(filename), logger=logger)


# WDL tasks for downloading a file based on its URI scheme


def aria2c_downloader(
cfg: config.Loader, logger: logging.Logger, uri: str, **kwargs
) -> Generator[Dict[str, Any], Dict[str, Any], None]:
wdl = r"""
task aria2c {
input {
String uri
Int connections = 10
}
command <<<
set -euxo pipefail
mkdir __out
cd __out
aria2c -x ~{connections} -s ~{connections} \
--file-allocation=none --retry-wait=2 --stderr=true --enable-color=false \
"~{uri}"
>>>
output {
File file = glob("__out/*")[0]
}
runtime {
cpu: 4
memory: "1G"
docker: "hobbsau/aria2"
}
}
"""
recv = yield {"task_wdl": wdl, "inputs": {"uri": uri}}
yield recv # pyre-ignore


def awscli_downloader(
cfg: config.Loader, logger: logging.Logger, uri: str, **kwargs
) -> Generator[Dict[str, Any], Dict[str, Any], None]:

# get AWS credentials from boto3 (unless prevented by configuration)
host_aws_credentials = None
if cfg["download_awscli"].get_bool("host_credentials"):
try:
import boto3 # pyre-fixme

b3creds = boto3.session.Session().get_credentials()
host_aws_credentials = "\n".join(
f"export {k}='{v}'"
for (k, v) in {
"AWS_ACCESS_KEY_ID": b3creds.access_key,
"AWS_SECRET_ACCESS_KEY": b3creds.secret_key,
"AWS_SESSION_TOKEN": b3creds.token,
}.items()
if v
)
if host_aws_credentials:
logger.getChild("awscli_downloader").info(
"using host's AWS credentials; to disable, configure [download_awscli] host_credentials=false (MINIWDL__DOWNLOAD_AWSCLI__HOST_CREDENTIALS=false)"
)
except Exception:
pass
if not host_aws_credentials:
logger.getChild("awscli_downloader").warning(
"no AWS credentials available on host; if needed, install awscli+boto3 and `aws configure`"
)

inputs = {"uri": uri}
with ExitStack() as cleanup:
if host_aws_credentials:
# write credentials to temp file that'll self-destruct afterwards
aws_credentials_file = cleanup.enter_context(
tempfile.NamedTemporaryFile(
prefix=hashlib.sha256(host_aws_credentials.encode()).hexdigest(),
delete=True,
mode="w",
)
)
print(host_aws_credentials, file=aws_credentials_file, flush=True)
# make file group-readable to ensure it'll be usable if the docker image runs as non-root
os.chmod(aws_credentials_file.name, os.stat(aws_credentials_file.name).st_mode | 0o40)
inputs["aws_credentials"] = aws_credentials_file.name

wdl = r"""
task aws_s3_cp {
input {
String uri
File? aws_credentials
}
command <<<
set -euo pipefail
if [ -n "~{aws_credentials}" ]; then
source "~{aws_credentials}"
fi
args=""
if ! aws sts get-caller-identity >&2 ; then
# no credentials or instance role; add --no-sign-request to allow requests for
# PUBLIC objects to proceed.
args="--no-sign-request"
fi
mkdir __out
cd __out
aws s3 cp $args "~{uri}" .
>>>
output {
File file = glob("__out/*")[0]
}
runtime {
cpu: 4
memory: "1G"
docker: "amazon/aws-cli"
}
}
"""
recv = yield {
"task_wdl": wdl,
"inputs": inputs,
}
yield recv # pyre-ignore


def gsutil_downloader(
cfg: config.Loader, logger: logging.Logger, uri: str, **kwargs
) -> Generator[Dict[str, Any], Dict[str, Any], None]:
Expand Down
7 changes: 4 additions & 3 deletions tests/test_7runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,14 @@ def test_download_cache3(self, capture):
"dir": os.path.join(self._dir, "cache"),
}
})
inp = {"files": ["https://raw.githubusercontent.com/chanzuckerberg/miniwdl/master/tests/alyssa_ben.txt?xxx"]}
inp = {"files": ["s3://1000genomes/CHANGELOG", "https://raw.githubusercontent.com/chanzuckerberg/miniwdl/master/tests/alyssa_ben.txt?xxx"]}
self._run(self.count_wdl, inp, cfg=cfg)
self._run(self.count_wdl, inp, cfg=cfg)
logs = [str(record.msg) for record in capture.records if str(record.msg).startswith("downloaded input files")]
# cache isn't used due to presence of query string
self.assertTrue("downloaded: 1" in logs[0])
# cache isn't used for alyssa_ben.txt due to presence of query string
self.assertTrue("downloaded: 2" in logs[0])
self.assertTrue("downloaded: 1" in logs[1])
assert next(record for record in capture.records if "AWS credentials" in str(record.msg))

def test_download_cache4(self):
cfg = WDL.runtime.config.Loader(logging.getLogger(self.id()))
Expand Down

0 comments on commit 6abaed4

Please sign in to comment.