Skip to content

Commit

Permalink
util: net: Change cached_download() and derivatives to functions
Browse files Browse the repository at this point in the history
Fixes: #1056

Signed-off-by: mHash1m <hashimchaudry23@gmail.com>
  • Loading branch information
mhash1m committed Apr 27, 2021
1 parent 64d8797 commit 135d8cb
Show file tree
Hide file tree
Showing 16 changed files with 228 additions and 359 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Calls to hashlib now go through helper functions
- Build docs using `dffml service dev docs`
- `cached_download/unpack_archive()` are now functions
### Fixed
- Record object key properties are now always strings

Expand Down
4 changes: 2 additions & 2 deletions dffml/service/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,11 +1164,11 @@ async def run(self):

copybutton_path = pages_path / "_static" / "copybutton.js"

cached_download(
await cached_download(
"https://raw.githubusercontent.com/python/python-docs-theme/master/python_docs_theme/static/copybutton.js",
copybutton_path,
"061b550f64fb65ccb73fbe61ce15f49c17bc5f30737f42bf3c9481c89f7996d0004a11bf283d6bd26cf0b65130fc1d4b",
).add_target_to_args_and_validate([])
)

nojekyll_path = pages_path / ".nojekyll"
nojekyll_path.touch(exist_ok=True)
Expand Down
241 changes: 45 additions & 196 deletions dffml/util/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import dataclasses
import urllib.request
from typing import List, Union

from .os import chdir
from .file import validate_file_hash
from .log import LOGGER, get_download_logger
Expand Down Expand Up @@ -145,106 +144,20 @@ def sync_urlretrieve_and_validate(
):
if not target_path.parent.is_dir():
target_path.parent.mkdir(parents=True)
sync_urlretrieve(
path, _ = sync_urlretrieve(
url,
filename=str(target_path),
protocol_allowlist=protocol_allowlist,
reporthook=progress_reporthook,
)
return path
validate_file_hash(
target_path, expected_sha384_hash=expected_sha384_hash,
)
return target_path.absolute()


@dataclasses.dataclass
class CachedDownloadWrapper:
url: Union[str, urllib.request.Request]
target_path: Union[str, pathlib.Path]
expected_hash: str
protocol_allowlist: List[str] = dataclasses.field(
default_factory=lambda: DEFAULT_PROTOCOL_ALLOWLIST
)

def __post_init__(self):
self.target_path = pathlib.Path(self.target_path)

def __call__(self, func):
if inspect.isasyncgenfunction(func) and hasattr(func, "__aenter__"):

@contextlib.asynccontextmanager
async def wrapped(*args, **kwargs):
async with func(
*self.add_target_to_args_and_validate(args), **kwargs,
) as result:
yield result

elif inspect.isasyncgenfunction(func):

async def wrapped(*args, **kwargs):
async for result in func(
*self.add_target_to_args_and_validate(args), **kwargs,
):
yield result

elif inspect.iscoroutinefunction(func):

async def wrapped(*args, **kwargs):
return await func(
*self.add_target_to_args_and_validate(args), **kwargs,
)

elif inspect.isgeneratorfunction(func) and hasattr(func, "__enter__"):

@contextlib.contextmanager
def wrapped(*args, **kwargs):
with func(
*self.add_target_to_args_and_validate(args), **kwargs,
) as result:
yield result

elif inspect.isgeneratorfunction(func):

def wrapped(*args, **kwargs):
yield from func(
*self.add_target_to_args_and_validate(args), **kwargs,
)

else:

def wrapped(*args, **kwargs):
return func(
*self.add_target_to_args_and_validate(args), **kwargs
)

# Wrap with functools
wrapped = functools.wraps(func)(wrapped)

return wrapped

def __enter__(self):
self.add_target_to_args_and_validate([])
return self.target_path

def __exit__(self, _exc_type, _exc_value, _traceback):
pass

async def __aenter__(self):
return self.__enter__()

async def __aexit__(self, _exc_type, _exc_value, _traceback):
pass

def add_target_to_args_and_validate(self, args):
sync_urlretrieve_and_validate(
self.url,
self.target_path,
expected_sha384_hash=self.expected_hash,
protocol_allowlist=self.protocol_allowlist,
)
return list(args) + [self.target_path]


def cached_download(
async def cached_download(
url: Union[str, urllib.request.Request],
target_path: Union[str, pathlib.Path],
expected_hash: str,
Expand Down Expand Up @@ -282,86 +195,29 @@ def cached_download(
--------
>>> import asyncio
>>> import contextlib
>>> from dffml import cached_download
>>> from dffml import *
>>>
>>> cached_manifest = cached_download(
... "https://github.com/intel/dffml/raw/152c2b92535fac6beec419236f8639b0d75d707d/MANIFEST.in",
... "MANIFEST.in",
... "f7aadf5cdcf39f161a779b4fa77ec56a49630cf7680e21fb3dc6c36ce2d8c6fae0d03d5d3094a6aec4fea1561393c14c",
>>> cached_manifest = asyncio.run(
... cached_download(
... "https://github.com/intel/dffml/raw/152c2b92535fac6beec419236f8639b0d75d707d/MANIFEST.in",
... "MANIFEST.in",
... "f7aadf5cdcf39f161a779b4fa77ec56a49630cf7680e21fb3dc6c36ce2d8c6fae0d03d5d3094a6aec4fea1561393c14c",
... )
... )
>>>
>>> @cached_manifest
... async def first_line_in_manifest_152c2b(manifest):
... return manifest.read_text().split()[:2]
>>>
>>> asyncio.run(first_line_in_manifest_152c2b())
['include', 'README.md']
>>>
>>> @cached_manifest
... def first_line_in_manifest_152c2b(manifest):
... return manifest.read_text().split()[:2]
>>>
>>> first_line_in_manifest_152c2b()
['include', 'README.md']
>>>
>>> @cached_manifest
... def first_line_in_manifest_152c2b(manifest):
... yield manifest.read_text().split()[:2]
>>>
>>> for contents in first_line_in_manifest_152c2b():
... print(contents)
['include', 'README.md']
>>>
>>> @cached_manifest
... async def first_line_in_manifest_152c2b(manifest):
... yield manifest.read_text().split()[:2]
>>>
>>> async def main():
... async for contents in first_line_in_manifest_152c2b():
... print(contents)
>>>
>>> asyncio.run(main())
['include', 'README.md']
>>>
>>> @cached_manifest
... @contextlib.contextmanager
... def first_line_in_manifest_152c2b(manifest):
... yield manifest.read_text().split()[:2]
>>>
>>> with first_line_in_manifest_152c2b() as contents:
... print(contents)
['include', 'README.md']
>>>
>>> @cached_manifest
... @contextlib.asynccontextmanager
... async def first_line_in_manifest_152c2b(manifest):
... yield manifest.read_text().split()[:2]
>>>
>>> async def main():
... async with first_line_in_manifest_152c2b() as contents:
... print(contents)
>>>
>>> asyncio.run(main())
['include', 'README.md']
>>>
... with cached_manifest as manifest_path:
... print(manifest_path.read_text().split()[:2])
['include', 'README.md']
>>>
>>> async def main():
... async with cached_manifest as manifest_path:
... print(manifest_path.read_text().split()[:2])
>>>
>>> asyncio.run(main())
>>> with open(cached_manifest) as manifest:
... print(manifest.read().split()[:2])
['include', 'README.md']
"""
return CachedDownloadWrapper(
url, target_path, expected_hash, protocol_allowlist=protocol_allowlist,
return sync_urlretrieve_and_validate(
url,
target_path,
expected_sha384_hash=expected_hash,
protocol_allowlist=protocol_allowlist,
)


def cached_download_unpack_archive(
async def cached_download_unpack_archive(
url,
file_path,
directory_path,
Expand Down Expand Up @@ -406,17 +262,16 @@ def cached_download_unpack_archive(
>>> import asyncio
>>> from dffml import cached_download_unpack_archive
>>>
>>> @cached_download_unpack_archive(
... "https://github.com/intel/dffml/archive/c4469abfe6007a50144858d485537324046ff229.tar.gz",
... "dffml.tar.gz",
... "dffml",
... "bb9bb47c4e6e4c6b7147bb3c000bc4069d69c0c77a3e560b69f476a78e6b5084adf5467ee83cbbcc47ba5a4a0696fdfc",
>>> dffml_dir = asyncio.run(
... cached_download_unpack_archive(
... "https://github.com/intel/dffml/archive/c4469abfe6007a50144858d485537324046ff229.tar.gz",
... "dffml.tar.gz",
... "dffml",
... "bb9bb47c4e6e4c6b7147bb3c000bc4069d69c0c77a3e560b69f476a78e6b5084adf5467ee83cbbcc47ba5a4a0696fdfc",
... )
... )
... async def files_in_dffml_commit_c4469a(dffml_dir):
... return len(list(dffml_dir.rglob("**/*")))
>>>
>>> asyncio.run(files_in_dffml_commit_c4469a())
124
>>> print(len(list(dffml_dir.rglob("**/*"))))
495
"""

def on_error(func, path, exc_info):
Expand All @@ -432,26 +287,20 @@ def on_error(func, path, exc_info):
directory_path = pathlib.Path(directory_path)

async def extractor(download_path):
download_path = download_path.absolute()
with chdir(directory_path):
try:
shutil.unpack_archive(str(download_path), ".")
except Exception as error:
shutil.rmtree(directory_path, onerror=on_error)
raise DirectoryNotExtractedError(directory_path) from error

extract = cached_download(
url, file_path, expected_hash, protocol_allowlist=protocol_allowlist
)(extractor)

def mkwrapper(func):
@functools.wraps(func)
async def wrapper(*args, **kwds):
if not directory_path.is_dir():
directory_path.mkdir(parents=True)
await extract()
return await func(*(list(args) + [directory_path]), **kwds)

return wrapper

return mkwrapper
try:
shutil.unpack_archive(str(download_path), str(directory_path))
except Exception as error:
shutil.rmtree(directory_path, onerror=on_error)
raise DirectoryNotExtractedError(directory_path) from error

if not directory_path.is_dir():
directory_path.mkdir(parents=True)
await extractor(
await cached_download(
url,
file_path,
expected_hash,
protocol_allowlist=protocol_allowlist,
)
)
return directory_path.absolute()
31 changes: 16 additions & 15 deletions examples/or_covid_data_by_county.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,22 @@ async def predict(self, sources: SourcesContext) -> AsyncIterator[Record]:
yield record


# DFFML has a function to download files and validate their contents using
# SHA 384 hashes. If you need to download files from an http:// site, you need
# to add the following to the call to cached_download()
# protocol_allowlist=["https://", "http://"]
@cached_download(
"https://github.com/intel/dffml/files/5773999/COVID.Oregon.Counties.Train.Clean.to.2020-10-24.csv.gz",
"training.csv.gz",
"af9536ab41580e04dd72b1285f6b2b703977aee5b95b80422bbe7cc11262297da265e6c0e333bfc1faa7b4f263f5496e",
)
@cached_download(
"https://github.com/intel/dffml/files/5773998/COVID.Oregon.Counties.Test.Clean.2020-10-25.to.2020-10-31.csv.gz",
"test.csv.gz",
"10ee8bcf06a511019f98c3e0e40f315585b2ed84d4a736f743567861d72438afcb7914f117e16640800959324f0f518d",
)
async def main(training_file, test_file):
async def main():
# DFFML has a function to download files and validate their contents using
# SHA 384 hashes. If you need to download files from an http:// site, you need
# to add the following to the call to cached_download()
# protocol_allowlist=["https://", "http://"]
training_file = await cached_download(
"https://github.com/intel/dffml/files/5773999/COVID.Oregon.Counties.Train.Clean.to.2020-10-24.csv.gz",
"training.csv.gz",
"af9536ab41580e04dd72b1285f6b2b703977aee5b95b80422bbe7cc11262297da265e6c0e333bfc1faa7b4f263f5496e",
)
testing_file = await cached_download(
"https://github.com/intel/dffml/files/5773998/COVID.Oregon.Counties.Test.Clean.2020-10-25.to.2020-10-31.csv.gz",
"test.csv.gz",
"10ee8bcf06a511019f98c3e0e40f315585b2ed84d4a736f743567861d72438afcb7914f117e16640800959324f0f518d",
)

# Load the training data
training_data = [record async for record in load(training_file)]

Expand Down

0 comments on commit 135d8cb

Please sign in to comment.