Skip to content

Commit

Permalink
Fix argument parsing of flags in the presence of allowPositional=true (
Browse files Browse the repository at this point in the history
…#66)

Should fix #58 and
#60

Previously, we allowed any arg to take positional arguments if
`allowPositional = true` (which is the case for Ammonite and Mill
user-defined entrypoints.), even `mainargs.Flag`s. for which being
positional doesn't make sense.

```scala
    val positionalArgSigs = argSigs
      .filter {
        case x: ArgSig.Simple[_, _] if x.reader.noTokens => false
        case x: ArgSig.Simple[_, _] if x.positional => true
        case x => allowPositional
      }
```

The relevant code path was rewritten in
#62, but the buggy behavior
was preserved before and after that change. This wasn't caught in other
uses of `mainargs.Flag`, e.g. for Ammonite/Mill's own flags, because
those did not set `allowPositional=true`

This PR tweaks `TokenGrouping.groupArgs` to be more discerning about how
it selects positional, keyword, and missing arguments:

1. Now, only `TokenReader.Simple[_]`s with `.positional` or
`allowPositional` can be positional; `Flag`s, `Leftover`, etc. cannot

2. Keyword arguments are limited only to `Flag`s and `Simple` with
`!a.positional`

Added `mainargs.IssueTests.issue60` as a regression test, that fails on
main and passes on this PR. Existing tests all pass
  • Loading branch information
lihaoyi committed Apr 29, 2023
1 parent 3298647 commit a9e4c5e
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 31 deletions.
2 changes: 1 addition & 1 deletion build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ trait MainArgsPublishModule extends PublishModule with CrossScalaModule with Mim

def ivyDeps = Agg(
ivy"org.scala-lang.modules::scala-collection-compat::2.8.1"
) ++ Agg(ivy"com.lihaoyi::pprint:0.8.1")
)
}

def scalaMajor(scalaVersion: String) = if (isScala3(scalaVersion)) "3" else "2"
Expand Down
2 changes: 2 additions & 0 deletions mainargs/src/Renderer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ object Renderer {
val flattenedAll: Seq[ArgSig] =
mainMethods.map(_.flattenedArgSigs)
.flatten
.map(_._1)

val leftColWidth = getLeftColWidth(flattenedAll)
mainMethods match {
case Seq() => ""
Expand Down
39 changes: 16 additions & 23 deletions mainargs/src/TokenGrouping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,22 @@ case class TokenGrouping[B](remaining: List[String], grouped: Map[ArgSig, Seq[St
object TokenGrouping {
def groupArgs[B](
flatArgs0: Seq[String],
argSigs0: Seq[ArgSig],
argSigs: Seq[(ArgSig, TokensReader.Terminal[_])],
allowPositional: Boolean,
allowRepeats: Boolean,
allowLeftover: Boolean
): Result[TokenGrouping[B]] = {
val argSigs: Seq[ArgSig] = argSigs0
.map(ArgSig.flatten(_).collect { case x: ArgSig => x })
.flatten

val positionalArgSigs = argSigs
.filter {
case x: ArgSig if x.reader.isLeftover || x.reader.isConstant => false
case x: ArgSig if x.positional => true
case x => allowPositional
}
val positionalArgSigs = argSigs.collect {
case (a, r: TokensReader.Simple[_]) if allowPositional | a.positional =>
a
}

val flatArgs = flatArgs0.toList
val keywordArgMap = argSigs
.filter { case x: ArgSig if x.positional => false; case _ => true }
.collect {
case (a, r: TokensReader.Simple[_]) if !a.positional => a
case (a, r: TokensReader.Flag) => a
}
.flatMap { x => (x.name.map("--" + _) ++ x.shortName.map("-" + _)).map(_ -> x) }
.toMap[String, ArgSig]

Expand Down Expand Up @@ -77,17 +74,13 @@ object TokenGrouping {
}
.toSeq

val missing = argSigs
.collect { case x: ArgSig => x }
.filter { x =>
x.reader match {
case r: TokensReader.Simple[_] =>
!r.allowEmpty &&
x.default.isEmpty &&
!current.contains(x)
case _ => false
}
}
val missing = argSigs.collect {
case (a, r: TokensReader.Simple[_])
if !r.allowEmpty
&& a.default.isEmpty
&& !current.contains(a) =>
a
}

val unknown = if (allowLeftover) Nil else remaining
if (missing.nonEmpty || duplicates.nonEmpty || unknown.nonEmpty) {
Expand Down
10 changes: 5 additions & 5 deletions mainargs/src/TokensReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ object ArgSig {
)
}

def flatten[T](x: ArgSig): Seq[ArgSig] = x.reader match {
case _: TokensReader.Terminal[T] => Seq(x)
def flatten[T](x: ArgSig): Seq[(ArgSig, TokensReader.Terminal[_])] = x.reader match {
case r: TokensReader.Terminal[T] => Seq((x, r))
case cls: TokensReader.Class[_] => cls.main.argSigs0.flatMap(flatten(_))
}
}
Expand Down Expand Up @@ -281,11 +281,11 @@ case class MainData[T, B](
invokeRaw: (B, Seq[Any]) => T
) {

val flattenedArgSigs: Seq[ArgSig] =
argSigs0.iterator.flatMap[ArgSig](ArgSig.flatten(_)).toVector
val flattenedArgSigs: Seq[(ArgSig, TokensReader.Terminal[_])] =
argSigs0.iterator.flatMap[(ArgSig, TokensReader.Terminal[_])](ArgSig.flatten(_)).toVector

val renderedArgSigs: Seq[ArgSig] =
flattenedArgSigs.filter(a => !a.hidden && !a.reader.isConstant)
flattenedArgSigs.collect{case (a, r) if !a.hidden && !r.isConstant => a}
}

object MainData {
Expand Down
4 changes: 2 additions & 2 deletions mainargs/test/src/CoreTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class CoreTests(allowPositional: Boolean) extends TestSuite {
List("foo", "bar", "qux", "ex")
)
val evaledArgs = check.mains.value.map(_.flattenedArgSigs.map {
case ArgSig(name, s, docs, None, parser, _, _) => (s, docs, None, parser)
case ArgSig(name, s, docs, Some(default), parser, _, _) =>
case (ArgSig(name, s, docs, None, parser, _, _), _) => (s, docs, None, parser)
case (ArgSig(name, s, docs, Some(default), parser, _, _), _) =>
(s, docs, Some(default(CoreBase)), parser)
})

Expand Down
31 changes: 31 additions & 0 deletions mainargs/test/src/IssueTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package mainargs
import utest._

object IssueTests extends TestSuite {

object Main {
@main
def mycmd(@arg(name = "the-flag") f: mainargs.Flag = mainargs.Flag(false),
@arg str: String = "s",
args: Leftover[String]) = {
(f.value, str, args.value)
}
}

val tests = Tests {
test("issue60") {
test {
val parsed = ParserForMethods(Main)
.runEither(Seq("--str", "str", "a", "b", "c", "d"), allowPositional = true)

assert(parsed == Right((false, "str", List("a", "b", "c", "d"))))
}
test {
val parsed = ParserForMethods(Main)
.runEither(Seq("a", "b", "c", "d"), allowPositional = true)

assert(parsed == Right((false, "a", List("b", "c", "d"))))
}
}
}
}

0 comments on commit a9e4c5e

Please sign in to comment.