Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexibility to the core parsing API #34

Closed
JesseFarebro opened this issue Feb 7, 2023 · 2 comments
Closed

More flexibility to the core parsing API #34

JesseFarebro opened this issue Feb 7, 2023 · 2 comments

Comments

@JesseFarebro
Copy link
Contributor

JesseFarebro commented Feb 7, 2023

Hi! Just wanted to start off by saying that the library looks great 🎉. I'm trying to assess the feasibility of Tyro for my use-case, and there are a couple suggestions that maybe you'd consider:

  1. It would be great to have more flexibility to get the parsed arguments without calling the function. I was thinking potentially having a function that returns something like inspect.BoundArguments after parsing is complete. This way we can have more flexibility on how we call the function. Another possibility might be an API where you don't provide a function, but just a dataclass that'll get hydrated from the CLI. EDIT: I see now that you can parse a dataclass using tyro.cli.

  2. It would be great if there was a flag like strict that controls whether argparse attempts to parse all arguments or only known arguments. For example, you could parse known arguments with Tyro then default back to absl to parse Jax config flags. To accomplish this, you'd need some way of knowing what Tyro wasn't able to parse. This is somewhat possible already by doing:

    1. _, unknown_args = tyro.extras.get_parser(f).parse_known_args()
    2. Filter out unknown_args from sys.argv and parse these separately.
    3. tyro.cli(f, args=...) with the filtered args.

    It would be nice if there was an easier way to perform this. Perhaps this could be specifically added to (1) where if you specify strict=False it'll return a tuple BoundArguments, List[str] where the second element is the unknown arguments.

  3. A crazy feature that might be useful to others is if there was a way to go from a dataclass instance (or annotation of a dataclass) to the (minimal?) CLI arguments needed to generate that dataclass. e.g., something like tyro.extras.to_cli_args. The use-case here is to keep all your configs / sweeps in Python, this would be specifically useful for sweeps. You could have a generator over (annotated?) dataclasses and then convert those to CLI arguments when launching a job.

@brentyi
Copy link
Owner

brentyi commented Feb 8, 2023

Thanks for giving the library a try!

  1. For generating BoundArguments, I think something like this should work:

    import functools
    import inspect
    import tyro
    
    def func(a: int, b: int):
        print(a + b)
    
    bind_args = functools.wraps(func)(
        lambda *args, **kwargs: inspect.signature(func).bind(*args, **kwargs)
    )
    
    args = tyro.cli(bind_args)
    print(args)

    Ideally we could just call tyro.cli() on signature(func).bind, but unfortunately this doesn't return a function with an inspectable signature. Calling wraps() without the redundant-looking lambda function also results in an error.

  2. Sure, I'm open to adding this. To match argparse naming, maybe could be a return_unknown_args: bool flag in tyro.cli() that results in an extra output? May also want to consider a flag for turning off or renaming the --help flag. PRs would be appreciated here if you have time.

    In the meantime for JAX config stuff it seems like just adding an extra field to the callable that tyro takes and then calling the usual from jax.config import config; config.update(...) would result in cleaner error messages.
    This could even be a dict:

    import tyro
    
    
    def train(
        # train_config: YourTrainConfigDataclassCanAlsoGoHere,
        jax_config: dict = {
            "enable_x64": False,
            "debug_nans": False,
            "disable_jit": False,
        }
    ):
        from jax.config import config
    
        for k, v in jax_config.items():
            config.update("jax_" + k, v)
    
    
    tyro.cli(train)

    Or, to change the flags from --jax-config.enable-x64 to --jax.enable-x64:

    from typing_extensions import Annotated
    
    import tyro
    
    
    def train(
        # train_config: YourTrainConfigDataclassCanAlsoGoHere,
        jax_config: Annotated[dict, tyro.conf.arg(name="jax")] = {
            "enable_x64": False,
            "debug_nans": False,
            "disable_jit": False,
        }
    ):
        from jax.config import config
    
        for k, v in jax_config.items():
            config.update("jax_" + k, v)
    
    
    tyro.cli(train)
  3. I've thought about this, and getting Smarter matching for default subcommands #33 merged should help. The main reason why this may not be feasible is that argument parsing depends heavily on rules for converting strings from the command-line to instances of annotated types. Unless we constrain the types supported by tyro and handwrite rules there will be cases where we can't invert this conversion, and I haven't yet thought of a use-case that's compelling enough to motivate the development / maintenance effort.

Does that all make sense?

@JesseFarebro
Copy link
Contributor Author

  1. Interesting, when I wrote my original post I didn't fully understand tyro.cli but I see it's more flexible than I once thought.

  2. Makes sense, I can probably circle back and make a PR for this. The Jax config was a hypothetical, the reasons for needing this go beyond that, so it'd be nice to have this built in.

  3. Yeah, that makes sense. There's probably a variation of this that's already possible, e.g., serialize the dataclass to YAML then use tyro.cli(..., default=...).

Thanks for quick response, I'll submit a PR soon for (2).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants