diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d8610dbd..ef29127b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,26 +1,67 @@ name: Python package -on: [push, pull_request] +on: + push: + branches: [master] + pull_request: jobs: - build: + linting: + name: Run linting/pre-commit checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: '3.11' + cache: 'pip' + - run: pip install pre-commit + - run: pre-commit --version + - run: pre-commit install + - run: pre-commit run --all-files + build: + needs: [linting] runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - name: Install poetry + run: | + python -m pip install --upgrade pip + pip install poetry - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - cache: "pip" - - name: Install dependencies + cache: poetry + + - name: Install base dependencies + run: poetry install + + - name: Unit tests with Pytest (no extras) + timeout-minutes: 3 run: | - python -m pip install --upgrade pip - pip install -e .[all] - - name: Unit tests with Pytest + poetry run pytest --benchmark-disable --cov=simple_parsing --cov-report=xml --cov-append + + + - name: Install extra dependencies + run: poetry install --all-extras + + - name: Unit tests with Pytest (with extra dependencies) + timeout-minutes: 3 run: | - pytest --benchmark-disable + poetry run pytest --benchmark-disable --cov=simple_parsing --cov-report=xml --cov-append + + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + env_vars: PLATFORM,PYTHON + name: codecov-umbrella + fail_ci_if_error: false diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 63a6fa57..8d6276ed 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -5,28 +5,36 @@ name: Upload Python Package on: release: - types: [created] + types: [published] workflow_dispatch: {} jobs: - deploy: + publish: + strategy: + matrix: + python-version: [3.9] + os: [ubuntu-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - name: Install poetry + run: | + python -m pip install --upgrade pip + pip install poetry - runs-on: ubuntu-latest + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + cache: poetry - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install setuptools wheel twine - - name: Build and publish - env: - TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} - TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} - run: | - python setup.py sdist bdist_wheel - twine upload dist/* + - name: Install dependencies + run: | + poetry install + poetry self add "poetry-dynamic-versioning[plugin]" + poetry dynamic-versioning enable + + - name: Publish package + env: + POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }} + run: poetry publish --build diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2852dae..92165901 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,18 +28,15 @@ repos: - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: 'v0.0.261' + rev: 'v0.1.14' hooks: + # Run the linter. - id: ruff - args: ['--line-length', '99', '--fix'] + args: ['--line-length', '99', "--select", "I,UP", '--fix'] require_serial: true - - # python code formatting - - repo: https://github.com/psf/black - rev: 22.12.0 - hooks: - - id: black - args: [--line-length, "99"] + # Run the formatter. + - id: ruff-format + args: ['--line-length', '99'] require_serial: true # python docstring formatting @@ -47,6 +44,7 @@ repos: rev: v1.5.1 hooks: - id: docformatter + exclude: ^test/test_docstrings.py args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] require_serial: true @@ -72,6 +70,7 @@ repos: hooks: - id: mdformat args: ["--number"] + exclude: ^.github/ISSUE_TEMPLATE/.*\.md$ additional_dependencies: - mdformat-gfm - mdformat-tables @@ -80,6 +79,11 @@ repos: # - mdformat-black require_serial: true + - repo: https://github.com/python-poetry/poetry + rev: 1.7.0 + hooks: + - id: poetry-check + require_serial: true # word spelling linter - repo: https://github.com/codespell-project/codespell diff --git a/README.md b/README.md index fdd349b1..c4d86590 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ ![Build Status](https://github.com/lebrice/SimpleParsing/actions/workflows/build.yml/badge.svg) [![PyPI version](https://badge.fury.io/py/simple-parsing.svg)](https://badge.fury.io/py/simple-parsing) - # Simple, Elegant, Typed Argument Parsing `simple-parsing` allows you to transform your ugly `argparse` scripts into beautifully structured, strongly typed little works of art. This isn't a fancy, complicated new command-line tool either, ***this simply adds new features to plain-old argparse!*** @@ -28,11 +27,13 @@ args = parser.parse_args() print("foo:", args.foo) print("options:", args.options) ``` + ```console $ python examples/demo.py --log_dir logs --foo 123 foo: 123 options: Options(log_dir='logs', learning_rate=0.0001) ``` + ```console $ python examples/demo.py --help usage: demo.py [-h] [--foo int] --log_dir str [--learning_rate float] @@ -60,7 +61,6 @@ options: Options = simple_parsing.parse(Options) options, leftover_args = simple_parsing.parse_known_args(Options) ``` - ## installation `pip install simple-parsing` @@ -70,66 +70,69 @@ options, leftover_args = simple_parsing.parse_known_args(Options) ## [API Documentation](https://github.com/lebrice/SimpleParsing/tree/master/docs/README.md) (Under construction) ## Features + - ### [Automatic "--help" strings](https://github.com/lebrice/SimpleParsing/tree/master/examples/docstrings/README.md) - As developers, we want to make it easy for people coming into our projects to understand how to run them. However, a user-friendly `--help` message is often hard to write and to maintain, especially as the number of arguments increases. + As developers, we want to make it easy for people coming into our projects to understand how to run them. However, a user-friendly `--help` message is often hard to write and to maintain, especially as the number of arguments increases. - With `simple-parsing`, your arguments and their descriptions are defined in the same place, making your code easier to read, write, and maintain. + With `simple-parsing`, your arguments and their descriptions are defined in the same place, making your code easier to read, write, and maintain. - ### Modular, Reusable, Cleanly Grouped Arguments - *(no more copy-pasting)* - - When you need to add a new group of command-line arguments similar to an existing one, instead of copy-pasting a block of `argparse` code and renaming variables, you can reuse your argument class, and let the `ArgumentParser` take care of adding relevant prefixes to the arguments for you: - - ```python - parser.add_arguments(Options, dest="train") - parser.add_arguments(Options, dest="valid") - args = parser.parse_args() - train_options: Options = args.train - valid_options: Options = args.valid - print(train_options) - print(valid_options) - ``` - ```console - $ python examples/demo.py \ - --train.log_dir "training" \ - --valid.log_dir "validation" - Options(log_dir='training', learning_rate=0.0001) - Options(log_dir='validation', learning_rate=0.0001) - ``` - - These prefixes can also be set explicitly, or not be used at all. For more info, take a look at the [Prefixing Guide](https://github.com/lebrice/SimpleParsing/tree/master/examples/prefixing/README.md) + *(no more copy-pasting)* + + When you need to add a new group of command-line arguments similar to an existing one, instead of copy-pasting a block of `argparse` code and renaming variables, you can reuse your argument class, and let the `ArgumentParser` take care of adding relevant prefixes to the arguments for you: + + ```python + parser.add_arguments(Options, dest="train") + parser.add_arguments(Options, dest="valid") + args = parser.parse_args() + train_options: Options = args.train + valid_options: Options = args.valid + print(train_options) + print(valid_options) + ``` + + ```console + $ python examples/demo.py \ + --train.log_dir "training" \ + --valid.log_dir "validation" + Options(log_dir='training', learning_rate=0.0001) + Options(log_dir='validation', learning_rate=0.0001) + ``` + + These prefixes can also be set explicitly, or not be used at all. For more info, take a look at the [Prefixing Guide](https://github.com/lebrice/SimpleParsing/tree/master/examples/prefixing/README.md) - ### [Argument subgroups](https://github.com/lebrice/SimpleParsing/tree/master/examples/subgroups/README.md) - It's easy to choose between different argument groups of arguments, with the `subgroups` - function! + It's easy to choose between different argument groups of arguments, with the `subgroups` + function! - ### [Setting defaults from Configuration files](https://github.com/lebrice/SimpleParsing/tree/master/examples/config_files/README.md) - Default values for command-line arguments can easily be read from many different formats, including json/yaml! + Default values for command-line arguments can easily be read from many different formats, including json/yaml! - ### [**Easy serialization**](https://github.com/lebrice/SimpleParsing/tree/master/examples/serialization/README.md): - Easily save/load configs to `json` or `yaml`!. + Easily save/load configs to `json` or `yaml`!. - ### [**Inheritance**!](https://github.com/lebrice/SimpleParsing/tree/master/examples/inheritance/README.md) - You can easily customize an existing argument class by extending it and adding your own attributes, which helps promote code reuse across projects. For more info, take a look at the [inheritance example](https://github.com/lebrice/SimpleParsing/tree/master/examples/inheritance/inheritance_example.py) + You can easily customize an existing argument class by extending it and adding your own attributes, which helps promote code reuse across projects. For more info, take a look at the [inheritance example](https://github.com/lebrice/SimpleParsing/tree/master/examples/inheritance/inheritance_example.py) - ### [**Nesting**!](https://github.com/lebrice/SimpleParsing/tree/master/examples/nesting/README.md): - Dataclasses can be nested within dataclasses, as deep as you need! + Dataclasses can be nested within dataclasses, as deep as you need! - ### [Easier parsing of lists and tuples](https://github.com/lebrice/SimpleParsing/tree/master/examples/container_types/README.md) : - This is sometimes tricky to do with regular `argparse`, but `simple-parsing` makes it a lot easier by using the python's builtin type annotations to automatically convert the values to the right type for you. - As an added feature, by using these type annotations, `simple-parsing` allows you to parse nested lists or tuples, as can be seen in [this example](https://github.com/lebrice/SimpleParsing/tree/master/examples/merging/README.md) + + This is sometimes tricky to do with regular `argparse`, but `simple-parsing` makes it a lot easier by using the python's builtin type annotations to automatically convert the values to the right type for you. + As an added feature, by using these type annotations, `simple-parsing` allows you to parse nested lists or tuples, as can be seen in [this example](https://github.com/lebrice/SimpleParsing/tree/master/examples/merging/README.md) - ### [Enums support](https://github.com/lebrice/SimpleParsing/tree/master/examples/enums/README.md) - (More to come!) - ## Examples: + Additional examples for all the features mentioned above can be found in the [examples folder](https://github.com/lebrice/SimpleParsing/tree/master/examples/README.md) diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..e5ee46dc --- /dev/null +++ b/codecov.yml @@ -0,0 +1,14 @@ +coverage: + status: + project: + default: + target: auto # auto compares coverage to the previous base commit + informational: true + patch: + default: + target: 100% + informational: true + + +# When modifying this file, please validate using +# curl -X POST --data-binary @codecov.yml https://codecov.io/validate diff --git a/docs/Roadmap.md b/docs/Roadmap.md index 101c67ee..89b33182 100644 --- a/docs/Roadmap.md +++ b/docs/Roadmap.md @@ -1,11 +1,13 @@ ## Currently supported features: -* Parsing of attributes of built-in types: - * `int`, `float`, `str` attributes - * `bool` attributes (using either the `--` or the `-- ` syntax) - * `list` attributes - * `tuple` attributes -* Parsing of multiple instances of a given dataclass, for the above-mentioned attribute types -* Nested parsing of instances (dataclasses within dataclasses) + +- Parsing of attributes of built-in types: + - `int`, `float`, `str` attributes + - `bool` attributes (using either the `--` or the `-- ` syntax) + - `list` attributes + - `tuple` attributes +- Parsing of multiple instances of a given dataclass, for the above-mentioned attribute types +- Nested parsing of instances (dataclasses within dataclasses) ## Possible Future Enhancements: -* Parsing two different dataclasses which share a base class (this currently would cause a conflict for the base class arguments. + +- Parsing two different dataclasses which share a base class (this currently would cause a conflict for the base class arguments. diff --git a/examples/ML/README.md b/examples/ML/README.md index 3b958a30..f9baa69a 100644 --- a/examples/ML/README.md +++ b/examples/ML/README.md @@ -1,6 +1,9 @@ ## Use-Case Example: ML Scripts + Let's look at a great use-case for `simple-parsing`: ugly ML code: + ### Before: + ```python import argparse @@ -39,7 +42,9 @@ class MyModel(): m = MyModel(data_dir, log_dir, checkpoint_dir, learning_rate, momentum) # Ok, what if we wanted to add a new hyperparameter?! ``` + ### After: + ```python from dataclasses import dataclass from simple_parsing import ArgumentParser diff --git a/examples/ML/ml_example_after.py b/examples/ML/ml_example_after.py index 6783542a..c69ad76c 100644 --- a/examples/ML/ml_example_after.py +++ b/examples/ML/ml_example_after.py @@ -8,7 +8,7 @@ @dataclass class MyModelHyperParameters: - """Hyperparameters of MyModel""" + """Hyperparameters of MyModel.""" # Learning rate of the Adam optimizer. learning_rate: float = 0.05 @@ -18,7 +18,7 @@ class MyModelHyperParameters: @dataclass class TrainingConfig: - """Training configuration settings""" + """Training configuration settings.""" data_dir: str = "/data" log_dir: str = "/logs" diff --git a/examples/ML/other_ml_example.py b/examples/ML/other_ml_example.py index 7a464968..089cadef 100644 --- a/examples/ML/other_ml_example.py +++ b/examples/ML/other_ml_example.py @@ -8,7 +8,7 @@ @dataclass class MyModelHyperParameters: - """Hyperparameters of MyModel""" + """Hyperparameters of MyModel.""" # Batch size (per-GPU) batch_size: int = 32 @@ -20,7 +20,7 @@ class MyModelHyperParameters: @dataclass class TrainingConfig: - """Settings related to Training""" + """Settings related to Training.""" data_dir: str = "data" log_dir: str = "logs" @@ -29,7 +29,7 @@ class TrainingConfig: @dataclass class EvalConfig: - """Settings related to evaluation""" + """Settings related to evaluation.""" eval_dir: str = "eval_data" diff --git a/examples/README.md b/examples/README.md index bff33f25..f957fc5e 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,4 +1,5 @@ # Examples + - [dataclasses intro](dataclasses/README.md): Quick intro to Python's new [dataclasses](https://docs.python.org/3.7/library/dataclasses.html) module. - **[Simple example](simple/basic.py)**: Simple use-case example with a before/after comparison. diff --git a/examples/aliases/README.md b/examples/aliases/README.md index 04eeb573..7412c4a7 100644 --- a/examples/aliases/README.md +++ b/examples/aliases/README.md @@ -1,24 +1,23 @@ # Using Aliases - ## Notes about `option_strings`: + Additional names for the same argument can be added via the `alias` argument -of the `field` function (see [the custom_args Example]( - /examples/custom_args/README.md) for more info). +of the `field` function (see [the custom_args Example](/examples/custom_args/README.md) for more info). The `simple_parsing.ArgumentParser` accepts an argument (currently called `add_option_string_dash_variants`, which defaults to False) which adds additional variants to allow using either dashes or underscores to refer to an argument: -- Whenever the name of an attribute includes underscores ("_"), the same -argument can be passed by using dashes ("-") instead. This also includes -aliases. -- If an alias contained leading dashes, either single or double, the -same number of dashes will be used, even in the case where a prefix is -added. -For instance, consider the following example. -Here we have two prefixes: `"train"` and `"valid"`. -The corresponding option_strings for each argument will be -`["--train.debug", "-train.d"]` and `["--valid.debug", "-valid.d"]`, -respectively, as shown here: +- Whenever the name of an attribute includes underscores ("\_"), the same + argument can be passed by using dashes ("-") instead. This also includes + aliases. +- If an alias contained leading dashes, either single or double, the + same number of dashes will be used, even in the case where a prefix is + added. + For instance, consider the following example. + Here we have two prefixes: `"train"` and `"valid"`. + The corresponding option_strings for each argument will be + `["--train.debug", "-train.d"]` and `["--valid.debug", "-valid.d"]`, + respectively, as shown here: ```python from dataclasses import dataclass diff --git a/examples/config_files/README.md b/examples/config_files/README.md index b37b87ec..2464bb90 100644 --- a/examples/config_files/README.md +++ b/examples/config_files/README.md @@ -14,11 +14,11 @@ When using both options (the `config_path` parameter of `ArgumentParser.__init__ updated with the contents of the `--config_path` file(s). In other words, the default values are set like so, in increasing priority: + 1. normal defaults (e.g. from the dataclass definitions) 2. updated with the contents of the `config_path` file(s) of `ArgumentParser.__init__` 3. updated with the contents of the `--config_path` file(s) from the command-line. - ## [Single Config example](one_config.py) When using a single config dataclass, the `simple_parsing.parse` function can then be used to simplify the argument parsing setup a bit. diff --git a/examples/config_files/composition.py b/examples/config_files/composition.py index 9ed16ff2..1e7e22aa 100644 --- a/examples/config_files/composition.py +++ b/examples/config_files/composition.py @@ -1,4 +1,4 @@ -""" Example where we compose different configurations! """ +"""Example where we compose different configurations!""" import shlex from dataclasses import dataclass diff --git a/examples/config_files/many_configs.py b/examples/config_files/many_configs.py index 50e6147d..41807306 100644 --- a/examples/config_files/many_configs.py +++ b/examples/config_files/many_configs.py @@ -6,7 +6,7 @@ @dataclass class TrainConfig: - """Training config for Machine Learning""" + """Training config for Machine Learning.""" workers: int = 8 # The number of workers for training exp_name: str = "default_exp" # The experiment name diff --git a/examples/config_files/one_config.py b/examples/config_files/one_config.py index ec30a1d9..a24a7f4b 100644 --- a/examples/config_files/one_config.py +++ b/examples/config_files/one_config.py @@ -1,4 +1,4 @@ -""" Example adapted from https://github.com/eladrich/pyrallis#my-first-pyrallis-example- """ +"""Example adapted from https://github.com/eladrich/pyrallis#my-first-pyrallis-example-""" from dataclasses import dataclass import simple_parsing diff --git a/examples/container_types/lists_example.py b/examples/container_types/lists_example.py index bda1d303..c28d0a39 100644 --- a/examples/container_types/lists_example.py +++ b/examples/container_types/lists_example.py @@ -10,9 +10,10 @@ class Example: some_integers: List[int] = field( default_factory=list ) # This is a list of integers (empty by default) - """This list is empty, by default. when passed some parameters, they are - automatically converted to integers, since we annotated the attribute with - a type (typing.List[]). + """This list is empty, by default. + + when passed some parameters, they are automatically converted to integers, since we annotated + the attribute with a type (typing.List[]). """ # When using a list attribute, the dataclasses module requires us to use `dataclass.field()`, diff --git a/examples/custom_args/README.md b/examples/custom_args/README.md index c6493fd9..2c8d5398 100644 --- a/examples/custom_args/README.md +++ b/examples/custom_args/README.md @@ -3,8 +3,7 @@ The `dataclasses.field()` function is used to customize the declaration of fields on a dataclass. It accepts, among others, the `default`, `default_factory`, arguments used to set the default instance values to fields -(please take a look at the [official documentation]( -https://docs.python.org/3/library/dataclasses.html#dataclasses.field) For more +(please take a look at the [official documentation](https://docs.python.org/3/library/dataclasses.html#dataclasses.field) For more information). `simple-parsing` provides an overloaded version of this function: @@ -18,55 +17,59 @@ The values passed this way take precedence over those auto-generated by ## Examples - ### List of choices - For example, here is how you would create a list of choices, whereby any - of the passed arguments can only be contained within the choices: - ```python + For example, here is how you would create a list of choices, whereby any + of the passed arguments can only be contained within the choices: - from dataclasses import dataclass - from simple_parsing import ArgumentParser, field - from typing import List + ```python - @dataclass - class Foo: - """ Some class Foo """ + from dataclasses import dataclass + from simple_parsing import ArgumentParser, field + from typing import List - # A sequence of tasks. - task_sequence: List[str] = field(choices=["train", "test", "ood"]) + @dataclass + class Foo: + """ Some class Foo """ - parser = ArgumentParser() - parser.add_arguments(Foo, "foo") + # A sequence of tasks. + task_sequence: List[str] = field(choices=["train", "test", "ood"]) - args = parser.parse_args("--task_sequence train train ood".split()) - foo: Foo = args.foo - print(foo) - assert foo.task_sequence == ["train", "train", "ood"] + parser = ArgumentParser() + parser.add_arguments(Foo, "foo") + + args = parser.parse_args("--task_sequence train train ood".split()) + foo: Foo = args.foo + print(foo) + assert foo.task_sequence == ["train", "train", "ood"] + + ``` - ``` - ### Adding additional aliases for an argument - By passing the - ```python - @dataclass - class Foo(TestSetup): - """ Some example Foo. """ - # The output directory. (can be passed using any of "-o" or --out or ) - output_dir: str = field( - default="/out", - alias=["-o", "--out"], - choices=["/out", "/bob"] - ) - - foo = Foo.setup("--output_dir /bob") - assert foo.output_dir == "/bob" - - with raises(): - foo = Foo.setup("-o /cat") - assert foo.output_dir == "/cat" - - foo = Foo.setup("--out /bob") - assert foo.output_dir == "/bob" - ``` + By passing the + + ```python + @dataclass + class Foo(TestSetup): + """ Some example Foo. """ + # The output directory. (can be passed using any of "-o" or --out or ) + output_dir: str = field( + default="/out", + alias=["-o", "--out"], + choices=["/out", "/bob"] + ) + + foo = Foo.setup("--output_dir /bob") + assert foo.output_dir == "/bob" + + with raises(): + foo = Foo.setup("-o /cat") + assert foo.output_dir == "/cat" + + foo = Foo.setup("--out /bob") + assert foo.output_dir == "/bob" + ``` - ### Adding Flags with "store-true" or "store-false" - Additionally, + + Additionally, diff --git a/examples/custom_args/custom_args_example.py b/examples/custom_args/custom_args_example.py index 7f0680e4..dd154407 100644 --- a/examples/custom_args/custom_args_example.py +++ b/examples/custom_args/custom_args_example.py @@ -1,5 +1,4 @@ -"""Example of overwriting auto-generated argparse options with custom ones. -""" +"""Example of overwriting auto-generated argparse options with custom ones.""" from dataclasses import dataclass from typing import List @@ -29,7 +28,9 @@ class Example1: assert parse(Example1, "") == Example1(pets_to_walk=["dog"]) assert parse(Example1, "--pets_to_walk") == Example1(pets_to_walk=[]) assert parse(Example1, "--pets_to_walk cat") == Example1(pets_to_walk=["cat"]) -assert parse(Example1, "--pets_to_walk dog dog cat") == Example1(pets_to_walk=["dog", "dog", "cat"]) +assert parse(Example1, "--pets_to_walk dog dog cat") == Example1( + pets_to_walk=["dog", "dog", "cat"] +) # # Passing a value not in 'choices' produces an error: diff --git a/examples/dataclasses/README.md b/examples/dataclasses/README.md index 837dd3d6..47b953cd 100644 --- a/examples/dataclasses/README.md +++ b/examples/dataclasses/README.md @@ -4,12 +4,11 @@ These are simple examples showing how to use `@dataclass` to create argument cla First, take a look at the official [dataclasses module documentation](https://docs.python.org/3.7/library/dataclasses.html). - With `simple-parsing`, groups of attributes are defined directly in code as dataclasses, each holding a set of related parameters. Methods can also be added to these dataclasses, which helps to promote the "Separation of Concerns" principle by keeping all the logic related to argument parsing in the same place as the arguments themselves. ## Examples: -- [dataclass_example.py](dataclass_example.py): a simple toy example showing an example of a dataclass +- [dataclass_example.py](dataclass_example.py): a simple toy example showing an example of a dataclass - [hyperparameters_example.py](hyperparameters_example.py): Shows an example of an argument dataclass which also defines serialization methods. diff --git a/examples/dataclasses/dataclass_example.py b/examples/dataclasses/dataclass_example.py index a6e3e19a..f55c8f45 100644 --- a/examples/dataclasses/dataclass_example.py +++ b/examples/dataclasses/dataclass_example.py @@ -4,7 +4,7 @@ @dataclass class Point: - """simple class Point""" + """Simple class Point.""" x: float = 0.0 y: float = 0.0 diff --git a/examples/demo.py b/examples/demo.py index 523b8027..679fc34f 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -9,7 +9,7 @@ @dataclass class Options: - """Help string for this group of command-line arguments""" + """Help string for this group of command-line arguments.""" log_dir: str # Help string for a required str argument learning_rate: float = 1e-4 # Help string for a float argument diff --git a/examples/demo_simple.py b/examples/demo_simple.py index cbd02efb..15f0e31b 100644 --- a/examples/demo_simple.py +++ b/examples/demo_simple.py @@ -6,7 +6,7 @@ @dataclass class Options: - """Help string for this group of command-line arguments""" + """Help string for this group of command-line arguments.""" log_dir: str # Help string for a required str argument learning_rate: float = 1e-4 # Help string for a float argument diff --git a/examples/docstrings/README.md b/examples/docstrings/README.md index 0c9c0225..1ee2c893 100644 --- a/examples/docstrings/README.md +++ b/examples/docstrings/README.md @@ -1,13 +1,15 @@ # Docstrings A docstring can either be: + - A comment on the same line as the attribute definition - A single or multi-line comment on the line(s) preceding the attribute definition - A single or multi-line docstring on the line(s) following the attribute -definition, starting with either `"""` or `'''` and ending with the same token. + definition, starting with either `"""` or `'''` and ending with the same token. When more than one docstring options are present, one of them is chosen to be used as the '--help' text of the attribute, according to the following ordering: + 1. docstring below the attribute 2. comment above the attribute 3. inline comment diff --git a/examples/docstrings/docstrings_example.py b/examples/docstrings/docstrings_example.py index ff7c6784..97f3e85f 100644 --- a/examples/docstrings/docstrings_example.py +++ b/examples/docstrings/docstrings_example.py @@ -1,7 +1,4 @@ -""" -A simple example to demonstrate the 'attribute docstrings' mechanism of simple-parsing. - -""" +"""A simple example to demonstrate the 'attribute docstrings' mechanism of simple-parsing.""" from dataclasses import dataclass from simple_parsing import ArgumentParser @@ -16,7 +13,7 @@ class DocStringsExample: """ attribute1: float = 1.0 - """docstring below, When used, this always shows up in the --help text for this attribute""" + """Docstring below, When used, this always shows up in the --help text for this attribute.""" # Comment above only: this shows up in the help text, since there is no docstring below. attribute2: float = 1.0 diff --git a/examples/enums/README.md b/examples/enums/README.md index 45409742..25ecaaf6 100644 --- a/examples/enums/README.md +++ b/examples/enums/README.md @@ -34,8 +34,4 @@ print(prefs) ``` - - - - You parse most datatypes using `simple-parsing`, as the type annotation on an argument is called as a conversion function in case the type of the attribute is not a builtin type or a dataclass. diff --git a/examples/enums/enums_example.py b/examples/enums/enums_example.py index 12f68f3a..1cafff84 100644 --- a/examples/enums/enums_example.py +++ b/examples/enums/enums_example.py @@ -21,7 +21,7 @@ class Temperature(enum.Enum): @dataclass class MyPreferences: - """You can use Enums""" + """You can use Enums.""" color: Color = Color.BLUE # my favorite colour temp: Temperature = Temperature.WARM diff --git a/examples/inheritance/ml_inheritance.py b/examples/inheritance/ml_inheritance.py index 9dca0fa1..e0a2f75d 100644 --- a/examples/inheritance/ml_inheritance.py +++ b/examples/inheritance/ml_inheritance.py @@ -22,30 +22,31 @@ def __init__(self, hparams: HyperParameters): class WGAN(GAN): - """ - Wasserstein GAN - """ + """Wasserstein GAN.""" @dataclass class HyperParameters(GAN.HyperParameters): e_drift: float = 1e-4 - """Coefficient from the progan authors which penalizes critic outputs for having a large magnitude.""" + """Coefficient from the progan authors which penalizes critic outputs for having a large + magnitude.""" def __init__(self, hparams: HyperParameters): self.hparams = hparams class WGANGP(WGAN): - """ - Wasserstein GAN with Gradient Penalty - """ + """Wasserstein GAN with Gradient Penalty.""" @dataclass class HyperParameters(WGAN.HyperParameters): e_drift: float = 1e-4 - """Coefficient from the progan authors which penalizes critic outputs for having a large magnitude.""" + """Coefficient from the progan authors which penalizes critic outputs for having a large + magnitude.""" gp_coefficient: float = 10.0 - """Multiplying coefficient for the gradient penalty term of the loss equation. (10.0 is the default value, and was used by the PROGAN authors.)""" + """Multiplying coefficient for the gradient penalty term of the loss equation. + + (10.0 is the default value, and was used by the PROGAN authors.) + """ def __init__(self, hparams: HyperParameters): self.hparams: WGANGP.HyperParameters = hparams diff --git a/examples/inheritance/ml_inheritance_2.py b/examples/inheritance/ml_inheritance_2.py index a506bcb8..4ddb9f19 100644 --- a/examples/inheritance/ml_inheritance_2.py +++ b/examples/inheritance/ml_inheritance_2.py @@ -17,14 +17,14 @@ class ConvBlock(Serializable): @dataclass class GeneratorHParams(ConvBlock): - """Settings of the Generator model""" + """Settings of the Generator model.""" optimizer: str = choice("ADAM", "RMSPROP", "SGD", default="ADAM") @dataclass class DiscriminatorHParams(ConvBlock): - """Settings of the Discriminator model""" + """Settings of the Discriminator model.""" optimizer: str = choice("ADAM", "RMSPROP", "SGD", default="ADAM") @@ -48,14 +48,15 @@ def __init__(self, hparams: GanHParams): @dataclass class WGanHParams(GanHParams): - """HParams of the WGAN model""" + """HParams of the WGAN model.""" e_drift: float = 1e-4 - """Coefficient from the progan authors which penalizes critic outputs for having a large magnitude.""" + """Coefficient from the progan authors which penalizes critic outputs for having a large + magnitude.""" class WGAN(GAN): - """Wasserstein GAN""" + """Wasserstein GAN.""" def __init__(self, hparams: WGanHParams): self.hparams = hparams @@ -70,21 +71,23 @@ class CriticHParams(DiscriminatorHParams): @dataclass class WGanGPHParams(WGanHParams): - """Hyperparameters of the WGAN with Gradient Penalty""" + """Hyperparameters of the WGAN with Gradient Penalty.""" e_drift: float = 1e-4 - """Coefficient from the progan authors which penalizes critic outputs for having a large magnitude.""" + """Coefficient from the progan authors which penalizes critic outputs for having a large + magnitude.""" gp_coefficient: float = 10.0 - """Multiplying coefficient for the gradient penalty term of the loss equation. (10.0 is the default value, and was used by the PROGAN authors.)""" + """Multiplying coefficient for the gradient penalty term of the loss equation. + + (10.0 is the default value, and was used by the PROGAN authors.) + """ disc: CriticHParams = field(default_factory=CriticHParams) # overwrite the usual 'disc' field of the WGanHParams dataclass. """ Parameters of the Critic. """ class WGANGP(WGAN): - """ - Wasserstein GAN with Gradient Penalty - """ + """Wasserstein GAN with Gradient Penalty.""" def __init__(self, hparams: WGanGPHParams): self.hparams = hparams diff --git a/examples/merging/README.md b/examples/merging/README.md index d1f682a6..896c3ef0 100644 --- a/examples/merging/README.md +++ b/examples/merging/README.md @@ -10,5 +10,6 @@ To do this, we pass the `ConflictResolution.ALWAYS_MERGE` option to the argument For more info, check out the docstring of the `ConflictResolution` enum. ## Examples: + - [multiple_example.py](multiple_example.py) - [multiple_lists_example.py](multiple_lists_example.py) diff --git a/examples/merging/multiple_example.py b/examples/merging/multiple_example.py index 5064ac35..f582748e 100644 --- a/examples/merging/multiple_example.py +++ b/examples/merging/multiple_example.py @@ -1,4 +1,5 @@ """Example of how to create multiple instances of a class from the command-line. + # NOTE: If your dataclass has a list attribute, and you wish to parse multiple instances of that class from the command line, # simply enclose each list with single or double quotes. # For this example, something like: @@ -18,7 +19,11 @@ class Config: run_name: str = "train" # Some parameter for the run name. some_int: int = 10 # an optional int parameter. log_dir: str = "logs" # an optional string parameter. - """the logging directory to use. (This is an attribute docstring for the log_dir attribute, and shows up when using the "--help" argument!)""" + """the logging directory to use. + + (This is an attribute docstring for the log_dir attribute, and shows up when using the "--help" + argument!) + """ parser.add_arguments(Config, "train_config") diff --git a/examples/merging/multiple_lists_example.py b/examples/merging/multiple_lists_example.py index b8fd4dba..0a0a3aca 100644 --- a/examples/merging/multiple_lists_example.py +++ b/examples/merging/multiple_lists_example.py @@ -1,7 +1,6 @@ -""" -Here, we demonstrate parsing multiple classes each of which has a list attribute. -There are a few options for doing this. For example, if we want to let each instance -have a distinct prefix for its arguments, we could use the ConflictResolution.AUTO option. +"""Here, we demonstrate parsing multiple classes each of which has a list attribute. There are a +few options for doing this. For example, if we want to let each instance have a distinct prefix for +its arguments, we could use the ConflictResolution.AUTO option. Here, we want to create a few instances of `CNNStack` from the command line, but don't want to have a different prefix for each instance. diff --git a/examples/nesting/nesting_example.py b/examples/nesting/nesting_example.py index 63060b3b..55076c74 100644 --- a/examples/nesting/nesting_example.py +++ b/examples/nesting/nesting_example.py @@ -7,9 +7,7 @@ @dataclass class TaskHyperParameters: - """ - HyperParameters for a task-specific model - """ + """HyperParameters for a task-specific model.""" # name of the task name: str diff --git a/examples/partials/README.md b/examples/partials/README.md index 560767d1..0b63f062 100644 --- a/examples/partials/README.md +++ b/examples/partials/README.md @@ -1,2 +1 @@ # Partials - Configuring arbitrary classes / callables - diff --git a/examples/partials/partials_example.py b/examples/partials/partials_example.py index 34d30a23..f3c2b7c6 100644 --- a/examples/partials/partials_example.py +++ b/examples/partials/partials_example.py @@ -57,7 +57,6 @@ def __init__( @dataclass class Config: - # Which optimizer to use. optimizer: Partial[Optimizer] = subgroups( { diff --git a/examples/prefixing/manual_prefix_example.py b/examples/prefixing/manual_prefix_example.py index 576ba812..3c8053c7 100644 --- a/examples/prefixing/manual_prefix_example.py +++ b/examples/prefixing/manual_prefix_example.py @@ -5,7 +5,7 @@ @dataclass class Config: - """Simple example of a class that can be reused""" + """Simple example of a class that can be reused.""" log_dir: str = "logs" diff --git a/examples/serialization/README.md b/examples/serialization/README.md index 3a6e4566..91f3d9c5 100644 --- a/examples/serialization/README.md +++ b/examples/serialization/README.md @@ -40,8 +40,8 @@ Student(name='Bob', age=20, domain='Computer Science', average_grade=0.8) >>> assert _bob == bob ``` - ## Adding custom types + Register a new encoding function using `encode`, and a new decoding function using `register_decoding_fn` For example: Consider the same example as above, but we add a Tensor attribute from `pytorch`. diff --git a/examples/serialization/custom_types_example.py b/examples/serialization/custom_types_example.py index ae0190e7..4a1908c2 100644 --- a/examples/serialization/custom_types_example.py +++ b/examples/serialization/custom_types_example.py @@ -27,7 +27,7 @@ class Student(Person): @encode.register def encode_tensor(obj: Tensor) -> List: - """We choose to encode a tensor as a list, for instance""" + """We choose to encode a tensor as a list, for instance.""" return obj.tolist() diff --git a/examples/serialization/serialization_example.ipynb b/examples/serialization/serialization_example.ipynb index 6f6aaf60..cd56862b 100644 --- a/examples/serialization/serialization_example.ipynb +++ b/examples/serialization/serialization_example.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -22,17 +22,11 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "tags": [] }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": "age: 20\nname: Charlie\n\n{\"name\": \"Charlie\", \"age\": 20}\n{\"name\": \"Charlie\", \"age\": 20}\n" - } - ], + "outputs": [], "source": [ "\n", "# Serialization:\n", @@ -45,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "tags": [] }, @@ -71,6 +65,10 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -82,11 +80,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10-final" - }, - "orig_nbformat": 2, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" } }, "nbformat": 4, diff --git a/examples/simple/flag.py b/examples/simple/flag.py index 4a4c6daf..f3f28f44 100644 --- a/examples/simple/flag.py +++ b/examples/simple/flag.py @@ -35,7 +35,7 @@ class HParams: assert parse(HParams, "--no-train") == HParams(train=False) -# showing what --help outputs +# showing what --help outputs parser = ArgumentParser() # Create an argument parser parser.add_arguments(HParams, dest="hparams") # add arguments for the dataclass parser.print_help() @@ -56,4 +56,4 @@ class HParams: (default: 0.001) --train bool, --no-train bool (default: True) -""" \ No newline at end of file +""" diff --git a/examples/simple/help.py b/examples/simple/help.py index 65d5f6bb..b0bbdbbb 100644 --- a/examples/simple/help.py +++ b/examples/simple/help.py @@ -30,10 +30,8 @@ class HParams: learning_rate: float = 0.001 # Learning_rate used by the optimizer. alpha: float = 0.05 # TODO: Tune this. (This doesn't appear in '--help') - """ - A detailed description of this new 'alpha' parameter, which can potentially - span multiple lines. - """ + """A detailed description of this new 'alpha' parameter, which can potentially span multiple + lines.""" parser = ArgumentParser() diff --git a/examples/simple/reuse.py b/examples/simple/reuse.py index 5902997b..bce48e62 100644 --- a/examples/simple/reuse.py +++ b/examples/simple/reuse.py @@ -1,12 +1,9 @@ -""" Modular and reusable! -With SimpleParsing, you can easily add similar groups of command-line arguments -by simply reusing the dataclasses you define! -There is no longer need for any copy-pasting of blocks, or adding prefixes -everywhere by hand. +"""Modular and reusable! With SimpleParsing, you can easily add similar groups of command-line +arguments by simply reusing the dataclasses you define! There is no longer need for any copy- +pasting of blocks, or adding prefixes everywhere by hand. -Instead, the ArgumentParser detects when more than one instance of -the same `@dataclass` needs to be parsed, and automatically adds the relevant -prefixes to the arguments for you. +Instead, the ArgumentParser detects when more than one instance of the same `@dataclass` needs to +be parsed, and automatically adds the relevant prefixes to the arguments for you. """ from dataclasses import dataclass diff --git a/examples/subgroups/README.md b/examples/subgroups/README.md index d878de04..0f500989 100644 --- a/examples/subgroups/README.md +++ b/examples/subgroups/README.md @@ -3,7 +3,6 @@ Adding a choice between different subgroups of arguments can be very difficult using Argparse. Subparsers are not exactly meant for this, and they introduce many errors - This friction is one of the motivating factors for a plethora of argument parsing frameworks such as Hydra, Click, and others. @@ -118,7 +117,6 @@ Dataset2Config ['config.dataset']: --dataset.bar float (default: 1.2) ``` - ```console $ python examples/subgroups/subgroups_example.py --model model_b --help usage: subgroups_example.py [-h] [--model {model_a,model_b}] [--dataset {dataset_1,dataset_2}] [--model.lr float] [--model.optimizer str] [--model.momentum float] diff --git a/examples/subgroups/subgroups_example.py b/examples/subgroups/subgroups_example.py index 27ba02c7..0b4610bc 100644 --- a/examples/subgroups/subgroups_example.py +++ b/examples/subgroups/subgroups_example.py @@ -45,7 +45,6 @@ class Dataset2Config(DatasetConfig): @dataclass class Config: - # Which model to use model: ModelConfig = subgroups( {"model_a": ModelAConfig, "model_b": ModelBConfig}, diff --git a/examples/subparsers/README.md b/examples/subparsers/README.md index 12b13d7f..3a29cfd5 100644 --- a/examples/subparsers/README.md +++ b/examples/subparsers/README.md @@ -1,15 +1,15 @@ ### [(Examples Home)](../README.md) -# Creating Commands with Subparsers +# Creating Commands with Subparsers Subparsers are one of the more advanced features of `argparse`. They allow the creation of subcommands, each having their own set of arguments. The `git` command, for instance, takes different arguments than the `pull` subcommand in `git pull`. For some more info on subparsers, check out the [argparse documentation](https://docs.python.org/3/library/argparse.html#argparse.ArgumentParser.add_subparsers). - With `simple-parsing`, subparsers can easily be created by using a `Union` type annotation on a dataclass attribute. By annotating a variable with a Union type, for example `x: Union[T1, T2]`, we simply state that `x` can either be of type `T1` or `T2`. When the arguments to the `Union` type **are all dataclasses**, `simple-parsing` creates subparsers for each dataclass type, using the lowercased class name as the command name by default. If you want to extend or change this behaviour (to have "t" and "train" map to the same training subcommand, for example), use the `subparsers` function, passing in a dictionary mapping command names to the appropriate type. + ## Example: @@ -61,49 +61,55 @@ prog.execute() ``` Here are some usage examples: + - Executing the training command: - ```console - $ python examples/subparsers/subparsers_example.py train - prog: Program(command=Train(train_dir=PosixPath('~/train')), verbose=False) - Executing Program (verbose: False) - Training in directory ~/train - ``` + ```console + $ python examples/subparsers/subparsers_example.py train + prog: Program(command=Train(train_dir=PosixPath('~/train')), verbose=False) + Executing Program (verbose: False) + Training in directory ~/train + ``` + - Passing a custom training directory: - ```console - $ python examples/subparsers/subparsers_example.py train --train_dir ~/train - prog: Program(command=Train(train_dir=PosixPath('/home/fabrice/train')), verbose=False) - Executing Program (verbose: False) - Training in directory /home/fabrice/train - ``` + + ```console + $ python examples/subparsers/subparsers_example.py train --train_dir ~/train + prog: Program(command=Train(train_dir=PosixPath('/home/fabrice/train')), verbose=False) + Executing Program (verbose: False) + Training in directory /home/fabrice/train + ``` + - Getting help for a subcommand: - ```console - $ python examples/subparsers/subparsers_example.py train --help - usage: subparsers_example.py train [-h] [--train_dir Path] - optional arguments: - -h, --help show this help message and exit + ```console + $ python examples/subparsers/subparsers_example.py train --help + usage: subparsers_example.py train [-h] [--train_dir Path] + + optional arguments: + -h, --help show this help message and exit + + Train ['prog.command']: + Example of a command to start a Training run. - Train ['prog.command']: - Example of a command to start a Training run. + --train_dir Path the training directory (default: ~/train) + ``` - --train_dir Path the training directory (default: ~/train) - ``` - Getting Help for the parent command: - ```console - $ python examples/subparsers/subparsers_example.py --help - usage: subparsers_example.py [-h] [--verbose [str2bool]] {train,test} ... + ```console + $ python examples/subparsers/subparsers_example.py --help + usage: subparsers_example.py [-h] [--verbose [str2bool]] {train,test} ... - optional arguments: - -h, --help show this help message and exit + optional arguments: + -h, --help show this help message and exit - Program ['prog']: - Some top-level command + Program ['prog']: + Some top-level command - --verbose [str2bool] log additional messages in the console. (default: - False) + --verbose [str2bool] log additional messages in the console. (default: + False) - command: - {train,test} - ``` + command: + {train,test} + ``` diff --git a/examples/subparsers/subparsers_example.py b/examples/subparsers/subparsers_example.py index d16a423e..7b30c7aa 100644 --- a/examples/subparsers/subparsers_example.py +++ b/examples/subparsers/subparsers_example.py @@ -29,7 +29,7 @@ def execute(self): @dataclass class Program: - """Some top-level command""" + """Some top-level command.""" command: Union[Train, Test] verbose: bool = False # log additional messages in the console. diff --git a/examples/ugly/ugly_example_after.py b/examples/ugly/ugly_example_after.py index 4c490183..f0fc1cec 100644 --- a/examples/ugly/ugly_example_after.py +++ b/examples/ugly/ugly_example_after.py @@ -14,12 +14,11 @@ @dataclass class DatasetParams: - """Dataset Parameters""" + """Dataset Parameters.""" default_root: ClassVar[str] = "/dataset" # the default root directory to use. - dataset: str = "objects_folder_multi" # laptop,pistol - """ dataset name: [shapenet, objects_folder, objects_folder]') """ + """dataset name: [shapenet, objects_folder, objects_folder]')""" root_dir: str = default_root # dataset root directory root_dir1: str = default_root # dataset root directory @@ -89,7 +88,7 @@ class NetworkParams: @dataclass class OptimizerParams: - """Optimization parameters""" + """Optimization parameters.""" optimizer: str = "adam" # Optimizer (adam, rmsprop) lr: float = 0.0001 # learning rate, default=0.0002 @@ -109,12 +108,14 @@ class OptimizerParams: n_iter: int = 76201 # number of iterations to train batchSize: int = 4 # input batch size alt_opt_zn_interval: Optional[int] = None - """ Alternating optimization interval. + """Alternating optimization interval. + - None: joint optimization - 20: every 20 iterations, etc. """ alt_opt_zn_start: int = 100000 """Alternating optimization start interaction. + - -1: starts immediately, - '100: starts alternating after the first 100 iterations. """ @@ -122,7 +123,7 @@ class OptimizerParams: @dataclass class GanParams: - """Gan parameters""" + """Gan parameters.""" criterion: str = choice("GAN", "WGAN", default="WGAN") # GAN Training criterion gp: str = choice("None", "original", default="original") # Add gradient penalty @@ -133,7 +134,7 @@ class GanParams: @dataclass class OtherParams: - """Other parameters""" + """Other parameters.""" manualSeed: int = 1 # manual seed no_cuda: bool = False # enables cuda @@ -144,7 +145,7 @@ class OtherParams: @dataclass class CameraParams: - """Camera Parameters""" + """Camera Parameters.""" cam_pos: Tuple[float, float, float] = (0.0, 0.0, 0.0) # Camera position. width: int = 128 @@ -179,7 +180,7 @@ class RenderingParams: est_normals: bool = False # Estimate normals from splat positions. n_splats: Optional[int] = None same_view: bool = False # before we add conditioning on cam pose, this is necessary - """ data with view fixed """ + """Data with view fixed.""" print_interval: int = 10 # Print loss interval. save_image_interval: int = 100 # Save image interval. @@ -204,7 +205,7 @@ class Parameters: other: OtherParams = field(default_factory=OtherParams) def __post_init__(self): - """Post-initialization code""" + """Post-initialization code.""" # Make output folder # try: # os.makedirs(self.other.out_dir) diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 00000000..c0d4fc47 --- /dev/null +++ b/poetry.lock @@ -0,0 +1,420 @@ +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "coverage" +version = "7.4.1" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "coverage-7.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:077d366e724f24fc02dbfe9d946534357fda71af9764ff99d73c3c596001bbd7"}, + {file = "coverage-7.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0193657651f5399d433c92f8ae264aff31fc1d066deee4b831549526433f3f61"}, + {file = "coverage-7.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d17bbc946f52ca67adf72a5ee783cd7cd3477f8f8796f59b4974a9b59cacc9ee"}, + {file = "coverage-7.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a3277f5fa7483c927fe3a7b017b39351610265308f5267ac6d4c2b64cc1d8d25"}, + {file = "coverage-7.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dceb61d40cbfcf45f51e59933c784a50846dc03211054bd76b421a713dcdf19"}, + {file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6008adeca04a445ea6ef31b2cbaf1d01d02986047606f7da266629afee982630"}, + {file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c61f66d93d712f6e03369b6a7769233bfda880b12f417eefdd4f16d1deb2fc4c"}, + {file = "coverage-7.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b9bb62fac84d5f2ff523304e59e5c439955fb3b7f44e3d7b2085184db74d733b"}, + {file = "coverage-7.4.1-cp310-cp310-win32.whl", hash = "sha256:f86f368e1c7ce897bf2457b9eb61169a44e2ef797099fb5728482b8d69f3f016"}, + {file = "coverage-7.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:869b5046d41abfea3e381dd143407b0d29b8282a904a19cb908fa24d090cc018"}, + {file = "coverage-7.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b8ffb498a83d7e0305968289441914154fb0ef5d8b3157df02a90c6695978295"}, + {file = "coverage-7.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3cacfaefe6089d477264001f90f55b7881ba615953414999c46cc9713ff93c8c"}, + {file = "coverage-7.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d6850e6e36e332d5511a48a251790ddc545e16e8beaf046c03985c69ccb2676"}, + {file = "coverage-7.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18e961aa13b6d47f758cc5879383d27b5b3f3dcd9ce8cdbfdc2571fe86feb4dd"}, + {file = "coverage-7.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dfd1e1b9f0898817babf840b77ce9fe655ecbe8b1b327983df485b30df8cc011"}, + {file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6b00e21f86598b6330f0019b40fb397e705135040dbedc2ca9a93c7441178e74"}, + {file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:536d609c6963c50055bab766d9951b6c394759190d03311f3e9fcf194ca909e1"}, + {file = "coverage-7.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7ac8f8eb153724f84885a1374999b7e45734bf93a87d8df1e7ce2146860edef6"}, + {file = "coverage-7.4.1-cp311-cp311-win32.whl", hash = "sha256:f3771b23bb3675a06f5d885c3630b1d01ea6cac9e84a01aaf5508706dba546c5"}, + {file = "coverage-7.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:9d2f9d4cc2a53b38cabc2d6d80f7f9b7e3da26b2f53d48f05876fef7956b6968"}, + {file = "coverage-7.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f68ef3660677e6624c8cace943e4765545f8191313a07288a53d3da188bd8581"}, + {file = "coverage-7.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:23b27b8a698e749b61809fb637eb98ebf0e505710ec46a8aa6f1be7dc0dc43a6"}, + {file = "coverage-7.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e3424c554391dc9ef4a92ad28665756566a28fecf47308f91841f6c49288e66"}, + {file = "coverage-7.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e0860a348bf7004c812c8368d1fc7f77fe8e4c095d661a579196a9533778e156"}, + {file = "coverage-7.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe558371c1bdf3b8fa03e097c523fb9645b8730399c14fe7721ee9c9e2a545d3"}, + {file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3468cc8720402af37b6c6e7e2a9cdb9f6c16c728638a2ebc768ba1ef6f26c3a1"}, + {file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:02f2edb575d62172aa28fe00efe821ae31f25dc3d589055b3fb64d51e52e4ab1"}, + {file = "coverage-7.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ca6e61dc52f601d1d224526360cdeab0d0712ec104a2ce6cc5ccef6ed9a233bc"}, + {file = "coverage-7.4.1-cp312-cp312-win32.whl", hash = "sha256:ca7b26a5e456a843b9b6683eada193fc1f65c761b3a473941efe5a291f604c74"}, + {file = "coverage-7.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:85ccc5fa54c2ed64bd91ed3b4a627b9cce04646a659512a051fa82a92c04a448"}, + {file = "coverage-7.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8bdb0285a0202888d19ec6b6d23d5990410decb932b709f2b0dfe216d031d218"}, + {file = "coverage-7.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:918440dea04521f499721c039863ef95433314b1db00ff826a02580c1f503e45"}, + {file = "coverage-7.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:379d4c7abad5afbe9d88cc31ea8ca262296480a86af945b08214eb1a556a3e4d"}, + {file = "coverage-7.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b094116f0b6155e36a304ff912f89bbb5067157aff5f94060ff20bbabdc8da06"}, + {file = "coverage-7.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2f5968608b1fe2a1d00d01ad1017ee27efd99b3437e08b83ded9b7af3f6f766"}, + {file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:10e88e7f41e6197ea0429ae18f21ff521d4f4490aa33048f6c6f94c6045a6a75"}, + {file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a4a3907011d39dbc3e37bdc5df0a8c93853c369039b59efa33a7b6669de04c60"}, + {file = "coverage-7.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6d224f0c4c9c98290a6990259073f496fcec1b5cc613eecbd22786d398ded3ad"}, + {file = "coverage-7.4.1-cp38-cp38-win32.whl", hash = "sha256:23f5881362dcb0e1a92b84b3c2809bdc90db892332daab81ad8f642d8ed55042"}, + {file = "coverage-7.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:a07f61fc452c43cd5328b392e52555f7d1952400a1ad09086c4a8addccbd138d"}, + {file = "coverage-7.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8e738a492b6221f8dcf281b67129510835461132b03024830ac0e554311a5c54"}, + {file = "coverage-7.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:46342fed0fff72efcda77040b14728049200cbba1279e0bf1188f1f2078c1d70"}, + {file = "coverage-7.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9641e21670c68c7e57d2053ddf6c443e4f0a6e18e547e86af3fad0795414a628"}, + {file = "coverage-7.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aeb2c2688ed93b027eb0d26aa188ada34acb22dceea256d76390eea135083950"}, + {file = "coverage-7.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d12c923757de24e4e2110cf8832d83a886a4cf215c6e61ed506006872b43a6d1"}, + {file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0491275c3b9971cdbd28a4595c2cb5838f08036bca31765bad5e17edf900b2c7"}, + {file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8dfc5e195bbef80aabd81596ef52a1277ee7143fe419efc3c4d8ba2754671756"}, + {file = "coverage-7.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1a78b656a4d12b0490ca72651fe4d9f5e07e3c6461063a9b6265ee45eb2bdd35"}, + {file = "coverage-7.4.1-cp39-cp39-win32.whl", hash = "sha256:f90515974b39f4dea2f27c0959688621b46d96d5a626cf9c53dbc653a895c05c"}, + {file = "coverage-7.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:64e723ca82a84053dd7bfcc986bdb34af8d9da83c521c19d6b472bc6880e191a"}, + {file = "coverage-7.4.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:32a8d985462e37cfdab611a6f95b09d7c091d07668fdc26e47a725ee575fe166"}, + {file = "coverage-7.4.1.tar.gz", hash = "sha256:1ed4b95480952b1a26d863e546fa5094564aa0065e1e5f0d4d0041f293251d04"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + +[[package]] +name = "docstring-parser" +version = "0.15" +description = "Parse Python docstrings in reST, Google and Numpydoc format" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "docstring_parser-0.15-py3-none-any.whl", hash = "sha256:d1679b86250d269d06a99670924d6bce45adc00b08069dae8c47d98e89b667a9"}, + {file = "docstring_parser-0.15.tar.gz", hash = "sha256:48ddc093e8b1865899956fcc03b03e66bb7240c310fac5af81814580c55bf682"}, +] + +[[package]] +name = "exceptiongroup" +version = "1.2.0" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, + {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"}, +] + +[package.extras] +test = ["pytest (>=6)"] + +[[package]] +name = "execnet" +version = "2.0.2" +description = "execnet: rapid multi-Python deployment" +optional = false +python-versions = ">=3.7" +files = [ + {file = "execnet-2.0.2-py3-none-any.whl", hash = "sha256:88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41"}, + {file = "execnet-2.0.2.tar.gz", hash = "sha256:cc59bc4423742fd71ad227122eb0dd44db51efb3dc4095b45ac9a08c770096af"}, +] + +[package.extras] +testing = ["hatch", "pre-commit", "pytest", "tox"] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "numpy" +version = "1.24.4" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, + {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, + {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, + {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, + {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, + {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, + {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, + {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, + {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, + {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, + {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, +] + +[[package]] +name = "packaging" +version = "23.2" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, + {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, +] + +[[package]] +name = "pluggy" +version = "1.4.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.4.0-py3-none-any.whl", hash = "sha256:7db9f7b503d67d1c5b95f59773ebb58a8c1c288129a88665838012cfb07b8981"}, + {file = "pluggy-1.4.0.tar.gz", hash = "sha256:8c85c2876142a764e5b7548e7d9a0e0ddb46f5185161049a79b7e974454223be"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +description = "Get CPU info with pure Python" +optional = false +python-versions = "*" +files = [ + {file = "py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690"}, + {file = "py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5"}, +] + +[[package]] +name = "pytest" +version = "8.0.0" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.0.0-py3-none-any.whl", hash = "sha256:50fb9cbe836c3f20f0dfa99c565201fb75dc54c8d76373cd1bde06b06657bdb6"}, + {file = "pytest-8.0.0.tar.gz", hash = "sha256:249b1b0864530ba251b7438274c4d251c58d868edaaec8762893ad4a0d71c36c"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.3.0,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-benchmark" +version = "4.0.0" +description = "A ``pytest`` fixture for benchmarking code. It will group the tests into rounds that are calibrated to the chosen timer." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-benchmark-4.0.0.tar.gz", hash = "sha256:fb0785b83efe599a6a956361c0691ae1dbb5318018561af10f3e915caa0048d1"}, + {file = "pytest_benchmark-4.0.0-py3-none-any.whl", hash = "sha256:fdb7db64e31c8b277dff9850d2a2556d8b60bcb0ea6524e36e28ffd7c87f71d6"}, +] + +[package.dependencies] +py-cpuinfo = "*" +pytest = ">=3.8" + +[package.extras] +aspect = ["aspectlib"] +elasticsearch = ["elasticsearch"] +histogram = ["pygal", "pygaljs"] + +[[package]] +name = "pytest-cov" +version = "4.1.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, + {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] + +[[package]] +name = "pytest-datadir" +version = "1.5.0" +description = "pytest plugin for test data directories and files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-datadir-1.5.0.tar.gz", hash = "sha256:1617ed92f9afda0c877e4eac91904b5f779d24ba8f5e438752e3ae39d8d2ee3f"}, + {file = "pytest_datadir-1.5.0-py3-none-any.whl", hash = "sha256:34adf361bcc7b37961bbc1dfa8d25a4829e778bab461703c38a5c50ca9c36dc8"}, +] + +[package.dependencies] +pytest = ">=5.0" + +[[package]] +name = "pytest-regressions" +version = "2.5.0" +description = "Easy to use fixtures to write regression tests." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-regressions-2.5.0.tar.gz", hash = "sha256:818c7884c1cff3babf89eebb02cbc27b307856b1985427c24d52cb812a106fd9"}, + {file = "pytest_regressions-2.5.0-py3-none-any.whl", hash = "sha256:8c4e5c4037325fdb0825bc1fdcb75e17e03adf3407049f0cb704bb996d496255"}, +] + +[package.dependencies] +pytest = ">=6.2.0" +pytest-datadir = ">=1.2.0" +pyyaml = "*" + +[package.extras] +dataframe = ["numpy", "pandas"] +dev = ["matplotlib", "mypy", "numpy", "pandas", "pillow", "pre-commit", "restructuredtext-lint", "tox"] +image = ["numpy", "pillow"] +num = ["numpy", "pandas"] + +[[package]] +name = "pytest-xdist" +version = "3.5.0" +description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-xdist-3.5.0.tar.gz", hash = "sha256:cbb36f3d67e0c478baa57fa4edc8843887e0f6cfc42d677530a36d7472b32d8a"}, + {file = "pytest_xdist-3.5.0-py3-none-any.whl", hash = "sha256:d075629c7e00b611df89f490a5063944bee7a4362a5ff11c7cc7824a03dfce24"}, +] + +[package.dependencies] +execnet = ">=1.1" +pytest = ">=6.2.0" + +[package.extras] +psutil = ["psutil (>=3.0)"] +setproctitle = ["setproctitle"] +testing = ["filelock"] + +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + +[[package]] +name = "tomli" +version = "2.0.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, + {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, +] + +[[package]] +name = "tomli-w" +version = "1.0.0" +description = "A lil' TOML writer" +optional = true +python-versions = ">=3.7" +files = [ + {file = "tomli_w-1.0.0-py3-none-any.whl", hash = "sha256:9f2a07e8be30a0729e533ec968016807069991ae2fd921a78d42f429ae5f4463"}, + {file = "tomli_w-1.0.0.tar.gz", hash = "sha256:f463434305e0336248cac9c2dc8076b707d8a12d019dd349f5c1e382dd1ae1b9"}, +] + +[[package]] +name = "typing-extensions" +version = "4.9.0" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"}, + {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, +] + +[extras] +toml = ["tomli", "tomli-w"] +yaml = ["pyyaml"] + +[metadata] +lock-version = "2.0" +python-versions = "^3.8" +content-hash = "babe84c4662a3f7fb04b313fc61736472503b75b1c8cb6e01640e165cdebae5a" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..16afd4d7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[tool.poetry] +name = "simple-parsing" +version = "0.0.0" +description = "A small utility for simplifying and cleaning up argument parsing scripts." +authors = ["Fabrice Normandin "] +license = "MIT" +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.8" +docstring-parser = "~=0.15" +typing-extensions = ">=4.5.0" +pyyaml = {version = "^6.0.1", optional = true} +tomli = {version = "^2.0.1", optional = true} +tomli-w = {version = "^1.0.0", optional = true} + +[tool.poetry.group.dev.dependencies] +pytest = "^8.0.0" +pytest-xdist = "^3.5.0" +pytest-regressions = "^2.4.2" +pytest-cov = "^4.1.0" +numpy = "^1.24.2" +pytest-benchmark = "^4.0.0" + + +[tool.poetry.extras] +yaml = ["pyyaml"] +toml = ["tomli", "tomli_w"] + + +[tool.poetry-dynamic-versioning] +enable = true + +[build-system] +requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] +build-backend = "poetry_dynamic_versioning.backend" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index f8e82b03..00000000 --- a/setup.cfg +++ /dev/null @@ -1,16 +0,0 @@ -[versioneer] -VCS=git -style=pep440-post -versionfile_source=simple_parsing/_version.py -versionfile_build=simple_parsing/_version.py -tag_prefix=v -parentdir_prefix=simple_parsing- - -[metadata] -license_file=LICENSE - -[flake8] -ignore = E203, E266, E501, W503, F403, F401 -max-line-length = 79 -max-complexity = 18 -select = B,C,E,F,W,T4,B9 diff --git a/setup.py b/setup.py deleted file mode 100644 index 186af475..00000000 --- a/setup.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -import sys - -import setuptools - -import versioneer - -with open("README.md") as fh: - long_description = fh.read() -packages = setuptools.find_namespace_packages(include=["simple_parsing*"]) -print("PACKAGES FOUND:", packages) -print(sys.version_info) - -with open("requirements.txt") as req_file: - install_requires = req_file.read().splitlines(keepends=False) - -extras_require: dict[str, list[str]] = { - "test": [ - "pytest", - "pytest-xdist", - "pytest-regressions", - "pytest-benchmark", - "numpy", - # "torch", - ], - "yaml": ["pyyaml"], - "toml": ["tomli", "tomli_w"], -} -extras_require["all"] = list(set(sum(extras_require.values(), []))) - - -setuptools.setup( - name="simple_parsing", - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - author="Fabrice Normandin", - author_email="fabrice.normandin@gmail.com", - description="A small utility for simplifying and cleaning up argument parsing scripts.", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/lebrice/SimpleParsing", - packages=packages, - package_data={"simple_parsing": ["py.typed"]}, - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - python_requires=">=3.7", - install_requires=install_requires, - extras_require=extras_require, - setup_requires=["pre-commit"], -) diff --git a/simple_parsing/__init__.py b/simple_parsing/__init__.py index 08671c6b..59fbbf3f 100644 --- a/simple_parsing/__init__.py +++ b/simple_parsing/__init__.py @@ -1,4 +1,5 @@ """Simple, Elegant Argument parsing. + @author: Fabrice Normandin """ from . import helpers, utils, wrappers @@ -57,7 +58,3 @@ "utils", "wrappers", ] - -from . import _version - -__version__ = _version.get_versions()["version"] diff --git a/simple_parsing/_version.py b/simple_parsing/_version.py deleted file mode 100644 index e7124597..00000000 --- a/simple_parsing/_version.py +++ /dev/null @@ -1,657 +0,0 @@ -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.20 (https://github.com/python-versioneer/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: # pylint: disable=too-few-public-methods - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440-post" - cfg.tag_prefix = "v" - cfg.parentdir_prefix = "simple_parsing-" - cfg.versionfile_source = "simple_parsing/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - - return decorate - - -# pylint:disable=too-many-arguments,consider-using-with # noqa -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen( - [command] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print(f"unable to find command, tried {commands}") - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs) as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r"\d", r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r"\d", r): - continue - if verbose: - print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '{}' doesn't start with prefix '{}'".format( - full_tag, - tag_prefix, - ) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post0.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post0.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split("/"): - root = os.path.dirname(root) - except NameError: - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None, - } - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } diff --git a/simple_parsing/annotation_utils/get_field_annotations.py b/simple_parsing/annotation_utils/get_field_annotations.py index d35e398e..a45f2d79 100644 --- a/simple_parsing/annotation_utils/get_field_annotations.py +++ b/simple_parsing/annotation_utils/get_field_annotations.py @@ -24,8 +24,7 @@ @contextmanager def _initvar_patcher() -> Iterator[None]: - """ - Patch InitVar to not fail when annotations are postponed. + """Patch InitVar to not fail when annotations are postponed. `TypeVar('Forward references must evaluate to types. Got dataclasses.InitVar[tp].')` is raised when postponed annotations are enabled and `get_type_hints` is called @@ -148,7 +147,6 @@ def _get_old_style_annotation(annotation: str) -> str: def _replace_new_union_syntax_with_old_union_syntax( annotations_dict: Dict[str, str], context: collections.ChainMap ) -> Dict[str, Any]: - new_annotations = annotations_dict.copy() for field, annotation_str in annotations_dict.items(): updated_annotation = _get_old_style_annotation(annotation_str) @@ -193,10 +191,11 @@ def get_field_type_from_annotations(some_class: type, field_name: str) -> type: # Get the global_ns in the module starting from the deepest base until the module with the field_name last definition. global_ns = {} - classes_to_iterate = list(dropwhile( - lambda cls: field_name not in getattr(cls, "__annotations__", {}), - some_class.mro() - )) + classes_to_iterate = list( + dropwhile( + lambda cls: field_name not in getattr(cls, "__annotations__", {}), some_class.mro() + ) + ) for base_cls in reversed(classes_to_iterate): global_ns.update(sys.modules[base_cls.__module__].__dict__) diff --git a/simple_parsing/conflicts.py b/simple_parsing/conflicts.py index fb71ecde..9d522fcb 100644 --- a/simple_parsing/conflicts.py +++ b/simple_parsing/conflicts.py @@ -13,7 +13,6 @@ class ConflictResolution(enum.Enum): """Determines prefixing when adding the same dataclass more than once. - - NONE: Disallow using the same dataclass in two different destinations without explicitly setting a distinct prefix for at least one of them. @@ -66,9 +65,8 @@ def __init__(self, conflict_resolution=ConflictResolution.AUTO): def resolve_and_flatten(self, wrappers: list[DataclassWrapper]) -> list[DataclassWrapper]: """Given the list of all dataclass wrappers, find and resolve any conflicts between fields. - Returns the new list of (possibly mutated in-place) dataclass wrappers. - This returned list is flattened, i.e. it contains all the dataclass wrappers and their - children. + Returns the new list of (possibly mutated in-place) dataclass wrappers. This returned list + is flattened, i.e. it contains all the dataclass wrappers and their children. """ from simple_parsing.parsing import _assert_no_duplicates, _flatten_wrappers @@ -377,7 +375,10 @@ def _get_conflicting_group(self, all_wrappers: list[DataclassWrapper]) -> Confli return None def _conflict_exists(self, all_wrappers: list[DataclassWrapper]) -> bool: - """Return True whenever a conflict exists. (option strings overlap).""" + """Return True whenever a conflict exists. + + (option strings overlap). + """ arg_names: set[str] = set() for wrapper in all_wrappers: for field in wrapper.fields: diff --git a/simple_parsing/decorators.py b/simple_parsing/decorators.py index 68580eab..16162c71 100644 --- a/simple_parsing/decorators.py +++ b/simple_parsing/decorators.py @@ -8,7 +8,9 @@ from typing import Any, Callable, NamedTuple import docstring_parser as dp + from simple_parsing.docstring import dp_parse, inspect_getdoc + from . import helpers, parsing diff --git a/simple_parsing/docstring.py b/simple_parsing/docstring.py index 17d52d0a..4cc9e129 100644 --- a/simple_parsing/docstring.py +++ b/simple_parsing/docstring.py @@ -1,4 +1,5 @@ -"""Utility for retrieveing the docstring of a dataclass's attributes +"""Utility for retrieveing the docstring of a dataclass's attributes. + @author: Fabrice Normandin """ from __future__ import annotations @@ -28,11 +29,12 @@ class AttributeDocString: docstring_below: str = "" desc_from_cls_docstring: str = "" - """ The description of this field from the class docstring. """ + """The description of this field from the class docstring.""" @property def help_string(self) -> str: - """Returns the value that will be used for the "--help" string, using the contents of self.""" + """Returns the value that will be used for the "--help" string, using the contents of + self.""" return ( self.docstring_below or self.comment_above @@ -104,6 +106,7 @@ def get_attribute_docstring( @functools.lru_cache(2048) def _get_attribute_docstring(dataclass: type, field_name: str) -> AttributeDocString | None: """Gets the AttributeDocString of the given field in the given dataclass. + Doesn't inspect base classes. """ try: diff --git a/simple_parsing/help_formatter.py b/simple_parsing/help_formatter.py index 67727138..11d717e7 100644 --- a/simple_parsing/help_formatter.py +++ b/simple_parsing/help_formatter.py @@ -38,9 +38,9 @@ def _format_args(self, action: Action, default_metavar: str): elif action.nargs == OPTIONAL: result = "[%s]" % _get_metavar(1) elif action.nargs == ZERO_OR_MORE: - result = "[%s [%s ...]]" % _get_metavar(2) + result = "[%s [%s ...]]" % _get_metavar(2) # noqa: UP031 elif action.nargs == ONE_OR_MORE: - result = "%s [%s ...]" % _get_metavar(2) + result = "%s [%s ...]" % _get_metavar(2) # noqa: UP031 elif action.nargs == REMAINDER: result = "..." elif action.nargs == PARSER: diff --git a/simple_parsing/helpers/__init__.py b/simple_parsing/helpers/__init__.py index a9717151..b0e635bd 100644 --- a/simple_parsing/helpers/__init__.py +++ b/simple_parsing/helpers/__init__.py @@ -1,4 +1,4 @@ -""" Collection of helper classes and functions to reduce boilerplate code. """ +"""Collection of helper classes and functions to reduce boilerplate code.""" from .fields import * from .flatten import FlattenedAccess from .hparams import HyperParameters diff --git a/simple_parsing/helpers/fields.py b/simple_parsing/helpers/fields.py index 1a0ab9ee..a2008084 100644 --- a/simple_parsing/helpers/fields.py +++ b/simple_parsing/helpers/fields.py @@ -1,5 +1,4 @@ -""" Utility functions that simplify defining field of dataclasses. -""" +"""Utility functions that simplify defining field of dataclasses.""" from __future__ import annotations import dataclasses @@ -130,7 +129,6 @@ def field( "to True when using the store_false action." ) default = True # type: ignore - elif action == "store_true": if default not in {MISSING, False}: raise RuntimeError( @@ -138,7 +136,6 @@ def field( "to False when using the store_true action." ) default = False # type: ignore - if default is not MISSING: return dataclasses.field( # type: ignore default=default, @@ -190,8 +187,7 @@ def choice( def choice(*choices, default=MISSING, **kwargs): - """Makes a field which can be chosen from the set of choices from the - command-line. + """Makes a field which can be chosen from the set of choices from the command-line. Returns a regular `dataclasses.field()`, but with metadata which indicates the allowed values. @@ -245,9 +241,8 @@ def choice(*choices, default=MISSING, **kwargs): # TODO: If the choice dict is given, then add encoding/decoding functions that just # get/set the right key. def _encoding_fn(value: Any) -> str: - """Custom encoding function that will simply represent the value as the - the key in the dict rather than the value itself. - """ + """Custom encoding function that will simply represent the value as the the key in + the dict rather than the value itself.""" if value in choice_dict.keys(): return value elif value in choice_dict.values(): @@ -257,9 +252,8 @@ def _encoding_fn(value: Any) -> str: kwargs.setdefault("encoding_fn", _encoding_fn) def _decoding_fn(value: Any) -> Any: - """Custom decoding function that will retrieve the value from the - stored key in the dictionary. - """ + """Custom decoding function that will retrieve the value from the stored key in the + dictionary.""" return choice_dict.get(value, value) kwargs.setdefault("decoding_fn", _decoding_fn) @@ -268,8 +262,8 @@ def _decoding_fn(value: Any) -> Any: def list_field(*default_items: T, **kwargs) -> list[T]: - """shorthand function for setting a `list` attribute on a dataclass, - so that every instance of the dataclass doesn't share the same list. + """shorthand function for setting a `list` attribute on a dataclass, so that every instance of + the dataclass doesn't share the same list. Accepts any of the arguments of the `dataclasses.field` function. @@ -291,8 +285,8 @@ def list_field(*default_items: T, **kwargs) -> list[T]: def dict_field(default_items: dict[K, V] | Iterable[tuple[K, V]] = (), **kwargs) -> dict[K, V]: - """shorthand function for setting a `dict` attribute on a dataclass, - so that every instance of the dataclass doesn't share the same `dict`. + """shorthand function for setting a `dict` attribute on a dataclass, so that every instance of + the dataclass doesn't share the same `dict`. NOTE: Do not use keyword arguments as you usually would with a dictionary (as in something like `dict_field(a=1, b=2, c=3)`). Instead pass in a @@ -325,7 +319,9 @@ def mutable_field( *fn_args: P.args, **fn_kwargs: P.kwargs, ) -> T: - """Shorthand for `dataclasses.field(default_factory=functools.partial(fn, *fn_args, **fn_kwargs))`. + """Shorthand for `dataclasses.field(default_factory=functools.partial(fn, *fn_args, + + **fn_kwargs))`. NOTE: The *fn_args and **fn_kwargs here are passed to `fn`, and are never used by the argparse Action! diff --git a/simple_parsing/helpers/flatten.py b/simple_parsing/helpers/flatten.py index 8b556072..7eacea61 100644 --- a/simple_parsing/helpers/flatten.py +++ b/simple_parsing/helpers/flatten.py @@ -97,8 +97,8 @@ def __getattr__(self, name: str): def __setattr__(self, name: str, value: Any): """Write the attribute in self or in the children that has it. - If more than one child has attributes that match the given one, an - `AttributeError` is raised. + If more than one child has attributes that match the given one, an `AttributeError` is + raised. """ # potential parents and corresponding values. parents: List[str] = [] diff --git a/simple_parsing/helpers/hparams/hyperparameters.py b/simple_parsing/helpers/hparams/hyperparameters.py index 47a3aa9e..2574449b 100644 --- a/simple_parsing/helpers/hparams/hyperparameters.py +++ b/simple_parsing/helpers/hparams/hyperparameters.py @@ -1,16 +1,17 @@ from __future__ import annotations + import dataclasses import inspect import math import pickle import random +import typing from collections import OrderedDict from dataclasses import Field, dataclass, fields from functools import singledispatch, total_ordering from logging import getLogger from pathlib import Path -from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Type, TypeVar -import typing +from typing import Any, ClassVar, NamedTuple, TypeVar from simple_parsing import utils from simple_parsing.helpers.serialization.serializable import Serializable @@ -39,7 +40,7 @@ class BoundInfo(Serializable): name: str # One of 'continuous', 'discrete' or 'bandit' (unsupported). type: str = "continuous" - domain: Tuple[float, float] = (-math.inf, math.inf) + domain: tuple[float, float] = (-math.inf, math.inf) @dataclass @@ -69,53 +70,53 @@ def __post_init__(self): setattr(self, name, new_value) @classmethod - def field_names(cls) -> List[str]: + def field_names(cls) -> list[str]: return [f.name for f in fields(cls)] def id(self): return compute_identity(**self.to_dict()) - def seed(self, seed: Optional[int]) -> None: + def seed(self, seed: int | None) -> None: """TODO: Seed all priors with the given seed. (recursively if nested dataclasses are present.) """ raise NotImplementedError("TODO") @classmethod - def get_priors(cls) -> Dict[str, Prior]: + def get_priors(cls) -> dict[str, Prior]: """Returns a dictionary of the Priors for the hparam fields in this class.""" - priors_dict: Dict[str, Prior] = {} + priors_dict: dict[str, Prior] = {} for field in fields(cls): # If a HyperParameters class contains another HyperParameters class as a field # we perform returned a flattened dict. if inspect.isclass(field.type) and issubclass(field.type, HyperParameters): priors_dict[field.name] = field.type.get_priors() else: - prior: Optional[Prior] = field.metadata.get("prior") + prior: Prior | None = field.metadata.get("prior") if prior: priors_dict[field.name] = prior return priors_dict @classmethod - def get_orion_space_dict(cls) -> Dict[str, str]: - result: Dict[str, str] = {} + def get_orion_space_dict(cls) -> dict[str, str]: + result: dict[str, str] = {} for field in fields(cls): # If a HyperParameters class contains another HyperParameters class as a field # we perform returned a flattened dict. if inspect.isclass(field.type) and issubclass(field.type, HyperParameters): result[field.name] = field.type.get_orion_space_dict() else: - prior: Optional[Prior] = field.metadata.get("prior") + prior: Prior | None = field.metadata.get("prior") if prior: result[field.name] = prior.get_orion_space_string() return result - def get_orion_space(self) -> Dict[str, str]: + def get_orion_space(self) -> dict[str, str]: """NOTE: This might be more useful in some cases than the above classmethod version, for example when a field is a different kind of dataclass than its annotation. """ - result: Dict[str, str] = {} + result: dict[str, str] = {} for field in fields(self): value = getattr(self, field.name) # If a HyperParameters class contains another HyperParameters class as a field @@ -123,7 +124,7 @@ def get_orion_space(self) -> Dict[str, str]: if isinstance(value, HyperParameters): result[field.name] = value.get_orion_space() else: - prior: Optional[Prior] = field.metadata.get("prior") + prior: Prior | None = field.metadata.get("prior") if prior: result[field.name] = prior.get_orion_space_string() return result @@ -133,12 +134,12 @@ def space_id(cls): return compute_identity(**cls.get_orion_space_dict()) @classmethod - def get_bounds(cls) -> List[BoundInfo]: + def get_bounds(cls) -> list[BoundInfo]: """Returns the bounds of the search domain for this type of HParam. Returns them as a list of `BoundInfo` objects, in the format expected by GPyOpt. """ - bounds: List[BoundInfo] = [] + bounds: list[BoundInfo] = [] for f in fields(cls): # TODO: handle a hparam which is categorical (i.e. choices) min_v = f.metadata.get("min") @@ -155,14 +156,14 @@ def get_bounds(cls) -> List[BoundInfo]: return bounds @classmethod - def get_bounds_dicts(cls) -> List[Dict[str, Any]]: + def get_bounds_dicts(cls) -> list[dict[str, Any]]: """Returns the bounds of the search space for this type of HParam, in the format expected by the `GPyOpt` package.""" return [b.to_dict() for b in cls.get_bounds()] @classmethod def sample(cls): - kwargs: Dict[str, Any] = {} + kwargs: dict[str, Any] = {} for field in dataclasses.fields(cls): if inspect.isclass(field.type) and issubclass(field.type, HyperParameters): # TODO: Should we allow adding a 'prior' in terms of a dataclass field? @@ -177,7 +178,7 @@ def sample(cls): value = chosen_class.sample() kwargs[field.name] = value else: - prior: Optional[Prior] = field.metadata.get("prior") + prior: Prior | None = field.metadata.get("prior") if prior is not None: try: import numpy as np @@ -209,7 +210,7 @@ def to_array(self, dtype: numpy.dtype | None = None) -> numpy.ndarray: import numpy as np dtype = np.float32 if dtype is None else dtype - values: List[float] = [] + values: list[float] = [] for k, v in self.to_dict(dict_factory=OrderedDict).items(): try: v = float(v) @@ -220,7 +221,7 @@ def to_array(self, dtype: numpy.dtype | None = None) -> numpy.ndarray: return np.array(values, dtype=dtype) @classmethod - def from_array(cls: Type[HP], array: numpy.ndarray) -> HP: + def from_array(cls: type[HP], array: numpy.ndarray) -> HP: import numpy as np if len(array.shape) == 2 and array.shape[0] == 1: @@ -287,7 +288,7 @@ def __eq__(self, other: object): hps_equal = hp_id == other_id return hps_equal and self.perf == other[1] - def __gt__(self, other: Tuple[object, ...]) -> bool: + def __gt__(self, other: tuple[object, ...]) -> bool: # Even though the tuple has (hp, perf), compare based on the order # (perf, hp). # This means that sorting a list of Points will work as expected! diff --git a/simple_parsing/helpers/hparams/hyperparameters_test.py b/simple_parsing/helpers/hparams/hyperparameters_test.py index 06d7c9ca..d675ac54 100644 --- a/simple_parsing/helpers/hparams/hyperparameters_test.py +++ b/simple_parsing/helpers/hparams/hyperparameters_test.py @@ -48,9 +48,8 @@ class C(HyperParameters): def test_clip_within_bounds(): - """Test to make sure that the `clip_within_bounds` actually restricts the - values of the HyperParameters to be within the bounds. - """ + """Test to make sure that the `clip_within_bounds` actually restricts the values of the + HyperParameters to be within the bounds.""" # valid range for learning_rate is (0 - 1]. a = A(learning_rate=123) assert a.learning_rate == 123 @@ -87,9 +86,8 @@ class C(HyperParameters): def test_strict_bounds(): - """When creating a class and using a hparam field with `strict=True`, the values - will be restricted to be within the given bounds. - """ + """When creating a class and using a hparam field with `strict=True`, the values will be + restricted to be within the given bounds.""" @dataclass class C(HyperParameters): @@ -147,7 +145,9 @@ class Child(HyperParameters): bob = Child.sample() assert bob.hparam in {1.23, 4.56, 7.89} - assert Child.get_orion_space_dict() == {"hparam": "choices(['a', 'b', 'c'], default_value='a')"} + assert Child.get_orion_space_dict() == { + "hparam": "choices(['a', 'b', 'c'], default_value='a')" + } def test_choice_field_with_values_of_a_weird_type(): @@ -213,7 +213,6 @@ class Foo(HyperParameters): def test_priors_with_shape(): - foo = Foo() assert foo.x == (5, 5) assert foo.y == (5, 5, 5) diff --git a/simple_parsing/helpers/hparams/priors.py b/simple_parsing/helpers/hparams/priors.py index 8817a2d7..eeb025aa 100644 --- a/simple_parsing/helpers/hparams/priors.py +++ b/simple_parsing/helpers/hparams/priors.py @@ -1,3 +1,4 @@ +import importlib.util import math import random from abc import abstractmethod @@ -14,14 +15,15 @@ overload, ) -import importlib.util class _np_lazy: def __getattr__(self, attr): global np import numpy as np + return getattr(np, attr) + np = _np_lazy() numpy_installed = importlib.util.find_spec("numpy") is not None diff --git a/simple_parsing/helpers/partial.py b/simple_parsing/helpers/partial.py index c6b2bffb..64b9f2c9 100644 --- a/simple_parsing/helpers/partial.py +++ b/simple_parsing/helpers/partial.py @@ -1,4 +1,4 @@ -""" A Partial helper that can be used to add arguments for an arbitrary class or callable. """ +"""A Partial helper that can be used to add arguments for an arbitrary class or callable.""" from __future__ import annotations import dataclasses @@ -192,7 +192,9 @@ def config_for( logger.debug(f"Adding optional field: {fields[-1]}") cls_name = _get_generated_config_class_name(cls) - config_class = make_dataclass(cls_name=cls_name, bases=(Partial,), fields=fields, frozen=frozen) + config_class = make_dataclass( + cls_name=cls_name, bases=(Partial,), fields=fields, frozen=frozen + ) config_class._target_ = cls config_class.__doc__ = ( f"Auto-Generated configuration dataclass for {cls.__module__}.{cls.__qualname__}\n" @@ -205,8 +207,7 @@ def config_for( @singledispatch def infer_type_annotation_from_default(default: Any) -> Any | type: """Used when there is a default value, but no type annotation, to infer the type of field to - create on the config dataclass. - """ + create on the config dataclass.""" if isinstance(default, (int, str, float, bool)): return type(default) if isinstance(default, tuple): @@ -281,9 +282,8 @@ def __getitem__(cls, target: Callable[_P, _T]) -> type[Callable[_P, _T]]: def __getattr__(name: str): - """ - Getting an attribute on this module here will check for the autogenerated config class with that name. - """ + """Getting an attribute on this module here will check for the autogenerated config class with + that name.""" if name in globals(): return globals()[name] diff --git a/simple_parsing/helpers/serialization/decoding.py b/simple_parsing/helpers/serialization/decoding.py index 2d3aab73..4e4a5fed 100644 --- a/simple_parsing/helpers/serialization/decoding.py +++ b/simple_parsing/helpers/serialization/decoding.py @@ -1,5 +1,4 @@ -""" Functions for decoding dataclass fields from "raw" values (e.g. from json). -""" +"""Functions for decoding dataclass fields from "raw" values (e.g. from json).""" from __future__ import annotations import inspect @@ -201,7 +200,6 @@ def get_decoding_fn(type_annotation: type[T] | str) -> Callable[..., T]: Returns: Callable[[Any], T]: A function that decodes a 'raw' value to an instance of type `t`. - """ from .serializable import from_dict @@ -454,8 +452,7 @@ def _decode_dict(val: dict[Any, Any] | list[tuple[Any, Any]]) -> dict[K, V]: def decode_enum(item_type: type[Enum]) -> Callable[[str], Enum]: - """ - Creates a decoding function for an enum type. + """Creates a decoding function for an enum type. Args: item_type (Type[Enum]): the type of the items in the set. diff --git a/simple_parsing/helpers/serialization/serializable.py b/simple_parsing/helpers/serialization/serializable.py index 64487c92..bc845ad5 100644 --- a/simple_parsing/helpers/serialization/serializable.py +++ b/simple_parsing/helpers/serialization/serializable.py @@ -461,7 +461,8 @@ class SimpleSerializable(SerializableMixin, decode_into_subclasses=True): def get_serializable_dataclass_types_from_forward_ref( forward_ref: type, serializable_base_class: type[S] = SerializableMixin ) -> list[type[S]]: - """Gets all the subclasses of `serializable_base_class` that have the same name as the argument of this forward reference annotation.""" + """Gets all the subclasses of `serializable_base_class` that have the same name as the argument + of this forward reference annotation.""" arg = get_forward_arg(forward_ref) potential_classes: list[type] = [] for serializable_class in serializable_base_class.subclasses: @@ -526,7 +527,7 @@ def load( # Load a dict from the file. d = read_file(path) elif load_fn: - with (path.open() if isinstance(path, Path) else path) as f: + with path.open() if isinstance(path, Path) else path as f: d = load_fn(f) else: raise ValueError( @@ -928,13 +929,11 @@ def is_dataclass_or_optional_dataclass_type(t: type) -> bool: def _locate(path: str) -> Any: - """ - COPIED FROM Hydra: - https://github.com/facebookresearch/hydra/blob/f8940600d0ab5c695961ad83abd042ffe9458caf/hydra/_internal/utils.py#L614 + """COPIED FROM Hydra: https://github.com/facebookresearch/hydra/blob/f8940600d0ab5c695961ad83ab + d042ffe9458caf/hydra/_internal/utils.py#L614. - Locate an object by name or dotted path, importing as necessary. - This is similar to the pydoc function `locate`, except that it checks for - the module from the given path from back to front. + Locate an object by name or dotted path, importing as necessary. This is similar to the pydoc + function `locate`, except that it checks for the module from the given path from back to front. """ if path == "": raise ImportError("Empty path") diff --git a/simple_parsing/helpers/serialization/yaml_serialization.py b/simple_parsing/helpers/serialization/yaml_serialization.py index 7091c5e4..13aa034c 100644 --- a/simple_parsing/helpers/serialization/yaml_serialization.py +++ b/simple_parsing/helpers/serialization/yaml_serialization.py @@ -4,7 +4,10 @@ from pathlib import Path from typing import IO -import yaml +try: + import yaml +except ImportError: + pass from .serializable import D, Serializable @@ -12,18 +15,22 @@ class YamlSerializable(Serializable): - """Convenience class, just sets different `load_fn` and `dump_fn` defaults - for the `dump`, `dumps`, `load`, `loads` methods of `Serializable`. + """Convenience class, just sets different `load_fn` and `dump_fn` defaults for the `dump`, + `dumps`, `load`, `loads` methods of `Serializable`. Uses the `yaml.safe_load` and `yaml.dump` for loading and dumping. Requires the pyyaml package. """ - def dump(self, fp: IO[str], dump_fn=yaml.dump, **kwargs) -> None: + def dump(self, fp: IO[str], dump_fn=None, **kwargs) -> None: + if dump_fn is None: + dump_fn = yaml.dump dump_fn(self.to_dict(), fp, **kwargs) - def dumps(self, dump_fn=yaml.dump, **kwargs) -> str: + def dumps(self, dump_fn=None, **kwargs) -> str: + if dump_fn is None: + dump_fn = yaml.dump return dump_fn(self.to_dict(), **kwargs) @classmethod @@ -31,9 +38,12 @@ def load( cls: type[D], path: Path | str | IO[str], drop_extra_fields: bool | None = None, - load_fn=yaml.safe_load, + load_fn=None, **kwargs, ) -> D: + if load_fn is None: + load_fn = yaml.safe_load + return super().load(path, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs) @classmethod @@ -41,9 +51,11 @@ def loads( cls: type[D], s: str, drop_extra_fields: bool | None = None, - load_fn=yaml.safe_load, + load_fn=None, **kwargs, ) -> D: + if load_fn is None: + load_fn = yaml.safe_load return super().loads(s, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs) @classmethod @@ -51,7 +63,9 @@ def _load( cls: type[D], fp: IO[str], drop_extra_fields: bool | None = None, - load_fn=yaml.safe_load, + load_fn=None, **kwargs, ) -> D: + if load_fn is None: + load_fn = yaml.safe_load return super()._load(fp, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs) diff --git a/simple_parsing/helpers/subgroups.py b/simple_parsing/helpers/subgroups.py index d9487998..8b16794c 100644 --- a/simple_parsing/helpers/subgroups.py +++ b/simple_parsing/helpers/subgroups.py @@ -194,7 +194,14 @@ def subgroups( from .fields import choice - return choice(choices, *args, default=default, default_factory=default_factory, metadata=metadata, **kwargs) # type: ignore + return choice( + choices, + *args, + default=default, + default_factory=default_factory, + metadata=metadata, + **kwargs, + ) # type: ignore def _get_dataclass_type_from_callable( @@ -226,7 +233,6 @@ def _get_dataclass_type_from_callable( # Recurse, so this also works with partial(partial(...)) (idk why you'd do that though.) if isinstance(signature.return_annotation, str): - dataclass_fn_type = signature.return_annotation if caller_frame is not None: # Travel up until we find the right frame where the subgroup is defined. @@ -265,7 +271,7 @@ def _get_dataclass_type_from_callable( def is_lambda(obj: Any) -> bool: """Returns True if the given object is a lambda expression. - Taken from https://stackoverflow.com/questions/3655842/how-can-i-test-whether-a-variable-holds-a-lambda + Taken froma-lambda """ LAMBDA = lambda: 0 # noqa: E731 return isinstance(obj, type(LAMBDA)) and obj.__name__ == LAMBDA.__name__ diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index ec81bdd1..ab1aa524 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -1,4 +1,5 @@ """Simple, Elegant Argument parsing. + @author: Fabrice Normandin """ from __future__ import annotations @@ -98,7 +99,6 @@ class ArgumentParser(argparse.ArgumentParser): - add_config_path_arg : bool, optional When set to `True`, adds a `--config_path` argument, of type Path, which is used to parse - """ def __init__( @@ -761,6 +761,7 @@ def _resolve_subgroups( def _remove_subgroups_from_namespace(self, parsed_args: argparse.Namespace) -> None: """Removes the subgroup choice results from the namespace. + Modifies the namespace in-place. """ # find all subgroup fields @@ -979,11 +980,12 @@ def _fill_constructor_arguments_with_fields( @property def confilct_resolver_max_attempts(self) -> int: return self._conflict_resolver.max_attempts - + @confilct_resolver_max_attempts.setter def confilct_resolver_max_attempts(self, value: int): self._conflict_resolver.max_attempts = value + # TODO: Change the order of arguments to put `args` as the second argument. def parse( config_class: type[DataclassT], @@ -1068,7 +1070,9 @@ def parse_known_args( add_config_path_arg=add_config_path_arg, ) parser.add_arguments(config_class, dest=dest, default=default) - parsed_args, unknown_args = parser.parse_known_args(args, attempt_to_reorder=attempt_to_reorder) + parsed_args, unknown_args = parser.parse_known_args( + args, attempt_to_reorder=attempt_to_reorder + ) config: Dataclass = getattr(parsed_args, dest) return config, unknown_args @@ -1118,7 +1122,6 @@ def _create_dataclass_instance( constructor: Callable[..., DataclassT], constructor_args: dict[str, Any], ) -> DataclassT | None: - # Check if the dataclass annotation is marked as Optional. # In this case, if no arguments were passed, and the default value is None, then return # None. @@ -1126,7 +1129,6 @@ def _create_dataclass_instance( # command-line from the case where no arguments are passed at all! if wrapper.optional and wrapper.default is None: for field_wrapper in wrapper.fields: - arg_value = constructor_args[field_wrapper.name] default_value = field_wrapper.default logger.debug( diff --git a/simple_parsing/replace.py b/simple_parsing/replace.py index 3735a336..db350fba 100644 --- a/simple_parsing/replace.py +++ b/simple_parsing/replace.py @@ -112,9 +112,8 @@ def replace(obj: DataclassT, changes_dict: dict[str, Any] | None = None, **chang def replace_subgroups( obj: DataclassT, selections: dict[str, Key | DataclassT] | None = None ) -> DataclassT: - """ - This function replaces the dataclass of subgroups, union, and optional union. - The `selections` dict can be in flat format or in nested format. + """This function replaces the dataclass of subgroups, union, and optional union. The + `selections` dict can be in flat format or in nested format. The values of selections can be `Key` of subgroups, dataclass type, and dataclass instance. """ @@ -184,9 +183,8 @@ def replace_subgroups( def _unflatten_selection_dict( flattened: Mapping[str, V], keyword: str = "__key__", sep: str = ".", recursive: bool = True ) -> PossiblyNestedDict[str, V]: - """ - This function convert a flattened dict into a nested dict - and it inserts the `keyword` as the selection into the nested dict. + """This function convert a flattened dict into a nested dict and it inserts the `keyword` as + the selection into the nested dict. >>> _unflatten_selection_dict({'ab_or_cd': 'cd', 'ab_or_cd.c_or_d': 'd'}) {'ab_or_cd': {'__key__': 'cd', 'c_or_d': 'd'}} diff --git a/simple_parsing/utils.py b/simple_parsing/utils.py index 24a97f16..752fbbcb 100644 --- a/simple_parsing/utils.py +++ b/simple_parsing/utils.py @@ -12,9 +12,8 @@ import sys import types import typing -from collections import OrderedDict +from collections import OrderedDict, defaultdict from collections import abc as c_abc -from collections import defaultdict from dataclasses import _MISSING_TYPE, MISSING, Field from enum import Enum from logging import getLogger @@ -108,9 +107,8 @@ def is_subparser_field(field: Field) -> bool: class InconsistentArgumentError(RuntimeError): - """ - Error raised when the number of arguments provided is inconsistent when parsing multiple instances from command line. - """ + """Error raised when the number of arguments provided is inconsistent when parsing multiple + instances from command line.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -126,9 +124,8 @@ def camel_case(name): def str2bool(raw_value: str | bool) -> bool: - """ - Taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse - """ + """Taken from https://stackoverflow.com/questions/15008758/parsing-boolean-values-with- + argparse.""" if isinstance(raw_value, bool): return raw_value v = raw_value.strip().lower() @@ -208,9 +205,9 @@ def get_item_type(container_type: type[Container[T]]) -> T: def get_argparse_type_for_container( container_type: type[Container[T]], ) -> type[T] | Callable[[str], T]: - """Gets the argparse 'type' option to be used for a given container type. - When an annotation is present, the 'type' option of argparse is set to that type. - if not, then the default value of 'str' is returned. + """Gets the argparse 'type' option to be used for a given container type. When an annotation is + present, the 'type' option of argparse is set to that type. if not, then the default value of + 'str' is returned. Arguments: container_type {Type} -- A container type (ideally a typing.Type such as List, Tuple, along with an item annotation: List[str], Tuple[int, int], etc.) @@ -413,7 +410,9 @@ def is_dataclass_type_or_typevar(t: type) -> bool: Returns: bool: Whether its a dataclass type. """ - return dataclasses.is_dataclass(t) or (is_typevar(t) and dataclasses.is_dataclass(get_bound(t))) + return dataclasses.is_dataclass(t) or ( + is_typevar(t) and dataclasses.is_dataclass(get_bound(t)) + ) def is_enum(t: type) -> bool: @@ -431,7 +430,7 @@ def is_tuple_or_list(t: type) -> bool: def is_union(t: type) -> bool: - """Returns whether or not the given Type annotation is a variant (or subclass) of typing.Union + """Returns whether or not the given Type annotation is a variant (or subclass) of typing.Union. Args: t (Type): some type annotation @@ -453,8 +452,7 @@ def is_union(t: type) -> bool: def is_homogeneous_tuple_type(t: type[tuple]) -> bool: - """Returns whether the given Tuple type is homogeneous: if all items types are the - same. + """Returns whether the given Tuple type is homogeneous: if all items types are the same. This also includes Tuple[, ...] @@ -658,6 +656,7 @@ def _parse(value: str) -> list[Any]: def _parse_literal(value: str) -> list[Any] | Any: """try to parse the string to a python expression directly. + (useful for nested lists or tuples.) """ literal = ast.literal_eval(value) @@ -723,8 +722,8 @@ def get_nesting_level(possibly_nested_list): def default_value(field: dataclasses.Field) -> T | _MISSING_TYPE: - """Returns the default value of a field in a dataclass, if available. - When not available, returns `dataclasses.MISSING`. + """Returns the default value of a field in a dataclass, if available. When not available, + returns `dataclasses.MISSING`. Args: field (dataclasses.Field): The dataclasses.Field to get the default value of. @@ -781,7 +780,6 @@ def keep_keys(d: dict, keys_to_keep: Iterable[str]) -> tuple[dict, dict]: Tuple[Dict, Dict] The same dictionary (with all the unwanted keys removed) as well as a new dict containing only the removed item. - """ d_keys = set(d.keys()) # save a copy since we will modify the dict. removed = {} @@ -792,7 +790,7 @@ def keep_keys(d: dict, keys_to_keep: Iterable[str]) -> tuple[dict, dict]: def compute_identity(size: int = 16, **sample) -> str: - """Compute a unique hash out of a dictionary + """Compute a unique hash out of a dictionary. Parameters ---------- @@ -801,7 +799,6 @@ def compute_identity(size: int = 16, **sample) -> str: **sample: Dictionary to compute the hash from - """ sample_hash = hashlib.sha256() @@ -840,7 +837,7 @@ def zip_dicts(*dicts: dict[K, V]) -> Iterable[tuple[K, tuple[V | None, ...]]]: def dict_union(*dicts: dict[K, V], recurse: bool = True, dict_factory=dict) -> dict[K, V]: - """Simple dict union until we use python 3.9 + """Simple dict union until we use python 3.9. If `recurse` is True, also does the union of nested dictionaries. NOTE: The returned dictionary has keys sorted alphabetically. @@ -924,7 +921,8 @@ def unflatten(flattened: Mapping[tuple[K, ...], V]) -> PossiblyNestedDict[K, V]: def flatten_join(nested: PossiblyNestedMapping[str, V], sep: str = ".") -> dict[str, V]: - """Flatten a dictionary of dictionaries. Joins different nesting levels with `sep` as separator. + """Flatten a dictionary of dictionaries. Joins different nesting levels with `sep` as + separator. >>> flatten_join({'a': {'b': 2, 'c': 3}, 'c': {'d': 3, 'e': 4}}) {'a.b': 2, 'a.c': 3, 'c.d': 3, 'c.e': 4} diff --git a/simple_parsing/wrappers/dataclass_wrapper.py b/simple_parsing/wrappers/dataclass_wrapper.py index cbcd43f1..2efa98af 100644 --- a/simple_parsing/wrappers/dataclass_wrapper.py +++ b/simple_parsing/wrappers/dataclass_wrapper.py @@ -3,8 +3,6 @@ import argparse import dataclasses import functools -import inspect -import sys import textwrap from dataclasses import MISSING from logging import getLogger @@ -21,10 +19,11 @@ logger = getLogger(__name__) MAX_DOCSTRING_DESC_LINES_HEIGHT: int = 50 -""" -Maximum number of lines of the class docstring to include in the autogenerated argument group -description. If fields don't have docstrings or help text, then this is not used, and the entire -docstring is used as the description of the argument group. +"""Maximum number of lines of the class docstring to include in the autogenerated argument group +description. + +If fields don't have docstrings or help text, then this is not used, and the entire docstring is +used as the description of the argument group. """ DataclassWrapperType = TypeVar("DataclassWrapperType", bound="DataclassWrapper") @@ -311,7 +310,9 @@ def set_default(self, value: DataclassT | dict | None): unknown_names.remove(nested_dataclass_wrapper.name) unknown_names.discard("_type_") if unknown_names: - raise RuntimeError(f"{sorted(unknown_names)} are not fields of {self.dataclass} at path {self.dest!r}!") + raise RuntimeError( + f"{sorted(unknown_names)} are not fields of {self.dataclass} at path {self.dest!r}!" + ) @property def title(self) -> str: @@ -420,6 +421,7 @@ def destinations(self, value: list[str]): def merge(self, other: DataclassWrapper): """Absorb all the relevant attributes from another wrapper. + Args: other (DataclassWrapper): Another instance to absorb into this one. """ @@ -447,10 +449,6 @@ def _get_dataclass_fields(dataclass: type[Dataclass]) -> tuple[dataclasses.Field # NOTE: `dataclasses.fields` method retrieves only `dataclasses._FIELD` # NOTE: but we also want to know about `dataclasses._FIELD_INITVAR` # NOTE: therefore we partly copy-paste its implementation - if sys.version_info[:2] < (3, 8): - # Before 3.8 `InitVar[tp] is InitVar` so it's impossible to retrieve field type - # therefore we should skip it just to be fully backward compatible - return dataclasses.fields(dataclass) try: dataclass_fields_map = getattr(dataclass, dataclasses._FIELDS) except AttributeError: diff --git a/simple_parsing/wrappers/field_parsing.py b/simple_parsing/wrappers/field_parsing.py index 4d0b08aa..6ab74220 100644 --- a/simple_parsing/wrappers/field_parsing.py +++ b/simple_parsing/wrappers/field_parsing.py @@ -1,7 +1,6 @@ """Functions that are to be used to parse a field. -Somewhat analogous to the 'parse' function in the -`helpers.serialization.parsing` package. +Somewhat analogous to the 'parse' function in the `helpers.serialization.parsing` package. """ import enum import functools diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 0becac2b..3a4d1860 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -29,31 +29,27 @@ class ArgumentGenerationMode(Enum): - """ - Enum for argument generation modes. - """ + """Enum for argument generation modes.""" FLAT = auto() - """ Tries to generate flat arguments, removing the argument destination path when possible. """ + """Tries to generate flat arguments, removing the argument destination path when possible.""" NESTED = auto() - """ Generates arguments with their full destination path. """ + """Generates arguments with their full destination path.""" BOTH = auto() - """ Generates both the flat and nested arguments. """ + """Generates both the flat and nested arguments.""" class NestedMode(Enum): - """ - Controls how nested arguments are generated. - """ + """Controls how nested arguments are generated.""" DEFAULT = auto() - """ By default, the full destination path is used. """ + """By default, the full destination path is used.""" WITHOUT_ROOT = auto() - """ - The full destination path is used, but the first level is removed. + """The full destination path is used, but the first level is removed. + Useful because sometimes the first level is uninformative (i.e. 'args'). """ @@ -69,7 +65,6 @@ class DashVariant(Enum): - UNDERSCORE_AND_DASH: - DASH: - """ AUTO = False @@ -79,12 +74,10 @@ class DashVariant(Enum): class FieldWrapper(Wrapper): - """ - The FieldWrapper class acts a bit like an 'argparse.Action' class, which - essentially just creates the `option_strings` and `arg_options` that get - passed to the `add_argument(*option_strings, **arg_options)` function of the - `argparse._ArgumentGroup` (in this case represented by the `parent` - attribute, an instance of the class `DataclassWrapper`). + """The FieldWrapper class acts a bit like an 'argparse.Action' class, which essentially just + creates the `option_strings` and `arg_options` that get passed to the + `add_argument(*option_strings, **arg_options)` function of the `argparse._ArgumentGroup` (in + this case represented by the `parent` attribute, an instance of the class `DataclassWrapper`). The `option_strings`, `required`, `help`, `metavar`, `default`, etc. attributes just autogenerate the argument of the same name of the @@ -179,9 +172,8 @@ def __call__( constructor_arguments: dict[str, dict[str, Any]], option_string: str | None = None, ): - """Immitates a custom Action, which sets the corresponding value from - `values` at the right destination in the `constructor_arguments` of the - parser. + """Immitates a custom Action, which sets the corresponding value from `values` at the right + destination in the `constructor_arguments` of the parser. TODO: Doesn't seem currently possible to check whether the argument was passed in the first place. @@ -206,7 +198,6 @@ def __call__( self._results = {} for destination, value in zip(self.destinations, values): - if self.is_subgroup: logger.debug(f"Ignoring the FieldWrapper for subgroup at dest {self.dest}") return @@ -288,7 +279,9 @@ def get_arg_options(self) -> dict[str, Any]: # Union[, NoneType] assert type_arguments non_none_types = [ - t for t in type_arguments if t is not type(None) # noqa: E721 + t + for t in type_arguments + if t is not type(None) # noqa: E721 ] # noqa: E721 assert non_none_types if len(non_none_types) == 1: @@ -464,8 +457,8 @@ def duplicate_if_needed(self, parsed_values: Any) -> list[Any]: ) def postprocess(self, raw_parsed_value: Any) -> Any: - """Applies any conversions to the 'raw' parsed value before it is used - in the constructor of the dataclass. + """Applies any conversions to the 'raw' parsed value before it is used in the constructor + of the dataclass. Args: raw_parsed_value (Any): The 'raw' parsed value. @@ -584,7 +577,6 @@ def option_strings(self) -> list[str]: added. For an illustration of this, see the aliases example. - """ dashes: list[str] = [] # contains the leading dashes. @@ -649,7 +641,9 @@ def add_args(dash: str, candidates: list[str]) -> None: if add_dash_variants == DashVariant.UNDERSCORE_AND_DASH: additional_options = [option.replace("_", "-") for option in options if "_" in option] - additional_dashes = ["-" if len(option) == 1 else "--" for option in additional_options] + additional_dashes = [ + "-" if len(option) == 1 else "--" for option in additional_options + ] options.extend(additional_options) dashes.extend(additional_dashes) @@ -684,13 +678,12 @@ def is_proxy(self) -> bool: @property def dest_field(self) -> FieldWrapper | None: - """Return the `FieldWrapper` for which `self` is a proxy (same dest). - When a `dest` argument is passed to `field()`, and its value is a - `Field`, that indicates that this Field is just a proxy for another. + """Return the `FieldWrapper` for which `self` is a proxy (same dest). When a `dest` + argument is passed to `field()`, and its value is a `Field`, that indicates that this Field + is just a proxy for another. - In such a case, we replace the dest of `self` with that of the other - wrapper's we then find the corresponding FieldWrapper and use its `dest` - instead of ours. + In such a case, we replace the dest of `self` with that of the other wrapper's we then find + the corresponding FieldWrapper and use its `dest` instead of ours. """ if self._dest_field is not None: return self._dest_field @@ -715,9 +708,9 @@ def nargs(self): @property def default(self) -> Any: - """Either a single default value, when parsing a single argument, or - the list of default values, when this argument is reused multiple times - (which only happens with the `ConflictResolution.ALWAYS_MERGE` option). + """Either a single default value, when parsing a single argument, or the list of default + values, when this argument is reused multiple times (which only happens with the + `ConflictResolution.ALWAYS_MERGE` option). In order of increasing priority, this could either be: 1. The default attribute of the field @@ -845,7 +838,9 @@ def type(self) -> type[Any]: get_field_type_from_annotations, ) - field_type = get_field_type_from_annotations(self.parent.dataclass, self.field.name) + field_type = get_field_type_from_annotations( + self.parent.dataclass, self.field.name + ) self._type = field_type elif isinstance(self._type, dataclasses.InitVar): self._type = self._type.type @@ -860,7 +855,8 @@ def is_choice(self) -> bool: @property def choices(self) -> list | None: - """The list of possible values that can be passed on the command-line for this field, or None.""" + """The list of possible values that can be passed on the command-line for this field, or + None.""" if "choices" in self.custom_arg_options: return self.custom_arg_options["choices"] @@ -913,9 +909,7 @@ def help(self, value: str): @property def metavar(self) -> str | None: - """Returns the 'metavar' when set using one of the `field` functions, - else None. - """ + """Returns the 'metavar' when set using one of the `field` functions, else None.""" if self._metavar: return self._metavar self._metavar = self.custom_arg_options.get("metavar") @@ -983,9 +977,9 @@ def parent(self) -> DataclassWrapper: @property def subparsers_dict(self) -> dict[str, type] | None: - """The dict of subparsers, which is created either when using a - Union[, ] type annotation, or when using the - `subparsers()` function. + """The dict of subparsers, which is created either when using a Union[, + + ] type annotation, or when using the `subparsers()` function. """ if self.field.metadata.get("subparsers"): return self.field.metadata["subparsers"] diff --git a/test/conftest.py b/test/conftest.py index 20fed2de..9fdad448 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -9,8 +9,10 @@ from typing import Any, Generic, TypeVar import pytest -from typing_extensions import NamedTuple # For Generic NamedTuples -from typing_extensions import Literal +from typing_extensions import ( + Literal, + NamedTuple, # For Generic NamedTuples +) pytest.register_assert_rewrite("test.testutils") @@ -71,7 +73,12 @@ class SimpleAttributeWithDefault(NamedTuple, Generic[T]): # TODO: Also add something like `[Optional[t] for t in simple_arguments]`! -default_values_for_type = {int: [0, -111], str: ["bob", ""], float: [0.0, 1e2], bool: [True, False]} +default_values_for_type = { + int: [0, -111], + str: ["bob", ""], + float: [0.0, 1e2], + bool: [True, False], +} @pytest.fixture( @@ -143,7 +150,7 @@ def setup_logging(): ch.setFormatter( logging.Formatter( "%(levelname)s {%(pathname)s:%(lineno)d} - %(message)s", - "%m-%d %H:%M:%S" + "%m-%d %H:%M:%S", # "%(asctime)-15s::%(levelname)s::%(pathname)s::%(lineno)d::%(message)s" ) ) @@ -192,10 +199,7 @@ def no_warning_log_messages(caplog): @pytest.fixture def silent(no_stdout, no_warning_log_messages): - """ - Test fixture that will make a test fail if it prints anything to stdout or - logs warnings - """ + """Test fixture that will make a test fail if it prints anything to stdout or logs warnings.""" @pytest.fixture diff --git a/test/helpers/test_encoding.py b/test/helpers/test_encoding.py index 854ecc3b..1ffd0340 100644 --- a/test/helpers/test_encoding.py +++ b/test/helpers/test_encoding.py @@ -8,6 +8,8 @@ from simple_parsing.helpers.serialization import load, save +from ..testutils import needs_yaml + @dataclass class A: @@ -42,7 +44,7 @@ def __post_init__(self): Container(item=BB(b="hey", extra_field=111)), ], ) -@pytest.mark.parametrize("file_type", [".json", ".yaml"]) +@pytest.mark.parametrize("file_type", [".json", pytest.param(".yaml", marks=needs_yaml)]) def test_encoding_with_dc_types( obj: Container, file_type: str, tmp_path: Path, file_regression: FileRegressionFixture ): @@ -66,7 +68,7 @@ def reset_encoding_fns(): _decoding_fns.update(copy) -@pytest.mark.parametrize("file_type", [".json", ".yaml"]) +@pytest.mark.parametrize("file_type", [".json", pytest.param(".yaml", marks=needs_yaml)]) def test_encoding_inner_dc_types_raises_warning_and_doest_work(tmp_path: Path, file_type: str): file = (tmp_path / "test").with_suffix(file_type) diff --git a/test/helpers/test_enum_serialization.py b/test/helpers/test_enum_serialization.py index 4feb6fe0..d3eef604 100644 --- a/test/helpers/test_enum_serialization.py +++ b/test/helpers/test_enum_serialization.py @@ -1,15 +1,20 @@ import textwrap +import typing from dataclasses import dataclass, field from enum import Enum from pathlib import Path from typing import List, Optional import pytest -import yaml from simple_parsing import Serializable from simple_parsing.helpers.serialization.serializable import dumps_yaml, loads_yaml +if typing.TYPE_CHECKING: + import yaml +else: + yaml = pytest.importorskip("yaml") + class LoggingTypes(Enum): JSONL = "jsonl" @@ -32,7 +37,8 @@ class Parameters(Serializable): raises=KeyError, match="'jsonl'", strict=True, reason="Enums are saved by name, not by value." ) def test_decode_enum_saved_by_value_doesnt_work(tmp_path: Path): - """Test to reproduce https://github.com/lebrice/SimpleParsing/issues/219#issuecomment-1437817369""" + """Test to reproduce + https://github.com/lebrice/SimpleParsing/issues/219#issuecomment-1437817369.""" with open(tmp_path / "conf.yaml", "w") as f: f.write( textwrap.dedent( diff --git a/test/helpers/test_save.py b/test/helpers/test_save.py index 89b4acdc..3589ddb9 100644 --- a/test/helpers/test_save.py +++ b/test/helpers/test_save.py @@ -3,8 +3,10 @@ import pytest from ..nesting.example_use_cases import HyperParameters +from ..testutils import needs_toml, needs_yaml +@needs_yaml def test_save_yaml(tmpdir: Path): hparams = HyperParameters.setup("") tmp_path = Path(tmpdir / "temp.yml") @@ -17,11 +19,12 @@ def test_save_yaml(tmpdir: Path): def test_save_json(tmpdir: Path): hparams = HyperParameters.setup("") tmp_path = Path(tmpdir / "temp.json") - hparams.save_yaml(tmp_path) - _hparams = HyperParameters.load_yaml(tmp_path) + hparams.save_json(tmp_path) + _hparams = HyperParameters.load_json(tmp_path) assert hparams == _hparams +@needs_yaml def test_save_yml(tmpdir: Path): hparams = HyperParameters.setup("") tmp_path = Path(tmpdir / "temp.yml") @@ -65,6 +68,7 @@ def test_save_torch(tmpdir: Path): assert hparams == _hparams +@needs_toml def test_save_toml(tmpdir: Path): hparams = HyperParameters.setup("") tmp_path = Path(tmpdir / "temp.toml") diff --git a/test/helpers/test_serializable.py b/test/helpers/test_serializable.py index c720c74e..4ce05dfd 100644 --- a/test/helpers/test_serializable.py +++ b/test/helpers/test_serializable.py @@ -1,5 +1,4 @@ -"""Adds typed dataclasses for the "config" yaml files. -""" +"""Adds typed dataclasses for the "config" yaml files.""" from collections import OrderedDict from dataclasses import dataclass from enum import Enum @@ -119,7 +118,6 @@ class ParentWithOptionalChildrenWithFriends(ParentWithOptionalChildren): def test_to_dict(silent, Child, Parent): - bob = Child("Bob") clarice = Child("Clarice") nancy = Parent("Nancy", children=dict(bob=bob, clarice=clarice)) @@ -497,11 +495,9 @@ class Container(FrozenSerializable if frozen else Serializable): def test_choice_dict_with_nonserializable_values(frozen: bool): - """Test that when a choice_dict has values of some non-json-FrozenSerializable if frozen else Serializable type, a - custom encoding/decoding function is provided that will map to/from the dict keys - rather than attempt to serialize the field value. - - """ + """Test that when a choice_dict has values of some non-json-FrozenSerializable if frozen else + Serializable type, a custom encoding/decoding function is provided that will map to/from the + dict keys rather than attempt to serialize the field value.""" from simple_parsing import choice def identity(x: int): diff --git a/test/nesting/example_use_cases.py b/test/nesting/example_use_cases.py index cd80c51c..d99d8b23 100644 --- a/test/nesting/example_use_cases.py +++ b/test/nesting/example_use_cases.py @@ -17,9 +17,7 @@ @dataclass class HParams(TestSetup): - """ - Model Hyper-parameters - """ + """Model Hyper-parameters.""" # Number of examples per batch batch_size: int = 32 @@ -33,14 +31,14 @@ class HParams(TestSetup): # number of layers. num_layers: int = default_num_layers # the number of neurons at each layer - neurons_per_layer: List[int] = field(default_factory=lambda: [128] * HParams.default_num_layers) + neurons_per_layer: List[int] = field( + default_factory=lambda: [128] * HParams.default_num_layers + ) @dataclass class RunConfig(TestSetup): - """ - Group of settings used during a training or validation run. - """ + """Group of settings used during a training or validation run.""" # the set of hyperparameters for this run. hparams: HParams = field(default_factory=HParams) @@ -56,9 +54,7 @@ def __post_init__(self): @dataclass class TrainConfig(TestSetup): - """ - Top-level settings for multiple runs. - """ + """Top-level settings for multiple runs.""" # run config to be used during training train: RunConfig = field(default_factory=functools.partial(RunConfig, log_dir="train")) @@ -68,9 +64,7 @@ class TrainConfig(TestSetup): @dataclass class TaskHyperParameters(TestSetup): - """ - HyperParameters for a task-specific model - """ + """HyperParameters for a task-specific model.""" # name of the task name: str diff --git a/test/nesting/test_nesting_defaults.py b/test/nesting/test_nesting_defaults.py index c3dbf5c7..a7c1ca78 100644 --- a/test/nesting/test_nesting_defaults.py +++ b/test/nesting/test_nesting_defaults.py @@ -7,6 +7,8 @@ from simple_parsing.helpers import field from simple_parsing.helpers.serialization.serializable import Serializable +from ..testutils import needs_yaml + @dataclass class AdvTraining(Serializable): @@ -41,8 +43,8 @@ class TrainConfig(Serializable): cpu: bool = False +@needs_yaml def test_comment_pull115(tmp_path): - config_in_file = TrainConfig( data_config=DatasetConfig(name="bob", split="victim", prop="123", value=1.23), epochs=1, diff --git a/test/nesting/test_weird_use_cases.py b/test/nesting/test_weird_use_cases.py index 0fcfece4..20f93a33 100644 --- a/test/nesting/test_weird_use_cases.py +++ b/test/nesting/test_weird_use_cases.py @@ -75,7 +75,9 @@ def test_beautiful_tree_structure_merge(): assert abcd.child_cd.child_d.d == "ABCD_CD_d" -def tree_structure_with_repetitions(some_type: Type[T], default_value_function: Callable[[str], T]): +def tree_structure_with_repetitions( + some_type: Type[T], default_value_function: Callable[[str], T] +): @dataclass class A: a: some_type = default_value_function("a") # type: ignore @@ -94,35 +96,35 @@ class D: @dataclass class AA: - """Weird AA Class""" + """Weird AA Class.""" a1: A = field(default_factory=functools.partial(A, default_value_function("A_1"))) a2: A = field(default_factory=functools.partial(A, default_value_function("A_2"))) @dataclass class BB: - """Weird BB Class""" + """Weird BB Class.""" b1: B = field(default_factory=functools.partial(B, default_value_function("B_1"))) b2: B = field(default_factory=functools.partial(B, default_value_function("B_2"))) @dataclass class CC: - """Weird CC Class""" + """Weird CC Class.""" c1: C = field(default_factory=functools.partial(C, default_value_function("C_1"))) c2: C = field(default_factory=functools.partial(C, default_value_function("C_2"))) @dataclass class DD: - """Weird DD Class""" + """Weird DD Class.""" d1: D = field(default_factory=functools.partial(D, default_value_function("D_1"))) d2: D = field(default_factory=functools.partial(D, default_value_function("D_2"))) @dataclass class AABB: - """Weird AABB Class""" + """Weird AABB Class.""" aa: AA = field( default_factory=functools.partial( @@ -137,7 +139,7 @@ class AABB: @dataclass class CCDD: - """Weird CCDD Class""" + """Weird CCDD Class.""" cc: CC = field( default_factory=functools.partial( @@ -152,7 +154,7 @@ class CCDD: @dataclass class AABBCCDD(TestSetup): - """Weird AABBCCDD Class""" + """Weird AABBCCDD Class.""" aabb: AABB = field( default_factory=functools.partial( @@ -183,7 +185,7 @@ class AABBCCDD(TestSetup): @dataclass class AABBCCDDWeird(TestSetup): - """Weird AABBCCDDWeird Class""" + """Weird AABBCCDDWeird Class.""" a: A = field(default_factory=functools.partial(A, "a")) b: B = field(default_factory=functools.partial(B, "b")) diff --git a/test/postponed_annotations/test_postponed_annotations.py b/test/postponed_annotations/test_postponed_annotations.py index faa25276..7d0a9053 100644 --- a/test/postponed_annotations/test_postponed_annotations.py +++ b/test/postponed_annotations/test_postponed_annotations.py @@ -71,13 +71,13 @@ def test_overwrite_base(): def test_overwrite_field(): """Test that postponed annotations don't break attribute overwriting in multiple files.""" - import test.postponed_annotations.overwrite_base as overwrite_base import test.postponed_annotations.overwrite_attribute as overwrite_attribute + import test.postponed_annotations.overwrite_base as overwrite_base instance = overwrite_attribute.Subclass.setup("--v True") - assert type(instance.attribute) != overwrite_base.ParamCls, ( - "attribute type from Base class correctly ignored" - ) + assert ( + type(instance.attribute) != overwrite_base.ParamCls + ), "attribute type from Base class correctly ignored" assert instance == overwrite_attribute.Subclass( attribute=overwrite_attribute.ParamClsSubclass(True) ), "parsed attribute value is correct" diff --git a/test/test_base.py b/test/test_base.py index ba2bdb91..66acf6aa 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -53,7 +53,7 @@ def test_works_fine_with_other_argparse_arguments(simple_attribute, silent): @dataclass class SomeClass: a: some_type # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" parser = ArgumentParser() parser.add_argument("--x", type=int) @@ -79,7 +79,7 @@ def test_arg_value_is_set_when_args_are_provided( @dataclass class SomeClass(TestSetup): a: some_type = default_value # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" class_a = SomeClass.setup(f"--a {arg_value}") assert class_a.a != default_value @@ -100,7 +100,7 @@ def test_not_providing_required_argument_throws_error(some_type): @dataclass class SomeClass(TestSetup): a: some_type # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" with raises(SystemExit): _ = SomeClass.setup("") @@ -111,7 +111,7 @@ def test_not_providing_required_argument_name_but_no_value_throws_error(some_typ @dataclass class SomeClass(TestSetup): a: some_type # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" with raises(SystemExit): _ = SomeClass.setup("--a") @@ -125,10 +125,10 @@ class Color(Enum): @dataclass class Base(TestSetup): - """A simple base-class example""" + """A simple base-class example.""" a: int # TODO: finetune this - """docstring for attribute 'a'""" + """Docstring for attribute 'a'.""" b: float = 5.0 # inline comment on attribute 'b' c: str = "" @@ -138,7 +138,7 @@ class Extended(Base): """Some extension of base-class `Base`""" d: int = 5 - """ docstring for 'd' in Extended. """ + """docstring for 'd' in Extended.""" e: Color = Color.BLUE @@ -183,7 +183,7 @@ def test_passing_default_value(simple_attribute, silent): @dataclass class SomeClass(TestSetup): a: some_type = passed_value # type: ignore - """some docstring for attribute 'a' """ + """Some docstring for attribute 'a'.""" some_class = SomeClass.setup(default=SomeClass(expected_value)) assert some_class.a == expected_value @@ -265,7 +265,7 @@ class Temperature(enum.Enum): @dataclass class MyPreferences(TestSetup): - """You can use Enums""" + """You can use Enums.""" color: Color = Color.BLUE # my favorite colour # a list of colors diff --git a/test/test_bools.py b/test/test_bools.py index 5da7c08e..34d9c6e4 100644 --- a/test/test_bools.py +++ b/test/test_bools.py @@ -137,7 +137,7 @@ def test_bool_nargs( ): @dataclass class MyClass(TestSetup): - """Some test class""" + """Some test class.""" a: bool = helpers.field(nargs=nargs) @@ -183,7 +183,7 @@ def test_list_of_bools_nargs( ): @dataclass class MyClass(TestSetup): - """Some test class""" + """Some test class.""" a: List[bool] = helpers.field(nargs=nargs) @@ -199,9 +199,8 @@ class MyClass(TestSetup): @pytest.mark.parametrize("field", [field, flag]) def test_using_custom_negative_prefix(field): - """Check that we can customize the negative prefix for a boolean field, either with the - `field` function or with the `flag` function. - """ + """Check that we can customize the negative prefix for a boolean field, either with the `field` + function or with the `flag` function.""" @dataclass class Config(TestSetup): @@ -254,7 +253,7 @@ class OtherConfig(TestSetup): @pytest.mark.parametrize("default_value", [True, False]) def test_nested_bool_field_negative_args(default_value: bool): - """Test that we get --train.nodebug instead of --notrain.debug""" + """Test that we get --train.nodebug instead of --notrain.debug.""" @dataclass class Options: @@ -322,10 +321,10 @@ def test_bool_nested_field_when_conflict_has_two_dashes( bool_field: Callable[..., bool], default_value: bool ): """TODO:""" + # Check that there isn't a "-train.d bool" argument generated here, only "--train.d bool" @dataclass class Options: - # whether or not to execute in debug mode. d: bool = bool_field(default=default_value) diff --git a/test/test_conflicts.py b/test/test_conflicts.py index a99b0974..c6561c27 100644 --- a/test/test_conflicts.py +++ b/test/test_conflicts.py @@ -1,5 +1,4 @@ -"""Tests for weird conflicts. -""" +"""Tests for weird conflicts.""" import argparse import functools from dataclasses import dataclass, field diff --git a/test/test_custom_args.py b/test/test_custom_args.py index 4b36b209..a2e614aa 100644 --- a/test/test_custom_args.py +++ b/test/test_custom_args.py @@ -172,13 +172,13 @@ def test_store_false_action(): def test_only_dashes(): @dataclass class AClass(TestSetup): - """foo""" + """Foo.""" a_var: int @dataclass class SomeClass(TestSetup): - """lol""" + """Lol.""" my_var: int a: AClass @@ -228,7 +228,7 @@ class SomeClass(TestSetup): def test_list_of_choices(): @dataclass class Foo(TestSetup): - """Some class Foo""" + """Some class Foo.""" # A sequence of tasks. task_sequence: List[str] = field(choices=["train", "test", "ood"]) diff --git a/test/test_decoding.py b/test/test_decoding.py index 3581e840..a52a884d 100644 --- a/test/test_decoding.py +++ b/test/test_decoding.py @@ -16,9 +16,10 @@ from simple_parsing.helpers.serialization.serializable import loads_json from simple_parsing.utils import DataclassT +from .testutils import needs_yaml -def test_encode_something(simple_attribute): +def test_encode_something(simple_attribute): some_type, passed_value, expected_value = simple_attribute @dataclass @@ -159,6 +160,7 @@ class Parameters(Serializable): def test_implicit_int_casting(tmp_path: Path): """Test that we do in fact perform the unsafe casting as described in #227: + https://github.com/lebrice/SimpleParsing/issues/227 """ with open(tmp_path / "conf.yaml", "w") as f: @@ -172,6 +174,7 @@ def test_implicit_int_casting(tmp_path: Path): """ ) ) + _yaml = pytest.importorskip("yaml") with pytest.warns(RuntimeWarning, match="Unsafe casting"): file_config = Parameters.load(tmp_path / "conf.yaml") assert file_config == Parameters(hparams=Hparams(severity=0, probs=[0, 0])) @@ -197,8 +200,9 @@ def reset_int_decoding_fns_after_test(): _decoding_fns.update(backup) +@needs_yaml def test_registering_safe_casting_decoding_fn(): - """Test the solution to 'issue' #227: https://github.com/lebrice/SimpleParsing/issues/227""" + """Test the solution to 'issue' #227: https://github.com/lebrice/SimpleParsing/issues/227.""" # Solution: register a decoding function for `int` that casts to int, but raises an error if # the value would lose precision. @@ -211,19 +215,16 @@ def _safe_cast(v: Any) -> int: register_decoding_fn(int, _safe_cast, overwrite=True) - assert ( - Parameters.loads_yaml( - textwrap.dedent( - """\ + assert Parameters.loads_yaml( + textwrap.dedent( + """\ hparams: use_log: 1 severity: 0.0 probs: [3, 4.0] """ - ) ) - == Parameters(hparams=Hparams(severity=0, probs=[3, 4])) - ) + ) == Parameters(hparams=Hparams(severity=0, probs=[3, 4])) with pytest.raises(ValueError, match="Cannot safely cast 0.1 to int"): Parameters.loads_yaml( @@ -312,7 +313,8 @@ def test_issue_227_unsafe_int_casting_on_load( expected_message: str, expected_result: DataclassT, ): - """Test that a warning is raised when performing a lossy cast when deserializing a dataclass.""" + """Test that a warning is raised when performing a lossy cast when deserializing a + dataclass.""" with pytest.warns( RuntimeWarning, match=expected_message, diff --git a/test/test_decorator.py b/test/test_decorator.py index 9521eee7..a3748ebe 100644 --- a/test/test_decorator.py +++ b/test/test_decorator.py @@ -2,10 +2,11 @@ import collections import dataclasses import functools +import inspect import sys -from typing import Callable import typing -import inspect +from typing import Callable + import pytest import simple_parsing as sp diff --git a/test/test_docstrings.py b/test/test_docstrings.py index 5f805e38..d13307fa 100644 --- a/test/test_docstrings.py +++ b/test/test_docstrings.py @@ -9,7 +9,7 @@ @dataclass class Base: - """A simple base-class example""" + """A simple base-class example.""" a: int # TODO: finetune this @@ -98,7 +98,7 @@ class UniqueFoo(TestSetup): def test_docstrings_with_multiple_inheritance(): - """Test to reproduce issue 162: https://github.com/lebrice/SimpleParsing/issues/162""" + """Test to reproduce issue 162: https://github.com/lebrice/SimpleParsing/issues/162.""" @dataclass class Fooz: @@ -149,7 +149,7 @@ class Base2(TestSetup): """ bar: int = 123 # inline - """field docstring from base class""" + """Field docstring from base class.""" @dataclass class FooB(Base2): @@ -159,10 +159,10 @@ class FooB(Base2): assert get_attribute_docstring(FooB, "bar") == AttributeDocString( comment_inline="The bar property", comment_above="Above", - docstring_below="field docstring from base class", + docstring_below="Field docstring from base class.", ) - assert "field docstring from base class" in FooB.get_help_text() + assert "Field docstring from base class." in FooB.get_help_text() def test_getdocstring_bug(): @@ -187,24 +187,23 @@ class SomeClass: ---------- batch_size : int, optional _description_, by default 32 - """ # above batch_size: int = 32 # side - """below""" + """Below.""" assert get_attribute_docstring(SomeClass, "batch_size") == AttributeDocString( desc_from_cls_docstring="_description_, by default 32", comment_above="above", comment_inline="side", - docstring_below="below", + docstring_below="Below.", ) def test_help_takes_value_from_docstring(): @dataclass - class Args(TestSetup): + class Args: """args. Attributes: diff --git a/test/test_examples.py b/test/test_examples.py index d20967c4..55b91b29 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1,6 +1,6 @@ """A test to make sure that all the example files work without crashing. -(Could be seen as a kind of integration test.) +(Could be seen as a kind of integration test.) """ from __future__ import annotations @@ -16,6 +16,8 @@ import pytest +from .testutils import needs_yaml + expected = "" @@ -107,18 +109,28 @@ def test_running_example_outputs_expected( *[ pytest.param( p, - marks=[ - pytest.mark.skipif( - sys.version_info[:2] == (3, 6), - reason="Example uses __future__ annotations feature", - ), - pytest.mark.xfail( - reason="Example has different indentation depending on python version.", - ), - ], + marks=( + [ + pytest.mark.skipif( + sys.version_info[:2] == (3, 6), + reason="Example uses __future__ annotations feature", + ), + pytest.mark.xfail( + reason="Example has different indentation depending on python version.", + ), + ] + if p == "examples/subgroups/subgroups_example.py" + else [needs_yaml] + if p + in [ + "examples/config_files/one_config.py", + "examples/config_files/composition.py", + "examples/config_files/many_configs.py", + "examples/serialization/serialization_example.py", + ] + else [] + ), ) - if p == "examples/subgroups/subgroups_example.py" - else p for p in glob.glob("examples/**/*.py") if p not in { @@ -134,7 +146,9 @@ def test_running_example_outputs_expected_without_arg( set_prog_name: Callable[[str, str | None], None], assert_equals_stdout: Callable[[str, str], None], ): - return test_running_example_outputs_expected(file_path, "", set_prog_name, assert_equals_stdout) + return test_running_example_outputs_expected( + file_path, "", set_prog_name, assert_equals_stdout + ) @contextmanager diff --git a/test/test_future_annotations.py b/test/test_future_annotations.py index 51e9c0a6..aa8f443e 100644 --- a/test/test_future_annotations.py +++ b/test/test_future_annotations.py @@ -1,4 +1,4 @@ -""" Tests for compatibility with the postponed evaluation of annotations. """ +"""Tests for compatibility with the postponed evaluation of annotations.""" from __future__ import annotations import dataclasses @@ -16,7 +16,7 @@ from simple_parsing.helpers import Serializable from simple_parsing.utils import is_list, is_tuple -from .testutils import TestSetup +from .testutils import YAML_INSTALLED, TestSetup @dataclass @@ -77,8 +77,7 @@ class ClassWithNewUnionSyntax(TestSetup): @dataclass class OtherClassWithNewUnionSyntax(ClassWithNewUnionSyntax): """Create a child class without annotations, just to check that they are picked up from the - base class. - """ + base class.""" @pytest.mark.parametrize( @@ -233,19 +232,21 @@ def test_serialization_deserialization(): assert Opts2 in Serializable.subclasses assert Wrapper.from_dict(opts.to_dict()) == opts assert Wrapper.loads_json(opts.dumps_json()) == opts - assert Wrapper.loads_yaml(opts.dumps_yaml()) == opts + + if YAML_INSTALLED: + assert Wrapper.loads_yaml(opts.dumps_yaml()) == opts @dataclass class OptimizerConfig(TestSetup): lr_scheduler: str = "cosine" - """ LR scheduler to use. """ + """LR scheduler to use.""" @dataclass class SubclassOfOptimizerConfig(OptimizerConfig): bar: int | float = 123 - """ some dummy arg bar. """ + """some dummy arg bar.""" def test_missing_annotation_on_subclass(): diff --git a/test/test_huggingface_compat.py b/test/test_huggingface_compat.py index 46576a07..5fd17fc0 100644 --- a/test/test_huggingface_compat.py +++ b/test/test_huggingface_compat.py @@ -1,4 +1,4 @@ -""" Simple test for compatibility with HuggingFace's help text convention. +"""Simple test for compatibility with HuggingFace's help text convention. This checks that Simple-Parsing can be used as a replacement for the HFArgumentParser. """ @@ -13,14 +13,13 @@ from simple_parsing import ArgumentParser from simple_parsing.docstring import get_attribute_docstring -from .testutils import TestSetup, raises_invalid_choice +from .testutils import TestSetup, needs_yaml, raises_invalid_choice @dataclass class ModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. - """ + """Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train + from scratch.""" model_name_or_path: Optional[str] = field( default=None, @@ -93,9 +92,7 @@ def __post_init__(self): @dataclass class DataTrainingArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - """ + """Arguments pertaining to what data we are going to input our model for training and eval.""" dataset_name: Optional[str] = field( default=None, @@ -217,7 +214,8 @@ class Config(TestSetup): def test_choices(): - """Checks that the `choices` in the field metadata are used as the `choice` argument to `add_argument`""" + """Checks that the `choices` in the field metadata are used as the `choice` argument to + `add_argument`""" with raises_invalid_choice(): Config.setup("--log_level invalid") @@ -226,9 +224,7 @@ def test_choices(): class ExplicitEnum(str, Enum): - """ - Enum with more explicit error message for missing values. - """ + """Enum with more explicit error message for missing values.""" @classmethod def _missing_(cls, value): @@ -276,9 +272,7 @@ class SchedulerType(ExplicitEnum): class OptimizerNames(ExplicitEnum): - """ - Stores the acceptable string identifiers for optimizers. - """ + """Stores the acceptable string identifiers for optimizers.""" ADAMW_HF = "adamw_hf" ADAMW_TORCH = "adamw_torch" @@ -319,10 +313,9 @@ def test_enums_are_parsed_to_enum_member(): @dataclass class TrainingArguments(TestSetup): - """ - TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop - itself**. - Using [`HfArgumentParser`] we can turn this class into + """TrainingArguments is the subset of the arguments we use in our example scripts **which + relate to the training loop itself**. Using [`HfArgumentParser`] we can turn this class into. + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the command line. Parameters: @@ -1266,7 +1259,6 @@ def __post_init__(self): @pytest.mark.xfail(reason="docstring_parser can't parse the docstring of TrainingArguments!") def test_docstring_parse_works_with_hf_training_args(): - assert get_attribute_docstring(TrainingArguments, "output_dir").desc_from_cls_docstring == ( "The output directory where the model predictions and checkpoints will be written." ) @@ -1289,7 +1281,15 @@ def test_entire_docstring_isnt_used_as_help(): TrainingArguments(save_strategy=IntervalStrategy.EPOCH), ], ) -@pytest.mark.parametrize("filename", ["bob.yaml", "bob.json", "bob.pkl", "bob.yml"]) +@pytest.mark.parametrize( + "filename", + [ + pytest.param("bob.yaml", marks=needs_yaml), + "bob.json", + "bob.pkl", + pytest.param("bob.yml", marks=needs_yaml), + ], +) def test_serialization(tmp_path: Path, filename: str, args: TrainingArguments): """test that serializing / deserializing a TrainingArguments works.""" from simple_parsing.helpers.serialization import load, save diff --git a/test/test_inheritance.py b/test/test_inheritance.py index 0547e68c..7f6fce7c 100644 --- a/test/test_inheritance.py +++ b/test/test_inheritance.py @@ -83,7 +83,7 @@ def test_subclasses_with_same_base_class_with_args_merge(): def test_weird_structure(): - """both is-a, and has-a at the same time, a very weird inheritance structure""" + """Both is-a, and has-a at the same time, a very weird inheritance structure.""" @dataclass class ConvBlock(Serializable): @@ -94,14 +94,14 @@ class ConvBlock(Serializable): @dataclass class GeneratorHParams(ConvBlock): - """Settings of the Generator model""" + """Settings of the Generator model.""" conv: ConvBlock = field(default_factory=ConvBlock) optimizer: str = choice("ADAM", "RMSPROP", "SGD", default="ADAM") @dataclass class DiscriminatorHParams(ConvBlock): - """Settings of the Discriminator model""" + """Settings of the Discriminator model.""" conv: ConvBlock = field(default_factory=ConvBlock) optimizer: str = choice("ADAM", "RMSPROP", "SGD", default="ADAM") diff --git a/test/test_issue64.py b/test/test_issue64.py index 9e79d225..27e51418 100644 --- a/test/test_issue64.py +++ b/test/test_issue64.py @@ -10,7 +10,7 @@ @dataclass class Options: - """These are the options""" + """These are the options.""" foo: str = "aaa" # Description bar: str = "bbb" @@ -18,7 +18,6 @@ class Options: @pytest.mark.xfail(reason="Issue64 is solved below.") def test_reproduce_issue64(): - parser = ArgumentParser("issue64") parser.add_arguments(Options, dest="options") @@ -46,11 +45,11 @@ def test_reproduce_issue64(): def test_vanilla_argparse_issue64(): - """This test shows that the ArgumentDefaultsHelpFormatter of argparse doesn't add - the "(default: xyz)" if the 'help' argument isn't already passed! + """This test shows that the ArgumentDefaultsHelpFormatter of argparse doesn't add the + "(default: xyz)" if the 'help' argument isn't already passed! - This begs the question: Should simple-parsing add a 'help' argument always, so that - the formatter can then add the default string after? + This begs the question: Should simple-parsing add a 'help' argument always, so that the + formatter can then add the default string after? """ import argparse @@ -86,10 +85,8 @@ def test_vanilla_argparse_issue64(): def test_solved_issue64(): - """test that shows that Issue 64 is solved now, by adding a single space as the - 'help' argument, the help formatter can then add the "(default: bbb)" after the - argument. - """ + """test that shows that Issue 64 is solved now, by adding a single space as the 'help' + argument, the help formatter can then add the "(default: bbb)" after the argument.""" parser = ArgumentParser("issue64") parser.add_arguments(Options, dest="options") diff --git a/test/test_issue_107.py b/test/test_issue_107.py index 608edca0..e4bce4b6 100644 --- a/test/test_issue_107.py +++ b/test/test_issue_107.py @@ -1,4 +1,4 @@ -""" test for https://github.com/lebrice/SimpleParsing/issues/107 """ +"""Test for https://github.com/lebrice/SimpleParsing/issues/107.""" from dataclasses import dataclass from typing import Any diff --git a/test/test_issue_132.py b/test/test_issue_132.py index ea37181c..ab3f3b92 100644 --- a/test/test_issue_132.py +++ b/test/test_issue_132.py @@ -1,4 +1,4 @@ -""" Test for https://github.com/lebrice/SimpleParsing/issues/132 """ +"""Test for https://github.com/lebrice/SimpleParsing/issues/132.""" from dataclasses import dataclass from simple_parsing import field diff --git a/test/test_issue_144.py b/test/test_issue_144.py index 1b08b0a3..a16a6252 100644 --- a/test/test_issue_144.py +++ b/test/test_issue_144.py @@ -1,8 +1,8 @@ -""" Tests for issue 144: https://github.com/lebrice/SimpleParsing/issues/144 """ +"""Tests for issue 144: https://github.com/lebrice/SimpleParsing/issues/144.""" from __future__ import annotations from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union +from typing import Union import pytest @@ -12,7 +12,7 @@ class TestOptional: @dataclass class Foo(Serializable): - foo: Optional[int] = 123 + foo: int | None = 123 @pytest.mark.parametrize("d", [{"foo": None}, {"foo": 1}]) def test_round_trip(self, d: dict): @@ -24,7 +24,7 @@ def test_round_trip(self, d: dict): class TestUnion: @dataclass class Foo(Serializable): - foo: Union[int, dict[int, bool]] = 123 + foo: Union[int, dict[int, bool]] = 123 # noqa: UP007 @pytest.mark.parametrize("d", [{"foo": None}, {"foo": {1: "False"}}]) def test_round_trip(self, d: dict): @@ -36,7 +36,7 @@ def test_round_trip(self, d: dict): class TestList: @dataclass class Foo(Serializable): - foo: List[int] = field(default_factory=list) + foo: list[int] = field(default_factory=list) @pytest.mark.parametrize("d", [{"foo": []}, {"foo": [123, 456]}]) def test_round_trip(self, d: dict): @@ -48,7 +48,7 @@ def test_round_trip(self, d: dict): class TestTuple: @dataclass class Foo(Serializable): - foo: Tuple[int, float, bool] + foo: tuple[int, float, bool] @pytest.mark.parametrize("d", [{"foo": (1, 1.2, False)}, {"foo": ("1", "1.2", "True")}]) def test_round_trip(self, d: dict): @@ -60,7 +60,7 @@ def test_round_trip(self, d: dict): class TestDict: @dataclass class Foo(Serializable): - foo: Dict[int, float] = field(default_factory=dict) + foo: dict[int, float] = field(default_factory=dict) @pytest.mark.parametrize("d", [{"foo": {}}, {"foo": {"123": "4.56"}}]) def test_round_trip(self, d: dict): diff --git a/test/test_lists.py b/test/test_lists.py index e2892bf9..6d34a4cd 100644 --- a/test/test_lists.py +++ b/test/test_lists.py @@ -110,7 +110,7 @@ def test_list_supported_formats( @dataclass class SomeClass(TestSetup): a: List[item_type] = field(default_factory=list) # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" arguments = "--a " + list_formatting_function(passed_values) print(arguments) @@ -148,7 +148,7 @@ def test_parse_multiple_with_list_attributes( @dataclass class SomeClass(TestSetup): a: List[item_type] = field(default_factory=list) # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" arguments = "--a " + list_of_lists_formatting_function(passed_values) classes = list(SomeClass.setup_multiple(3, arguments)) diff --git a/test/test_literal.py b/test/test_literal.py index 78bfff5f..caac3be4 100644 --- a/test/test_literal.py +++ b/test/test_literal.py @@ -110,7 +110,7 @@ class SomeFoo(TestSetup): @pytest.mark.skipif(sys.version_info != (3, 9), reason="Bug is only in 3.9") @pytest.mark.xfail(strict=True, reason="This bug was fixed by #260") def test_reproduce_issue_259_parsing_literal_py39(): - """Reproduces https://github.com/lebrice/SimpleParsing/issues/259""" + """Reproduces https://github.com/lebrice/SimpleParsing/issues/259.""" # $ python issue.py # usage: issue.py [-h] [--param typing.Literal['bar', 'biz']] # issue.py: error: argument --param: invalid typing.Literal['bar', 'biz'] value: 'biz' diff --git a/test/test_multiple.py b/test/test_multiple.py index 19f8bdc5..d94ba596 100644 --- a/test/test_multiple.py +++ b/test/test_multiple.py @@ -39,7 +39,7 @@ def test_parse_multiple_with_no_arguments_sets_default_value( @dataclass class SomeClass(TestSetup): a: some_type = expected_value # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" classes = SomeClass.setup_multiple(num_instances, "") assert len(classes) == num_instances @@ -61,7 +61,7 @@ def test_parse_multiple_with_single_arg_value_sets_that_value_for_all_instances( @dataclass class SomeClass(TestSetup): a: some_type = expected_value # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" classes = SomeClass.setup_multiple(num_instances, f"--a {passed_value}") @@ -88,7 +88,7 @@ def test_parse_multiple_with_provided_value_for_each_instance( @dataclass class SomeClass(TestSetup): a: some_type = default_value # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" # TODO: maybe test out other syntaxes for passing in multiple argument values? (This looks a lot like passing in a list of values..) arguments = f"--a {' '.join(str(p) for p in passed_values)}" @@ -107,7 +107,7 @@ def test_parse_multiple_without_required_arguments(some_type: Type): @dataclass class SomeClass(TestSetup): a: some_type # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" with exits_and_writes_to_stderr(): SomeClass.setup_multiple(2, "") @@ -115,11 +115,13 @@ class SomeClass(TestSetup): @parametrize("container_type", [List, Tuple]) @parametrize("item_type", [int, float, str, bool]) -def test_parse_multiple_without_required_container_arguments(container_type: Type, item_type: Type): +def test_parse_multiple_without_required_container_arguments( + container_type: Type, item_type: Type +): @dataclass class SomeClass(TestSetup): a: container_type[item_type] # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" with exits_and_writes_to_stderr("the following arguments are required:"): _ = SomeClass.setup_multiple(3, "") @@ -131,7 +133,7 @@ def test_parse_multiple_with_arg_name_without_arg_value(container_type: Type, it @dataclass class SomeClass(TestSetup): a: container_type[item_type] # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" with exits_and_writes_to_stderr("expected at least one argument"): _ = SomeClass.setup_multiple(3, "--a") @@ -154,7 +156,7 @@ def test_parse_multiple_containers_default_value( @dataclass class SomeClass(TestSetup): a: container_type = field(default_factory=lambda: default_value.copy()) # type: ignore - """some docstring for attribute 'a'""" + """Some docstring for attribute 'a'.""" values = list(SomeClass.setup_multiple(num_instances)) assert values == [SomeClass(default_value) for i in range(num_instances)] diff --git a/test/test_optional.py b/test/test_optional.py index babebb44..425aa0d3 100644 --- a/test/test_optional.py +++ b/test/test_optional.py @@ -50,6 +50,7 @@ class Parent(TestSetup): def test_optional_parameter_group(): """Reproduces issue #28 : + https://github.com/lebrice/SimpleParsing/issues/28#issue-663689719 """ parent: Parent = Parent.setup("--breed Shitzu") diff --git a/test/test_optional_subparsers.py b/test/test_optional_subparsers.py index 8b13c7e4..2d8e07d2 100644 --- a/test/test_optional_subparsers.py +++ b/test/test_optional_subparsers.py @@ -130,7 +130,6 @@ class Options(HyperParameters): @pytest.mark.parametrize("seed", [123, 456, 789]) def test_sample_with_subparsers_field(seed: int): - random.seed(seed) samples = [Options.sample() for _ in range(10)] diff --git a/test/test_performance.py b/test/test_performance.py index 273e45bd..47e54465 100644 --- a/test/test_performance.py +++ b/test/test_performance.py @@ -1,11 +1,14 @@ import functools import importlib -from pathlib import Path import sys +from pathlib import Path from typing import Callable, TypeVar + import pytest from pytest_benchmark.fixture import BenchmarkFixture +from .testutils import needs_yaml + C = TypeVar("C", bound=Callable) @@ -53,9 +56,10 @@ def test_import_performance(benchmark: BenchmarkFixture): group="parse", ) def test_parse_performance(benchmark: BenchmarkFixture): - import simple_parsing as sp from test.nesting.example_use_cases import HyperParameters + import simple_parsing as sp + benchmark( call_before(clear_lru_caches, sp.parse), HyperParameters, @@ -66,11 +70,12 @@ def test_parse_performance(benchmark: BenchmarkFixture): @pytest.mark.benchmark( group="serialization", ) -@pytest.mark.parametrize("filetype", [".yaml", ".json", ".pkl"]) +@pytest.mark.parametrize("filetype", [pytest.param(".yaml", marks=needs_yaml), ".json", ".pkl"]) def test_serialization_performance(benchmark: BenchmarkFixture, tmp_path: Path, filetype: str): - from simple_parsing.helpers.serialization import save, load from test.test_huggingface_compat import TrainingArguments + from simple_parsing.helpers.serialization import load, save + args = TrainingArguments() path = (tmp_path / "bob").with_suffix(filetype) diff --git a/test/test_set_defaults.py b/test/test_set_defaults.py index 7754a2d7..cf4bfa32 100644 --- a/test/test_set_defaults.py +++ b/test/test_set_defaults.py @@ -1,14 +1,19 @@ -""" Tests for the setdefaults method of the parser. """ +"""Tests for the setdefaults method of the parser.""" +import typing from dataclasses import dataclass, field from pathlib import Path import pytest -import yaml from simple_parsing.helpers.serialization.serializable import save, to_dict from simple_parsing.parsing import ArgumentParser from simple_parsing.wrappers.field_wrapper import NestedMode +if typing.TYPE_CHECKING: + import yaml +else: + yaml = pytest.importorskip("yaml") + from .testutils import TestSetup @@ -62,8 +67,7 @@ def test_set_broken_defaults_from_file(tmp_path: Path): def test_set_defaults_from_file_without_root(tmp_path: Path): """test that set_defaults accepts the fields of the dataclass directly, when the parser has - nested_mode=NestedMode.WITHOUT_ROOT. - """ + nested_mode=NestedMode.WITHOUT_ROOT.""" parser = ArgumentParser(nested_mode=NestedMode.WITHOUT_ROOT) parser.add_arguments(Foo, dest="foo") @@ -102,9 +106,8 @@ class ConfigWithFoo(TestSetup): @pytest.mark.parametrize("with_root", [True, False]) @pytest.mark.parametrize("add_arguments_before", [True, False]) def test_with_nested_field(tmp_path: Path, add_arguments_before: bool, with_root: bool): - """Test that when we use set_defaults with a config that has a nested dataclass field, - we can pass a path to a yaml file for one of the field, and it also works. - """ + """Test that when we use set_defaults with a config that has a nested dataclass field, we can + pass a path to a yaml file for one of the field, and it also works.""" parser = ArgumentParser( nested_mode=NestedMode.WITHOUT_ROOT if not with_root else NestedMode.DEFAULT ) diff --git a/test/test_subgroups.py b/test/test_subgroups.py index 3c4eba3f..1fb2688a 100644 --- a/test/test_subgroups.py +++ b/test/test_subgroups.py @@ -189,7 +189,7 @@ def test_parse(dataclass_type: type[TestClass], args: str, expected: TestClass): def test_subgroup_choice_is_saved_on_namespace(): - """test for https://github.com/lebrice/SimpleParsing/issues/139 + """Test for https://github.com/lebrice/SimpleParsing/issues/139. Need to save the chosen subgroup name somewhere on the args. """ @@ -244,7 +244,6 @@ def test_two_subgroups_with_conflict(args_str: str, expected: TwoSubgroupsWithCo def test_subgroups_with_key_default() -> None: - with pytest.raises(ValueError): subgroups({"a": A, "b": B}, default_factory="a") @@ -270,13 +269,17 @@ def test_subgroup_default_needs_to_be_key_in_dict(): def test_subgroup_default_factory_needs_to_be_value_in_dict(): - with pytest.raises(ValueError, match="`default_factory` must be a value in the subgroups dict"): + with pytest.raises( + ValueError, match="`default_factory` must be a value in the subgroups dict" + ): _ = subgroups({"a": B, "aa": A}, default_factory=C) def test_lambdas_dont_return_same_instance(): """Slightly unrelated, but I just want to check if lambda expressions return the same object - instance when a default factory looks like `lambda: A()`. If so, then I won't encourage this. + instance when a default factory looks like `lambda: A()`. + + If so, then I won't encourage this. """ @dataclass @@ -292,8 +295,7 @@ class Config(TestSetup): def test_partials_new_args_overwrite_set_values(): """Double-check that functools.partial overwrites the keywords that are stored when it is - created with the ones that are passed when calling it. - """ + created with the ones that are passed when calling it.""" # just to avoid the test passing if I were to hard-code the same value as the default by # accident. default_a = A().a @@ -438,7 +440,8 @@ class Foo(TestSetup): ], ) def test_other_default_factories(a_factory: Callable[[], A], b_factory: Callable[[], B]): - """Test using other kinds of default factories (i.e. functools.partial or lambda expressions)""" + """Test using other kinds of default factories (i.e. functools.partial or lambda + expressions)""" @dataclass class Foo(TestSetup): @@ -467,6 +470,7 @@ def test_help_string_displays_default_factory_arguments( When using `functools.partial` or lambda expressions, we'd ideally also like the help text to show the field values from inside the `partial` or lambda, if possible. """ + # NOTE: Here we need to return just A() and B() with these default factories, so the defaults # for the fields are the same @dataclass @@ -585,7 +589,6 @@ class ModelBConfig(ModelConfig): @dataclass class Config(TestSetup): - # Which model to use model: ModelConfig = subgroups( {"model_a": ModelAConfig, "model_b": ModelBConfig}, @@ -594,7 +597,7 @@ class Config(TestSetup): def test_destination_substring_of_other_destination_issue191(): - """Test for https://github.com/lebrice/SimpleParsing/issues/191""" + """Test for https://github.com/lebrice/SimpleParsing/issues/191.""" parser = ArgumentParser() parser.add_arguments(Config, dest="config") @@ -666,7 +669,9 @@ def test_annotated_as_subgroups(): @dataclasses.dataclass class Config(TestSetup): - model: Model = subgroups({"small": SmallModel, "big": BigModel}, default_factory=SmallModel) + model: Model = subgroups( + {"small": SmallModel, "big": BigModel}, default_factory=SmallModel + ) assert Config.setup().model == SmallModel() # Hopefully this illustrates why Annotated aren't exactly great: @@ -799,7 +804,7 @@ def test_help( @pytest.mark.parametrize("frozen", [True, False]) def test_nested_subgroups(frozen: bool): - """Assert that #160 is fixed: https://github.com/lebrice/SimpleParsing/issues/160""" + """Assert that #160 is fixed: https://github.com/lebrice/SimpleParsing/issues/160.""" @dataclass(frozen=frozen) class FooConfig: @@ -880,7 +885,6 @@ class Dataset2Config(DatasetConfig): @dataclass class Config(TestSetup): - # Which model to use model: ModelConfig = subgroups( {"model_a": ModelAConfig, "model_b": ModelBConfig}, diff --git a/test/test_subgroups/test_help[Config---help].md b/test/test_subgroups/test_help[Config---help].md index fbf29578..9a51d7d7 100644 --- a/test/test_subgroups/test_help[Config---help].md +++ b/test/test_subgroups/test_help[Config---help].md @@ -1,11 +1,10 @@ -# Regression file for [this test](test/test_subgroups.py:725) +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: ```python @dataclass class Config(TestSetup): - # Which model to use model: ModelConfig = subgroups( {"model_a": ModelAConfig, "model_b": ModelBConfig}, diff --git a/test/test_subgroups/test_help[Config---model=model_a --help].md b/test/test_subgroups/test_help[Config---model=model_a --help].md index 1b99968a..7d2f8970 100644 --- a/test/test_subgroups/test_help[Config---model=model_a --help].md +++ b/test/test_subgroups/test_help[Config---model=model_a --help].md @@ -1,11 +1,10 @@ -# Regression file for [this test](test/test_subgroups.py:725) +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: ```python @dataclass class Config(TestSetup): - # Which model to use model: ModelConfig = subgroups( {"model_a": ModelAConfig, "model_b": ModelBConfig}, diff --git a/test/test_subgroups/test_help[Config---model=model_b --help].md b/test/test_subgroups/test_help[Config---model=model_b --help].md index e9097f94..1e2fb4c0 100644 --- a/test/test_subgroups/test_help[Config---model=model_b --help].md +++ b/test/test_subgroups/test_help[Config---model=model_b --help].md @@ -1,11 +1,10 @@ -# Regression file for [this test](test/test_subgroups.py:725) +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: ```python @dataclass class Config(TestSetup): - # Which model to use model: ModelConfig = subgroups( {"model_a": ModelAConfig, "model_b": ModelBConfig}, diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md index da79dce2..5b82f578 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --a 100 --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:725) +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md index a3911bd7..312b7218 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=even --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:725) +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md index 6ea4e758..cbfa7eeb 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --a 123 --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:725) +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md index 1f1560d7..890e7326 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---conf=odd --help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:725) +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subgroups/test_help[ConfigWithFrozen---help].md b/test/test_subgroups/test_help[ConfigWithFrozen---help].md index ba951380..d05b94ef 100644 --- a/test/test_subgroups/test_help[ConfigWithFrozen---help].md +++ b/test/test_subgroups/test_help[ConfigWithFrozen---help].md @@ -1,4 +1,4 @@ -# Regression file for [this test](test/test_subgroups.py:725) +# Regression file for [this test](test/test_subgroups.py:730) Given Source code: diff --git a/test/test_subparsers.py b/test/test_subparsers.py index 2b2ed665..4add289a 100644 --- a/test/test_subparsers.py +++ b/test/test_subparsers.py @@ -16,7 +16,7 @@ @dataclass class TrainOptions: - """Training Options""" + """Training Options.""" lr: float = 1e-3 train_path: Path = Path("./train") @@ -24,7 +24,7 @@ class TrainOptions: @dataclass class ValidOptions: - """Validation Options""" + """Validation Options.""" test_path: Path = Path("./test") metric: str = "accuracy" @@ -32,7 +32,7 @@ class ValidOptions: @dataclass class GlobalOptions(TestSetup): - """Global Options""" + """Global Options.""" # mode, either Train or Valid. mode: Union[TrainOptions, ValidOptions] = subparsers( @@ -71,7 +71,7 @@ def test_help_text_works(): @dataclass class Start: - """Start command""" + """Start command.""" value: str = "start command value" @@ -82,7 +82,7 @@ def execute(self, verbose=False): @dataclass class Stop: - """Stop command""" + """Stop command.""" value: str = "stop command value" @@ -117,7 +117,7 @@ def execute(self, verbose=False): @dataclass class Program(TestSetup): - """Some top-level command""" + """Some top-level command.""" command: Union[Push, Pull] verbose: bool = False @@ -276,7 +276,6 @@ class Parent(TestSetup): def test_argparse_version_giving_extra_args_to_parent(): - parser = argparse.ArgumentParser() parser.add_argument("--foo", type=int, default=3) diff --git a/test/test_tuples.py b/test/test_tuples.py index a01c46cf..05e3093f 100644 --- a/test/test_tuples.py +++ b/test/test_tuples.py @@ -177,9 +177,7 @@ def test_vanilla_argparse_beheviour( ], ) def test_arg_options_created(self, field_type: Type, expected_options: Dict[str, Any]): - """Check the 'arg_options' that get created for different types of tuple - fields. - """ + """Check the 'arg_options' that get created for different types of tuple fields.""" parser = ArgumentParser() @dataclass diff --git a/test/test_utils.py b/test/test_utils.py index 22a82b64..72d2a2f4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -167,20 +167,20 @@ class A(TestSetup): def test_flatten(): - """Test basic functionality of flatten""" + """Test basic functionality of flatten.""" d = {"a": {"b": 2, "c": 3}, "c": {"d": 3, "e": 4}} assert flatten_join(d) == {"a.b": 2, "a.c": 3, "c.d": 3, "c.e": 4} def test_flatten_double_ref(): - """Test proper handling of double references in dicts""" + """Test proper handling of double references in dicts.""" a = {"b": 2, "c": 3} d = {"a": a, "d": {"e": a}} assert flatten_join(d) == {"a.b": 2, "a.c": 3, "d.e.b": 2, "d.e.c": 3} def test_unflatten(): - """Test than unflatten(flatten(x)) is idempotent""" + """Test than unflatten(flatten(x)) is idempotent.""" a = {"b": 2, "c": 3} d = {"a": a, "d": {"e": a}} assert unflatten_split(flatten_join(d)) == d diff --git a/test/testutils.py b/test/testutils.py index 06ddabf1..374df6e5 100644 --- a/test/testutils.py +++ b/test/testutils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import importlib.util import os import shlex import string @@ -94,7 +95,8 @@ def assert_help_output_equals(actual: str, expected: str) -> None: class TestParser(simple_parsing.ArgumentParser, Generic[T]): __test__ = False - """ A parser subclass just used for testing. + """A parser subclass just used for testing. + Makes the retrieval of the arguments a bit easier to read. """ @@ -293,3 +295,21 @@ def format_lists_using_double_quotes(list_of_lists: list[list[Any]]) -> str: def format_lists_using_single_quotes(list_of_lists: list[list[Any]]) -> str: return " ".join(format_list_using_single_quotes(value_list) for value_list in list_of_lists) + + +YAML_INSTALLED = importlib.util.find_spec("yaml") is not None +needs_yaml = pytest.mark.xfail( + not YAML_INSTALLED, + raises=ModuleNotFoundError, + reason="Test requires pyyaml to be installed.", +) + +TOML_INSTALLED = ( + importlib.util.find_spec("tomli") is not None + and importlib.util.find_spec("tomli_w") is not None +) +needs_toml = pytest.mark.xfail( + not TOML_INSTALLED, + raises=ModuleNotFoundError, + reason="Test requires tomli and tomli_w to be installed.", +) diff --git a/test/utils/test_flattened.py b/test/utils/test_flattened.py index 8f18c7d9..4ce5ef13 100644 --- a/test/utils/test_flattened.py +++ b/test/utils/test_flattened.py @@ -1,5 +1,4 @@ -"""Adds typed dataclasses for the "config" yaml files. -""" +"""Adds typed dataclasses for the "config" yaml files.""" import functools from dataclasses import dataclass, field from test.testutils import pytest, raises diff --git a/test/utils/test_yaml.py b/test/utils/test_yaml.py index a4bfc774..0c4e7825 100644 --- a/test/utils/test_yaml.py +++ b/test/utils/test_yaml.py @@ -1,10 +1,15 @@ -""" Tests for serialization to/from yaml files. """ +"""Tests for serialization to/from yaml files.""" import textwrap from dataclasses import dataclass from typing import List +import pytest + from simple_parsing import list_field -from simple_parsing.helpers.serialization import YamlSerializable + +yaml = pytest.importorskip("yaml") + +from simple_parsing.helpers.serialization.yaml_serialization import YamlSerializable # noqa: E402 @dataclass diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index 89e57280..00000000 --- a/versioneer.py +++ /dev/null @@ -1,2108 +0,0 @@ -# Version: 0.20 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/python-versioneer/python-versioneer -* Brian Warner -* License: Public Domain -* Compatible with: Python 3.6, 3.7, 3.8, 3.9 and pypy3 -* [![Latest Version][pypi-image]][pypi-url] -* [![Build Status][travis-image]][travis-url] - -This is a tool for managing a recorded version number in distutils-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -* `pip install versioneer` to somewhere in your $PATH -* add a `[versioneer]` section to your setup.cfg (see [Install](INSTALL.md)) -* run `versioneer install` in your source tree, commit the results -* Verify version information with `python setup.py version` - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes). - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/python-versioneer/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other languages) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/python-versioneer/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/python-versioneer/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/python-versioneer/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - -## Similar projects - -* [setuptools_scm](https://github.com/pypa/setuptools_scm/) - a non-vendored build-time - dependency -* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of - versioneer -* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools - plugin - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . - -[pypi-image]: https://img.shields.io/pypi/v/versioneer.svg -[pypi-url]: https://pypi.python.org/pypi/versioneer/ -[travis-image]: -https://img.shields.io/travis/com/python-versioneer/python-versioneer.svg -[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer - -""" - -import configparser -import errno -import json -import os -import re -import subprocess -import sys - - -class VersioneerConfig: # pylint: disable=too-few-public-methods # noqa - """Container for Versioneer configuration parameters.""" - - -def get_root(): - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ( - "Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND')." - ) - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - my_path = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(my_path)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print( - "Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(my_path), versioneer_py) - ) - except NameError: - pass - return root - - -def get_config_from_root(root): - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.ConfigParser() - with open(setup_cfg) as cfg_file: - parser.read_file(cfg_file) - VCS = parser.get("versioneer", "VCS") # mandatory - - # Dict-like interface for non-mandatory entries - section = parser["versioneer"] - - # pylint:disable=attribute-defined-outside-init # noqa - cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = section.get("style", "") - cfg.versionfile_source = section.get("versionfile_source") - cfg.versionfile_build = section.get("versionfile_build") - cfg.tag_prefix = section.get("tag_prefix") - if cfg.tag_prefix in ("''", '""'): - cfg.tag_prefix = "" - cfg.parentdir_prefix = section.get("parentdir_prefix") - cfg.verbose = section.get("verbose") - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - HANDLERS.setdefault(vcs, {})[method] = f - return f - - return decorate - - -# pylint:disable=too-many-arguments,consider-using-with # noqa -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen( - [command] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print(f"unable to find command, tried {commands}") - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -LONG_VERSION_PY[ - "git" -] = r''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.20 (https://github.com/python-versioneer/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: # pylint: disable=too-few-public-methods - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -# pylint:disable=too-many-arguments,consider-using-with # noqa -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post0.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post0.dev%%d" %% pieces["distance"] - else: - # exception #1 - rendered = "0.post0.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs) as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r"\d", r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r"\d", r): - continue - if verbose: - print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '{}' doesn't start with prefix '{}'".format( - full_tag, - tag_prefix, - ) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(manifest_in, versionfile_source, ipy): - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] - if ipy: - files.append(ipy) - try: - my_path = __file__ - if my_path.endswith(".pyc") or my_path.endswith(".pyo"): - my_path = os.path.splitext(my_path)[0] + ".py" - versioneer_file = os.path.relpath(my_path) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - with open(".gitattributes") as fobj: - for line in fobj: - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - break - except OSError: - pass - if not present: - with open(".gitattributes", "a+") as fobj: - fobj.write(f"{versionfile_source} export-subst\n") - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print( - "Tried directories %s but none started with prefix %s" - % (str(rootdirs), parentdir_prefix) - ) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.20) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename): - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except OSError: - raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) - if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename, versions): - """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set {} to '{}'".format(filename, versions["version"])) - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post0.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post0.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose=False): - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print(f"got version from file {versionfile_abs} {ver}") - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } - - -def get_version(): - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(cmdclass=None): - """Get the custom setuptools/distutils subclasses used by Versioneer. - - If the package uses a different cmdclass (e.g. one from numpy), it - should be provide as an argument. - """ - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/python-versioneer/python-versioneer/issues/52 - - cmds = {} if cmdclass is None else cmdclass.copy() - - # we add "version" to both distutils and setuptools - from distutils.core import Command - - class cmd_version(Command): - description = "report generated version string" - user_options = [] - boolean_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - - cmds["version"] = cmd_version - - # we override "build_py" in both distutils and setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # we override different "build_py" commands for both environments - if "build_py" in cmds: - _build_py = cmds["build_py"] - elif "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py - else: - from distutils.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - cmds["build_py"] = cmd_build_py - - if "build_ext" in cmds: - _build_ext = cmds["build_ext"] - elif "setuptools" in sys.modules: - from setuptools.command.build_ext import build_ext as _build_ext - else: - from distutils.command.build_ext import build_ext as _build_ext - - class cmd_build_ext(_build_ext): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_ext.run(self) - if self.inplace: - # build_ext --inplace will only build extensions in - # build/lib<..> dir with no _version.py to write to. - # As in place builds will already have a _version.py - # in the module dir, we do not need to write one. - return - # now locate _version.py in the new build/ directory and replace - # it with an updated value - target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - cmds["build_ext"] = cmd_build_ext - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe - - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if "py2exe" in sys.modules: # py2exe enabled? - from py2exe.distutils_buildexe import py2exe as _py2exe - - class cmd_py2exe(_py2exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - cmds["py2exe"] = cmd_py2exe - - # we override different "sdist" commands for both environments - if "sdist" in cmds: - _sdist = cmds["sdist"] - elif "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist - else: - from distutils.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self): - versions = get_versions() - # pylint:disable=attribute-defined-outside-init # noqa - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir, files): - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, self._versioneer_generated_versions) - - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -OLD_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - -INIT_PY_SNIPPET = """ -from . import {0} -__version__ = {0}.get_versions()['version'] -""" - - -def do_setup(): - """Do main VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write( - LONG - % { - "DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - } - ) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") - if os.path.exists(ipy): - try: - with open(ipy) as f: - old = f.read() - except OSError: - old = "" - module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] - snippet = INIT_PY_SNIPPET.format(module) - if OLD_SNIPPET in old: - print(" replacing boilerplate in %s" % ipy) - with open(ipy, "w") as f: - f.write(old.replace(OLD_SNIPPET, snippet)) - elif snippet not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(snippet) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in) as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except OSError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % cfg.versionfile_source) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) - return 0 - - -def scan_setup_py(): - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1)