Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partials - Dynamic Config Dataclasses for arbitrary callables #156

Merged
merged 32 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8ca5445
Partials feature POC
lebrice May 9, 2022
0ae990b
Functools black magic, partials are pickleable
lebrice May 11, 2022
4fe5e19
Partials feature POC
lebrice May 9, 2022
9af694a
Functools black magic, partials are pickleable
lebrice May 11, 2022
6f2e1ab
Add postponed annotation version of test
lebrice Aug 11, 2022
13d03d9
Apply pre-commit hooks to partial.py
lebrice Aug 11, 2022
a1707dc
Fix example, rename typevars
lebrice Aug 11, 2022
6638c7a
Add comments in the partials_example.py
lebrice Aug 15, 2022
0ed5c89
Merge branch 'master' into partials
lebrice Jan 11, 2023
c42ef22
Merge branch 'master' into partials
lebrice Jan 26, 2023
f94f785
Fix the partials_example.py file
lebrice Feb 7, 2023
72dfffc
Add `nested_partial` helper function
lebrice Feb 7, 2023
a88c953
Tweak the partials_example.py
lebrice Feb 7, 2023
b204879
Merge branch 'master' into partials
lebrice Feb 7, 2023
4e86e8a
Fix issue with using functools.partial[T] in py37
lebrice Feb 20, 2023
dff9d35
Adding some more tests for Partial
lebrice Feb 20, 2023
89fdd9d
Merge branch 'master' into partials
lebrice Feb 20, 2023
4a9c8b7
Merge branch 'master' into partials
lebrice Mar 1, 2023
c5e619d
Simplify `partial.py` a bit
lebrice Mar 1, 2023
8b9bc8f
Merge branch 'master' into partials
lebrice Mar 13, 2023
b2e39e7
Add test from PR suggestion, add `sp.config_for`
lebrice Mar 13, 2023
7844d4e
Fix missing ``` in docstring
lebrice Mar 13, 2023
feba0f7
Remove torch.optim.SGD fix an old BUG comment
lebrice Mar 14, 2023
bd66e57
Improve docstring of `config_for`
lebrice Mar 14, 2023
115f258
Add `adjust_default` in __all__
lebrice Mar 14, 2023
52b1c9d
Fix import issue in test_partial_postponed.py
lebrice Mar 14, 2023
4810776
Remove kw_only which appeared in py>=3.9
lebrice Mar 14, 2023
a47d485
Update regression files (idk why though?!)
lebrice Mar 14, 2023
18b5699
Actually use a frozen instance as default in test
lebrice Mar 14, 2023
657b5d7
Add `frozen` argument that gets passed through
lebrice Mar 14, 2023
c8c49c7
Fix doctest
lebrice Apr 19, 2023
cad23cc
Merge branch 'master' into partials
lebrice Apr 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/partials/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Partials - Configuring arbitrary classes / callables

86 changes: 86 additions & 0 deletions examples/partials/partials_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

from dataclasses import dataclass

from simple_parsing import ArgumentParser
from simple_parsing.helpers import subgroups
from simple_parsing.helpers.partial import Partial, config_for


# Suppose we want to choose between the Adam and SGD optimizers from PyTorch:
# (NOTE: We don't import pytorch here, so we just create the types to illustrate)
class Optimizer:
def __init__(self, params):
...


class Adam(Optimizer):
def __init__(
self,
params,
lr: float = 3e-4,
beta1: float = 0.9,
beta2: float = 0.999,
eps: float = 1e-08,
):
self.params = params
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps


class SGD(Optimizer):
def __init__(
self,
params,
lr: float = 3e-4,
weight_decay: float | None = None,
momentum: float = 0.9,
eps: float = 1e-08,
):
self.params = params
self.lr = lr
self.weight_decay = weight_decay
self.momentum = momentum
self.eps = eps


# Dynamically create a dataclass that will be used for the above type:
# NOTE: We could use Partial[Adam] or Partial[Optimizer], however this would treat `params` as a
# required argument.
# AdamConfig = Partial[Adam] # would treat 'params' as a required argument.
# SGDConfig = Partial[SGD] # same here
AdamConfig: type[Partial[Adam]] = config_for(Adam, ignore_args="params")
SGDConfig: type[Partial[SGD]] = config_for(SGD, ignore_args="params")


@dataclass
class Config:

# Which optimizer to use.
optimizer: Partial[Optimizer] = subgroups(
{
"sgd": SGDConfig,
"adam": AdamConfig,
},
default_factory=AdamConfig,
)


parser = ArgumentParser()
parser.add_arguments(Config, "config")
args = parser.parse_args()


config: Config = args.config
print(config)
expected = "Config(optimizer=AdamConfig(lr=0.0003, beta1=0.9, beta2=0.999, eps=1e-08))"
lebrice marked this conversation as resolved.
Show resolved Hide resolved

my_model_parameters = [123] # nn.Sequential(...).parameters()

optimizer = config.optimizer(params=my_model_parameters)
print(vars(optimizer))
expected += """
{'params': [123], 'lr': 0.0003, 'beta1': 0.9, 'beta2': 0.999, 'eps': 1e-08}
"""
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
docstring-parser~=0.15
typing_extensions>=4.3.0
typing_extensions>=4.5.0
4 changes: 4 additions & 0 deletions simple_parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from .decorators import main
from .help_formatter import SimpleHelpFormatter
from .helpers import (
Partial,
Serializable,
choice,
config_for,
field,
flag,
list_field,
Expand All @@ -31,6 +33,7 @@
"ArgumentGenerationMode",
"ArgumentParser",
"choice",
"config_for",
"ConflictResolution",
"DashVariant",
"field",
Expand All @@ -44,6 +47,7 @@
"parse_known_args",
"parse",
"ParsingError",
"Partial",
"replace",
"Serializable",
"SimpleHelpFormatter",
Expand Down
1 change: 1 addition & 0 deletions simple_parsing/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .fields import *
from .flatten import FlattenedAccess
from .hparams import HyperParameters
from .partial import Partial, config_for
from .serialization import FrozenSerializable, Serializable, SimpleJsonEncoder, encode

try:
Expand Down
48 changes: 48 additions & 0 deletions simple_parsing/helpers/nested_partial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import functools
from typing import Any, Generic, TypeVar

_T = TypeVar("_T")


class npartial(functools.partial, Generic[_T]):
lebrice marked this conversation as resolved.
Show resolved Hide resolved
"""Partial that also invokes partials in args and kwargs before feeding them to the function.

Useful for creating nested partials, e.g.:


>>> from dataclasses import dataclass, field
>>> @dataclass
... class Value:
... v: int = 0
>>> @dataclass
... class ValueWrapper:
... value: Value
...
>>> from functools import partial
>>> @dataclass
... class WithRegularPartial:
... wrapped: ValueWrapper = field(
... default_factory=partial(ValueWrapper, value=Value(v=123)),
... )

Here's the problem: This here is BAD! They both share the same instance of Value!

>>> WithRegularPartial().wrapped.value is WithRegularPartial().wrapped.value
True
>>> @dataclass
... class WithNPartial:
... wrapped: ValueWrapper = field(
... default_factory=npartial(ValueWrapper, value=npartial(Value, v=123)),
... )
>>> WithNPartial().wrapped.value is WithNPartial().wrapped.value
False

This is fine now!
"""

def __call__(self, *args: Any, **keywords: Any) -> _T:
keywords = {**self.keywords, **keywords}
args = self.args + args
args = tuple(arg() if isinstance(arg, npartial) else arg for arg in args)
keywords = {k: v() if isinstance(v, npartial) else v for k, v in keywords.items()}
return self.func(*args, **keywords)
Loading