diff --git a/mainargs/src-2/Macros.scala b/mainargs/src-2/Macros.scala index 04b6e8f..f5b3574 100755 --- a/mainargs/src-2/Macros.scala +++ b/mainargs/src-2/Macros.scala @@ -37,10 +37,8 @@ class Macros(val c: Context) { q""" new _root_.mainargs.ParserForClass( - _root_.mainargs.ClassMains[${weakTypeOf[T]}]( - $route.asInstanceOf[_root_.mainargs.MainData[${weakTypeOf[T]}, Any]], - () => $companionObj - ) + $route.asInstanceOf[_root_.mainargs.MainData[${weakTypeOf[T]}, Any]], + () => $companionObj ) """ } @@ -115,16 +113,17 @@ class Macros(val c: Context) { case _ => q"new _root_.mainargs.arg()" } val argSig = if (vararg) q""" - _root_.mainargs.ArgSig.createVararg[$varargUnwrappedType, $curCls]( - ${arg.name.decoded}, - $instantiateArg, - ).widen[_root_.scala.Any] + _root_.mainargs.ArgSig.create[_root_.mainargs.Leftover[$varargUnwrappedType], $curCls]( + ${arg.name.decoded}, + $instantiateArg, + $defaultOpt + ) """ else q""" _root_.mainargs.ArgSig.create[$varargUnwrappedType, $curCls]( ${arg.name.decoded}, $instantiateArg, $defaultOpt - ).widen[_root_.scala.Any] + ) """ c.internal.setPos(argSig, methodPos) diff --git a/mainargs/src-3/Macros.scala b/mainargs/src-3/Macros.scala index 4145ca5..6cd75ab 100644 --- a/mainargs/src-3/Macros.scala +++ b/mainargs/src-3/Macros.scala @@ -41,11 +41,7 @@ object Macros { companionModuleType match case '[bCompanion] => val mainData = createMainData[B, Any](annotatedMethod, mainAnnotationInstance) - '{ - new ParserForClass[B]( - ClassMains[B](${ mainData }, () => ${ Ident(companionModule).asExpr }) - ) - } + '{ new ParserForClass[B](${ mainData }, () => ${ Ident(companionModule).asExpr }) } } def createMainData[T: Type, B: Type](using Quotes)(method: quotes.reflect.Symbol, annotation: quotes.reflect.Term): Expr[MainData[T, B]] = { @@ -63,13 +59,13 @@ object Macros { case Some('{ $v: `t`}) => '{ Some(((_: B) => $v)) } case None => '{ None } } - val argReader = Expr.summon[mainargs.ArgReader[t]].getOrElse { + val tokensReader = Expr.summon[mainargs.TokensReader[t]].getOrElse { report.throwError( s"No mainargs.ArgReader found for parameter ${param.name}", param.pos.get ) } - '{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ argReader })).asInstanceOf[ArgSig[Any, B]] } + '{ (ArgSig.create[t, B](${ Expr(param.name) }, ${ arg }, ${ defaultParam })(using ${ tokensReader })) } }) val invokeRaw: Expr[(B, Seq[Any]) => T] = { diff --git a/mainargs/src/Annotations.scala b/mainargs/src/Annotations.scala index aad5b88..927e86a 100644 --- a/mainargs/src/Annotations.scala +++ b/mainargs/src/Annotations.scala @@ -7,7 +7,7 @@ class arg( val doc: String = null, val noDefaultName: Boolean = false, val positional: Boolean = false, - val isHidden: Boolean = false + val hidden: Boolean = false ) extends ClassfileAnnotation class main(val name: String = null, val doc: String = null) extends ClassfileAnnotation diff --git a/mainargs/src/Invoker.scala b/mainargs/src/Invoker.scala index a543f25..a937d66 100644 --- a/mainargs/src/Invoker.scala +++ b/mainargs/src/Invoker.scala @@ -2,7 +2,7 @@ package mainargs object Invoker { def construct[T]( - cep: ClassMains[T], + cep: TokensReader.Class[T], args: Seq[String], allowPositional: Boolean, allowRepeats: Boolean @@ -10,36 +10,41 @@ object Invoker { TokenGrouping .groupArgs( args, - cep.main.argSigs, + cep.main.flattenedArgSigs, allowPositional, allowRepeats, - cep.main.leftoverArgSig.nonEmpty + cep.main.argSigs0.exists(_.reader.isLeftover) ) - .flatMap(invoke(cep.companion(), cep.main, _)) + .flatMap((group: TokenGrouping[Any]) => invoke(cep.companion(), cep.main, group)) } + def invoke0[T, B]( base: B, mainData: MainData[T, B], - kvs: Map[ArgSig.Named[_, B], Seq[String]], + kvs: Map[ArgSig, Seq[String]], extras: Seq[String] ): Result[T] = { val readArgValues: Seq[Either[Result[Any], ParamResult[_]]] = for (a <- mainData.argSigs0) yield { - a match { - case a: ArgSig.Flag[B] => + a.reader match { + case r: TokensReader.Flag => Right(ParamResult.Success(Flag(kvs.contains(a)).asInstanceOf[T])) - case a: ArgSig.Simple[T, B] => Right(makeReadCall(kvs, base, a)) - case a: ArgSig.Leftover[T, B] => - Right(makeReadVarargsCall(a, extras).map(x => Leftover(x: _*).asInstanceOf[T])) - case a: ArgSig.Class[T, B] => + case r: TokensReader.Simple[T] => Right(makeReadCall(kvs, base, a, r)) + case r: TokensReader.Constant[T] => Right(r.read() match { + case Left(s) => ParamResult.Failure(Seq(Result.ParamError.Failed(a, Nil, s))) + case Right(v) => ParamResult.Success(v) + }) + case r: TokensReader.Leftover[T, _] => Right(makeReadVarargsCall(a, extras, r)) + case r: TokensReader.Class[T] => Left( invoke0[T, B]( - a.reader.companion().asInstanceOf[B], - a.reader.main.asInstanceOf[MainData[T, B]], + r.companion().asInstanceOf[B], + r.main.asInstanceOf[MainData[T, B]], kvs, extras ) ) + } } @@ -79,18 +84,25 @@ object Invoker { allowPositional: Boolean, allowRepeats: Boolean ): Either[Result.Failure.Early, (MainData[Any, B], Result[Any])] = { - def groupArgs(main: MainData[Any, B], argsList: Seq[String]) = Right( - main, - TokenGrouping - .groupArgs( - argsList, - main.argSigs, - allowPositional, - allowRepeats, - main.leftoverArgSig.nonEmpty - ) - .flatMap(Invoker.invoke(mains.base(), main, _)) - ) + def groupArgs(main: MainData[Any, B], argsList: Seq[String]) = { + def invokeLocal(group: TokenGrouping[Any]) = + invoke(mains.base(), main.asInstanceOf[MainData[Any, Any]], group) + Right( + main, + TokenGrouping + .groupArgs( + argsList, + main.flattenedArgSigs, + allowPositional, + allowRepeats, + main.argSigs0.exists { + case x: ArgSig => x.reader.isLeftover + case _ => false + } + ) + .flatMap(invokeLocal) + ) + } mains.value match { case Seq() => Left(Result.Failure.Early.NoMainMethodsDetected()) case Seq(main) => groupArgs(main, args) @@ -115,10 +127,11 @@ object Invoker { try Right(t) catch { case e: Throwable => Left(error(e)) } } - def makeReadCall[T, B]( - dict: Map[ArgSig.Named[_, B], Seq[String]], - base: B, - arg: ArgSig.Simple[T, B] + def makeReadCall[T]( + dict: Map[ArgSig, Seq[String]], + base: Any, + arg: ArgSig, + reader: TokensReader.Simple[_] ): ParamResult[T] = { def prioritizedDefault = tryEither( arg.default.map(_(base)), @@ -128,14 +141,14 @@ object Invoker { case Right(v) => ParamResult.Success(v) } val tokens = dict.get(arg) match { - case None => if (arg.reader.allowEmpty) Some(Nil) else None + case None => if (reader.allowEmpty) Some(Nil) else None case Some(tokens) => Some(tokens) } val optResult = tokens match { case None => prioritizedDefault case Some(tokens) => tryEither( - arg.reader.read(tokens), + reader.read(tokens), Result.ParamError.Exception(arg, tokens, _) ) match { case Left(ex) => ParamResult.Failure(Seq(ex)) @@ -144,27 +157,27 @@ object Invoker { case Right(Right(v)) => ParamResult.Success(Some(v)) } } - optResult.map(_.get) + optResult.map(_.get.asInstanceOf[T]) } - def makeReadVarargsCall[T, B]( - arg: ArgSig.Leftover[T, B], - values: Seq[String] - ): ParamResult[Seq[T]] = { - val attempts = - for (token <- values) - yield tryEither( - arg.reader.read(Seq(token)), - Result.ParamError.Exception(arg, Seq(token), _) - ) match { - case Left(x) => Left(x) - case Right(Left(errMsg)) => Left(Result.ParamError.Failed(arg, Seq(token), errMsg)) - case Right(Right(v)) => Right(v) - } + def makeReadVarargsCall[T]( + arg: ArgSig, + values: Seq[String], + reader: TokensReader.Leftover[_, _] + ): ParamResult[T] = { + val eithers = + tryEither( + reader.read(values), + Result.ParamError.Exception(arg, values, _) + ) match { + case Left(x) => Left(x) + case Right(Left(errMsg)) => Left(Result.ParamError.Failed(arg, values, errMsg)) + case Right(Right(v)) => Right(v) + } - attempts.collect { case Left(x) => x } match { - case Nil => ParamResult.Success(attempts.collect { case Right(x) => x }) - case bad => ParamResult.Failure(bad) + eithers match { + case Left(s) => ParamResult.Failure(Seq(s)) + case Right(v) => ParamResult.Success(v.asInstanceOf[T]) } } } diff --git a/mainargs/src/Model.scala b/mainargs/src/Model.scala deleted file mode 100644 index 33a9cdd..0000000 --- a/mainargs/src/Model.scala +++ /dev/null @@ -1,144 +0,0 @@ -package mainargs - -sealed trait ArgSig[T, B] { - def widen[V >: T] = this.asInstanceOf[ArgSig[V, B]] -} -object ArgSig { - def createVararg[T, B](name0: String, arg: mainargs.arg)(implicit - argParser: ArgReader.Leftover[T] - ) = { - val name = scala.Option(arg.name).getOrElse(name0) - val docOpt = scala.Option(arg.doc) - Leftover[T, B](name, docOpt, argParser.reader) - } - def create[T, B](name0: String, arg: mainargs.arg, defaultOpt: Option[B => T])(implicit - argParser: ArgReader[T] - ): ArgSig[T, B] = { - val nameOpt = scala.Option(arg.name).orElse(if (name0.length == 1 || arg.noDefaultName) None - else Some(name0)) - val shortOpt = arg.short match { - case '\u0000' => if (name0.length != 1 || arg.noDefaultName) None else Some(name0(0)); - case c => Some(c) - } - val docOpt = scala.Option(arg.doc) - val isHidden = arg.isHidden - argParser match { - case ArgReader.Flag() => ArgSig.Flag[B](nameOpt, shortOpt, docOpt, isHidden) - case ArgReader.Class(parser) => Class(parser.mains) - case ArgReader.Leftover(reader: TokensReader[T]) => - Leftover[T, B](scala.Option(arg.name).getOrElse(name0), docOpt, reader) - case ArgReader.Simple(reader) => - Simple[T, B](nameOpt, shortOpt, docOpt, defaultOpt, reader, arg.positional, isHidden) - } - } - - sealed trait Terminal[T, B] extends ArgSig[T, B] { - def name: Option[String] - def doc: Option[String] - } - - sealed trait Named[T, B] extends Terminal[T, B] { - def shortName: Option[Char] - def isHidden: Boolean - } - - /** - * Models what is known by the router about a single argument: that it has - * a [[name]], a human-readable [[typeString]] describing what the type is - * (just for logging and reading, not a replacement for a `TypeTag`) and - * possible a function that can compute its default value - */ - case class Simple[T, B]( - override val name: Option[String], - override val shortName: Option[Char], - override val doc: Option[String], - default: Option[B => T], - reader: TokensReader[T], - positional: Boolean, - override val isHidden: Boolean - ) extends ArgSig.Named[T, B] { - def typeString = reader.shortName - } - - case class Flag[B]( - override val name: Option[String], - override val shortName: Option[Char], - override val doc: Option[String], - override val isHidden: Boolean - ) extends ArgSig.Named[mainargs.Flag, B] - - def flatten[T, B](x: ArgSig[T, B]): Seq[Terminal[T, B]] = x match { - case x: Terminal[T, B] => Seq(x) - case x: Class[T, B] => - x.reader.main.argSigs.flatMap(x => flatten(x.asInstanceOf[Terminal[T, B]])) - } - - case class Class[T, B](reader: ClassMains[T]) extends ArgSig[T, B] - - case class Leftover[T, B](name0: String, doc: Option[String], reader: TokensReader[T]) - extends ArgSig.Terminal[T, B] { - def name = Some(name0) - } -} - -sealed trait ArgReader[T] -object ArgReader { - implicit def createSimple[T: TokensReader]: Simple[T] = Simple(implicitly[TokensReader[T]]) - case class Simple[T](x: TokensReader[T]) extends ArgReader[T] - - implicit def createClass[T: SubParser]: Class[T] = Class(implicitly[SubParser[T]]) - case class Class[T](x: SubParser[T]) extends ArgReader[T] - - implicit def createLeftover[T: TokensReader]: Leftover[T] = Leftover(implicitly[TokensReader[T]]) - case class Leftover[T](reader: TokensReader[T]) extends ArgReader[mainargs.Leftover[T]] - - implicit def createFlag: Flag = Flag() - case class Flag() extends ArgReader[mainargs.Flag] -} - -trait SubParser[T] { - def mains: ClassMains[T] -} - -case class MethodMains[B](value: Seq[MainData[Any, B]], base: () => B) - -case class ClassMains[T](main: MainData[T, Any], companion: () => Any) - -/** - * What is known about a single endpoint for our routes. It has a [[name]], - * [[argSigs]] for each argument, and a macro-generated [[invoke0]] - * that performs all the necessary argument parsing and de-serialization. - * - * Realistically, you will probably spend most of your time calling [[Invoker.invoke]] - * instead, which provides a nicer API to call it that mimmicks the API of - * calling a Scala method. - */ -case class MainData[T, B]( - name: String, - argSigs0: Seq[ArgSig[_, B]], - doc: Option[String], - invokeRaw: (B, Seq[Any]) => T -) { - - val argSigs: Seq[ArgSig.Terminal[_, B]] = - argSigs0.iterator.flatMap[ArgSig.Terminal[_, B]](ArgSig.flatten(_)).toVector - val leftoverArgSig: Seq[ArgSig.Leftover[_, _]] = - argSigs.collect { case x: ArgSig.Leftover[_, B] => x } - -} - -object MainData { - def create[T, B]( - methodName: String, - main: mainargs.main, - argSigs: Seq[ArgSig[Any, B]], - invokeRaw: (B, Seq[Any]) => T - ) = { - MainData( - Option(main.name).getOrElse(methodName), - argSigs, - Option(main.doc), - invokeRaw - ) - } -} diff --git a/mainargs/src/Parser.scala b/mainargs/src/Parser.scala index 072c623..d225b7a 100644 --- a/mainargs/src/Parser.scala +++ b/mainargs/src/Parser.scala @@ -175,7 +175,8 @@ class ParserForMethods[B](val mains: MethodMains[B]) { } object ParserForClass extends ParserForClassCompanionVersionSpecific -class ParserForClass[T](val mains: ClassMains[T]) extends SubParser[T] { +class ParserForClass[T](val main: MainData[T, Any], val companion: () => Any) + extends TokensReader.Class[T] { def helpText( totalWidth: Int = 100, docsOnNewLine: Boolean = false, @@ -184,10 +185,10 @@ class ParserForClass[T](val mains: ClassMains[T]) extends SubParser[T] { sorted: Boolean = true ): String = { Renderer.formatMainMethodSignature( - mains.main, + main, 0, totalWidth, - Renderer.getLeftColWidth(mains.main.argSigs), + Renderer.getLeftColWidth(main.renderedArgSigs), docsOnNewLine, Option(customName), Option(customDoc), @@ -281,7 +282,7 @@ class ParserForClass[T](val mains: ClassMains[T]) extends SubParser[T] { case f: Result.Failure => Left( Renderer.renderResult( - mains.main, + main, f, totalWidth, printHelpOnExit, @@ -323,6 +324,6 @@ class ParserForClass[T](val mains: ClassMains[T]) extends SubParser[T] { allowPositional: Boolean = false, allowRepeats: Boolean = false ): Result[T] = { - Invoker.construct[T](mains, args, allowPositional, allowRepeats) + Invoker.construct[T](this, args, allowPositional, allowRepeats) } } diff --git a/mainargs/src/Renderer.scala b/mainargs/src/Renderer.scala index c84e6d2..2b648b6 100644 --- a/mainargs/src/Renderer.scala +++ b/mainargs/src/Renderer.scala @@ -5,44 +5,45 @@ import scala.math object Renderer { - def getLeftColWidth(items: Seq[ArgSig.Terminal[_, _]]) = { + def getLeftColWidth(items: Seq[ArgSig]) = { if (items.isEmpty) 0 - else items.filter(nonHidden).map(renderArgShort(_).length).max + else items.map(renderArgShort(_).length).max } val newLine = System.lineSeparator() def normalizeNewlines(s: String) = s.replace("\r", "").replace("\n", newLine) - def renderArgShort(arg: ArgSig.Terminal[_, _]) = arg match { - case arg: ArgSig.Flag[_] => + def renderArgShort(arg: ArgSig) = arg.reader match { + case r: TokensReader.Flag => val shortPrefix = arg.shortName.map(c => s"-$c") val nameSuffix = arg.name.map(s => s"--$s") (shortPrefix ++ nameSuffix).mkString(" ") - case arg: ArgSig.Simple[_, _] => - val shortPrefix = arg.shortName.map(c => s"-$c") - val typeSuffix = s"<${arg.typeString}>" + case r: TokensReader.Simple[_] => + val shortPrefix = arg.shortName.map(c => s"-$c") + val typeSuffix = s"<${r.shortName}>" val nameSuffix = if (arg.positional) arg.name else arg.name.map(s => s"--$s") (shortPrefix ++ nameSuffix ++ Seq(typeSuffix)).mkString(" ") - case arg: ArgSig.Leftover[_, _] => - s"${arg.name0} <${arg.reader.shortName}>..." + + case r: TokensReader.Leftover[_, _] => + s"${arg.name.get} <${r.shortName}>..." } /** * Returns a `Some[string]` with the sortable string or a `None` if it is an leftover. */ - private def sortableName(arg: ArgSig.Terminal[_, _]): Option[String] = arg match { - case l: ArgSig.Leftover[_, _] => - None - case a: ArgSig.Named[_, _] => + private def sortableName(arg: ArgSig): Option[String] = arg match { + case arg: ArgSig if arg.reader.isLeftover => None + + case a: ArgSig => a.shortName.map(_.toString).orElse(a.name).orElse(Some("")) - case a: ArgSig.Terminal[_, _] => + case a: ArgSig => a.name.orElse(Some("")) } - object ArgOrd extends math.Ordering[ArgSig.Terminal[_, _]] { - override def compare(x: ArgSig.Terminal[_, _], y: ArgSig.Terminal[_, _]): Int = + object ArgOrd extends math.Ordering[ArgSig] { + override def compare(x: ArgSig, y: ArgSig): Int = (sortableName(x), sortableName(y)) match { case (None, None) => 0 // don't sort leftovers case (None, Some(_)) => 1 // keep left overs at the end @@ -51,13 +52,8 @@ object Renderer { } } - private[this] val nonHidden: ArgSig.Terminal[_, _] => Boolean = { - case arg: ArgSig.Named[_, _] => !arg.isHidden - case _ => true - } - def renderArg( - arg: ArgSig.Terminal[_, _], + arg: ArgSig, leftOffset: Int, wrappedWidth: Int ): (String, String) = { @@ -73,8 +69,8 @@ object Renderer { customDocs: Map[String, String], sorted: Boolean ): String = { - val flattenedAll: Seq[ArgSig.Terminal[_, _]] = - mainMethods.map(_.argSigs) + val flattenedAll: Seq[ArgSig] = + mainMethods.map(_.flattenedArgSigs) .flatten val leftColWidth = getLeftColWidth(flattenedAll) mainMethods match { @@ -142,10 +138,10 @@ object Renderer { val argLeftCol = if (docsOnNewLine) leftIndent + 8 else leftColWidth + leftIndent + 2 + 2 val sortedArgs = - if (sorted) main.argSigs.sorted(ArgOrd) - else main.argSigs + if (sorted) main.renderedArgSigs.sorted(ArgOrd) + else main.renderedArgSigs - val args = sortedArgs.filter(nonHidden).map(renderArg(_, argLeftCol, totalWidth)) + val args = sortedArgs.map(renderArg(_, argLeftCol, totalWidth)) val leftIndentStr = " " * leftIndent @@ -240,7 +236,7 @@ object Renderer { def expectedMsg() = { if (printHelpOnError) { - val leftColWidth = getLeftColWidth(main.argSigs) + val leftColWidth = getLeftColWidth(main.renderedArgSigs) "Expected Signature: " + Renderer.formatMainMethodSignature( main, diff --git a/mainargs/src/Result.scala b/mainargs/src/Result.scala index 9e8601e..60c5147 100644 --- a/mainargs/src/Result.scala +++ b/mainargs/src/Result.scala @@ -47,10 +47,10 @@ object Result { * did not line up with the arguments expected */ case class MismatchedArguments( - missing: Seq[ArgSig.Simple[_, _]] = Nil, + missing: Seq[ArgSig] = Nil, unknown: Seq[String] = Nil, - duplicate: Seq[(ArgSig.Named[_, _], Seq[String])] = Nil, - incomplete: Option[ArgSig.Simple[_, _]] = None + duplicate: Seq[(ArgSig, Seq[String])] = Nil, + incomplete: Option[ArgSig] = None ) extends Failure /** @@ -66,21 +66,21 @@ object Result { /** * Something went wrong trying to de-serialize the input parameter */ - case class Failed(arg: ArgSig.Terminal[_, _], tokens: Seq[String], errMsg: String) + case class Failed(arg: ArgSig, tokens: Seq[String], errMsg: String) extends ParamError /** * Something went wrong trying to de-serialize the input parameter; * the thrown exception is stored in [[ex]] */ - case class Exception(arg: ArgSig.Terminal[_, _], tokens: Seq[String], ex: Throwable) + case class Exception(arg: ArgSig, tokens: Seq[String], ex: Throwable) extends ParamError /** * Something went wrong trying to evaluate the default value * for this input parameter */ - case class DefaultFailed(arg: ArgSig.Simple[_, _], ex: Throwable) extends ParamError + case class DefaultFailed(arg: ArgSig, ex: Throwable) extends ParamError } } diff --git a/mainargs/src/TokenGrouping.scala b/mainargs/src/TokenGrouping.scala index 6ef5186..b275e59 100644 --- a/mainargs/src/TokenGrouping.scala +++ b/mainargs/src/TokenGrouping.scala @@ -2,51 +2,51 @@ package mainargs import scala.annotation.tailrec -case class TokenGrouping[B](remaining: List[String], grouped: Map[ArgSig.Named[_, B], Seq[String]]) +case class TokenGrouping[B](remaining: List[String], grouped: Map[ArgSig, Seq[String]]) object TokenGrouping { def groupArgs[B]( flatArgs0: Seq[String], - argSigs0: Seq[ArgSig[_, B]], + argSigs0: Seq[ArgSig], allowPositional: Boolean, allowRepeats: Boolean, allowLeftover: Boolean ): Result[TokenGrouping[B]] = { - val argSigs: Seq[ArgSig.Named[_, B]] = argSigs0 - .map(ArgSig.flatten(_).collect { case x: ArgSig.Named[_, _] => x }) + val argSigs: Seq[ArgSig] = argSigs0 + .map(ArgSig.flatten(_).collect { case x: ArgSig => x }) .flatten val positionalArgSigs = argSigs .filter { - case x: ArgSig.Simple[_, _] if x.reader.noTokens => false - case x: ArgSig.Simple[_, _] if x.positional => true + case x: ArgSig if x.reader.isLeftover || x.reader.isConstant => false + case x: ArgSig if x.positional => true case x => allowPositional } val flatArgs = flatArgs0.toList val keywordArgMap = argSigs - .filter { case x: ArgSig.Simple[_, _] if x.positional => false; case _ => true } + .filter { case x: ArgSig if x.positional => false; case _ => true } .flatMap { x => (x.name.map("--" + _) ++ x.shortName.map("-" + _)).map(_ -> x) } - .toMap[String, ArgSig.Named[_, B]] + .toMap[String, ArgSig] @tailrec def rec( remaining: List[String], - current: Map[ArgSig.Named[_, B], Vector[String]] + current: Map[ArgSig, Vector[String]] ): Result[TokenGrouping[B]] = { remaining match { case head :: rest => if (head.startsWith("-") && head.exists(_ != '-')) { keywordArgMap.get(head) match { - case Some(cliArg: ArgSig.Flag[_]) => + case Some(cliArg: ArgSig) if cliArg.reader.isFlag => rec(rest, Util.appendMap(current, cliArg, "")) - case Some(cliArg: ArgSig.Simple[_, _]) => + case Some(cliArg: ArgSig) if !cliArg.reader.isLeftover => rest match { case next :: rest2 => rec(rest2, Util.appendMap(current, cliArg, next)) case Nil => Result.Failure.MismatchedArguments(Nil, Nil, Nil, incomplete = Some(cliArg)) } - case None => complete(remaining, current) + case _ => complete(remaining, current) } } else { positionalArgSigs.find(!current.contains(_)) match { @@ -56,29 +56,39 @@ object TokenGrouping { } case _ => complete(remaining, current) - } } + def complete( remaining: List[String], - current: Map[ArgSig.Named[_, B], Vector[String]] + current: Map[ArgSig, Vector[String]] ): Result[TokenGrouping[B]] = { val duplicates = current .filter { - case (a: ArgSig.Flag[_], vs) => vs.size > 1 && !allowRepeats - case (a: ArgSig.Simple[_, _], vs) => - vs.size > 1 && !a.reader.alwaysRepeatable && !allowRepeats + case (a: ArgSig, vs) => + a.reader match { + case r: TokensReader.Flag => vs.size > 1 && !allowRepeats + case r: TokensReader.Simple[_] => vs.size > 1 && !r.alwaysRepeatable && !allowRepeats + case r: TokensReader.Leftover[_, _] => false + case r: TokensReader.Constant[_] => false + } + } .toSeq val missing = argSigs - .collect { case x: ArgSig.Simple[_, _] => x } + .collect { case x: ArgSig => x } .filter { x => - !x.reader.allowEmpty && - x.default.isEmpty && - !current.contains(x) + x.reader match { + case r: TokensReader.Simple[_] => + !r.allowEmpty && + x.default.isEmpty && + !current.contains(x) + case _ => false + } } + val unknown = if (allowLeftover) Nil else remaining if (missing.nonEmpty || duplicates.nonEmpty || unknown.nonEmpty) { Result.Failure.MismatchedArguments( diff --git a/mainargs/src/TokensReader.scala b/mainargs/src/TokensReader.scala index 4e4c9cb..f0c45bb 100644 --- a/mainargs/src/TokensReader.scala +++ b/mainargs/src/TokensReader.scala @@ -1,54 +1,187 @@ package mainargs import scala.collection.compat._ import scala.collection.mutable -class TokensReader[T]( - val shortName: String, - val read: Seq[String] => Either[String, T], - val alwaysRepeatable: Boolean = false, - val allowEmpty: Boolean = false, - val noTokens: Boolean = false -) + +/** + * Represents the ability to parse CLI input arguments into a type [[T]] + * + * Has a fixed number of direct subtypes - [[Simple]], [[Constant]], [[Flag]], + * [[Leftover]], and [[Class]] - but each of those can be extended by an + * arbitrary number of user-specified instances. + */ +sealed trait TokensReader[T] { + def isLeftover = false + def isFlag = false + def isClass = false + def isConstant = false + def isSimple = false +} + object TokensReader { + + sealed trait Terminal[T] extends TokensReader[T] + + sealed trait ShortNamed[T] extends Terminal[T] { + /** + * The label that shows up in the CLI help message, e.g. the `bar` in + * `--foo ` + */ + def shortName: String + } + + /** + * A [[TokensReader]] for a single CLI parameter that takes a value + * e.g. `--foo bar` + */ + trait Simple[T] extends ShortNamed[T] { + /** + * Converts the given input tokens to a [[T]] or an error `String`. + * The input is a `Seq` because input tokens can be passed more than once, + * e.g. `--foo bar --foo qux` will result in [[read]] being passed + * `["foo", "qux"]` + */ + def read(strs: Seq[String]): Either[String, T] + + /** + * Whether is CLI param is repeatable + */ + def alwaysRepeatable: Boolean = false + + /** + * Whether this CLI param can be no passed from the CLI, even if a default + * value is not specified. In that case, [[read]] receives an empty `Seq` + */ + def allowEmpty: Boolean = false + override def isSimple = true + } + + /** + * A [[TokensReader]] that doesn't read any tokens and just returns a value. + * Useful sometimes for injecting things into main methods that aren't + * strictly computed from CLI argument tokens but nevertheless need to get + * passed in. + */ + trait Constant[T] extends Terminal[T] { + def read(): Either[String, T] + override def isConstant = true + } + + /** + * A [[TokensReader]] for a flag that does not take any value, e.g. `--foo` + */ + trait Flag extends Terminal[mainargs.Flag] { + override def isFlag = true + } + + /** + * A [[TokensReader]] for parsing the left-over parameters that do not belong + * to any other flag or parameter. + */ + trait Leftover[T, V] extends ShortNamed[T] { + def read(strs: Seq[String]): Either[String, T] + + def shortName: String + override def isLeftover = true + } + + /** + * A [[TokensReader]] that can parse an instance of the class [[T]], which + * may contain multiple fields each parsed by their own [[TokensReader]] + */ + trait Class[T] extends TokensReader[T] { + def companion: () => Any + def main: MainData[T, Any] + override def isClass = true + } + def tryEither[T](f: => T) = try Right(f) catch { case e: Throwable => Left(e.toString) } - implicit object StringRead extends TokensReader[String]("str", strs => Right(strs.last)) - implicit object BooleanRead - extends TokensReader[Boolean]("bool", strs => tryEither(strs.last.toBoolean)) - implicit object ByteRead extends TokensReader[Byte]("byte", strs => tryEither(strs.last.toByte)) - implicit object ShortRead - extends TokensReader[Short]("short", strs => tryEither(strs.last.toShort)) - implicit object IntRead extends TokensReader[Int]("int", strs => tryEither(strs.last.toInt)) - implicit object LongRead extends TokensReader[Long]("long", strs => tryEither(strs.last.toLong)) - implicit object FloatRead - extends TokensReader[Float]("float", strs => tryEither(strs.last.toFloat)) - implicit object DoubleRead - extends TokensReader[Double]("double", strs => tryEither(strs.last.toDouble)) - - implicit def OptionRead[T: TokensReader]: TokensReader[Option[T]] = new TokensReader[Option[T]]( - implicitly[TokensReader[T]].shortName, - strs => { + implicit object FlagRead extends Flag + implicit object StringRead extends Simple[String] { + def shortName = "str" + def read(strs: Seq[String]) = Right(strs.last) + } + implicit object BooleanRead extends Simple[Boolean] { + def shortName = "bool" + def read(strs: Seq[String]) = tryEither(strs.last.toBoolean) + } + implicit object ByteRead extends Simple[Byte] { + def shortName = "byte" + def read(strs: Seq[String]) = tryEither(strs.last.toByte) + } + implicit object ShortRead extends Simple[Short] { + def shortName = "short" + def read(strs: Seq[String]) = tryEither(strs.last.toShort) + } + implicit object IntRead extends Simple[Int] { + def shortName = "int" + def read(strs: Seq[String]) = tryEither(strs.last.toInt) + } + implicit object LongRead extends Simple[Long] { + def shortName = "long" + def read(strs: Seq[String]) = tryEither(strs.last.toLong) + } + implicit object FloatRead extends Simple[Float] { + def shortName = "float" + def read(strs: Seq[String]) = tryEither(strs.last.toFloat) + } + implicit object DoubleRead extends Simple[Double] { + def shortName = "double" + def read(strs: Seq[String]) = tryEither(strs.last.toDouble) + } + + implicit def LeftoverRead[T: TokensReader.Simple]: TokensReader.Leftover[mainargs.Leftover[T], T] = + new LeftoverRead[T]()(implicitly[TokensReader.Simple[T]]) + + class LeftoverRead[T](implicit wrapped: TokensReader.Simple[T]) + extends Leftover[mainargs.Leftover[T], T] { + def read(strs: Seq[String]) = { + val (failures, successes) = strs + .map(s => + implicitly[TokensReader[T]] match{ + case r: TokensReader.Simple[T] => r.read(Seq(s)) + case r: TokensReader.Leftover[T, _] => r.read(Seq(s)) + } + ) + .partitionMap(identity) + + if (failures.nonEmpty) Left(failures.head) + else Right(Leftover(successes: _*)) + } + def shortName = wrapped.shortName + } + + implicit def OptionRead[T: TokensReader.Simple]: TokensReader[Option[T]] = new OptionRead[T] + class OptionRead[T: TokensReader.Simple] extends Simple[Option[T]] { + def shortName = implicitly[TokensReader.Simple[T]].shortName + def read(strs: Seq[String]) = { strs.lastOption match { case None => Right(None) - case Some(s) => implicitly[TokensReader[T]].read(Seq(s)) match { + case Some(s) => implicitly[TokensReader.Simple[T]].read(Seq(s)) match { case Left(s) => Left(s) case Right(s) => Right(Some(s)) } } - }, - allowEmpty = true - ) - implicit def SeqRead[C[_] <: Iterable[_], T: TokensReader](implicit + } + override def allowEmpty = true + } + + implicit def SeqRead[C[_] <: Iterable[_], T: TokensReader.Simple](implicit factory: Factory[T, C[T]] - ): TokensReader[C[T]] = new TokensReader[C[T]]( - implicitly[TokensReader[T]].shortName, - strs => { + ): TokensReader[C[T]] = + new SeqRead[C, T] + + class SeqRead[C[_] <: Iterable[_], T: TokensReader.Simple](implicit factory: Factory[T, C[T]]) + extends Simple[C[T]] { + def shortName = implicitly[TokensReader.Simple[T]].shortName + def read(strs: Seq[String]) = { strs .foldLeft(Right(factory.newBuilder): Either[String, mutable.Builder[T, C[T]]]) { case (Left(s), _) => Left(s) case (Right(builder), token) => - implicitly[TokensReader[T]].read(Seq(token)) match { + implicitly[TokensReader.Simple[T]].read(Seq(token)) match { case Left(s) => Left(s) case Right(v) => builder += v @@ -56,31 +189,117 @@ object TokensReader { } } .map(_.result()) - }, - alwaysRepeatable = true, - allowEmpty = true - ) - implicit def MapRead[K: TokensReader, V: TokensReader]: TokensReader[Map[K, V]] = - new TokensReader[Map[K, V]]( - "k=v", - strs => { - strs.foldLeft[Either[String, Map[K, V]]](Right(Map())) { - case (Left(s), _) => Left(s) - case (Right(prev), token) => - token.split("=", 2) match { - case Array(k, v) => - for { - tuple <- Right((k, v)): Either[String, (String, String)] - (k, v) = tuple - key <- implicitly[TokensReader[K]].read(Seq(k)) - value <- implicitly[TokensReader[V]].read(Seq(v)) - } yield prev + (key -> value) - - case _ => Left("parameter must be in k=v format") - } - } - }, - alwaysRepeatable = true, - allowEmpty = true + } + override def alwaysRepeatable = true + override def allowEmpty = true + } + + implicit def MapRead[K: TokensReader.Simple, V: TokensReader.Simple]: TokensReader[Map[K, V]] = + new MapRead[K, V] + class MapRead[K: TokensReader.Simple, V: TokensReader.Simple] extends Simple[Map[K, V]] { + def shortName = "k=v" + def read(strs: Seq[String]) = { + strs.foldLeft[Either[String, Map[K, V]]](Right(Map())) { + case (Left(s), _) => Left(s) + case (Right(prev), token) => + token.split("=", 2) match { + case Array(k, v) => + for { + tuple <- Right((k, v)): Either[String, (String, String)] + (k, v) = tuple + key <- implicitly[TokensReader.Simple[K]].read(Seq(k)) + value <- implicitly[TokensReader.Simple[V]].read(Seq(v)) + } yield prev + (key -> value) + + case _ => Left("parameter must be in k=v format") + } + } + } + override def alwaysRepeatable = true + override def allowEmpty = true + } +} + +object ArgSig { + def create[T, B](name0: String, arg: mainargs.arg, defaultOpt: Option[B => T]) + (implicit tokensReader: TokensReader[T]): ArgSig = { + val nameOpt = scala.Option(arg.name).orElse(if (name0.length == 1 || arg.noDefaultName) None + else Some(name0)) + val shortOpt = arg.short match { + case '\u0000' => if (name0.length != 1 || arg.noDefaultName) None else Some(name0(0)); + case c => Some(c) + } + val docOpt = scala.Option(arg.doc) + ArgSig( + nameOpt, + shortOpt, + docOpt, + defaultOpt.asInstanceOf[Option[Any => Any]], + tokensReader, + arg.positional, + arg.hidden + ) + } + + def flatten[T](x: ArgSig): Seq[ArgSig] = x.reader match { + case _: TokensReader.Terminal[T] => Seq(x) + case cls: TokensReader.Class[_] => cls.main.argSigs0.flatMap(flatten(_)) + } +} + +/** + * Models what is known by the router about a single argument: that it has + * a [[name]], a human-readable [[typeString]] describing what the type is + * (just for logging and reading, not a replacement for a `TypeTag`) and + * possible a function that can compute its default value + */ +case class ArgSig( + name: Option[String], + shortName: Option[Char], + doc: Option[String], + default: Option[Any => Any], + reader: TokensReader[_], + positional: Boolean, + hidden: Boolean +) + +case class MethodMains[B](value: Seq[MainData[Any, B]], base: () => B) + +/** + * What is known about a single endpoint for our routes. It has a [[name]], + * [[flattenedArgSigs]] for each argument, and a macro-generated [[invoke0]] + * that performs all the necessary argument parsing and de-serialization. + * + * Realistically, you will probably spend most of your time calling [[Invoker.invoke]] + * instead, which provides a nicer API to call it that mimmicks the API of + * calling a Scala method. + */ +case class MainData[T, B]( + name: String, + argSigs0: Seq[ArgSig], + doc: Option[String], + invokeRaw: (B, Seq[Any]) => T +) { + + val flattenedArgSigs: Seq[ArgSig] = + argSigs0.iterator.flatMap[ArgSig](ArgSig.flatten(_)).toVector + + val renderedArgSigs: Seq[ArgSig] = + flattenedArgSigs.filter(a => !a.hidden && !a.reader.isConstant) +} + +object MainData { + def create[T, B]( + methodName: String, + main: mainargs.main, + argSigs: Seq[ArgSig], + invokeRaw: (B, Seq[Any]) => T + ) = { + MainData( + Option(main.name).getOrElse(methodName), + argSigs, + Option(main.doc), + invokeRaw ) + } } diff --git a/mainargs/test/src-2/OldVarargsTests.scala b/mainargs/test/src-2/VarargsOldTests.scala similarity index 86% rename from mainargs/test/src-2/OldVarargsTests.scala rename to mainargs/test/src-2/VarargsOldTests.scala index 6e7bdef..88b217a 100644 --- a/mainargs/test/src-2/OldVarargsTests.scala +++ b/mainargs/test/src-2/VarargsOldTests.scala @@ -1,7 +1,7 @@ package mainargs import utest._ -object OldVarargsTests extends VarargsTests { +object VarargsOldTests extends VarargsBaseTests { object Base { @main diff --git a/mainargs/test/src-jvm-2/AmmoniteTests.scala b/mainargs/test/src-jvm-2/AmmoniteTests.scala index 6fc0e91..c9ea6db 100644 --- a/mainargs/test/src-jvm-2/AmmoniteTests.scala +++ b/mainargs/test/src-jvm-2/AmmoniteTests.scala @@ -23,10 +23,19 @@ case class AmmoniteConfig( ) object AmmoniteConfig { - implicit object PathRead - extends TokensReader[os.Path]("path", strs => Right(os.Path(strs.head, os.pwd))) + implicit object PathRead extends TokensReader.Simple[os.Path] { + def shortName = "path" + def read(strs: Seq[String]) = Right(os.Path(strs.head, os.pwd)) + } + + case class InjectedConstant() + + implicit object InjectedTokensReader extends TokensReader.Constant[InjectedConstant] { + def read() = Right(new InjectedConstant()) + } @main case class Core( + injectedConstant: InjectedConstant, @arg( name = "no-default-predef", doc = "Disable the default predef and run Ammonite with the minimal predef possible" @@ -211,6 +220,7 @@ object AmmoniteTests extends TestSuite { Right( AmmoniteConfig( AmmoniteConfig.Core( + injectedConstant = AmmoniteConfig.InjectedConstant(), noDefaultPredef = Flag(), silent = Flag(), watch = Flag(), diff --git a/mainargs/test/src-jvm-2/MillTests.scala b/mainargs/test/src-jvm-2/MillTests.scala index 2cd89d4..c849fe9 100644 --- a/mainargs/test/src-jvm-2/MillTests.scala +++ b/mainargs/test/src-jvm-2/MillTests.scala @@ -3,9 +3,11 @@ package mainargs import utest._ object MillTests extends TestSuite { + implicit object PathRead extends TokensReader.Simple[os.Path] { + def shortName = "path" + def read(strs: Seq[String]) = Right(os.Path(strs.head, os.pwd)) + } - implicit object PathRead - extends TokensReader[os.Path]("path", strs => Right(os.Path(strs.head, os.pwd))) @main( name = "Mill Build Tool", doc = "usage: mill [mill-options] [target [target-options]]" @@ -69,6 +71,7 @@ object MillTests extends TestSuite { ) threadCount: Int = 1, ammoniteConfig: AmmoniteConfig.Core = AmmoniteConfig.Core( + injectedConstant = AmmoniteConfig.InjectedConstant(), noDefaultPredef = Flag(), silent = Flag(), watch = Flag(), @@ -78,7 +81,7 @@ object MillTests extends TestSuite { ), @arg( name = "hidden-dummy", - isHidden = true + hidden = true ) hiddenDummy: String = "" ) diff --git a/mainargs/test/src/ClassTests.scala b/mainargs/test/src/ClassTests.scala index 4570bde..3bdb664 100644 --- a/mainargs/test/src/ClassTests.scala +++ b/mainargs/test/src/ClassTests.scala @@ -32,14 +32,14 @@ object ClassTests extends TestSuite { fooParser.constructRaw(Seq("-x", "1")) ==> Result.Failure.MismatchedArguments( Seq( - ArgSig.Simple( + ArgSig( None, Some('y'), None, None, mainargs.TokensReader.IntRead, positional = false, - isHidden = false + hidden = false ) ), List(), @@ -63,14 +63,14 @@ object ClassTests extends TestSuite { barParser.constructRaw(Seq("-w", "-x", "1", "-z", "xxx")) ==> Result.Failure.MismatchedArguments( Seq( - ArgSig.Simple( + ArgSig( None, Some('y'), None, None, mainargs.TokensReader.IntRead, positional = false, - isHidden = false + hidden = false ) ), List(), @@ -85,14 +85,14 @@ object ClassTests extends TestSuite { barParser.constructRaw(Seq("-w", "-x", "1", "-y", "2")) ==> Result.Failure.MismatchedArguments( Seq( - ArgSig.Simple( + ArgSig( Some("zzzz"), Some('z'), None, None, mainargs.TokensReader.StringRead, positional = false, - isHidden = false + hidden = false ) ), List(), @@ -108,23 +108,23 @@ object ClassTests extends TestSuite { barParser.constructRaw(Seq("-w", "-x", "1")) ==> Result.Failure.MismatchedArguments( Seq( - ArgSig.Simple( + ArgSig( None, Some('y'), None, None, mainargs.TokensReader.IntRead, positional = false, - isHidden = false + hidden = false ), - ArgSig.Simple( + ArgSig( Some("zzzz"), Some('z'), None, None, mainargs.TokensReader.StringRead, positional = false, - isHidden = false + hidden = false ) ), List(), @@ -143,12 +143,12 @@ object ClassTests extends TestSuite { case Result.Failure.InvalidArguments( Seq( Result.ParamError.Failed( - ArgSig.Simple(None, Some('x'), None, None, _, false, _), + ArgSig(None, Some('x'), None, None, _, false, _), Seq("xxx"), _ ), Result.ParamError.Failed( - ArgSig.Simple(None, Some('y'), None, None, _, false, _), + ArgSig(None, Some('y'), None, None, _, false, _), Seq("hohoho"), _ ) diff --git a/mainargs/test/src/ConstantTests.scala b/mainargs/test/src/ConstantTests.scala new file mode 100644 index 0000000..d5e9d0e --- /dev/null +++ b/mainargs/test/src/ConstantTests.scala @@ -0,0 +1,27 @@ +package mainargs +import utest._ + +object ConstantTests extends TestSuite { + + case class Injected() + implicit def InjectedTokensReader: TokensReader.Constant[Injected] = + new TokensReader.Constant[Injected]{ + def read() = Right(new Injected()) + } + object Base { + @main + def flaggy(a: Injected, b: Boolean) = a.toString + " " + b + } + val check = new Checker(ParserForMethods(Base), allowPositional = true) + + val tests = Tests { + test - check( + List("-b", "true"), + Result.Success("Injected() true") + ) + test - check( + List("-b", "false"), + Result.Success("Injected() false") + ) + } +} diff --git a/mainargs/test/src/CoreTests.scala b/mainargs/test/src/CoreTests.scala index 531cbd8..f80715f 100644 --- a/mainargs/test/src/CoreTests.scala +++ b/mainargs/test/src/CoreTests.scala @@ -58,9 +58,9 @@ class CoreTests(allowPositional: Boolean) extends TestSuite { names == List("foo", "bar", "qux", "ex") ) - val evaledArgs = check.mains.value.map(_.argSigs.map { - case ArgSig.Simple(name, s, docs, None, parser, _, _) => (s, docs, None, parser) - case ArgSig.Simple(name, s, docs, Some(default), parser, _, _) => + val evaledArgs = check.mains.value.map(_.flattenedArgSigs.map { + case ArgSig(name, s, docs, None, parser, _, _) => (s, docs, None, parser) + case ArgSig(name, s, docs, Some(default), parser, _, _) => (s, docs, Some(default(CoreBase)), parser) }) @@ -113,7 +113,7 @@ class CoreTests(allowPositional: Boolean) extends TestSuite { test("missingParams") { test - assertMatch(check.parseInvoke(List("bar"))) { case Result.Failure.MismatchedArguments( - Seq(ArgSig.Simple(None, Some('i'), _, _, _, _, _)), + Seq(ArgSig(None, Some('i'), _, _, _, _, _)), Nil, Nil, None @@ -121,7 +121,7 @@ class CoreTests(allowPositional: Boolean) extends TestSuite { } test - assertMatch(check.parseInvoke(List("qux", "-s", "omg"))) { case Result.Failure.MismatchedArguments( - Seq(ArgSig.Simple(None, Some('i'), _, _, _, _, _)), + Seq(ArgSig(None, Some('i'), _, _, _, _, _)), Nil, Nil, None @@ -150,14 +150,14 @@ object CorePositionalDisabledOnlyTests extends TestSuite { test - check( List("bar", "2"), MismatchedArguments( - missing = List(ArgSig.Simple( + missing = List(ArgSig( None, Some('i'), None, None, TokensReader.IntRead, positional = false, - isHidden = false + hidden = false )), unknown = List("2") ) @@ -165,14 +165,14 @@ object CorePositionalDisabledOnlyTests extends TestSuite { test - check( List("qux", "2"), MismatchedArguments( - missing = List(ArgSig.Simple( + missing = List(ArgSig( None, Some('i'), None, None, TokensReader.IntRead, positional = false, - isHidden = false + hidden = false )), unknown = List("2") ) @@ -180,14 +180,14 @@ object CorePositionalDisabledOnlyTests extends TestSuite { test - check( List("qux", "3", "x"), MismatchedArguments( - missing = List(ArgSig.Simple( + missing = List(ArgSig( None, Some('i'), None, None, TokensReader.IntRead, positional = false, - isHidden = false + hidden = false )), unknown = List("3", "x") ) @@ -202,14 +202,14 @@ object CorePositionalDisabledOnlyTests extends TestSuite { test("invalidParams") - check( List("bar", "lol"), MismatchedArguments( - missing = List(ArgSig.Simple( + missing = List(ArgSig( None, Some('i'), None, None, TokensReader.IntRead, positional = false, - isHidden = false + hidden = false )), unknown = List("lol") ) @@ -219,7 +219,7 @@ object CorePositionalDisabledOnlyTests extends TestSuite { test("redundantParams") - check( List("qux", "1", "-i", "2"), MismatchedArguments( - missing = List(ArgSig.Simple(None, Some('i'), None, None, TokensReader.IntRead, positional = false, isHidden = false)), + missing = List(ArgSig(None, Some('i'), None, None, TokensReader.IntRead, positional = false, hidden = false)), unknown = List("1", "-i", "2") ) ) @@ -249,7 +249,7 @@ object CorePositionalEnabledOnlyTests extends TestSuite { ) { case Result.Failure.InvalidArguments( List(Result.ParamError.Failed( - ArgSig.Simple(None, Some('i'), _, _, _, _, _), + ArgSig(None, Some('i'), _, _, _, _, _), Seq("lol"), _ )) @@ -262,7 +262,7 @@ object CorePositionalEnabledOnlyTests extends TestSuite { case Result.Failure.MismatchedArguments( Nil, Nil, - Seq((ArgSig.Simple(None, Some('i'), _, _, _, _, _), Seq("1", "2"))), + Seq((ArgSig(None, Some('i'), _, _, _, _, _), Seq("1", "2"))), None ) => } diff --git a/mainargs/test/src/PositionalTests.scala b/mainargs/test/src/PositionalTests.scala index 3361072..c82d092 100644 --- a/mainargs/test/src/PositionalTests.scala +++ b/mainargs/test/src/PositionalTests.scala @@ -14,23 +14,23 @@ object PositionalTests extends TestSuite { List("true", "true", "true"), Result.Failure.MismatchedArguments( Vector( - ArgSig.Simple( + ArgSig( None, Some('x'), None, None, TokensReader.BooleanRead, positional = false, - isHidden = false + hidden = false ), - ArgSig.Simple( + ArgSig( None, Some('z'), None, None, TokensReader.BooleanRead, positional = false, - isHidden = false + hidden = false ) ), List("true", "true"), @@ -46,23 +46,23 @@ object PositionalTests extends TestSuite { List("-x", "true", "-y", "false", "-z", "false"), Result.Failure.MismatchedArguments( Vector( - ArgSig.Simple( + ArgSig( None, Some('y'), None, None, TokensReader.BooleanRead, positional = true, - isHidden = false + hidden = false ), - ArgSig.Simple( + ArgSig( None, Some('z'), None, None, TokensReader.BooleanRead, positional = false, - isHidden = false + hidden = false ) ), List("-y", "false", "-z", "false"), diff --git a/mainargs/test/src/VarargsTests.scala b/mainargs/test/src/VarargsBaseTests.scala similarity index 81% rename from mainargs/test/src/VarargsTests.scala rename to mainargs/test/src/VarargsBaseTests.scala index bda3efa..5b67bab 100644 --- a/mainargs/test/src/VarargsTests.scala +++ b/mainargs/test/src/VarargsBaseTests.scala @@ -1,7 +1,7 @@ package mainargs import utest._ -trait VarargsTests extends TestSuite { +trait VarargsBaseTests extends TestSuite { def check: Checker[_] def isNewVarargsTests: Boolean val tests = Tests { @@ -42,8 +42,8 @@ trait VarargsTests extends TestSuite { case Result.Failure.InvalidArguments( List( Result.ParamError.Failed( - ArgSig.Leftover("nums", _, _), - Seq("--nums"), + ArgSig(Some("nums"), _, _, _, _, _, _), + Seq("--nums", "31337"), """java.lang.NumberFormatException: For input string: "--nums"""" | """java.lang.NumberFormatException: --nums""" ) @@ -57,8 +57,8 @@ trait VarargsTests extends TestSuite { case Result.Failure.InvalidArguments( List( Result.ParamError.Failed( - ArgSig.Leftover("nums", _, _), - Seq("--nums"), + ArgSig(Some("nums"), _, _, _, _, _, _), + Seq("1", "2", "3", "--nums", "4"), "java.lang.NumberFormatException: For input string: \"--nums\"" | "java.lang.NumberFormatException: --nums" ) @@ -75,7 +75,7 @@ trait VarargsTests extends TestSuite { test("notEnoughNormalArgsStillFails") { assertMatch(check.parseInvoke(List("mixedVariadic"))) { case Result.Failure.MismatchedArguments( - Seq(ArgSig.Simple(Some("first"), _, _, _, _, _, _)), + Seq(ArgSig(Some("first"), _, _, _, _, _, _)), Nil, Nil, None @@ -89,16 +89,10 @@ trait VarargsTests extends TestSuite { case Result.Failure.InvalidArguments( List( Result.ParamError.Failed( - ArgSig.Leftover("nums", _, _), - Seq("aa"), + ArgSig(Some("nums"), _, _, _, _, _, _), + Seq("aa", "bb", "3"), "java.lang.NumberFormatException: For input string: \"aa\"" | "java.lang.NumberFormatException: aa" - ), - Result.ParamError.Failed( - ArgSig.Leftover("nums", _, _), - Seq("bb"), - "java.lang.NumberFormatException: For input string: \"bb\"" | - "java.lang.NumberFormatException: bb" ) ) ) => @@ -110,7 +104,7 @@ trait VarargsTests extends TestSuite { case Result.Failure.InvalidArguments( List( Result.ParamError.Failed( - ArgSig.Simple(Some("first"), _, _, _, _, _, _), + ArgSig(Some("first"), _, _, _, _, _, _), Seq("aa"), "java.lang.NumberFormatException: For input string: \"aa\"" | "java.lang.NumberFormatException: aa" diff --git a/mainargs/test/src/VarargsCustomTests.scala b/mainargs/test/src/VarargsCustomTests.scala new file mode 100644 index 0000000..0073874 --- /dev/null +++ b/mainargs/test/src/VarargsCustomTests.scala @@ -0,0 +1,43 @@ +package mainargs +import utest._ + +object VarargsCustomTests extends VarargsBaseTests { + // Test that we are able to replace the `Leftover` type entirely with our + // own implementation + class Wrapper[T](val unwrap: Seq[T]) + class WrapperRead[T](implicit wrapped: TokensReader.Simple[T]) + extends TokensReader.Leftover[Wrapper[T], T] { + def read(strs: Seq[String]) = { + val results = strs.map(s => implicitly[TokensReader.Simple[T]].read(Seq(s))) + val failures = results.collect { case Left(x) => x } + val successes = results.collect { case Right(x) => x } + + if (failures.nonEmpty) Left(failures.head) + else Right(new Wrapper(successes)) + } + def shortName = wrapped.shortName + } + + implicit def WrapperRead[T: TokensReader.Simple]: TokensReader[Wrapper[T]] = + new WrapperRead[T] + + object Base { + @main + def pureVariadic(nums: Wrapper[Int]) = nums.unwrap.sum + + @main + def mixedVariadic(@arg(short = 'f') first: Int, args: Wrapper[String]) = { + first + args.unwrap.mkString + } + @main + def mixedVariadicWithDefault( + @arg(short = 'f') first: Int = 1337, + args: Wrapper[String] + ) = { + first + args.unwrap.mkString + } + } + + val check = new Checker(ParserForMethods(Base), allowPositional = true) + val isNewVarargsTests = true +} diff --git a/mainargs/test/src/NewVarargsTests.scala b/mainargs/test/src/VarargsNewTests.scala similarity index 91% rename from mainargs/test/src/NewVarargsTests.scala rename to mainargs/test/src/VarargsNewTests.scala index c55065b..a0e9081 100644 --- a/mainargs/test/src/NewVarargsTests.scala +++ b/mainargs/test/src/VarargsNewTests.scala @@ -1,6 +1,6 @@ package mainargs import utest._ -object NewVarargsTests extends VarargsTests { +object VarargsNewTests extends VarargsBaseTests { object Base { @main def pureVariadic(nums: Leftover[Int]) = nums.value.sum diff --git a/mainargs/test/src/VarargsWrappedTests.scala b/mainargs/test/src/VarargsWrappedTests.scala new file mode 100644 index 0000000..00d03bb --- /dev/null +++ b/mainargs/test/src/VarargsWrappedTests.scala @@ -0,0 +1,40 @@ +package mainargs +import utest._ + +object VarargsWrappedTests extends VarargsBaseTests { + // Test that we are able to wrap the `Leftover` type we use for Varargs in + // our own custom types, and have things work + class Wrapper[T](val unwrap: T) + class WrapperRead[T](implicit wrapped: TokensReader.ShortNamed[T]) + extends TokensReader.Leftover[Wrapper[T], T] { + + def read(strs: Seq[String]) = wrapped + .asInstanceOf[TokensReader.Leftover[T, _]] + .read(strs).map(new Wrapper(_)) + + def shortName = wrapped.shortName + } + + implicit def WrapperRead[T: TokensReader.ShortNamed]: TokensReader[Wrapper[T]] = + new WrapperRead[T] + + object Base { + @main + def pureVariadic(nums: Wrapper[Leftover[Int]]) = nums.unwrap.value.sum + + @main + def mixedVariadic(@arg(short = 'f') first: Int, args: Wrapper[Leftover[String]]) = { + first + args.unwrap.value.mkString + } + @main + def mixedVariadicWithDefault( + @arg(short = 'f') first: Int = 1337, + args: Wrapper[Leftover[String]] + ) = { + first + args.unwrap.value.mkString + } + } + + val check = new Checker(ParserForMethods(Base), allowPositional = true) + val isNewVarargsTests = true +} diff --git a/readme.md b/readme.md index 23ea70b..898dd3b 100644 --- a/readme.md +++ b/readme.md @@ -310,7 +310,7 @@ customize your usage: - `doc: String`: a documentation string used to provide additional information about the command -- `isHidden: Boolean`: if `true` this arg will not be included in the rendered help text. +- `hidden: Boolean`: if `true` this arg will not be included in the rendered help text. ## Customization