Skip to content

Commit

Permalink
Move the subgroup tests to a folder
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Jul 26, 2023
1 parent cd6fd81 commit 646631c
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 27 deletions.
45 changes: 29 additions & 16 deletions test/test_subgroups.py → test/subgroups/test_subgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from simple_parsing import ArgumentParser, parse, subgroups
from simple_parsing.wrappers.field_wrapper import ArgumentGenerationMode, NestedMode

from .test_choice import Color
from .testutils import TestSetup, raises_invalid_choice, raises_missing_required_arg
from ..test_choice import Color
from ..testutils import TestSetup, raises_invalid_choice, raises_missing_required_arg

TestClass = TypeVar("TestClass", bound=TestSetup)

Expand Down Expand Up @@ -189,7 +189,7 @@ def test_parse(dataclass_type: type[TestClass], args: str, expected: TestClass):


def test_subgroup_choice_is_saved_on_namespace():
"""test for https://github.com/lebrice/SimpleParsing/issues/139
"""Test for https://github.com/lebrice/SimpleParsing/issues/139.
Need to save the chosen subgroup name somewhere on the args.
"""
Expand Down Expand Up @@ -244,7 +244,6 @@ def test_two_subgroups_with_conflict(args_str: str, expected: TwoSubgroupsWithCo


def test_subgroups_with_key_default() -> None:

with pytest.raises(ValueError):
subgroups({"a": A, "b": B}, default_factory="a")

Expand All @@ -270,13 +269,17 @@ def test_subgroup_default_needs_to_be_key_in_dict():


def test_subgroup_default_factory_needs_to_be_value_in_dict():
with pytest.raises(ValueError, match="`default_factory` must be a value in the subgroups dict"):
with pytest.raises(
ValueError, match="`default_factory` must be a value in the subgroups dict"
):
_ = subgroups({"a": B, "aa": A}, default_factory=C)


def test_lambdas_dont_return_same_instance():
"""Slightly unrelated, but I just want to check if lambda expressions return the same object
instance when a default factory looks like `lambda: A()`. If so, then I won't encourage this.
instance when a default factory looks like `lambda: A()`.
If so, then I won't encourage this.
"""

@dataclass
Expand All @@ -292,8 +295,7 @@ class Config(TestSetup):

def test_partials_new_args_overwrite_set_values():
"""Double-check that functools.partial overwrites the keywords that are stored when it is
created with the ones that are passed when calling it.
"""
created with the ones that are passed when calling it."""
# just to avoid the test passing if I were to hard-code the same value as the default by
# accident.
default_a = A().a
Expand Down Expand Up @@ -423,7 +425,10 @@ class Foo(TestSetup):
marks=pytest.mark.xfail(
strict=True,
raises=NotImplementedError,
reason="Lambda expressions aren't allowed in the subgroup dict or default_factory at the moment.",
reason=(
"Lambda expressions aren't allowed in the subgroup dict or default_factory at the "
"moment."
),
),
)

Expand All @@ -438,7 +443,8 @@ class Foo(TestSetup):
],
)
def test_other_default_factories(a_factory: Callable[[], A], b_factory: Callable[[], B]):
"""Test using other kinds of default factories (i.e. functools.partial or lambda expressions)"""
"""Test using other kinds of default factories (i.e. functools.partial or lambda
expressions)"""

@dataclass
class Foo(TestSetup):
Expand Down Expand Up @@ -467,6 +473,7 @@ def test_help_string_displays_default_factory_arguments(
When using `functools.partial` or lambda expressions, we'd ideally also like the help text to
show the field values from inside the `partial` or lambda, if possible.
"""

# NOTE: Here we need to return just A() and B() with these default factories, so the defaults
# for the fields are the same
@dataclass
Expand Down Expand Up @@ -585,7 +592,6 @@ class ModelBConfig(ModelConfig):

@dataclass
class Config(TestSetup):

# Which model to use
model: ModelConfig = subgroups(
{"model_a": ModelAConfig, "model_b": ModelBConfig},
Expand All @@ -594,7 +600,7 @@ class Config(TestSetup):


def test_destination_substring_of_other_destination_issue191():
"""Test for https://github.com/lebrice/SimpleParsing/issues/191"""
"""Test for https://github.com/lebrice/SimpleParsing/issues/191."""

parser = ArgumentParser()
parser.add_arguments(Config, dest="config")
Expand Down Expand Up @@ -666,7 +672,9 @@ def test_annotated_as_subgroups():

@dataclasses.dataclass
class Config(TestSetup):
model: Model = subgroups({"small": SmallModel, "big": BigModel}, default_factory=SmallModel)
model: Model = subgroups(
{"small": SmallModel, "big": BigModel}, default_factory=SmallModel
)

assert Config.setup().model == SmallModel()
# Hopefully this illustrates why Annotated aren't exactly great:
Expand Down Expand Up @@ -722,6 +730,9 @@ def test_subgroups_supports_frozen_instances(command: str, expected: ConfigWithF
assert ConfigWithFrozen.setup(command) == expected


@pytest.mark.skipif(
sys.version_info[:2] != (3, 11), reason="The regression check is only run with Python 3.11"
)
@pytest.mark.parametrize(
("dataclass_type", "command"),
[
Expand Down Expand Up @@ -782,7 +793,10 @@ def test_help(

# @dataclasses.dataclass
# class Config(TestSetup):
# model: Model = subgroups({"small": SmallModel, "big": BigModel}, default_factory=SmallModel)
# model: Model = subgroups(
# {"small": SmallModel, "big": BigModel},
# default_factory=SmallModel,
# )

# assert Config.setup().model == SmallModel()
# # Hopefully this illustrates why Annotated aren't exactly great:
Expand All @@ -799,7 +813,7 @@ def test_help(

@pytest.mark.parametrize("frozen", [True, False])
def test_nested_subgroups(frozen: bool):
"""Assert that #160 is fixed: https://github.com/lebrice/SimpleParsing/issues/160"""
"""Assert that #160 is fixed: https://github.com/lebrice/SimpleParsing/issues/160."""

@dataclass(frozen=frozen)
class FooConfig:
Expand Down Expand Up @@ -880,7 +894,6 @@ class Dataset2Config(DatasetConfig):

@dataclass
class Config(TestSetup):

# Which model to use
model: ModelConfig = subgroups(
{"model_a": ModelAConfig, "model_b": ModelBConfig},
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Regression file for [this test](test/test_subgroups.py:725)
# Regression file for [this test](test/subgroups/test_subgroups.py:724)

Given Source code:

```python
@dataclass
class Config(TestSetup):

# Which model to use
model: ModelConfig = subgroups(
{"model_a": ModelAConfig, "model_b": ModelBConfig},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Regression file for [this test](test/test_subgroups.py:725)
# Regression file for [this test](test/subgroups/test_subgroups.py:724)

Given Source code:

```python
@dataclass
class Config(TestSetup):

# Which model to use
model: ModelConfig = subgroups(
{"model_a": ModelAConfig, "model_b": ModelBConfig},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Regression file for [this test](test/test_subgroups.py:725)
# Regression file for [this test](test/subgroups/test_subgroups.py:724)

Given Source code:

```python
@dataclass
class Config(TestSetup):

# Which model to use
model: ModelConfig = subgroups(
{"model_a": ModelAConfig, "model_b": ModelBConfig},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Regression file for [this test](test/test_subgroups.py:725)
# Regression file for [this test](test/subgroups/test_subgroups.py:724)

Given Source code:

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Regression file for [this test](test/test_subgroups.py:725)
# Regression file for [this test](test/subgroups/test_subgroups.py:724)

Given Source code:

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Regression file for [this test](test/test_subgroups.py:725)
# Regression file for [this test](test/subgroups/test_subgroups.py:724)

Given Source code:

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Regression file for [this test](test/test_subgroups.py:725)
# Regression file for [this test](test/subgroups/test_subgroups.py:724)

Given Source code:

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Regression file for [this test](test/test_subgroups.py:725)
# Regression file for [this test](test/subgroups/test_subgroups.py:724)

Given Source code:

Expand Down

0 comments on commit 646631c

Please sign in to comment.