diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..61d9445e --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,18 @@ +--- +name: Pull request +about: Create a pull request for merge + +--- + +## What does this PR do? +E.g. Describe the added feature or what issue it fixes #(issue)... + +## Checklist + - [ ] Did you adhere to [PEP-8](https://www.python.org/dev/peps/pep-0008/) standards? + - [ ] Did you run black and isort prior to submitting your PR? + - [ ] Does your PR pass all existing unit tests? + - [ ] Did you add associated unit tests for any additional functionality? + - [ ] Did you provide documentation ([Numpy Docstring format](https://numpydoc.readthedocs.io/en/latest/format.html#style-guide)) whenever possible, even for simple functions or classes? + +## Review +Request will go to reviewers to approve for merge. \ No newline at end of file diff --git a/.github/workflows/python-coverage.yaml b/.github/workflows/python-coverage.yaml index 2e9dcff5..da804bb6 100644 --- a/.github/workflows/python-coverage.yaml +++ b/.github/workflows/python-coverage.yaml @@ -21,11 +21,17 @@ jobs: with: python-version: '3.8' + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }} + - name: Install dependencies and dev dependencies run: | python -m pip install --upgrade pip - pip install -r DEV_REQUIREMENTS.txt - pip install -r S3_REQUIREMENTS.txt + pip install -r REQUIREMENTS.txt + pip install -r ./requirements/DEV_REQUIREMENTS.txt + pip install -r ./requirements/S3_REQUIREMENTS.txt - name: Test with pytest run: | diff --git a/.github/workflows/python-docs.yaml b/.github/workflows/python-docs.yaml index 88c90275..5df2ab7b 100644 --- a/.github/workflows/python-docs.yaml +++ b/.github/workflows/python-docs.yaml @@ -18,12 +18,18 @@ jobs: uses: actions/setup-python@v2 with: python-version: '3.8' + + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }} + - name: Install dependencies and dev dependencies run: | python -m pip install --upgrade pip pip install -e .[s3] - pip install -r DEV_REQUIREMENTS.txt - pip install -r S3_REQUIREMENTS.txt + pip install -r ./requirements/DEV_REQUIREMENTS.txt + pip install -r ./requirements/S3_REQUIREMENTS.txt - name: Build docs with Portray env: diff --git a/.github/workflows/python-lint.yaml b/.github/workflows/python-lint.yaml new file mode 100644 index 00000000..35d77eee --- /dev/null +++ b/.github/workflows/python-lint.yaml @@ -0,0 +1,42 @@ +# This workflow will run isort and black linters on PRs + +name: lint + +# on: workflow_dispatch +on: + pull_request: + branches: [master] + push: + branches: [master] + +jobs: + run_lint: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }} + + - name: Install dependencies and dev dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install -r ./requirements/DEV_REQUIREMENTS.txt + pip install -r ./requirements/S3_REQUIREMENTS.txt + + - name: Run isort linter + run: | + isort --check . --skip="debug" --skip="versioneer.py" --skip="tests" --skip="_version.py" + + - name: Run black linter + run: | + black --check . --exclude="versioneer.py|_version.py|debug|tests" diff --git a/.github/workflows/python-pytest-s3.yaml b/.github/workflows/python-pytest-s3.yaml index e359199e..7c542d71 100644 --- a/.github/workflows/python-pytest-s3.yaml +++ b/.github/workflows/python-pytest-s3.yaml @@ -23,12 +23,17 @@ jobs: with: python-version: ${{ matrix.python-version }} + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }} + - name: Install dependencies run: | python -m pip install --upgrade pip pip install -e . - pip install -r DEV_REQUIREMENTS.txt - pip install -r S3_REQUIREMENTS.txt + pip install -r ./requirements/DEV_REQUIREMENTS.txt + pip install -r ./requirements/S3_REQUIREMENTS.txt - name: Test with pytest run: | diff --git a/.github/workflows/python-pytest.yml b/.github/workflows/python-pytest.yml index 4f246e96..75cd5333 100644 --- a/.github/workflows/python-pytest.yml +++ b/.github/workflows/python-pytest.yml @@ -23,11 +23,16 @@ jobs: with: python-version: ${{ matrix.python-version }} + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/DEV_REQUIREMENTS.txt') }}-${{ hashFiles('./requirements/S3_REQUIREMENTS.txt') }} + - name: Install dependencies run: | python -m pip install --upgrade pip pip install -e . - pip install -r DEV_REQUIREMENTS.txt + pip install -r ./requirements/DEV_REQUIREMENTS.txt - name: Test with pytest run: | diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index dd05fe30..d3036f40 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,6 +4,7 @@ Requests in the public repository. ## Contribution Guidelines 1. Adhere to [PEP-8](https://www.python.org/dev/peps/pep-0008/) standards. -2. Any changes to core functionality must pass all existing unit tests. -3. Additional functionality should have associated unit tests. -4. Provide documentation (Google Docstring format) whenever possible, even for simple functions or classes. \ No newline at end of file +2. Run black and isort linters before creating a PR. +3. Any changes to core functionality must pass all existing unit tests. +4. Additional functionality should have associated unit tests. +5. Provide documentation ([Numpy Docstring format](https://numpydoc.readthedocs.io/en/latest/format.html#style-guide)) whenever possible, even for simple functions or classes. \ No newline at end of file diff --git a/NOTICE.txt b/NOTICE.txt index 60910e68..21be5235 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -12,5 +12,5 @@ FMR LLC (https://www.fidelity.com/). This product relies on the following works (and the dependencies thereof), installed separately: - attrs | https://github.com/python-attrs/attrs | MIT License - GitPython | https://github.com/gitpython-developers/GitPython | BSD 3-Clause License -- PyYAML | https://github.com/yaml/pyyaml | MIT License -- toml | https://github.com/toml-lang/toml | MIT License \ No newline at end of file +- pytomlpp | https://github.com/bobfang1992/pytomlpp | MIT License +- PyYAML | https://github.com/yaml/pyyaml | MIT License \ No newline at end of file diff --git a/README.md b/README.md index 5afe2e18..55ee4fa8 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![License](https://img.shields.io/badge/License-Apache%202.0-9cf)](https://opensource.org/licenses/Apache-2.0) [![Python](https://img.shields.io/badge/python-3.6+-informational.svg)]() +[![Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![PyPI version](https://badge.fury.io/py/spock-config.svg)](https://badge.fury.io/py/spock-config) [![Coverage Status](https://coveralls.io/repos/github/fidelity/spock/badge.svg?branch=master)](https://coveralls.io/github/fidelity/spock?branch=master) ![Tests](https://github.com/fidelity/spock/workflows/pytest/badge.svg?branch=master) diff --git a/REQUIREMENTS.txt b/REQUIREMENTS.txt index 7db37490..b33e46e0 100644 --- a/REQUIREMENTS.txt +++ b/REQUIREMENTS.txt @@ -1,4 +1,4 @@ attrs GitPython -pyYAML -toml \ No newline at end of file +pytomlpp +pyYAML \ No newline at end of file diff --git a/docs/Quick-Start.md b/docs/Quick-Start.md index 2f64ac80..ba107af9 100644 --- a/docs/Quick-Start.md +++ b/docs/Quick-Start.md @@ -107,7 +107,7 @@ fancier_parameter: 64.64 most_fancy_parameter: [768, 768, 512, 128] ``` -Finally, we would run our script and pass the path to the configuration file to the command line (-c or --config): +Finally, we would run our script and pass the path to the configuration file to the command line (`-c` or `--config`): ```bash $ python simple.py -c simple.yaml @@ -131,4 +131,16 @@ configuration(s): fancy_parameter float parameter that multiplies a value fancier_parameter float parameter that gets added to product of val and fancy_parameter most_fancy_parameter List[int] values to apply basic algebra to +``` + +### Spock As a Drop In For Argparser + +`spock` can easily be used as a drop in for argparser. This means that all parameter definitions as required to come in +from the command line or from setting defaults within the `@spock` decorated classes. Simply do not pass a `-c` or +`--config` argument at the command line and instead pass in all of the automatically generated cmd-line arguments. + + +```bash +$ python simple.py --BasicConfig.parameter --BasicConfig.fancy_parameter 8.8 --BasicConfig.fancier_parameter 64.64 \ + --BasicConfig.most_fancy_parameter [768, 768, 512, 128] ``` \ No newline at end of file diff --git a/docs/addons/S3.md b/docs/addons/S3.md index d2ac5d32..36dfb89c 100644 --- a/docs/addons/S3.md +++ b/docs/addons/S3.md @@ -46,11 +46,11 @@ session = boto3.Session( ### Using the S3Config Object -As an example let's create a basic `@spock` decorated class, instantiate a `S3Config` object from `spock.addons` with +As an example let's create a basic `@spock` decorated class, instantiate a `S3Config` object from `spock.addons.s3` with the `boto3.session.Session` we created above, and pass it to the `ConfigArgBuilder`. ```python -from spock.addons import S3Config +from spock.addons.s3 import S3Config from spock.builder import ConfigArgBuilder from spock.config import spock from typing import List @@ -123,8 +123,8 @@ With a `S3Config` object passed into the `ConfigArgBuilder` the S3 URI will auto If you require any other settings for uploading or downloading files from S3 the `S3Config` class has two extra attributes: -`download_config` which takes a `S3DownloadConfig` object from `spock.addons` which supports all ExtraArgs from +`download_config` which takes a `S3DownloadConfig` object from `spock.addons.s3` which supports all ExtraArgs from [S3Transfer.ALLOWED_DOWNLOAD_ARGS](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_DOWNLOAD_ARGS) -`upload_config` which takes a `S3UploadConfig` object from `spock.addons` which supports all ExtraArgs from +`upload_config` which takes a `S3UploadConfig` object from `spock.addons.s3` which supports all ExtraArgs from [S3Transfer.ALLOWED_UPLOAD_ARGS](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS) diff --git a/docs/addons/tuner/Ax.md b/docs/addons/tuner/Ax.md new file mode 100644 index 00000000..15037ff3 --- /dev/null +++ b/docs/addons/tuner/Ax.md @@ -0,0 +1 @@ +Test Placeholder \ No newline at end of file diff --git a/docs/addons/tuner/Basics.md b/docs/addons/tuner/Basics.md new file mode 100644 index 00000000..15037ff3 --- /dev/null +++ b/docs/addons/tuner/Basics.md @@ -0,0 +1 @@ +Test Placeholder \ No newline at end of file diff --git a/docs/addons/tuner/Optuna.md b/docs/addons/tuner/Optuna.md new file mode 100644 index 00000000..777b7340 --- /dev/null +++ b/docs/addons/tuner/Optuna.md @@ -0,0 +1,5 @@ +Test Placeholder + +Ask & Tell Define and Run + +https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/009_ask_and_tell.html# \ No newline at end of file diff --git a/docs/advanced_features/Command-Line-Overrides.md b/docs/advanced_features/Command-Line-Overrides.md index 46efc070..128cbe94 100644 --- a/docs/advanced_features/Command-Line-Overrides.md +++ b/docs/advanced_features/Command-Line-Overrides.md @@ -117,5 +117,17 @@ We could override the parameters like so (note that the len must match the defin ```bash $ python tutorial.py --config tutorial.yaml --TypeConfig.nested_list.NestedListStuff.one [1,2] \ ---TypeConfig.nested_list.NestedListStuff.two [ciao,ciao] +--TypeConfig.nested_list.NestedListStuff.two ['ciao','ciao'] +``` + +### Spock As a Drop In For Argparser + +`spock` can easily be used as a drop in for argparser. This means that all parameter definitions as required to come in +from the command line or from setting defaults within the `@spock` decorated classes. Simply do not pass a `-c` or +`--config` argument at the command line and instead pass in all of the automatically generated cmd-line arguments. + + +```bash +$ python tutorial.py --TypeConfig.nested_list.NestedListStuff.one [1,2] \ + --TypeConfig.nested_list.NestedListStuff.two [ciao,ciao] ... ``` \ No newline at end of file diff --git a/examples/legacy/quick-start/simple.py b/examples/legacy/quick-start/simple.py index 290e364c..d979d111 100644 --- a/examples/legacy/quick-start/simple.py +++ b/examples/legacy/quick-start/simple.py @@ -13,7 +13,12 @@ class BasicConfig: def add_namespace(config): # Lets just do some basic algebra here - val_sum = sum([(config.fancy_parameter * val) + config.fancier_parameter for val in config.most_fancy_parameter]) + val_sum = sum( + [ + (config.fancy_parameter * val) + config.fancier_parameter + for val in config.most_fancy_parameter + ] + ) # If the boolean is true let's round if config.parameter: val_sum = round(val_sum) @@ -38,10 +43,14 @@ def main(): val_sum_namespace = add_namespace(config.BasicConfig) print(val_sum_namespace) # Or pass by parameter - val_sum_parameter = add_by_parameter(config.BasicConfig.fancy_parameter, config.BasicConfig.most_fancy_parameter, - config.BasicConfig.fancier_parameter, config.BasicConfig.parameter) + val_sum_parameter = add_by_parameter( + config.BasicConfig.fancy_parameter, + config.BasicConfig.most_fancy_parameter, + config.BasicConfig.fancier_parameter, + config.BasicConfig.parameter, + ) print(val_sum_parameter) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/legacy/tutorial/advanced/basic_nn.py b/examples/legacy/tutorial/advanced/basic_nn.py index 86015cb9..cf81e88c 100644 --- a/examples/legacy/tutorial/advanced/basic_nn.py +++ b/examples/legacy/tutorial/advanced/basic_nn.py @@ -6,12 +6,16 @@ class BasicNet(nn.Module): def __init__(self, model_config): super(BasicNet, self).__init__() # Make a dictionary of activation functions to select from - self.act_fncs = {'relu': nn.ReLU, 'gelu': nn.GELU, 'tanh': nn.Tanh} + self.act_fncs = {"relu": nn.ReLU, "gelu": nn.GELU, "tanh": nn.Tanh} self.use_act = self.act_fncs.get(model_config.activation)() # Define the layers manually (avoiding list comprehension for clarity) self.layer_1 = nn.Linear(model_config.n_features, model_config.hidden_sizes[0]) - self.layer_2 = nn.Linear(model_config.hidden_sizes[0], model_config.hidden_sizes[1]) - self.layer_3 = nn.Linear(model_config.hidden_sizes[1], model_config.hidden_sizes[2]) + self.layer_2 = nn.Linear( + model_config.hidden_sizes[0], model_config.hidden_sizes[1] + ) + self.layer_3 = nn.Linear( + model_config.hidden_sizes[1], model_config.hidden_sizes[2] + ) # Define some dropout layers self.dropout = [] if model_config.dropout is not None: diff --git a/examples/legacy/tutorial/advanced/tutorial.py b/examples/legacy/tutorial/advanced/tutorial.py index 7bf7e7be..c2aad1c7 100644 --- a/examples/legacy/tutorial/advanced/tutorial.py +++ b/examples/legacy/tutorial/advanced/tutorial.py @@ -1,8 +1,9 @@ +import torch from basic_nn import BasicNet + from spock.args import * from spock.builder import ConfigArgBuilder from spock.config import spock_config -import torch @spock_config @@ -11,8 +12,8 @@ class ModelConfig: n_features: IntArg dropout: ListOptArg[float] hidden_sizes: TupleArg[int] = TupleArg.defaults((32, 32, 32)) - activation: ChoiceArg(choice_set=['relu', 'gelu', 'tanh'], default='relu') - optimizer: ChoiceArg(choice_set=['SGD', 'Adam']) + activation: ChoiceArg(choice_set=["relu", "gelu", "tanh"], default="relu") + optimizer: ChoiceArg(choice_set=["SGD", "Adam"]) cache_path: StrOptArg @@ -38,42 +39,61 @@ class SGDConfig(OptimizerConfig): def train(x_data, y_data, model, model_config, data_config, optimizer_config): - if model_config.optimizer == 'SGD': - optimizer = torch.optim.SGD(model.parameters(), lr=optimizer_config.lr, momentum=optimizer_config.momentum, - nesterov=optimizer_config.nesterov) - elif model_config.optimizer == 'Adam': + if model_config.optimizer == "SGD": + optimizer = torch.optim.SGD( + model.parameters(), + lr=optimizer_config.lr, + momentum=optimizer_config.momentum, + nesterov=optimizer_config.nesterov, + ) + elif model_config.optimizer == "Adam": optimizer = torch.optim.Adam(model.parameters(), lr=optimizer_config.lr) else: - raise ValueError(f'Optimizer choice {optimizer_config.optimizer} not available') + raise ValueError(f"Optimizer choice {optimizer_config.optimizer} not available") n_steps_per_epoch = data_config.n_samples % data_config.batch_size for epoch in range(optimizer_config.n_epochs): for i in range(n_steps_per_epoch): # Ugly data slicing for simplicity - x_batch = x_data[i * n_steps_per_epoch:(i + 1) * n_steps_per_epoch, ] - y_batch = y_data[i * n_steps_per_epoch:(i + 1) * n_steps_per_epoch, ] + x_batch = x_data[ + i * n_steps_per_epoch : (i + 1) * n_steps_per_epoch, + ] + y_batch = y_data[ + i * n_steps_per_epoch : (i + 1) * n_steps_per_epoch, + ] optimizer.zero_grad() output = model(x_batch) loss = torch.nn.CrossEntropyLoss(output, y_batch) loss.backward() if optimizer_config.grad_clip: - torch.nn.utils.clip_grad_value(model.parameters(), optimizer_config.grad_clip) + torch.nn.utils.clip_grad_value( + model.parameters(), optimizer_config.grad_clip + ) optimizer.step() - print(f'Finished Epoch {epoch+1}') + print(f"Finished Epoch {epoch+1}") def main(): # A simple description - description = 'spock Advanced Tutorial' + description = "spock Advanced Tutorial" # Build out the parser by passing in Spock config objects as *args after description - config = ConfigArgBuilder(ModelConfig, DataConfig, SGDConfig, desc=description).generate() + config = ConfigArgBuilder( + ModelConfig, DataConfig, SGDConfig, desc=description + ).generate() # Instantiate our neural net using basic_nn = BasicNet(model_config=config.ModelConfig) # Make some random data (BxH): H has dim of features in x_data = torch.rand(config.DataConfig.n_samples, config.ModelConfig.n_features) y_data = torch.randint(0, 3, (config.DataConfig.n_samples,)) # Run some training - train(x_data, y_data, basic_nn, config.ModelConfig, config.DataConfig, config.SGDConfig) + train( + x_data, + y_data, + basic_nn, + config.ModelConfig, + config.DataConfig, + config.SGDConfig, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/legacy/tutorial/basic/basic_nn.py b/examples/legacy/tutorial/basic/basic_nn.py index f4dc96b3..fdbe4a1c 100644 --- a/examples/legacy/tutorial/basic/basic_nn.py +++ b/examples/legacy/tutorial/basic/basic_nn.py @@ -6,12 +6,16 @@ class BasicNet(nn.Module): def __init__(self, model_config): super(BasicNet, self).__init__() # Make a dictionary of activation functions to select from - self.act_fncs = {'relu': nn.ReLU, 'gelu': nn.GELU, 'tanh': nn.Tanh} + self.act_fncs = {"relu": nn.ReLU, "gelu": nn.GELU, "tanh": nn.Tanh} self.use_act = self.act_fncs.get(model_config.activation)() # Define the layers manually (avoiding list comprehension for clarity) self.layer_1 = nn.Linear(model_config.n_features, model_config.hidden_sizes[0]) - self.layer_2 = nn.Linear(model_config.hidden_sizes[0], model_config.hidden_sizes[1]) - self.layer_3 = nn.Linear(model_config.hidden_sizes[1], model_config.hidden_sizes[2]) + self.layer_2 = nn.Linear( + model_config.hidden_sizes[0], model_config.hidden_sizes[1] + ) + self.layer_3 = nn.Linear( + model_config.hidden_sizes[1], model_config.hidden_sizes[2] + ) # Define some dropout layers self.dropout_1 = nn.Dropout(model_config.dropout[0]) self.dropout_2 = nn.Dropout(model_config.dropout[1]) diff --git a/examples/legacy/tutorial/basic/tutorial.py b/examples/legacy/tutorial/basic/tutorial.py index ae7ae324..a75a8418 100644 --- a/examples/legacy/tutorial/basic/tutorial.py +++ b/examples/legacy/tutorial/basic/tutorial.py @@ -1,8 +1,9 @@ +import torch from basic_nn import BasicNet + from spock.args import * from spock.builder import ConfigArgBuilder from spock.config import spock_config -import torch @spock_config @@ -11,15 +12,18 @@ class ModelConfig: n_features: IntArg dropout: ListArg[float] hidden_sizes: TupleArg[int] - activation: ChoiceArg(choice_set=['relu', 'gelu', 'tanh']) + activation: ChoiceArg(choice_set=["relu", "gelu", "tanh"]) def main(): # A simple description - description = 'spock Tutorial' + description = "spock Tutorial" # Build out the parser by passing in Spock config objects as *args after description - config = ConfigArgBuilder( - ModelConfig, desc=description, create_save_path=True).save(file_extension='.toml').generate() + config = ( + ConfigArgBuilder(ModelConfig, desc=description, create_save_path=True) + .save(file_extension=".toml") + .generate() + ) # Instantiate our neural net using basic_nn = BasicNet(model_config=config.ModelConfig) # Make some random data (BxH): H has dim of features in @@ -28,5 +32,5 @@ def main(): print(result) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/quick-start/simple.py b/examples/quick-start/simple.py index 0f00a00f..76cd15f9 100644 --- a/examples/quick-start/simple.py +++ b/examples/quick-start/simple.py @@ -1,6 +1,7 @@ +from typing import List + from spock.builder import ConfigArgBuilder from spock.config import spock -from typing import List @spock @@ -14,6 +15,7 @@ class BasicConfig: most_fancy_parameter: values to apply basic algebra to """ + parameter: bool fancy_parameter: float fancier_parameter: float @@ -22,7 +24,12 @@ class BasicConfig: def add_namespace(config): # Lets just do some basic algebra here - val_sum = sum([(config.fancy_parameter * val) + config.fancier_parameter for val in config.most_fancy_parameter]) + val_sum = sum( + [ + (config.fancy_parameter * val) + config.fancier_parameter + for val in config.most_fancy_parameter + ] + ) # If the boolean is true let's round if config.parameter: val_sum = round(val_sum) @@ -40,17 +47,21 @@ def add_by_parameter(multiply_param, list_vals, add_param, tf_round): def main(): # Chain the generate function to the class call - config = ConfigArgBuilder(BasicConfig, desc='Quick start example').generate() + config = ConfigArgBuilder(BasicConfig, desc="Quick start example").generate() # One can now access the Spock config object by class name with the returned namespace print(config.BasicConfig.parameter) # And pass the namespace to our first function val_sum_namespace = add_namespace(config.BasicConfig) print(val_sum_namespace) # Or pass by parameter - val_sum_parameter = add_by_parameter(config.BasicConfig.fancy_parameter, config.BasicConfig.most_fancy_parameter, - config.BasicConfig.fancier_parameter, config.BasicConfig.parameter) + val_sum_parameter = add_by_parameter( + config.BasicConfig.fancy_parameter, + config.BasicConfig.most_fancy_parameter, + config.BasicConfig.fancier_parameter, + config.BasicConfig.parameter, + ) print(val_sum_parameter) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tune/optuna/__init__.py b/examples/tune/optuna/__init__.py new file mode 100644 index 00000000..40a96afc --- /dev/null +++ b/examples/tune/optuna/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/examples/tune/optuna/tune.py b/examples/tune/optuna/tune.py new file mode 100644 index 00000000..5b4a403c --- /dev/null +++ b/examples/tune/optuna/tune.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- + +"""A simple example using sklearn and Optuna support""" + +# Spock ONLY supports the define-and-run style interface from Optuna +# https://optuna.readthedocs.io/en/stable/tutorial/20_recipes/009_ask_and_tell.html#define-and-run + + +from sklearn.datasets import load_iris +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split + +from spock.addons.tune import spockTuner +from spock.addons.tune.config import ( + ChoiceHyperParameter, + OptunaTunerConfig, + RangeHyperParameter, +) +from spock.builder import ConfigArgBuilder +from spock.config import spock + + +@spock +class BasicParams: + n_trials: int + + +@spockTuner +class LogisticRegressionHP: + c: RangeHyperParameter + solver: ChoiceHyperParameter + + +def main(): + # Load the iris data + X, y = load_iris(return_X_y=True) + + # Split the Iris data + X_train, X_valid, y_train, y_valid = train_test_split(X, y) + + # Optuna config -- this will internally spawn the study object for the define-and-run style which will be returned + # as part of the call to sample() + optuna_config = OptunaTunerConfig( + study_name="Iris Logistic Regression", direction="maximize" + ) + + # Use the builder to setup + # Call tuner to indicate that we are going to do some HP tuning -- passing in an optuna study object + attrs_obj = ConfigArgBuilder( + LogisticRegressionHP, + BasicParams, + desc="Example Logistic Regression Hyper-Parameter Tuning", + ).tuner(tuner_config=optuna_config) + + # Here we need some of the fixed parameters first so we can just call the generate fnc to grab all the fixed params + # prior to starting the sampling process + fixed_params = attrs_obj.generate() + + # Now we iterate through a bunch of optuna trials + for _ in range(fixed_params.BasicParams.n_trials): + # The crux of spock support -- call save w/ the add_tuner_sample flag to write the current draw to file and + # then call save to return the composed Spockspace of the fixed parameters and the sampled parameters + # Under the hood spock uses the define-and-run Optuna interface -- thus it handled the underlying 'ask' call + # and returns the necessary trial object in the return dictionary to call 'tell' with the study object + attrs_class, tune_dict = attrs_obj.save( + add_tuner_sample=True, user_specified_path="/tmp" + ).sample() + # Use the currently sampled parameters in a simple LogisticRegression from sklearn + clf = LogisticRegression( + C=attrs_class.LogisticRegressionHP.c, + solver=attrs_class.LogisticRegressionHP.solver, + ) + clf.fit(X_train, y_train) + val_acc = clf.score(X_valid, y_valid) + # Pull the study and trials object out of the return dictionary and pass it to the tell call using the study + # object + tune_dict["study"].tell(tune_dict["trial"], val_acc) + + +if __name__ == "__main__": + main() diff --git a/examples/tune/optuna/tune.yaml b/examples/tune/optuna/tune.yaml new file mode 100644 index 00000000..335c3efe --- /dev/null +++ b/examples/tune/optuna/tune.yaml @@ -0,0 +1,14 @@ +################ +# tune.yaml +################ +BasicParams: + n_trials: 10 + +LogisticRegressionHP: + c: + type: float + bounds: [1E-07, 10.0] + log_scale: true + solver: + type: str + choices: ["lbfgs", "saga"] \ No newline at end of file diff --git a/examples/tutorial/advanced/basic_nn.py b/examples/tutorial/advanced/basic_nn.py index 86015cb9..cf81e88c 100644 --- a/examples/tutorial/advanced/basic_nn.py +++ b/examples/tutorial/advanced/basic_nn.py @@ -6,12 +6,16 @@ class BasicNet(nn.Module): def __init__(self, model_config): super(BasicNet, self).__init__() # Make a dictionary of activation functions to select from - self.act_fncs = {'relu': nn.ReLU, 'gelu': nn.GELU, 'tanh': nn.Tanh} + self.act_fncs = {"relu": nn.ReLU, "gelu": nn.GELU, "tanh": nn.Tanh} self.use_act = self.act_fncs.get(model_config.activation)() # Define the layers manually (avoiding list comprehension for clarity) self.layer_1 = nn.Linear(model_config.n_features, model_config.hidden_sizes[0]) - self.layer_2 = nn.Linear(model_config.hidden_sizes[0], model_config.hidden_sizes[1]) - self.layer_3 = nn.Linear(model_config.hidden_sizes[1], model_config.hidden_sizes[2]) + self.layer_2 = nn.Linear( + model_config.hidden_sizes[0], model_config.hidden_sizes[1] + ) + self.layer_3 = nn.Linear( + model_config.hidden_sizes[1], model_config.hidden_sizes[2] + ) # Define some dropout layers self.dropout = [] if model_config.dropout is not None: diff --git a/examples/tutorial/advanced/tutorial.py b/examples/tutorial/advanced/tutorial.py index 13a01a3c..09c54993 100644 --- a/examples/tutorial/advanced/tutorial.py +++ b/examples/tutorial/advanced/tutorial.py @@ -1,23 +1,23 @@ -from basic_nn import BasicNet from enum import Enum +from typing import List, Optional, Tuple + +import torch +from basic_nn import BasicNet + from spock.args import SavePath from spock.builder import ConfigArgBuilder from spock.config import spock -import torch -from typing import List -from typing import Optional -from typing import Tuple class Activation(Enum): - relu = 'relu' - gelu = 'gelu' - tanh = 'tanh' + relu = "relu" + gelu = "gelu" + tanh = "tanh" class Optimizer(Enum): - sgd = 'SGD' - adam = 'Adam' + sgd = "SGD" + adam = "Adam" @spock @@ -26,7 +26,7 @@ class ModelConfig: n_features: int dropout: Optional[List[float]] hidden_sizes: Tuple[int, int, int] = (32, 32, 32) - activation: Activation = 'relu' + activation: Activation = "relu" optimizer: Optimizer cache_path: Optional[str] @@ -53,42 +53,61 @@ class SGDConfig(OptimizerConfig): def train(x_data, y_data, model, model_config, data_config, optimizer_config): - if model_config.optimizer == 'SGD': - optimizer = torch.optim.SGD(model.parameters(), lr=optimizer_config.lr, momentum=optimizer_config.momentum, - nesterov=optimizer_config.nesterov) - elif model_config.optimizer == 'Adam': + if model_config.optimizer == "SGD": + optimizer = torch.optim.SGD( + model.parameters(), + lr=optimizer_config.lr, + momentum=optimizer_config.momentum, + nesterov=optimizer_config.nesterov, + ) + elif model_config.optimizer == "Adam": optimizer = torch.optim.Adam(model.parameters(), lr=optimizer_config.lr) else: - raise ValueError(f'Optimizer choice {optimizer_config.optimizer} not available') + raise ValueError(f"Optimizer choice {optimizer_config.optimizer} not available") n_steps_per_epoch = data_config.n_samples % data_config.batch_size for epoch in range(optimizer_config.n_epochs): for i in range(n_steps_per_epoch): # Ugly data slicing for simplicity - x_batch = x_data[i * n_steps_per_epoch:(i + 1) * n_steps_per_epoch, ] - y_batch = y_data[i * n_steps_per_epoch:(i + 1) * n_steps_per_epoch, ] + x_batch = x_data[ + i * n_steps_per_epoch : (i + 1) * n_steps_per_epoch, + ] + y_batch = y_data[ + i * n_steps_per_epoch : (i + 1) * n_steps_per_epoch, + ] optimizer.zero_grad() output = model(x_batch) loss = torch.nn.CrossEntropyLoss(output, y_batch) loss.backward() if optimizer_config.grad_clip: - torch.nn.utils.clip_grad_value(model.parameters(), optimizer_config.grad_clip) + torch.nn.utils.clip_grad_value( + model.parameters(), optimizer_config.grad_clip + ) optimizer.step() - print(f'Finished Epoch {epoch+1}') + print(f"Finished Epoch {epoch+1}") def main(): # A simple description - description = 'spock Advanced Tutorial' + description = "spock Advanced Tutorial" # Build out the parser by passing in Spock config objects as *args after description - config = ConfigArgBuilder(ModelConfig, DataConfig, SGDConfig, desc=description).generate() + config = ConfigArgBuilder( + ModelConfig, DataConfig, SGDConfig, desc=description + ).generate() # Instantiate our neural net using basic_nn = BasicNet(model_config=config.ModelConfig) # Make some random data (BxH): H has dim of features in x_data = torch.rand(config.DataConfig.n_samples, config.ModelConfig.n_features) y_data = torch.randint(0, 3, (config.DataConfig.n_samples,)) # Run some training - train(x_data, y_data, basic_nn, config.ModelConfig, config.DataConfig, config.SGDConfig) + train( + x_data, + y_data, + basic_nn, + config.ModelConfig, + config.DataConfig, + config.SGDConfig, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/tutorial/basic/basic_nn.py b/examples/tutorial/basic/basic_nn.py index f4dc96b3..fdbe4a1c 100644 --- a/examples/tutorial/basic/basic_nn.py +++ b/examples/tutorial/basic/basic_nn.py @@ -6,12 +6,16 @@ class BasicNet(nn.Module): def __init__(self, model_config): super(BasicNet, self).__init__() # Make a dictionary of activation functions to select from - self.act_fncs = {'relu': nn.ReLU, 'gelu': nn.GELU, 'tanh': nn.Tanh} + self.act_fncs = {"relu": nn.ReLU, "gelu": nn.GELU, "tanh": nn.Tanh} self.use_act = self.act_fncs.get(model_config.activation)() # Define the layers manually (avoiding list comprehension for clarity) self.layer_1 = nn.Linear(model_config.n_features, model_config.hidden_sizes[0]) - self.layer_2 = nn.Linear(model_config.hidden_sizes[0], model_config.hidden_sizes[1]) - self.layer_3 = nn.Linear(model_config.hidden_sizes[1], model_config.hidden_sizes[2]) + self.layer_2 = nn.Linear( + model_config.hidden_sizes[0], model_config.hidden_sizes[1] + ) + self.layer_3 = nn.Linear( + model_config.hidden_sizes[1], model_config.hidden_sizes[2] + ) # Define some dropout layers self.dropout_1 = nn.Dropout(model_config.dropout[0]) self.dropout_2 = nn.Dropout(model_config.dropout[1]) diff --git a/examples/tutorial/basic/tutorial.py b/examples/tutorial/basic/tutorial.py index 19c953ae..52f832b3 100644 --- a/examples/tutorial/basic/tutorial.py +++ b/examples/tutorial/basic/tutorial.py @@ -1,11 +1,12 @@ -from basic_nn import BasicNet from enum import Enum +from typing import List, Tuple + +import torch +from basic_nn import BasicNet + from spock.args import SavePath from spock.builder import ConfigArgBuilder from spock.config import spock -import torch -from typing import List -from typing import Tuple class Activation(Enum): @@ -16,9 +17,10 @@ class Activation(Enum): gelu: gelu activation tanh: tanh activation """ - relu = 'relu' - gelu = 'gelu' - tanh = 'tanh' + + relu = "relu" + gelu = "gelu" + tanh = "tanh" @spock @@ -32,6 +34,7 @@ class ModelConfig: hidden_sizes: hidden size for each layer activation: choice from the Activation enum of the activation function to use """ + save_path: SavePath n_features: int dropout: List[float] @@ -41,10 +44,13 @@ class ModelConfig: def main(): # A simple description - description = 'spock Basic Tutorial' + description = "spock Basic Tutorial" # Build out the parser by passing in Spock config objects as *args after description - config = ConfigArgBuilder( - ModelConfig, desc=description, create_save_path=True).save(file_extension='.toml').generate() + config = ( + ConfigArgBuilder(ModelConfig, desc=description, create_save_path=True) + .save(file_extension=".toml") + .generate() + ) # Instantiate our neural net using basic_nn = BasicNet(model_config=config.ModelConfig) # Make some random data (BxH): H has dim of features in @@ -53,5 +59,5 @@ def main(): print(result) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index 2a1ba4df..902d2bb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +[tool.isort] +profile = "black" + [tool.portray] extra_dirs = ["resources"] @@ -65,6 +68,13 @@ Motivation = "docs/Motivation.md" [[tool.portray.mkdocs.nav]] [[tool.portray.mkdocs.nav."Addons"]] "S3" = "docs/addons/S3.md" + [[tool.portray.mkdocs.nav."Addons"]] + [[tool.portray.mkdocs.nav."Addons"."Hyper-Parameter Tuning"]] + "Basics" = "docs/addons/tuner/Basics.md" + [[tool.portray.mkdocs.nav."Addons"."Hyper-Parameter Tuning"]] + "Optuna" = "docs/addons/tuner/Optuna.md" + [[tool.portray.mkdocs.nav."Addons"."Hyper-Parameter Tuning"]] + "Ax" = "docs/addons/tuner/Ax.md" [[tool.portray.mkdocs.nav]] Contributing = "CONTRIBUTING.md" diff --git a/DEV_REQUIREMENTS.txt b/requirements/DEV_REQUIREMENTS.txt similarity index 65% rename from DEV_REQUIREMENTS.txt rename to requirements/DEV_REQUIREMENTS.txt index 3934641c..fc406fb0 100644 --- a/DEV_REQUIREMENTS.txt +++ b/requirements/DEV_REQUIREMENTS.txt @@ -1,6 +1,7 @@ --r REQUIREMENTS.txt +black coveralls coverage +isort moto portray pytest diff --git a/S3_REQUIREMENTS.txt b/requirements/S3_REQUIREMENTS.txt similarity index 100% rename from S3_REQUIREMENTS.txt rename to requirements/S3_REQUIREMENTS.txt diff --git a/requirements/TUNE_REQUIREMENTS.txt b/requirements/TUNE_REQUIREMENTS.txt new file mode 100644 index 00000000..2da1ac22 --- /dev/null +++ b/requirements/TUNE_REQUIREMENTS.txt @@ -0,0 +1,4 @@ +optuna==2.8.0 +torchvision +pytorch +#ax-platform \ No newline at end of file diff --git a/setup.py b/setup.py index cb474a50..74093ab0 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,32 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Spock Setup""" -from pkg_resources import parse_requirements import setuptools +from pkg_resources import parse_requirements + import versioneer -with open('README.md', 'r') as fid: +with open("README.md", "r") as fid: long_description = fid.read() -with open('REQUIREMENTS.txt', 'r') as fid: +with open("REQUIREMENTS.txt", "r") as fid: install_reqs = [str(req) for req in parse_requirements(fid)] -with open('S3_REQUIREMENTS.txt', 'r') as fid: +with open("./requirements/S3_REQUIREMENTS.txt", "r") as fid: s3_reqs = [str(req) for req in parse_requirements(fid)] +with open("./requirements/TUNE_REQUIREMENTS.txt", "r") as fid: + tune_reqs = [str(req) for req in parse_requirements(fid)] + setuptools.setup( - name='spock-config', - description='Spock is a framework designed to help manage complex parameter configurations for Python applications', + name="spock-config", + description="Spock is a framework designed to help manage complex parameter configurations for Python applications", long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), author="FMR LLC", @@ -33,24 +37,33 @@ "Intended Audience :: Developers", "Natural Language :: English", "License :: OSI Approved :: Apache Software License", - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Operating System :: OS Independent", "Topic :: Scientific/Engineering", "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development :: Libraries :: Python Modules" + "Topic :: Software Development :: Libraries :: Python Modules", ], project_urls={ "Source": "https://github.com/fidelity/spock", "Documentation": "https://fidelity.github.io/spock/", - "Bug Tracker": "https://fidelity.github.io/spock/issues" + "Bug Tracker": "https://fidelity.github.io/spock/issues", }, - keywords=['configuration', 'argparse', 'parameters', 'machine learning', 'deep learning', 'reproducibility'], - packages=setuptools.find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), - python_requires='>=3.6', + keywords=[ + "configuration", + "argparse", + "parameters", + "machine learning", + "deep learning", + "reproducibility", + ], + packages=setuptools.find_packages( + exclude=["*.tests", "*.tests.*", "tests.*", "tests"] + ), + python_requires=">=3.6", install_requires=install_reqs, - extras_require={'s3': s3_reqs} + extras_require={"s3": s3_reqs, "tune": tune_reqs}, ) diff --git a/spock/__init__.py b/spock/__init__.py index 0ea44bea..46d8e091 100644 --- a/spock/__init__.py +++ b/spock/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """ @@ -13,5 +13,5 @@ __all__ = ["args", "builder", "config"] -__version__ = get_versions()['version'] -del get_versions \ No newline at end of file +__version__ = get_versions()["version"] +del get_versions diff --git a/spock/addons/__init__.py b/spock/addons/__init__.py index 854473e1..ce9724fd 100644 --- a/spock/addons/__init__.py +++ b/spock/addons/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """ @@ -8,8 +8,6 @@ Please refer to the documentation provided in the README.md """ -from spock.addons.s3.utils import S3Config -from spock.addons.s3.configs import S3DownloadConfig -from spock.addons.s3.configs import S3UploadConfig -__all__ = ["s3", "S3Config", "S3DownloadConfig", "S3UploadConfig"] + +__all__ = ["s3", "tune"] diff --git a/spock/addons/s3/__init__.py b/spock/addons/s3/__init__.py index 6927bcbe..eb6f546a 100644 --- a/spock/addons/s3/__init__.py +++ b/spock/addons/s3/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """ @@ -9,4 +9,6 @@ Please refer to the documentation provided in the README.md """ -__all__ = ["utils"] +from spock.addons.s3.configs import S3Config, S3DownloadConfig, S3UploadConfig + +__all__ = ["configs", "utils", "S3Config", "S3DownloadConfig", "S3UploadConfig"] diff --git a/spock/addons/s3/configs.py b/spock/addons/s3/configs.py index 01818c02..1c948676 100644 --- a/spock/addons/s3/configs.py +++ b/spock/addons/s3/configs.py @@ -1,45 +1,54 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Handles all S3 related configurations""" import attr + try: import boto3 from botocore.client import BaseClient from s3transfer.manager import TransferManager except ImportError: - print('Missing libraries to support S3 functionality. Please re-install spock with the extra s3 dependencies -- ' - 'pip install spock-config[s3]') -import typing - + print( + "Missing libraries to support S3 functionality. Please re-install spock with the extra s3 dependencies -- " + "pip install spock-config[s3]" + ) +from typing import Optional # Iterate through the allowed download args for S3 and map into optional attr.ib download_attrs = { val: attr.ib( default=None, type=str, - validator=attr.validators.optional(attr.validators.instance_of(str)) - ) for val in TransferManager.ALLOWED_DOWNLOAD_ARGS} + validator=attr.validators.optional(attr.validators.instance_of(str)), + ) + for val in TransferManager.ALLOWED_DOWNLOAD_ARGS +} # Make the class dynamically -S3DownloadConfig = attr.make_class(name="S3DownloadConfig", attrs=download_attrs, kw_only=True, frozen=True) +S3DownloadConfig = attr.make_class( + name="S3DownloadConfig", attrs=download_attrs, kw_only=True, frozen=True +) # Iterate through the allowed upload args for S3 and map into optional attr.ib upload_attrs = { val: attr.ib( default=None, type=str, - validator=attr.validators.optional(attr.validators.instance_of(str)) - ) for val in TransferManager.ALLOWED_UPLOAD_ARGS + validator=attr.validators.optional(attr.validators.instance_of(str)), + ) + for val in TransferManager.ALLOWED_UPLOAD_ARGS } # Make the class dynamically -S3UploadConfig = attr.make_class(name="S3UploadConfig", attrs=upload_attrs, kw_only=True, frozen=True) +S3UploadConfig = attr.make_class( + name="S3UploadConfig", attrs=upload_attrs, kw_only=True, frozen=True +) @attr.s(auto_attribs=True) @@ -56,13 +65,14 @@ class S3Config: upload_config: S3UploadConfig for extra upload configs (optional) """ + session: boto3.Session # s3_session: BaseClient = attr.ib(init=False) - s3_session: typing.Optional[BaseClient] = None - temp_folder: typing.Optional[str] = '/tmp/' + s3_session: Optional[BaseClient] = None + temp_folder: Optional[str] = "/tmp/" download_config: S3DownloadConfig = S3DownloadConfig() upload_config: S3UploadConfig = S3UploadConfig() def __attrs_post_init__(self): if self.s3_session is None: - self.s3_session = self.session.client('s3') + self.s3_session = self.session.client("s3") diff --git a/spock/addons/s3/utils.py b/spock/addons/s3/utils.py index b463f57b..78678bf9 100644 --- a/spock/addons/s3/utils.py +++ b/spock/addons/s3/utils.py @@ -1,25 +1,28 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Handles all S3 related ops -- allows for s3 functionality to be optional to keep req deps light""" import attr + try: import boto3 from botocore.client import BaseClient except ImportError: - print('Missing libraries to support S3 functionality. Please re-install spock with the extra s3 dependencies -- ' - 'pip install spock-config[s3]') -from hurry.filesize import size + print( + "Missing libraries to support S3 functionality. Please re-install spock with the extra s3 dependencies -- " + "pip install spock-config[s3]" + ) import os -from urllib.parse import urlparse -from spock.addons.s3.configs import S3Config -from spock.addons.s3.configs import S3DownloadConfig -from spock.addons.s3.configs import S3UploadConfig import sys import typing +from urllib.parse import urlparse + +from hurry.filesize import size + +from spock.addons.s3.configs import S3Config, S3DownloadConfig, S3UploadConfig def handle_s3_load_path(path: str, s3_config: S3Config) -> str: @@ -39,15 +42,20 @@ def handle_s3_load_path(path: str, s3_config: S3Config) -> str: """ if s3_config is None: - raise ValueError('Load from S3 -- Missing S3Config object which is necessary to handle S3 style paths') + raise ValueError( + "Load from S3 -- Missing S3Config object which is necessary to handle S3 style paths" + ) bucket, obj, fid = get_s3_bucket_object_name(s3_path=path) # Construct the full temp path - temp_path = f'{s3_config.temp_folder}/{fid}' + temp_path = f"{s3_config.temp_folder}/{fid}" # Strip double slashes if exist - temp_path = temp_path.replace(r'//', r'/') + temp_path = temp_path.replace(r"//", r"/") temp_path = download_s3( - bucket=bucket, obj=obj, temp_path=temp_path, s3_session=s3_config.s3_session, - download_config=s3_config.download_config + bucket=bucket, + obj=obj, + temp_path=temp_path, + s3_session=s3_config.s3_session, + download_config=s3_config.download_config, ) return temp_path @@ -68,13 +76,18 @@ def handle_s3_save_path(temp_path: str, s3_path: str, name: str, s3_config: S3Co """ if s3_config is None: - raise ValueError('Save to S3 -- Missing S3Config object which is necessary to handle S3 style paths') + raise ValueError( + "Save to S3 -- Missing S3Config object which is necessary to handle S3 style paths" + ) # Fix posix strip - s3_path = s3_path.replace('s3:/', 's3://') - bucket, obj, fid = get_s3_bucket_object_name(f'{s3_path}/{name}') + s3_path = s3_path.replace("s3:/", "s3://") + bucket, obj, fid = get_s3_bucket_object_name(f"{s3_path}/{name}") upload_s3( - bucket=bucket, obj=obj, temp_path=temp_path, - s3_session=s3_config.s3_session, upload_config=s3_config.upload_config + bucket=bucket, + obj=obj, + temp_path=temp_path, + s3_session=s3_config.s3_session, + upload_config=s3_config.upload_config, ) @@ -93,11 +106,16 @@ def get_s3_bucket_object_name(s3_path: str) -> typing.Tuple[str, str, str]: """ parsed = urlparse(s3_path) - return parsed.netloc, parsed.path.lstrip('/'), os.path.basename(parsed.path) + return parsed.netloc, parsed.path.lstrip("/"), os.path.basename(parsed.path) -def download_s3(bucket: str, obj: str, temp_path: str, s3_session: BaseClient, - download_config: S3DownloadConfig) -> str: +def download_s3( + bucket: str, + obj: str, + temp_path: str, + s3_session: BaseClient, + download_config: S3DownloadConfig, +) -> str: """Attempts to download the file from the S3 uri to a temp location using any extra arguments to the download *Args*: @@ -115,9 +133,13 @@ def download_s3(bucket: str, obj: str, temp_path: str, s3_session: BaseClient, """ try: # Unroll the extra options for those values that are not None - extra_options = {k: v for k, v in attr.asdict(download_config).items() if v is not None} - file_size = s3_session.head_object(Bucket=bucket, Key=obj, **extra_options)['ContentLength'] - print(f'Attempting to download s3://{bucket}/{obj} (size: {size(file_size)})') + extra_options = { + k: v for k, v in attr.asdict(download_config).items() if v is not None + } + file_size = s3_session.head_object(Bucket=bucket, Key=obj, **extra_options)[ + "ContentLength" + ] + print(f"Attempting to download s3://{bucket}/{obj} (size: {size(file_size)})") current_progress = 0 n_ticks = 50 @@ -126,21 +148,34 @@ def _s3_progress_bar(chunk): # Increment progress current_progress += chunk done = int(n_ticks * (current_progress / file_size)) - sys.stdout.write(f"\r[%s%s] " - f"{int(current_progress/file_size) * 100}%%" % ('=' * done, ' ' * (n_ticks - done))) + sys.stdout.write( + f"\r[%s%s] " + f"{int(current_progress/file_size) * 100}%%" + % ("=" * done, " " * (n_ticks - done)) + ) sys.stdout.flush() - sys.stdout.write('\n\n') + sys.stdout.write("\n\n") + # Download with the progress callback - s3_session.download_file(bucket, obj, temp_path, Callback=_s3_progress_bar, ExtraArgs=extra_options) + s3_session.download_file( + bucket, obj, temp_path, Callback=_s3_progress_bar, ExtraArgs=extra_options + ) return temp_path except IOError: - print(f'Failed to download file from S3 ' - f'(bucket: {bucket}, object: {obj}) ' - f'and write to {temp_path}') - - -def upload_s3(bucket: str, obj: str, temp_path: str, s3_session: BaseClient, - upload_config: S3UploadConfig): + print( + f"Failed to download file from S3 " + f"(bucket: {bucket}, object: {obj}) " + f"and write to {temp_path}" + ) + + +def upload_s3( + bucket: str, + obj: str, + temp_path: str, + s3_session: BaseClient, + upload_config: S3UploadConfig, +): """Attempts to upload the local file to the S3 uri using any extra arguments to the upload *Args*: @@ -156,9 +191,11 @@ def upload_s3(bucket: str, obj: str, temp_path: str, s3_session: BaseClient, """ try: # Unroll the extra options for those values that are not None - extra_options = {k: v for k, v in attr.asdict(upload_config).items() if v is not None} + extra_options = { + k: v for k, v in attr.asdict(upload_config).items() if v is not None + } file_size = os.path.getsize(temp_path) - print(f'Attempting to upload s3://{bucket}/{obj} (size: {size(file_size)})') + print(f"Attempting to upload s3://{bucket}/{obj} (size: {size(file_size)})") current_progress = 0 n_ticks = 50 @@ -167,13 +204,21 @@ def _s3_progress_bar(chunk): # Increment progress current_progress += chunk done = int(n_ticks * (current_progress / file_size)) - sys.stdout.write(f"\r[%s%s] " - f"{int(current_progress/file_size) * 100}%%" % ('=' * done, ' ' * (n_ticks - done))) + sys.stdout.write( + f"\r[%s%s] " + f"{int(current_progress/file_size) * 100}%%" + % ("=" * done, " " * (n_ticks - done)) + ) sys.stdout.flush() - sys.stdout.write('\n\n') + sys.stdout.write("\n\n") + # Upload with progress callback - s3_session.upload_file(temp_path, bucket, obj, Callback=_s3_progress_bar, ExtraArgs=extra_options) + s3_session.upload_file( + temp_path, bucket, obj, Callback=_s3_progress_bar, ExtraArgs=extra_options + ) except IOError: - print(f'Failed to upload file to S3 ' - f'(bucket: {bucket}, object: {obj}) ' - f'from {temp_path}') + print( + f"Failed to upload file to S3 " + f"(bucket: {bucket}, object: {obj}) " + f"from {temp_path}" + ) diff --git a/spock/backend/attr/__init__.py b/spock/addons/tune/__init__.py similarity index 61% rename from spock/backend/attr/__init__.py rename to spock/addons/tune/__init__.py index 1cd0336a..1bdca667 100644 --- a/spock/backend/attr/__init__.py +++ b/spock/addons/tune/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """ @@ -9,4 +9,6 @@ Please refer to the documentation provided in the README.md """ -__all__ = ["builder", "config", "payload", "saver", "typed"] +from spock.addons.tune.config import spockTuner + +__all__ = ["builder", "config", "spockTuner"] diff --git a/spock/addons/tune/builder.py b/spock/addons/tune/builder.py new file mode 100644 index 00000000..dd585549 --- /dev/null +++ b/spock/addons/tune/builder.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the tuner builder backend""" + +from spock.backend.builder import BaseBuilder +from spock.utils import make_argument + + +class TunerBuilder(BaseBuilder): + def __init__(self, *args, **kwargs): + """TunerBuilder init + + Args: + *args: list of input classes that link to a backend + configs: None or List of configs to read from + desc: description for the arg parser + no_cmd_line: flag to force no command line reads + **kwargs: any extra keyword args + """ + super().__init__(*args, module_name="spock.addons.tune.config", **kwargs) + + def _handle_arguments(self, args, class_obj): + """Ovverides base -- Handles all argument mapping + + Creates a dictionary of named parameters that are mapped to the final type of object + + *Args*: + + args: read file arguments + class_obj: instance of a class obj + + *Returns*: + + fields: dictionary of mapped parameters + + """ + attr_name = class_obj.__name__ + fields = { + val.name: val.type(**args[attr_name][val.name]) + for val in class_obj.__attrs_attrs__ + } + return fields + + @staticmethod + def _make_group_override_parser(parser, class_obj, class_name): + """Makes a name specific override parser for a given class obj + + Takes a class object of the backend and adds a new argument group with argument names given with name + Class.val.(unrolled config parameters) so that individual parameters specific to a class can be overridden. + + *Args*: + + parser: argument parser + class_obj: instance of a backend class + class_name: used for module matching + + *Returns*: + + parser: argument parser with new class specific overrides + + """ + attr_name = class_obj.__name__ + group_parser = parser.add_argument_group( + title=str(attr_name) + " Specific Overrides" + ) + for val in class_obj.__attrs_attrs__: + val_type = val.metadata["type"] if "type" in val.metadata else val.type + for arg in val_type.__attrs_attrs__: + arg_name = f"--{str(attr_name)}.{val.name}.{arg.name}" + group_parser = make_argument(arg_name, arg.type, group_parser) + return parser + + def _extract_fnc(self, val, module_name): + return self._extract_other_types(val.type, module_name) diff --git a/spock/addons/tune/config.py b/spock/addons/tune/config.py new file mode 100644 index 00000000..5974f789 --- /dev/null +++ b/spock/addons/tune/config.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Creates the spock config interface that wraps attr -- tune version for hyper-parameters""" +import sys +from typing import List, Optional, Sequence, Tuple, Union + +import attr +import optuna + +from spock.backend.config import _base_attr + + +@attr.s(auto_attribs=True) +class OptunaTunerConfig: + storage: Optional[Union[str, optuna.storages.BaseStorage]] = None + sampler: Optional[optuna.samplers.BaseSampler] = None + pruner: Optional[optuna.pruners.BasePruner] = None + study_name: Optional[str] = None + direction: Optional[Union[str, optuna.study.StudyDirection]] = None + load_if_exists: bool = False + directions: Optional[Sequence[Union[str, optuna.study.StudyDirection]]] = None + + +def _spock_tune(cls): + """Ovverides basic spock_attr decorator with another name + + Using a different name allows spock to easily determine which parameters are normal and which are + meant to be used in a hyper-parameter tuning backend + + *Args*: + + cls: basic class def + + *Returns*: + + cls: slotted attrs class that is frozen and kw only + """ + bases, attrs_dict = _base_attr(cls) + # Dynamically make an attr class + obj = attr.make_class( + name=cls.__name__, bases=bases, attrs=attrs_dict, kw_only=True, frozen=True + ) + # For each class we dynamically create we need to register it within the system modules for pickle to work + setattr(sys.modules["spock"].addons.tune.config, obj.__name__, obj) + # Swap the __doc__ string from cls to obj + obj.__doc__ = cls.__doc__ + return obj + + +# Make the alias for the decorator +spockTuner = _spock_tune + + +@attr.s +class RangeHyperParameter: + """Range based hyper-parameter that is sampled uniformly + + Attributes: + type: type of the hyper-parameter (note: spock will attempt to autocast into this type) + bounds: min and max of the hyper-parameter range + log_scale: log scale the values before sampling + + """ + + type = attr.ib( + type=str, + validator=[ + attr.validators.instance_of(str), + attr.validators.in_(["float", "int"]), + ], + ) + bounds = attr.ib( + type=Union[Tuple[float, float], Tuple[int, int]], + validator=attr.validators.deep_iterable( + member_validator=attr.validators.instance_of((float, int)), + iterable_validator=attr.validators.instance_of(tuple), + ), + ) + log_scale = attr.ib(type=bool, validator=attr.validators.instance_of(bool)) + + +@attr.s +class ChoiceHyperParameter: + """Choice based hyper-parameter that is sampled uniformly + + Attributes: + type: type of the hyper-parameter -- (note: spock will attempt to autocast into this type) + choices: list of variable length that contains all the possible choices to select from + + """ + + type = attr.ib( + type=str, + validator=[ + attr.validators.instance_of(str), + attr.validators.in_(["float", "int", "str", "bool"]), + ], + ) + choices = attr.ib( + type=Union[List[str], List[int], List[float], List[bool]], + validator=attr.validators.deep_iterable( + member_validator=attr.validators.instance_of((float, int, bool, str)), + iterable_validator=attr.validators.instance_of(list), + ), + ) diff --git a/spock/addons/tune/interface.py b/spock/addons/tune/interface.py new file mode 100644 index 00000000..373109db --- /dev/null +++ b/spock/addons/tune/interface.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the base interface""" +from abc import ABC, abstractmethod +from typing import Dict + +import attr + +from spock.backend.wrappers import Spockspace + + +class BaseInterface(ABC): + def __init__(self, tuner_config, tuner_namespace: Spockspace): + """Base init call that maps a few variables + + *Args*: + + _tuner_config: necessary object to determine the interface and sample correctly from the underlying library + _tuner_namespace: tuner namespace that has attr classes that maps to an underlying library types + + """ + + self._tuner_config = { + k: v for k, v in attr.asdict(tuner_config).items() if v is not None + } + self._tuner_namespace = tuner_namespace + + @abstractmethod + def sample(self): + """Calls the underlying library sample to get a single sample/draw from the hyper-parameter + sets (e.g. ranges, choices) + + *Returns*: + + Spockspace of the current hyper-parameter draw + dictionary of any extra returns needed to use for the underlying hyper-parameter library + + """ + pass + + @abstractmethod + def _construct(self): + """Constructs the base object needed by the underlying library to construct the correct object that allows + for hyper-parameter sampling + + *Returns*: + + Any typed object needed for support + + """ + pass + + @staticmethod + def _gen_attr_classes(tune_dict: Dict): + for k, v in tune_dict.items(): + attrs_dict = { + ik: attr.ib( + validator=attr.validators.instance_of(type(iv)), type=type(iv) + ) + for ik, iv in v.items() + } + obj = attr.make_class(name=k, attrs=attrs_dict, kw_only=True, frozen=True) + tune_dict.update({k: obj(**v)}) + return tune_dict + + @staticmethod + def _to_spockspace(tune_dict: Dict): + """Converts a dict to a Spockspace + + *Args*: + + tune_dict: current dictionary + + *Returns*: + + Spockspace of dict + + """ + return Spockspace(**tune_dict) + + @staticmethod + def _get_caster(val): + """Gets a callable type object from a string type + + *Args*: + + val: current attr val: + + *Returns*: + + type class object + + """ + return __builtins__[val.type] diff --git a/spock/addons/tune/optuna.py b/spock/addons/tune/optuna.py new file mode 100644 index 00000000..58e6134e --- /dev/null +++ b/spock/addons/tune/optuna.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the optuna backend""" + +import attr +import optuna + +from spock.addons.tune.config import OptunaTunerConfig +from spock.addons.tune.interface import BaseInterface + + +class OptunaInterface(BaseInterface): + """Specific override to support the optuna backend + + *Attributes*: + + _map_type: dictionary that maps class names and types to fns that create optuna distributions + _tuner_obj: necessary object to determine the interface and sample correctly from the underlying library + _tuner_namespace: tuner namespace that has attr classes that maps to an underlying library types + _param_obj: underlying object that optuna study can sample from (flat dictionary) + + """ + + def __init__(self, tuner_config: OptunaTunerConfig, tuner_namespace): + """OptunaInterface init call that maps variables, creates a map to fnc calls, and constructs the necessary + underlying objects + + *Args*: + + tuner_config: necessary object to determine the interface and sample correctly from the underlying library + tuner_namespace: tuner namespace that has attr classes that maps to an underlying library types + + """ + super(OptunaInterface, self).__init__(tuner_config, tuner_namespace) + self._tuner_obj = optuna.create_study(**self._tuner_config) + # Mapping spock underlying classes to optuna distributions (define-and-run interface) + self._map_type = { + "RangeHyperParameter": { + "int": self._uniform_int_dist, + "float": self._uniform_float_dist, + }, + "ChoiceHyperParameter": { + "int": self._categorical_dist, + "float": self._categorical_dist, + "str": self._categorical_dist, + "bool": self._categorical_dist, + }, + } + # Build the correct underlying dictionary object for Optuna + self._param_obj = self._construct() + + def sample(self): + trial = self._tuner_obj.ask(self._param_obj) + # Roll this back out into a Spockspace so it can be merged into the fixed parameter Spockspace + # Also need to un-dot the param names to rebuild the nested structure + key_set = {k.split(".")[0] for k in trial.params.keys()} + rollup_dict = {val: {} for val in key_set} + for k, v in trial.params.items(): + split_names = k.split(".") + rollup_dict[split_names[0]].update({split_names[1]: v}) + return self._to_spockspace(self._gen_attr_classes(rollup_dict)), { + "trial": trial, + "study": self._tuner_obj, + } + + def _construct(self): + """Constructs the base object needed by the underlying library to construct the correct object that allows + for hyper-parameter sampling + + *Returns*: + + flat dictionary of all hyper-parameters named with dot notation (class.param_name) + + """ + optuna_dict = {} + # These will only be nested one level deep given the tuner syntax + for k, v in vars(self._tuner_namespace).items(): + for ik, iv in vars(v).items(): + param_fn = self._map_type[type(iv).__name__][iv.type] + optuna_dict.update({f"{k}.{ik}": param_fn(iv)}) + return optuna_dict + + @staticmethod + def _uniform_float_dist(val): + """Assemble the optuna.distributions.(Log)UniformDistribution object + + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.UniformDistribution.html + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.LogUniformDistribution.html + + *Args*: + + val: current attr val + + *Returns*: + + optuna.distributions.UniformDistribution or optuna.distributions.LogUniformDistribution + + """ + try: + low = float(val.bounds[0]) + high = float(val.bounds[1]) + except TypeError: + print( + f"Attempted to cast into type: {val.type} but failed -- check the inputs to RangeHyperParameter" + ) + log_scale = val.log_scale + return ( + optuna.distributions.LogUniformDistribution(low=low, high=high) + if log_scale + else optuna.distributions.UniformDistribution(low=low, high=high) + ) + + @staticmethod + def _uniform_int_dist(val): + """Assemble the optuna.distributions.Int(Log)UniformDistribution object + + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntLogUniformDistribution.html + + *Args*: + + val: current attr val + + *Returns*: + + optuna.distributions.IntUniformDistribution or optuna.distributions.IntLogUniformDistribution + + """ + try: + low = int(val.bounds[0]) + high = int(val.bounds[1]) + except TypeError: + print( + f"Attempted to cast into type: {val.type} but failed -- check the inputs to RangeHyperParameter" + ) + log_scale = val.log_scale + return ( + optuna.distributions.IntLogUniformDistribution(low=low, high=high) + if log_scale + else optuna.distributions.IntUniformDistribution(low=low, high=high) + ) + + def _categorical_dist(self, val): + """Assemble the optuna.distributions.CategoricalDistribution object + + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html + + *Args*: + + val: current attr val + + *Returns*: + + optuna.distributions.CategoricalDistribution + + """ + caster = self._get_caster(val) + # Just attempt to cast in a try except + try: + val.choices = [caster(v) for v in val.choices] + except TypeError: + print( + f"Attempted to cast into type: {val.type} but failed -- check the inputs to ChoiceHyperParameter" + ) + return optuna.distributions.CategoricalDistribution(choices=val.choices) diff --git a/spock/addons/tune/payload.py b/spock/addons/tune/payload.py new file mode 100644 index 00000000..85dcfa30 --- /dev/null +++ b/spock/addons/tune/payload.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the tuner payload backend""" + +from spock.backend.payload import BasePayload + + +class TunerPayload(BasePayload): + """Handles building the payload for tuners + + This class builds out the payload from config files of multiple types. It handles various + file types and also composition of config files via a recursive calls + + *Attributes*: + + _loaders: maps of each file extension to the loader class + + """ + + def __init__(self, s3_config=None): + """Init for TunerPayload + + *Args*: + + s3_config: optional S3 config object + + """ + super().__init__(s3_config=s3_config) + + def __call__(self, *args, **kwargs): + """Call to allow self chaining + + *Args*: + + *args: + **kwargs: + + *Returns*: + + Payload: instance of self + + """ + return TunerPayload(*args, **kwargs) + + @staticmethod + def _update_payload(base_payload, input_classes, ignore_classes, payload): + # Get the ignore fields + ignore_fields = { + attr.__name__: [val.name for val in attr.__attrs_attrs__] + for attr in ignore_classes + } + for k, v in base_payload.items(): + if k not in ignore_fields: + for ik, iv in v.items(): + if "bounds" in iv: + iv["bounds"] = tuple(iv["bounds"]) + return base_payload + + @staticmethod + def _handle_payload_override(payload, key, value): + key_split = key.split(".") + curr_ref = payload + for idx, split in enumerate(key_split): + # If the root isn't in the payload then it needs to be added but only for the first key split + if idx == 0 and (split not in payload): + payload.update({split: {}}) + # Check if it's the last value and figure out the override + if idx == (len(key_split) - 1): + # Handle bool(s) a bit differently as they are store_true + if isinstance(curr_ref, dict) and isinstance(value, bool): + if value is not False: + curr_ref[split] = value + # If we are at the dictionary level we should be able to just payload override + elif isinstance(curr_ref, dict) and not isinstance(value, bool): + curr_ref[split] = value + else: + raise ValueError( + f"cmd-line override failed for {key} -- " + f"Failed to find key {split} within lowest level Dict" + ) + # If it's not keep walking the current payload + else: + curr_ref = curr_ref[split] + return payload diff --git a/spock/addons/tune/tuner.py b/spock/addons/tune/tuner.py new file mode 100644 index 00000000..7a13e9b7 --- /dev/null +++ b/spock/addons/tune/tuner.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the tuner interface interface""" + +from typing import Union + +from spock.addons.tune.config import OptunaTunerConfig +from spock.addons.tune.optuna import OptunaInterface +from spock.backend.wrappers import Spockspace + + +class TunerInterface: + """Handles the general tuner interface by creating the necessary underlying tuner class and dispatches necessary + ops to the class instance + + *Attributes*: + + _fixed_namespace: fixed parameter namespace used for combination with a sample draw + _lib_interface: class instance of the underlying hyper-parameter library + + """ + + def __init__( + self, + tuner_config: Union[OptunaTunerConfig], + tuner_namespace: Spockspace, + fixed_namespace: Spockspace, + ): + """Init call to the TunerInterface + + *Args*: + + tuner_config: necessary object to determine the interface and sample correctly from the underlying library + tuner_namespace: tuner namespace that has attr classes that maps to an underlying library types + fixed_namespace: namespace of fixed parameters + + """ + self._fixed_namespace = fixed_namespace + # Todo: add ax type check here + accept_types = OptunaTunerConfig + if not isinstance(tuner_config, accept_types): + raise ValueError( + f"Passed incorrect tuner_config type of {type(tuner_config)} -- must be of type " + f"{repr(accept_types)}" + ) + if isinstance(tuner_config, OptunaTunerConfig): + self._lib_interface = OptunaInterface( + tuner_config=tuner_config, tuner_namespace=tuner_namespace + ) + # # TODO: Add ax class logic + # elif isinstance(tuner_config, (ax.Experiment, ax.SimpleExperiment)): + # pass + + def sample(self): + """Public interface to underlying library sepcific sample that returns a single sample/draw from the + hyper-parameter sets (e.g. ranges, choices) and combines them with the fixed parameters into a single Spockspace + + *Returns*: + + Spockspace of drawn sample of hyper-parameters and fixed parameters + + """ + curr_sample, extra_dict = self._lib_interface.sample() + # Merge w/ fixed parameters + return ( + Spockspace(**vars(curr_sample), **vars(self._fixed_namespace)), + extra_dict, + ) diff --git a/spock/args.py b/spock/args.py index e98032b4..d0e0109e 100644 --- a/spock/args.py +++ b/spock/args.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Handles import aliases to allow backwards compat with backends""" # from spock.backend.dataclass.args import * -from spock.backend.attr.typed import SavePath +from spock.backend.typed import SavePath diff --git a/spock/backend/__init__.py b/spock/backend/__init__.py index d8767c44..5d4e8add 100644 --- a/spock/backend/__init__.py +++ b/spock/backend/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """ @@ -8,4 +8,5 @@ Please refer to the documentation provided in the README.md """ -__all__ = ["attr", "base"] + +__all__ = ["builder", "config", "payload", "saver", "typed"] diff --git a/spock/backend/attr/builder.py b/spock/backend/attr/builder.py deleted file mode 100644 index 9b76ee7d..00000000 --- a/spock/backend/attr/builder.py +++ /dev/null @@ -1,125 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 FMR LLC -# SPDX-License-Identifier: Apache-2.0 - -"""Handles the building/saving of the configurations from the Spock config classes""" - -import attr -from enum import EnumMeta -import re -import sys -from warnings import warn -from spock.backend.base import BaseBuilder - - -class AttrBuilder(BaseBuilder): - """Attr specific builder - - Class that handles building for the attr backend - - *Attributes* - - input_classes: list of input classes that link to a backend - _configs: None or List of configs to read from - _create_save_path: boolean to make the path to save to - _desc: description for the arg parser - _no_cmd_line: flag to force no command line reads - save_path: list of path(s) to save the configs to - - """ - def __init__(self, *args, configs=None, create_save_path=False, desc='', no_cmd_line=False, **kwargs): - super().__init__(*args, configs=configs, create_save_path=create_save_path, desc=desc, - no_cmd_line=no_cmd_line, **kwargs) - for arg in self.input_classes: - if not attr.has(arg): - raise TypeError('*arg inputs to ConfigArgBuilder must all be class instances with attrs attributes') - - def print_usage_and_exit(self, msg=None, sys_exit=True, exit_code=1): - print(f'usage: {sys.argv[0]} -c [--config] config1 [config2, config3, ...]') - print(f'\n{self._desc if self._desc != "" else ""}\n') - print('configuration(s):\n') - self._handle_help_info() - if msg is not None: - print(msg) - if sys_exit: - sys.exit(exit_code) - - def _handle_help_info(self): - self._attrs_help(self.input_classes) - - def _handle_arguments(self, args, class_obj): - attr_name = class_obj.__name__ - class_names = [val.__name__ for val in self.input_classes] - # Handle repeated classes - if attr_name in class_names and attr_name in args and isinstance(args[attr_name], list): - fields = self._handle_repeated(args[attr_name], attr_name, class_names) - # Handle non-repeated classes - else: - fields = {} - for val in class_obj.__attrs_attrs__: - # Check if namespace is named and then check for key -- checking for local class def - if attr_name in args and val.name in args[attr_name]: - fields[val.name] = self._handle_nested_class(args, args[attr_name][val.name], class_names) - # If not named then just check for keys -- checking for global def - elif val.name in args: - fields[val.name] = self._handle_nested_class(args, args[val.name], class_names) - # Check for special keys to set - if 'special_key' in val.metadata and val.metadata['special_key'] is not None: - if val.name in args: - self.save_path = args[val.name] - elif val.default is not None: - self.save_path = val.default - return fields - - def _handle_repeated(self, args, check_value, class_names): - """Handles repeated classes as lists - - *Args*: - - args: dictionary of arguments from the configs - check_value: value to check classes against - class_names: current class names - - *Returns*: - - list of input_class[match)idx[0]] types filled with repeated values - - """ - # Check to see if the value trying to be set is actually an input class - match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] - return [self.input_classes[match_idx[0]](**val) for val in args] - - def _handle_nested_class(self, args, check_value, class_names): - """Handles passing another class to the field dictionary - - *Args*: - args: dictionary of arguments from the configs - check_value: value to check classes against - class_names: current class names - - *Returns*: - - either the check_value or the necessary class - - """ - # Check to see if the value trying to be set is actually an input class - match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] - # If so then create the needed class object by unrolling the args to **kwargs and return it - if len(match_idx) > 0: - if len(match_idx) > 1: - raise ValueError('Match error -- multiple classes with the same name definition') - else: - if args.get(self.input_classes[match_idx[0]].__name__) is None: - raise ValueError(f'Missing config file definition for the referenced class ' - f'{self.input_classes[match_idx[0]].__name__}') - current_arg = args.get(self.input_classes[match_idx[0]].__name__) - if isinstance(current_arg, list): - class_value = [self.input_classes[match_idx[0]](**val) for val in current_arg] - else: - class_value = self.input_classes[match_idx[0]](**current_arg) - return_value = class_value - # else return the expected value - else: - return_value = check_value - return return_value diff --git a/spock/backend/attr/payload.py b/spock/backend/attr/payload.py deleted file mode 100644 index 4c8c6dd0..00000000 --- a/spock/backend/attr/payload.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 FMR LLC -# SPDX-License-Identifier: Apache-2.0 - -"""Handles payloads from markup files""" - -from itertools import chain -from spock.backend.attr.utils import convert_to_tuples -from spock.backend.attr.utils import get_type_fields -from spock.backend.attr.utils import deep_update -from spock.backend.base import BasePayload - - -class AttrPayload(BasePayload): - """Handles building the payload for attrs backend - - This class builds out the payload from config files of multiple types. It handles various - file types and also composition of config files via a recursive calls - - *Attributes*: - - _loaders: maps of each file extension to the loader class - - """ - def __init__(self, s3_config=None): - super().__init__(s3_config=s3_config) - - def __call__(self, *args, **kwargs): - """Call to allow self chaining - - *Args*: - - *args: - **kwargs: - - *Returns*: - - Payload: instance of self - - """ - return AttrPayload(*args, **kwargs) - - @staticmethod - def _update_payload(base_payload, input_classes, payload): - # Get basic args - attr_fields = {attr.__name__: [val.name for val in attr.__attrs_attrs__] for attr in input_classes} - # Class names - class_names = [val.__name__ for val in input_classes] - # Parse out the types if generic - type_fields = get_type_fields(input_classes) - for keys, values in base_payload.items(): - # check if the keys, value pair is expected by the attr class - if keys != 'config': - # Dict infers that we are overriding a global setting in a specific config - if isinstance(values, dict): - # we're in a namespace - # Check for incorrect specific override of global def - if keys not in attr_fields: - raise TypeError(f'Referring to a class space {keys} that is undefined') - for i_keys in values.keys(): - if i_keys not in attr_fields[keys]: - raise ValueError(f'Provided an unknown argument named {keys}.{i_keys}') - else: - # Check if the key is actually a reference to another class - if keys in class_names: - if isinstance(values, list): - # Check for incorrect specific override of global def - if keys not in attr_fields: - raise ValueError(f'Referring to a class space {keys} that is undefined') - # We are in a repeated class def - # Raise if the key set is different from the defined set (i.e. incorrect arguments) - key_set = set(list(chain(*[list(val.keys()) for val in values]))) - for i_keys in key_set: - if i_keys not in attr_fields[keys]: - raise ValueError(f'Provided an unknown argument named {keys}.{i_keys}') - # Chain all the values from multiple spock classes into one list - elif keys not in list(chain(*attr_fields.values())): - raise ValueError(f'Provided an unknown argument named {keys}') - # Chain all the values from multiple spock classes into one list - elif keys not in list(chain(*attr_fields.values())): - raise ValueError(f'Provided an unknown argument named {keys}') - if keys in payload and isinstance(values, dict): - payload[keys].update(values) - else: - payload[keys] = values - tuple_payload = convert_to_tuples(payload, type_fields, class_names) - payload = deep_update(payload, tuple_payload) - return payload diff --git a/spock/backend/attr/saver.py b/spock/backend/attr/saver.py deleted file mode 100644 index a0c45d06..00000000 --- a/spock/backend/attr/saver.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 FMR LLC -# SPDX-License-Identifier: Apache-2.0 - -"""Handles prepping and saving the Spock config""" - -import attr -from spock.backend.base import BaseSaver - - -class AttrSaver(BaseSaver): - """Base class for saving configs for the attrs backend - - Contains methods to build a correct output payload and then writes to file based on the file - extension - - *Attributes*: - - _writers: maps file extension to the correct i/o handler - - """ - def __init__(self, s3_config=None): - super().__init__(s3_config=s3_config) - - def __call__(self, *args, **kwargs): - return AttrSaver(*args, **kwargs) - - def _clean_up_values(self, payload, file_extension): - # Dictionary to recursively write to - out_dict = {} - # All of the classes are defined at the top level - all_spock_cls = set(vars(payload).keys()) - out_dict = self._recursively_handle_clean(payload, out_dict, all_cls=all_spock_cls) - # Convert values - clean_dict = self._clean_output(out_dict) - return clean_dict - - def _recursively_handle_clean(self, payload, out_dict, parent_name=None, all_cls=None): - """Recursively works through spock classes and adds clean data to a dictionary - - Given a payload (Spockspace) work recursively through items that don't have parents to catch all - parameter definitions while correctly mapping nested class definitions to their base level class thus - allowing the output markdown to be a valid input file - - *Args*: - - payload: current payload (namespace) - out_dict: output dictionary - parent_name: name of the parent spock class if nested - all_cls: all top level spock class definitions - - *Returns*: - - out_dict: modified dictionary with the cleaned data - - """ - for key, val in vars(payload).items(): - val_name = type(val).__name__ - # This catches basic lists and list of classes - if isinstance(val, list): - # Check if each entry is a spock class - clean_val = [] - repeat_flag = False - for l_val in val: - cls_name = type(l_val).__name__ - # For those that are a spock class and are repeated (cls_name == key) simply convert to dict - if (cls_name in all_cls) and (cls_name == key): - clean_val.append(attr.asdict(l_val)) - # For those whose cls is different than the key just append the cls name - elif cls_name in all_cls: - # Change the flag as this is a repeated class -- which needs to be compressed into a single - # k:v pair - repeat_flag = True - clean_val.append(cls_name) - # Fall back to the passed in values - else: - clean_val.append(l_val) - # Handle repeated classes - if repeat_flag: - clean_val = list(set(clean_val))[-1] - out_dict.update({key: clean_val}) - # If it's a spock class but has a parent then just use the class name to reference the values - elif(val_name in all_cls) and parent_name is not None: - out_dict.update({key: val_name}) - # Check if it's a spock class without a parent -- iterate the values and recurse to catch more lists - elif val_name in all_cls: - new_dict = self._recursively_handle_clean(val, {}, parent_name=key, all_cls=all_cls) - out_dict.update({key: new_dict}) - # Either base type or no nested values that could be Spock classes - else: - out_dict.update({key: val}) - return out_dict diff --git a/spock/backend/base.py b/spock/backend/base.py deleted file mode 100644 index a6bd6364..00000000 --- a/spock/backend/base.py +++ /dev/null @@ -1,928 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2019 FMR LLC -# SPDX-License-Identifier: Apache-2.0 - -"""Handles base Spock classes""" - -from abc import ABC -from abc import abstractmethod -import argparse -import attr -from attr import NOTHING -from enum import EnumMeta -import os -from pathlib import Path -import re -import sys -from uuid import uuid1 -import yaml -from spock.handlers import JSONHandler -from spock.handlers import TOMLHandler -from spock.handlers import YAMLHandler -from spock.utils import add_info -from spock.utils import check_path_s3 -from spock.utils import make_argument -from typing import List - - -class Spockspace(argparse.Namespace): - """Inherits from Namespace to implement a pretty print on the obj - - Overwrites the __repr__ method with a pretty version of printing - - """ - def __init__(self, **kwargs): - super(Spockspace, self).__init__(**kwargs) - - def __repr__(self): - # Remove aliases in YAML print - yaml.Dumper.ignore_aliases = lambda *args: True - return yaml.dump(self.__dict__, default_flow_style=False) - - -class BaseHandler(ABC): - """Base class for saver and payload - - *Attributes*: - - _writers: maps file extension to the correct i/o handler - _s3_config: optional S3Config object to handle s3 access - - """ - def __init__(self, s3_config=None): - self._supported_extensions = {'.yaml': YAMLHandler, '.toml': TOMLHandler, '.json': JSONHandler} - self._s3_config = s3_config - - def _check_extension(self, file_extension: str): - if file_extension not in self._supported_extensions: - raise TypeError(f'File extension {file_extension} not supported -- \n' - f'File extension must be from {list(self._supported_extensions.keys())}') - - -class BaseSaver(BaseHandler): # pylint: disable=too-few-public-methods - """Base class for saving configs - - Contains methods to build a correct output payload and then writes to file based on the file - extension - - *Attributes*: - - _writers: maps file extension to the correct i/o handler - _s3_config: optional S3Config object to handle s3 access - - """ - def __init__(self, s3_config=None): - super(BaseSaver, self).__init__(s3_config=s3_config) - - def save(self, payload, path, file_name=None, create_save_path=False, extra_info=True, file_extension='.yaml'): #pylint: disable=too-many-arguments - """Writes Spock config to file - - Cleans and builds an output payload and then correctly writes it to file based on the - specified file extension - - *Args*: - - payload: current config payload - path: path to save - file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to uuid if None - create_save_path: boolean to create the path if non-existent - extra_info: boolean to write extra info - file_extension: what type of file to write - - *Returns*: - - None - - """ - # Check extension - self._check_extension(file_extension=file_extension) - # Make the filename -- always append a uuid for unique-ness - uuid_str = str(uuid1()) - fname = '' if file_name is None else f'{file_name}.' - name = f'{fname}{uuid_str}.spock.cfg{file_extension}' - # Fix up values -- parameters - out_dict = self._clean_up_values(payload, file_extension) - # Get extra info - extra_dict = add_info() if extra_info else None - try: - self._supported_extensions.get(file_extension)().save( - out_dict=out_dict, info_dict=extra_dict, path=str(path), name=name, - create_path=create_save_path, s3_config=self._s3_config - ) - except OSError as e: - print(f'Unable to write to given path: {path / name}') - raise e - - @abstractmethod - def _clean_up_values(self, payload, file_extension): - """Clean up the config payload so it can be written to file - - *Args*: - - payload: dirty payload - extra_info: boolean to add extra info - file_extension: type of file to write - - *Returns*: - - clean_dict: cleaned output payload - - """ - - def _clean_output(self, out_dict): - """Clean up the dictionary so it can be written to file - - *Args*: - - out_dict: cleaned dictionary - extra_info: boolean to add extra info - - *Returns*: - - clean_dict: cleaned output payload - - """ - # Convert values - clean_dict = {} - for key, val in out_dict.items(): - clean_inner_dict = {} - if isinstance(val, list): - for idx, list_val in enumerate(val): - tmp_dict = {} - for inner_key, inner_val in list_val.items(): - tmp_dict = self._convert(tmp_dict, inner_val, inner_key) - val[idx] = tmp_dict - clean_inner_dict = val - else: - for inner_key, inner_val in val.items(): - clean_inner_dict = self._convert(clean_inner_dict, inner_val, inner_key) - clean_dict.update({key: clean_inner_dict}) - return clean_dict - - def _convert(self, clean_inner_dict, inner_val, inner_key): - # Convert tuples to lists so they get written correctly - if isinstance(inner_val, tuple): - clean_inner_dict.update({inner_key: self._recursive_tuple_to_list(inner_val)}) - elif inner_val is not None: - clean_inner_dict.update({inner_key: inner_val}) - return clean_inner_dict - - def _recursive_tuple_to_list(self, value): - """Recursively turn tuples into lists - - Recursively looks through tuple(s) and convert to lists - - *Args*: - - value: value to check and set typ if necessary - typed: type of the generic alias to check against - - *Returns*: - - value: updated value with correct type casts - - """ - # Check for __args__ as it signifies a generic and make sure it's not already been cast as a tuple - # from a composed payload - list_v = [] - for v in value: - if isinstance(v, tuple): - v = self._recursive_tuple_to_list(v) - list_v.append(v) - else: - list_v.append(v) - return list_v - - -class BaseBuilder(ABC): # pylint: disable=too-few-public-methods - """Base class for building the backend specific builders - - This class handles the interface to the backend with the generic ConfigArgBuilder so that different - backends can be used to handle processing - - *Attributes* - - input_classes: list of input classes that link to a backend - _configs: None or List of configs to read from - _create_save_path: boolean to make the path to save to - _desc: description for the arg parser - _no_cmd_line: flag to force no command line reads - _max_indent: maximum to indent between help prints - save_path: list of path(s) to save the configs to - - """ - def __init__(self, *args, configs=None, create_save_path=False, desc='', no_cmd_line=False, - max_indent=4, **kwargs): - self.input_classes = args - self._configs = configs - self._create_save_path = create_save_path - self._desc = desc - self._no_cmd_line = no_cmd_line - self._max_indent = max_indent - self.save_path = None - - @abstractmethod - def print_usage_and_exit(self, msg=None, sys_exit=True): - """Prints the help message and exits - - *Args*: - - msg: message to print pre exit - - *Returns*: - - None - - """ - - @abstractmethod - def _handle_help_info(self): - """Handles walking through classes to get help info - - For each class this function will search __doc__ and attempt to pull out help information for both the class - itself and each attribute within the class - - *Returns*: - - None - - """ - - @abstractmethod - def _handle_arguments(self, args, class_obj): - """Handles all argument mapping - - Creates a dictionary of named parameters that are mapped to the final type of object - - *Args*: - - args: read file arguments - class_obj: instance of a class obj - - *Returns*: - - fields: dictionary of mapped parameters - - """ - - def generate(self, dict_args): - """Method to auto-generate the actual class instances from the generated args - - Based on the generated arguments groups and the args read in from the config file(s) - this function instantiates the classes with the necessary field or attr values - - *Args*: - - dict_args: dictionary of arguments from the configs - - *Returns*: - - namespace containing automatically generated instances of the classes - """ - auto_dict = {} - for attr_classes in self.input_classes: - attr_build = self._auto_generate(dict_args, attr_classes) - if isinstance(attr_build, list): - class_name = list({type(val).__name__ for val in attr_build}) - if len(class_name) > 1: - raise ValueError('Repeated class has more than one unique name') - auto_dict.update({class_name[0]: attr_build}) - else: - auto_dict.update({type(attr_build).__name__: attr_build}) - return Spockspace(**auto_dict) - # return argparse.Namespace(**auto_dict) - - def _auto_generate(self, args, input_class): - """Builds an instance of a DataClass - - Builds an instance with the necessary field values from the argument - dictionary read from the config file(s) - - *Args*: - - args: dictionary of arguments read from the config file(s) - data_class: data class to build - - *Returns*: - - An instance of data_class with correct values assigned to fields - """ - # Handle the basic data types - fields = self._handle_arguments(args, input_class) - if isinstance(fields, list): - return_value = fields - else: - self._handle_late_defaults(args, fields, input_class) - return_value = input_class(**fields) - return return_value - - def _handle_late_defaults(self, args, fields, input_class): - """Handles late defaults when the type is non-standard - - If the default type is not a base python type then we need to catch those defaults here and build the correct - values from the input classes while maintaining the optional nature. The trick is to exclude all 'base' types - as these defaults are covered by the attr default value - - *Args*: - - args: dictionary of arguments read from the config file(s) - fields: current fields returned from _handle_arguments - input_class: which input class being checked for late defaults - - *Returns*: - - fields: updated field dictionary with late defaults set - - """ - names = [val.name for val in input_class.__attrs_attrs__] - class_names = [val.__name__ for val in self.input_classes] - field_list = list(fields.keys()) - arg_list = list(args.keys()) - # Exclude all the base types that are supported -- these can be set by attrs - exclude_list = ['_Nothing', 'NoneType', 'bool', 'int', 'float', 'str', 'list', 'tuple'] - for val in names: - if val not in field_list: - default_type_name = type(getattr(input_class.__attrs_attrs__, val).default).__name__ - if default_type_name not in exclude_list: - default_name = getattr(input_class.__attrs_attrs__, val).default.__name__ - else: - default_name = None - if default_name is not None and default_name in arg_list: - if isinstance(args.get(default_name), list): - default_value = [self.input_classes[class_names.index(default_name)](**arg_val) - for arg_val in args.get(default_name)] - else: - default_value = self.input_classes[class_names.index(default_name)](**args.get(default_name)) - fields.update({val: default_value}) - return fields - - def get_config_paths(self): - """Get config paths from all methods - - Config paths can enter from either the command line or be added in the class init call - as a kwarg (configs=[]) - - *Returns*: - - args: namespace of args - - """ - # Check if the no_cmd_line is not flagged and if the configs are not empty - - if self._no_cmd_line and (self._configs is None): - raise ValueError("Flag set for preventing command line read but no paths were passed to the config kwarg") - if not self._no_cmd_line: - args = self._build_override_parsers(desc=self._desc) - else: - args = argparse.Namespace(config=[], help=False) - if self._configs is not None: - args = self._get_from_kwargs(args, self._configs) - return args - - def _build_override_parsers(self, desc): - """Creates parsers for command-line overrides - - Builds the basic command line parser for configs and help then iterates through each attr instance to make - namespace specific cmd line override parsers - - *Args*: - - desc: argparser description - - *Returns*: - - args: argument namespace - - """ - parser = argparse.ArgumentParser(description=desc, add_help=False) - parser.add_argument('-c', '--config', required=False, nargs='+', default=[]) - parser.add_argument('-h', '--help', action='store_true') - # Build out each class override specific parser - for val in self.input_classes: - parser = self._make_group_override_parser(parser=parser, class_obj=val) - args = parser.parse_args() - return args - - def _make_group_override_parser(self, parser, class_obj): - """Makes a name specific override parser for a given class obj - - Takes a class object of the backend and adds a new argument group with argument names given with name - Class.name so that individual parameters specific to a class can be overridden. - - *Args*: - - parser: argument parser - class_obj: instance of a backend class - - *Returns*: - - parser: argument parser with new class specific overrides - - """ - attr_name = class_obj.__name__ - group_parser = parser.add_argument_group(title=str(attr_name) + " Specific Overrides") - for val in class_obj.__attrs_attrs__: - val_type = val.metadata['type'] if 'type' in val.metadata else val.type - # Check if the val type has __args__ - # TODO (ncilfone): Fix up this super super ugly logic - if hasattr(val_type, '__args__') and ((list(set(val_type.__args__))[0]).__module__ == 'spock.backend.attr.config') and attr.has((list(set(val_type.__args__))[0])): - args = (list(set(val_type.__args__))[0]) - for inner_val in args.__attrs_attrs__: - arg_name = f"--{str(attr_name)}.{val.name}.{args.__name__}.{inner_val.name}" - group_parser = make_argument(arg_name, List[inner_val.type], group_parser) - else: - arg_name = f"--{str(attr_name)}.{val.name}" - group_parser = make_argument(arg_name, val_type, group_parser) - return parser - - @staticmethod - def _get_from_kwargs(args, configs): - """Get configs from the configs kwarg - - - *Args*: - - args: argument namespace - configs: config kwarg - - *Returns*: - - args: arg namespace - - """ - if type(configs).__name__ == 'list': - args.config.extend(configs) - else: - raise TypeError(f'configs kwarg must be of type list -- given {type(configs)}') - return args - - @staticmethod - def _find_attribute_idx(newline_split_docs): - """Finds the possible split between the header and Attribute annotations - - *Args*: - - newline_split_docs: new line split text - - Returns: - - idx: -1 if none or the idx of Attributes - - """ - for idx, val in enumerate(newline_split_docs): - re_check = re.search(r'(?i)Attribute?s?:', val) - if re_check is not None: - return idx - return -1 - - def _split_docs(self, obj): - """Possibly splits head class doc string from attribute docstrings - - Attempts to find the first contiguous line within the Google style docstring to use as the class docstring. - Splits the docs base on the Attributes tag if present. - - *Args*: - - obj: class object to rip info from - - *Returns*: - - class_doc: class docstring if present or blank str - attr_doc: list of attribute doc strings - - """ - if obj.__doc__ is not None: - # Split by new line - newline_split_docs = obj.__doc__.split('\n') - # Cleanup l/t whitespace - newline_split_docs = [val.strip() for val in newline_split_docs] - else: - newline_split_docs = [] - # Find the break between the class docs and the Attribute section -- if this returns -1 then there is no - # Attributes section - attr_idx = self._find_attribute_idx(newline_split_docs) - head_docs = newline_split_docs[:attr_idx] if attr_idx != -1 else newline_split_docs - attr_docs = newline_split_docs[attr_idx:] if attr_idx != -1 else [] - # Grab only the first contiguous line as everything else will probably be too verbose (e.g. the - # mid-level docstring that has detailed descriptions - class_doc = '' - for idx, val in enumerate(head_docs): - class_doc += f' {val}' - if idx + 1 != len(head_docs) and head_docs[idx + 1] == '': - break - # Clean up any l/t whitespace - class_doc = class_doc.strip() - return class_doc, attr_docs - - @staticmethod - def _match_attribute_docs(attr_name, attr_docs, attr_type_str, attr_default=NOTHING): - """Matches class attributes with attribute docstrings via regex - - *Args*: - - attr_name: attribute name - attr_docs: list of attribute docstrings - attr_type_str: str representation of the attribute type - attr_default: str representation of a possible default value - - *Returns*: - - dictionary of packed attribute information - - """ - # Regex match each value - a_str = None - for a_doc in attr_docs: - match_re = re.search(r'(?i)^' + attr_name + '?:', a_doc) - # Find only the first match -- if more than one than ignore - if match_re: - a_str = a_doc[match_re.end():].strip() - return {attr_name: { - 'type': attr_type_str, - 'desc': a_str if a_str is not None else "", - 'default': "(default: " + repr(attr_default) + ")" if type(attr_default).__name__ != '_Nothing' - else "", - 'len': {'name': len(attr_name), 'type': len(attr_type_str)} - }} - - def _handle_attributes_print(self, info_dict): - """Prints attribute information in an argparser style format - - *Args*: - - info_dict: packed attribute info dictionary to print - - """ - # Figure out indents - max_param_length = max([len(k) for k in info_dict.keys()]) - max_type_length = max([v['len']['type'] for v in info_dict.values()]) - # Print akin to the argparser - for k, v in info_dict.items(): - print(f' {k}' + (' ' * (max_param_length - v["len"]["name"] + self._max_indent)) + - f'{v["type"]}' + (' ' * (max_type_length - v["len"]["type"] + self._max_indent)) + - f'{v["desc"]} {v["default"]}') - # Blank for spacing :-/ - print('') - - def _extract_other_types(self, typed): - """Takes a high level type and recursively extracts any enum or class types - - *Args*: - - typed: highest level type - - *Returns*: - - return_list: list of nums (dot notation of module_path.enum_name or module_path.class_name) - - """ - return_list = [] - if hasattr(typed, '__args__'): - for val in typed.__args__: - recurse_return = self._extract_other_types(val) - if isinstance(recurse_return, list): - return_list.extend(recurse_return) - else: - return_list.append(self._extract_other_types(val)) - elif isinstance(typed, EnumMeta) or (typed.__module__ == 'spock.backend.attr.config'): - return f'{typed.__module__}.{typed.__name__}' - return return_list - - def _attrs_help(self, input_classes): - """Handles walking through a list classes to get help info - - For each class this function will search __doc__ and attempt to pull out help information for both the class - itself and each attribute within the class. If it finds a repeated class in a iterable object it will - recursively call self to handle information - - *Args*: - - input_classes: list of attr classes - - *Returns*: - - None - - """ - # List to catch Enums and classes and handle post spock wrapped attr classes - other_list = [] - covered_set = set() - for attrs_class in input_classes: - # Split the docs into class docs and any attribute docs - class_doc, attr_docs = self._split_docs(attrs_class) - print(' ' + attrs_class.__name__ + f' ({class_doc})') - # Keep a running info_dict of all the attribute level info - info_dict = {} - for val in attrs_class.__attrs_attrs__: - # If the type is an enum we need to handle it outside of this attr loop - # Match the style of nested enums and return a string of module.name notation - if isinstance(val.type, EnumMeta): - other_list.append(f'{val.type.__module__}.{val.type.__name__}') - # if there is a type (implied Iterable) -- check it for nested Enums or classes - nested_others = self._extract_other_types(val.metadata['type']) if 'type' in val.metadata else [] - if len(nested_others) > 0: - other_list.extend(nested_others) - # Grab the base or type info depending on what is provided - type_string = repr(val.metadata['type']) if 'type' in val.metadata else val.metadata['base'] - # Regex out the typing info if present - type_string = re.sub(r'typing.', '', type_string) - # Regex out any nested_others that have module path information - for other_val in nested_others: - split_other = f"{'.'.join(other_val.split('.')[:-1])}." - type_string = re.sub(split_other, '', type_string) - # Regex the string to see if it matches any Enums in the __main__ module space - # for val in sys.modules - # Construct the type with the metadata - if 'optional' in val.metadata: - type_string = f"Optional[{type_string}]" - info_dict.update(self._match_attribute_docs(val.name, attr_docs, type_string, val.default)) - # Add to covered so we don't print help twice in the case of some recursive nesting - covered_set.add(f'{attrs_class.__module__}.{attrs_class.__name__}') - self._handle_attributes_print(info_dict=info_dict) - # Convert the enum list to a set to remove dupes and then back to a list so it is iterable -- set diff to not - # repeat - other_list = list(set(other_list) - covered_set) - # Iterate any Enum type classes - for other in other_list: - # if it's longer than 2 then it's an embedded Spock class - if '.'.join(other.split('.')[:-1]) == 'spock.backend.attr.config': - class_type = self._get_from_sys_modules(other) - # Invoke recursive call for the class - self._attrs_help([class_type]) - # Fall back to enum style - else: - enum = self._get_from_sys_modules(other) - # Split the docs into class docs and any attribute docs - class_doc, attr_docs = self._split_docs(enum) - print(' ' + enum.__name__ + f' ({class_doc})') - info_dict = {} - for val in enum: - info_dict.update(self._match_attribute_docs( - attr_name=val.name, - attr_docs=attr_docs, - attr_type_str=type(val.value).__name__ - )) - self._handle_attributes_print(info_dict=info_dict) - - @staticmethod - def _get_from_sys_modules(cls_name): - """Gets the class from a dot notation name - - *Args*: - - cls_name: dot notation enum name - - *Returns*: - - module: enum class - - """ - # Split on dot notation - split_string = cls_name.split('.') - module = None - for idx, val in enumerate(split_string): - # idx = 0 will always be a call to the sys.modules dict - if idx == 0: - module = sys.modules[val] - # all other idx are paths along the module that need to be traversed - # idx = -1 will always be the final Enum object name we want to grab (final getattr call) - else: - module = getattr(module, val) - return module - - -class BasePayload(BaseHandler): # pylint: disable=too-few-public-methods - """Handles building the payload for config file(s) - - This class builds out the payload from config files of multiple types. It handles various - file types and also composition of config files via recursive calls - - *Attributes*: - - _loaders: maps of each file extension to the loader class - __s3_config: optional S3Config object to handle s3 access - - """ - def __init__(self, s3_config=None): - super(BasePayload, self).__init__(s3_config=s3_config) - - @staticmethod - @abstractmethod - def _update_payload(base_payload, input_classes, payload): - """Updates the payload - - Checks the parameters defined in the config files against the provided classes and if - passable adds them to the payload - - *Args*: - - base_payload: current payload - input_classes: class to roll into - payload: total payload - - *Returns*: - - payload: updated payload - - """ - - def payload(self, input_classes, path, cmd_args, deps): - """Builds the payload from config files - - Public exposed call to build the payload and set any command line overrides - - *Args*: - - input_classes: list of backend classes - path: path to config file(s) - cmd_args: command line overrides - deps: dictionary of config dependencies - - *Returns*: - - payload: dictionary of all mapped parameters - - """ - payload = self._payload(input_classes, path, deps, root=True) - payload = self._handle_overrides(payload, cmd_args) - return payload - - def _payload(self, input_classes, path, deps, root=False): - """Private call to construct the payload - - Main function call that builds out the payload from config files of multiple types. It handles - various file types and also composition of config files via a recursive calls - - *Args*: - input_classes: list of backend classes - path: path to config file(s) - deps: dictionary of config dependencies - - *Returns*: - - payload: dictionary of all mapped parameters - - """ - # Match to loader based on file-extension - config_extension = Path(path).suffix.lower() - # Verify extension - self._check_extension(file_extension=config_extension) - # Load from file - base_payload = self._supported_extensions.get(config_extension)().load(path, s3_config=self._s3_config) - # Check and? update the dependencies - deps = self._handle_dependencies(deps, path, root) - payload = {} - if 'config' in base_payload: - payload = self._handle_includes( - base_payload, config_extension, input_classes, path, payload, deps) - payload = self._update_payload(base_payload, input_classes, payload) - return payload - - @staticmethod - def _handle_dependencies(deps, path, root): - """Handles config file dependencies - - Checks to see if the config path (full or relative) has already been encountered. Essentially a DFS for graph - cycles - - *Args*: - - deps: dictionary of config dependencies - path: current config path - root: boolean if root - - *Returns*: - - deps: updated dependencies - - """ - if root and path in deps.get('paths'): - raise ValueError(f'Duplicate Read -- Config file {path} has already been encountered. ' - f'Please remove duplicate reads of config files.') - elif path in deps.get('paths') or path in deps.get('rel_paths'): - raise ValueError(f'Cyclical Dependency -- Config file {path} has already been encountered. ' - f'Please remove cyclical dependencies between config files.') - else: - # Update the dependency lists - deps.get('paths').append(path) - deps.get('rel_paths').append(os.path.basename(path)) - if root: - deps.get('roots').append(path) - return deps - - def _handle_includes(self, base_payload, config_extension, input_classes, path, payload, deps): # pylint: disable=too-many-arguments - """Handles config composition - - For all of the config tags in the config file this function will recursively call the payload function - with the composition path to get the additional payload(s) from the composed file(s) -- checks for file - validity or if it is an S3 URI via regex - - *Args*: - - base_payload: base payload that has a config kwarg - config_extension: file type - input_classes: defined backend classes - path: path to base file - payload: payload pulled from composed files - deps: dictionary of config dependencies - - *Returns*: - - payload: payload update from composed files - - """ - included_params = {} - for inc_path in base_payload['config']: - if check_path_s3(inc_path): - use_path = inc_path - elif os.path.exists(inc_path): - use_path = inc_path - elif os.path.join(os.path.dirname(path), inc_path): - use_path = os.path.join(os.path.dirname(path), inc_path) - else: - raise RuntimeError(f'Could not find included {config_extension} file {inc_path} or is not an S3 URI!') - included_params.update(self._payload(input_classes, use_path, deps)) - payload.update(included_params) - return payload - - def _handle_overrides(self, payload, args): - """Handle command line overrides - - Iterate through the command line override values, determine at what level to set them, and set them if possible - - *Args*: - - payload: current payload dictionary - args: command line override args - - *Returns*: - - payload: updated payload dictionary with override values set - - """ - skip_keys = ['config', 'help'] - for k, v in vars(args).items(): - if k not in skip_keys and v is not None: - payload = self._handle_payload_override(payload, k, v) - return payload - - @staticmethod - def _handle_payload_override(payload, key, value): - """Handles the complex logic needed for List[spock class] overrides - - Messy logic that sets overrides for the various different types. The hardest being List[spock class] since str - names have to be mapped backed to sys.modules and can be set at either the general or class level. - - *Args*: - - payload: current payload dictionary - key: current arg key - value: value at current arg key - - *Returns*: - - payload: modified payload with overrides - - """ - key_split = key.split('.') - curr_ref = payload - for idx, split in enumerate(key_split): - # If the root isn't in the payload then it needs to be added but only for the first key split - if idx == 0 and (split not in payload): - payload.update({split: {}}) - # Check for curr_ref switch over -- verify by checking the sys modules names - if idx != 0 and (split in payload) and (isinstance(curr_ref, str)) and (hasattr(sys.modules['spock'].backend.attr.config, split)): - curr_ref = payload[split] - elif idx != 0 and (split in payload) and (isinstance(payload[split], str)) and (hasattr(sys.modules['spock'].backend.attr.config, payload[split])): - curr_ref = payload[split] - # elif check if it's the last value and figure out the override - elif idx == (len(key_split)-1): - # Handle bool(s) a bit differently as they are store_true - if isinstance(curr_ref, dict) and isinstance(value, bool): - if value is not False: - curr_ref[split] = value - # If we are at the dictionary level we should be able to just payload override - elif isinstance(curr_ref, dict) and not isinstance(value, bool): - curr_ref[split] = value - # If we are at a list level it must be some form of repeated class since this is the end of the class - # tree -- check the instance type but also make sure the cmd-line override is the correct len - elif isinstance(curr_ref, list) and len(value) == len(curr_ref): - # Walk the list and check for the key - for ref_idx, val in enumerate(curr_ref): - if split in val: - val[split] = value[ref_idx] - else: - raise ValueError(f'cmd-line override failed for {key} -- ' - f'Failed to find key {split} within lowest level List[Dict]') - elif isinstance(curr_ref, list) and len(value) != len(curr_ref): - raise ValueError(f'cmd-line override failed for {key} -- ' - f'Specified key {split} with len {len(value)} does not match len {len(curr_ref)} ' - f'of List[Dict]') - else: - raise ValueError(f'cmd-line override failed for {key} -- ' - f'Failed to find key {split} within lowest level Dict') - # If it's not keep walking the current payload - else: - curr_ref = curr_ref[split] - return payload diff --git a/spock/backend/builder.py b/spock/backend/builder.py new file mode 100644 index 00000000..e3ec2d1b --- /dev/null +++ b/spock/backend/builder.py @@ -0,0 +1,849 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles the building/saving of the configurations from the Spock config classes""" + +import re +import sys +from abc import ABC, abstractmethod +from enum import EnumMeta +from typing import List + +import attr +from attr import NOTHING + +from spock.backend.wrappers import Spockspace +from spock.utils import make_argument + + +class BaseBuilder(ABC): # pylint: disable=too-few-public-methods + """Base class for building the backend specific builders + + This class handles the interface to the backend with the generic ConfigArgBuilder so that different + backends can be used to handle processing + + *Attributes* + + input_classes: list of input classes that link to a backend + _configs: None or List of configs to read from + _desc: description for the arg parser + _no_cmd_line: flag to force no command line reads + _max_indent: maximum to indent between help prints + save_path: list of path(s) to save the configs to + + """ + + def __init__(self, *args, max_indent=4, module_name, **kwargs): + self.input_classes = args + self._module_name = module_name + self._max_indent = max_indent + self.save_path = None + + @staticmethod + @abstractmethod + def _make_group_override_parser(parser, class_obj, class_name): + """Makes a name specific override parser for a given class obj + + Takes a class object of the backend and adds a new argument group with argument names given with name + Class.name so that individual parameters specific to a class can be overridden. + + *Args*: + + parser: argument parser + class_obj: instance of a backend class + class_name: used for module matching + + *Returns*: + + parser: argument parser with new class specific overrides + + """ + + def handle_help_info(self): + """Handles walking through classes to get help info + + For each class this function will search __doc__ and attempt to pull out help information for both the class + itself and each attribute within the class + + *Returns*: + + None + + """ + self._attrs_help(self.input_classes, self._module_name) + + def _handle_arguments(self, args, class_obj): + """Handles all argument mapping + + Creates a dictionary of named parameters that are mapped to the final type of object + + *Args*: + + args: read file arguments + class_obj: instance of a class obj + + *Returns*: + + fields: dictionary of mapped parameters + + """ + attr_name = class_obj.__name__ + class_names = [val.__name__ for val in self.input_classes] + # Handle repeated classes + if ( + attr_name in class_names + and attr_name in args + and isinstance(args[attr_name], list) + ): + fields = self._handle_repeated(args[attr_name], attr_name, class_names) + # Handle non-repeated classes + else: + fields = {} + for val in class_obj.__attrs_attrs__: + # Check if namespace is named and then check for key -- checking for local class def + if attr_name in args and val.name in args[attr_name]: + fields[val.name] = self._handle_nested_class( + args, args[attr_name][val.name], class_names + ) + # If not named then just check for keys -- checking for global def + elif val.name in args: + fields[val.name] = self._handle_nested_class( + args, args[val.name], class_names + ) + # Check for special keys to set + if ( + "special_key" in val.metadata + and val.metadata["special_key"] is not None + ): + if val.name in args: + self.save_path = args[val.name] + elif val.default is not None: + self.save_path = val.default + return fields + + def _handle_repeated(self, args, check_value, class_names): + """Handles repeated classes as lists + + *Args*: + + args: dictionary of arguments from the configs + check_value: value to check classes against + class_names: current class names + + *Returns*: + + list of input_class[match)idx[0]] types filled with repeated values + + """ + # Check to see if the value trying to be set is actually an input class + match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] + return [self.input_classes[match_idx[0]](**val) for val in args] + + def _handle_nested_class(self, args, check_value, class_names): + """Handles passing another class to the field dictionary + + *Args*: + args: dictionary of arguments from the configs + check_value: value to check classes against + class_names: current class names + + *Returns*: + + either the check_value or the necessary class + + """ + # Check to see if the value trying to be set is actually an input class + match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] + # If so then create the needed class object by unrolling the args to **kwargs and return it + if len(match_idx) > 0: + if len(match_idx) > 1: + raise ValueError( + "Match error -- multiple classes with the same name definition" + ) + else: + if args.get(self.input_classes[match_idx[0]].__name__) is None: + raise ValueError( + f"Missing config file definition for the referenced class " + f"{self.input_classes[match_idx[0]].__name__}" + ) + current_arg = args.get(self.input_classes[match_idx[0]].__name__) + if isinstance(current_arg, list): + class_value = [ + self.input_classes[match_idx[0]](**val) for val in current_arg + ] + else: + class_value = self.input_classes[match_idx[0]](**current_arg) + return_value = class_value + # else return the expected value + else: + return_value = check_value + return return_value + + def generate(self, dict_args): + """Method to auto-generate the actual class instances from the generated args + + Based on the generated arguments groups and the args read in from the config file(s) + this function instantiates the classes with the necessary field or attr values + + *Args*: + + dict_args: dictionary of arguments from the configs + + *Returns*: + + namespace containing automatically generated instances of the classes + """ + auto_dict = {} + for attr_classes in self.input_classes: + attr_build = self._auto_generate(dict_args, attr_classes) + if isinstance(attr_build, list): + class_name = list({type(val).__name__ for val in attr_build}) + if len(class_name) > 1: + raise ValueError("Repeated class has more than one unique name") + auto_dict.update({class_name[0]: attr_build}) + else: + auto_dict.update({type(attr_build).__name__: attr_build}) + return Spockspace(**auto_dict) + + def _auto_generate(self, args, input_class): + """Builds an instance of an attr class + + Builds an instance with the necessary field values from the argument + dictionary read from the config file(s) + + *Args*: + + args: dictionary of arguments read from the config file(s) + data_class: data class to build + + *Returns*: + + An instance of data_class with correct values assigned to fields + """ + # Handle the basic data types + fields = self._handle_arguments(args, input_class) + if isinstance(fields, list): + return_value = fields + else: + self._handle_late_defaults(args, fields, input_class) + return_value = input_class(**fields) + return return_value + + def _handle_late_defaults(self, args, fields, input_class): + """Handles late defaults when the type is non-standard + + If the default type is not a base python type then we need to catch those defaults here and build the correct + values from the input classes while maintaining the optional nature. The trick is to exclude all 'base' types + as these defaults are covered by the attr default value + + *Args*: + + args: dictionary of arguments read from the config file(s) + fields: current fields returned from _handle_arguments + input_class: which input class being checked for late defaults + + *Returns*: + + fields: updated field dictionary with late defaults set + + """ + names = [val.name for val in input_class.__attrs_attrs__] + class_names = [val.__name__ for val in self.input_classes] + field_list = list(fields.keys()) + arg_list = list(args.keys()) + # Exclude all the base types that are supported -- these can be set by attrs + exclude_list = [ + "_Nothing", + "NoneType", + "bool", + "int", + "float", + "str", + "list", + "tuple", + ] + for val in names: + if val not in field_list: + default_type_name = type( + getattr(input_class.__attrs_attrs__, val).default + ).__name__ + if default_type_name not in exclude_list: + default_name = getattr( + input_class.__attrs_attrs__, val + ).default.__name__ + else: + default_name = None + if default_name is not None and default_name in arg_list: + if isinstance(args.get(default_name), list): + default_value = [ + self.input_classes[class_names.index(default_name)]( + **arg_val + ) + for arg_val in args.get(default_name) + ] + else: + default_value = self.input_classes[ + class_names.index(default_name) + ](**args.get(default_name)) + fields.update({val: default_value}) + return fields + + def build_override_parsers(self, parser): + """Creates parsers for command-line overrides + + Builds the basic command line parser for configs and help then iterates through each attr instance to make + namespace specific cmd line override parsers + + *Args*: + + parser: argument parser + + *Returns*: + + parser: argument parser with new class specific overrides + + """ + # Build out each class override specific parser + for val in self.input_classes: + parser = self._make_group_override_parser( + parser=parser, class_obj=val, class_name=self._module_name + ) + return parser + + @staticmethod + def _get_from_kwargs(args, configs): + """Get configs from the configs kwarg + + *Args*: + + args: argument namespace + configs: config kwarg + + *Returns*: + + args: arg namespace + + """ + if isinstance(configs, list): + args.config.extend(configs) + else: + raise TypeError( + f"configs kwarg must be of type list -- given {type(configs)}" + ) + return args + + @staticmethod + def _find_attribute_idx(newline_split_docs): + """Finds the possible split between the header and Attribute annotations + + *Args*: + + newline_split_docs: new line split text + + Returns: + + idx: -1 if none or the idx of Attributes + + """ + for idx, val in enumerate(newline_split_docs): + re_check = re.search(r"(?i)Attribute?s?:", val) + if re_check is not None: + return idx + return -1 + + def _split_docs(self, obj): + """Possibly splits head class doc string from attribute docstrings + + Attempts to find the first contiguous line within the Google style docstring to use as the class docstring. + Splits the docs base on the Attributes tag if present. + + *Args*: + + obj: class object to rip info from + + *Returns*: + + class_doc: class docstring if present or blank str + attr_doc: list of attribute doc strings + + """ + if obj.__doc__ is not None: + # Split by new line + newline_split_docs = obj.__doc__.split("\n") + # Cleanup l/t whitespace + newline_split_docs = [val.strip() for val in newline_split_docs] + else: + newline_split_docs = [] + # Find the break between the class docs and the Attribute section -- if this returns -1 then there is no + # Attributes section + attr_idx = self._find_attribute_idx(newline_split_docs) + head_docs = ( + newline_split_docs[:attr_idx] if attr_idx != -1 else newline_split_docs + ) + attr_docs = newline_split_docs[attr_idx:] if attr_idx != -1 else [] + # Grab only the first contiguous line as everything else will probably be too verbose (e.g. the + # mid-level docstring that has detailed descriptions + class_doc = "" + for idx, val in enumerate(head_docs): + class_doc += f" {val}" + if idx + 1 != len(head_docs) and head_docs[idx + 1] == "": + break + # Clean up any l/t whitespace + class_doc = class_doc.strip() + if len(class_doc) > 0: + class_doc = f"-- {class_doc}" + return class_doc, attr_docs + + @staticmethod + def _match_attribute_docs( + attr_name, attr_docs, attr_type_str, attr_default=NOTHING + ): + """Matches class attributes with attribute docstrings via regex + + *Args*: + + attr_name: attribute name + attr_docs: list of attribute docstrings + attr_type_str: str representation of the attribute type + attr_default: str representation of a possible default value + + *Returns*: + + dictionary of packed attribute information + + """ + # Regex match each value + a_str = None + for a_doc in attr_docs: + match_re = re.search(r"(?i)^" + attr_name + "?:", a_doc) + # Find only the first match -- if more than one than ignore + if match_re: + a_str = a_doc[match_re.end() :].strip() + return { + attr_name: { + "type": attr_type_str, + "desc": a_str if a_str is not None else "", + "default": "(default: " + repr(attr_default) + ")" + if type(attr_default).__name__ != "_Nothing" + else "", + "len": {"name": len(attr_name), "type": len(attr_type_str)}, + } + } + + def _handle_attributes_print(self, info_dict): + """Prints attribute information in an argparser style format + + *Args*: + + info_dict: packed attribute info dictionary to print + + """ + # Figure out indents + max_param_length = max([len(k) for k in info_dict.keys()]) + max_type_length = max([v["len"]["type"] for v in info_dict.values()]) + # Print akin to the argparser + for k, v in info_dict.items(): + print( + f" {k}" + + (" " * (max_param_length - v["len"]["name"] + self._max_indent)) + + f'{v["type"]}' + + (" " * (max_type_length - v["len"]["type"] + self._max_indent)) + + f'{v["desc"]} {v["default"]}' + ) + # Blank for spacing :-/ + print("") + + def _extract_other_types(self, typed, module_name): + """Takes a high level type and recursively extracts any enum or class types + + *Args*: + + typed: highest level type + module_name: name of module to match + + *Returns*: + + return_list: list of nums (dot notation of module_path.enum_name or module_path.class_name) + + """ + return_list = [] + if hasattr(typed, "__args__"): + for val in typed.__args__: + recurse_return = self._extract_other_types(val, module_name) + if isinstance(recurse_return, list): + return_list.extend(recurse_return) + else: + return_list.append(self._extract_other_types(val, module_name)) + elif isinstance(typed, EnumMeta) or (typed.__module__ == module_name): + return [f"{typed.__module__}.{typed.__name__}"] + return return_list + + def _attrs_help(self, input_classes, module_name): + """Handles walking through a list classes to get help info + + For each class this function will search __doc__ and attempt to pull out help information for both the class + itself and each attribute within the class. If it finds a repeated class in a iterable object it will + recursively call self to handle information + + *Args*: + + input_classes: list of attr classes + module_name: name of module to match + + *Returns*: + + None + + """ + # Handle the main loop + other_list = self._handle_help_main(input_classes, module_name) + self._handle_help_enums(other_list=other_list, module_name=module_name) + + @staticmethod + def _get_type_string(val, nested_others): + """Gets the type of the attr val as a string + + *Args*: + + val: current attr being processed + nested_others: list of nested others to deal with that might have module path info in the string + + *Returns*: + + type_string: type of the attr as a str + + """ + # Grab the base or type info depending on what is provided + if "type" in val.metadata: + type_string = repr(val.metadata["type"]) + elif "base" in val.metadata: + type_string = val.metadata["base"] + elif hasattr(val.type, "__name__"): + type_string = val.type.__name__ + else: + type_string = str(val.type) + # Regex out the typing info if present + type_string = re.sub(r"typing.", "", type_string) + # Regex out any nested_others that have module path information + for other_val in nested_others: + split_other = f"{'.'.join(other_val.split('.')[:-1])}." + type_string = re.sub(split_other, "", type_string) + # Regex the string to see if it matches any Enums in the __main__ module space + # Construct the type with the metadata + if "optional" in val.metadata: + type_string = f"Optional[{type_string}]" + return type_string + + def _handle_help_main(self, input_classes, module_name): + """Handles the print of the main class types + + *Args*: + + input_classes: current set of input classes + module_name: module name to match + + *Returns*: + + other_list: extended list of other classes/enums to process + + """ + # List to catch Enums and classes and handle post spock wrapped attr classes + other_list = [] + covered_set = set() + for attrs_class in input_classes: + # Split the docs into class docs and any attribute docs + class_doc, attr_docs = self._split_docs(attrs_class) + print(" " + attrs_class.__name__ + f" {class_doc}") + # Keep a running info_dict of all the attribute level info + info_dict = {} + for val in attrs_class.__attrs_attrs__: + # If the type is an enum we need to handle it outside of this attr loop + # Match the style of nested enums and return a string of module.name notation + if isinstance(val.type, EnumMeta): + other_list.append(f"{val.type.__module__}.{val.type.__name__}") + # if there is a type (implied Iterable) -- check it for nested Enums or classes + nested_others = self._extract_fnc(val, module_name) + if len(nested_others) > 0: + other_list.extend(nested_others) + # Get the type represented as a string + type_string = self._get_type_string(val, nested_others) + info_dict.update( + self._match_attribute_docs( + val.name, attr_docs, type_string, val.default + ) + ) + # Add to covered so we don't print help twice in the case of some recursive nesting + covered_set.add(f"{attrs_class.__module__}.{attrs_class.__name__}") + self._handle_attributes_print(info_dict=info_dict) + # Convert the enum list to a set to remove dupes and then back to a list so it is iterable -- set diff to not + # repeat + return list(set(other_list) - covered_set) + + def _handle_help_enums(self, other_list, module_name): + """handles any extra enums from non main args + + *Args*: + + other_list: extended list of other classes/enums to process + module_name: module name to match + + *Returns*: + + None + + """ + # Iterate any Enum type classes + for other in other_list: + # if it's longer than 2 then it's an embedded Spock class + if ".".join(other.split(".")[:-1]) == module_name: + class_type = self._get_from_sys_modules(other) + # Invoke recursive call for the class + self._attrs_help([class_type], module_name) + # Fall back to enum style + else: + enum = self._get_from_sys_modules(other) + # Split the docs into class docs and any attribute docs + class_doc, attr_docs = self._split_docs(enum) + print(" " + enum.__name__ + f" ({class_doc})") + info_dict = {} + for val in enum: + info_dict.update( + self._match_attribute_docs( + attr_name=val.name, + attr_docs=attr_docs, + attr_type_str=type(val.value).__name__, + ) + ) + self._handle_attributes_print(info_dict=info_dict) + + @abstractmethod + def _extract_fnc(self, val, module_name): + """Function that gets the nested lists within classes + + *Args*: + + val: current attr + module_name: matching module name + + *Returns*: + + list of any nested classes/enums + + """ + + @staticmethod + def _get_from_sys_modules(cls_name): + """Gets the class from a dot notation name + + *Args*: + + cls_name: dot notation enum name + + *Returns*: + + module: enum class + + """ + # Split on dot notation + split_string = cls_name.split(".") + module = None + for idx, val in enumerate(split_string): + # idx = 0 will always be a call to the sys.modules dict + if idx == 0: + module = sys.modules[val] + # all other idx are paths along the module that need to be traversed + # idx = -1 will always be the final Enum object name we want to grab (final getattr call) + else: + module = getattr(module, val) + return module + + +class AttrBuilder(BaseBuilder): + """Attr specific builder + + Class that handles building for the attr backend + + *Attributes* + + input_classes: list of input classes that link to a backend + _configs: None or List of configs to read from + _create_save_path: boolean to make the path to save to + _desc: description for the arg parser + _no_cmd_line: flag to force no command line reads + save_path: list of path(s) to save the configs to + + """ + + def __init__(self, *args, **kwargs): + """AttrBuilder init + + Args: + *args: list of input classes that link to a backend + configs: None or List of configs to read from + desc: description for the arg parser + no_cmd_line: flag to force no command line reads + **kwargs: any extra keyword args + """ + super().__init__(*args, module_name="spock.backend.config", **kwargs) + + @staticmethod + def _make_group_override_parser(parser, class_obj, class_name): + """Makes a name specific override parser for a given class obj + + Takes a class object of the backend and adds a new argument group with argument names given with name + Class.name so that individual parameters specific to a class can be overridden. + + *Args*: + + parser: argument parser + class_obj: instance of a backend class + class_name: used for module matching + + *Returns*: + + parser: argument parser with new class specific overrides + + """ + attr_name = class_obj.__name__ + group_parser = parser.add_argument_group( + title=str(attr_name) + " Specific Overrides" + ) + for val in class_obj.__attrs_attrs__: + val_type = val.metadata["type"] if "type" in val.metadata else val.type + # Check if the val type has __args__ -- this catches lists? + # TODO (ncilfone): Fix up this super super ugly logic + if ( + hasattr(val_type, "__args__") + and ((list(set(val_type.__args__))[0]).__module__ == class_name) + and attr.has((list(set(val_type.__args__))[0])) + ): + args = list(set(val_type.__args__))[0] + for inner_val in args.__attrs_attrs__: + arg_name = f"--{str(attr_name)}.{val.name}.{args.__name__}.{inner_val.name}" + group_parser = make_argument( + arg_name, List[inner_val.type], group_parser + ) + # If it's a reference to a class it needs to be an arg of a simple string as class matching will take care + # of it later on + elif val_type.__module__ == "spock.backend.config": + arg_name = f"--{str(attr_name)}.{val.name}" + val_type = str + group_parser = make_argument(arg_name, val_type, group_parser) + else: + arg_name = f"--{str(attr_name)}.{val.name}" + group_parser = make_argument(arg_name, val_type, group_parser) + return parser + + def _handle_arguments(self, args, class_obj): + attr_name = class_obj.__name__ + class_names = [val.__name__ for val in self.input_classes] + # Handle repeated classes + if ( + attr_name in class_names + and attr_name in args + and isinstance(args[attr_name], list) + ): + fields = self._handle_repeated(args[attr_name], attr_name, class_names) + # Handle non-repeated classes + else: + fields = {} + for val in class_obj.__attrs_attrs__: + # Check if namespace is named and then check for key -- checking for local class def + if attr_name in args and val.name in args[attr_name]: + fields[val.name] = self._handle_nested_class( + args, args[attr_name][val.name], class_names + ) + # If not named then just check for keys -- checking for global def + elif val.name in args: + fields[val.name] = self._handle_nested_class( + args, args[val.name], class_names + ) + # Check for special keys to set + if ( + "special_key" in val.metadata + and val.metadata["special_key"] is not None + ): + if val.name in args: + self.save_path = args[val.name] + elif val.default is not None: + self.save_path = val.default + return fields + + def _handle_repeated(self, args, check_value, class_names): + """Handles repeated classes as lists + + *Args*: + + args: dictionary of arguments from the configs + check_value: value to check classes against + class_names: current class names + + *Returns*: + + list of input_class[match)idx[0]] types filled with repeated values + + """ + # Check to see if the value trying to be set is actually an input class + match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] + return [self.input_classes[match_idx[0]](**val) for val in args] + + def _handle_nested_class(self, args, check_value, class_names): + """Handles passing another class to the field dictionary + + *Args*: + args: dictionary of arguments from the configs + check_value: value to check classes against + class_names: current class names + + *Returns*: + + either the check_value or the necessary class + + """ + # Check to see if the value trying to be set is actually an input class + match_idx = [idx for idx, val in enumerate(class_names) if val == check_value] + # If so then create the needed class object by unrolling the args to **kwargs and return it + if len(match_idx) > 0: + if len(match_idx) > 1: + raise ValueError( + "Match error -- multiple classes with the same name definition" + ) + else: + if args.get(self.input_classes[match_idx[0]].__name__) is None: + raise ValueError( + f"Missing config file definition for the referenced class " + f"{self.input_classes[match_idx[0]].__name__}" + ) + current_arg = args.get(self.input_classes[match_idx[0]].__name__) + if isinstance(current_arg, list): + class_value = [ + self.input_classes[match_idx[0]](**val) for val in current_arg + ] + else: + class_value = self.input_classes[match_idx[0]](**current_arg) + return_value = class_value + # else return the expected value + else: + return_value = check_value + return return_value + + def _extract_fnc(self, val, module_name): + """Function that gets the nested lists within classes + + *Args*: + + val: current attr + module_name: matching module name + + *Returns*: + + list of any nested classes/enums + + """ + return ( + self._extract_other_types(val.metadata["type"], module_name) + if "type" in val.metadata + else [] + ) diff --git a/spock/backend/attr/config.py b/spock/backend/config.py similarity index 67% rename from spock/backend/attr/config.py rename to spock/backend/config.py index a081d14d..9f481294 100644 --- a/spock/backend/attr/config.py +++ b/spock/backend/config.py @@ -1,16 +1,18 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Creates the spock config interface that wraps attr""" import sys + import attr -from spock.backend.attr.typed import katra +from spock.backend.typed import katra -def spock_attr(cls): + +def _base_attr(cls): """Map type hints to katras Connector function that maps type hinting style to the defined katra style which uses the more strict @@ -24,7 +26,6 @@ def spock_attr(cls): cls: slotted attrs class that is frozen and kw only """ - # Since we are not using the @attr.s decorator we need to get the parent classes for inheritance # We do this by using the mro and grabbing anything that is not the first and last indices in the list and wrapping # it into a tuple @@ -35,7 +36,7 @@ def spock_attr(cls): bases = () # Make a blank attrs dict for new attrs attrs_dict = {} - if hasattr(cls, '__annotations__'): + if hasattr(cls, "__annotations__"): for k, v in cls.__annotations__.items(): # If the cls has the attribute then a default was set if hasattr(cls, k): @@ -43,10 +44,30 @@ def spock_attr(cls): else: default = None attrs_dict.update({k: katra(typed=v, default=default)}) + return bases, attrs_dict + + +def spock_attr(cls): + """Map type hints to katras + + Connector function that maps type hinting style to the defined katra style which uses the more strict + attr.ib() definition + + *Args*: + + cls: basic class def + + *Returns*: + + cls: slotted attrs class that is frozen and kw only + """ + bases, attrs_dict = _base_attr(cls) # Dynamically make an attr class - obj = attr.make_class(name=cls.__name__, bases=bases, attrs=attrs_dict, kw_only=True, frozen=True) + obj = attr.make_class( + name=cls.__name__, bases=bases, attrs=attrs_dict, kw_only=True, frozen=True + ) # For each class we dynamically create we need to register it within the system modules for pickle to work - setattr(sys.modules['spock'].backend.attr.config, obj.__name__, obj) + setattr(sys.modules["spock"].backend.config, obj.__name__, obj) # Swap the __doc__ string from cls to obj obj.__doc__ = cls.__doc__ return obj diff --git a/spock/backend/handler.py b/spock/backend/handler.py new file mode 100644 index 00000000..534b9b38 --- /dev/null +++ b/spock/backend/handler.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Base handler Spock class""" + +from abc import ABC + +from spock.handlers import JSONHandler, TOMLHandler, YAMLHandler + + +class BaseHandler(ABC): + """Base class for saver and payload + + *Attributes*: + + _writers: maps file extension to the correct i/o handler + _s3_config: optional S3Config object to handle s3 access + + """ + + def __init__(self, s3_config=None): + self._supported_extensions = { + ".yaml": YAMLHandler, + ".toml": TOMLHandler, + ".json": JSONHandler, + } + self._s3_config = s3_config + + def _check_extension(self, file_extension: str): + if file_extension not in self._supported_extensions: + raise TypeError( + f"File extension {file_extension} not supported -- \n" + f"File extension must be from {list(self._supported_extensions.keys())}" + ) diff --git a/spock/backend/payload.py b/spock/backend/payload.py new file mode 100644 index 00000000..96bf0ba2 --- /dev/null +++ b/spock/backend/payload.py @@ -0,0 +1,487 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles payloads from markup files""" + +import os +import sys +from abc import abstractmethod +from itertools import chain +from pathlib import Path + +from spock.backend.handler import BaseHandler +from spock.backend.utils import convert_to_tuples, deep_update, get_type_fields +from spock.utils import check_path_s3 + + +class BasePayload(BaseHandler): # pylint: disable=too-few-public-methods + """Handles building the payload for config file(s) + + This class builds out the payload from config files of multiple types. It handles various + file types and also composition of config files via recursive calls + + *Attributes*: + + _loaders: maps of each file extension to the loader class + __s3_config: optional S3Config object to handle s3 access + + """ + + def __init__(self, s3_config=None): + super(BasePayload, self).__init__(s3_config=s3_config) + + @staticmethod + @abstractmethod + def _update_payload(base_payload, input_classes, ignore_classes, payload): + """Updates the payload + + Checks the parameters defined in the config files against the provided classes and if + passable adds them to the payload + + *Args*: + + base_payload: current payload + input_classes: class to roll into + ignore_classes: list of classes to ignore + payload: total payload + + *Returns*: + + payload: updated payload + + """ + + def payload(self, input_classes, ignore_classes, path, cmd_args, deps): + """Builds the payload from config files + + Public exposed call to build the payload and set any command line overrides + + *Args*: + + input_classes: list of backend classes + ignore_classes: list of classes to ignore + path: path to config file(s) + cmd_args: command line overrides + deps: dictionary of config dependencies + + *Returns*: + + payload: dictionary of all mapped parameters + + """ + payload = self._payload(input_classes, ignore_classes, path, deps, root=True) + payload = self._handle_overrides(payload, ignore_classes, cmd_args) + return payload + + def _payload(self, input_classes, ignore_classes, path, deps, root=False): + """Private call to construct the payload + + Main function call that builds out the payload from config files of multiple types. It handles + various file types and also composition of config files via a recursive calls + + *Args*: + input_classes: list of backend classes + ignore_classes: list of classes to ignore + path: path to config file(s) + deps: dictionary of config dependencies + + *Returns*: + + payload: dictionary of all mapped parameters + + """ + # empty payload + payload = {} + if path is not None: + # Match to loader based on file-extension + config_extension = Path(path).suffix.lower() + # Verify extension + self._check_extension(file_extension=config_extension) + # Load from file + base_payload = self._supported_extensions.get(config_extension)().load( + path, s3_config=self._s3_config + ) + # Check and? update the dependencies + deps = self._handle_dependencies(deps, path, root) + if "config" in base_payload: + payload = self._handle_includes( + base_payload, + config_extension, + input_classes, + ignore_classes, + path, + payload, + deps, + ) + payload = self._update_payload( + base_payload, input_classes, ignore_classes, payload + ) + return payload + + @staticmethod + def _handle_dependencies(deps, path, root): + """Handles config file dependencies + + Checks to see if the config path (full or relative) has already been encountered. Essentially a DFS for graph + cycles + + *Args*: + + deps: dictionary of config dependencies + path: current config path + root: boolean if root + + *Returns*: + + deps: updated dependencies + + """ + if root and path in deps.get("paths"): + raise ValueError( + f"Duplicate Read -- Config file {path} has already been encountered. " + f"Please remove duplicate reads of config files." + ) + elif path in deps.get("paths") or path in deps.get("rel_paths"): + raise ValueError( + f"Cyclical Dependency -- Config file {path} has already been encountered. " + f"Please remove cyclical dependencies between config files." + ) + else: + # Update the dependency lists + deps.get("paths").append(path) + deps.get("rel_paths").append(os.path.basename(path)) + if root: + deps.get("roots").append(path) + return deps + + def _handle_includes( + self, + base_payload, + config_extension, + input_classes, + ignore_classes, + path, + payload, + deps, + ): # pylint: disable=too-many-arguments + """Handles config composition + + For all of the config tags in the config file this function will recursively call the payload function + with the composition path to get the additional payload(s) from the composed file(s) -- checks for file + validity or if it is an S3 URI via regex + + *Args*: + + base_payload: base payload that has a config kwarg + config_extension: file type + input_classes: defined backend classes + ignore_classes: list of classes to ignore + path: path to base file + payload: payload pulled from composed files + deps: dictionary of config dependencies + + *Returns*: + + payload: payload update from composed files + + """ + included_params = {} + for inc_path in base_payload["config"]: + if check_path_s3(inc_path): + use_path = inc_path + elif os.path.exists(inc_path): + use_path = inc_path + elif os.path.join(os.path.dirname(path), inc_path): + use_path = os.path.join(os.path.dirname(path), inc_path) + else: + raise RuntimeError( + f"Could not find included {config_extension} file {inc_path} or is not an S3 URI!" + ) + included_params.update( + self._payload(input_classes, ignore_classes, use_path, deps) + ) + payload.update(included_params) + return payload + + def _handle_overrides(self, payload, ignore_classes, args): + """Handle command line overrides + + Iterate through the command line override values, determine at what level to set them, and set them if possible + + *Args*: + + payload: current payload dictionary + args: command line override args + + *Returns*: + + payload: updated payload dictionary with override values set + + """ + skip_keys = ["config", "help"] + pruned_args = self._prune_args(args, ignore_classes) + for k, v in pruned_args.items(): + if k not in skip_keys and v is not None: + payload = self._handle_payload_override(payload, k, v) + return payload + + @staticmethod + def _prune_args(args, ignore_classes): + """Prunes ignored class names from the cmd line args list to prevent incorrect access + + *Args*: + + args: current cmd line args + ignore_classes: list of class names to ignore + + *Returns*: + + dictionary of pruned cmd line args + + """ + ignored_stems = [val.__name__ for val in ignore_classes] + return { + k: v for k, v in vars(args).items() if k.split(".")[0] not in ignored_stems + } + + @staticmethod + @abstractmethod + def _handle_payload_override(payload, key, value): + """Handles the complex logic needed for List[spock class] overrides + + Messy logic that sets overrides for the various different types. The hardest being List[spock class] since str + names have to be mapped backed to sys.modules and can be set at either the general or class level. + + *Args*: + + payload: current payload dictionary + key: current arg key + value: value at current arg key + + *Returns*: + + payload: modified payload with overrides + + """ + + +class AttrPayload(BasePayload): + """Handles building the payload for attrs backend + + This class builds out the payload from config files of multiple types. It handles various + file types and also composition of config files via a recursive calls + + *Attributes*: + + _loaders: maps of each file extension to the loader class + + """ + + def __init__(self, s3_config=None): + """Init for AttrPayload + + *Args*: + + s3_config: optional S3 config object + + """ + super().__init__(s3_config=s3_config) + + def __call__(self, *args, **kwargs): + """Call to allow self chaining + + *Args*: + + *args: + **kwargs: + + *Returns*: + + Payload: instance of self + + """ + return AttrPayload(*args, **kwargs) + + @staticmethod + def _update_payload(base_payload, input_classes, ignore_classes, payload): + # Get basic args + attr_fields = { + attr.__name__: [val.name for val in attr.__attrs_attrs__] + for attr in input_classes + } + # Get the ignore fields + ignore_fields = { + attr.__name__: [val.name for val in attr.__attrs_attrs__] + for attr in ignore_classes + } + # Class names + class_names = [val.__name__ for val in input_classes] + # Parse out the types if generic + type_fields = get_type_fields(input_classes) + for keys, values in base_payload.items(): + if keys not in ignore_fields: + # check if the keys, value pair is expected by the attr class + if keys != "config": + # Dict infers that we are overriding a global setting in a specific config + if isinstance(values, dict): + # we're in a namespace + # Check for incorrect specific override of global def + if keys not in attr_fields: + raise TypeError( + f"Referring to a class space {keys} that is undefined" + ) + for i_keys in values.keys(): + if i_keys not in attr_fields[keys]: + raise ValueError( + f"Provided an unknown argument named {keys}.{i_keys}" + ) + else: + # Check if the key is actually a reference to another class + if keys in class_names: + if isinstance(values, list): + # Check for incorrect specific override of global def + if keys not in attr_fields: + raise ValueError( + f"Referring to a class space {keys} that is undefined" + ) + # We are in a repeated class def + # Raise if the key set is different from the defined set (i.e. incorrect arguments) + key_set = set( + list(chain(*[list(val.keys()) for val in values])) + ) + for i_keys in key_set: + if i_keys not in attr_fields[keys]: + raise ValueError( + f"Provided an unknown argument named {keys}.{i_keys}" + ) + # Chain all the values from multiple spock classes into one list + elif keys not in list(chain(*attr_fields.values())): + raise ValueError( + f"Provided an unknown argument named {keys}" + ) + # Chain all the values from multiple spock classes into one list + elif keys not in list(chain(*attr_fields.values())): + raise ValueError( + f"Provided an unknown argument named {keys}" + ) + if keys in payload and isinstance(values, dict): + payload[keys].update(values) + else: + payload[keys] = values + tuple_payload = convert_to_tuples(payload, type_fields, class_names) + payload = deep_update(payload, tuple_payload) + return payload + + @staticmethod + def _handle_payload_override(payload, key, value): + """Handles the complex logic needed for List[spock class] overrides + + Messy logic that sets overrides for the various different types. The hardest being List[spock class] since str + names have to be mapped backed to sys.modules and can be set at either the general or class level. + + *Args*: + + payload: current payload dictionary + key: current arg key + value: value at current arg key + + *Returns*: + + payload: modified payload with overrides + + """ + key_split = key.split(".") + curr_ref = payload + # Handle non existing parts of the payload for specific cases + root_classes = [ + idx + for idx, val in enumerate(key_split) + if hasattr(sys.modules["spock"].backend.config, val) + ] + # Verify any classes have roots in the payload dict + for idx in root_classes: + # Update all root classes if not present + if key_split[idx] not in payload: + payload.update({key_split[idx]: {}}) + # If not updating the root then it is a reference to another class which might not be in the payload + # Make sure it's there by setting it -- since this is an override setting is fine as these should be the + # final say in the param values so don't worry about clashing + if idx != 0: + payload[key_split[0]][key_split[idx - 1]] = key_split[idx] + # Check also for repeated classes -- value will be a list when the type is not + var = getattr( + getattr( + sys.modules["spock"].backend.config, key_split[idx] + ).__attrs_attrs__, + key_split[-1], + ) + if isinstance(value, list) and var.type != list: + # If the dict is blank we need to handle the creation of the list of dicts + if len(payload[key_split[idx]]) == 0: + payload.update( + { + key_split[idx]: [ + {key_split[-1]: None} for _ in range(len(value)) + ] + } + ) + # If it's already partially filled we need to update not overwrite + else: + for val in payload[key_split[idx]]: + val.update({key_split[-1]: None}) + + for idx, split in enumerate(key_split): + # Check for curr_ref switch over -- verify by checking the sys modules names + if ( + idx != 0 + and (split in payload) + and (isinstance(curr_ref, str)) + and (hasattr(sys.modules["spock"].backend.config, split)) + ): + curr_ref = payload[split] + # Look ahead to check if the next value exists in the dictionary + elif ( + idx != 0 + and (split in payload) + and (isinstance(payload[split], str)) + and (hasattr(sys.modules["spock"].backend.config, payload[split])) + ): + curr_ref = payload[split] + # elif check if it's the last value and figure out the override + elif idx == (len(key_split) - 1): + # Handle bool(s) a bit differently as they are store_true + if isinstance(curr_ref, dict) and isinstance(value, bool): + if value is not False: + curr_ref[split] = value + # If we are at the dictionary level we should be able to just payload override + elif isinstance(curr_ref, dict) and not isinstance(value, bool): + curr_ref[split] = value + # If we are at a list level it must be some form of repeated class since this is the end of the class + # tree -- check the instance type but also make sure the cmd-line override is the correct len + elif isinstance(curr_ref, list) and len(value) == len(curr_ref): + # Walk the list and check for the key + for ref_idx, val in enumerate(curr_ref): + if split in val: + val[split] = value[ref_idx] + else: + raise ValueError( + f"cmd-line override failed for {key} -- " + f"Failed to find key {split} within lowest level List[Dict]" + ) + elif isinstance(curr_ref, list) and len(value) != len(curr_ref): + raise ValueError( + f"cmd-line override failed for {key} -- " + f"Specified key {split} with len {len(value)} does not match len {len(curr_ref)} " + f"of List[Dict]" + ) + else: + raise ValueError( + f"cmd-line override failed for {key} -- " + f"Failed to find key {split} within lowest level Dict" + ) + # If it's not keep walking the current payload + else: + curr_ref = curr_ref[split] + return payload diff --git a/spock/backend/saver.py b/spock/backend/saver.py new file mode 100644 index 00000000..63d9ff80 --- /dev/null +++ b/spock/backend/saver.py @@ -0,0 +1,257 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles prepping and saving the Spock config""" + +from abc import abstractmethod +from uuid import uuid1 + +import attr + +from spock.backend.handler import BaseHandler +from spock.utils import add_info + + +class BaseSaver(BaseHandler): # pylint: disable=too-few-public-methods + """Base class for saving configs + + Contains methods to build a correct output payload and then writes to file based on the file + extension + + *Attributes*: + + _writers: maps file extension to the correct i/o handler + _s3_config: optional S3Config object to handle s3 access + + """ + + def __init__(self, s3_config=None): + super(BaseSaver, self).__init__(s3_config=s3_config) + + def save( + self, + payload, + path, + file_name=None, + create_save_path=False, + extra_info=True, + file_extension=".yaml", + ): # pylint: disable=too-many-arguments + """Writes Spock config to file + + Cleans and builds an output payload and then correctly writes it to file based on the + specified file extension + + *Args*: + + payload: current config payload + path: path to save + file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to uuid if None + create_save_path: boolean to create the path if non-existent + extra_info: boolean to write extra info + file_extension: what type of file to write + + *Returns*: + + None + + """ + # Check extension + self._check_extension(file_extension=file_extension) + # Make the filename -- always append a uuid for unique-ness + uuid_str = str(uuid1()) + fname = "" if file_name is None else f"{file_name}." + name = f"{fname}{uuid_str}.spock.cfg{file_extension}" + # Fix up values -- parameters + out_dict = self._clean_up_values(payload, file_extension) + # Get extra info + extra_dict = add_info() if extra_info else None + try: + self._supported_extensions.get(file_extension)().save( + out_dict=out_dict, + info_dict=extra_dict, + path=str(path), + name=name, + create_path=create_save_path, + s3_config=self._s3_config, + ) + except OSError as e: + print(f"Unable to write to given path: {path / name}") + raise e + + @abstractmethod + def _clean_up_values(self, payload, file_extension): + """Clean up the config payload so it can be written to file + + *Args*: + + payload: dirty payload + extra_info: boolean to add extra info + file_extension: type of file to write + + *Returns*: + + clean_dict: cleaned output payload + + """ + + def _clean_output(self, out_dict): + """Clean up the dictionary so it can be written to file + + *Args*: + + out_dict: cleaned dictionary + extra_info: boolean to add extra info + + *Returns*: + + clean_dict: cleaned output payload + + """ + # Convert values + clean_dict = {} + for key, val in out_dict.items(): + clean_inner_dict = {} + if isinstance(val, list): + for idx, list_val in enumerate(val): + tmp_dict = {} + for inner_key, inner_val in list_val.items(): + tmp_dict = self._convert(tmp_dict, inner_val, inner_key) + val[idx] = tmp_dict + clean_inner_dict = val + else: + for inner_key, inner_val in val.items(): + clean_inner_dict = self._convert( + clean_inner_dict, inner_val, inner_key + ) + clean_dict.update({key: clean_inner_dict}) + return clean_dict + + def _convert(self, clean_inner_dict, inner_val, inner_key): + # Convert tuples to lists so they get written correctly + if isinstance(inner_val, tuple): + clean_inner_dict.update( + {inner_key: self._recursive_tuple_to_list(inner_val)} + ) + elif inner_val is not None: + clean_inner_dict.update({inner_key: inner_val}) + return clean_inner_dict + + def _recursive_tuple_to_list(self, value): + """Recursively turn tuples into lists + + Recursively looks through tuple(s) and convert to lists + + *Args*: + + value: value to check and set typ if necessary + typed: type of the generic alias to check against + + *Returns*: + + value: updated value with correct type casts + + """ + # Check for __args__ as it signifies a generic and make sure it's not already been cast as a tuple + # from a composed payload + list_v = [] + for v in value: + if isinstance(v, tuple): + v = self._recursive_tuple_to_list(v) + list_v.append(v) + else: + list_v.append(v) + return list_v + + +class AttrSaver(BaseSaver): + """Base class for saving configs for the attrs backend + + Contains methods to build a correct output payload and then writes to file based on the file + extension + + *Attributes*: + + _writers: maps file extension to the correct i/o handler + + """ + + def __init__(self, s3_config=None): + super().__init__(s3_config=s3_config) + + def __call__(self, *args, **kwargs): + return AttrSaver(*args, **kwargs) + + def _clean_up_values(self, payload, file_extension): + # Dictionary to recursively write to + out_dict = {} + # All of the classes are defined at the top level + all_spock_cls = set(vars(payload).keys()) + out_dict = self._recursively_handle_clean( + payload, out_dict, all_cls=all_spock_cls + ) + # Convert values + clean_dict = self._clean_output(out_dict) + return clean_dict + + def _recursively_handle_clean( + self, payload, out_dict, parent_name=None, all_cls=None + ): + """Recursively works through spock classes and adds clean data to a dictionary + + Given a payload (Spockspace) work recursively through items that don't have parents to catch all + parameter definitions while correctly mapping nested class definitions to their base level class thus + allowing the output markdown to be a valid input file + + *Args*: + + payload: current payload (namespace) + out_dict: output dictionary + parent_name: name of the parent spock class if nested + all_cls: all top level spock class definitions + + *Returns*: + + out_dict: modified dictionary with the cleaned data + + """ + for key, val in vars(payload).items(): + val_name = type(val).__name__ + # This catches basic lists and list of classes + if isinstance(val, list): + # Check if each entry is a spock class + clean_val = [] + repeat_flag = False + for l_val in val: + cls_name = type(l_val).__name__ + # For those that are a spock class and are repeated (cls_name == key) simply convert to dict + if (cls_name in all_cls) and (cls_name == key): + clean_val.append(attr.asdict(l_val)) + # For those whose cls is different than the key just append the cls name + elif cls_name in all_cls: + # Change the flag as this is a repeated class -- which needs to be compressed into a single + # k:v pair + repeat_flag = True + clean_val.append(cls_name) + # Fall back to the passed in values + else: + clean_val.append(l_val) + # Handle repeated classes + if repeat_flag: + clean_val = list(set(clean_val))[-1] + out_dict.update({key: clean_val}) + # If it's a spock class but has a parent then just use the class name to reference the values + elif (val_name in all_cls) and parent_name is not None: + out_dict.update({key: val_name}) + # Check if it's a spock class without a parent -- iterate the values and recurse to catch more lists + elif val_name in all_cls: + new_dict = self._recursively_handle_clean( + val, {}, parent_name=key, all_cls=all_cls + ) + out_dict.update({key: new_dict}) + # Either base type or no nested values that could be Spock classes + else: + out_dict.update({key: val}) + return out_dict diff --git a/spock/backend/attr/typed.py b/spock/backend/typed.py similarity index 72% rename from spock/backend/attr/typed.py rename to spock/backend/typed.py index 8452e111..fcb8678b 100644 --- a/spock/backend/attr/typed.py +++ b/spock/backend/typed.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Handles the definitions of arguments types for Spock (backend: attrs)""" @@ -8,9 +8,10 @@ import sys from enum import EnumMeta from functools import partial -from typing import TypeVar -from typing import Union +from typing import TypeVar, Union + import attr + minor = sys.version_info.minor if minor < 7: from typing import GenericMeta as _GenericAlias @@ -24,6 +25,7 @@ class SavePath(str): Defines a special key use to save the current Spock config to file """ + def __new__(cls, x): return super().__new__(cls, x) @@ -40,7 +42,7 @@ def _get_name_py_version(typed): name of the type """ - return typed._name if hasattr(typed, '_name') else typed.__name__ + return typed._name if hasattr(typed, "_name") else typed.__name__ def _extract_base_type(typed): @@ -57,7 +59,7 @@ def _extract_base_type(typed): name of type """ - if hasattr(typed, '__args__'): + if hasattr(typed, "__args__"): name = _get_name_py_version(typed=typed) bracket_val = f"{name}[{_extract_base_type(typed.__args__[0])}]" return bracket_val @@ -81,19 +83,28 @@ def _recursive_generic_validator(typed): return_type: recursively built deep_iterable validators """ - if hasattr(typed, '__args__'): + if hasattr(typed, "__args__"): # If there are more __args__ then we still need to recurse as it is still a GenericAlias - return_type = attr.validators.deep_iterable( - member_validator=_recursive_generic_validator(typed.__args__[0]), - iterable_validator=attr.validators.instance_of(typed.__origin__) - ) + # Iterate through since there might be multiple types? + if len(typed.__args__) > 1: + return_type = attr.validators.deep_iterable( + member_validator=_recursive_generic_validator(typed.__args__), + iterable_validator=attr.validators.instance_of(typed.__origin__), + ) + else: + return_type = attr.validators.deep_iterable( + member_validator=_recursive_generic_validator(typed.__args__[0]), + iterable_validator=attr.validators.instance_of(typed.__origin__), + ) return return_type else: # If no more __args__ then we are to the base type and need to bubble up the type # But we need to check against base types and enums if isinstance(typed, EnumMeta): base_type, allowed = _check_enum_props(typed) - return_type = attr.validators.and_(attr.validators.instance_of(base_type), attr.validators.in_(allowed)) + return_type = attr.validators.and_( + attr.validators.instance_of(base_type), attr.validators.in_(allowed) + ) else: return_type = attr.validators.instance_of(typed) return return_type @@ -122,19 +133,34 @@ def _generic_alias_katra(typed, default=None, optional=False): # base python class from which a GenericAlias is derived base_typed = typed.__origin__ if default is not None: - x = attr.ib(validator=_recursive_generic_validator(typed), default=default, type=base_typed, - metadata={'base': _extract_base_type(typed), 'type': typed}) + x = attr.ib( + validator=_recursive_generic_validator(typed), + default=default, + type=base_typed, + metadata={"base": _extract_base_type(typed), "type": typed}, + ) # x = attr.ib(validator=_recursive_generic_iterator(typed), default=default, type=base_typed, # metadata={'base': _extract_base_type(typed)}) elif optional: # if there's no default, but marked as optional, then set the default to None - x = attr.ib(validator=attr.validators.optional(_recursive_generic_validator(typed)), type=base_typed, - default=default, metadata={'optional': True, 'base': _extract_base_type(typed), 'type': typed}) + x = attr.ib( + validator=attr.validators.optional(_recursive_generic_validator(typed)), + type=base_typed, + default=default, + metadata={ + "optional": True, + "base": _extract_base_type(typed), + "type": typed, + }, + ) # x = attr.ib(validator=attr.validators.optional(_recursive_generic_iterator(typed)), type=base_typed, # default=default, metadata={'optional': True, 'base': _extract_base_type(typed)}) else: - x = attr.ib(validator=_recursive_generic_validator(typed), type=base_typed, - metadata={'base': _extract_base_type(typed), 'type': typed}) + x = attr.ib( + validator=_recursive_generic_validator(typed), + type=base_typed, + metadata={"base": _extract_base_type(typed), "type": typed}, + ) # x = attr.ib(validator=_recursive_generic_iterator(typed), type=base_typed, # metadata={'base': _extract_base_type(typed)}) return x @@ -184,10 +210,18 @@ def _enum_katra(typed, default=None, optional=False): """ # First check if the types of Enum are the same base_type, allowed = _check_enum_props(typed) - if base_type.__name__ == 'type': - x = _enum_class_katra(typed=typed, allowed=allowed, default=default, optional=optional) + if base_type.__name__ == "type": + x = _enum_class_katra( + typed=typed, allowed=allowed, default=default, optional=optional + ) else: - x = _enum_base_katra(typed=typed, base_type=base_type, allowed=allowed, default=default, optional=optional) + x = _enum_base_katra( + typed=typed, + base_type=base_type, + allowed=allowed, + default=default, + optional=optional, + ) return x @@ -214,15 +248,32 @@ def _enum_base_katra(typed, base_type, allowed, default=None, optional=False): """ if default is not None: x = attr.ib( - validator=[attr.validators.instance_of(base_type), attr.validators.in_(allowed)], - default=default, type=typed, metadata={'base': typed.__name__}) + validator=[ + attr.validators.instance_of(base_type), + attr.validators.in_(allowed), + ], + default=default, + type=typed, + metadata={"base": typed.__name__}, + ) elif optional: x = attr.ib( - validator=attr.validators.optional([attr.validators.instance_of(base_type), attr.validators.in_(allowed)]), - default=default, type=typed, metadata={'base': typed.__name__, 'optional': True}) + validator=attr.validators.optional( + [attr.validators.instance_of(base_type), attr.validators.in_(allowed)] + ), + default=default, + type=typed, + metadata={"base": typed.__name__, "optional": True}, + ) else: - x = attr.ib(validator=[attr.validators.instance_of(base_type), attr.validators.in_(allowed)], type=typed, - metadata={'base': typed.__name__}) + x = attr.ib( + validator=[ + attr.validators.instance_of(base_type), + attr.validators.in_(allowed), + ], + type=typed, + metadata={"base": typed.__name__}, + ) return x @@ -242,7 +293,7 @@ def _in_type(instance, attribute, value, options): """ if type(value) not in options: - raise ValueError(f'{attribute.name} must be in {options}') + raise ValueError(f"{attribute.name} must be in {options}") def _enum_class_katra(typed, allowed, default=None, optional=False): @@ -269,14 +320,24 @@ def _enum_class_katra(typed, allowed, default=None, optional=False): """ if default is not None: x = attr.ib( - validator=[partial(_in_type, options=allowed)], default=default, type=typed, - metadata={'base': typed.__name__}) + validator=[partial(_in_type, options=allowed)], + default=default, + type=typed, + metadata={"base": typed.__name__}, + ) elif optional: x = attr.ib( validator=attr.validators.optional([partial(_in_type, options=allowed)]), - default=default, type=typed, metadata={'base': typed.__name__, 'optional': True}) + default=default, + type=typed, + metadata={"base": typed.__name__, "optional": True}, + ) else: - x = attr.ib(validator=[partial(_in_type, options=allowed)], type=typed, metadata={'base': typed.__name__}) + x = attr.ib( + validator=[partial(_in_type, options=allowed)], + type=typed, + metadata={"base": typed.__name__}, + ) return x @@ -306,7 +367,7 @@ def _type_katra(typed, default=None, optional=False): elif isinstance(typed, _GenericAlias): name = _get_name_py_version(typed=typed) else: - raise TypeError('Encountered an unexpected type in _type_katra') + raise TypeError("Encountered an unexpected type in _type_katra") special_key = None # Default booleans to false and optional due to the nature of a boolean if isinstance(typed, type) and name == "bool": @@ -320,14 +381,25 @@ def _type_katra(typed, default=None, optional=False): typed = str if default is not None: # if a default is provided, that takes precedence - x = attr.ib(validator=attr.validators.instance_of(typed), default=default, type=typed, - metadata={'base': name, 'special_key': special_key}) + x = attr.ib( + validator=attr.validators.instance_of(typed), + default=default, + type=typed, + metadata={"base": name, "special_key": special_key}, + ) elif optional: - x = attr.ib(validator=attr.validators.optional(attr.validators.instance_of(typed)), default=default, type=typed, - metadata={'optional': True, 'base': name, 'special_key': special_key}) + x = attr.ib( + validator=attr.validators.optional(attr.validators.instance_of(typed)), + default=default, + type=typed, + metadata={"optional": True, "base": name, "special_key": special_key}, + ) else: - x = attr.ib(validator=attr.validators.instance_of(typed), type=typed, metadata={'base': name, - 'special_key': special_key}) + x = attr.ib( + validator=attr.validators.instance_of(typed), + type=typed, + metadata={"base": name, "special_key": special_key}, + ) return x @@ -349,7 +421,7 @@ def _handle_optional_typing(typed): # Set optional to false optional = False # Check if it has __args__ to look for optionality as it is a GenericAlias - if hasattr(typed, '__args__'): + if hasattr(typed, "__args__"): # If it is more than one than it is most likely optional but check against NoneType in the tuple to verify # Check the length of type __args__ type_args = typed.__args__ @@ -364,6 +436,8 @@ def _handle_optional_typing(typed): def _check_generic_recursive_single_type(typed): """Checks generics for the single types -- mixed types of generics are not allowed + DEPRECATED -- NOW SUPPORTS MIXED TYPES OF TUPLES + *Args*: typed: type @@ -372,13 +446,14 @@ def _check_generic_recursive_single_type(typed): """ # Check if it has __args__ to look for optionality as it is a GenericAlias - if hasattr(typed, '__args__'): - if len(set(typed.__args__)) > 1: - type_list = [str(val) for val in typed.__args__] - raise TypeError(f"Passing multiple different subscript types to GenericAlias is not supported: {type_list}") - else: - for val in typed.__args__: - _check_generic_recursive_single_type(typed=val) + # if hasattr(typed, '__args__'): + # if len(set(typed.__args__)) > 1: + # type_list = [str(val) for val in typed.__args__] + # raise TypeError(f"Passing multiple different subscript types to GenericAlias is not supported: {type_list}") + # else: + # for val in typed.__args__: + # _check_generic_recursive_single_type(typed=val) + pass def katra(typed, default=None): @@ -405,7 +480,9 @@ def katra(typed, default=None): _check_generic_recursive_single_type(typed) # We need to check if the type is a _GenericAlias so that we can handle subscripted general types # If it is subscript typed it will not be T which python uses as a generic type name - if isinstance(typed, _GenericAlias) and (not isinstance(typed.__args__[0], TypeVar)): + if isinstance(typed, _GenericAlias) and ( + not isinstance(typed.__args__[0], TypeVar) + ): x = _generic_alias_katra(typed=typed, default=default, optional=optional) elif isinstance(typed, EnumMeta): x = _enum_katra(typed=typed, default=default, optional=optional) diff --git a/spock/backend/attr/utils.py b/spock/backend/utils.py similarity index 83% rename from spock/backend/attr/utils.py rename to spock/backend/utils.py index 51e84e63..c4f1ed3d 100644 --- a/spock/backend/attr/utils.py +++ b/spock/backend/utils.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Attr utility functions for Spock""" @@ -23,8 +23,8 @@ def get_type_fields(input_classes): for attr in input_classes: input_attr = {} for val in attr.__attrs_attrs__: - if 'type' in val.metadata: - input_attr.update({val.name: val.metadata['type']}) + if "type" in val.metadata: + input_attr.update({val.name: val.metadata["type"]}) else: input_attr.update({val.name: None}) type_fields.update({attr.__name__: input_attr}) @@ -73,14 +73,16 @@ def convert_to_tuples(input_dict, named_type_dict, class_names): updated_dict = {} all_typed_dict = flatten_type_dict(named_type_dict) for k, v in input_dict.items(): - if k != 'config': + if k != "config": if isinstance(v, dict): updated = convert_to_tuples(v, named_type_dict.get(k), class_names) if updated: updated_dict.update({k: updated}) elif isinstance(v, list) and k in class_names: for val in v: - updated = convert_to_tuples(val, named_type_dict.get(k), class_names) + updated = convert_to_tuples( + val, named_type_dict.get(k), class_names + ) if updated: updated_dict.update({k: updated}) elif all_typed_dict[k] is not None: @@ -133,18 +135,25 @@ def _recursive_list_to_tuple(value, typed, class_names): """ # Check for __args__ as it signifies a generic and make sure it's not already been cast as a tuple # from a composed payload - if hasattr(typed, '__args__') and not isinstance(value, tuple) and not (isinstance(value, str) - and value in class_names): + if ( + hasattr(typed, "__args__") + and not isinstance(value, tuple) + and not (isinstance(value, str) and value in class_names) + ): # Force those with origin tuple types to be of the defined length - if (typed.__origin__.__name__.lower() == 'tuple') and len(value) != len(typed.__args__): - raise ValueError(f'Tuple(s) use a fixed/defined length -- Length of the provided argument ({len(value)}) ' - f'does not match the length of the defined argument ({len(typed.__args__)})') + if (typed.__origin__.__name__.lower() == "tuple") and len(value) != len( + typed.__args__ + ): + raise ValueError( + f"Tuple(s) use a fixed/defined length -- Length of the provided argument ({len(value)}) " + f"does not match the length of the defined argument ({len(typed.__args__)})" + ) # need to recurse before casting as we can't set values in a tuple with idx # Since it's generic it should be iterable to recurse and check it's children for idx, val in enumerate(value): value[idx] = _recursive_list_to_tuple(val, typed.__args__[0], class_names) # First check if list and then swap to tuple if the origin is tuple - if isinstance(value, list) and typed.__origin__.__name__.lower() == 'tuple': + if isinstance(value, list) and typed.__origin__.__name__.lower() == "tuple": value = tuple(value) else: return value diff --git a/spock/backend/wrappers.py b/spock/backend/wrappers.py new file mode 100644 index 00000000..ffe2b9c6 --- /dev/null +++ b/spock/backend/wrappers.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +# Copyright FMR LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Handles Spock data type wrappers""" + +import argparse + +import yaml + + +class Spockspace(argparse.Namespace): + """Inherits from Namespace to implement a pretty print on the obj + + Overwrites the __repr__ method with a pretty version of printing + + """ + + def __init__(self, **kwargs): + super(Spockspace, self).__init__(**kwargs) + + def __repr__(self): + # Remove aliases in YAML print + yaml.Dumper.ignore_aliases = lambda *args: True + return yaml.dump(self.__dict__, default_flow_style=False) diff --git a/spock/builder.py b/spock/builder.py index d6782064..5ca36245 100644 --- a/spock/builder.py +++ b/spock/builder.py @@ -1,18 +1,21 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Handles the building/saving of the configurations from the Spock config classes""" +import argparse +import sys +import typing from pathlib import Path + import attr -from spock.backend.attr.builder import AttrBuilder -from spock.backend.attr.payload import AttrPayload -from spock.backend.attr.saver import AttrSaver -from spock.utils import check_payload_overwrite -from spock.utils import deep_payload_update -import typing + +from spock.backend.builder import AttrBuilder +from spock.backend.payload import AttrPayload +from spock.backend.saver import AttrSaver +from spock.utils import check_payload_overwrite, deep_payload_update class ConfigArgBuilder: @@ -25,27 +28,85 @@ class ConfigArgBuilder: *Attributes*: + _args: all command line args _arg_namespace: generated argument namespace _builder_obj: instance of a BaseBuilder class - _create_save_path: boolean to make the path to save to _dict_args: dictionary args from the command line _payload_obj: instance of a BasePayload class _saver_obj: instance of a BaseSaver class + _tune_payload_obj: payload for tuner related objects -- instance of TunerPayload class + _tune_obj: instance of TunerBuilder class + _tuner_interface: interface that handles the underlying library for sampling -- instance of TunerInterface + _tuner_state: current state of the hyper-parameter sampler + _tune_namespace: namespace that hold the generated tuner related parameters + _sample_count: current call to the sample function """ - def __init__(self, *args, configs: typing.Optional[typing.List] = None, create_save_path: bool = False, - desc: str = '', no_cmd_line: bool = False, s3_config=None, **kwargs): - backend = self._set_backend(args) - self._create_save_path = create_save_path - self._builder_obj = backend.get('builder')( - *args, configs=configs, create_save_path=create_save_path, desc=desc, no_cmd_line=no_cmd_line, **kwargs) - self._payload_obj = backend.get('payload')(s3_config=s3_config) - self._saver_obj = backend.get('saver')(s3_config=s3_config) + + def __init__( + self, + *args, + configs: typing.Optional[typing.List] = None, + desc: str = "", + no_cmd_line: bool = False, + s3_config=None, + **kwargs, + ): + """Init call for ConfigArgBuilder + + *Args*: + + *args: tuple of spock decorated classes to process + configs: list of config paths + desc: description for help + no_cmd_line: turn off cmd line args + s3_config: s3Config object for S3 support + **kwargs: keyword args + + """ + # Do some verification first + self._verify_attr(args) + self._configs = configs + self._no_cmd_line = no_cmd_line + self._desc = desc + # Build the payload and saver objects + self._payload_obj = AttrPayload(s3_config=s3_config) + self._saver_obj = AttrSaver(s3_config=s3_config) + # Split the fixed parameters from the tuneable ones (if present) + fixed_args, tune_args = self._strip_tune_parameters(args) + # The fixed parameter builder + self._builder_obj = AttrBuilder(*fixed_args, **kwargs) + # The possible tunable parameter builder -- might return None + self._tune_obj, self._tune_payload_obj = self._handle_tuner_objects( + tune_args, s3_config, kwargs + ) + self._tuner_interface = None + self._tuner_state = None + self._sample_count = 0 try: - self._dict_args = self._get_payload() + # Get all cmd line args and build overrides + self._args = self._handle_cmd_line() + # Get the actual payload from the config files -- fixed configs + self._dict_args = self._get_payload( + payload_obj=self._payload_obj, + input_classes=self._builder_obj.input_classes, + ignore_args=tune_args, + ) + # Build the Spockspace from the payload and the classes + # Fixed configs self._arg_namespace = self._builder_obj.generate(self._dict_args) + # Get the payload from the config files -- hyper-parameters -- only if the obj is not None + if self._tune_obj is not None: + self._tune_args = self._get_payload( + payload_obj=self._tune_payload_obj, + input_classes=self._tune_obj.input_classes, + ignore_args=fixed_args, + ) + # Build the Spockspace from the payload and the classes + # Tuneable parameters + self._tune_namespace = self._tune_obj.generate(self._tune_args) except Exception as e: - self._builder_obj.print_usage_and_exit(str(e), sys_exit=False) + self._print_usage_and_exit(str(e), sys_exit=False) raise ValueError(e) def __call__(self, *args, **kwargs): @@ -65,8 +126,6 @@ def __call__(self, *args, **kwargs): def generate(self): """Generate method that returns the actual argument namespace - *Args*: - *Returns*: @@ -75,79 +134,308 @@ def generate(self): """ return self._arg_namespace + def sample(self): + """Sample method that constructs a namespace from the fixed parameters and samples from the tuner space to + generate a Spockspace derived from both + + *Returns*: + + argument namespace(s) -- fixed + drawn sample from tuner backend + + """ + if self._tune_obj is None: + raise ValueError( + f"Called sample method without passing any @spockTuner decorated classes" + ) + if self._tuner_interface is None: + raise ValueError( + f"Called sample method without first calling the tuner method that initializes the " + f"backend library" + ) + return_tuple = self._tuner_state + self._tuner_state = self._tuner_interface.sample() + self._sample_count += 1 + return return_tuple + + def tuner(self, tuner_config): + """Chained call that builds the tuner interface for either optuna or ax depending upon the type of the tuner_obj + + *Args*: + + tuner_config: a class of type optuna.study.Study or AX**** + + *Returns*: + + self so that functions can be chained + + """ + if self._tune_obj is None: + raise ValueError( + f"Called tuner method without passing any @spockTuner decorated classes" + ) + try: + from spock.addons.tune.tuner import TunerInterface + + self._tuner_interface = TunerInterface( + tuner_config=tuner_config, + tuner_namespace=self._tune_namespace, + fixed_namespace=self._arg_namespace, + ) + self._tuner_state = self._tuner_interface.sample() + except ImportError: + print( + "Missing libraries to support tune functionality. Please re-install with the extra tune " + "dependencies -- pip install spock-config[tune]" + ) + return self + + def _print_usage_and_exit(self, msg=None, sys_exit=True, exit_code=1): + """Prints the help message and exits + + *Args*: + + msg: message to print pre exit + + *Returns*: + + None + + """ + print(f"usage: {sys.argv[0]} -c [--config] config1 [config2, config3, ...]") + print(f'\n{self._desc if self._desc != "" else ""}\n') + print("configuration(s):\n") + # Call the fixed parameter help info + self._builder_obj.handle_help_info() + if self._tune_obj is not None: + self._tune_obj.handle_help_info() + if msg is not None: + print(msg) + if sys_exit: + sys.exit(exit_code) + @staticmethod - def _set_backend(args: typing.List): - """Determines which backend class to use + def _handle_tuner_objects(tune_args, s3_config, kwargs): + """Handles creating the tuner builder object if @spockTuner classes were passed in *Args*: - args: list of classes passed to the builder + tune_args: list of tuner classes + s3_config: s3Config object for S3 support + kwargs: optional keyword args *Returns*: - backend: class of backend + tuner builder object or None + + """ + if len(tune_args) > 0: + try: + from spock.addons.tune.builder import TunerBuilder + from spock.addons.tune.payload import TunerPayload + + tuner_builder = TunerBuilder(*tune_args, **kwargs) + tuner_payload = TunerPayload(s3_config=s3_config) + return tuner_builder, tuner_payload + except ImportError: + print( + "Missing libraries to support tune functionality. Please re-install with the extra tune " + "dependencies -- pip install spock-config[tune]" + ) + else: + return None, None + + @staticmethod + def _verify_attr(args: typing.Tuple): + """Verifies that all the input classes are attr based + + *Args*: + + args: tuple of classes passed to the builder + + *Returns*: + + None """ # Gather if all attr backend type_attrs = all([attr.has(arg) for arg in args]) if not type_attrs: which_idx = [attr.has(arg) for arg in args].index(False) - if hasattr(args[which_idx], '__name__'): - raise TypeError(f"*args must be of all attrs backend -- missing a @spock decorator on class " - f"{args[which_idx].__name__}") + if hasattr(args[which_idx], "__name__"): + raise TypeError( + f"*args must be of all attrs backend -- missing a @spock decorator on class " + f"{args[which_idx].__name__}" + ) else: - raise TypeError(f"*args must be of all attrs backend -- invalid type " - f"{type(args[which_idx])}") - else: - backend = {'builder': AttrBuilder, 'payload': AttrPayload, 'saver': AttrSaver} - return backend + raise TypeError( + f"*args must be of all attrs backend -- invalid type " + f"{type(args[which_idx])}" + ) - def _get_config_paths(self): - """Get config paths from all methods + @staticmethod + def _strip_tune_parameters(args: typing.Tuple): + """Separates the fixed arguments from any hyper-parameter arguments + + *Args*: + + args: tuple of classes passed to the builder + + *Returns*: + + fixed_args: list of fixed args + tune_args: list of args destined for a tuner backend + + """ + fixed_args = [] + tune_args = [] + for arg in args: + if arg.__module__ == "spock.backend.config": + fixed_args.append(arg) + elif arg.__module__ == "spock.addons.tune.config": + tune_args.append(arg) + return fixed_args, tune_args + + def _handle_cmd_line(self): + """Handle all cmd line related tasks Config paths can enter from either the command line or be added in the class init call - as a kwarg (configs=[]) + as a kwarg (configs=[]) -- also trigger the building of the cmd line overrides for each fixed and + tunable objects *Returns*: args: namespace of args """ - # Call the objects get_config_paths function - args = self._builder_obj.get_config_paths() + # Need to hold an overarching parser here that just gets appended to for both fixed and tunable objects + # Check if the no_cmd_line is not flagged and if the configs are not empty + if self._no_cmd_line and (self._configs is None): + raise ValueError( + "Flag set for preventing command line read but no paths were passed to the config kwarg" + ) + # If cmd_line is flagged then build the parsers if not make any empty Namespace + args = ( + self._build_override_parsers(desc=self._desc) + if not self._no_cmd_line + else argparse.Namespace(config=[], help=False) + ) + # If configs are present from the init call then roll these into the namespace + if self._configs is not None: + args = self._get_from_kwargs(args, self._configs) + return args + + def _build_override_parsers(self, desc): + """Creates parsers for command-line overrides + + Builds the basic command line parser for configs and help then iterates through each attr instance to make + namespace specific cmd line override parsers -- handles calling both the fixed and tunable objects + + *Args*: + + desc: argparser description + + *Returns*: + + args: argument namespace + + """ + # Highest level parser object + parser = argparse.ArgumentParser(description=desc, add_help=False) + parser.add_argument("-c", "--config", required=False, nargs="+", default=[]) + parser.add_argument("-h", "--help", action="store_true") + # Handle the builder obj + parser = self._builder_obj.build_override_parsers(parser=parser) + if self._tune_obj is not None: + parser = self._tune_obj.build_override_parsers(parser=parser) + args = parser.parse_args() + return args + + @staticmethod + def _get_from_kwargs(args, configs): + """Get configs from the configs kwarg + + *Args*: + + args: argument namespace + configs: config kwarg + + *Returns*: + + args: arg namespace + + """ + if isinstance(configs, list): + args.config.extend(configs) + else: + raise TypeError( + f"configs kwarg must be of type list -- given {type(configs)}" + ) return args - def _get_payload(self): + def _get_payload(self, payload_obj, input_classes, ignore_args: typing.List): """Get the parameter payload from the config file(s) Calls the various ways to get configs and then parses to retrieve the parameter payload - make sure to call deep update so as to not lose some parameters when only partially updating the payload + *Args*: + + payload_obj: current payload object to call + input_classes: classes to use to get payload + ignore_args: args that were decorated for hyper-parameter tuning + *Returns*: payload: dictionary of parameter values """ - args = self._get_config_paths() - if args.help: + if self._args.help: # Call sys exit with a clean code as this is the help call which is not unexpected behavior - self._builder_obj.print_usage_and_exit(sys_exit=True, exit_code=0) + self._print_usage_and_exit(sys_exit=True, exit_code=0) payload = {} - dependencies = {'paths': [], 'rel_paths': [], 'roots': []} - for configs in args.config: - payload_update = self._payload_obj.payload(self._builder_obj.input_classes, configs, args, dependencies) - check_payload_overwrite(payload, payload_update, configs) - deep_payload_update(payload, payload_update) + dependencies = {"paths": [], "rel_paths": [], "roots": []} + if payload_obj is not None: + # Make sure we are actually trying to map to input classes + if len(input_classes) > 0: + # If configs are present then iterate through them and deal with the payload + if len(self._args.config) > 0: + for configs in self._args.config: + payload_update = payload_obj.payload( + input_classes, + ignore_args, + configs, + self._args, + dependencies, + ) + check_payload_overwrite(payload, payload_update, configs) + deep_payload_update(payload, payload_update) + # If there are no configs present we have to fall back only on cmd line args to fill out the necessary + # data -- this is essentially using spock as a drop in replacement of arg-parser + else: + payload_update = payload_obj.payload( + input_classes, ignore_args, None, self._args, dependencies + ) + check_payload_overwrite(payload, payload_update, None) + deep_payload_update(payload, payload_update) return payload - def save(self, file_name: str = None, user_specified_path: str = None, extra_info: bool = True, - file_extension: str = '.yaml'): - """Saves the current config setup to file with a UUID + def _save( + self, + payload, + file_name: str = None, + user_specified_path: str = None, + create_save_path: bool = True, + extra_info: bool = True, + file_extension: str = ".yaml", + ): + """Private interface -- saves the current config setup to file with a UUID *Args*: - file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to uuid if None + payload: Spockspace to save + file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to just uuid if None user_specified_path: if user provides a path it will be used as the path to write + create_save_path: bool to create the path to save if called extra_info: additional info to write to saved config (run date and git info) file_extension: file type to write (default: yaml) @@ -160,10 +448,61 @@ def save(self, file_name: str = None, user_specified_path: str = None, extra_inf elif self._builder_obj.save_path is not None: save_path = Path(self._builder_obj.save_path) else: - raise ValueError('Save did not receive a valid path from: (1) markup file(s) or (2) ' - 'the keyword arg user_specified_path') + raise ValueError( + "Save did not receive a valid path from: (1) markup file(s) or (2) " + "the keyword arg user_specified_path" + ) # Call the saver class and save function self._saver_obj.save( - self._arg_namespace, save_path, file_name, self._create_save_path, extra_info, file_extension + payload, save_path, file_name, create_save_path, extra_info, file_extension ) return self + + def save( + self, + file_name: str = None, + user_specified_path: str = None, + create_save_path: bool = True, + extra_info: bool = True, + file_extension: str = ".yaml", + add_tuner_sample: bool = False, + ): + """Saves the current config setup to file with a UUID + + *Args*: + + file_name: name of file (will be appended with .spock.cfg.file_extension) -- falls back to just uuid if None + user_specified_path: if user provides a path it will be used as the path to write + create_save_path: bool to create the path to save if called + extra_info: additional info to write to saved config (run date and git info) + file_extension: file type to write (default: yaml) + append_tuner_state: save the current tuner sample to the payload + + *Returns*: + + self so that functions can be chained + """ + if add_tuner_sample: + file_name = ( + f"hp.sample.{self._sample_count+1}" + if file_name is None + else f"{file_name}.hp.sample.{self._sample_count+1}" + ) + self._save( + self._tuner_state[0], + file_name, + user_specified_path, + create_save_path, + extra_info, + file_extension, + ) + else: + self._save( + self._arg_namespace, + file_name, + user_specified_path, + create_save_path, + extra_info, + file_extension, + ) + return self diff --git a/spock/config.py b/spock/config.py index 28ee0114..7426af19 100644 --- a/spock/config.py +++ b/spock/config.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Creates the spock config decorator that wraps attrs""" -from spock.backend.attr.config import spock_attr +from spock.backend.config import spock_attr from spock.utils import _is_spock_instance # Simplified decorator for attrs spock = spock_attr # Public alias for checking if an object is a @spock annotated class -isinstance_spock =_is_spock_instance +isinstance_spock = _is_spock_instance diff --git a/spock/handlers.py b/spock/handlers.py index bd4fabda..fb277151 100644 --- a/spock/handlers.py +++ b/spock/handlers.py @@ -1,22 +1,23 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """I/O handlers for various file formats""" -from abc import ABC -from abc import abstractmethod import json import os import re -from spock import __version__ -from spock.utils import check_path_s3 -import toml import typing +from abc import ABC, abstractmethod from warnings import warn + +import pytomlpp import yaml +from spock import __version__ +from spock.utils import check_path_s3 + class Handler(ABC): """Base class for file type loaders @@ -24,6 +25,7 @@ class Handler(ABC): ABC for loaders """ + def load(self, path: str, s3_config=None) -> typing.Dict: """Load function for file type @@ -57,8 +59,15 @@ def _load(self, path: str) -> typing.Dict: """ raise NotImplementedError - def save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str, name: str, - create_path: bool = False, s3_config=None): + def save( + self, + out_dict: typing.Dict, + info_dict: typing.Optional[typing.Dict], + path: str, + name: str, + create_path: bool = False, + s3_config=None, + ): """Write function for file type This will handle local or s3 writes with the boolean is_s3 flag. If detected it will conditionally import @@ -84,12 +93,17 @@ def save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], p if is_s3: try: from spock.addons.s3.utils import handle_s3_save_path - handle_s3_save_path(temp_path=write_path, s3_path=path, name=name, s3_config=s3_config) + + handle_s3_save_path( + temp_path=write_path, s3_path=path, name=name, s3_config=s3_config + ) except ImportError: - print('Error importing spock s3 utils after detecting s3:// save path') + print("Error importing spock s3 utils after detecting s3:// save path") @abstractmethod - def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str) -> str: + def _save( + self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str + ) -> str: """Write function for file type *Args*: @@ -124,14 +138,16 @@ def _handle_possible_s3_load_path(path: str, s3_config=None) -> str: if is_s3: try: from spock.addons.s3.utils import handle_s3_load_path + path = handle_s3_load_path(path=path, s3_config=s3_config) except ImportError: - print('Error importing spock s3 utils after detecting s3:// load path') + print("Error importing spock s3 utils after detecting s3:// load path") return path @staticmethod - def _handle_possible_s3_save_path(path: str, name: str, create_path: bool, - s3_config=None) -> typing.Tuple[str, bool]: + def _handle_possible_s3_save_path( + path: str, name: str, create_path: bool, s3_config=None + ) -> typing.Tuple[str, bool]: """Handles the possibility of having to save to a S3 path Checks to see if it detects a S3 uri and if so generates a tmp location to write the file to pre-upload @@ -149,15 +165,17 @@ def _handle_possible_s3_save_path(path: str, name: str, create_path: bool, is_s3 = check_path_s3(path=path) if is_s3: if s3_config is None: - raise ValueError('Save to S3 -- Missing S3Config object which is necessary to handle S3 style paths') - write_path = f'{s3_config.temp_folder}/{name}' + raise ValueError( + "Save to S3 -- Missing S3Config object which is necessary to handle S3 style paths" + ) + write_path = f"{s3_config.temp_folder}/{name}" # Strip double slashes if exist - write_path = write_path.replace(r'//', r'/') + write_path = write_path.replace(r"//", r"/") else: # Handle the path logic for non S3 if not os.path.exists(path) and create_path: os.makedirs(path) - write_path = f'{path}/{name}' + write_path = f"{path}/{name}" return write_path, is_s3 @staticmethod @@ -173,14 +191,14 @@ def write_extra_info(path, info_dict): """ # Write the commented info as new lines - with open(path, 'w+') as fid: + with open(path, "w+") as fid: # Write a spock header - fid.write(f'# Spock Version: {__version__}\n') + fid.write(f"# Spock Version: {__version__}\n") # Write info dict if not None if info_dict is not None: for k, v in info_dict.items(): - fid.write(f'{k}: {v}\n') - fid.write('\n') + fid.write(f"{k}: {v}\n") + fid.write("\n") class YAMLHandler(Handler): @@ -189,19 +207,23 @@ class YAMLHandler(Handler): Base YAML class """ + # override default SafeLoader behavior to correctly # interpret 1e1 (as opposed to 1.e+1) as 10 # https://stackoverflow.com/questions/30458977/yaml-loads-5e-6-as-string-and-not-a-number/30462009#30462009 yaml.SafeLoader.add_implicit_resolver( - u'tag:yaml.org,2002:float', - re.compile(u'''^(?: + "tag:yaml.org,2002:float", + re.compile( + """^(?: [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |\\.[0-9_]+(?:[eE][-+][0-9]+)? |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |[-+]?\\.(?:inf|Inf|INF) - |\\.(?:nan|NaN|NAN))$''', re.X), - list(u'-+0123456789.') + |\\.(?:nan|NaN|NAN))$""", + re.X, + ), + list("-+0123456789."), ) def _load(self, path: str) -> typing.Dict: @@ -216,12 +238,14 @@ def _load(self, path: str) -> typing.Dict: base_payload: dictionary of read file """ - file_contents = open(path, 'r').read() - file_contents = re.sub(r'--([a-zA-Z0-9_]*)', r'\g<1>: True', file_contents) + file_contents = open(path, "r").read() + file_contents = re.sub(r"--([a-zA-Z0-9_]*)", r"\g<1>: True", file_contents) base_payload = yaml.safe_load(file_contents) return base_payload - def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str): + def _save( + self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str + ): """Write function for YAML type *Args*: @@ -237,7 +261,7 @@ def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], self.write_extra_info(path=path, info_dict=info_dict) # Remove aliases in YAML dump yaml.Dumper.ignore_aliases = lambda *args: True - with open(path, 'a') as yaml_fid: + with open(path, "a") as yaml_fid: yaml.safe_dump(out_dict, yaml_fid, default_flow_style=False) return path @@ -248,6 +272,7 @@ class TOMLHandler(Handler): Base TOML class """ + def _load(self, path: str) -> typing.Dict: """TOML load function @@ -260,10 +285,12 @@ def _load(self, path: str) -> typing.Dict: base_payload: dictionary of read file """ - base_payload = toml.load(path) + base_payload = pytomlpp.load(path) return base_payload - def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str): + def _save( + self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str + ): """Write function for TOML type *Args*: @@ -277,8 +304,8 @@ def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], """ # First write the commented info self.write_extra_info(path=path, info_dict=info_dict) - with open(path, 'a') as toml_fid: - toml.dump(out_dict, toml_fid) + with open(path, "a") as toml_fid: + pytomlpp.dump(out_dict, toml_fid) return path @@ -288,6 +315,7 @@ class JSONHandler(Handler): Base JSON class """ + def _load(self, path: str) -> typing.Dict: """JSON load function @@ -304,7 +332,9 @@ def _load(self, path: str) -> typing.Dict: base_payload = json.load(json_fid) return base_payload - def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str): + def _save( + self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], path: str + ): """Write function for JSON type *Args*: @@ -317,8 +347,10 @@ def _save(self, out_dict: typing.Dict, info_dict: typing.Optional[typing.Dict], """ if info_dict is not None: - warn('JSON does not support comments and thus cannot save extra info to file... removing extra info') + warn( + "JSON does not support comments and thus cannot save extra info to file... removing extra info" + ) info_dict = None - with open(path, 'a') as json_fid: - json.dump(out_dict, json_fid, indent=4, separators=(',', ': ')) + with open(path, "a") as json_fid: + json.dump(out_dict, json_fid, indent=4, separators=(",", ": ")) return path diff --git a/spock/utils.py b/spock/utils.py index 282430a5..588cf966 100644 --- a/spock/utils.py +++ b/spock/utils.py @@ -1,22 +1,23 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 """Utility functions for Spock""" import ast -import attr -from enum import EnumMeta import os import re import socket import subprocess import sys -from time import localtime -from time import strftime +from enum import EnumMeta +from time import localtime, strftime from warnings import warn + +import attr import git + minor = sys.version_info.minor if minor < 7: from typing import GenericMeta as _GenericAlias @@ -37,7 +38,7 @@ def check_path_s3(path: str) -> bool: """ # Make a case insensitive s3 regex with single or double forward slash (due to posix stripping) - s3_regex = re.compile(r'(?i)^s3://?').search(path) + s3_regex = re.compile(r"(?i)^s3://?").search(path) # If it returns an object then the path is an s3 style reference return s3_regex is not None @@ -57,7 +58,7 @@ def _is_spock_instance(__obj: object): bool """ - return (__obj.__module__ == 'spock.backend.attr.config') and attr.has(__obj) + return (__obj.__module__ == "spock.backend.config") and attr.has(__obj) def make_argument(arg_name, arg_type, parser): @@ -84,11 +85,11 @@ def make_argument(arg_name, arg_type, parser): elif isinstance(arg_type, EnumMeta): type_set = list({type(val.value) for val in arg_type})[0] # if this is an enum of a class switch the type to str as this is how it gets matched - type_set = str if type_set.__name__ == 'type' else type_set + type_set = str if type_set.__name__ == "type" else type_set parser.add_argument(arg_name, required=False, type=type_set) # For booleans we map to store true elif arg_type == bool: - parser.add_argument(arg_name, required=False, action='store_true') + parser.add_argument(arg_name, required=False, action="store_true") # Else we are a simple base type which we can cast to else: parser.add_argument(arg_name, required=False, type=arg_type) @@ -140,8 +141,8 @@ def make_blank_git(out_dict): out_dict: output dictionary with added git info """ - for key in ('BRANCH', 'COMMIT SHA', 'STATUS', 'ORIGIN'): - out_dict.update({f'# Git {key}': 'UNKNOWN'}) + for key in ("BRANCH", "COMMIT SHA", "STATUS", "ORIGIN"): + out_dict.update({f"# Git {key}": "UNKNOWN"}) return out_dict @@ -161,23 +162,38 @@ def add_repo_info(out_dict): repo = git.Repo(os.getcwd(), search_parent_directories=True) # Check if we are really in a detached head state as later info will fail if we are if minor < 7: - head_result = subprocess.run('git rev-parse --abbrev-ref --symbolic-full-name HEAD', stdout=subprocess.PIPE, - shell=True, check=False) + head_result = subprocess.run( + "git rev-parse --abbrev-ref --symbolic-full-name HEAD", + stdout=subprocess.PIPE, + shell=True, + check=False, + ) else: - head_result = subprocess.run('git rev-parse --abbrev-ref --symbolic-full-name HEAD', capture_output=True, - shell=True, check=False) - if head_result.stdout.decode().rstrip('\n') == 'HEAD': + head_result = subprocess.run( + "git rev-parse --abbrev-ref --symbolic-full-name HEAD", + capture_output=True, + shell=True, + check=False, + ) + if head_result.stdout.decode().rstrip("\n") == "HEAD": out_dict = make_blank_git(out_dict) else: - out_dict.update({'# Git Branch': repo.active_branch.name}) - out_dict.update({'# Git Commit': repo.active_branch.commit.hexsha}) - out_dict.update({'# Git Date': repo.active_branch.commit.committed_datetime}) - if len(repo.untracked_files) > 0 or len(repo.active_branch.commit.diff(None)) > 0: - git_status = 'DIRTY' + out_dict.update({"# Git Branch": repo.active_branch.name}) + out_dict.update({"# Git Commit": repo.active_branch.commit.hexsha}) + out_dict.update( + {"# Git Date": repo.active_branch.commit.committed_datetime} + ) + if ( + len(repo.untracked_files) > 0 + or len(repo.active_branch.commit.diff(None)) > 0 + ): + git_status = "DIRTY" else: - git_status = 'CLEAN' - out_dict.update({'# Git Status': git_status}) - out_dict.update({'# Git Origin': repo.active_branch.commit.repo.remotes.origin.url}) + git_status = "CLEAN" + out_dict.update({"# Git Status": git_status}) + out_dict.update( + {"# Git Origin": repo.active_branch.commit.repo.remotes.origin.url} + ) except git.InvalidGitRepositoryError: # pragma: no cover # But it's okay if we are not out_dict = make_blank_git(out_dict) @@ -195,16 +211,20 @@ def add_generic_info(out_dict): out_dict: output dictionary """ - out_dict.update({'# Machine FQDN': socket.getfqdn()}) - out_dict.update({'# Python Executable': sys.executable}) - out_dict.update({'# Python Version': f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}'}) - out_dict.update({'# Python Script': os.path.realpath(sys.argv[0])}) - out_dict.update({'# Run Date': strftime('%Y-%m-%d', localtime())}) - out_dict.update({'# Run Time': strftime('%H:%M:%S', localtime())}) + out_dict.update({"# Machine FQDN": socket.getfqdn()}) + out_dict.update({"# Python Executable": sys.executable}) + out_dict.update( + { + "# Python Version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + } + ) + out_dict.update({"# Python Script": os.path.realpath(sys.argv[0])}) + out_dict.update({"# Run Date": strftime("%Y-%m-%d", localtime())}) + out_dict.update({"# Run Time": strftime("%H:%M:%S", localtime())}) # Make a best effort to determine if run in a container - out_dict.update({'# Run w/ Docker': str(_maybe_docker())}) + out_dict.update({"# Run w/ Docker": str(_maybe_docker())}) # Make a best effort to determine if run in a container via k8s - out_dict.update({'# Run w/ Kubernetes': str(_maybe_k8s())}) + out_dict.update({"# Run w/ Kubernetes": str(_maybe_k8s())}) return out_dict @@ -223,10 +243,12 @@ def _maybe_docker(cgroup_path="/proc/self/cgroup"): """ # A few options seem to be at play here: # 1. Check for /.dockerenv -- docker should create this is any container - bool_env = os.path.exists('/.dockerenv') + bool_env = os.path.exists("/.dockerenv") # 2. Check /proc/self/cgroup for "docker" # https://stackoverflow.com/a/48710609 - bool_cgroup = os.path.isfile(cgroup_path) and any("docker" in line for line in open(cgroup_path)) + bool_cgroup = os.path.isfile(cgroup_path) and any( + "docker" in line for line in open(cgroup_path) + ) return bool_env or bool_cgroup @@ -247,7 +269,9 @@ def _maybe_k8s(cgroup_path="/proc/self/cgroup"): bool_env = os.environ.get("KUBERNETES_SERVICE_HOST") is not None # 2. Similar to docker check /proc/self/cgroup for "kubepods" # https://stackoverflow.com/a/48710609 - bool_cgroup = os.path.isfile(cgroup_path) and any("kubepods" in line for line in open(cgroup_path)) + bool_cgroup = os.path.isfile(cgroup_path) and any( + "kubepods" in line for line in open(cgroup_path) + ) return bool_env or bool_cgroup @@ -279,7 +303,7 @@ def deep_payload_update(source, updates): return source -def check_payload_overwrite(payload, updates, configs, overwrite=''): +def check_payload_overwrite(payload, updates, configs, overwrite=""): """Warns when parameters are overwritten across payloads as order will matter *Args*: @@ -294,11 +318,13 @@ def check_payload_overwrite(payload, updates, configs, overwrite=''): """ for k, v in updates.items(): if isinstance(v, dict) and v: - overwrite += (k + ":") + overwrite += k + ":" current_payload = {} if payload.get(k) is None else payload.get(k) check_payload_overwrite(current_payload, v, configs, overwrite=overwrite) else: if k in payload: - warn(f'Overriding an already set parameter {overwrite + k} from {configs}\n' - f'Be aware that value precedence is set by the order of the config files (last to load)...', - SyntaxWarning) + warn( + f"Overriding an already set parameter {overwrite + k} from {configs}\n" + f"Be aware that value precedence is set by the order of the config files (last to load)...", + SyntaxWarning, + ) diff --git a/tests/base/attr_configs_test.py b/tests/base/attr_configs_test.py index 82db5889..3d32ca86 100644 --- a/tests/base/attr_configs_test.py +++ b/tests/base/attr_configs_test.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 from enum import Enum @@ -88,6 +88,8 @@ class TypeConfig: tuple_p_str: Tuple[str, str] # Required Tuple -- Bool tuple_p_bool: Tuple[bool, bool] + # Required Tuple -- mixed + tuple_p_mixed: Tuple[int, float] # Required choice -- Str choice_p_str: StrChoice # Required choice -- Int diff --git a/tests/base/base_asserts_test.py b/tests/base/base_asserts_test.py index eb092e8a..8067ba41 100644 --- a/tests/base/base_asserts_test.py +++ b/tests/base/base_asserts_test.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2019 FMR LLC +# Copyright FMR LLC # SPDX-License-Identifier: Apache-2.0 @@ -21,6 +21,7 @@ def test_all_set(self, arg_builder): assert arg_builder.TypeConfig.tuple_p_int == (10, 20) assert arg_builder.TypeConfig.tuple_p_str == ('Spock', 'Package') assert arg_builder.TypeConfig.tuple_p_bool == (True, False) + assert arg_builder.TypeConfig.tuple_p_mixed == (5, 11.5) assert arg_builder.TypeConfig.choice_p_str == 'option_1' assert arg_builder.TypeConfig.choice_p_int == 10 assert arg_builder.TypeConfig.choice_p_float == 10.0 diff --git a/tests/base/test_cmd_line.py b/tests/base/test_cmd_line.py index 699cd02b..9a3724f5 100644 --- a/tests/base/test_cmd_line.py +++ b/tests/base/test_cmd_line.py @@ -21,6 +21,7 @@ def arg_builder(monkeypatch): '--TypeConfig.tuple_p_float', '(11.0, 21.0)', '--TypeConfig.tuple_p_int', '(11, 21)', '--TypeConfig.tuple_p_str', "('Hooray', 'Working')", '--TypeConfig.tuple_p_bool', '(False, True)', + '--TypeConfig.tuple_p_mixed', '(5, 11.5)', '--TypeConfig.list_list_p_int', "[[11, 21], [11, 21]]", '--TypeConfig.choice_p_str', 'option_2', '--TypeConfig.choice_p_int', '20', '--TypeConfig.choice_p_float', '20.0', @@ -28,6 +29,7 @@ def arg_builder(monkeypatch): '--TypeConfig.list_list_choice_p_str', "[['option_2'], ['option_2']]", '--TypeConfig.list_choice_p_int', '[20]', '--TypeConfig.list_choice_p_float', '[20.0]', + '--TypeConfig.class_enum', 'NestedStuff', '--NestedStuff.one', '12', '--NestedStuff.two', 'ancora', '--TypeConfig.nested_list.NestedListStuff.one', '[11, 21]', '--TypeConfig.nested_list.NestedListStuff.two', "['Hooray', 'Working']", @@ -48,6 +50,69 @@ def test_class_overrides(self, arg_builder): assert arg_builder.TypeConfig.tuple_p_int == (11, 21) assert arg_builder.TypeConfig.tuple_p_str == ('Hooray', 'Working') assert arg_builder.TypeConfig.tuple_p_bool == (False, True) + assert arg_builder.TypeConfig.tuple_p_mixed == (5, 11.5) + assert arg_builder.TypeConfig.choice_p_str == 'option_2' + assert arg_builder.TypeConfig.choice_p_int == 20 + assert arg_builder.TypeConfig.choice_p_float == 20.0 + assert arg_builder.TypeConfig.list_list_p_int == [[11, 21], [11, 21]] + assert arg_builder.TypeConfig.list_choice_p_str == ['option_2'] + assert arg_builder.TypeConfig.list_list_choice_p_str == [['option_2'], ['option_2']] + assert arg_builder.TypeConfig.list_choice_p_int == [20] + assert arg_builder.TypeConfig.list_choice_p_float == [20.0] + assert arg_builder.TypeConfig.class_enum.one == 12 + assert arg_builder.TypeConfig.class_enum.two == 'ancora' + assert arg_builder.NestedListStuff[0].one == 11 + assert arg_builder.NestedListStuff[0].two == 'Hooray' + assert arg_builder.NestedListStuff[1].one == 21 + assert arg_builder.NestedListStuff[1].two == 'Working' + + +class TestClassOnlyCmdLine: + """Testing command line overrides""" + @staticmethod + @pytest.fixture + def arg_builder(monkeypatch): + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['', + '--TypeConfig.bool_p', '--TypeConfig.int_p', '11', '--TypeConfig.float_p', '11.0', + '--TypeConfig.string_p', 'Hooray', + '--TypeConfig.list_p_float', '[11.0, 21.0]', '--TypeConfig.list_p_int', '[11, 21]', + '--TypeConfig.list_p_str', "['Hooray', 'Working']", + '--TypeConfig.list_p_bool', '[False, True]', + '--TypeConfig.tuple_p_float', '(11.0, 21.0)', '--TypeConfig.tuple_p_int', '(11, 21)', + '--TypeConfig.tuple_p_str', "('Hooray', 'Working')", + '--TypeConfig.tuple_p_bool', '(False, True)', + '--TypeConfig.tuple_p_mixed', '(5, 11.5)', + '--TypeConfig.list_list_p_int', "[[11, 21], [11, 21]]", + '--TypeConfig.choice_p_str', 'option_2', + '--TypeConfig.choice_p_int', '20', '--TypeConfig.choice_p_float', '20.0', + '--TypeConfig.list_choice_p_str', "['option_2']", + '--TypeConfig.list_list_choice_p_str', "[['option_2'], ['option_2']]", + '--TypeConfig.list_choice_p_int', '[20]', + '--TypeConfig.list_choice_p_float', '[20.0]', + '--TypeConfig.class_enum', 'NestedStuff', + '--TypeConfig.nested', 'NestedStuff', + '--NestedStuff.one', '12', '--NestedStuff.two', 'ancora', + '--TypeConfig.nested_list.NestedListStuff.one', '[11, 21]', + '--TypeConfig.nested_list.NestedListStuff.two', "['Hooray', 'Working']", + ]) + config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, desc='Test Builder') + return config.generate() + + def test_class_overrides(self, arg_builder): + assert arg_builder.TypeConfig.bool_p is True + assert arg_builder.TypeConfig.int_p == 11 + assert arg_builder.TypeConfig.float_p == 11.0 + assert arg_builder.TypeConfig.string_p == 'Hooray' + assert arg_builder.TypeConfig.list_p_float == [11.0, 21.0] + assert arg_builder.TypeConfig.list_p_int == [11, 21] + assert arg_builder.TypeConfig.list_p_str == ['Hooray', 'Working'] + assert arg_builder.TypeConfig.list_p_bool == [False, True] + assert arg_builder.TypeConfig.tuple_p_float == (11.0, 21.0) + assert arg_builder.TypeConfig.tuple_p_int == (11, 21) + assert arg_builder.TypeConfig.tuple_p_str == ('Hooray', 'Working') + assert arg_builder.TypeConfig.tuple_p_bool == (False, True) + assert arg_builder.TypeConfig.tuple_p_mixed == (5, 11.5) assert arg_builder.TypeConfig.choice_p_str == 'option_2' assert arg_builder.TypeConfig.choice_p_int == 20 assert arg_builder.TypeConfig.choice_p_float == 20.0 diff --git a/tests/base/test_type_specific.py b/tests/base/test_type_specific.py index e353e174..b3ecea04 100644 --- a/tests/base/test_type_specific.py +++ b/tests/base/test_type_specific.py @@ -62,11 +62,11 @@ def test_enum_class_missing(self, monkeypatch): ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, desc='Test Builder') -class TestMixedGeneric: - def test_mixed_generic(self, monkeypatch): - with monkeypatch.context() as m: - with pytest.raises(TypeError): - @spock - class GenericFail: - generic_fail: Tuple[List[int], List[int], int] +# class TestMixedGeneric: +# def test_mixed_generic(self, monkeypatch): +# with monkeypatch.context() as m: +# with pytest.raises(TypeError): +# @spock +# class GenericFail: +# generic_fail: Tuple[List[int], List[int], int] diff --git a/tests/base/test_writers.py b/tests/base/test_writers.py index 33a1e99d..d65111a6 100644 --- a/tests/base/test_writers.py +++ b/tests/base/test_writers.py @@ -43,10 +43,9 @@ def test_yaml_file_writer_create(self, monkeypatch, tmp_path): with monkeypatch.context() as m: m.setattr(sys, 'argv', ['', '--config', './tests/conf/yaml/test.yaml']) - config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, TypeOptConfig, desc='Test Builder', - create_save_path=True) + config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, TypeOptConfig, desc='Test Builder') # Test the chained version - config.save(user_specified_path=f'{tmp_path}/tmp', file_extension='.yaml').generate() + config.save(user_specified_path=f'{tmp_path}/tmp', create_save_path=True, file_extension='.yaml').generate() check_path = f'{str(tmp_path)}/tmp/*.yaml' fname = glob.glob(check_path)[0] with open(fname, 'r') as fin: @@ -101,7 +100,10 @@ def test_yaml_file_writer(self, monkeypatch, tmp_path): config = ConfigArgBuilder(TypeConfig, NestedStuff, NestedListStuff, TypeOptConfig, desc='Test Builder') # Test the chained version with pytest.raises(FileNotFoundError): - config.save(user_specified_path=f'{str(tmp_path)}/foo.bar/fizz.buzz/', file_extension='.yaml').generate() + config.save( + user_specified_path=f'{str(tmp_path)}/foo.bar/fizz.buzz/', file_extension='.yaml', + create_save_path=False + ).generate() class TestInvalidExtensionTypeRaise: diff --git a/tests/conf/json/test.json b/tests/conf/json/test.json index 8ac6ff3e..e81dff46 100644 --- a/tests/conf/json/test.json +++ b/tests/conf/json/test.json @@ -12,6 +12,7 @@ "tuple_p_int": [10, 20], "tuple_p_str": ["Spock", "Package"], "tuple_p_bool": [true, false], + "tuple_p_mixed": [5, 11.5], "choice_p_str": "option_1", "choice_p_int": 10, "choice_p_float": 10.0, diff --git a/tests/conf/toml/test.toml b/tests/conf/toml/test.toml index 85bcdad0..6e9e0a77 100644 --- a/tests/conf/toml/test.toml +++ b/tests/conf/toml/test.toml @@ -25,6 +25,8 @@ tuple_p_int = [10, 20] tuple_p_str = ["Spock", "Package"] # Required Tuple -- Bool tuple_p_bool = [true, false] +# Required Tuple -- mixed +tuple_p_mixed = [5, 11.5] # Required Choice -- Str type choice_p_str = 'option_1' # Required Choice -- Int diff --git a/tests/conf/yaml/inherited.yaml b/tests/conf/yaml/inherited.yaml index 4d03e7ce..de292a22 100644 --- a/tests/conf/yaml/inherited.yaml +++ b/tests/conf/yaml/inherited.yaml @@ -24,6 +24,8 @@ tuple_p_float: [10.0, 20.0] tuple_p_int: [10, 20] # Required Tuple -- Str tuple_p_str: [Spock, Package] +# Required Tuple -- mixed +tuple_p_mixed: [5, 11.5] # Required Tuple -- Bool tuple_p_bool: [True, False] # Required Choice -- Str diff --git a/tests/conf/yaml/test.yaml b/tests/conf/yaml/test.yaml index b344ad5c..983a0b00 100644 --- a/tests/conf/yaml/test.yaml +++ b/tests/conf/yaml/test.yaml @@ -26,6 +26,8 @@ tuple_p_int: [10, 20] tuple_p_str: [Spock, Package] # Required Tuple -- Bool tuple_p_bool: [True, False] +# Required Tuple -- mixed +tuple_p_mixed: [5, 11.5] # Required Choice -- Str choice_p_str: option_1 # Required Choice -- Int diff --git a/tests/conf/yaml/test_class.yaml b/tests/conf/yaml/test_class.yaml index 88867675..93eb81e7 100644 --- a/tests/conf/yaml/test_class.yaml +++ b/tests/conf/yaml/test_class.yaml @@ -2,7 +2,7 @@ ### Required or Boolean Base Types ### TypeConfig: # Boolean - Set - bool_p_set: true + bool_p: true # Required Int int_p: 10 # Required Float @@ -46,7 +46,7 @@ TypeConfig: # Nested List configuration nested_list: NestedListStuff # Class Enum - class_enum: NestedStuff + class_enum: NestedListStuff NestedListStuff: - one: 10 two: hello diff --git a/tests/s3/test_io.py b/tests/s3/test_io.py index 0f3b0262..089b0de4 100644 --- a/tests/s3/test_io.py +++ b/tests/s3/test_io.py @@ -2,7 +2,7 @@ import datetime from tests.base.base_asserts_test import * from spock.builder import ConfigArgBuilder -from spock.addons import S3Config +from spock.addons.s3 import S3Config from tests.base.attr_configs_test import * from tests.s3.fixtures_test import * import re diff --git a/tests/s3/test_raises.py b/tests/s3/test_raises.py index 6415c05d..bf636678 100644 --- a/tests/s3/test_raises.py +++ b/tests/s3/test_raises.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import datetime from spock.builder import ConfigArgBuilder -from spock.addons import S3Config +from spock.addons.s3 import S3Config from tests.base.attr_configs_test import * from tests.s3.fixtures_test import * import sys