-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multiple class choices & initialization functions #32
Comments
Thanks for giving the library a try! It won't look as clean as your factory methods, but how do you feel about defining a dataclass to replace each factory method, something like: @dataclass
class BarFromX:
arg1: int
arg2: int
def instantiate(self) -> Bar:
return Bar(self.arg1, self.arg2, "X") And then taking the union over all of them + calling I've thought about factory methods like what you've described a few times (#30 is an attempt at something related). Basic support in def main(
experiment_name: str = "experiment",
config: Config = Config(...),
) -> None:
pass
tyro.cli(main) When If def main(
experiment_name: str = "experiment",
config: Config = Config.from_args(arg1=3, arg2=5),
) -> None:
pass
tyro.cli(main) Intuitively, this should create a Note that |
Hi Brent, thanks a lot for the thorough reply! The idea of combining With # constructor: Foo.from_X, Foo.from_Y, Bar.from_X, ...
# dynamic_config_class_name: Foo, Bar, ...
cfg = builds(constructor, populate_full_signature=True, zen_dataclass={'cls_name': dynamic_config_class_name}) Then I aggregate them together in a list to define a new # FooBar can combine different ctor configs that can build either Foo or Bar
FooBar = TypeVar(group_name, cfg1, cfg2, cfg3, ...)
def my_main(foobar: FooBar) -> None:
my_foo_or_bar_instance = instantiate(foobar) # hydra-zen for cfg -> instance here
tyro.cli(my_main) # let tyro handle cli
The only limitation I have so far is somewhat cosmetic. The usage of subcommands means that:
Are there any hooks in place to customize subcommand / help behaviors? |
Glad you got that working! Depending on how much you care about writing "correct" type signatures, you might consider replacing the from typing import Type, TYPE_CHECKING, Union
from hydra_zen.typing import Builds
if TYPE_CHECKING:
# For type checking, use a statically analyzable type.
FooBarConfig = Builds[Type[Union[FooConfig, BarConfig]]]
else:
# At runtime, use the dynamic dataclasses. This will be what's consumed by `tyro`.
FooBarConfig = Union[cfg1, cfg2, etc] This is gross but will fix static type resolution + tab complete for For your two questions:
|
Thanks again! The dynamic_types = call_inspection_func() # this one returns [cfg1, cfg2, cfg3...]
FooBar = TypeVar(group_name, *dynamic_types) # type: ignore
# or alternatively to "fool" the warnings:
# T = TypeVar('T', dynamic_types[0], *dynamic_types[1:]) Union doesn't really like the Asterisk (*) operator, which is why I ended up using Union[TypeVar(group_name, *dynamic_types).__constraints__] # type: ignore One tradeoff maybe, is to opt for your mode if users explicitly generated dataclasses for everything. |
For replacing the TypeVar with a Union, how about Another suggestion is that you could generate the dynamic union type from the statically analyzable one. This might require less boilerplate. I tried mocking something up for this, which works and seems OK: from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Any, Type, Union, get_args, get_origin
import hydra_zen
import tyro
from hydra_zen.typing import Builds
from typing_extensions import Annotated, reveal_type
class Foo:
@classmethod
def from_X(cls, a: int, b: int) -> Foo:
return Foo()
@classmethod
def from_Y(cls, c: int, d: int) -> Foo:
return Foo()
class Bar:
@classmethod
def from_X(cls, a: int, b: int) -> Foo:
return Foo()
@classmethod
def from_Y(cls, c: int, d: int) -> Foo:
return Foo()
def dynamic_union_from_static_union(typ: Type[Builds[Type]]) -> Any:
# Builds[Type[Foo | Bar]] => Type[Foo | Bar]
(typ,) = get_args(typ)
# Type[Foo | Bar] => Foo | Bar
assert get_origin(typ) is type
(union_type,) = get_args(typ)
# Foo | Bar => Foo, Bar
config_types = get_args(union_type)
# Get constructors.
constructors = []
for config_type in config_types:
constructors.extend(
[
method
for name, method in inspect.getmembers(
config_type, predicate=inspect.ismethod
)
if name.startswith("from_")
and hasattr(method, "__self__")
and method.__self__ is config_type
]
)
# Return union over dynamic dataclasses, one for each constructor type.
return Union.__getitem__( # type: ignore
tuple(
Annotated[
# Create the dynamic dataclass.
hydra_zen.builds(c, populate_full_signature=True),
# Rename the subcommand.
tyro.conf.subcommand(
c.__self__.__name__.lower() + "_" + c.__name__.lower(),
prefix_name=False,
),
]
for c in constructors
)
)
Config = Builds[Type[Union[Foo, Bar]]]
if not TYPE_CHECKING:
Config = dynamic_union_from_static_union(Config)
def main(config: Config) -> None:
# Should resolve to `Bar | Foo`.
reveal_type(hydra_zen.instantiate(config))
if __name__ == "__main__":
tyro.cli(main) Documentation and For the helptext stuff, I'm guessing you could figure this out yourself, but the custom argparse formatter is probably what you want to look at! |
@brentyi Thanks for all the useful advice again! I finally took care of the flat class TyroFlatSubcommandHelpFormatter(tyro._argparse_formatter.TyroArgparseHelpFormatter):
def add_usage(self, usage, actions, groups, prefix=None):
aggregated_subcommand_group = []
for action_name, sub_parser in self.collect_subcommands_parsers(actions).items():
for sub_action_group in sub_parser._action_groups:
sub_group_actions = sub_action_group._group_actions
if len(sub_group_actions) > 0:
is_subparser_action = lambda x: isinstance(x, argparse._SubParsersAction)
is_help_action = lambda x: isinstance(x, argparse._HelpAction)
if any([is_subparser_action(a) and not is_help_action(a) for a in sub_group_actions]):
aggregated_subcommand_group.append(sub_action_group)
# Remove duplicate subcommand parsers
aggregated_subcommand_group = list({a._group_actions[0].metavar: a
for a in aggregated_subcommand_group}.values())
next_actions = [g._group_actions[0] for g in aggregated_subcommand_group]
actions.extend(next_actions)
super().add_usage(usage, actions, groups, prefix)
def add_arguments(self, action_group):
if len(action_group) > 0 and action_group[0].container.title == 'subcommands':
# If a subcommands action group - rename first subcommand (for which this function was invoked)
choices_header = next(iter(action_group[0].choices))
choices_title = choices_header.split(':')[0] + ' choices'
action_group[0].container.title = choices_title
self._current_section.heading = choices_title # Formatter have already set a section, override heading
# Invoke default
super().add_arguments(action_group)
aggregated_action_group = []
aggregated_subcommand_group = []
for action in action_group:
if not isinstance(action, argparse._SubParsersAction):
continue
for action_name, sub_parser in self.collect_subcommands_parsers([action]).items():
sub_parser.formatter_class = self
for sub_action_group in sub_parser._action_groups:
sub_group_actions = sub_action_group._group_actions
if len(sub_group_actions) > 0:
is_subparser_action = lambda x: isinstance(x, argparse._SubParsersAction)
is_help_action = lambda x: isinstance(x, argparse._HelpAction)
if any([not is_subparser_action(a) and not is_help_action(a) for a in sub_group_actions]):
for a in sub_group_actions:
a.container.title = action_name + ' arguments'
aggregated_action_group.append(sub_action_group)
elif any([not is_help_action(a) for a in sub_group_actions]):
for a in sub_group_actions:
choices_header = next(iter(sub_group_actions[0].choices))
a.container.title = choices_header.split(':')[0] + ' choices'
aggregated_subcommand_group.append(sub_action_group)
# Remove duplicate subcommand parsers
aggregated_subcommand_group = list({a._group_actions[0].metavar: a
for a in aggregated_subcommand_group}.values())
for aggregated_group in (aggregated_subcommand_group, aggregated_action_group):
for next_action_group in aggregated_group:
self.end_section()
self.start_section(next_action_group.title)
self.add_text(next_action_group.description)
super().add_arguments(next_action_group._group_actions)
def collect_subcommands_parsers(self, actions):
collected_titles = list()
collected_subparsers = list()
parsers = list()
def _handle_actions(_actions):
action_choices = [action.choices for action in _actions if isinstance(action, argparse._SubParsersAction)]
for choices in action_choices:
for subcommand, subcommand_parser in choices.items():
collected_titles.append(subcommand)
collected_subparsers.append(subcommand_parser)
parsers.append(subcommand_parser)
_handle_actions(actions)
while parsers:
parser = parsers.pop(0)
_handle_actions(parser._actions)
# Eliminate duplicates and preserve order (dicts are guaranteed to preserve insertion order from python >=3.7)
return dict(zip(collected_titles, collected_subparsers)) I can test it like this: parser = tyro.extras.get_parser(Config)
parser.formatter_class = TyroFlatSubcommandHelpFormatter
parser.print_help() but |
Cool! Yeah, I guess the hacky short-term solution is a monkey patch? tyro._argparse_formatter.TyroArgparseHelpFormatter = TyroFlatSubcommandHelpFormatter
tyro.cli(...) Accepting + supporting custom formatters seems like a can of worms that I'm not sure we want to open...! |
That makes sense!
|
That makes sense! A marker makes sense, but without documentation would imply some level of fine-grained control if we have a deeply nested subcommand tree and apply the annotation at an intermediate level, for example. Is this possible to implement? For these aesthetic things like this global state might also be okay, like we currently have a https://brentyi.github.io/tyro/api/tyro/extras/#tyro.extras.set_accent_color We could broaden this a bit into something like |
Hi tyro team,
Thanks for a great library,
tyro
is a blast and really helps simplify config code so far!There is a convoluted scenario I'm still trying to figure out how to apply tyro to:
Foo
andBar
are interchangeable classes (i.e. think 2 different dataset implementations), and each can be constructed with different constructors.We have a function which generates a
dataclass
automatically from functions / class__init__
.We want our config system to be able to:
Foo
orBar
config should be created, something similar to the Subcommands except we don't have explicit dataclasses (as they're constructed dynamically).from_..
construction methods within the same config, but still allow cli to explicitly show these are separate group args (i.e. Foo.fromX has arg1, arg2 and Foo.fromY has arg3 and arg4).What's the best practice to approach this with
tyro
?Thanks in advance!
The text was updated successfully, but these errors were encountered: