Skip to content

Commit

Permalink
Remove hard-coded support for mainargs.Leftover/Flag/SubParser to sup…
Browse files Browse the repository at this point in the history
…port alternate implementations (#62)

This PR moves the handling of `mainargs.Leftover`/`Flag`/`SubParser`
from a hard-coded `ArgSig` that only works with `mainargs.Leftover` or
`mainargs.Flag`, to properties of the `TokensReader` that can be
configured to work with any custom type.

Should probably be reviewed concurrently with
com-lihaoyi/mill#1948, which is the motivation
for this PR: we want to be able to define a CLI entrypoing taking
`mill.define.Task[mainargs.Leftover[T]]` or equivalent, which is
currently impossible due to the hard-coded nature of `mainargs.Leftover`
(and `mainargs.Flag` etc.)

# Major Changes

1. `ArgReader` is eliminated and `ArgSig` is greatly simplified to a
single type with no subtypes or type parameters

2. `TokensReader` is split into 5primary sub-types - `.Simple`,
`Constant`, `.Flag`, `.Leftover`, and `.Class`. These roughly mirror the
original `{ArgSig,ArgReader}.{Simple,Flag,Leftover,Class}` case classes.
The 5 sub-classes control behavior through
`Renderer.scala`/`Invoker.scala`/`TokensGrouping.scala` in the same way.

The major effect of moving the logic from `{ArgSig,ArgReader}` to
`TokensReader` is that they now are no longer hard-coded to work with
`mainargs.{Flag,Leftover,Subparser}` types. Now, anyone who has a custom
type `Foo` can choose whether they want to define a
`TokensReader.Simple` for it, or whether they want to define a
`TokensReader.Leftover` or `TokensReader.Flag`. Similarly, people can
define their own `TokensReader.Class` rather than relying on the default
implementation in `mainargs.ParserForClass`.

# Testing

Tested with two new flavors of `VarargsTests` (now renamed
`VarargsBasedTests`:

1. `VarargsWrappedTests` that exercises using a custom wrapper type to
define a main entrypoints that takes `Wrapper[mainargs.Leftover[T]]`,

2. `VarargsCustomTests` that replaces `mainargs.Leftover[T]` entirely
and defines main entrypoints that take `Wrapper[T]`

3. Added a `ConstantTests.scala` to exercise the code path, which was
previously the `noTokens` codepath and un-tested in this repo

4. All existing tests pass

# Notes

1. I chose to remove the type params from `ArgSig` because they weren't
really paying for their complexity; most of the time we were passing
around `ArgSig[_, _]`s anyway, so we weren't getting type safety, but
they nevertheless caused tons of headaches trying to get types to line
up. The un-typed ` default: Option[Any => Any], reader: TokensReader[_]`
isn't great, but it's a lot easier to work with and TBH not really much
less type-safe than the status quo

2. Because `ArgSig` and `TokensReader` now have a circular dependency on
each other (via `TokensReader.Class` -> `MainData` -> `ArgSig` ->
`TokensReader`), I moved them into the same file. This both makes the
acyclic linter happy, and also kind of makes sense since they're now
part of the same recursive data structure (v.s. previously `ArgSig` was
the recursive data structure with `TokensReader`s hanging off of each
node)

3. The new structure with `TokensReader` as a `sealed trait` with 5
distinct sub-types is different from what it was before, with
`TokensReader` as a single `class` with a grab-bag of all possible
fields and callbacks. I thought the `sealed trait` approach is much
cleaner here, since they reflect exactly the data necessary 4 different
scenarios we care about, whereas otherwise we'd find some fields
meaningless in some cases e.g. `Flag` has no meaningful fields,
`Leftover` doesn't care about `noTokens` or `alwaysRepeatable` or
`allowEmpty`, etc.
  • Loading branch information
lihaoyi committed Apr 29, 2023
1 parent 3f52e88 commit 3298647
Show file tree
Hide file tree
Showing 22 changed files with 594 additions and 387 deletions.
17 changes: 8 additions & 9 deletions mainargs/src-2/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ class Macros(val c: Context) {

q"""
new _root_.mainargs.ParserForClass(
_root_.mainargs.ClassMains[${weakTypeOf[T]}](
$route.asInstanceOf[_root_.mainargs.MainData[${weakTypeOf[T]}, Any]],
() => $companionObj
)
$route.asInstanceOf[_root_.mainargs.MainData[${weakTypeOf[T]}, Any]],
() => $companionObj
)
"""
}
Expand Down Expand Up @@ -115,16 +113,17 @@ class Macros(val c: Context) {
case _ => q"new _root_.mainargs.arg()"
}
val argSig = if (vararg) q"""
_root_.mainargs.ArgSig.createVararg[$varargUnwrappedType, $curCls](
${arg.name.decoded},
$instantiateArg,
).widen[_root_.scala.Any]
_root_.mainargs.ArgSig.create[_root_.mainargs.Leftover[$varargUnwrappedType], $curCls](
${arg.name.decoded},
$instantiateArg,
$defaultOpt
)
""" else q"""
_root_.mainargs.ArgSig.create[$varargUnwrappedType, $curCls](
${arg.name.decoded},
$instantiateArg,
$defaultOpt
).widen[_root_.scala.Any]
)
"""

c.internal.setPos(argSig, methodPos)
Expand Down
10 changes: 3 additions & 7 deletions mainargs/src-3/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ object Macros {
companionModuleType match
case '[bCompanion] =>
val mainData = createMainData[B, Any](annotatedMethod, mainAnnotationInstance)
'{
new ParserForClass[B](
ClassMains[B](${ mainData }, () => ${ Ident(companionModule).asExpr })
)
}
'{ new ParserForClass[B](${ mainData }, () => ${ Ident(companionModule).asExpr }) }
}

def createMainData[T: Type, B: Type](using Quotes)(method: quotes.reflect.Symbol, annotation: quotes.reflect.Term): Expr[MainData[T, B]] = {
Expand All @@ -63,13 +59,13 @@ object Macros {
case Some('{ $v: `t`}) => '{ Some(((_: B) => $v)) }
case None => '{ None }
}
val argReader = Expr.summon[mainargs.ArgReader[t]].getOrElse {
val tokensReader = Expr.summon[mainargs.TokensReader[t]].getOrElse {
report.throwError(
s"No mainargs.ArgReader found for parameter ${param.name}",
param.pos.get
)
}
'{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ argReader })).asInstanceOf[ArgSig[Any, B]] }
'{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ tokensReader })) }
})

val invokeRaw: Expr[(B, Seq[Any]) => T] = {
Expand Down
2 changes: 1 addition & 1 deletion mainargs/src/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class arg(
val doc: String = null,
val noDefaultName: Boolean = false,
val positional: Boolean = false,
val isHidden: Boolean = false
val hidden: Boolean = false
) extends ClassfileAnnotation

class main(val name: String = null, val doc: String = null) extends ClassfileAnnotation
111 changes: 62 additions & 49 deletions mainargs/src/Invoker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,49 @@ package mainargs

object Invoker {
def construct[T](
cep: ClassMains[T],
cep: TokensReader.Class[T],
args: Seq[String],
allowPositional: Boolean,
allowRepeats: Boolean
): Result[T] = {
TokenGrouping
.groupArgs(
args,
cep.main.argSigs,
cep.main.flattenedArgSigs,
allowPositional,
allowRepeats,
cep.main.leftoverArgSig.nonEmpty
cep.main.argSigs0.exists(_.reader.isLeftover)
)
.flatMap(invoke(cep.companion(), cep.main, _))
.flatMap((group: TokenGrouping[Any]) => invoke(cep.companion(), cep.main, group))
}

def invoke0[T, B](
base: B,
mainData: MainData[T, B],
kvs: Map[ArgSig.Named[_, B], Seq[String]],
kvs: Map[ArgSig, Seq[String]],
extras: Seq[String]
): Result[T] = {
val readArgValues: Seq[Either[Result[Any], ParamResult[_]]] =
for (a <- mainData.argSigs0) yield {
a match {
case a: ArgSig.Flag[B] =>
a.reader match {
case r: TokensReader.Flag =>
Right(ParamResult.Success(Flag(kvs.contains(a)).asInstanceOf[T]))
case a: ArgSig.Simple[T, B] => Right(makeReadCall(kvs, base, a))
case a: ArgSig.Leftover[T, B] =>
Right(makeReadVarargsCall(a, extras).map(x => Leftover(x: _*).asInstanceOf[T]))
case a: ArgSig.Class[T, B] =>
case r: TokensReader.Simple[T] => Right(makeReadCall(kvs, base, a, r))
case r: TokensReader.Constant[T] => Right(r.read() match {
case Left(s) => ParamResult.Failure(Seq(Result.ParamError.Failed(a, Nil, s)))
case Right(v) => ParamResult.Success(v)
})
case r: TokensReader.Leftover[T, _] => Right(makeReadVarargsCall(a, extras, r))
case r: TokensReader.Class[T] =>
Left(
invoke0[T, B](
a.reader.companion().asInstanceOf[B],
a.reader.main.asInstanceOf[MainData[T, B]],
r.companion().asInstanceOf[B],
r.main.asInstanceOf[MainData[T, B]],
kvs,
extras
)
)

}
}

Expand Down Expand Up @@ -79,18 +84,25 @@ object Invoker {
allowPositional: Boolean,
allowRepeats: Boolean
): Either[Result.Failure.Early, (MainData[Any, B], Result[Any])] = {
def groupArgs(main: MainData[Any, B], argsList: Seq[String]) = Right(
main,
TokenGrouping
.groupArgs(
argsList,
main.argSigs,
allowPositional,
allowRepeats,
main.leftoverArgSig.nonEmpty
)
.flatMap(Invoker.invoke(mains.base(), main, _))
)
def groupArgs(main: MainData[Any, B], argsList: Seq[String]) = {
def invokeLocal(group: TokenGrouping[Any]) =
invoke(mains.base(), main.asInstanceOf[MainData[Any, Any]], group)
Right(
main,
TokenGrouping
.groupArgs(
argsList,
main.flattenedArgSigs,
allowPositional,
allowRepeats,
main.argSigs0.exists {
case x: ArgSig => x.reader.isLeftover
case _ => false
}
)
.flatMap(invokeLocal)
)
}
mains.value match {
case Seq() => Left(Result.Failure.Early.NoMainMethodsDetected())
case Seq(main) => groupArgs(main, args)
Expand All @@ -115,10 +127,11 @@ object Invoker {
try Right(t)
catch { case e: Throwable => Left(error(e)) }
}
def makeReadCall[T, B](
dict: Map[ArgSig.Named[_, B], Seq[String]],
base: B,
arg: ArgSig.Simple[T, B]
def makeReadCall[T](
dict: Map[ArgSig, Seq[String]],
base: Any,
arg: ArgSig,
reader: TokensReader.Simple[_]
): ParamResult[T] = {
def prioritizedDefault = tryEither(
arg.default.map(_(base)),
Expand All @@ -128,14 +141,14 @@ object Invoker {
case Right(v) => ParamResult.Success(v)
}
val tokens = dict.get(arg) match {
case None => if (arg.reader.allowEmpty) Some(Nil) else None
case None => if (reader.allowEmpty) Some(Nil) else None
case Some(tokens) => Some(tokens)
}
val optResult = tokens match {
case None => prioritizedDefault
case Some(tokens) =>
tryEither(
arg.reader.read(tokens),
reader.read(tokens),
Result.ParamError.Exception(arg, tokens, _)
) match {
case Left(ex) => ParamResult.Failure(Seq(ex))
Expand All @@ -144,27 +157,27 @@ object Invoker {
case Right(Right(v)) => ParamResult.Success(Some(v))
}
}
optResult.map(_.get)
optResult.map(_.get.asInstanceOf[T])
}

def makeReadVarargsCall[T, B](
arg: ArgSig.Leftover[T, B],
values: Seq[String]
): ParamResult[Seq[T]] = {
val attempts =
for (token <- values)
yield tryEither(
arg.reader.read(Seq(token)),
Result.ParamError.Exception(arg, Seq(token), _)
) match {
case Left(x) => Left(x)
case Right(Left(errMsg)) => Left(Result.ParamError.Failed(arg, Seq(token), errMsg))
case Right(Right(v)) => Right(v)
}
def makeReadVarargsCall[T](
arg: ArgSig,
values: Seq[String],
reader: TokensReader.Leftover[_, _]
): ParamResult[T] = {
val eithers =
tryEither(
reader.read(values),
Result.ParamError.Exception(arg, values, _)
) match {
case Left(x) => Left(x)
case Right(Left(errMsg)) => Left(Result.ParamError.Failed(arg, values, errMsg))
case Right(Right(v)) => Right(v)
}

attempts.collect { case Left(x) => x } match {
case Nil => ParamResult.Success(attempts.collect { case Right(x) => x })
case bad => ParamResult.Failure(bad)
eithers match {
case Left(s) => ParamResult.Failure(Seq(s))
case Right(v) => ParamResult.Success(v.asInstanceOf[T])
}
}
}
144 changes: 0 additions & 144 deletions mainargs/src/Model.scala

This file was deleted.

0 comments on commit 3298647

Please sign in to comment.