From 19c6014d3589ddaaf56f50fbaf67de5329b7befb Mon Sep 17 00:00:00 2001 From: rsteube Date: Mon, 24 Jul 2023 09:21:38 +0200 Subject: [PATCH] traverse: fix shorthand chain --- defaultActions_test.go | 2 +- example/cmd/chain.go | 45 ++++++++++++++ example/cmd/chain_test.go | 80 +++++++++++++++++++++++++ example/cmd/root_test.go | 1 + internal/pflagfork/flag.go | 67 +++++++-------------- internal/pflagfork/flagset.go | 94 ++++++++++++++++++++++++++++-- internal/pflagfork/flagset_test.go | 48 +++++++++++++++ internalActions.go | 9 ++- traverse.go | 65 ++++++--------------- 9 files changed, 313 insertions(+), 98 deletions(-) create mode 100644 example/cmd/chain.go create mode 100644 example/cmd/chain_test.go create mode 100644 internal/pflagfork/flagset_test.go diff --git a/defaultActions_test.go b/defaultActions_test.go index 14fc8fda3..0ed257179 100644 --- a/defaultActions_test.go +++ b/defaultActions_test.go @@ -39,7 +39,7 @@ func TestActionFlags(t *testing.T) { cmd.Flag("alpha").Changed = true a := actionFlags(cmd).Invoke(Context{Value: "-a"}) - assertEqual(t, ActionValuesDescribed("b", "", "h", "help for this command").Tag("flags").NoSpace().Invoke(Context{}).Prefix("-a"), a) + assertEqual(t, ActionValuesDescribed("b", "", "h", "help for this command").Tag("flags").NoSpace('b', 'h').Invoke(Context{}).Prefix("-a"), a) } func TestActionExecCommandEnv(t *testing.T) { diff --git a/example/cmd/chain.go b/example/cmd/chain.go new file mode 100644 index 000000000..406cb3733 --- /dev/null +++ b/example/cmd/chain.go @@ -0,0 +1,45 @@ +package cmd + +import ( + "github.com/rsteube/carapace" + "github.com/spf13/cobra" +) + +var chainCmd = &cobra.Command{ + Use: "chain", + Short: "shorthand chain", + Run: func(cmd *cobra.Command, args []string) {}, + DisableFlagParsing: true, +} + +func init() { + carapace.Gen(chainCmd).Standalone() + + rootCmd.AddCommand(chainCmd) + + carapace.Gen(chainCmd).PositionalAnyCompletion( + carapace.ActionCallback(func(c carapace.Context) carapace.Action { + cmd := &cobra.Command{} + carapace.Gen(cmd).Standalone() + + cmd.Flags().CountP("count", "c", "") + cmd.Flags().BoolP("bool", "b", false, "") + cmd.Flags().StringP("value", "v", "", "") + cmd.Flags().StringP("optarg", "o", "", "") + + cmd.Flag("optarg").NoOptDefVal = " " + + carapace.Gen(cmd).FlagCompletion(carapace.ActionMap{ + "value": carapace.ActionValues("val1", "val2"), + "optarg": carapace.ActionValues("opt1", "opt2"), + }) + + carapace.Gen(cmd).PositionalCompletion( + carapace.ActionValues("p1", "positional1"), + ) + + return carapace.ActionExecute(cmd) + }), + ) + +} diff --git a/example/cmd/chain_test.go b/example/cmd/chain_test.go new file mode 100644 index 000000000..f59783825 --- /dev/null +++ b/example/cmd/chain_test.go @@ -0,0 +1,80 @@ +package cmd + +import ( + "testing" + + "github.com/rsteube/carapace" + "github.com/rsteube/carapace/pkg/sandbox" + "github.com/rsteube/carapace/pkg/style" +) + +func TestShorthandChain(t *testing.T) { + sandbox.Package(t, "github.com/rsteube/carapace/example")(func(s *sandbox.Sandbox) { + s.Run("chain", "-b"). + Expect(carapace.ActionStyledValues( + "c", style.Default, + "o", style.Yellow, + "v", style.Blue, + ).Prefix("-b"). + NoSpace('c', 'o'). + Tag("flags")) + + s.Run("chain", "-bc"). + Expect(carapace.ActionStyledValues( + "c", style.Default, + "o", style.Yellow, + "v", style.Blue, + ).Prefix("-bc"). + NoSpace('c', 'o'). + Tag("flags")) + + s.Run("chain", "-bcc"). + Expect(carapace.ActionStyledValues( + "c", style.Default, + "o", style.Yellow, + "v", style.Blue, + ).Prefix("-bcc"). + NoSpace('c', 'o'). + Tag("flags")) + + s.Run("chain", "-bcco"). + Expect(carapace.ActionStyledValues( + "c", style.Default, + "v", style.Blue, + ).Prefix("-bcco"). + NoSpace('c'). + Tag("flags")) + + s.Run("chain", "-bcco", ""). + Expect(carapace.ActionValues( + "p1", + "positional1", + )) + + s.Run("chain", "-bcco="). + Expect(carapace.ActionValues( + "opt1", + "opt2", + ).Prefix("-bcco=")) + + s.Run("chain", "-bccv", ""). + Expect(carapace.ActionValues( + "val1", + "val2", + )) + + s.Run("chain", "-bccv="). + Expect(carapace.ActionValues( + "val1", + "val2", + ).Prefix("-bccv=")) + + s.Run("chain", "-bccv", "val1", "-c"). + Expect(carapace.ActionStyledValues( + "c", style.Default, + "o", style.Yellow, + ).Prefix("-c"). + NoSpace('c', 'o'). + Tag("flags")) + }) +} diff --git a/example/cmd/root_test.go b/example/cmd/root_test.go index fcd5907af..10640150c 100644 --- a/example/cmd/root_test.go +++ b/example/cmd/root_test.go @@ -72,6 +72,7 @@ func TestRoot(t *testing.T) { "injection", "just trying to break things", ).Style(style.Magenta).Tag("test commands"), carapace.ActionValuesDescribed( + "chain", "shorthand chain", "completion", "Generate the autocompletion script for the specified shell", "group", "group example", "help", "Help about any command", diff --git a/internal/pflagfork/flag.go b/internal/pflagfork/flag.go index adf359e1b..79465df79 100644 --- a/internal/pflagfork/flag.go +++ b/internal/pflagfork/flag.go @@ -21,6 +21,8 @@ const ( type Flag struct { *pflag.Flag + Prefix string + Args []string } func (f Flag) Nargs() int { @@ -53,52 +55,6 @@ func (f Flag) IsRepeatable() bool { return false } -func (f Flag) Split(arg string) []string { - delimiter := string(f.OptargDelimiter()) - return strings.SplitAfterN(arg, delimiter, 2) -} - -func (f Flag) Matches(arg string, posix bool) bool { - if !strings.HasPrefix(arg, "-") { // not a flag - return false - } - - switch { - - case strings.HasPrefix(arg, "--"): - name := strings.TrimPrefix(arg, "--") - name = strings.SplitN(name, string(f.OptargDelimiter()), 2)[0] - - switch f.Mode() { - case ShorthandOnly, NameAsShorthand: - return false - default: - return name == f.Name - } - - case !posix: - name := strings.TrimPrefix(arg, "-") - name = strings.SplitN(name, string(f.OptargDelimiter()), 2)[0] - - if name == "" { - return false - } - - switch f.Mode() { - case ShorthandOnly: - return name == f.Shorthand - default: - return name == f.Name || name == f.Shorthand - } - - default: - if f.Shorthand != "" { - return strings.HasSuffix(arg, f.Shorthand) - } - return false - } -} - func (f Flag) TakesValue() bool { switch f.Value.Type() { case "bool", "boolSlice", "count": @@ -173,3 +129,22 @@ func (f Flag) Definition() string { return definition } + +func (f Flag) Consumes(arg string) bool { + switch { + case f.Flag == nil: + return false + case !f.TakesValue(): + return false + case f.IsOptarg(): + return false + case len(f.Args) == 0: + return true + case f.Nargs() > 1 && len(f.Args) < f.Nargs(): + return true + case f.Nargs() < 0 && !strings.HasPrefix(arg, "-"): + return true + default: + return false + } +} diff --git a/internal/pflagfork/flagset.go b/internal/pflagfork/flagset.go index e6d4bdecb..17077ebf0 100644 --- a/internal/pflagfork/flagset.go +++ b/internal/pflagfork/flagset.go @@ -29,7 +29,7 @@ func (f FlagSet) IsPosix() bool { } func (f FlagSet) IsShorthandSeries(arg string) bool { - re := regexp.MustCompile("^-(?P[^-=]+)") + re := regexp.MustCompile("^-(?P[^-].*)") return re.MatchString(arg) && f.IsPosix() } @@ -48,20 +48,106 @@ func (f FlagSet) IsMutuallyExclusive(flag *pflag.Flag) bool { func (f *FlagSet) VisitAll(fn func(*Flag)) { f.FlagSet.VisitAll(func(flag *pflag.Flag) { - fn(&Flag{flag}) + fn(&Flag{Flag: flag, Args: []string{}}) }) } func (fs FlagSet) LookupArg(arg string) (result *Flag) { isPosix := fs.IsPosix() - fs.VisitAll(func(f *Flag) { + + switch { + case strings.HasPrefix(arg, "--"): + return fs.lookupPosixLonghandArg(arg) + case isPosix: + return fs.lookupPosixShorthandArg(arg) + case !isPosix: + return fs.lookupNonPosixShorthandArg(arg) + } + return +} + +func (fs FlagSet) ShorthandLookup(name string) *Flag { + if f := fs.FlagSet.ShorthandLookup(name); f != nil { + return &Flag{ + Flag: f, + Args: []string{}, + } + } + return nil +} + +func (fs FlagSet) lookupPosixLonghandArg(arg string) (flag *Flag) { + if !strings.HasPrefix(arg, "--") { + return nil + } + + fs.VisitAll(func(f *Flag) { // TODO needs to be sorted to try longest matching first + if flag != nil || f.Mode() != Default { + return + } + + splitted := strings.SplitAfterN(arg, string(f.OptargDelimiter()), 2) + if strings.TrimSuffix(splitted[0], string(f.OptargDelimiter())) == "--"+f.Name { + flag = f + flag.Prefix = splitted[0] + if len(splitted) > 1 { + flag.Args = splitted[1:] + } + } + }) + return +} + +func (fs FlagSet) lookupPosixShorthandArg(arg string) *Flag { + if !strings.HasPrefix(arg, "-") || !fs.IsPosix() || len(arg) < 2 { + return nil + } + + for index, r := range arg[1:] { + index += 1 + flag := fs.ShorthandLookup(string(r)) + + switch { + case flag == nil: + return flag + case len(arg) == index+1: + flag.Prefix = arg + return flag + case arg[index+1] == byte(flag.OptargDelimiter()) && len(arg) > index+2: + flag.Prefix = arg[:index+2] + flag.Args = []string{arg[index+2:]} + return flag + case arg[index+1] == byte(flag.OptargDelimiter()): + flag.Prefix = arg[:index+2] + flag.Args = []string{""} + return flag + case !flag.IsOptarg() && len(arg) > index+1: + flag.Prefix = arg[:index+1] + flag.Args = []string{arg[index+1:]} + return flag + } + } + return nil +} + +func (fs FlagSet) lookupNonPosixShorthandArg(arg string) (result *Flag) { // TODO pretty much duplicates longhand lookup + if !strings.HasPrefix(arg, "-") { + return nil + } + + fs.VisitAll(func(f *Flag) { // TODO needs to be sorted to try longest matching first if result != nil { return } - if f.Matches(arg, isPosix) { + splitted := strings.SplitAfterN(arg, string(f.OptargDelimiter()), 2) + if strings.TrimSuffix(splitted[0], string(f.OptargDelimiter())) == "-"+f.Shorthand { result = f + result.Prefix = splitted[0] + if len(splitted) > 1 { + result.Args = splitted[1:] + } } }) return diff --git a/internal/pflagfork/flagset_test.go b/internal/pflagfork/flagset_test.go new file mode 100644 index 000000000..ce2c8d6ca --- /dev/null +++ b/internal/pflagfork/flagset_test.go @@ -0,0 +1,48 @@ +package pflagfork + +import ( + "reflect" + "testing" + + "github.com/spf13/pflag" +) + +func TestLookupPosixShorthandArg(t *testing.T) { + _test := func(arg, name, prefix string, args ...string) { + t.Run(arg, func(t *testing.T) { + if args == nil { + args = []string{} + } + + fs := &FlagSet{pflag.NewFlagSet("test", pflag.PanicOnError)} + + fs.BoolP("bool", "b", false, "") + fs.CountP("count", "c", "") + fs.StringP("string", "s", "", "") + + f := fs.lookupPosixShorthandArg(arg) + if f == nil || f.Name != name { + t.Fatalf("should be " + name) + } + + if f.Prefix != prefix { + t.Fatalf("prefix doesnt match actual: %#v, expected: %#v", f.Prefix, prefix) + } + + if !reflect.DeepEqual(f.Args, args) { + t.Fatalf("args dont match %v: actual: %#v expected: %#v", arg, f.Args, args) + } + + }) + } + + _test("-b=", "bool", "-b=", "") + _test("-b=t", "bool", "-b=", "t") + _test("-b=true", "bool", "-b=", "true") + _test("-ccb", "bool", "-ccb") + _test("-ccb=", "bool", "-ccb=", "") + _test("-ccb=t", "bool", "-ccb=", "t") + _test("-ccb=true", "bool", "-ccb=", "true") + _test("-ccbs=val1", "string", "-ccbs=", "val1") + _test("-ccbsval1", "string", "-ccbs", "val1") +} diff --git a/internalActions.go b/internalActions.go index 3a6060c33..3a52e660a 100644 --- a/internalActions.go +++ b/internalActions.go @@ -83,6 +83,7 @@ func actionFlags(cmd *cobra.Command) Action { flagSet := pflagfork.FlagSet{FlagSet: cmd.Flags()} isShorthandSeries := flagSet.IsShorthandSeries(c.Value) + nospace := make([]rune, 0) vals := make([]string, 0) flagSet.VisitAll(func(f *pflagfork.Flag) { switch { @@ -104,6 +105,9 @@ func actionFlags(cmd *cobra.Command) Action { } } vals = append(vals, f.Shorthand, f.Usage, f.Style()) + if f.IsOptarg() { + nospace = append(nospace, []rune(f.Shorthand)[0]) + } } } else { switch f.Mode() { @@ -120,7 +124,10 @@ func actionFlags(cmd *cobra.Command) Action { }) if isShorthandSeries { - return ActionStyledValuesDescribed(vals...).Prefix(c.Value).NoSpace('*') + if len(nospace) > 0 { + return ActionStyledValuesDescribed(vals...).Prefix(c.Value).NoSpace(nospace...) + } + return ActionStyledValuesDescribed(vals...).Prefix(c.Value) } return ActionStyledValuesDescribed(vals...).MultiParts(".") // multiparts completion for flags grouped with `.` }).Tag("flags") diff --git a/traverse.go b/traverse.go index 9c485b0b7..ed2904083 100644 --- a/traverse.go +++ b/traverse.go @@ -10,31 +10,6 @@ import ( "github.com/spf13/cobra" ) -type _inFlag struct { // TODO rename or integrate into pflagfork.Flag? - *pflagfork.Flag - // currently consumed args since encountered flag - Args []string -} - -func (f _inFlag) Consumes(arg string) bool { - switch { - case f.Flag == nil: - return false - case !f.TakesValue(): - return false - case f.IsOptarg(): - return false - case len(f.Args) == 0: - return true - case f.Nargs() > 1 && len(f.Args) < f.Nargs(): - return true - case f.Nargs() < 0 && !strings.HasPrefix(arg, "-"): - return true - default: - return false - } -} - func traverse(c *cobra.Command, args []string) (Action, Context) { LOG.Printf("traverse called for %#v with args %#v\n", c.Name(), args) storage.preRun(c, args) @@ -46,7 +21,7 @@ func traverse(c *cobra.Command, args []string) (Action, Context) { inArgs := []string{} // args consumed by current command inPositionals := []string{} // positionals consumed by current command - var inFlag *_inFlag // last encountered flag that still expects arguments + var inFlag *pflagfork.Flag // last encountered flag that still expects arguments c.LocalFlags() // TODO force c.mergePersistentFlags() which is missing from c.Flags() fs := pflagfork.FlagSet{FlagSet: c.Flags()} @@ -75,15 +50,10 @@ loop: case !c.DisableFlagParsing && strings.HasPrefix(arg, "-") && (fs.IsInterspersed() || len(inPositionals) == 0): LOG.Printf("arg %#v is a flag\n", arg) inArgs = append(inArgs, arg) - inFlag = &_inFlag{ - Flag: fs.LookupArg(arg), - Args: []string{}, - } + inFlag = fs.LookupArg(arg) - if inFlag.Flag == nil { + if inFlag == nil { LOG.Printf("flag %#v is unknown", arg) - } else if splitted := inFlag.Flag.Split(arg); len(splitted) > 1 { - inFlag.Args = append(inFlag.Args, splitted[1]) } continue @@ -117,16 +87,17 @@ loop: if inFlag != nil && len(inFlag.Args) == 0 && inFlag.Consumes("") { LOG.Printf("removing arg %#v since it is a flag missing its argument\n", toParse[len(toParse)-1]) toParse = toParse[:len(toParse)-1] - } else if (fs.IsInterspersed() || len(inPositionals) == 0) && fs.IsShorthandSeries(context.Value) { - LOG.Printf("arg %#v is a shorthand flag series", context.Value) - localInFlag := &_inFlag{ - Flag: fs.LookupArg(context.Value), - Args: []string{}, - } - if localInFlag.Consumes("") && len(context.Value) > 2 { + } else if (fs.IsInterspersed() || len(inPositionals) == 0) && fs.IsShorthandSeries(context.Value) { // TODO shorthand series isn't correct anymore (can have value attached) + LOG.Printf("arg %#v is a shorthand flag series", context.Value) // TODO not aways correct + localInFlag := fs.LookupArg(context.Value) + + if localInFlag != nil && (len(localInFlag.Args) == 0 || localInFlag.Args[0] == "") && (!localInFlag.IsOptarg() || strings.HasSuffix(localInFlag.Prefix, string(localInFlag.OptargDelimiter()))) { // TODO && len(context.Value) > 2 { + // TODO check if empty prefix LOG.Printf("removing shorthand %#v from flag series since it is missing its argument\n", localInFlag.Shorthand) - toParse = append(toParse, strings.TrimSuffix(context.Value, localInFlag.Shorthand)) + LOG.Printf("prefix %#v", localInFlag.Prefix) + toParse = append(toParse, strings.TrimSuffix(strings.TrimSuffix(localInFlag.Prefix, string(localInFlag.OptargDelimiter())), localInFlag.Shorthand)) } else { + LOG.Printf("adding shorthand flag %#v", context.Value) toParse = append(toParse, context.Value) } @@ -162,16 +133,18 @@ loop: // flag case !c.DisableFlagParsing && strings.HasPrefix(context.Value, "-") && (fs.IsInterspersed() || len(inPositionals) == 0): - if f := fs.LookupArg(context.Value); f != nil && strings.Contains(context.Value, string(f.OptargDelimiter())) { - LOG.Printf("completing optional flag argument for arg %#v\n", context.Value) - prefix := f.Split(context.Value)[0] + if f := fs.LookupArg(context.Value); f != nil && len(f.Args) > 0 { + LOG.Printf("completing optional flag argument for arg %#v with prefix %#v\n", context.Value, f.Prefix) switch f.Value.Type() { case "bool": - return ActionValues("true", "false").StyleF(style.ForKeyword).Usage(f.Usage).Prefix(prefix), context + return ActionValues("true", "false").StyleF(style.ForKeyword).Usage(f.Usage).Prefix(f.Prefix), context default: - return storage.getFlag(c, f.Name).Prefix(prefix), context + return storage.getFlag(c, f.Name).Prefix(f.Prefix), context } + } else if f != nil && fs.IsPosix() && !strings.HasPrefix(context.Value, "--") && !f.IsOptarg() && f.Prefix == context.Value { + LOG.Printf("completing attached flag argument for arg %#v with prefix %#v\n", context.Value, f.Prefix) + return storage.getFlag(c, f.Name).Prefix(f.Prefix), context } LOG.Printf("completing flags for arg %#v\n", context.Value) return actionFlags(c), context