Skip to content

Commit

Permalink
Merge 4a7f9a2 into 5fd9809
Browse files Browse the repository at this point in the history
  • Loading branch information
ncilfone committed May 5, 2021
2 parents 5fd9809 + 4a7f9a2 commit 4b8cbb2
Show file tree
Hide file tree
Showing 20 changed files with 297 additions and 374 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Created by .ignore support plugin (hsz.mobi)
### Python template

# Debugging folder
debug/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,6 @@
keywords=['configuration', 'argparse', 'parameters', 'machine learning', 'deep learning', 'reproducibility'],
packages=setuptools.find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
python_requires='>=3.6',
install_requires=install_reqs
install_requires=install_reqs,
extras_require={'s3': ['boto3', 'botocore', 'hurry.filesize']}
)
13 changes: 13 additions & 0 deletions spock/addons/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-

# Copyright 2019 FMR LLC <opensource@fidelity.com>
# SPDX-License-Identifier: Apache-2.0

"""
Spock is a framework that helps manage complex parameter configurations for Python applications
Please refer to the documentation provided in the README.md
"""
from spock.addons.s3.utils import S3Config

__all__ = ["s3", "S3Config"]
12 changes: 12 additions & 0 deletions spock/addons/s3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-

# Copyright 2019 FMR LLC <opensource@fidelity.com>
# SPDX-License-Identifier: Apache-2.0

"""
Spock is a framework that helps manage complex parameter configurations for Python applications
Please refer to the documentation provided in the README.md
"""

__all__ = ["utils"]
130 changes: 130 additions & 0 deletions spock/addons/s3/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-

# Copyright 2019 FMR LLC <opensource@fidelity.com>
# SPDX-License-Identifier: Apache-2.0

"""Handles all S3 related ops -- allows for s3 functionality to be optional to keep req deps light"""

import attr
try:
import boto3
from botocore.client import BaseClient
except ImportError:
print('Missing libraries to support S3 functionality. Please re-install spock with the extra s3 dependencies -- '
'pip install spock-config[s3]')
from hurry.filesize import size
import os
from urllib.parse import urlparse
import sys
import typing


@attr.s(auto_attribs=True)
class S3Config:
"""Configuration class for S3 support
*Attributes*:
session: instantiated boto3 session object
s3_session: automatically generated s3 client from the boto3 session
kms_arn: AWS KMS key ARN (optional)
temp_folder: temporary working folder to write/read spock configuration(s) (optional: defaults to /tmp)
"""
session: boto3.Session
s3_session: BaseClient = attr.ib(init=False)
kms_arn: typing.Optional[str] = None
temp_folder: typing.Optional[str] = '/tmp/'

def __attrs_post_init__(self):
self.s3_session = self.session.client('s3')


def handle_s3_load_path(path: str, s3_config: S3Config) -> str:
"""Handles loading from S3 uri
Handles downloading file from a given s3 uri to a local temp location and passing the path back to the handler
load call
*Args*:
path: s3 uri path
s3_config: s3_config object
*Returns*:
temp_path: the temporary path of the config file downloaded from s3
"""
if s3_config is None:
raise ValueError('Missing S3Config object which is necessary to handle S3 style paths')
bucket, obj, fid = get_s3_bucket_object_name(s3_path=path)
# Construct the full temp path
temp_path = f'{s3_config.temp_folder}/{fid}'
# Strip double slashes if exist
temp_path = temp_path.replace(r'//', r'/')
temp_path = download_s3(bucket=bucket, obj=obj, temp_path=temp_path, s3_session=s3_config.s3_session)
return temp_path


def get_s3_bucket_object_name(s3_path: str) -> typing.Tuple[str, str, str]:
"""Splits a S3 uri into bucket, object, name
*Args*:
s3_path: s3 uri
*Returns*:
bucket
object
name
"""
parsed = urlparse(s3_path)
return parsed.netloc, parsed.path.lstrip('/'), os.path.basename(parsed.path)


def download_s3(bucket: str, obj: str, temp_path: str, s3_session: BaseClient) -> str:
"""Attempts to download the file from the S3 uri to a temp location
*Args*:
bucket: s3 bucket
obj: s3 object
temp_path: local temporary path to write file
s3_session: current s3 session
*Returns*:
temp_path: the temporary path of the config file downloaded from s3
"""
try:
file_size = s3_session.head_object(Bucket=bucket, Key=obj)['ContentLength']
print(f'Attempting to download s3://{bucket}/{obj} (size: {size(file_size)})')
current_progress = 0
n_ticks = 50

def _s3_progress_bar(chunk):
nonlocal current_progress
# Increment progress
current_progress += chunk
done = int(n_ticks * (current_progress / file_size))
sys.stdout.write(f"\r[%s%s] "
f"{int(current_progress/file_size) * 100}%%" % ('=' * done, ' ' * (n_ticks - done)))
sys.stdout.flush()
sys.stdout.write('\n\n')
# Download with the progress callback
s3_session.download_file(bucket, obj, temp_path, Callback=_s3_progress_bar)
return temp_path
except IOError:
print(f'Failed to download file from S3 '
f'(bucket: {bucket}, object: {obj}) '
f'and write to {temp_path}')


def upload_s3(self):
# Here it should upload to S3 from the written path (/tmp?)
# How to manage KMS or if file is encrypted? Config obj? Would the session have it already
pass
6 changes: 3 additions & 3 deletions spock/backend/attr/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class AttrPayload(BasePayload):
_loaders: maps of each file extension to the loader class
"""
def __init__(self):
super().__init__()
def __init__(self, s3_config=None):
super().__init__(s3_config=s3_config)

def __call__(self, *args, **kwargs):
"""Call to allow self chaining
Expand All @@ -39,7 +39,7 @@ def __call__(self, *args, **kwargs):
Payload: instance of self
"""
return AttrPayload()
return AttrPayload(*args, **kwargs)

@staticmethod
def _update_payload(base_payload, input_classes, payload):
Expand Down
6 changes: 3 additions & 3 deletions spock/backend/attr/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class AttrSaver(BaseSaver):
_writers: maps file extension to the correct i/o handler
"""
def __init__(self):
super().__init__()
def __init__(self, s3_config=None):
super().__init__(s3_config=s3_config)

def __call__(self, *args, **kwargs):
return AttrSaver()
return AttrSaver(*args, **kwargs)

def _clean_up_values(self, payload, file_extension):
# Dictionary to recursively write to
Expand Down
69 changes: 45 additions & 24 deletions spock/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from spock.handlers import TOMLHandler
from spock.handlers import YAMLHandler
from spock.utils import add_info
from spock.utils import check_path_s3
from spock.utils import make_argument
from typing import List

Expand All @@ -40,7 +41,26 @@ def __repr__(self):
return yaml.dump(self.__dict__, default_flow_style=False)


class BaseSaver(ABC): # pylint: disable=too-few-public-methods
class BaseHandler(ABC):
"""Base class for saver and payload
*Attributes*:
_writers: maps file extension to the correct i/o handler
_s3_config: optional S3Config object to handle s3 access
"""
def __init__(self, s3_config=None):
self._supported_extensions = {'.yaml': YAMLHandler, '.toml': TOMLHandler, '.json': JSONHandler}
self._s3_config = s3_config

def _check_extension(self, file_extension: str):
if file_extension not in self._supported_extensions:
raise TypeError(f'File extension {file_extension} not supported -- \n'
f'File extension must be from {list(self._supported_extensions.keys())}')


class BaseSaver(BaseHandler): # pylint: disable=too-few-public-methods
"""Base class for saving configs
Contains methods to build a correct output payload and then writes to file based on the file
Expand All @@ -49,10 +69,11 @@ class BaseSaver(ABC): # pylint: disable=too-few-public-methods
*Attributes*:
_writers: maps file extension to the correct i/o handler
_s3_config: optional S3Config object to handle s3 access
"""
def __init__(self):
self._writers = {'.yaml': YAMLHandler, '.toml': TOMLHandler, '.json': JSONHandler}
def __init__(self, s3_config=None):
super(BaseSaver, self).__init__(s3_config=s3_config)

def save(self, payload, path, file_name=None, create_save_path=False, extra_info=True, file_extension='.yaml'): #pylint: disable=too-many-arguments
"""Writes Spock config to file
Expand All @@ -74,9 +95,8 @@ def save(self, payload, path, file_name=None, create_save_path=False, extra_info
None
"""
supported_extensions = list(self._writers.keys())
if file_extension not in self._writers:
raise ValueError(f'Invalid fileout extension. Expected a fileout from {supported_extensions}')
# Check extension
self._check_extension(file_extension=file_extension)
# Make the filename
fname = str(uuid1()) if file_name is None else file_name
name = f'{fname}.spock.cfg{file_extension}'
Expand All @@ -89,7 +109,7 @@ def save(self, payload, path, file_name=None, create_save_path=False, extra_info
if not os.path.exists(path) and create_save_path:
os.makedirs(path)
with open(fid, 'w') as file_out:
self._writers.get(file_extension)().save(out_dict, extra_dict, file_out)
self._supported_extensions.get(file_extension)().save(out_dict, extra_dict, file_out)
except OSError as e:
print(f'Not a valid file path to write to: {fid}')
raise e
Expand Down Expand Up @@ -672,19 +692,20 @@ def _get_from_sys_modules(cls_name):
return module


class BasePayload(ABC): # pylint: disable=too-few-public-methods
class BasePayload(BaseHandler): # pylint: disable=too-few-public-methods
"""Handles building the payload for config file(s)
This class builds out the payload from config files of multiple types. It handles various
file types and also composition of config files via a recursive calls
file types and also composition of config files via recursive calls
*Attributes*:
_loaders: maps of each file extension to the loader class
__s3_config: optional S3Config object to handle s3 access
"""
def __init__(self):
self._loaders = {'.yaml': YAMLHandler(), '.toml': TOMLHandler(), '.json': JSONHandler()}
def __init__(self, s3_config=None):
super(BasePayload, self).__init__(s3_config=s3_config)

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -745,12 +766,10 @@ def _payload(self, input_classes, path, deps, root=False):
"""
# Match to loader based on file-extension
config_extension = Path(path).suffix.lower()
supported_extensions = list(self._loaders.keys())
if config_extension not in supported_extensions:
raise TypeError(f'File extension {config_extension} not supported\n'
f'Must be from {supported_extensions}')
# Verify extension
self._check_extension(file_extension=config_extension)
# Load from file
base_payload = self._loaders.get(config_extension).load(path)
base_payload = self._supported_extensions.get(config_extension)().load(path, s3_config=self._s3_config)
# Check and? update the dependencies
deps = self._handle_dependencies(deps, path, root)
payload = {}
Expand Down Expand Up @@ -796,7 +815,8 @@ def _handle_includes(self, base_payload, config_extension, input_classes, path,
"""Handles config composition
For all of the config tags in the config file this function will recursively call the payload function
with the composition path to get the additional payload(s) from the composed file(s)
with the composition path to get the additional payload(s) from the composed file(s) -- checks for file
validity or if it is an S3 URI via regex
*Args*:
Expand All @@ -814,14 +834,15 @@ def _handle_includes(self, base_payload, config_extension, input_classes, path,
"""
included_params = {}
for inc_path in base_payload['config']:
if not os.path.exists(inc_path):
# maybe it's relative?
abs_inc_path = os.path.join(os.path.dirname(path), inc_path)
if check_path_s3(inc_path):
use_path = inc_path
elif os.path.exists(inc_path):
use_path = inc_path
elif os.path.join(os.path.dirname(path), inc_path):
use_path = os.path.join(os.path.dirname(path), inc_path)
else:
abs_inc_path = inc_path
if not os.path.exists(abs_inc_path):
raise RuntimeError(f'Could not find included {config_extension} file {inc_path}!')
included_params.update(self._payload(input_classes, abs_inc_path, deps))
raise RuntimeError(f'Could not find included {config_extension} file {inc_path} or is not an S3 URI!')
included_params.update(self._payload(input_classes, use_path, deps))
payload.update(included_params)
return payload

Expand Down
Loading

0 comments on commit 4b8cbb2

Please sign in to comment.