From a39c141ee932e5e250f323476bb48359f9076c20 Mon Sep 17 00:00:00 2001 From: Scala Steward Date: Sun, 25 Feb 2024 18:48:30 +0000 Subject: [PATCH 1/3] Update scalafmt-core to 3.8.0 --- .scalafmt.conf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.scalafmt.conf b/.scalafmt.conf index 4e8575e27..80a98133c 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,4 +1,4 @@ -version = "3.5.9" +version = "3.8.0" runner.dialect = scala212 fileOverride { "glob:**/core/src/test/scala/**" { From 12bd519e9e132dd6ce125b4af67f5af9e4eda59e Mon Sep 17 00:00:00 2001 From: Scala Steward Date: Sun, 25 Feb 2024 18:49:34 +0000 Subject: [PATCH 2/3] Reformat with scalafmt 3.8.0 Executed command: scalafmt --non-interactive --- .../main/scala/org/bykn/bosatsu/Macro.scala | 20 +- .../scala/org/bykn/bosatsu/TestBench.scala | 31 +- .../scala/org/bykn/bosatsu/PathModule.scala | 51 +- .../org/bykn/bosatsu/TypedExprToProto.scala | 1226 +++++++---- .../scala/org/bykn/bosatsu/JsonTest.scala | 13 +- .../org/bykn/bosatsu/PathModuleTest.scala | 160 +- .../org/bykn/bosatsu/TestProtoType.scala | 134 +- .../bosatsu/codegen/python/CodeTest.scala | 217 +- .../codegen/python/PythonGenTest.scala | 76 +- .../src/main/scala/org/bykn/bosatsu/Par.scala | 12 +- .../src/main/scala/org/bykn/bosatsu/Par.scala | 15 +- .../org/bykn/bosatsu/BindingStatement.scala | 10 +- .../org/bykn/bosatsu/CollectionUtils.scala | 31 +- .../org/bykn/bosatsu/CommentStatement.scala | 24 +- .../scala/org/bykn/bosatsu/Declaration.scala | 1107 ++++++---- .../org/bykn/bosatsu/DefRecursionCheck.scala | 331 +-- .../scala/org/bykn/bosatsu/DefStatement.scala | 18 +- .../scala/org/bykn/bosatsu/EditDistance.scala | 6 +- .../scala/org/bykn/bosatsu/Evaluation.scala | 115 +- .../scala/org/bykn/bosatsu/ExportedName.scala | 155 +- .../main/scala/org/bykn/bosatsu/Expr.scala | 242 ++- .../main/scala/org/bykn/bosatsu/FfiCall.scala | 13 +- .../src/main/scala/org/bykn/bosatsu/Fix.scala | 9 +- .../scala/org/bykn/bosatsu/Identifier.scala | 49 +- .../main/scala/org/bykn/bosatsu/Import.scala | 61 +- .../scala/org/bykn/bosatsu/Indented.scala | 20 +- .../scala/org/bykn/bosatsu/IorMethods.scala | 4 +- .../main/scala/org/bykn/bosatsu/Json.scala | 79 +- .../main/scala/org/bykn/bosatsu/Kind.scala | 26 +- .../scala/org/bykn/bosatsu/KindFormula.scala | 27 +- .../scala/org/bykn/bosatsu/ListLang.scala | 95 +- .../scala/org/bykn/bosatsu/ListUtil.scala | 16 +- .../src/main/scala/org/bykn/bosatsu/Lit.scala | 38 +- .../scala/org/bykn/bosatsu/LocationMap.scala | 73 +- .../scala/org/bykn/bosatsu/MainModule.scala | 25 +- .../scala/org/bykn/bosatsu/Matchless.scala | 472 +++-- .../bykn/bosatsu/MatchlessFromTypedExpr.scala | 47 +- .../org/bykn/bosatsu/MatchlessToValue.scala | 139 +- .../scala/org/bykn/bosatsu/MemoryMain.scala | 64 +- .../scala/org/bykn/bosatsu/NameKind.scala | 29 +- .../scala/org/bykn/bosatsu/Operators.scala | 88 +- .../scala/org/bykn/bosatsu/OptIndent.scala | 24 +- .../main/scala/org/bykn/bosatsu/Package.scala | 405 ++-- .../org/bykn/bosatsu/PackageCustoms.scala | 147 +- .../scala/org/bykn/bosatsu/PackageError.scala | 661 ++++-- .../scala/org/bykn/bosatsu/PackageMap.scala | 491 +++-- .../scala/org/bykn/bosatsu/PackageName.scala | 5 +- .../main/scala/org/bykn/bosatsu/Padding.scala | 18 +- .../org/bykn/bosatsu/ParallelViaProduct.scala | 2 +- .../main/scala/org/bykn/bosatsu/Parser.scala | 217 +- .../main/scala/org/bykn/bosatsu/PathGen.scala | 25 +- .../main/scala/org/bykn/bosatsu/Pattern.scala | 642 +++--- .../main/scala/org/bykn/bosatsu/Predef.scala | 131 +- .../main/scala/org/bykn/bosatsu/Program.scala | 9 +- .../scala/org/bykn/bosatsu/Referant.scala | 112 +- .../scala/org/bykn/bosatsu/SelfCallKind.scala | 7 +- .../org/bykn/bosatsu/SourceConverter.scala | 1709 +++++++++------ .../scala/org/bykn/bosatsu/Statement.scala | 414 ++-- .../scala/org/bykn/bosatsu/StringUtil.scala | 78 +- .../main/scala/org/bykn/bosatsu/Test.scala | 74 +- .../org/bykn/bosatsu/TotalityCheck.scala | 737 ++++--- .../scala/org/bykn/bosatsu/TypeParser.scala | 94 +- .../main/scala/org/bykn/bosatsu/TypeRef.scala | 120 +- .../org/bykn/bosatsu/TypeRefConverter.scala | 68 +- .../scala/org/bykn/bosatsu/TypedExpr.scala | 1056 ++++++---- .../bykn/bosatsu/TypedExprNormalization.scala | 522 +++-- .../org/bykn/bosatsu/UnusedLetCheck.scala | 81 +- .../main/scala/org/bykn/bosatsu/Value.scala | 140 +- .../scala/org/bykn/bosatsu/ValueToDoc.scala | 162 +- .../scala/org/bykn/bosatsu/ValueToJson.scala | 570 +++--- .../scala/org/bykn/bosatsu/Variance.scala | 49 +- .../bykn/bosatsu/codegen/python/Code.scala | 330 +-- .../bosatsu/codegen/python/PythonGen.scala | 1289 +++++++----- .../scala/org/bykn/bosatsu/graph/Dag.scala | 3 +- .../org/bykn/bosatsu/graph/Memoize.scala | 45 +- .../scala/org/bykn/bosatsu/graph/Paths.scala | 37 +- .../org/bykn/bosatsu/graph/Toposort.scala | 42 +- .../scala/org/bykn/bosatsu/graph/Tree.scala | 39 +- .../org/bykn/bosatsu/pattern/Matcher.scala | 17 +- .../bosatsu/pattern/NamedSeqPattern.scala | 109 +- .../org/bykn/bosatsu/pattern/SeqPart.scala | 34 +- .../org/bykn/bosatsu/pattern/SeqPattern.scala | 585 +++--- .../org/bykn/bosatsu/pattern/Splitter.scala | 41 +- .../org/bykn/bosatsu/rankn/DataRepr.scala | 9 +- .../org/bykn/bosatsu/rankn/DefinedType.scala | 73 +- .../scala/org/bykn/bosatsu/rankn/Infer.scala | 1824 ++++++++++------- .../bykn/bosatsu/rankn/ParsedTypeEnv.scala | 11 +- .../scala/org/bykn/bosatsu/rankn/Ref.scala | 52 +- .../scala/org/bykn/bosatsu/rankn/Type.scala | 745 ++++--- .../org/bykn/bosatsu/rankn/TypeEnv.scala | 136 +- .../main/scala/org/bykn/bosatsu/set/Rel.scala | 17 +- .../org/bykn/bosatsu/set/Relatable.scala | 237 ++- .../scala/org/bykn/bosatsu/set/SetOps.scala | 126 +- .../bykn/bosatsu/CollectionUtilsTest.scala | 9 +- .../org/bykn/bosatsu/DeclarationTest.scala | 142 +- .../bykn/bosatsu/DefRecursionCheckTest.scala | 8 +- .../org/bykn/bosatsu/EvaluationTest.scala | 1683 ++++++++++----- .../scala/org/bykn/bosatsu/FreeVarTest.scala | 19 +- .../src/test/scala/org/bykn/bosatsu/Gen.scala | 1178 +++++++---- .../test/scala/org/bykn/bosatsu/GenJson.scala | 34 +- .../scala/org/bykn/bosatsu/GenValue.scala | 5 +- .../test/scala/org/bykn/bosatsu/IntLaws.scala | 71 +- .../scala/org/bykn/bosatsu/JsonTest.scala | 82 +- .../org/bykn/bosatsu/KindFormulaTest.scala | 14 +- .../org/bykn/bosatsu/KindParseTest.scala | 20 +- .../scala/org/bykn/bosatsu/ListUtilTest.scala | 49 +- .../test/scala/org/bykn/bosatsu/LitTest.scala | 32 +- .../org/bykn/bosatsu/LocationMapTest.scala | 21 +- .../org/bykn/bosatsu/MatchlessTests.scala | 75 +- .../scala/org/bykn/bosatsu/MonadGen.scala | 2 +- .../scala/org/bykn/bosatsu/OperatorTest.scala | 51 +- .../scala/org/bykn/bosatsu/PackageTest.scala | 51 +- .../test/scala/org/bykn/bosatsu/ParTest.scala | 2 +- .../scala/org/bykn/bosatsu/ParserTest.scala | 1272 ++++++++---- .../scala/org/bykn/bosatsu/PatternTest.scala | 39 +- .../org/bykn/bosatsu/SelfCallKindTest.scala | 28 +- .../bykn/bosatsu/SourceConverterTest.scala | 64 +- .../scala/org/bykn/bosatsu/TestUtils.scala | 148 +- .../scala/org/bykn/bosatsu/TotalityTest.scala | 436 ++-- .../scala/org/bykn/bosatsu/TypeRefTest.scala | 19 +- .../org/bykn/bosatsu/TypedExprTest.scala | 341 +-- .../scala/org/bykn/bosatsu/ValueTest.scala | 11 +- .../org/bykn/bosatsu/ValueToDocTest.scala | 62 +- .../scala/org/bykn/bosatsu/VarianceTest.scala | 25 +- .../codegen/python/PythonGenTest.scala | 20 +- .../org/bykn/bosatsu/graph/ToposortTest.scala | 42 +- .../org/bykn/bosatsu/graph/TreeTest.scala | 34 +- .../bykn/bosatsu/pattern/SeqPatternTest.scala | 401 ++-- .../pattern/StringSeqPatternSetLaws.scala | 113 +- .../org/bykn/bosatsu/rankn/NTypeGen.scala | 84 +- .../bykn/bosatsu/rankn/RankNInferTest.scala | 904 +++++--- .../org/bykn/bosatsu/rankn/TypeTest.scala | 333 ++- .../scala/org/bykn/bosatsu/set/RelLaws.scala | 56 +- .../org/bykn/bosatsu/set/SetOpsLaws.scala | 171 +- .../scala/org/bykn/bosatsu/jsapi/JsApi.scala | 48 +- .../scala/org/bykn/bosatsu/jsui/Action.scala | 5 +- .../scala/org/bykn/bosatsu/jsui/App.scala | 4 +- .../scala/org/bykn/bosatsu/jsui/Store.scala | 87 +- 138 files changed, 17647 insertions(+), 10587 deletions(-) diff --git a/base/src/main/scala/org/bykn/bosatsu/Macro.scala b/base/src/main/scala/org/bykn/bosatsu/Macro.scala index 6e09b1740..623b5f2f6 100644 --- a/base/src/main/scala/org/bykn/bosatsu/Macro.scala +++ b/base/src/main/scala/org/bykn/bosatsu/Macro.scala @@ -15,14 +15,15 @@ class Macro(val c: Context) { if (f.exists()) { val res = Source.fromFile(s, "UTF-8").getLines().mkString("\n") Some(c.Expr[String](q"$res")) - } - else { + } else { None } - } - catch { + } catch { case NonFatal(err) => - c.abort(c.enclosingPosition, s"could not read existing file: $s. Exception: $err") + c.abort( + c.enclosingPosition, + s"could not read existing file: $s. Exception: $err" + ) } file.tree match { @@ -34,11 +35,14 @@ class Macro(val c: Context) { .getOrElse { c.abort( c.enclosingPosition, - s"no file found at: $s. working directory is ${System.getProperty("user.dir")}") + s"no file found at: $s. working directory is ${System.getProperty("user.dir")}" + ) } case otherTree => - c.abort(c.enclosingPosition, s"expected string literal, found: $otherTree") + c.abort( + c.enclosingPosition, + s"expected string literal, found: $otherTree" + ) } } } - diff --git a/bench/src/main/scala/org/bykn/bosatsu/TestBench.scala b/bench/src/main/scala/org/bykn/bosatsu/TestBench.scala index 9ceedc3d1..b1aa3233e 100644 --- a/bench/src/main/scala/org/bykn/bosatsu/TestBench.scala +++ b/bench/src/main/scala/org/bykn/bosatsu/TestBench.scala @@ -12,7 +12,10 @@ class TestBench { // don't use threads in the benchmark which will complicate matters import DirectEC.directEC - private def prepPackages(packages: List[String], mainPackS: String): (PackageMap.Inferred, PackageName) = { + private def prepPackages( + packages: List[String], + mainPackS: String + ): (PackageMap.Inferred, PackageName) = { val mainPack = PackageName.parse(mainPackS).get val parsed = packages.zipWithIndex.traverse { case (pack, i) => @@ -28,11 +31,18 @@ class TestBench { val d = p.showContext(LocationMap.Colorize.None) System.err.println(d.render(100)) } - sys.error("failed to parse") //errs.toString) + sys.error("failed to parse") // errs.toString) } - implicit val show: Show[(String, LocationMap)] = Show.show { case (s, _) => s } - PackageMap.resolveThenInfer(PackageMap.withPredefA(("predef", LocationMap("")), parsedPaths), Nil).strictToValidated match { + implicit val show: Show[(String, LocationMap)] = Show.show { case (s, _) => + s + } + PackageMap + .resolveThenInfer( + PackageMap.withPredefA(("predef", LocationMap("")), parsedPaths), + Nil + ) + .strictToValidated match { case Validated.Valid(packMap) => (packMap, mainPack) case other => sys.error(s"expected clean compilation: $other") @@ -40,11 +50,13 @@ class TestBench { } def gauss(n: Int) = prepPackages( - List(s""" + List(s""" package Gauss gauss$n = range($n).foldLeft(0, add) -"""), "Gauss") +"""), + "Gauss" + ) val compiled0: (PackageMap.Inferred, PackageName) = gauss(10) @@ -69,7 +81,8 @@ gauss$n = range($n).foldLeft(0, add) } val compiled2 = - prepPackages(List(""" + prepPackages( + List(""" package Euler4 def operator >(a, b): @@ -132,7 +145,9 @@ max_pal_opt = max_of(99, \n1 -> first_of(99, product_palindrome(n1))) max_pal = match max_pal_opt: Some(m): m None: 0 -"""), "Euler4") +"""), + "Euler4" + ) @Benchmark def bench2(): Unit = { val c = compiled2 diff --git a/cli/src/main/scala/org/bykn/bosatsu/PathModule.scala b/cli/src/main/scala/org/bykn/bosatsu/PathModule.scala index 402718130..2b61b90f2 100644 --- a/cli/src/main/scala/org/bykn/bosatsu/PathModule.scala +++ b/cli/src/main/scala/org/bykn/bosatsu/PathModule.scala @@ -39,7 +39,10 @@ object PathModule extends MainModule[IO] { def readInterfaces(paths: List[Path]): IO[List[Package.Interface]] = ProtoConverter.readInterfaces(paths) - def writeInterfaces(interfaces: List[Package.Interface], path: Path): IO[Unit] = + def writeInterfaces( + interfaces: List[Package.Interface], + path: Path + ): IO[Unit] = ProtoConverter.writeInterfaces(interfaces, path) def writePackages[A](packages: List[Package.Typed[A]], path: Path): IO[Unit] = @@ -54,14 +57,14 @@ object PathModule extends MainModule[IO] { Some(IO { f.listFiles.iterator.map(_.toPath).toList }) - } - else None + } else None } } } - def hasExtension(str: String): Path => Boolean = - { (path: Path) => path.toString.endsWith(str) } + def hasExtension(str: String): Path => Boolean = { (path: Path) => + path.toString.endsWith(str) + } def print(str: => String): IO[Unit] = IO(println(str)) @@ -71,7 +74,7 @@ object PathModule extends MainModule[IO] { def report(io: IO[Output]): IO[ExitCode] = io.attempt.flatMap { case Right(out) => reportOutput(out) - case Left(err) => reportException(err).as(ExitCode.Error) + case Left(err) => reportException(err).as(ExitCode.Error) } def reportOutput(out: Output): IO[ExitCode] = @@ -84,25 +87,27 @@ object PathModule extends MainModule[IO] { print(testReport.doc.render(80)).as(code) case Output.EvaluationResult(_, tpe, resDoc) => val tDoc = rankn.Type.fullyResolvedDocument.document(tpe) - val doc = resDoc.value + (Doc.lineOrEmpty + Doc.text(": ") + tDoc).nested(4) + val doc = + resDoc.value + (Doc.lineOrEmpty + Doc.text(": ") + tDoc).nested(4) print(doc.render(100)).as(ExitCode.Success) case Output.JsonOutput(json, pathOpt) => val jdoc = json.toDoc (pathOpt match { case Some(path) => CodeGenWrite.writeDoc(path, jdoc) - case None => IO(println(jdoc.renderTrim(80))) + case None => IO(println(jdoc.renderTrim(80))) }).as(ExitCode.Success) case Output.TranspileOut(outs, base) => def path(p: List[String]): Path = p.foldLeft(base)(_.resolve(_)) - outs.toList.map { case (p, d) => - (p, CodeGenWrite.writeDoc(path(p.toList), d)) - } - .sortBy(_._1) - .traverse_ { case (_, w) => w } - .as(ExitCode.Success) + outs.toList + .map { case (p, d) => + (p, CodeGenWrite.writeDoc(path(p.toList), d)) + } + .sortBy(_._1) + .traverse_ { case (_, w) => w } + .as(ExitCode.Success) case Output.CompileOut(packList, ifout, output) => val ifres = ifout match { @@ -128,15 +133,16 @@ object PathModule extends MainModule[IO] { output match { case None => IO.blocking { - doc.renderStreamTrim(80) + doc + .renderStreamTrim(80) .iterator .foreach(System.out.print) System.out.println("") - } - .as(ExitCode.Success) + }.as(ExitCode.Success) case Some(p) => - CodeGenWrite.writeDoc(p, doc) + CodeGenWrite + .writeDoc(p, doc) .as(ExitCode.Success) } } @@ -154,7 +160,8 @@ object PathModule extends MainModule[IO] { import scala.jdk.CollectionConverters._ def getP(p: Path): Option[PackageName] = { - val subPath = p.relativize(packFile) + val subPath = p + .relativize(packFile) .asScala .map { part => part.toString.toLowerCase.capitalize @@ -164,7 +171,7 @@ object PathModule extends MainModule[IO] { val dropExtension = """(.*)\.[^.]*$""".r val toParse = subPath match { case dropExtension(prefix) => prefix - case _ => subPath + case _ => subPath } PackageName.parse(toParse) } @@ -172,9 +179,9 @@ object PathModule extends MainModule[IO] { @annotation.tailrec def loop(roots: List[Path]): Option[PackageName] = roots match { - case Nil => None + case Nil => None case h :: _ if packFile.startsWith(h) => getP(h) - case _ :: t => loop(t) + case _ :: t => loop(t) } if (packFile.toString.isEmpty) None diff --git a/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala b/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala index ca4879f62..ce1c6037b 100644 --- a/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala +++ b/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala @@ -6,7 +6,12 @@ import cats.data.{NonEmptyList, ReaderT, StateT} import cats.effect.IO import org.bykn.bosatsu.graph.Memoize import java.nio.file.Path -import java.io.{FileInputStream, FileOutputStream, BufferedInputStream, BufferedOutputStream} +import java.io.{ + FileInputStream, + FileOutputStream, + BufferedInputStream, + BufferedOutputStream +} import org.bykn.bosatsu.rankn.{DefinedType, Type, TypeEnv} import scala.util.{Failure, Success, Try} import scala.reflect.ClassTag @@ -17,9 +22,8 @@ import Identifier.{Bindable, Constructor} import cats.implicits._ -/** - * convert TypedExpr to and from Protobuf representation - */ +/** convert TypedExpr to and from Protobuf representation + */ object ProtoConverter { case class IdAssignment[A1, A2](mapping: Map[A1, Int], inOrder: Vector[A2]) { def get(a1: A1, a2: => A2): Either[(IdAssignment[A1, A2], Int), Int] = @@ -27,9 +31,8 @@ object ProtoConverter { case Some(id) => Right(id) case None => val id = inOrder.size - val next = copy( - mapping = mapping.updated(a1, id), - inOrder = inOrder :+ a2) + val next = + copy(mapping = mapping.updated(a1, id), inOrder = inOrder :+ a2) Left((next, id)) } @@ -38,25 +41,42 @@ object ProtoConverter { } object IdAssignment { - def empty[A1, A2]: IdAssignment[A1, A2] = IdAssignment(Map.empty, Vector.empty) + def empty[A1, A2]: IdAssignment[A1, A2] = + IdAssignment(Map.empty, Vector.empty) } case class SerState( - strings: IdAssignment[String, String], - types: IdAssignment[Type, proto.Type], - patterns: IdAssignment[Pattern[(PackageName, Constructor), Type], proto.Pattern], - expressions: IdAssignment[TypedExpr[Any], proto.TypedExpr]) { + strings: IdAssignment[String, String], + types: IdAssignment[Type, proto.Type], + patterns: IdAssignment[ + Pattern[(PackageName, Constructor), Type], + proto.Pattern + ], + expressions: IdAssignment[TypedExpr[Any], proto.TypedExpr] + ) { def stringId(s: String): Either[(SerState, Int), Int] = - strings.get(s, s).left.map { case (next, id) => (copy(strings = next), id) } + strings.get(s, s).left.map { case (next, id) => + (copy(strings = next), id) + } - def typeId(t: Type, protoType: => proto.Type): Either[(SerState, Int), Int] = - types.get(t, protoType).left.map { case (next, id) => (copy(types = next), id) } + def typeId( + t: Type, + protoType: => proto.Type + ): Either[(SerState, Int), Int] = + types.get(t, protoType).left.map { case (next, id) => + (copy(types = next), id) + } } object SerState { val empty: SerState = - SerState(IdAssignment.empty, IdAssignment.empty, IdAssignment.empty, IdAssignment.empty) + SerState( + IdAssignment.empty, + IdAssignment.empty, + IdAssignment.empty, + IdAssignment.empty + ) } type Tab[A] = StateT[Try, SerState, A] @@ -73,7 +93,7 @@ object ProtoConverter { case Failure(_) => System.err.println(message) self - } + } } private def tabFail[S, A](ex: Exception): Tab[A] = @@ -82,13 +102,14 @@ object ProtoConverter { Monad[Tab].pure(a) private def get(fn: SerState => Either[(SerState, Int), Int]): Tab[Int] = - StateT.get[Try, SerState] + StateT + .get[Try, SerState] .flatMap { ss => fn(ss) match { - case Right(idx) => StateT.pure(idx + 1) - case Left((ss, idx)) => - StateT.set[Try, SerState](ss).as(idx + 1) - } + case Right(idx) => StateT.pure(idx + 1) + case Left((ss, idx)) => + StateT.set[Try, SerState](ss).as(idx + 1) + } } private def getId(s: String): Tab[Int] = get(_.stringId(s)) @@ -97,11 +118,16 @@ object ProtoConverter { get(_.typeId(t, pt)) private def getProtoTypeTab(t: Type): Tab[Option[Int]] = - StateT.get[Try, SerState] + StateT + .get[Try, SerState] .map(_.types.indexOf(t).map(_ + 1)) - private def writePattern(p: Pattern[(PackageName, Constructor), Type], pp: proto.Pattern): Tab[Int] = - StateT.get[Try, SerState] + private def writePattern( + p: Pattern[(PackageName, Constructor), Type], + pp: proto.Pattern + ): Tab[Int] = + StateT + .get[Try, SerState] .flatMap { s => s.patterns.get(p, pp) match { case Right(_) => @@ -113,7 +139,8 @@ object ProtoConverter { } private def writeExpr(te: TypedExpr[Any], pte: proto.TypedExpr): Tab[Int] = - StateT.get[Try, SerState] + StateT + .get[Try, SerState] .flatMap { s => s.expressions.get(te, pte) match { case Right(_) => @@ -128,11 +155,12 @@ object ProtoConverter { t.run(SerState.empty) class DecodeState private ( - strings: Array[String], - types: Array[Type], - dts: Array[DefinedType[Kind.Arg]], - patterns: Array[Pattern[(PackageName, Constructor), Type]], - expr: Array[TypedExpr[Unit]]) { + strings: Array[String], + types: Array[Type], + dts: Array[DefinedType[Kind.Arg]], + patterns: Array[Pattern[(PackageName, Constructor), Type]], + expr: Array[TypedExpr[Unit]] + ) { def getString(idx: Int): Option[String] = if ((0 <= idx) && (idx < strings.length)) Some(strings(idx)) else None @@ -149,7 +177,10 @@ object ProtoConverter { if ((0 <= idx) && (idx < types.length)) Success(types(idx)) else Failure(new Exception(msg)) - def tryPattern(idx: Int, msg: => String): Try[Pattern[(PackageName, Constructor), Type]] = + def tryPattern( + idx: Int, + msg: => String + ): Try[Pattern[(PackageName, Constructor), Type]] = if ((0 <= idx) && (idx < patterns.length)) Success(patterns(idx)) else Failure(new Exception(msg)) @@ -170,7 +201,9 @@ object ProtoConverter { def withTypes(ary: Array[Type]): DecodeState = new DecodeState(strings, ary, dts, patterns, expr) - def withPatterns(ary: Array[Pattern[(PackageName, Constructor), Type]]): DecodeState = + def withPatterns( + ary: Array[Pattern[(PackageName, Constructor), Type]] + ): DecodeState = new DecodeState(strings, types, dts, ary, expr) def withExprs(ary: Array[TypedExpr[Unit]]): DecodeState = @@ -179,12 +212,20 @@ object ProtoConverter { object DecodeState { def init(strings: Seq[String]): DecodeState = - new DecodeState(strings.toArray, Array.empty, Array.empty, Array.empty, Array.empty) + new DecodeState( + strings.toArray, + Array.empty, + Array.empty, + Array.empty, + Array.empty + ) } type DTab[A] = ReaderT[Try, DecodeState, A] - private def find[A](idx: Int, context: => String)(fn: (DecodeState, Int) => Option[A]): DTab[A] = + private def find[A](idx: Int, context: => String)( + fn: (DecodeState, Int) => Option[A] + ): DTab[A] = ReaderT { decodeState => fn(decodeState, idx - 1) match { case Some(s) => Success(s) @@ -198,29 +239,37 @@ object ProtoConverter { private def lookupType(idx: Int, context: => String): DTab[Type] = find(idx, context)(_.getType(_)) - private def lookupDts(idx: Int, context: => String): DTab[DefinedType[Kind.Arg]] = + private def lookupDts( + idx: Int, + context: => String + ): DTab[DefinedType[Kind.Arg]] = find(idx, context)(_.getDt(_)) private def lookupExpr(idx: Int, context: => String): DTab[TypedExpr[Unit]] = find(idx, context)(_.getExpr(_)) - /** - * this is code to build tables of serialized dags. We use this for types, patterns, expressions - */ - private def buildTable[A, B: ClassTag](ary: Array[A])(fn: (A, Int => Try[B]) => Try[B]): Try[Array[B]] = { + /** this is code to build tables of serialized dags. We use this for types, + * patterns, expressions + */ + private def buildTable[A, B: ClassTag]( + ary: Array[A] + )(fn: (A, Int => Try[B]) => Try[B]): Try[Array[B]] = { val result = new Array[B](ary.length) def lookup(a: A, max: Int): Int => Try[B] = { idx => if (idx > 0 && idx <= max) Success(result(idx - 1)) - else Failure(new Exception(s"while decoding $a, invalid index $idx, max: $max")) + else + Failure( + new Exception(s"while decoding $a, invalid index $idx, max: $max") + ) } var idx = 0 var res: Failure[Array[B]] = null - while((idx < ary.length) && (res eq null)) { + while ((idx < ary.length) && (res eq null)) { val a = ary(idx) val lookupFn = lookup(a, idx) fn(a, lookupFn) match { - case Success(b) => result(idx) = b + case Success(b) => result(idx) = b case Failure(err) => res = Failure(err) } idx = idx + 1 @@ -231,7 +280,6 @@ object ProtoConverter { def buildTypes(types: Seq[proto.Type]): DTab[Array[Type]] = ReaderT[Try, DecodeState, Array[Type]] { ds => - def typeFromProto(p: proto.Type, tpe: Int => Try[Type]): Try[Type] = { import proto.Type.Value import bosatsu.TypedAst.{Type => _, _} @@ -275,10 +323,18 @@ object ProtoConverter { buildTable(types.toArray)(typeFromProto _) } - def buildPatterns(pats: Seq[proto.Pattern]): DTab[Array[Pattern[(PackageName, Constructor), Type]]] = - ReaderT[Try, DecodeState, Array[Pattern[(PackageName, Constructor), Type]]] { ds => - - def patternFromProto(p: proto.Pattern, pat: Int => Try[Pattern[(PackageName, Constructor), Type]]): Try[Pattern[(PackageName, Constructor), Type]] = { + def buildPatterns( + pats: Seq[proto.Pattern] + ): DTab[Array[Pattern[(PackageName, Constructor), Type]]] = + ReaderT[ + Try, + DecodeState, + Array[Pattern[(PackageName, Constructor), Type]] + ] { ds => + def patternFromProto( + p: proto.Pattern, + pat: Int => Try[Pattern[(PackageName, Constructor), Type]] + ): Try[Pattern[(PackageName, Constructor), Type]] = { import proto.Pattern.Value def str(i: Int): Try[String] = @@ -296,24 +352,36 @@ object ProtoConverter { case Value.NamedPat(proto.NamedPat(nidx, pidx, _)) => (bindable(nidx), pat(pidx)).mapN(Pattern.Named(_, _)) case Value.ListPat(proto.ListPat(lp, _)) => - def decodePart(part: proto.ListPart): Try[Pattern.ListPart[Pattern[(PackageName, Constructor), Type]]] = + def decodePart(part: proto.ListPart): Try[ + Pattern.ListPart[Pattern[(PackageName, Constructor), Type]] + ] = part.value match { - case proto.ListPart.Value.Empty => Failure(new Exception(s"invalid empty list pattern in $p")) - case proto.ListPart.Value.ItemPattern(p) => pat(p).map(Pattern.ListPart.Item(_)) - case proto.ListPart.Value.UnnamedList(_) => Success(Pattern.ListPart.WildList) - case proto.ListPart.Value.NamedList(idx) => bindable(idx).map { n => Pattern.ListPart.NamedList(n) } + case proto.ListPart.Value.Empty => + Failure(new Exception(s"invalid empty list pattern in $p")) + case proto.ListPart.Value.ItemPattern(p) => + pat(p).map(Pattern.ListPart.Item(_)) + case proto.ListPart.Value.UnnamedList(_) => + Success(Pattern.ListPart.WildList) + case proto.ListPart.Value.NamedList(idx) => + bindable(idx).map { n => Pattern.ListPart.NamedList(n) } } lp.toList.traverse(decodePart).map(Pattern.ListPat(_)) case Value.StrPat(proto.StrPat(items, _)) => def decodePart(part: proto.StrPart): Try[Pattern.StrPart] = part.value match { - case proto.StrPart.Value.Empty => Failure(new Exception(s"invalid empty list pattern in $p")) - case proto.StrPart.Value.LiteralStr(idx) => str(idx).map(Pattern.StrPart.LitStr(_)) - case proto.StrPart.Value.UnnamedStr(_) => Success(Pattern.StrPart.WildStr) - case proto.StrPart.Value.NamedStr(idx) => bindable(idx).map { n => Pattern.StrPart.NamedStr(n) } - case proto.StrPart.Value.UnnamedChar(_) => Success(Pattern.StrPart.WildChar) - case proto.StrPart.Value.NamedChar(idx) => bindable(idx).map { n => Pattern.StrPart.NamedChar(n) } + case proto.StrPart.Value.Empty => + Failure(new Exception(s"invalid empty list pattern in $p")) + case proto.StrPart.Value.LiteralStr(idx) => + str(idx).map(Pattern.StrPart.LitStr(_)) + case proto.StrPart.Value.UnnamedStr(_) => + Success(Pattern.StrPart.WildStr) + case proto.StrPart.Value.NamedStr(idx) => + bindable(idx).map { n => Pattern.StrPart.NamedStr(n) } + case proto.StrPart.Value.UnnamedChar(_) => + Success(Pattern.StrPart.WildChar) + case proto.StrPart.Value.NamedChar(idx) => + bindable(idx).map { n => Pattern.StrPart.NamedChar(n) } } items.toList match { @@ -326,12 +394,17 @@ object ProtoConverter { .map(Pattern.StrPat(_)) } case Value.AnnotationPat(proto.AnnotationPat(pidx, tidx, _)) => - (pat(pidx), ds.tryType(tidx - 1, s"invalid type index $tidx in: $p")) + ( + pat(pidx), + ds.tryType(tidx - 1, s"invalid type index $tidx in: $p") + ) .mapN(Pattern.Annotation(_, _)) case Value.StructPat(proto.StructPattern(packIdx, cidx, args, _)) => str(packIdx) .product(str(cidx)) - .flatMap { case (p, c) => fullNameFromStr(p, c, s"invalid structpat names: $p, $c") } + .flatMap { case (p, c) => + fullNameFromStr(p, c, s"invalid structpat names: $p, $c") + } .flatMap { pc => args.toList.traverse(pat).map(Pattern.PositionalStruct(pc, _)) } @@ -344,7 +417,11 @@ object ProtoConverter { } case notTwo => - Failure(new Exception(s"invalid union found size: ${notTwo.size}, expected 2 or more")) + Failure( + new Exception( + s"invalid union found size: ${notTwo.size}, expected 2 or more" + ) + ) } } } @@ -352,16 +429,23 @@ object ProtoConverter { buildTable(pats.toArray)(patternFromProto _) } - def recursionKindFromProto(rec: proto.RecursionKind, context: => String): Try[RecursionKind] = + def recursionKindFromProto( + rec: proto.RecursionKind, + context: => String + ): Try[RecursionKind] = rec match { case proto.RecursionKind.NotRec => Success(RecursionKind.NonRecursive) - case proto.RecursionKind.IsRec => Success(RecursionKind.Recursive) - case other => Failure(new Exception(s"invalid recursion kind: $other, in $context")) + case proto.RecursionKind.IsRec => Success(RecursionKind.Recursive) + case other => + Failure(new Exception(s"invalid recursion kind: $other, in $context")) } def buildExprs(exprs: Seq[proto.TypedExpr]): DTab[Array[TypedExpr[Unit]]] = ReaderT[Try, DecodeState, Array[TypedExpr[Unit]]] { ds => - def expressionFromProto(ex: proto.TypedExpr, exprOf: Int => Try[TypedExpr[Unit]]): Try[TypedExpr[Unit]] = { + def expressionFromProto( + ex: proto.TypedExpr, + exprOf: Int => Try[TypedExpr[Unit]] + ): Try[TypedExpr[Unit]] = { import proto.TypedExpr.Value def str(i: Int): Try[String] = @@ -391,40 +475,51 @@ object ProtoConverter { (faList, exList, exprOf(expr)) .mapN { (fa, ex, e) => - Type.Quantification.fromLists(forallList = fa.toList, existList = ex.toList) match { + Type.Quantification.fromLists( + forallList = fa.toList, + existList = ex.toList + ) match { case Some(q) => TypedExpr.Generic(q, e) - case None => e - } + case None => e + } } case Value.AnnotationExpr(proto.AnnotationExpr(expr, tpe, _)) => (exprOf(expr), typeOf(tpe)) .mapN(TypedExpr.Annotation(_, _)) case Value.LambdaExpr(proto.LambdaExpr(varsName, varsTpe, expr, _)) => - (varsName.traverse(bindable(_)), varsTpe.traverse(typeOf(_)), exprOf(expr)) + ( + varsName.traverse(bindable(_)), + varsTpe.traverse(typeOf(_)), + exprOf(expr) + ) .flatMapN { (vs, ts, e) => val vsLen = vs.length if (vsLen <= 0) { Failure(new Exception(s"no bind names in this lambda: $ex")) - } - else if (vsLen == ts.length) { + } else if (vsLen == ts.length) { // we know length > 0 and they match - val args = NonEmptyList.fromListUnsafe(vs.iterator.zip(ts.iterator).toList) + val args = NonEmptyList.fromListUnsafe( + vs.iterator.zip(ts.iterator).toList + ) Success(TypedExpr.AnnotatedLambda(args, e, ())) - } - else { - Failure(new Exception(s"type list length didn't match bind name length in $ex")) + } else { + Failure( + new Exception( + s"type list length didn't match bind name length in $ex" + ) + ) } } case Value.VarExpr(proto.VarExpr(pack, varname, tpe, _)) => val tryPack = if (pack == 0) Success(None) - else for { - ps <- str(pack) - pack <- parsePack(ps, s"expression: $ex") - } yield Some(pack) + else + for { + ps <- str(pack) + pack <- parsePack(ps, s"expression: $ex") + } yield Some(pack) - (tryPack, typeOf(tpe)) - .tupled + (tryPack, typeOf(tpe)).tupled .flatMap { case (None, tpe) => bindable(varname).map(TypedExpr.Local(_, tpe, ())) @@ -447,14 +542,20 @@ object ProtoConverter { .mapN(TypedExpr.Let(_, _, _, _, ())) case Value.LiteralExpr(proto.LiteralExpr(lit, tpe, _)) => lit match { - case None => Failure(new Exception(s"invalid missing literal in $ex")) + case None => + Failure(new Exception(s"invalid missing literal in $ex")) case Some(lit) => (litFromProto(lit), typeOf(tpe)) .mapN(TypedExpr.Literal(_, _, ())) } case Value.MatchExpr(proto.MatchExpr(argId, branches, _)) => - def buildBranch(b: proto.Branch): Try[(Pattern[(PackageName, Constructor), Type], TypedExpr[Unit])] = - (ds.tryPattern(b.pattern - 1, s"invalid pattern in $ex"), exprOf(b.resultExpr)).tupled + def buildBranch(b: proto.Branch): Try[ + (Pattern[(PackageName, Constructor), Type], TypedExpr[Unit]) + ] = + ( + ds.tryPattern(b.pattern - 1, s"invalid pattern in $ex"), + exprOf(b.resultExpr) + ).tupled NonEmptyList.fromList(branches.toList) match { case Some(nel) => @@ -476,11 +577,21 @@ object ProtoConverter { case Some(pack) => Success(pack) } - private def fullNameFromStr(pstr: String, tstr: String, context: => String): Try[(PackageName, Constructor)] = + private def fullNameFromStr( + pstr: String, + tstr: String, + context: => String + ): Try[(PackageName, Constructor)] = (parsePack(pstr, context), toConstructor(tstr)).tupled - def typeConstFromStr(pstr: String, tstr: String, context: => String): Try[Type.Const.Defined] = - fullNameFromStr(pstr, tstr, context).map { case (p, c) => Type.Const.Defined(p, TypeName(c)) } + def typeConstFromStr( + pstr: String, + tstr: String, + context: => String + ): Try[Type.Const.Defined] = + fullNameFromStr(pstr, tstr, context).map { case (p, c) => + Type.Const.Defined(p, TypeName(c)) + } def typeConstFromProto(p: proto.TypeConst): DTab[Type.Const.Defined] = { val proto.TypeConst(packidx, tidx, _) = p @@ -517,28 +628,30 @@ object ProtoConverter { val foralls = q.forallList val exs = q.existList val in = q.in - (foralls.traverse { case (b, k) => varKindToProto(b, k) }, + ( + foralls.traverse { case (b, k) => varKindToProto(b, k) }, exs.traverse { case (b, k) => varKindToProto(b, k) }, - typeToProto(in)) - .flatMapN { (faids, exids, idx) => - val ft0 = - if (exs.nonEmpty) { - val withEx = Type.exists(exs, in) - getTypeId(withEx, - proto.Type( - Value.TypeExists(TypeExists(exids, idx)))) - } - else tabPure(idx) - - ft0.flatMap { t0 => - if (foralls.nonEmpty) { - getTypeId(p, - proto.Type( - Value.TypeForAll(TypeForAll(faids, t0)))) - } - else tabPure(t0) - } + typeToProto(in) + ) + .flatMapN { (faids, exids, idx) => + val ft0 = + if (exs.nonEmpty) { + val withEx = Type.exists(exs, in) + getTypeId( + withEx, + proto.Type(Value.TypeExists(TypeExists(exids, idx))) + ) + } else tabPure(idx) + + ft0.flatMap { t0 => + if (foralls.nonEmpty) { + getTypeId( + p, + proto.Type(Value.TypeForAll(TypeForAll(faids, t0))) + ) + } else tabPure(t0) } + } case Type.TyApply(on, arg) => typeToProto(on) .product(typeToProto(arg)) @@ -565,8 +678,7 @@ object ProtoConverter { case Lit.Integer(i) => try { proto.Literal.Value.IntValueAs64(i.longValueExact) - } - catch { + } catch { case _: ArithmeticException => proto.Literal.Value.IntValueAsString(i.toString) } @@ -593,84 +705,132 @@ object ProtoConverter { } def patternToProto(p: Pattern[(PackageName, Constructor), Type]): Tab[Int] = - StateT.get[Try, SerState] + StateT + .get[Try, SerState] .map(_.patterns.indexOf(p)) .flatMap { case Some(idx) => tabPure(idx + 1) case None => p match { case Pattern.WildCard => - writePattern(p, proto.Pattern(proto.Pattern.Value.WildPat(proto.WildCardPat()))) + writePattern( + p, + proto.Pattern(proto.Pattern.Value.WildPat(proto.WildCardPat())) + ) case Pattern.Literal(lit) => val litP = litToProto(lit) writePattern(p, proto.Pattern(proto.Pattern.Value.LitPat(litP))) case Pattern.Var(n) => getId(n.sourceCodeRepr) .flatMap { idx => - writePattern(p, proto.Pattern(proto.Pattern.Value.VarNamePat(idx))) + writePattern( + p, + proto.Pattern(proto.Pattern.Value.VarNamePat(idx)) + ) } - case named@Pattern.Named(n, p) => + case named @ Pattern.Named(n, p) => getId(n.sourceCodeRepr) .product(patternToProto(p)) .flatMap { case (idx, pidx) => - writePattern(named, proto.Pattern(proto.Pattern.Value.NamedPat(proto.NamedPat(idx, pidx)))) + writePattern( + named, + proto.Pattern( + proto.Pattern.Value.NamedPat(proto.NamedPat(idx, pidx)) + ) + ) } case Pattern.StrPat(parts) => - parts.traverse { - case Pattern.StrPart.WildStr => - tabPure(proto.StrPart(proto.StrPart.Value.UnnamedStr(proto.WildCardPat()))) - case Pattern.StrPart.WildChar => - tabPure(proto.StrPart(proto.StrPart.Value.UnnamedChar(proto.WildCardPat()))) - case Pattern.StrPart.NamedStr(n) => - getId(n.sourceCodeRepr).map { idx => - proto.StrPart(proto.StrPart.Value.NamedStr(idx)) - } - case Pattern.StrPart.NamedChar(n) => - getId(n.sourceCodeRepr).map { idx => - proto.StrPart(proto.StrPart.Value.NamedChar(idx)) - } - case Pattern.StrPart.LitStr(s) => - getId(s).map { idx => - proto.StrPart(proto.StrPart.Value.LiteralStr(idx)) - } - } - .flatMap { parts => - writePattern(p, proto.Pattern(proto.Pattern.Value.StrPat(proto.StrPat(parts.toList)))) - } + parts + .traverse { + case Pattern.StrPart.WildStr => + tabPure( + proto.StrPart( + proto.StrPart.Value.UnnamedStr(proto.WildCardPat()) + ) + ) + case Pattern.StrPart.WildChar => + tabPure( + proto.StrPart( + proto.StrPart.Value.UnnamedChar(proto.WildCardPat()) + ) + ) + case Pattern.StrPart.NamedStr(n) => + getId(n.sourceCodeRepr).map { idx => + proto.StrPart(proto.StrPart.Value.NamedStr(idx)) + } + case Pattern.StrPart.NamedChar(n) => + getId(n.sourceCodeRepr).map { idx => + proto.StrPart(proto.StrPart.Value.NamedChar(idx)) + } + case Pattern.StrPart.LitStr(s) => + getId(s).map { idx => + proto.StrPart(proto.StrPart.Value.LiteralStr(idx)) + } + } + .flatMap { parts => + writePattern( + p, + proto.Pattern( + proto.Pattern.Value.StrPat(proto.StrPat(parts.toList)) + ) + ) + } case Pattern.ListPat(items) => - items.traverse { - case Pattern.ListPart.Item(itemPat) => - patternToProto(itemPat).map { pidx => - proto.ListPart(proto.ListPart.Value.ItemPattern(pidx)) - } - case Pattern.ListPart.WildList => - tabPure(proto.ListPart(proto.ListPart.Value.UnnamedList(proto.WildCardPat()))) - case Pattern.ListPart.NamedList(bindable) => - getId(bindable.sourceCodeRepr).map { idx => - proto.ListPart(proto.ListPart.Value.NamedList(idx)) - } - } - .flatMap { parts => - writePattern(p, proto.Pattern(proto.Pattern.Value.ListPat(proto.ListPat(parts)))) - } - case ann@Pattern.Annotation(p, tpe) => + items + .traverse { + case Pattern.ListPart.Item(itemPat) => + patternToProto(itemPat).map { pidx => + proto.ListPart(proto.ListPart.Value.ItemPattern(pidx)) + } + case Pattern.ListPart.WildList => + tabPure( + proto.ListPart( + proto.ListPart.Value.UnnamedList(proto.WildCardPat()) + ) + ) + case Pattern.ListPart.NamedList(bindable) => + getId(bindable.sourceCodeRepr).map { idx => + proto.ListPart(proto.ListPart.Value.NamedList(idx)) + } + } + .flatMap { parts => + writePattern( + p, + proto.Pattern( + proto.Pattern.Value.ListPat(proto.ListPat(parts)) + ) + ) + } + case ann @ Pattern.Annotation(p, tpe) => patternToProto(p) .product(typeToProto(tpe)) .flatMap { case (pidx, tidx) => - writePattern(ann, proto.Pattern(proto.Pattern.Value.AnnotationPat(proto.AnnotationPat(pidx, tidx)))) + writePattern( + ann, + proto.Pattern( + proto.Pattern.Value + .AnnotationPat(proto.AnnotationPat(pidx, tidx)) + ) + ) } - case pos@Pattern.PositionalStruct((packName, consName), params) => + case pos @ Pattern.PositionalStruct((packName, consName), params) => typeConstToProto(Type.Const.Defined(packName, TypeName(consName))) .flatMap { ptc => params .traverse(patternToProto) .flatMap { parts => - writePattern(pos, - proto.Pattern(proto.Pattern.Value.StructPat( - proto.StructPattern( - packageName = ptc.packageName, - constructorName = ptc.typeName, - params = parts)))) + writePattern( + pos, + proto.Pattern( + proto.Pattern.Value.StructPat( + proto.StructPattern( + packageName = ptc.packageName, + constructorName = ptc.typeName, + params = parts + ) + ) + ) + ) } } @@ -678,7 +838,12 @@ object ProtoConverter { (h :: t.toList) .traverse(patternToProto) .flatMap { us => - writePattern(p, proto.Pattern(proto.Pattern.Value.UnionPat(proto.UnionPattern(us)))) + writePattern( + p, + proto.Pattern( + proto.Pattern.Value.UnionPat(proto.UnionPattern(us)) + ) + ) } } } @@ -689,14 +854,15 @@ object ProtoConverter { } def typedExprToProto(te: TypedExpr[Any]): Tab[Int] = - StateT.get[Try, SerState] + StateT + .get[Try, SerState] .map(_.expressions.indexOf(te)) .flatMap { case Some(idx) => tabPure(idx + 1) case None => import TypedExpr._ te match { - case g@Generic(quant, expr) => + case g @ Generic(quant, expr) => val fas = quant.forallList.traverse { case (v, k) => varKindToProto(v, k) } @@ -706,51 +872,72 @@ object ProtoConverter { (fas, exs, typedExprToProto(expr)) .flatMapN { (fas, exs, exid) => val ex = proto.GenericExpr(forAlls = fas, exists = exs, exid) - writeExpr(g, proto.TypedExpr(proto.TypedExpr.Value.GenericExpr(ex))) + writeExpr( + g, + proto.TypedExpr(proto.TypedExpr.Value.GenericExpr(ex)) + ) } - case a@Annotation(term, tpe) => + case a @ Annotation(term, tpe) => typedExprToProto(term) .product(typeToProto(tpe)) .flatMap { case (term, tpe) => val ex = proto.AnnotationExpr(term, tpe) - writeExpr(a, proto.TypedExpr(proto.TypedExpr.Value.AnnotationExpr(ex))) + writeExpr( + a, + proto.TypedExpr(proto.TypedExpr.Value.AnnotationExpr(ex)) + ) } - case al@AnnotatedLambda(args, res, _) => - args.toList.traverse { case (n, tpe) => - getId(n.sourceCodeRepr).product(typeToProto(tpe)) - } - .product(typedExprToProto(res)) - .flatMap { case (args, resid) => - val ex = proto.LambdaExpr(args.map(_._1), args.map(_._2), resid) - writeExpr(al, proto.TypedExpr(proto.TypedExpr.Value.LambdaExpr(ex))) - } - case l@Local(nm, tpe, _) => + case al @ AnnotatedLambda(args, res, _) => + args.toList + .traverse { case (n, tpe) => + getId(n.sourceCodeRepr).product(typeToProto(tpe)) + } + .product(typedExprToProto(res)) + .flatMap { case (args, resid) => + val ex = + proto.LambdaExpr(args.map(_._1), args.map(_._2), resid) + writeExpr( + al, + proto.TypedExpr(proto.TypedExpr.Value.LambdaExpr(ex)) + ) + } + case l @ Local(nm, tpe, _) => getId(nm.sourceCodeRepr) .product(typeToProto(tpe)) .flatMap { case (varId, tpeId) => val ex = proto.VarExpr(0, varId, tpeId) - writeExpr(l, proto.TypedExpr(proto.TypedExpr.Value.VarExpr(ex))) + writeExpr( + l, + proto.TypedExpr(proto.TypedExpr.Value.VarExpr(ex)) + ) } - case g@Global(pack, nm, tpe, _) => - (getId(pack.asString), + case g @ Global(pack, nm, tpe, _) => + ( + getId(pack.asString), getId(nm.sourceCodeRepr), - typeToProto(tpe)) - .tupled + typeToProto(tpe) + ).tupled .flatMap { case (packId, varId, tpeId) => val ex = proto.VarExpr(packId, varId, tpeId) - writeExpr(g, proto.TypedExpr(proto.TypedExpr.Value.VarExpr(ex))) + writeExpr( + g, + proto.TypedExpr(proto.TypedExpr.Value.VarExpr(ex)) + ) } - case a@App(fn, args, resTpe, _) => + case a @ App(fn, args, resTpe, _) => typedExprToProto(fn) .product(args.traverse(typedExprToProto(_))) .product(typeToProto(resTpe)) .flatMap { case ((fn, args), resTpe) => val ex = proto.AppExpr(fn, args.toList, resTpe) - writeExpr(a, proto.TypedExpr(proto.TypedExpr.Value.AppExpr(ex))) + writeExpr( + a, + proto.TypedExpr(proto.TypedExpr.Value.AppExpr(ex)) + ) } - case let@Let(nm, nmexpr, inexpr, rec, _) => + case let @ Let(nm, nmexpr, inexpr, rec, _) => val prec = rec match { - case RecursionKind.Recursive => proto.RecursionKind.IsRec + case RecursionKind.Recursive => proto.RecursionKind.IsRec case RecursionKind.NonRecursive => proto.RecursionKind.NotRec } getId(nm.sourceCodeRepr) @@ -758,16 +945,24 @@ object ProtoConverter { .product(typedExprToProto(inexpr)) .flatMap { case ((nm, nmexpr), inexpr) => val ex = proto.LetExpr(nm, nmexpr, inexpr, prec) - writeExpr(let, proto.TypedExpr(proto.TypedExpr.Value.LetExpr(ex))) + writeExpr( + let, + proto.TypedExpr(proto.TypedExpr.Value.LetExpr(ex)) + ) } - case lit@Literal(l, tpe, _) => + case lit @ Literal(l, tpe, _) => typeToProto(tpe) .flatMap { tpe => val ex = proto.LiteralExpr(Some(litToProto(l)), tpe) - writeExpr(lit, proto.TypedExpr(proto.TypedExpr.Value.LiteralExpr(ex))) + writeExpr( + lit, + proto.TypedExpr(proto.TypedExpr.Value.LiteralExpr(ex)) + ) } - case m@Match(argE, branches, _) => - def encodeBranch(p: (Pattern[(PackageName, Constructor), Type], TypedExpr[Any])): Tab[proto.Branch] = + case m @ Match(argE, branches, _) => + def encodeBranch( + p: (Pattern[(PackageName, Constructor), Type], TypedExpr[Any]) + ): Tab[proto.Branch] = (patternToProto(p._1), typedExprToProto(p._2)) .mapN { (pat, expr) => proto.Branch(pat, expr) } @@ -775,7 +970,10 @@ object ProtoConverter { .product(branches.toList.traverse(encodeBranch)) .flatMap { case (argId, branches) => val ex = proto.MatchExpr(argId, branches) - writeExpr(m, proto.TypedExpr(proto.TypedExpr.Value.MatchExpr(ex))) + writeExpr( + m, + proto.TypedExpr(proto.TypedExpr.Value.MatchExpr(ex)) + ) } } } @@ -783,19 +981,20 @@ object ProtoConverter { def varianceToProto(v: Variance): proto.Variance = v match { - case Variance.Phantom => proto.Variance.Phantom - case Variance.Covariant => proto.Variance.Covariant + case Variance.Phantom => proto.Variance.Phantom + case Variance.Covariant => proto.Variance.Covariant case Variance.Contravariant => proto.Variance.Contravariant - case Variance.Invariant => proto.Variance.Invariant + case Variance.Invariant => proto.Variance.Invariant } - + def varianceFromProto(p: proto.Variance): Try[Variance] = p match { - case proto.Variance.Phantom => Success(Variance.Phantom) - case proto.Variance.Covariant => Success(Variance.Covariant) + case proto.Variance.Phantom => Success(Variance.Phantom) + case proto.Variance.Covariant => Success(Variance.Covariant) case proto.Variance.Contravariant => Success(Variance.Contravariant) - case proto.Variance.Invariant => Success(Variance.Invariant) - case proto.Variance.Unrecognized(value) => Failure(new Exception(s"unrecognized value for variance: $value")) + case proto.Variance.Invariant => Success(Variance.Invariant) + case proto.Variance.Unrecognized(value) => + Failure(new Exception(s"unrecognized value for variance: $value")) } def kindToProto(kind: Kind): proto.Kind = @@ -809,7 +1008,9 @@ object ProtoConverter { val vp = varianceToProto(v) val ip = kindToProto(i) val op = kindToProto(o) - proto.Kind(proto.Kind.Value.Cons(proto.ConsKind(vp, Some(ip), Some(op)))) + proto.Kind( + proto.Kind.Value.Cons(proto.ConsKind(vp, Some(ip), Some(op))) + ) } } def kindFromProto(kp: Option[proto.Kind]): Try[Kind] = @@ -817,11 +1018,14 @@ object ProtoConverter { case Some(proto.Kind(proto.Kind.Value.Encoded(idx), _)) => Kind.longToKind(idx) match { case Some(k) => Success(k) - case None => + case None => Failure(new Exception(s"could not decode $idx into Kind")) } - case Some(proto.Kind(proto.Kind.Value.Type(proto.TypeKind(_)), _)) => Success(Kind.Type) - case Some(proto.Kind(proto.Kind.Value.Cons(proto.ConsKind(v, i, o, _)), _)) => + case Some(proto.Kind(proto.Kind.Value.Type(proto.TypeKind(_)), _)) => + Success(Kind.Type) + case Some( + proto.Kind(proto.Kind.Value.Cons(proto.ConsKind(v, i, o, _)), _) + ) => for { variance <- varianceFromProto(v) kindI <- kindFromProto(i) @@ -837,7 +1041,11 @@ object ProtoConverter { typeVarBoundToProto(tv._1) .map { tvb => val Kind.Arg(variance, kind) = tv._2 - proto.TypeParam(Some(tvb), varianceToProto(variance), Some(kindToProto(kind))) + proto.TypeParam( + Some(tvb), + varianceToProto(variance), + Some(kindToProto(kind)) + ) } val protoTypeParams: Tab[List[proto.TypeParam]] = @@ -845,30 +1053,36 @@ object ProtoConverter { val constructors: Tab[List[proto.ConstructorFn]] = d.constructors.traverse { cf => - cf.args.traverse { case (b, t) => - typeToProto(t).flatMap { tidx => - getId(b.sourceCodeRepr) - .map { n => - proto.FnParam(n, tidx) + cf.args + .traverse { case (b, t) => + typeToProto(t).flatMap { tidx => + getId(b.sourceCodeRepr) + .map { n => + proto.FnParam(n, tidx) + } + } + } + .flatMap { params => + getId(cf.name.asString) + .map { id => + proto.ConstructorFn(id, params) } } - } - .flatMap { params => - getId(cf.name.asString) - .map { id => - proto.ConstructorFn(id, params) - } - } } (protoTypeParams, constructors) .mapN(proto.DefinedType(Some(tc), _, _)) } - def definedTypeFromProto(pdt: proto.DefinedType): DTab[DefinedType[Kind.Arg]] = { + def definedTypeFromProto( + pdt: proto.DefinedType + ): DTab[DefinedType[Kind.Arg]] = { def paramFromProto(tp: proto.TypeParam): DTab[(Type.Var.Bound, Kind.Arg)] = tp.typeVar match { - case None => ReaderT.liftF(Failure(new Exception(s"expected type variable in $tp"))) + case None => + ReaderT.liftF( + Failure(new Exception(s"expected type variable in $tp")) + ) case Some(tv) => val ka = for { v <- varianceFromProto(tp.variance) @@ -889,10 +1103,12 @@ object ProtoConverter { def consFromProto(c: proto.ConstructorFn): DTab[rankn.ConstructorFn] = lookup(c.name, c.toString) .flatMap { cname => - ReaderT.liftF(toConstructor(cname)) + ReaderT + .liftF(toConstructor(cname)) .flatMap { cname => - //def - c.params.toList.traverse(fnParamFromProto) + // def + c.params.toList + .traverse(fnParamFromProto) .map { fnParams => rankn.ConstructorFn(cname, fnParams) } @@ -900,7 +1116,8 @@ object ProtoConverter { } pdt.typeConst match { - case None => ReaderT.liftF(Failure(new Exception(s"missing typeConst: $pdt"))) + case None => + ReaderT.liftF(Failure(new Exception(s"missing typeConst: $pdt"))) case Some(tc) => for { tconst <- typeConstFromProto(tc) @@ -910,7 +1127,10 @@ object ProtoConverter { } } - def referantToProto[V](allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], r: Referant[V]): Tab[proto.Referant] = + def referantToProto[V]( + allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], + r: Referant[V] + ): Tab[proto.Referant] = r match { case Referant.Value(t) => typeToProto(t).map { tpeId => @@ -921,16 +1141,26 @@ object ProtoConverter { allDts.get(key) match { case Some((_, idx)) => tabPure( - proto.Referant(proto.Referant.Referant.DefinedType( - proto.DefinedTypeReference( - proto.DefinedTypeReference.Value.LocalDefinedTypePtr(idx + 1)))) + proto.Referant( + proto.Referant.Referant.DefinedType( + proto.DefinedTypeReference( + proto.DefinedTypeReference.Value.LocalDefinedTypePtr( + idx + 1 + ) + ) + ) + ) ) case None => // this is a non-local defined type: typeConstToProto(dt.toTypeConst).map { case tc => - proto.Referant(proto.Referant.Referant.DefinedType( - proto.DefinedTypeReference( - proto.DefinedTypeReference.Value.ImportedDefinedType(tc)))) + proto.Referant( + proto.Referant.Referant.DefinedType( + proto.DefinedTypeReference( + proto.DefinedTypeReference.Value.ImportedDefinedType(tc) + ) + ) + ) } } case Referant.Constructor(dt, cf) => @@ -944,11 +1174,21 @@ object ProtoConverter { proto.Referant.Referant.Constructor( proto.ConstructorReference( proto.ConstructorReference.Value.LocalConstructor( - proto.ConstructorPtr(dtIdx + 1, cIdx + 1)))))) - } - else tabFail(new Exception(s"missing contructor for type $key, ${cf.name}, with local: $dt")) + proto.ConstructorPtr(dtIdx + 1, cIdx + 1) + ) + ) + ) + ) + ) + } else + tabFail( + new Exception( + s"missing contructor for type $key, ${cf.name}, with local: $dt" + ) + ) case None => - (getId(dt.packageName.asString), + ( + getId(dt.packageName.asString), getId(dt.name.ident.sourceCodeRepr), getId(cf.name.sourceCodeRepr) ).mapN { (pid, tid, cid) => @@ -956,12 +1196,19 @@ object ProtoConverter { proto.Referant.Referant.Constructor( proto.ConstructorReference( proto.ConstructorReference.Value.ImportedConstructor( - proto.ImportedConstructor(pid, tid, cid))))) - } + proto.ImportedConstructor(pid, tid, cid) + ) + ) + ) + ) + } } } - def expNameToProto[V](allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], e: ExportedName[Referant[V]]): Tab[proto.ExportedName] = { + def expNameToProto[V]( + allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], + e: ExportedName[Referant[V]] + ): Tab[proto.ExportedName] = { val protoRef: Tab[proto.Referant] = referantToProto(allDts, e.tag) val exKind: Tab[(Int, proto.ExportKind)] = e match { case ExportedName.Binding(b, _) => @@ -972,10 +1219,15 @@ object ProtoConverter { getId(n.asString).map((_, proto.ExportKind.ConstructorName)) } - (protoRef, exKind).mapN { case (ref, (idx, k)) => proto.ExportedName(k, idx, Some(ref)) } + (protoRef, exKind).mapN { case (ref, (idx, k)) => + proto.ExportedName(k, idx, Some(ref)) + } } - private def packageDeps(strings: Array[String], dt: proto.DefinedType): List[String] = + private def packageDeps( + strings: Array[String], + dt: proto.DefinedType + ): List[String] = dt.typeConst match { case Some(tc) => strings(tc.packageName - 1) :: Nil @@ -986,7 +1238,10 @@ object ProtoConverter { private def ifaceDeps(iface: proto.Interface): List[String] = { val ary = iface.strings.toArray val thisPack = ary(iface.packageName - 1) - iface.definedTypes.toList.flatMap(packageDeps(ary, _).filterNot(_ == thisPack)).distinct.sorted + iface.definedTypes.toList + .flatMap(packageDeps(ary, _).filterNot(_ == thisPack)) + .distinct + .sorted } // what package names does this full package depend on? @@ -1001,16 +1256,16 @@ object ProtoConverter { } def interfaceToProto(iface: Package.Interface): Try[proto.Interface] = { - val allDts = DefinedType.listToMap( - iface.exports.flatMap { ex => + val allDts = DefinedType + .listToMap(iface.exports.flatMap { ex => /* * allDts are the locally defined types to this package * so we need to filter those outside this package */ - ex.tag - .definedType + ex.tag.definedType .filter(_.packageName == iface.name) - }).mapWithIndex { (dt, idx) => (dt, idx) } + }) + .mapWithIndex { (dt, idx) => (dt, idx) } val tryProtoDts = allDts .traverse { case (dt, _) => definedTypeToProto(dt) } @@ -1023,15 +1278,25 @@ object ProtoConverter { val last = packageId.product(tryProtoDts).product(tryExports) runTab(last).map { case (serstate, ((nm, dts), exps)) => - proto.Interface(serstate.strings.inOrder, serstate.types.inOrder, dts, nm, exps) + proto.Interface( + serstate.strings.inOrder, + serstate.types.inOrder, + dts, + nm, + exps + ) } } - private def referantFromProto(loadDT: Type.Const => Try[DefinedType[Kind.Arg]], ref: proto.Referant): DTab[Referant[Kind.Arg]] = + private def referantFromProto( + loadDT: Type.Const => Try[DefinedType[Kind.Arg]], + ref: proto.Referant + ): DTab[Referant[Kind.Arg]] = ref.referant match { case proto.Referant.Referant.Value(t) => lookupType(t, s"invalid type in $ref").map(Referant.Value(_)) - case proto.Referant.Referant.DefinedType(proto.DefinedTypeReference(dt, _)) => + case proto.Referant.Referant + .DefinedType(proto.DefinedTypeReference(dt, _)) => dt match { case proto.DefinedTypeReference.Value.LocalDefinedTypePtr(idx) => lookupDts(idx, s"invalid defined type in $ref") @@ -1043,31 +1308,43 @@ object ProtoConverter { case proto.DefinedTypeReference.Value.Empty => ReaderT.liftF(Failure(new Exception(s"empty referant found: $ref"))) } - case proto.Referant.Referant.Constructor(proto.ConstructorReference(consRef, _)) => + case proto.Referant.Referant + .Constructor(proto.ConstructorReference(consRef, _)) => consRef match { - case proto.ConstructorReference.Value.LocalConstructor(proto.ConstructorPtr(dtIdx, cIdx, _)) => + case proto.ConstructorReference.Value + .LocalConstructor(proto.ConstructorPtr(dtIdx, cIdx, _)) => lookupDts(dtIdx, s"invalid defined type in $ref").flatMap { dt => // cIdx is 1 based: val fixedIdx = cIdx - 1 ReaderT.liftF(dt.constructors.get(fixedIdx.toLong) match { case None => - Failure(new Exception(s"invalid constructor index: $cIdx in: $dt")) + Failure( + new Exception(s"invalid constructor index: $cIdx in: $dt") + ) case Some(cf) => Success(Referant.Constructor(dt, cf)) }) } - case proto.ConstructorReference.Value.ImportedConstructor(proto.ImportedConstructor(packId, typeId, consId, _)) => - (lookup(packId, s"imported constructor package in $ref"), + case proto.ConstructorReference.Value.ImportedConstructor( + proto.ImportedConstructor(packId, typeId, consId, _) + ) => + ( + lookup(packId, s"imported constructor package in $ref"), lookup(typeId, s"imported constructor typename in $ref"), - lookup(consId, s"imported constructor name in $ref")) - .tupled + lookup(consId, s"imported constructor name in $ref") + ).tupled .flatMapF { case (p, t, c) => for { tc <- typeConstFromStr(p, t, s"in $ref decoding ($p, $t)") dt <- loadDT(tc) cons <- toConstructor(c) idx = dt.constructors.indexWhere(_.name == cons) - _ <- if (idx < 0) Failure(new Exception(s"invalid constuctor name: $cons for $dt")) else Success(()) + _ <- + if (idx < 0) + Failure( + new Exception(s"invalid constuctor name: $cons for $dt") + ) + else Success(()) } yield Referant.Constructor(dt, dt.constructors(idx)) } case proto.ConstructorReference.Value.Empty => @@ -1078,14 +1355,17 @@ object ProtoConverter { } private def exportedNameFromProto( - loadDT: Type.Const => Try[DefinedType[Kind.Arg]], - en: proto.ExportedName): DTab[ExportedName[Referant[Kind.Arg]]] = { + loadDT: Type.Const => Try[DefinedType[Kind.Arg]], + en: proto.ExportedName + ): DTab[ExportedName[Referant[Kind.Arg]]] = { val tryRef: DTab[Referant[Kind.Arg]] = en.referant match { case Some(r) => referantFromProto(loadDT, r) - case None => ReaderT.liftF(Failure(new Exception(s"missing referant in $en"))) + case None => + ReaderT.liftF(Failure(new Exception(s"missing referant in $en"))) } - tryRef.product(lookup(en.name, en.toString)) + tryRef + .product(lookup(en.name, en.toString)) .flatMapF { case (ref, n) => en.exportKind match { case proto.ExportKind.Binding => @@ -1101,7 +1381,7 @@ object ProtoConverter { ExportedName.Constructor(c, ref) } case proto.ExportKind.Unrecognized(idx) => - Failure(new Exception(s"unknown export kind: $idx in $en")) + Failure(new Exception(s"unknown export kind: $idx in $en")) } } } @@ -1112,11 +1392,13 @@ object ProtoConverter { private sealed trait Scoped { def finish[A](dtab: DTab[A]): DTab[A] = this match { - case Scoped.Prep(d, fn) => d.flatMap { b => dtab.local[DecodeState] { ds => fn(ds, b) } } + case Scoped.Prep(d, fn) => + d.flatMap { b => dtab.local[DecodeState] { ds => fn(ds, b) } } } } private object Scoped { - case class Prep[A](dtab: DTab[A], fn: (DecodeState, A) => DecodeState) extends Scoped + case class Prep[A](dtab: DTab[A], fn: (DecodeState, A) => DecodeState) + extends Scoped def apply[A](dtab: DTab[A])(fn: (DecodeState, A) => DecodeState): Scoped = Prep(dtab, fn) @@ -1124,18 +1406,26 @@ object ProtoConverter { s.foldRight(dtab)(_.finish(_)) } - private def interfaceFromProto0(loadDT: Type.Const => Try[DefinedType[Kind.Arg]], protoIface: proto.Interface): Try[Package.Interface] = { + private def interfaceFromProto0( + loadDT: Type.Const => Try[DefinedType[Kind.Arg]], + protoIface: proto.Interface + ): Try[Package.Interface] = { val tab: DTab[Package.Interface] = for { packageName <- lookup(protoIface.packageName, protoIface.toString) pn <- ReaderT.liftF(parsePack(packageName, s"interface: $protoIface")) - exports <- protoIface.exports.toList.traverse(exportedNameFromProto(loadDT, _)) + exports <- protoIface.exports.toList.traverse( + exportedNameFromProto(loadDT, _) + ) } yield Package(pn, Nil, exports, ()) // build up the decoding state by decoding the tables in order - Scoped.run( - Scoped(buildTypes(protoIface.types))(_.withTypes(_)), - Scoped(protoIface.definedTypes.toVector.traverse(definedTypeFromProto))(_.withDefinedTypes(_)) + Scoped + .run( + Scoped(buildTypes(protoIface.types))(_.withTypes(_)), + Scoped(protoIface.definedTypes.toVector.traverse(definedTypeFromProto))( + _.withDefinedTypes(_) + ) )(tab) .run(DecodeState.init(protoIface.strings)) } @@ -1143,17 +1433,23 @@ object ProtoConverter { def interfaceFromProto(protoIface: proto.Interface): Try[Package.Interface] = interfacesFromProto(proto.Interfaces(protoIface :: Nil)).map(_.head) - def interfacesToProto[F[_]: Foldable](ps: F[Package.Interface]): Try[proto.Interfaces] = + def interfacesToProto[F[_]: Foldable]( + ps: F[Package.Interface] + ): Try[proto.Interfaces] = ps.toList.traverse(interfaceToProto).map { ifs => // sort so we are deterministic - proto.Interfaces(ifs.sortBy { iface => iface.strings(iface.packageName - 1) }) + proto.Interfaces(ifs.sortBy { iface => + iface.strings(iface.packageName - 1) + }) } def interfacesFromProto(ps: proto.Interfaces): Try[List[Package.Interface]] = // packagesFromProto can handle just interfaces as well packagesFromProto(ps.interfaces, Nil).map(_._1) - def read[A <: GeneratedMessage](path: Path)(implicit gmc: GeneratedMessageCompanion[A]): IO[A] = + def read[A <: GeneratedMessage]( + path: Path + )(implicit gmc: GeneratedMessageCompanion[A]): IO[A] = IO { val f = path.toFile val ios = new BufferedInputStream(new FileInputStream(f)) @@ -1173,16 +1469,21 @@ object ProtoConverter { } } - def readInterfacesAndPackages(ifacePaths: List[Path], packagePaths: List[Path]): IO[(List[Package.Interface], List[Package.Typed[Unit]])] = - (ifacePaths.traverse(read[proto.Interfaces](_)), - packagePaths.traverse(read[proto.Packages](_))) - .tupled + def readInterfacesAndPackages( + ifacePaths: List[Path], + packagePaths: List[Path] + ): IO[(List[Package.Interface], List[Package.Typed[Unit]])] = + ( + ifacePaths.traverse(read[proto.Interfaces](_)), + packagePaths.traverse(read[proto.Packages](_)) + ).tupled .flatMap { case (ifs, packs) => IO.fromTry( packagesFromProto( ifs.flatMap(_.interfaces), packs.flatMap(_.packages) - )) + ) + ) } def readInterfaces(paths: List[Path]): IO[List[Package.Interface]] = @@ -1191,7 +1492,10 @@ object ProtoConverter { def readPackages(paths: List[Path]): IO[List[Package.Typed[Unit]]] = readInterfacesAndPackages(Nil, paths).map(_._2) - def writeInterfaces(interfaces: List[Package.Interface], path: Path): IO[Unit] = + def writeInterfaces( + interfaces: List[Package.Interface], + path: Path + ): IO[Unit] = IO.fromTry(interfacesToProto(interfaces)) .flatMap(write(_, path)) @@ -1200,17 +1504,17 @@ object ProtoConverter { packages .traverse(packageToProto(_)) .map(proto.Packages(_)) - } - .flatMap(write(_, path)) + }.flatMap(write(_, path)) def importedNameToProto( - allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], - in: ImportedName[NonEmptyList[Referant[Kind.Arg]]]): Tab[proto.ImportedName] = { + allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], + in: ImportedName[NonEmptyList[Referant[Kind.Arg]]] + ): Tab[proto.ImportedName] = { val locName = in match { case ImportedName.OriginalName(_, _) => None - case ImportedName.Renamed(_, l, _) => Some(l) + case ImportedName.Renamed(_, l, _) => Some(l) } for { orig <- getId(in.originalName.sourceCodeRepr) @@ -1220,8 +1524,9 @@ object ProtoConverter { } def importToProto( - allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], - i: Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]): Tab[proto.Imports] = + allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], + i: Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]] + ): Tab[proto.Imports] = for { nm <- getId(i.pack.name.asString) imps <- i.items.toList.traverse(importedNameToProto(allDts, _)) @@ -1230,7 +1535,9 @@ object ProtoConverter { def letToProto(l: (Bindable, RecursionKind, TypedExpr[Any])): Tab[proto.Let] = for { nm <- getId(l._1.sourceCodeRepr) - rec = if (l._2.isRecursive) proto.RecursionKind.IsRec else proto.RecursionKind.NotRec + rec = + if (l._2.isRecursive) proto.RecursionKind.IsRec + else proto.RecursionKind.NotRec tex <- typedExprToProto(l._3) } yield proto.Let(nm, rec, tex) @@ -1243,7 +1550,8 @@ object ProtoConverter { def packageToProto[A](cpack: Package.Typed[A]): Try[proto.Package] = { // the Int is in index in the list of definedTypes: - val allDts: SortedMap[(PackageName, TypeName), (DefinedType[Kind.Arg], Int)] = + val allDts + : SortedMap[(PackageName, TypeName), (DefinedType[Kind.Arg], Int)] = cpack.program.types.definedTypes.mapWithIndex { (dt, idx) => (dt, idx) } val dtVect: Vector[DefinedType[Kind.Arg]] = allDts.values.iterator.map(_._1).toVector @@ -1254,20 +1562,23 @@ object ProtoConverter { exps <- cpack.exports.traverse(expNameToProto(allDts, _)) prog = cpack.program lets <- prog.lets.traverse(letToProto) - exdefs <- prog.externalDefs.traverse { nm => extDefToProto(nm, prog.types.getValue(cpack.name, nm)) } + exdefs <- prog.externalDefs.traverse { nm => + extDefToProto(nm, prog.types.getValue(cpack.name, nm)) + } dts <- dtVect.traverse(definedTypeToProto) } yield { (ss: SerState) => - proto.Package( - strings = ss.strings.inOrder, - types = ss.types.inOrder, - definedTypes = dts, - patterns = ss.patterns.inOrder, - expressions = ss.expressions.inOrder, - packageName = nmId, - imports = imps, - exports = exps, - lets = lets, - externalDefs = exdefs) + proto.Package( + strings = ss.strings.inOrder, + types = ss.types.inOrder, + definedTypes = dts, + patterns = ss.patterns.inOrder, + expressions = ss.expressions.inOrder, + packageName = nmId, + imports = imps, + exports = exps, + lets = lets, + externalDefs = exdefs + ) } runTab(tab).map { case (ss, fn) => fn(ss) } @@ -1277,7 +1588,8 @@ object ProtoConverter { Success(Identifier.Name("$anon")) def toBindable(str: String): Try[Bindable] = - if (str == "$anon") anonBind // used in Expr to create some lambdas with pattern match + if (str == "$anon") + anonBind // used in Expr to create some lambdas with pattern match else Try(Identifier.unsafeParse(Identifier.bindableParser, str)) def toIdent(str: String): Try[Identifier] = @@ -1294,19 +1606,22 @@ object ProtoConverter { lookup(idx, context).flatMapF(toIdent) def importedNameFromProto( - loadDT: Type.Const => Try[DefinedType[Kind.Arg]], - iname: proto.ImportedName): DTab[ImportedName[NonEmptyList[Referant[Kind.Arg]]]] = { + loadDT: Type.Const => Try[DefinedType[Kind.Arg]], + iname: proto.ImportedName + ): DTab[ImportedName[NonEmptyList[Referant[Kind.Arg]]]] = { def build[A](orig: Identifier, ref: A): DTab[ImportedName[A]] = if (iname.localName == 0) { ReaderT.pure(ImportedName.OriginalName(originalName = orig, ref)) - } - else { + } else { lookupIdentifier(iname.localName, iname.toString) .map(ImportedName.Renamed(originalName = orig, _, ref)) } NonEmptyList.fromList(iname.referant.toList) match { - case None => ReaderT.liftF(Failure(new Exception(s"expected at least one imported name: $iname"))) + case None => + ReaderT.liftF( + Failure(new Exception(s"expected at least one imported name: $iname")) + ) case Some(refs) => for { orig <- lookupIdentifier(iname.originalName, iname.toString) @@ -1316,11 +1631,16 @@ object ProtoConverter { } } - def importsFromProto(imp: proto.Imports, - lookupIface: PackageName => Try[Package.Interface], - loadDT: Type.Const => Try[DefinedType[Kind.Arg]]): DTab[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]] = + def importsFromProto( + imp: proto.Imports, + lookupIface: PackageName => Try[Package.Interface], + loadDT: Type.Const => Try[DefinedType[Kind.Arg]] + ): DTab[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]] = NonEmptyList.fromList(imp.names.toList) match { - case None => ReaderT.liftF(Failure(new Exception(s"expected non-empty import names in: $imp"))) + case None => + ReaderT.liftF( + Failure(new Exception(s"expected non-empty import names in: $imp")) + ) case Some(nei) => for { pnameStr <- lookup(imp.packageName, imp.toString) @@ -1330,20 +1650,30 @@ object ProtoConverter { } yield Import(iface, inames) } - def letsFromProto(let: proto.Let): DTab[(Bindable, RecursionKind, TypedExpr[Unit])] = - (lookupBindable(let.name, let.toString), - ReaderT.liftF(recursionKindFromProto(let.rec, let.toString)): DTab[RecursionKind], - lookupExpr(let.expr, let.toString)).tupled + def letsFromProto( + let: proto.Let + ): DTab[(Bindable, RecursionKind, TypedExpr[Unit])] = + ( + lookupBindable(let.name, let.toString), + ReaderT.liftF(recursionKindFromProto(let.rec, let.toString)): DTab[ + RecursionKind + ], + lookupExpr(let.expr, let.toString) + ).tupled def externalDefsFromProto(ed: proto.ExternalDef): DTab[(Bindable, Type)] = - (lookupBindable(ed.name, ed.toString), - lookupType(ed.typeOf, ed.toString)).tupled + ( + lookupBindable(ed.name, ed.toString), + lookupType(ed.typeOf, ed.toString) + ).tupled def buildProgram( - pack: PackageName, - lets: List[(Bindable, RecursionKind, TypedExpr[Unit])], - exts: List[(Bindable, Type)]): DTab[Program[TypeEnv[Kind.Arg], TypedExpr[Unit], Unit]] = - ReaderT.ask[Try, DecodeState] + pack: PackageName, + lets: List[(Bindable, RecursionKind, TypedExpr[Unit])], + exts: List[(Bindable, Type)] + ): DTab[Program[TypeEnv[Kind.Arg], TypedExpr[Unit], Unit]] = + ReaderT + .ask[Try, DecodeState] .map { ds => // this adds all the types and contructors // from the given defined types @@ -1357,21 +1687,24 @@ object ProtoConverter { } def packagesFromProto( - ifaces: Iterable[proto.Interface], - packs: Iterable[proto.Package]): Try[(List[Package.Interface], List[Package.Typed[Unit]])] = { + ifaces: Iterable[proto.Interface], + packs: Iterable[proto.Package] + ): Try[(List[Package.Interface], List[Package.Typed[Unit]])] = { type Node = Either[proto.Interface, proto.Package] def iname(p: proto.Interface): String = - p.strings.lift(p.packageName - 1) + p.strings + .lift(p.packageName - 1) .getOrElse("_unknown_" + p.packageName.toString) def pname(p: proto.Package): String = - p.strings.lift(p.packageName - 1) + p.strings + .lift(p.packageName - 1) .getOrElse("_unknown_" + p.packageName.toString) def nodeName(n: Node): String = n match { - case Left(i) => iname(i) + case Left(i) => iname(i) case Right(p) => pname(p) } @@ -1381,17 +1714,18 @@ object ProtoConverter { (l, r) match { case (Left(_), Right(_)) => -1 case (Right(_), Left(_)) => 1 - case (nl, nr) => nodeName(nl).compareTo(nodeName(nr)) + case (nl, nr) => nodeName(nl).compareTo(nodeName(nr)) } } - val nodes: List[Node] = ifaces.map(Left(_)).toList ::: packs.map(Right(_)).toList + val nodes: List[Node] = + ifaces.map(Left(_)).toList ::: packs.map(Right(_)).toList val nodeMap: Map[String, List[Node]] = nodes.groupBy(nodeName) def getNodes(n: String, parent: Node): List[Node] = nodeMap.get(n) match { - case Some(ns) => ns + case Some(ns) => ns case None if n == PackageName.PredefName.asString => // we can load the predef below Nil @@ -1403,132 +1737,170 @@ object ProtoConverter { // so, the unsafe calls inside are checked before we call def dependsOn(n: Node): List[Node] = n match { - case Left(i) => ifaceDeps(i).flatMap { dep => getNodes(dep, n) } + case Left(i) => ifaceDeps(i).flatMap { dep => getNodes(dep, n) } case Right(p) => packageDeps(p).flatMap { dep => getNodes(dep, n) } } val dupNames: List[String] = - nodeMap - .iterator + nodeMap.iterator .filter { case (_, vs) => vs.lengthCompare(1) > 0 } .map(_._1) .toList .sorted Try(graph.Toposort.sort(nodes)(dependsOn)).flatMap { sorted => + if (dupNames.nonEmpty) { + Failure( + new Exception("duplicate package names: " + dupNames.mkString(", ")) + ) + } else if (sorted.isFailure) { + val loopStr = + sorted.loopNodes + .map { + case Left(i) => "interface: " + iname(i) + case Right(p) => "compiled: " + pname(p) + } + .mkString(", ") + Failure(new Exception(s"circular dependencies in packages: $loopStr")) + } else { + def makeLoadDT( + load: String => Try[Either[ + (Package.Interface, TypeEnv[Kind.Arg]), + Package.Typed[Unit] + ]] + ): Type.Const => Try[DefinedType[Kind.Arg]] = { + case tc @ Type.Const.Defined(p, _) => + val res = load(p.asString).map { + case Left((_, dt)) => + dt.toDefinedType(tc) + case Right(comp) => + comp.program.types.toDefinedType(tc) + } - if (dupNames.nonEmpty) { - Failure(new Exception("duplicate package names: " + dupNames.mkString(", "))) - } - else if (sorted.isFailure) { - val loopStr = - sorted - .loopNodes - .map { - case Left(i) => "interface: " + iname(i) - case Right(p) => "compiled: " + pname(p) - } - .mkString(", ") - Failure(new Exception(s"circular dependencies in packages: $loopStr")) - } - else { - def makeLoadDT( - load: String => Try[Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]]] - ): Type.Const => Try[DefinedType[Kind.Arg]] = { case tc@Type.Const.Defined(p, _) => - val res = load(p.asString).map { - case Left((_, dt)) => - dt.toDefinedType(tc) - case Right(comp) => - comp.program.types.toDefinedType(tc) + res.flatMap { + case None => + Failure(new Exception(s"unknown type $tc not present")) + case Some(dt) => Success(dt) + } } - res.flatMap { - case None => Failure(new Exception(s"unknown type $tc not present")) - case Some(dt) => Success(dt) - } - } + /* + * We know we have a dag now, so we can just go through + * loading them. + * + * We will need a list of these an memoize loading them all + */ - /* - * We know we have a dag now, so we can just go through - * loading them. - * - * We will need a list of these an memoize loading them all - */ - - def packFromProtoUncached( - pack: proto.Package, - load: String => Try[Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]]] - ): Try[Package.Typed[Unit]] = { - val loadIface: PackageName => Try[Package.Interface] = { p => - load(p.asString).map { - case Left((iface, _)) => iface - case Right(pack) => Package.interfaceOf(pack) + def packFromProtoUncached( + pack: proto.Package, + load: String => Try[Either[ + (Package.Interface, TypeEnv[Kind.Arg]), + Package.Typed[Unit] + ]] + ): Try[Package.Typed[Unit]] = { + val loadIface: PackageName => Try[Package.Interface] = { p => + load(p.asString).map { + case Left((iface, _)) => iface + case Right(pack) => Package.interfaceOf(pack) + } } - } - val loadDT = makeLoadDT(load) - - val tab: DTab[Package.Typed[Unit]] = - for { - packageNameStr <- lookup(pack.packageName, pack.toString) - packageName <- ReaderT.liftF(parsePack(packageNameStr, pack.toString)) - imps <- pack.imports.toList.traverse(importsFromProto(_, loadIface, loadDT)) - exps <- pack.exports.toList.traverse(exportedNameFromProto(loadDT, _)) - lets <- pack.lets.toList.traverse(letsFromProto) - eds <- pack.externalDefs.toList.traverse(externalDefsFromProto) - prog <- buildProgram(packageName, lets, eds) - } yield Package(packageName, imps, exps, prog) - - // build up the decoding state by decoding the tables in order - val tab1 = Scoped.run( - Scoped(buildTypes(pack.types))(_.withTypes(_)), - Scoped(pack.definedTypes.toVector.traverse(definedTypeFromProto))(_.withDefinedTypes(_)), - Scoped(buildPatterns(pack.patterns))(_.withPatterns(_)), - Scoped(buildExprs(pack.expressions))(_.withExprs(_)) + val loadDT = makeLoadDT(load) + + val tab: DTab[Package.Typed[Unit]] = + for { + packageNameStr <- lookup(pack.packageName, pack.toString) + packageName <- ReaderT.liftF( + parsePack(packageNameStr, pack.toString) + ) + imps <- pack.imports.toList.traverse( + importsFromProto(_, loadIface, loadDT) + ) + exps <- pack.exports.toList.traverse( + exportedNameFromProto(loadDT, _) + ) + lets <- pack.lets.toList.traverse(letsFromProto) + eds <- pack.externalDefs.toList.traverse(externalDefsFromProto) + prog <- buildProgram(packageName, lets, eds) + } yield Package(packageName, imps, exps, prog) + + // build up the decoding state by decoding the tables in order + val tab1 = Scoped.run( + Scoped(buildTypes(pack.types))(_.withTypes(_)), + Scoped(pack.definedTypes.toVector.traverse(definedTypeFromProto))( + _.withDefinedTypes(_) + ), + Scoped(buildPatterns(pack.patterns))(_.withPatterns(_)), + Scoped(buildExprs(pack.expressions))(_.withExprs(_)) )(tab) - tab1.run(DecodeState.init(pack.strings)) - } + tab1.run(DecodeState.init(pack.strings)) + } - val predefIface = { - val iface = Package.interfaceOf(PackageMap.predefCompiled) - (iface, ExportedName.typeEnvFromExports(iface.name, iface.exports)) - } + val predefIface = { + val iface = Package.interfaceOf(PackageMap.predefCompiled) + (iface, ExportedName.typeEnvFromExports(iface.name, iface.exports)) + } - val load: String => Try[Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]]] = - Memoize.memoizeDagHashed[String, Try[Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]]]] { (pack, rec) => - nodeMap.get(pack) match { - case Some(Left(iface) :: Nil) => - interfaceFromProto0(makeLoadDT(rec), iface) - .map { iface => Left((iface, ExportedName.typeEnvFromExports(iface.name, iface.exports))) } - case Some(Right(p) :: Nil) => - packFromProtoUncached(p, rec) - .map(Right(_)) - case None if pack == PackageName.PredefName.asString => - // if we haven't replaced explicitly, use the built in predef - Success(Left(predefIface)) - case found => - Failure(new Exception(s"missing interface or compiled: $pack, found: $found")) + val load: String => Try[ + Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]] + ] = + Memoize.memoizeDagHashed[String, Try[ + Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]] + ]] { (pack, rec) => + nodeMap.get(pack) match { + case Some(Left(iface) :: Nil) => + interfaceFromProto0(makeLoadDT(rec), iface) + .map { iface => + Left( + ( + iface, + ExportedName.typeEnvFromExports( + iface.name, + iface.exports + ) + ) + ) + } + case Some(Right(p) :: Nil) => + packFromProtoUncached(p, rec) + .map(Right(_)) + case None if pack == PackageName.PredefName.asString => + // if we haven't replaced explicitly, use the built in predef + Success(Left(predefIface)) + case found => + Failure( + new Exception( + s"missing interface or compiled: $pack, found: $found" + ) + ) + } } - } - val deserPack: proto.Package => Try[Package.Typed[Unit]] = { p => - load(pname(p)).flatMap { - case Left((iface, _)) => Failure(new Exception(s"expected compiled for ${iface.name.asString}, found interface")) - case Right(pack) => Success(pack) + val deserPack: proto.Package => Try[Package.Typed[Unit]] = { p => + load(pname(p)).flatMap { + case Left((iface, _)) => + Failure( + new Exception( + s"expected compiled for ${iface.name.asString}, found interface" + ) + ) + case Right(pack) => Success(pack) + } } - } - val deserIface: proto.Interface => Try[Package.Interface] = { p => - load(iname(p)).map { - case Left((iface, _)) => iface - case Right(pack) => Package.interfaceOf(pack) + val deserIface: proto.Interface => Try[Package.Interface] = { p => + load(iname(p)).map { + case Left((iface, _)) => iface + case Right(pack) => Package.interfaceOf(pack) + } } - } - // use the cached versions down here - (ifaces.toList.traverse(deserIface), - packs.toList.traverse(deserPack)).tupled - } + // use the cached versions down here + ( + ifaces.toList.traverse(deserIface), + packs.toList.traverse(deserPack) + ).tupled + } } } } diff --git a/cli/src/test/scala/org/bykn/bosatsu/JsonTest.scala b/cli/src/test/scala/org/bykn/bosatsu/JsonTest.scala index 05682a633..616fa5558 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/JsonTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/JsonTest.scala @@ -1,6 +1,9 @@ package org.bykn.bosatsu -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.typelevel.jawn.ast.{JValue, JParser} import GenJson._ @@ -14,11 +17,11 @@ class JsonJawnTest extends AnyFunSuite { def matches(j1: Json, j2: JValue): Unit = { import Json._ j1 match { - case JString(str) => assert(j2.asString == str); () + case JString(str) => assert(j2.asString == str); () case JNumberStr(nstr) => assert(BigDecimal(nstr) == j2.asBigDecimal); () - case JNull => assert(j2.isNull); () - case JBool.True => assert(j2.asBoolean); () - case JBool.False => assert(!j2.asBoolean); () + case JNull => assert(j2.isNull); () + case JBool.True => assert(j2.asBoolean); () + case JBool.False => assert(!j2.asBoolean); () case JArray(js) => js.zipWithIndex.foreach { case (j, idx) => matches(j, j2.get(idx)) diff --git a/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala b/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala index 4027700cf..06f0b04e4 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala @@ -26,11 +26,32 @@ class PathModuleTest extends AnyFunSuite { def pn(roots: List[String], file: String): Option[PackageName] = PathModule.pathPackage(roots.map(Paths.get(_)), Paths.get(file)) - assert(pn(List("/root0", "/root1"), "/root0/Bar.bosatsu") == Some(PackageName(NonEmptyList.of("Bar")))) - assert(pn(List("/root0", "/root1"), "/root1/Bar/Baz.bosatsu") == Some(PackageName(NonEmptyList.of("Bar", "Baz")))) - assert(pn(List("/root0", "/root0/Bar"), "/root0/Bar/Baz.bosatsu") == Some(PackageName(NonEmptyList.of("Bar", "Baz")))) - assert(pn(List("/root0/", "/root0/Bar"), "/root0/Bar/Baz.bosatsu") == Some(PackageName(NonEmptyList.of("Bar", "Baz")))) - assert(pn(List("/root0/ext", "/root0/Bar"), "/root0/ext/Bar/Baz.bosatsu") == Some(PackageName(NonEmptyList.of("Bar", "Baz")))) + assert( + pn(List("/root0", "/root1"), "/root0/Bar.bosatsu") == Some( + PackageName(NonEmptyList.of("Bar")) + ) + ) + assert( + pn(List("/root0", "/root1"), "/root1/Bar/Baz.bosatsu") == Some( + PackageName(NonEmptyList.of("Bar", "Baz")) + ) + ) + assert( + pn(List("/root0", "/root0/Bar"), "/root0/Bar/Baz.bosatsu") == Some( + PackageName(NonEmptyList.of("Bar", "Baz")) + ) + ) + assert( + pn(List("/root0/", "/root0/Bar"), "/root0/Bar/Baz.bosatsu") == Some( + PackageName(NonEmptyList.of("Bar", "Baz")) + ) + ) + assert( + pn( + List("/root0/ext", "/root0/Bar"), + "/root0/ext/Bar/Baz.bosatsu" + ) == Some(PackageName(NonEmptyList.of("Bar", "Baz"))) + ) } test("no roots means no Package") { @@ -50,7 +71,9 @@ class PathModuleTest extends AnyFunSuite { if (rest.toString != "" && root.toString != "") { val path = root.resolve(rest) val pack = - PackageName.parse(rest.asScala.map(_.toString.toLowerCase.capitalize).mkString("/")) + PackageName.parse( + rest.asScala.map(_.toString.toLowerCase.capitalize).mkString("/") + ) assert(PathModule.pathPackage(root :: otherRoots, path) == pack) } } @@ -60,7 +83,8 @@ class PathModuleTest extends AnyFunSuite { val regressions: List[(Path, List[Path], Path)] = List( (Paths.get(""), Nil, Paths.get("/foo/bar")), - (Paths.get(""), List(Paths.get("")), Paths.get("/foo/bar"))) + (Paths.get(""), List(Paths.get("")), Paths.get("/foo/bar")) + ) regressions.foreach { case (r, o, e) => law(r, o, e) } } @@ -70,7 +94,9 @@ class PathModuleTest extends AnyFunSuite { val roots = (r0 :: roots0).filterNot(_.toString == "") val pack = PathModule.pathPackage(roots, file) - val noPrefix = !roots.exists { r => file.asScala.toList.startsWith(r.asScala.toList) } + val noPrefix = !roots.exists { r => + file.asScala.toList.startsWith(r.asScala.toList) + } if (noPrefix) assert(pack == None) } @@ -80,30 +106,43 @@ class PathModuleTest extends AnyFunSuite { PathModule.run(args.toList) match { case Left(h) => fail(s"got help: $h on command: ${args.toList}") case Right(io) => - io.attempt.flatMap { - case Right(out) => - PathModule.reportOutput(out).as(out) - case Left(err) => - PathModule.reportException(err) *> IO.raiseError(err) - } - .unsafeRunSync() + io.attempt + .flatMap { + case Right(out) => + PathModule.reportOutput(out).as(out) + case Left(err) => + PathModule.reportException(err) *> IO.raiseError(err) + } + .unsafeRunSync() } test("test direct run of a file") { - val out = run("test --input test_workspace/List.bosatsu --input test_workspace/Nat.bosatsu --input test_workspace/Bool.bosatsu --test_file test_workspace/Queue.bosatsu".split("\\s+").toSeq: _*) + val out = run( + "test --input test_workspace/List.bosatsu --input test_workspace/Nat.bosatsu --input test_workspace/Bool.bosatsu --test_file test_workspace/Queue.bosatsu" + .split("\\s+") + .toSeq: _* + ) out match { case PathModule.Output.TestOutput(results, _) => - val res = results.collect { case (pn, Some(t)) if pn.asString == "Queue" => t.value } + val res = results.collect { + case (pn, Some(t)) if pn.asString == "Queue" => t.value + } assert(res.length == 1) case other => fail(s"expected test output: $other") } } test("test search run of a file") { - val out = run("test --package_root test_workspace --search --test_file test_workspace/Bar.bosatsu".split("\\s+").toSeq: _*) + val out = run( + "test --package_root test_workspace --search --test_file test_workspace/Bar.bosatsu" + .split("\\s+") + .toSeq: _* + ) out match { case PathModule.Output.TestOutput(results, _) => - val res = results.collect { case (pn, Some(t)) if pn.asString == "Bar" => t.value } + val res = results.collect { + case (pn, Some(t)) if pn.asString == "Bar" => t.value + } assert(res.length == 1) assert(res.head.assertions == 1) assert(res.head.failureCount == 0) @@ -112,7 +151,11 @@ class PathModuleTest extends AnyFunSuite { } test("test python transpile on the entire test_workspace") { - val out = run("transpile --input_dir test_workspace/ --outdir pyout --lang python --package_root test_workspace".split("\\s+").toSeq: _*) + val out = run( + "transpile --input_dir test_workspace/ --outdir pyout --lang python --package_root test_workspace" + .split("\\s+") + .toSeq: _* + ) out match { case PathModule.Output.TranspileOut(_, _) => assert(true) @@ -122,18 +165,29 @@ class PathModuleTest extends AnyFunSuite { test("test search with json write") { - val out = run("json write --package_root test_workspace --search --main_file test_workspace/Bar.bosatsu".split("\\s+").toSeq: _*) + val out = run( + "json write --package_root test_workspace --search --main_file test_workspace/Bar.bosatsu" + .split("\\s+") + .toSeq: _* + ) out match { - case PathModule.Output.JsonOutput(j@Json.JObject(_), _) => - assert(j.toMap == Map("value" -> Json.JBool(true), "message" -> Json.JString("got the right string"))) + case PathModule.Output.JsonOutput(j @ Json.JObject(_), _) => + assert( + j.toMap == Map( + "value" -> Json.JBool(true), + "message" -> Json.JString("got the right string") + ) + ) assert(j.items.length == 2) case other => fail(s"expected json object output: $other") } } test("test search json apply") { - val cmd = "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" - .split("\\s+").toList :+ "[2, 4]" + val cmd = + "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" + .split("\\s+") + .toList :+ "[2, 4]" run(cmd: _*) match { case PathModule.Output.JsonOutput(Json.JNumberStr("8"), _) => succeed @@ -142,11 +196,17 @@ class PathModuleTest extends AnyFunSuite { } test("test search json traverse") { - val cmd = "json traverse --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" - .split("\\s+").toList :+ "[[2, 4], [3, 5]]" + val cmd = + "json traverse --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" + .split("\\s+") + .toList :+ "[[2, 4], [3, 5]]" run(cmd: _*) match { - case PathModule.Output.JsonOutput(Json.JArray(Vector(Json.JNumberStr("8"), Json.JNumberStr("15"))), _) => succeed + case PathModule.Output.JsonOutput( + Json.JArray(Vector(Json.JNumberStr("8"), Json.JNumberStr("15"))), + _ + ) => + succeed case other => fail(s"expected json object output: $other") } } @@ -163,7 +223,8 @@ class PathModuleTest extends AnyFunSuite { } // ill-typed json fails - val cmd = "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" + val cmd = + "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" fails(cmd, "[\"2\", 4]") fails(cmd, "[2, \"4\"]") // wrong arity @@ -171,42 +232,61 @@ class PathModuleTest extends AnyFunSuite { fails(cmd, "[2]") fails(cmd, "[]") // unknown command fails - val badName = "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::foooooo --json_string 23" + val badName = + "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::foooooo --json_string 23" fails(badName) - val badPack = "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/DoesNotExist --json_string 23" + val badPack = + "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/DoesNotExist --json_string 23" fails(badPack) // bad json fails fails(cmd, "[\"2\", foo, bla]") fails(cmd, "[42, 31] and some junk") // exercise unsupported, we cannot write mult, it is a function - fails("json write --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult") + fails( + "json write --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult" + ) // a bad main name triggers help - PathModule.run("json write --input_dir test_workspace --main Bo//".split(' ').toList) match { - case Left(_) => succeed + PathModule.run( + "json write --input_dir test_workspace --main Bo//".split(' ').toList + ) match { + case Left(_) => succeed case Right(_) => fail() } - PathModule.run("json write --input_dir test_workspace --main Bo:::boop".split(' ').toList) match { - case Left(_) => succeed + PathModule.run( + "json write --input_dir test_workspace --main Bo:::boop".split(' ').toList + ) match { + case Left(_) => succeed case Right(_) => fail() } } test("test running all test in test_workspace") { - val out = run("test --package_root test_workspace --input_dir test_workspace".split("\\s+").toSeq: _*) + val out = run( + "test --package_root test_workspace --input_dir test_workspace" + .split("\\s+") + .toSeq: _* + ) out match { case PathModule.Output.TestOutput(res, _) => val noTests = res.collect { case (pn, None) => pn }.toList assert(noTests == Nil) - val failures = res.collect { case (pn, Some(t)) if t.value.failureCount > 0 => pn } + val failures = res.collect { + case (pn, Some(t)) if t.value.failureCount > 0 => pn + } assert(failures == Nil) case other => fail(s"expected test output: $other") } } test("evaluation by name with shadowing") { - run("json write --package_root test_workspace --input test_workspace/Foo.bosatsu --main Foo::x".split("\\s+").toSeq: _*) match { - case PathModule.Output.JsonOutput(Json.JString("this is Foo"), _) => succeed + run( + "json write --package_root test_workspace --input test_workspace/Foo.bosatsu --main Foo::x" + .split("\\s+") + .toSeq: _* + ) match { + case PathModule.Output.JsonOutput(Json.JString("this is Foo"), _) => + succeed case other => fail(s"unexpeced: $other") } } diff --git a/cli/src/test/scala/org/bykn/bosatsu/TestProtoType.scala b/cli/src/test/scala/org/bykn/bosatsu/TestProtoType.scala index f0e199a0f..cca7a98b2 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/TestProtoType.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/TestProtoType.scala @@ -5,7 +5,10 @@ import cats.Eq import cats.effect.{IO, Resource} import org.bykn.bosatsu.rankn.Type import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import scala.util.{Failure, Success, Try} import cats.implicits._ @@ -17,9 +20,9 @@ import org.scalatest.funsuite.AnyFunSuite class TestProtoType extends AnyFunSuite with ParTest { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 100) - //PropertyCheckConfiguration(minSuccessful = 5) + // PropertyCheckConfiguration(minSuccessful = 5) def law[A: Eq, B](a: A, fn: A => Try[B], gn: B => Try[A]) = { val maybeProto = fn(a) @@ -39,12 +42,16 @@ class TestProtoType extends AnyFunSuite with ParTest { .zip(orig.toString) .zipWithIndex .dropWhile { case ((a, b), _) => a == b } - .headOption.map(_._2) + .headOption + .map(_._2) .getOrElse(0) val context = 100 - assert(Eq[A].eqv(a, orig), s"${a.toString.drop(diffIdx - context/2).take(context)} != ${orig.toString.drop(diffIdx - context/2).take(context)}") - //assert(Eq[A].eqv(a, orig), s"$a\n\n!=\n\n$orig") + assert( + Eq[A].eqv(a, orig), + s"${a.toString.drop(diffIdx - context / 2).take(context)} != ${orig.toString.drop(diffIdx - context / 2).take(context)}" + ) + // assert(Eq[A].eqv(a, orig), s"$a\n\n!=\n\n$orig") } def testWithTempFile(fn: Path => IO[Unit]): Unit = { @@ -63,7 +70,9 @@ class TestProtoType extends AnyFunSuite with ParTest { tempRes.use(fn).unsafeRunSync() } - def tabLaw[A: Eq, B](f: A => ProtoConverter.Tab[B])(g: (ProtoConverter.SerState, B) => ProtoConverter.DTab[A]) = { (a: A) => + def tabLaw[A: Eq, B]( + f: A => ProtoConverter.Tab[B] + )(g: (ProtoConverter.SerState, B) => ProtoConverter.DTab[A]) = { (a: A) => f(a).run(ProtoConverter.SerState.empty) match { case Success((ss, b)) => val ds = ProtoConverter.DecodeState.init(ss.strings.inOrder) @@ -87,10 +96,14 @@ class TestProtoType extends AnyFunSuite with ParTest { } test("we can roundtrip patterns through proto") { - val testFn = tabLaw(ProtoConverter.patternToProto(_: Pattern[(PackageName, Constructor), Type])) { (ss, idx) => + val testFn = tabLaw( + ProtoConverter.patternToProto( + _: Pattern[(PackageName, Constructor), Type] + ) + ) { (ss, idx) => for { tps <- ProtoConverter.buildTypes(ss.types.inOrder) - pats = ProtoConverter.buildPatterns(ss.patterns.inOrder).map(_(idx - 1)) + pats = ProtoConverter.buildPatterns(ss.patterns.inOrder).map(_(idx - 1)) res <- pats.local[ProtoConverter.DecodeState](_.withTypes(tps)) } yield res }(Eq.fromUniversalEquals) @@ -99,22 +112,33 @@ class TestProtoType extends AnyFunSuite with ParTest { } test("we can roundtrip TypedExpr through proto") { - val testFn = tabLaw(ProtoConverter.typedExprToProto(_: TypedExpr[Unit])) { (ss, idx) => - for { - tps <- ProtoConverter.buildTypes(ss.types.inOrder) - pats = ProtoConverter.buildPatterns(ss.patterns.inOrder) - patTab <- pats.local[ProtoConverter.DecodeState](_.withTypes(tps)) - expr = ProtoConverter.buildExprs(ss.expressions.inOrder).map(_(idx - 1)) - res <- expr.local[ProtoConverter.DecodeState](_.withTypes(tps).withPatterns(patTab)) - } yield res + val testFn = tabLaw(ProtoConverter.typedExprToProto(_: TypedExpr[Unit])) { + (ss, idx) => + for { + tps <- ProtoConverter.buildTypes(ss.types.inOrder) + pats = ProtoConverter.buildPatterns(ss.patterns.inOrder) + patTab <- pats.local[ProtoConverter.DecodeState](_.withTypes(tps)) + expr = ProtoConverter + .buildExprs(ss.expressions.inOrder) + .map(_(idx - 1)) + res <- expr.local[ProtoConverter.DecodeState]( + _.withTypes(tps).withPatterns(patTab) + ) + } yield res }(Eq.fromUniversalEquals) - forAll(Generators.genTypedExpr(Gen.const(()), 4, rankn.NTypeGen.genDepth03))(testFn) + forAll( + Generators.genTypedExpr(Gen.const(()), 4, rankn.NTypeGen.genDepth03) + )(testFn) } test("we can roundtrip interface through proto") { forAll(Generators.interfaceGen) { iface => - law(iface, ProtoConverter.interfaceToProto _, ProtoConverter.interfaceFromProto _)(Eq.fromUniversalEquals) + law( + iface, + ProtoConverter.interfaceToProto _, + ProtoConverter.interfaceFromProto _ + )(Eq.fromUniversalEquals) } } @@ -127,49 +151,71 @@ class TestProtoType extends AnyFunSuite with ParTest { } test("we can roundtrip interfaces through proto") { - forAll(Generators.smallDistinctByList(Generators.interfaceGen)(_.name)) { ifaces => - law(ifaces, ProtoConverter.interfacesToProto[List] _, ProtoConverter.interfacesFromProto _)(sortedEq) + forAll(Generators.smallDistinctByList(Generators.interfaceGen)(_.name)) { + ifaces => + law( + ifaces, + ProtoConverter.interfacesToProto[List] _, + ProtoConverter.interfacesFromProto _ + )(sortedEq) } } test("we can roundtrip interfaces from full packages through proto") { forAll(Generators.genPackage(Gen.const(()), 10)) { packMap => - val ifaces = packMap.iterator.map { case (_, p) => Package.interfaceOf(p) }.toList - law(ifaces, ProtoConverter.interfacesToProto[List] _, ProtoConverter.interfacesFromProto _)(sortedEq) + val ifaces = packMap.iterator.map { case (_, p) => + Package.interfaceOf(p) + }.toList + law( + ifaces, + ProtoConverter.interfacesToProto[List] _, + ProtoConverter.interfacesFromProto _ + )(sortedEq) } } test("we can roundtrip interfaces through file") { - forAll(Generators.smallDistinctByList(Generators.interfaceGen)(_.name)) { ifaces => - testWithTempFile { path => - for { - _ <- ProtoConverter.writeInterfaces(ifaces, path) - ifaces1 <- ProtoConverter.readInterfaces(path :: Nil) - _ = assert(sortedEq.eqv(ifaces, ifaces1)) - } yield () - } + forAll(Generators.smallDistinctByList(Generators.interfaceGen)(_.name)) { + ifaces => + testWithTempFile { path => + for { + _ <- ProtoConverter.writeInterfaces(ifaces, path) + ifaces1 <- ProtoConverter.readInterfaces(path :: Nil) + _ = assert(sortedEq.eqv(ifaces, ifaces1)) + } yield () + } } } test("test some hand written packages") { - def ser(p: List[Package.Typed[Unit]]): Try[List[proto.Package]] = - p.traverse(ProtoConverter.packageToProto) - def deser(ps: List[proto.Package]): Try[List[Package.Typed[Unit]]] = - ProtoConverter.packagesFromProto(Nil, ps).map { case (_, p) => p.sortBy(_.name) } + def ser(p: List[Package.Typed[Unit]]): Try[List[proto.Package]] = + p.traverse(ProtoConverter.packageToProto) + def deser(ps: List[proto.Package]): Try[List[Package.Typed[Unit]]] = + ProtoConverter.packagesFromProto(Nil, ps).map { case (_, p) => + p.sortBy(_.name) + } val tf = Package.typedFunctor - TestUtils.testInferred(List( -"""package Foo + TestUtils.testInferred( + List( + """package Foo export bar bar = 1 """ - ), "Foo", { (packs, _) => - law(packs.toMap.values.toList.sortBy(_.name).map { pt => Package.setProgramFrom(tf.void(pt), ()) }, - ser _, - deser _)(Eq.fromUniversalEquals) - }) + ), + "Foo", + { (packs, _) => + law( + packs.toMap.values.toList.sortBy(_.name).map { pt => + Package.setProgramFrom(tf.void(pt), ()) + }, + ser _, + deser _ + )(Eq.fromUniversalEquals) + } + ) } test("we can roundtrip packages through proto") { @@ -177,7 +223,9 @@ bar = 1 def ser(p: List[Package.Typed[Unit]]): Try[List[proto.Package]] = p.traverse(ProtoConverter.packageToProto) def deser(ps: List[proto.Package]): Try[List[Package.Typed[Unit]]] = - ProtoConverter.packagesFromProto(Nil, ps).map { case (_, p) => p.sortBy(_.name) } + ProtoConverter.packagesFromProto(Nil, ps).map { case (_, p) => + p.sortBy(_.name) + } val packList = packMap.toList.sortBy(_._1).map(_._2) law(packList, ser _, deser _)(Eq.fromUniversalEquals) diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala index b947d8e05..820cc741a 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala @@ -3,14 +3,17 @@ package org.bykn.bosatsu.codegen.python import cats.data.NonEmptyList import java.math.BigInteger import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.python.core.{ParserFacade => JythonParserFacade} import org.scalatest.funsuite.AnyFunSuite class CodeTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 50000) - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 50000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 1000) lazy val genPy2Name: Gen[String] = { @@ -38,10 +41,15 @@ class CodeTest extends AnyFunSuite { Gen.oneOf( Gen.identifier.map(Code.PyString), genIdent, - Gen.oneOf(Code.Const.Zero, Code.Const.One, Code.Const.True, Code.Const.False), + Gen.oneOf( + Code.Const.Zero, + Code.Const.One, + Code.Const.True, + Code.Const.False + ), genDotselect, - Gen.choose(-1024, 1024).map(Code.fromInt)) - + Gen.choose(-1024, 1024).map(Code.fromInt) + ) if (depth <= 0) genZero else { @@ -57,9 +65,11 @@ class CodeTest extends AnyFunSuite { Code.Const.Eq, Code.Const.Neq, Code.Const.Gt, - Code.Const.Lt) + Code.Const.Lt + ) - val genOp = Gen.zip(rec, opName, rec).map { case (a, b, c) => Code.Op(a, b, c) } + val genOp = + Gen.zip(rec, opName, rec).map { case (a, b, c) => Code.Op(a, b, c) } val genTup = for { @@ -88,9 +98,19 @@ class CodeTest extends AnyFunSuite { (1, genOp), (2, rec.map(Code.Parens(_))), (2, Gen.zip(rec, Gen.choose(0, 100)).map { case (a, p) => a.get(p) }), - (1, Gen.zip(rec, Gen.option(rec), Gen.option(rec)).map { case (a, s, e) => Code.SelectRange(a, s, e) }), + ( + 1, + Gen.zip(rec, Gen.option(rec), Gen.option(rec)).map { case (a, s, e) => + Code.SelectRange(a, s, e) + } + ), (1, Gen.oneOf(genTup, genList)), // these can really blow things up - (2, Gen.zip(Gen.listOf(genIdent), rec).map { case (args, x) => Code.Lambda(args, x) }), + ( + 2, + Gen.zip(Gen.listOf(genIdent), rec).map { case (args, x) => + Code.Lambda(args, x) + } + ), (1, genApp), (1, genTern) ) @@ -108,7 +128,12 @@ class CodeTest extends AnyFunSuite { Gen.frequency( (10, recX), (1, Gen.zip(recS, rec).map { case (s, r) => s.withValue(r) }), - (1, Gen.zip(genNel(4, cond), rec).map { case (conds, e) => Code.IfElse(conds, e) }) + ( + 1, + Gen.zip(genNel(4, cond), rec).map { case (conds, e) => + Code.IfElse(conds, e) + } + ) ) } @@ -118,11 +143,12 @@ class CodeTest extends AnyFunSuite { lst <- Gen.listOfN(cnt, genA) } yield NonEmptyList.fromListUnsafe(lst) - def genStatement(depth: Int): Gen[Code.Statement] = { val genZero = { val gp = Gen.const(Code.Pass) - val genImp = Gen.zip(genPy2Name, Gen.option(genIdent)).map { case (m, a) => Code.Import(m, a) } + val genImp = Gen.zip(genPy2Name, Gen.option(genIdent)).map { + case (m, a) => Code.Import(m, a) + } Gen.oneOf(gp, genImp) } @@ -151,8 +177,10 @@ class CodeTest extends AnyFunSuite { val genBlock = genNel(5, recStmt).map(Code.Block(_)) val genRet = recVL.map(Code.toReturn(_)) val genAlways = recVL.map(Code.always(_)) - val genAssign = Gen.zip(genIdent, recVL).map { case (v, e) => Code.addAssign(v, e) } - val genWhile = Gen.zip(recExpr, recStmt).map { case (c, b) => Code.While(c, b) } + val genAssign = + Gen.zip(genIdent, recVL).map { case (v, e) => Code.addAssign(v, e) } + val genWhile = + Gen.zip(recExpr, recStmt).map { case (c, b) => Code.While(c, b) } val genIf = for { conds <- genNel(4, Gen.zip(recExpr, recStmt)) @@ -160,11 +188,11 @@ class CodeTest extends AnyFunSuite { } yield Code.ifStatement(conds, elseCond) val genDef = - for { - name <- genIdent - args <- Gen.listOf(genIdent) - body <- recStmt - } yield Code.Def(name, args, body) + for { + name <- genIdent + args <- Gen.listOf(genIdent) + body <- recStmt + } yield Code.Def(name, args, body) Gen.frequency( (20, genZero), @@ -183,10 +211,15 @@ class CodeTest extends AnyFunSuite { def assertParse(str: String) = { try { - val mod = JythonBarrier.run(JythonParserFacade.parseExpressionOrModule(new java.io.StringReader(str), "filename.py", new org.python.core.CompilerFlags())) + val mod = JythonBarrier.run( + JythonParserFacade.parseExpressionOrModule( + new java.io.StringReader(str), + "filename.py", + new org.python.core.CompilerFlags() + ) + ) assert(mod != null) - } - catch { + } catch { case _: Throwable => val msg = "\n\n" + ("=" * 80) + "\n\n" + str + "\n\n" + ("=" * 80) assert(false, msg) @@ -219,19 +252,26 @@ else: test("test some Operator examples") { import Code._ - val apbpc = Op(Ident("a"), Const.Plus, Op(Ident("b"), Const.Plus, Ident("c"))) + val apbpc = + Op(Ident("a"), Const.Plus, Op(Ident("b"), Const.Plus, Ident("c"))) assert(toDoc(apbpc).renderTrim(80) == """a + b + c""") - val apbmc = Op(Ident("a"), Const.Plus, Op(Ident("b"), Const.Minus, Ident("c"))) + val apbmc = + Op(Ident("a"), Const.Plus, Op(Ident("b"), Const.Minus, Ident("c"))) assert(toDoc(apbmc).renderTrim(80) == """a + b - c""") - val ambmc = Op(Ident("a"), Const.Minus, Op(Ident("b"), Const.Minus, Ident("c"))) + val ambmc = + Op(Ident("a"), Const.Minus, Op(Ident("b"), Const.Minus, Ident("c"))) assert(toDoc(ambmc).renderTrim(80) == """a - (b - c)""") - val amzmbmc = Op(Op(Ident("a"), Const.Minus, Ident("z")), Const.Minus, Op(Ident("b"), Const.Minus, Ident("c"))) + val amzmbmc = Op( + Op(Ident("a"), Const.Minus, Ident("z")), + Const.Minus, + Op(Ident("b"), Const.Minus, Ident("c")) + ) assert(toDoc(amzmbmc).renderTrim(80) == """(a - z) - (b - c)""") } @@ -250,11 +290,9 @@ else: if (cmp == 0) { assert(p1.eval(Code.Const.Eq, p2) == Code.Const.True) - } - else if (cmp < 0) { + } else if (cmp < 0) { assert(p1.eval(Code.Const.Lt, p2) == Code.Const.True) - } - else { + } else { assert(p1.eval(Code.Const.Gt, p2) == Code.Const.True) } } @@ -302,17 +340,39 @@ else: val gop = Gen.oneOf(Code.Const.Plus, Code.Const.Minus, Code.Const.Times) forAll(gi, gi, gi, gop, gop) { (a, b, c, op1, op2) => - val left = Code.Op(Code.Op(Code.fromLong(a), op1, Code.fromLong(b)), op2, Code.fromLong(c)) - assert(left.simplify == Code.PyInt(op2(op1(BigInteger.valueOf(a), BigInteger.valueOf(b)), BigInteger.valueOf(c)))) + val left = Code.Op( + Code.Op(Code.fromLong(a), op1, Code.fromLong(b)), + op2, + Code.fromLong(c) + ) + assert( + left.simplify == Code.PyInt( + op2( + op1(BigInteger.valueOf(a), BigInteger.valueOf(b)), + BigInteger.valueOf(c) + ) + ) + ) - val right = Code.Op(Code.fromLong(a), op1, Code.Op(Code.fromLong(b), op2, Code.fromLong(c))) - assert(right.simplify == Code.PyInt(op1(BigInteger.valueOf(a), op2(BigInteger.valueOf(b), BigInteger.valueOf(c))))) + val right = Code.Op( + Code.fromLong(a), + op1, + Code.Op(Code.fromLong(b), op2, Code.fromLong(c)) + ) + assert( + right.simplify == Code.PyInt( + op1( + BigInteger.valueOf(a), + op2(BigInteger.valueOf(b), BigInteger.valueOf(c)) + ) + ) + ) } } def runAll(op: Code.Expression): Option[Code.PyInt] = op match { - case pi@Code.PyInt(_) => Some(pi) + case pi @ Code.PyInt(_) => Some(pi) case Code.Op(left, op: Code.IntOp, right) => for { l <- runAll(left) @@ -321,7 +381,11 @@ else: case _ => None } - def genOp(depth: Int, go: Gen[Code.IntOp], gen0: Gen[Code.Expression]): Gen[Code.Expression] = + def genOp( + depth: Int, + go: Gen[Code.IntOp], + gen0: Gen[Code.Expression] + ): Gen[Code.Expression] = if (depth <= 0) gen0 else { val rec = Gen.lzy(genIntOp(depth - 1, go)) @@ -334,14 +398,22 @@ else: def genIntOp(depth: Int, go: Gen[Code.IntOp]): Gen[Code.Expression] = genOp(depth, go, Gen.choose(-1024, 1024).map(Code.fromInt)) - test("any sequence of IntOps is optimized") { - forAll(genIntOp(5, Gen.oneOf(Code.Const.Plus, Code.Const.Minus, Code.Const.Times))) { op => + forAll( + genIntOp( + 5, + Gen.oneOf(Code.Const.Plus, Code.Const.Minus, Code.Const.Times) + ) + ) { op => // adding zero collapses to an Int assert(Some(op.evalPlus(Code.fromInt(0))) == runAll(op)) assert(Some(Code.fromInt(0).evalPlus(op)) == runAll(op)) assert(Some(op.evalMinus(Code.fromInt(0))) == runAll(op)) - assert(Some(Code.fromInt(0).evalMinus(op)) == runAll(op.evalTimes(Code.fromInt(-1)))) + assert( + Some(Code.fromInt(0).evalMinus(op)) == runAll( + op.evalTimes(Code.fromInt(-1)) + ) + ) assert(Some(Code.fromInt(1).evalTimes(op)) == runAll(op)) } } @@ -350,13 +422,21 @@ else: val gen = genOp( 5, Gen.oneOf(Code.Const.Plus, Code.Const.Minus), - Gen.oneOf(Gen.choose(-1024, 1024).map(Code.fromInt), Gen.identifier.map(Code.Ident(_)))) + Gen.oneOf( + Gen.choose(-1024, 1024).map(Code.fromInt), + Gen.identifier.map(Code.Ident(_)) + ) + ) forAll(gen) { op => val simpOp = op.simplify - def assertGood(x: Code.Expression, isRight: Boolean): org.scalatest.Assertion = + def assertGood( + x: Code.Expression, + isRight: Boolean + ): org.scalatest.Assertion = x match { - case Code.PyInt(_) => assert(isRight, s"found: $x on the left inside of $simpOp") + case Code.PyInt(_) => + assert(isRight, s"found: $x on the left inside of $simpOp") case Code.Op(left, _, right) => assertGood(left, false) assertGood(right, isRight) @@ -376,20 +456,20 @@ else: assert(block(Pass, Pass) == Pass) forAll(genNel(4, genStatement(3))) { case NonEmptyList(h, t) => - val stmt = block(h, t :_*) + val stmt = block(h, t: _*) def passCount(s: Statement): Int = s match { - case Pass => 1 + case Pass => 1 case Block(s) => s.toList.map(passCount).sum - case _ => 0 + case _ => 0 } def notPassCount(s: Statement): Int = s match { - case Pass => 0 + case Pass => 0 case Block(s) => s.toList.map(notPassCount).sum - case _ => 1 + case _ => 1 } val pc = passCount(stmt) @@ -410,7 +490,14 @@ else: val regressions: List[Code.Expression] = List( - Code.SelectItem(Code.Ternary(Code.fromInt(0), Code.fromInt(0), Code.MakeTuple(List(Code.fromInt(42)))), 0) + Code.SelectItem( + Code.Ternary( + Code.fromInt(0), + Code.fromInt(0), + Code.MakeTuple(List(Code.fromInt(42))) + ), + 0 + ) ) regressions.foreach { expr => @@ -425,15 +512,13 @@ else: case Code.PyBool(b) => if (b) { assert(tern == t.simplify) - } - else { + } else { assert(tern == f.simplify) } case Code.PyInt(i) => if (i != BigInteger.ZERO) { assert(tern == t.simplify) - } - else { + } else { assert(tern == f.simplify) } case whoKnows => @@ -446,7 +531,7 @@ else: forAll(genExpr(4)) { expr => expr.identOrParens match { case Code.Ident(_) | Code.Parens(_) => assert(true) - case other => assert(false, other.toString) + case other => assert(false, other.toString) } } } @@ -457,12 +542,16 @@ else: val and = left.evalAnd(right) assert(Code.toDoc(and).renderTrim(80) == "(a == b) and (b == c)") - assert(Code.toDoc(Code.Ident("z").evalAnd(and)).renderTrim(80) == "z and (a == b) and (b == c)") + assert( + Code + .toDoc(Code.Ident("z").evalAnd(and)) + .renderTrim(80) == "z and (a == b) and (b == c)" + ) } test("simplify applies lambdas") { // (lambda x: f(x))(y) == f(y) - val genArgs = + val genArgs = for { n <- Gen.choose(0, 4) largs <- Gen.listOfN(n, genIdent) @@ -471,16 +560,18 @@ else: } yield (Code.Lambda(largs, result), args) forAll(genArgs) { case (lam, arg) => - assert(lam(arg: _*).simplify == Code.substitute(lam.args.zip(arg).toMap, lam.result).simplify) + assert( + lam(arg: _*).simplify == Code + .substitute(lam.args.zip(arg).toMap, lam.result) + .simplify + ) } } test("(lambda x: lambda y: x + y)(y)") { val x = Code.Ident("x") val y = Code.Ident("y") - val hardCase = Code.Lambda(List(x), - Code.Lambda(List(y), x + y) - ) + val hardCase = Code.Lambda(List(x), Code.Lambda(List(y), x + y)) val applied = hardCase(y).simplify val y0 = Code.Ident("y0") @@ -493,11 +584,13 @@ else: test("simplify(Map.empty, x) == x") { forAll(genExpr(4)) { x => - assert(Code.substitute(Map.empty, x) == x) + assert(Code.substitute(Map.empty, x) == x) } } - test("simplify creates subsets of freeIdents (we can remove ternary branches)") { + test( + "simplify creates subsets of freeIdents (we can remove ternary branches)" + ) { forAll(genExpr(4)) { x => assert(Code.freeIdents(x.simplify).subsetOf(Code.freeIdents(x))) } diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala index 0cd7d121d..7c6f197d5 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala @@ -5,9 +5,19 @@ import cats.data.NonEmptyList import java.io.{ByteArrayInputStream, InputStream} import java.nio.file.{Paths, Files} import java.util.concurrent.Semaphore -import org.bykn.bosatsu.{PackageMap, MatchlessFromTypedExpr, Parser, Package, LocationMap, PackageName} +import org.bykn.bosatsu.{ + PackageMap, + MatchlessFromTypedExpr, + Parser, + Package, + LocationMap, + PackageName +} import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.python.util.PythonInterpreter import org.python.core.{PyInteger, PyFunction, PyObject, PyTuple} @@ -58,13 +68,21 @@ class PythonGenTest extends AnyFunSuite { tup.getArray()(0) match { case x if x == zero => // True == one in our encoding - assert(tup.getArray()(1) == one, prefix + "/" + tup.getArray()(2).toString) + assert( + tup.getArray()(1) == one, + prefix + "/" + tup.getArray()(2).toString + ) () case x if x == one => val suite = tup.getArray()(1).toString - foreachList(tup.getArray()(2)) { t => checkTest(t, prefix + "/" + suite); () } + foreachList(tup.getArray()(2)) { t => + checkTest(t, prefix + "/" + suite); () + } case other => - assert(false, s"expected a Test to have 0 or 1 in first tuple entry: $tup, $other") + assert( + false, + s"expected a Test to have 0 or 1 in first tuple entry: $tup, $other" + ) () } } @@ -73,7 +91,6 @@ class PythonGenTest extends AnyFunSuite { def toS(s: String): String = new String(Files.readAllBytes(Paths.get(s)), "UTF-8") - val packNEL = NonEmptyList(path, rest.toList) .map { s => @@ -85,7 +102,7 @@ class PythonGenTest extends AnyFunSuite { val res = PackageMap.typeCheckParsed(packNEL, Nil, "") res.left match { case Some(err) => sys.error(err.toString) - case None => () + case None => () } res.right.get @@ -102,7 +119,8 @@ class PythonGenTest extends AnyFunSuite { val bosatsuPM = compileFile(natPathBosatu) val matchless = MatchlessFromTypedExpr.compile(bosatsuPM) - val packMap = PythonGen.renderAll(matchless, Map.empty, Map.empty, Map.empty) + val packMap = + PythonGen.renderAll(matchless, Map.empty, Map.empty, Map.empty) val natDoc = packMap(PackageName.parts("Bosatsu", "Nat"))._2 val natStr = natDoc.renderTrim(80) @@ -110,8 +128,7 @@ class PythonGenTest extends AnyFunSuite { try { intr.execfile(isfromString(natStr), "nat.py") checkTest(intr.get("tests"), "Nat.bosatsu") - } - catch { + } catch { case t: Throwable => System.err.println("=" * 80) System.err.println("couldn't compile nat.py") @@ -133,8 +150,7 @@ class PythonGenTest extends AnyFunSuite { val res = fn.__call__(arg) if (i <= 0) { assert(res == new PyInteger(0)) - } - else { + } else { assert(fn.__call__(arg) == arg) } } @@ -155,48 +171,52 @@ class PythonGenTest extends AnyFunSuite { JythonBarrier.run(intr.close()) - def runBoTests(path: String, pn: PackageName, testName: String) = JythonBarrier.run { - val intr = new PythonInterpreter() - - val bosatsuPM = compileFile(path) - val matchless = MatchlessFromTypedExpr.compile(bosatsuPM) + def runBoTests(path: String, pn: PackageName, testName: String) = + JythonBarrier.run { + val intr = new PythonInterpreter() - val packMap = PythonGen.renderAll(matchless, Map.empty, Map.empty, Map.empty) - val doc = packMap(pn)._2 + val bosatsuPM = compileFile(path) + val matchless = MatchlessFromTypedExpr.compile(bosatsuPM) - intr.execfile(isfromString(doc.renderTrim(80)), "test.py") - checkTest(intr.get(testName), pn.asString) + val packMap = + PythonGen.renderAll(matchless, Map.empty, Map.empty, Map.empty) + val doc = packMap(pn)._2 - intr.close() - } + intr.execfile(isfromString(doc.renderTrim(80)), "test.py") + checkTest(intr.get(testName), pn.asString) + intr.close() + } test("we can compile StrConcatExample") { runBoTests( "test_workspace/StrConcatExample.bosatsu", PackageName.parts("StrConcatExample"), - "test") + "test" + ) } - test("test some list pattern matches") { runBoTests( "test_workspace/ListPat.bosatsu", PackageName.parts("ListPat"), - "tests") + "tests" + ) } test("test euler6") { runBoTests( "test_workspace/euler6.bosatsu", PackageName.parts("Euler", "P6"), - "tests") + "tests" + ) } test("test PredefTests") { runBoTests( "test_workspace/PredefTests.bosatsu", PackageName.parts("PredefTests"), - "test") + "test" + ) } } diff --git a/core/.js/src/main/scala/org/bykn/bosatsu/Par.scala b/core/.js/src/main/scala/org/bykn/bosatsu/Par.scala index 5744722d8..052bc4bd2 100644 --- a/core/.js/src/main/scala/org/bykn/bosatsu/Par.scala +++ b/core/.js/src/main/scala/org/bykn/bosatsu/Par.scala @@ -1,11 +1,10 @@ package org.bykn.bosatsu -/** - * This is an abstraction to handle parallel computation, not effectful - * computation. It is used in places where we have parallelism in expensive - * operations. Since scalajs cannot handle this, we use conditional build - * to replace the scalajs with just running directly - */ +/** This is an abstraction to handle parallel computation, not effectful + * computation. It is used in places where we have parallelism in expensive + * operations. Since scalajs cannot handle this, we use conditional build to + * replace the scalajs with just running directly + */ object Par { class Box[A] { private[this] var value: A = _ @@ -34,4 +33,3 @@ object Par { @inline def toF[A](pa: P[A]): F[A] = pa.get } - diff --git a/core/.jvm/src/main/scala/org/bykn/bosatsu/Par.scala b/core/.jvm/src/main/scala/org/bykn/bosatsu/Par.scala index 0c4fe41aa..5da6397ce 100644 --- a/core/.jvm/src/main/scala/org/bykn/bosatsu/Par.scala +++ b/core/.jvm/src/main/scala/org/bykn/bosatsu/Par.scala @@ -4,12 +4,11 @@ import java.util.concurrent.Executors import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration.Duration -/** - * This is an abstraction to handle parallel computation, not effectful - * computation. It is used in places where we have parallelism in expensive - * operations. Since scalajs cannot handle this, we use conditional build - * to replace the scalajs with just running directly - */ +/** This is an abstraction to handle parallel computation, not effectful + * computation. It is used in places where we have parallelism in expensive + * operations. Since scalajs cannot handle this, we use conditional build to + * replace the scalajs with just running directly + */ object Par { type F[A] = Future[A] type P[A] = Promise[A] @@ -22,7 +21,8 @@ object Par { def shutdownService(es: ExecutionService): Unit = es.shutdown() - def ecFromService(es: ExecutionService): EC = ExecutionContext.fromExecutor(es) + def ecFromService(es: ExecutionService): EC = + ExecutionContext.fromExecutor(es) @inline def start[A](a: => A)(implicit ec: EC): F[A] = Future(a) @@ -39,4 +39,3 @@ object Par { @inline def toF[A](pa: P[A]): F[A] = pa.future } - diff --git a/core/src/main/scala/org/bykn/bosatsu/BindingStatement.scala b/core/src/main/scala/org/bykn/bosatsu/BindingStatement.scala index ab37868f0..9078d3539 100644 --- a/core/src/main/scala/org/bykn/bosatsu/BindingStatement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/BindingStatement.scala @@ -1,16 +1,18 @@ package org.bykn.bosatsu -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} case class BindingStatement[B, V, T](name: B, value: V, in: T) object BindingStatement { private[this] val eqDoc = Doc.text(" = ") - implicit def document[A: Document, V: Document, T: Document]: Document[BindingStatement[A, V, T]] = + implicit def document[A: Document, V: Document, T: Document] + : Document[BindingStatement[A, V, T]] = Document.instance[BindingStatement[A, V, T]] { let => import let._ - Document[A].document(name) + eqDoc + Document[V].document(value) + Document[T].document(in) + Document[A].document(name) + eqDoc + Document[V].document( + value + ) + Document[T].document(in) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/CollectionUtils.scala b/core/src/main/scala/org/bykn/bosatsu/CollectionUtils.scala index 0a3ce7b75..5f87819a5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/CollectionUtils.scala +++ b/core/src/main/scala/org/bykn/bosatsu/CollectionUtils.scala @@ -8,10 +8,13 @@ import scala.util.{Success, Failure, Try} import cats.implicits._ object CollectionUtils { - /** - * Return the unique keys on the Right, and the duplicate keys on the Left (and possibly Both) - */ - def uniqueByKey[A, B: Order](as: NonEmptyList[A])(fn: A => B): Ior[NonEmptyMap[B, (A, NonEmptyList[A])], NonEmptyMap[B, A]] = { + + /** Return the unique keys on the Right, and the duplicate keys on the Left + * (and possibly Both) + */ + def uniqueByKey[A, B: Order](as: NonEmptyList[A])( + fn: A => B + ): Ior[NonEmptyMap[B, (A, NonEmptyList[A])], NonEmptyMap[B, A]] = { def check(as: NonEmptyList[A]): Either[(A, NonEmptyList[A]), A] = as match { case NonEmptyList(a, Nil) => @@ -21,13 +24,16 @@ object CollectionUtils { } // We know this is nonEmpty, so good and bad can't both be empty - val checked: SortedMap[B, Either[(A, NonEmptyList[A]), A]] = as.groupBy(fn).map { case (b, as) => (b, check(as)) } + val checked: SortedMap[B, Either[(A, NonEmptyList[A]), A]] = + as.groupBy(fn).map { case (b, as) => (b, check(as)) } val good: SortedMap[B, A] = checked.collect { case (b, Right(a)) => (b, a) } - val bad: SortedMap[B, (A, NonEmptyList[A])] = checked.collect { case (b, Left(a)) => (b, a) } + val bad: SortedMap[B, (A, NonEmptyList[A])] = checked.collect { + case (b, Left(a)) => (b, a) + } (NonEmptyMap.fromMap(bad), NonEmptyMap.fromMap(good)) match { - case (None, Some(goodNE)) => Ior.right(goodNE) - case (Some(badNE), None) => Ior.left(badNE) + case (None, Some(goodNE)) => Ior.right(goodNE) + case (Some(badNE), None) => Ior.left(badNE) case (Some(badNE), Some(goodNE)) => Ior.both(badNE, goodNE) // $COVERAGE-OFF$ case _ => @@ -36,17 +42,20 @@ object CollectionUtils { } } - def listToUnique[A, K: Order, V](l: List[A])(key: A => K, value: A => V, msg: => String): Try[SortedMap[K, V]] = + def listToUnique[A, K: Order, V]( + l: List[A] + )(key: A => K, value: A => V, msg: => String): Try[SortedMap[K, V]] = NonEmptyList.fromList(l) match { case None => Success(SortedMap.empty[K, V]) case Some(nel) => uniqueByKey(nel)(key) match { - case Ior.Right(b) => Success(b.toSortedMap.map { case (k, a) => (k, value(a)) }) + case Ior.Right(b) => + Success(b.toSortedMap.map { case (k, a) => (k, value(a)) }) case Ior.Left(errMap) => Failure(new IllegalArgumentException(s"$msg, $errMap")) case Ior.Both(errMap, _) => Failure(new IllegalArgumentException(s"$msg, $errMap")) - } + } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/CommentStatement.scala b/core/src/main/scala/org/bykn/bosatsu/CommentStatement.scala index e9b3b34e1..102f00326 100644 --- a/core/src/main/scala/org/bykn/bosatsu/CommentStatement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/CommentStatement.scala @@ -2,12 +2,10 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} -/** - * Represents a commented thing. Commented[A] would probably - * be a better name - */ +/** Represents a commented thing. Commented[A] would probably be a better name + */ final case class CommentStatement[T](message: NonEmptyList[String], on: T) object CommentStatement { @@ -16,20 +14,24 @@ object CommentStatement { implicit def document[T: Document]: Document[CommentStatement[T]] = Document.instance[CommentStatement[T]] { comment => import comment._ - val block = Doc.intercalate(Doc.line, message.toList.map { mes => Doc.char('#') + Doc.text(mes) }) + val block = Doc.intercalate( + Doc.line, + message.toList.map { mes => Doc.char('#') + Doc.text(mes) } + ) block + Doc.line + Document[T].document(on) } - /** on should make sure indent is matching - * this is to allow a P[Unit] that does nothing for testing or other applications - */ + /** on should make sure indent is matching this is to allow a P[Unit] that + * does nothing for testing or other applications + */ def parser[T](onP: String => P0[T]): Parser.Indy[CommentStatement[T]] = Parser.Indy { indent => val sep = Parser.newline val commentBlock: P[NonEmptyList[String]] = // if the next line is part of the comment until we see the # or not - (Parser.maybeSpace.with1.soft *> commentPart).repSep(sep = sep) <* Parser.newline.orElse(P.end) + (Parser.maybeSpace.with1.soft *> commentPart) + .repSep(sep = sep) <* Parser.newline.orElse(P.end) (commentBlock ~ onP(indent)) .map { case (m, on) => CommentStatement(m, on) } @@ -38,5 +40,3 @@ object CommentStatement { val commentPart: P[String] = P.char('#') *> P.until0(P.char('\n')) } - - diff --git a/core/src/main/scala/org/bykn/bosatsu/Declaration.scala b/core/src/main/scala/org/bykn/bosatsu/Declaration.scala index 031213b6f..c6402fa7f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Declaration.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Declaration.scala @@ -1,10 +1,19 @@ package org.bykn.bosatsu -import Parser.{ Combinators, Indy, maybeSpace, maybeSpacesAndLines, spaces, toEOL1, keySpace, MaybeTupleOrParens } +import Parser.{ + Combinators, + Indy, + maybeSpace, + maybeSpacesAndLines, + spaces, + toEOL1, + keySpace, + MaybeTupleOrParens +} import cats.data.NonEmptyList import org.bykn.bosatsu.graph.Memoize import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import scala.collection.immutable.SortedSet import Indy.IndyMethods @@ -14,9 +23,9 @@ import ListLang.{KVPair, SpliceOrItem} import Identifier.{Bindable, Constructor} import cats.implicits._ -/** - * Represents the syntactic version of Expr - */ + +/** Represents the syntactic version of Expr + */ sealed abstract class Declaration { import Declaration._ @@ -44,32 +53,48 @@ sealed abstract class Declaration { (args.head.toDoc + Doc.char('.') + fnDoc, args.tail) } - prefix + Doc.char('(') + Doc.intercalate(Doc.text(", "), body.map(_.toDoc)) + Doc.char(')') + prefix + Doc.char('(') + Doc.intercalate( + Doc.text(", "), + body.map(_.toDoc) + ) + Doc.char(')') case ApplyOp(left, Identifier.Operator(opStr), right) => left.toDoc space Doc.text(opStr) space right.toDoc case Binding(b) => val d0 = Document[Padding[Declaration]] val withNewLine = Document.instance[Padding[Declaration]] { pd => - Doc.line + d0.document(pd) + Doc.line + d0.document(pd) } - BindingStatement.document(Document[Pattern.Parsed], Document.instance[NonBinding](_.toDoc), withNewLine).document(b) + BindingStatement + .document( + Document[Pattern.Parsed], + Document.instance[NonBinding](_.toDoc), + withNewLine + ) + .document(b) case LeftApply(pat, _, arg, body) => - Document[Pattern.Parsed].document(pat) + Doc.text(" <- ") + arg.toDoc + Doc.line + - Document[Padding[Declaration]].document(body) + Document[Pattern.Parsed].document(pat) + Doc.text( + " <- " + ) + arg.toDoc + Doc.line + + Document[Padding[Declaration]].document(body) case Comment(c) => CommentStatement.document[Padding[Declaration]].document(c) case CommentNB(c) => CommentStatement.document[Padding[NonBinding]].document(c) case DefFn(d) => - implicit val pairDoc: Document[(OptIndent[Declaration], Padding[Declaration])] = - Document.instance { - case (fnBody, letBody) => - fnBody.sepDoc + - Document[OptIndent[Declaration]].document(fnBody) + - Doc.line + - Document[Padding[Declaration]].document(letBody) + implicit val pairDoc + : Document[(OptIndent[Declaration], Padding[Declaration])] = + Document.instance { case (fnBody, letBody) => + fnBody.sepDoc + + Document[OptIndent[Declaration]].document(fnBody) + + Doc.line + + Document[Padding[Declaration]].document(letBody) } - DefStatement.document[Pattern.Parsed, (OptIndent[Declaration], Padding[Declaration])].document(d) + DefStatement + .document[ + Pattern.Parsed, + (OptIndent[Declaration], Padding[Declaration]) + ] + .document(d) case IfElse(ifCases, elseCase) => def checkBody(cb: (Declaration, OptIndent[Declaration])): Doc = { val (check, optbody) = cb @@ -78,16 +103,22 @@ sealed abstract class Declaration { check.toDoc + Doc.char(':') + rest } - val elseDoc = elseCase.sepDoc + Document[OptIndent[Declaration]].document(elseCase) + val elseDoc = + elseCase.sepDoc + Document[OptIndent[Declaration]].document(elseCase) val tail = Doc.text("else:") + elseDoc :: Nil - val parts = (Doc.text("if ") + checkBody(ifCases.head)) :: (ifCases.tail.map(Doc.text("elif ") + checkBody(_))) ::: tail + val parts = (Doc.text("if ") + checkBody(ifCases.head)) :: (ifCases.tail + .map(Doc.text("elif ") + checkBody(_))) ::: tail Doc.intercalate(Doc.line, parts) case Ternary(trueCase, cond, falseCase) => - Doc.intercalate(Doc.space, - trueCase.toDoc :: Doc.text("if") :: cond.toDoc :: Doc.text("else") :: falseCase.toDoc :: Nil) + Doc.intercalate( + Doc.space, + trueCase.toDoc :: Doc.text("if") :: cond.toDoc :: Doc.text( + "else" + ) :: falseCase.toDoc :: Nil + ) case Lambda(args, body) => // slash style: - //val argDoc = Doc.char('\\') + Doc.intercalate(Doc.text(", "), args.toList.map(Document[Pattern.Parsed].document(_))) + // val argDoc = Doc.char('\\') + Doc.intercalate(Doc.text(", "), args.toList.map(Document[Pattern.Parsed].document(_))) // bare style: val argDoc = args match { case NonEmptyList(one, Nil) => @@ -95,11 +126,13 @@ sealed abstract class Declaration { if (Pattern.isNonUnitTuple(one)) { // wrap with parens Doc.char('(') + od + Doc.char(')') - } - else od + } else od case args => // more than one must wrap in () - Doc.char('(') + Doc.intercalate(Doc.text(", "), args.toList.map(Document[Pattern.Parsed].document(_))) + Doc.char(')') + Doc.char('(') + Doc.intercalate( + Doc.text(", "), + args.toList.map(Document[Pattern.Parsed].document(_)) + ) + Doc.char(')') } argDoc + Doc.text(" -> ") + body.toDoc case Literal(lit) => Document[Lit].document(lit) @@ -108,26 +141,34 @@ sealed abstract class Declaration { val caseDoc = Doc.text("case ") - implicit val patDoc: Document[(Pattern.Parsed, OptIndent[Declaration])] = + implicit val patDoc + : Document[(Pattern.Parsed, OptIndent[Declaration])] = Document.instance[(Pattern.Parsed, OptIndent[Declaration])] { case (pat, decl) => - caseDoc + Document[Pattern.Parsed].document(pat) + Doc.text(":") + decl.sepDoc + pid.document(decl) + caseDoc + Document[Pattern.Parsed].document(pat) + Doc.text( + ":" + ) + decl.sepDoc + pid.document(decl) } implicit def linesDoc[T: Document]: Document[NonEmptyList[T]] = - Document.instance { ts => Doc.intercalate(Doc.line, ts.toList.map(Document[T].document _)) } + Document.instance { ts => + Doc.intercalate(Doc.line, ts.toList.map(Document[T].document _)) + } - val piPat = Document[OptIndent[NonEmptyList[(Pattern.Parsed, OptIndent[Declaration])]]] + val piPat = Document[OptIndent[ + NonEmptyList[(Pattern.Parsed, OptIndent[Declaration])] + ]] val kindDoc = kind match { case RecursionKind.NonRecursive => Doc.text("match ") - case RecursionKind.Recursive => Doc.text("recur ") + case RecursionKind.Recursive => Doc.text("recur ") } // TODO this isn't quite right kindDoc + typeName.toDoc + Doc.char(':') + args.sepDoc + piPat.document(args) - case m@Matches(arg, p) => + case m @ Matches(arg, p) => val da = arg match { // matches binds tighter than all these - case Lambda(_, _) | IfElse(_, _) | ApplyOp(_, _, _) | Match(_, _, _) => + case Lambda(_, _) | IfElse(_, _) | ApplyOp(_, _, _) | + Match(_, _, _) => Parens(arg)(m.region).toDoc case _ => arg.toDoc @@ -139,23 +180,31 @@ sealed abstract class Declaration { // we need a trailing comma here: Doc.char('(') + h.toDoc + Doc.char(',') + Doc.char(')') case TupleCons(items) => - Doc.char('(') + Doc.intercalate(Doc.text(", "), - items.map(_.toDoc)) + Doc.char(')') + Doc.char('(') + Doc.intercalate( + Doc.text(", "), + items.map(_.toDoc) + ) + Doc.char(')') case Var(name) => Document[Identifier].document(name) case StringDecl(parts) => val useDouble = parts.exists { - case StringDecl.Literal(_, str) => str.contains('\'') && !str.contains('"') + case StringDecl.Literal(_, str) => + str.contains('\'') && !str.contains('"') case _ => false } val q = if (useDouble) '"' else '\'' - val inner = Doc.intercalate(Doc.empty, + val inner = Doc.intercalate( + Doc.empty, parts.toList.map { - case StringDecl.Literal(_, str) => Doc.text(StringUtil.escape(q, str)) - case StringDecl.StrExpr(decl) => Doc.text("${") + decl.toDoc + Doc.char('}') - case StringDecl.CharExpr(decl) => Doc.text("$.{") + decl.toDoc + Doc.char('}') - }) + case StringDecl.Literal(_, str) => + Doc.text(StringUtil.escape(q, str)) + case StringDecl.StrExpr(decl) => + Doc.text("${") + decl.toDoc + Doc.char('}') + case StringDecl.CharExpr(decl) => + Doc.text("$.{") + decl.toDoc + Doc.char('}') + } + ) Doc.char(q) + inner + Doc.char(q) case ListDecl(list) => @@ -166,24 +215,28 @@ sealed abstract class Declaration { case RecordConstructor(name, args) => val argDoc = Doc.char('{') + - Doc.intercalate(Doc.char(',') + Doc.space, - args.toList.map(_.toDoc)) + Doc.char('}') + Doc.intercalate( + Doc.char(',') + Doc.space, + args.toList.map(_.toDoc) + ) + Doc.char('}') Declaration.identDoc.document(name) + Doc.space + argDoc } - /** - * Get the set of free variables in this declaration. - * These are variables that must be defined at an outer - * lexical scope in order to typecheck - */ + /** Get the set of free variables in this declaration. These are variables + * that must be defined at an outer lexical scope in order to typecheck + */ def freeVars: SortedSet[Bindable] = { - def loop(decl: Declaration, bound: Set[Bindable], acc: SortedSet[Bindable]): SortedSet[Bindable] = + def loop( + decl: Declaration, + bound: Set[Bindable], + acc: SortedSet[Bindable] + ): SortedSet[Bindable] = decl match { case Annotation(term, _) => loop(term, bound, acc) case Apply(fn, args, _) => (fn :: args).foldLeft(acc) { (acc0, d) => loop(d, bound, acc0) } - case ao@ApplyOp(left, _, right) => + case ao @ ApplyOp(left, _, right) => val acc0 = loop(left, bound, acc) val acc1 = loop(ao.opVar, bound, acc0) loop(right, bound, acc1) @@ -191,7 +244,7 @@ sealed abstract class Declaration { val acc0 = loop(v, bound, acc) val bound1 = bound ++ n.names loop(in.padded, bound1, acc0) - case Comment(c) => loop(c.on.padded, bound, acc) + case Comment(c) => loop(c.on.padded, bound, acc) case CommentNB(c) => loop(c.on.padded, bound, acc) case DefFn(d) => val (body, rest) = d.result @@ -216,7 +269,7 @@ sealed abstract class Declaration { case Lambda(args, body) => val bound1 = bound ++ args.patternNames loop(body, bound1, acc) - case la@LeftApply(_, _, _, _) => + case la @ LeftApply(_, _, _, _) => loop(la.rewrite, bound, acc) case Literal(_) => acc case Match(_, typeName, args) => @@ -231,12 +284,12 @@ sealed abstract class Declaration { case TupleCons(items) => items.foldLeft(acc) { (acc0, d) => loop(d, bound, acc0) } case Var(name: Bindable) if !bound(name) => acc + name - case Var(_) => acc + case Var(_) => acc case StringDecl(items) => items.foldLeft(acc) { - case (acc, StringDecl.StrExpr(nb)) => loop(nb, bound, acc) + case (acc, StringDecl.StrExpr(nb)) => loop(nb, bound, acc) case (acc, StringDecl.CharExpr(nb)) => loop(nb, bound, acc) - case (acc, _) => acc + case (acc, _) => acc } case ListDecl(ListLang.Cons(items)) => items.foldLeft(acc) { (acc0, sori) => @@ -277,25 +330,23 @@ sealed abstract class Declaration { decl match { case Var(_) | Literal(_) => true case Annotation(term, _) => loop(term) - case Parens(p) => loop(p) - case _ => false + case Parens(p) => loop(p) + case _ => false } loop(this) } - /** - * Wrap in Parens is needed - */ + /** Wrap in Parens is needed + */ def toNonBinding: NonBinding = this match { case nb: NonBinding => nb - case decl => Parens(decl)(decl.region) + case decl => Parens(decl)(decl.region) } - /** - * This returns *all* names in the declaration, bound or not - */ + /** This returns *all* names in the declaration, bound or not + */ def allNames: SortedSet[Bindable] = { def loop(decl: Declaration, acc: SortedSet[Bindable]): SortedSet[Bindable] = decl match { @@ -308,9 +359,9 @@ sealed abstract class Declaration { case Binding(BindingStatement(n, v, in)) => val acc0 = loop(v, acc ++ n.names) loop(in.padded, acc0) - case Comment(c) => loop(c.on.padded, acc) + case Comment(c) => loop(c.on.padded, acc) case CommentNB(c) => loop(c.on.padded, acc) - case DefFn(d) => + case DefFn(d) => // def sets up a binding to itself, which // may or may not be recursive val acc1 = (acc + d.name) ++ d.args.toList.flatMap(_.patternNames) @@ -323,7 +374,7 @@ sealed abstract class Declaration { loop(v.get, acc1) } loop(elseCase.get, acc2) - case la@LeftApply(_, _, _, _) => + case la @ LeftApply(_, _, _, _) => loop(la.rewrite, acc) case Ternary(t, c, f) => val acc1 = loop(t, acc) @@ -344,12 +395,12 @@ sealed abstract class Declaration { case TupleCons(items) => items.foldLeft(acc) { (acc0, d) => loop(d, acc0) } case Var(name: Bindable) => acc + name - case Var(_) => acc + case Var(_) => acc case StringDecl(nel) => nel.foldLeft(acc) { - case (acc0, StringDecl.StrExpr(decl)) => loop(decl, acc0) + case (acc0, StringDecl.StrExpr(decl)) => loop(decl, acc0) case (acc0, StringDecl.CharExpr(decl)) => loop(decl, acc0) - case (acc0, _) => acc0 + case (acc0, _) => acc0 } case ListDecl(ListLang.Cons(items)) => items.foldLeft(acc) { (acc0, sori) => @@ -372,7 +423,7 @@ sealed abstract class Declaration { case RecordConstructor(_, args) => args.foldLeft(acc) { case (acc, RecordArg.Pair(_, v)) => loop(v, acc) - case (acc, RecordArg.Simple(n)) => acc + n + case (acc, RecordArg.Simple(n)) => acc + n } } loop(this, SortedSet.empty) @@ -381,11 +432,24 @@ sealed abstract class Declaration { def replaceRegions(r: Region): Declaration = this match { case Binding(BindingStatement(n, v, in)) => - Binding(BindingStatement(n, v.replaceRegionsNB(r), in.map(_.replaceRegions(r))))(r) + Binding( + BindingStatement( + n, + v.replaceRegionsNB(r), + in.map(_.replaceRegions(r)) + ) + )(r) case Comment(CommentStatement(lines, c)) => Comment(CommentStatement(lines, c.map(_.replaceRegions(r))))(r) case DefFn(d) => - DefFn(d.copy(result = (d.result._1.map(_.replaceRegions(r)), d.result._2.map(_.replaceRegions(r)))))(r) + DefFn( + d.copy(result = + ( + d.result._1.map(_.replaceRegions(r)), + d.result._2.map(_.replaceRegions(r)) + ) + ) + )(r) case LeftApply(p, _, right, b) => LeftApply(p, r, right.replaceRegionsNB(r), b.map(_.replaceRegions(r))) case nb: NonBinding => nb.replaceRegionsNB(r) @@ -393,7 +457,8 @@ sealed abstract class Declaration { } object Declaration { - implicit val document: Document[Declaration] = Document.instance[Declaration](_.toDoc) + implicit val document: Document[Declaration] = + Document.instance[Declaration](_.toDoc) implicit val hasRegion: HasRegion[Declaration] = HasRegion.instance[Declaration](_.region) @@ -403,16 +468,18 @@ object Declaration { case object Parens extends ApplyKind } - /** - * Try to substitute ex for ident in the expression: in - * - * This can fail if the free variables in ex are shadowed - * above ident in in. - * - * this code is very similar to TypedExpr.substitute - * if bugs are found in one, consult the other - */ - def substitute[A](ident: Bindable, ex: NonBinding, in: Declaration): Option[Declaration] = { + /** Try to substitute ex for ident in the expression: in + * + * This can fail if the free variables in ex are shadowed above ident in in. + * + * this code is very similar to TypedExpr.substitute if bugs are found in + * one, consult the other + */ + def substitute[A]( + ident: Bindable, + ex: NonBinding, + in: Declaration + ): Option[Declaration] = { // if we hit a shadow, we don't need to substitute down // that branch @inline def shadows(i: Bindable): Boolean = i === ident @@ -421,10 +488,13 @@ object Declaration { // this causes us to return None lazy val masks: Bindable => Boolean = ex.freeVars - def loopLL[F[_]](ll: ListLang[F, NonBinding, Pattern.Parsed])(fn: F[NonBinding] => Option[F[NonBinding]]): Option[ListLang[F, NonBinding, Pattern.Parsed]] = + def loopLL[F[_]](ll: ListLang[F, NonBinding, Pattern.Parsed])( + fn: F[NonBinding] => Option[F[NonBinding]] + ): Option[ListLang[F, NonBinding, Pattern.Parsed]] = ll match { case ListLang.Cons(items) => - items.traverse(fn) + items + .traverse(fn) .map(ListLang.Cons(_)) case ListLang.Comprehension(ex, b, in, filt) => // b sets up bindings for filt and ex @@ -432,7 +502,8 @@ object Declaration { .flatMap { in1 => val pnames = b.names if (pnames.exists(masks)) None - else if (pnames.exists(shadows)) Some(ListLang.Comprehension(ex, b, in1, filt)) + else if (pnames.exists(shadows)) + Some(ListLang.Comprehension(ex, b, in1, filt)) else { // no shadowing or masking (fn(ex), filt.traverse(loop)) @@ -455,8 +526,7 @@ object Declaration { // This is no longer a simple RecordArg Some(RecordArg.Pair(fn, ex)) } - } - else Some(ra) + } else Some(ra) } def loop(decl: NonBinding): Option[NonBinding] = @@ -466,7 +536,7 @@ object Declaration { case Apply(fn, args, kind) => (loop(fn), args.traverse(loop)) .mapN(Apply(_, _, kind)(decl.region)) - case aop@ApplyOp(left, op, right) if (op: Bindable) === ident => + case aop @ ApplyOp(left, op, right) if (op: Bindable) === ident => // we cannot make a general substition on ApplyOp ex match { case Var(op1: Identifier.Operator) => @@ -497,7 +567,7 @@ object Declaration { if (pnames.exists(masks)) None else if (pnames.exists(shadows)) Some(decl) else loopDec(body).map(Lambda(args, _)(decl.region)) - case l@Literal(_) => Some(l) + case l @ Literal(_) => Some(l) case Match(k, arg, cases) => val caseRes = cases @@ -527,7 +597,8 @@ object Declaration { nel .traverse { case StringDecl.StrExpr(nb) => loop(nb).map(StringDecl.StrExpr(_)) - case StringDecl.CharExpr(nb) => loop(nb).map(StringDecl.CharExpr(_)) + case StringDecl.CharExpr(nb) => + loop(nb).map(StringDecl.CharExpr(_)) case lit => Some(lit) } .map(StringDecl(_)(decl.region)) @@ -538,7 +609,8 @@ object Declaration { loopLL(ll)(_.traverse(loop)) .map(DictDecl(_)(decl.region)) case RecordConstructor(c, args) => - args.traverse(loopRA) + args + .traverse(loopRA) .map(RecordConstructor(c, _)(decl.region)) } @@ -553,8 +625,7 @@ object Declaration { .map { v1 => Binding(BindingStatement(n, v1, in))(decl.region) } - } - else { + } else { // we substitute on both (loop(v), in.traverse(loopDec)) .mapN { (v1, in1) => @@ -589,8 +660,7 @@ object Declaration { .map { v1 => LeftApply(n, r, v1, in) } - } - else { + } else { // we substitute on both (loop(v), in.traverse(loopDec)) .mapN { (v1, in1) => @@ -633,16 +703,15 @@ object Declaration { (Identifier.bindableParser ~ (pairFn.?)) .map { - case (b, None) => Simple(b) + case (b, None) => Simple(b) case (b, Some(fn)) => fn(b) } } } - /** - * These are all Declarations other than Binding, DefFn and Comment, - * in other words, things that don't need to start with indentation - */ + /** These are all Declarations other than Binding, DefFn and Comment, in other + * words, things that don't need to start with indentation + */ sealed abstract class NonBinding extends Declaration { def replaceRegionsNB(r: Region): NonBinding = this match { @@ -654,17 +723,29 @@ object Declaration { case CommentNB(CommentStatement(msg, p)) => CommentNB(CommentStatement(msg, p.map(_.replaceRegionsNB(r))))(r) case IfElse(ifCases, elseCase) => - IfElse(ifCases.map { case (bool, res) => (bool.replaceRegionsNB(r), res.map(_.replaceRegions(r))) }, - elseCase.map(_.replaceRegions(r)))(r) + IfElse( + ifCases.map { case (bool, res) => + (bool.replaceRegionsNB(r), res.map(_.replaceRegions(r))) + }, + elseCase.map(_.replaceRegions(r)) + )(r) case Ternary(t, c, f) => - Ternary(t.replaceRegionsNB(r), c.replaceRegionsNB(r), f.replaceRegionsNB(r)) + Ternary( + t.replaceRegionsNB(r), + c.replaceRegionsNB(r), + f.replaceRegionsNB(r) + ) case Lambda(args, body) => Lambda(args, body.replaceRegions(r))(r) case Literal(lit) => Literal(lit)(r) case Match(rec, arg, branches) => - Match(rec, + Match( + rec, arg.replaceRegionsNB(r), - branches.map(_.map { case (p, x) => (p, x.map(_.replaceRegions(r))) }))(r) + branches.map(_.map { case (p, x) => + (p, x.map(_.replaceRegions(r))) + }) + )(r) case Matches(a, p) => Matches(a.replaceRegionsNB(r), p)(r) case Parens(p) => Parens(p.replaceRegions(r))(r) @@ -674,25 +755,41 @@ object Declaration { case StringDecl(nel) => val ne1 = nel.map { case StringDecl.Literal(_, s) => StringDecl.Literal(r, s) - case StringDecl.CharExpr(e) => StringDecl.CharExpr(e.replaceRegionsNB(r)) - case StringDecl.StrExpr(e) => StringDecl.StrExpr(e.replaceRegionsNB(r)) + case StringDecl.CharExpr(e) => + StringDecl.CharExpr(e.replaceRegionsNB(r)) + case StringDecl.StrExpr(e) => + StringDecl.StrExpr(e.replaceRegionsNB(r)) } StringDecl(ne1)(r) case ListDecl(ListLang.Cons(items)) => ListDecl(ListLang.Cons(items.map(_.map(_.replaceRegionsNB(r)))))(r) case ListDecl(ListLang.Comprehension(ex, b, in, filter)) => - ListDecl(ListLang.Comprehension(ex.map(_.replaceRegionsNB(r)), b, in.replaceRegionsNB(r), filter.map(_.replaceRegionsNB(r))))(r) + ListDecl( + ListLang.Comprehension( + ex.map(_.replaceRegionsNB(r)), + b, + in.replaceRegionsNB(r), + filter.map(_.replaceRegionsNB(r)) + ) + )(r) case DictDecl(ListLang.Cons(items)) => - DictDecl(ListLang.Cons(items.map { - case ListLang.KVPair(k, v) => - ListLang.KVPair(k.replaceRegionsNB(r), v.replaceRegionsNB(r)) + DictDecl(ListLang.Cons(items.map { case ListLang.KVPair(k, v) => + ListLang.KVPair(k.replaceRegionsNB(r), v.replaceRegionsNB(r)) }))(r) case DictDecl(ListLang.Comprehension(ex, b, in, filter)) => - DictDecl(ListLang.Comprehension(ex.map(_.replaceRegionsNB(r)), b, in.replaceRegionsNB(r), filter.map(_.replaceRegionsNB(r))))(r) + DictDecl( + ListLang.Comprehension( + ex.map(_.replaceRegionsNB(r)), + b, + in.replaceRegionsNB(r), + filter.map(_.replaceRegionsNB(r)) + ) + )(r) case RecordConstructor(c, args) => val args1 = args.map { case RecordArg.Simple(b) => RecordArg.Simple(b) - case RecordArg.Pair(k, v) => RecordArg.Pair(k, v.replaceRegionsNB(r)) + case RecordArg.Pair(k, v) => + RecordArg.Pair(k, v.replaceRegionsNB(r)) } RecordConstructor(c, args1)(r) } @@ -703,16 +800,35 @@ object Declaration { Document.instance(_.toDoc) } - /** - * These are "binding" kinds, (not-NonBinding) + /** These are "binding" kinds, (not-NonBinding) */ - case class Binding(binding: BindingStatement[Pattern.Parsed, NonBinding, Padding[Declaration]])(implicit val region: Region) extends Declaration - case class Comment(comment: CommentStatement[Padding[Declaration]])(implicit val region: Region) extends Declaration - case class DefFn(deffn: DefStatement[Pattern.Parsed, (OptIndent[Declaration], Padding[Declaration])])(implicit val region: Region) extends Declaration - case class LeftApply(arg: Pattern.Parsed, argRegion: Region, fn: NonBinding, result: Padding[Declaration]) extends Declaration { + case class Binding( + binding: BindingStatement[Pattern.Parsed, NonBinding, Padding[ + Declaration + ]] + )(implicit val region: Region) + extends Declaration + case class Comment(comment: CommentStatement[Padding[Declaration]])(implicit + val region: Region + ) extends Declaration + case class DefFn( + deffn: DefStatement[ + Pattern.Parsed, + (OptIndent[Declaration], Padding[Declaration]) + ] + )(implicit val region: Region) + extends Declaration + case class LeftApply( + arg: Pattern.Parsed, + argRegion: Region, + fn: NonBinding, + result: Padding[Declaration] + ) extends Declaration { def region: Region = argRegion + result.padded.region def rewrite: NonBinding = { - val lam = Lambda(NonEmptyList.one(arg), result.padded)(argRegion + result.padded.region) + val lam = Lambda(NonEmptyList.one(arg), result.padded)( + argRegion + result.padded.region + ) Apply(fn, NonEmptyList.one(lam), ApplyKind.Parens)(region) } } @@ -725,88 +841,129 @@ object Declaration { // value in tests and construct them. // These reasons are a bit abusive, and we may revisit this in the future // - case class Annotation(fn: NonBinding, tpe: TypeRef)(implicit val region: Region) extends NonBinding - case class Apply(fn: NonBinding, args: NonEmptyList[NonBinding], kind: ApplyKind)(implicit val region: Region) extends NonBinding - case class ApplyOp(left: NonBinding, op: Identifier.Operator, right: NonBinding) extends NonBinding { + case class Annotation(fn: NonBinding, tpe: TypeRef)(implicit + val region: Region + ) extends NonBinding + case class Apply( + fn: NonBinding, + args: NonEmptyList[NonBinding], + kind: ApplyKind + )(implicit val region: Region) + extends NonBinding + case class ApplyOp( + left: NonBinding, + op: Identifier.Operator, + right: NonBinding + ) extends NonBinding { val region = left.region + right.region def opVar: Var = Var(op)(Region(left.region.end, right.region.start)) def toApply: Apply = Apply(opVar, NonEmptyList(left, right :: Nil), ApplyKind.Parens)(region) } - case class CommentNB(comment: CommentStatement[Padding[NonBinding]])(implicit val region: Region) extends NonBinding - - case class IfElse(ifCases: NonEmptyList[(NonBinding, OptIndent[Declaration])], - elseCase: OptIndent[Declaration])(implicit val region: Region) extends NonBinding - case class Ternary(trueCase: NonBinding, cond: NonBinding, falseCase: NonBinding) extends NonBinding { + case class CommentNB(comment: CommentStatement[Padding[NonBinding]])(implicit + val region: Region + ) extends NonBinding + + case class IfElse( + ifCases: NonEmptyList[(NonBinding, OptIndent[Declaration])], + elseCase: OptIndent[Declaration] + )(implicit val region: Region) + extends NonBinding + case class Ternary( + trueCase: NonBinding, + cond: NonBinding, + falseCase: NonBinding + ) extends NonBinding { val region = trueCase.region + falseCase.region } - case class Lambda(args: NonEmptyList[Pattern.Parsed], body: Declaration)(implicit val region: Region) extends NonBinding + case class Lambda(args: NonEmptyList[Pattern.Parsed], body: Declaration)( + implicit val region: Region + ) extends NonBinding case class Literal(lit: Lit)(implicit val region: Region) extends NonBinding case class Match( - kind: RecursionKind, - arg: NonBinding, - cases: OptIndent[NonEmptyList[(Pattern.Parsed, OptIndent[Declaration])]])( - implicit val region: Region) extends NonBinding - case class Matches(arg: NonBinding, pattern: Pattern.Parsed)(implicit val region: Region) extends NonBinding - case class Parens(of: Declaration)(implicit val region: Region) extends NonBinding - case class TupleCons(items: List[NonBinding])(implicit val region: Region) extends NonBinding - case class Var(name: Identifier)(implicit val region: Region) extends NonBinding - - /** - * This represents code like: - * Foo { bar: 12 } - */ - case class RecordConstructor(cons: Constructor, arg: NonEmptyList[RecordArg])(implicit val region: Region) extends NonBinding - /** - * This represents interpolated strings - */ - case class StringDecl(items: NonEmptyList[StringDecl.Part])(implicit val region: Region) extends NonBinding + kind: RecursionKind, + arg: NonBinding, + cases: OptIndent[NonEmptyList[(Pattern.Parsed, OptIndent[Declaration])]] + )(implicit val region: Region) + extends NonBinding + case class Matches(arg: NonBinding, pattern: Pattern.Parsed)(implicit + val region: Region + ) extends NonBinding + case class Parens(of: Declaration)(implicit val region: Region) + extends NonBinding + case class TupleCons(items: List[NonBinding])(implicit val region: Region) + extends NonBinding + case class Var(name: Identifier)(implicit val region: Region) + extends NonBinding + + /** This represents code like: Foo { bar: 12 } + */ + case class RecordConstructor(cons: Constructor, arg: NonEmptyList[RecordArg])( + implicit val region: Region + ) extends NonBinding + + /** This represents interpolated strings + */ + case class StringDecl(items: NonEmptyList[StringDecl.Part])(implicit + val region: Region + ) extends NonBinding object StringDecl { sealed abstract class Part case class Literal(region: Region, toStr: String) extends Part case class StrExpr(nonBinding: NonBinding) extends Part case class CharExpr(nonBinding: NonBinding) extends Part } - /** - * This represents the list construction language - */ - case class ListDecl(list: ListLang[SpliceOrItem, NonBinding, Pattern.Parsed])(implicit val region: Region) extends NonBinding - /** - * Here are dict constructors and comprehensions - */ - case class DictDecl(list: ListLang[KVPair, NonBinding, Pattern.Parsed])(implicit val region: Region) extends NonBinding + + /** This represents the list construction language + */ + case class ListDecl(list: ListLang[SpliceOrItem, NonBinding, Pattern.Parsed])( + implicit val region: Region + ) extends NonBinding + + /** Here are dict constructors and comprehensions + */ + case class DictDecl(list: ListLang[KVPair, NonBinding, Pattern.Parsed])( + implicit val region: Region + ) extends NonBinding val matchKindParser: P[RecursionKind] = P.string("match") .as(RecursionKind.NonRecursive) .orElse( P.string("recur") - .as(RecursionKind.Recursive)).soft <* Parser.spaces.peek + .as(RecursionKind.Recursive) + ) + .soft <* Parser.spaces.peek - /** - * A pattern can also be a declaration in some cases - * - * TODO, patterns don't parse with regions, so we lose track of precise position information - * if we want to point to an inner portion of it - */ + /** A pattern can also be a declaration in some cases + * + * TODO, patterns don't parse with regions, so we lose track of precise + * position information if we want to point to an inner portion of it + */ def toPattern(d: NonBinding): Option[Pattern.Parsed] = d match { case Annotation(term, tpe) => toPattern(term).map(Pattern.Annotation(_, tpe)) - case Var(nm@Identifier.Constructor(_)) => - Some(Pattern.PositionalStruct( - Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), Nil)) + case Var(nm @ Identifier.Constructor(_)) => + Some( + Pattern.PositionalStruct( + Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), + Nil + ) + ) case Var(v: Bindable) => Some(Pattern.Var(v)) - case Literal(lit) => Some(Pattern.Literal(lit)) + case Literal(lit) => Some(Pattern.Literal(lit)) case StringDecl(NonEmptyList(StringDecl.Literal(_, s), Nil)) => Some(Pattern.Literal(Lit.Str(s))) case StringDecl(items) => def toStrPart(p: StringDecl.Part): Option[Pattern.StrPart] = p match { case StringDecl.Literal(_, str) => Some(Pattern.StrPart.LitStr(str)) - case StringDecl.StrExpr(Var(v: Bindable)) => Some(Pattern.StrPart.NamedStr(v)) - case StringDecl.CharExpr(Var(v: Bindable)) => Some(Pattern.StrPart.NamedChar(v)) + case StringDecl.StrExpr(Var(v: Bindable)) => + Some(Pattern.StrPart.NamedStr(v)) + case StringDecl.CharExpr(Var(v: Bindable)) => + Some(Pattern.StrPart.NamedChar(v)) case _ => None } items.traverse(toStrPart).map(Pattern.StrPat(_)) @@ -825,10 +982,12 @@ object Declaration { (toPattern(left), toPattern(right)).mapN { (l, r) => Pattern.union(l, r :: Nil) } - case Apply(Var(nm@Identifier.Constructor(_)), args, ApplyKind.Parens) => + case Apply(Var(nm @ Identifier.Constructor(_)), args, ApplyKind.Parens) => args.traverse(toPattern(_)).map { argPats => - Pattern.PositionalStruct(Pattern.StructKind.Named(nm, - Pattern.StructKind.Style.TupleLike), argPats.toList) + Pattern.PositionalStruct( + Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), + argPats.toList + ) } case TupleCons(ps) => ps.traverse(toPattern(_)).map { argPats => @@ -836,14 +995,15 @@ object Declaration { } case Parens(p: NonBinding) => toPattern(p) case RecordConstructor(cons, args) => - args.traverse { - case RecordArg.Simple(b) => Some(Left(b)) - case RecordArg.Pair(k, v) => - toPattern(v).map { vpat => - Right((k, vpat)) - } - } - .map(Pattern.recordPat(cons, _)(Pattern.StructKind.Named(_, _))) + args + .traverse { + case RecordArg.Simple(b) => Some(Left(b)) + case RecordArg.Pair(k, v) => + toPattern(v).map { vpat => + Right((k, vpat)) + } + } + .map(Pattern.recordPat(cons, _)(Pattern.StructKind.Named(_, _))) case _ => None } @@ -857,23 +1017,28 @@ object Declaration { parser.indentBefore.mapF(Padding.parser(_)) def commentP(parser: Indy[Declaration]): Parser.Indy[Declaration] = - CommentStatement.parser( + CommentStatement + .parser( { indent => Padding.parser(P.string0(indent).with1 *> parser(indent)) } ) .region - .map { - case (r, c) => - c.on.padded match { - case nb: NonBinding => - CommentNB(CommentStatement(c.message, Padding(c.on.lines, nb)))(r) - case _ => - Comment(c)(r) - } + .map { case (r, c) => + c.on.padded match { + case nb: NonBinding => + CommentNB(CommentStatement(c.message, Padding(c.on.lines, nb)))(r) + case _ => + Comment(c)(r) + } } def commentNBP(parser: P[NonBinding]): Indy[CommentNB] = - CommentStatement.parser( - { indent => Padding.parser(P.string0(indent).with1 *> (Parser.maybeSpace.soft.with1 *> parser)) } + CommentStatement + .parser( + { indent => + Padding.parser( + P.string0(indent).with1 *> (Parser.maybeSpace.soft.with1 *> parser) + ) + } ) .region .map { case (r, c) => CommentNB(c)(r) } @@ -883,7 +1048,8 @@ object Declaration { OptIndent.indy(parser).product(Indy.lift(toEOL1) *> restP(parser)) restParser.mapF { rp => - DefStatement.parser(Pattern.bindParser, maybeSpace.with1 *> rp) + DefStatement + .parser(Pattern.bindParser, maybeSpace.with1 *> rp) .region .map { case (r, d) => DefFn(d)(r) } } @@ -904,7 +1070,9 @@ object Declaration { .map(_._2) val elifs1 = - ifelif("elif").nonEmptyList(sepIndy = Indy.toEOLIndent) <* Indy.toEOLIndent + ifelif("elif").nonEmptyList(sepIndy = + Indy.toEOLIndent + ) <* Indy.toEOLIndent val notIfs = Indy { indent => elifs1(indent).?.with1 ~ elseTerm(indent) @@ -913,14 +1081,14 @@ object Declaration { (ifelif("if") <* Indy.toEOLIndent) .cutThen(notIfs) .region - .map { - case (region, (ifcase, (optElses, elseBody))) => - val elses = - optElses match { - case None => Nil - case Some(s) => s.toList // type inference works better than fold sadly - } - IfElse(NonEmptyList(ifcase, elses), elseBody)(region) + .map { case (region, (ifcase, (optElses, elseBody))) => + val elses = + optElses match { + case None => Nil + case Some(s) => + s.toList // type inference works better than fold sadly + } + IfElse(NonEmptyList(ifcase, elses), elseBody)(region) } } @@ -932,7 +1100,8 @@ object Declaration { val q2 = '"' inner.mapF { p => - val plist = StringUtil.interpolatedString(q1, start, p, end) + val plist = StringUtil + .interpolatedString(q1, start, p, end) .orElse(StringUtil.interpolatedString(q2, start, p, end)) plist.region.map { @@ -944,24 +1113,29 @@ object Declaration { case (r, h :: tail) => StringDecl(NonEmptyList(h, tail).map { case Right((region, str)) => StringDecl.Literal(region, str) - case Left(expr) => expr + case Left(expr) => expr })(r) - } + } } } def lambdaP(parser: Indy[Declaration]): Indy[Lambda] = { - val params = Indy.lift(P.char('\\') *> maybeSpace *> Pattern.bindParser.nonEmptyList) + val params = + Indy.lift(P.char('\\') *> maybeSpace *> Pattern.bindParser.nonEmptyList) - val withSlash = OptIndent.blockLike(params, parser, maybeSpace.with1 *> rightArrow) + val withSlash = OptIndent + .blockLike(params, parser, maybeSpace.with1 *> rightArrow) .region .map { case (r, (args, body)) => Lambda(args, body.get)(r) } val noSlashParamsArrow = // patterns are ambiguous with expressions wo se need backtracking - MaybeTupleOrParens.parser(Pattern.bindParser) <* (maybeSpace *> ((!Operators.operatorToken) *> rightArrow)) - - val noSlash = OptIndent.blockLike(Indy.lift(noSlashParamsArrow.backtrack), parser, P.unit) + MaybeTupleOrParens.parser( + Pattern.bindParser + ) <* (maybeSpace *> ((!Operators.operatorToken) *> rightArrow)) + + val noSlash = OptIndent + .blockLike(Indy.lift(noSlashParamsArrow.backtrack), parser, P.unit) .region .map { case (r, (rawPat, body)) => val args = rawPat match { @@ -971,7 +1145,7 @@ object Declaration { NonEmptyList.one(p) case MaybeTupleOrParens.Tuple(Nil) => // consider this the same as the pattern () - NonEmptyList.one(Pattern.tuple(Nil)) + NonEmptyList.one(Pattern.tuple(Nil)) case MaybeTupleOrParens.Tuple(h :: tail) => // we consider a top level non-empty tuple to be a list: NonEmptyList(h, tail) @@ -986,52 +1160,70 @@ object Declaration { val withTrailingExpr = expr.cutLeftP(maybeSpace) // TODO: make this strict val bp = (P.string("case") *> Parser.spaces).?.with1 *> Pattern.matchParser - //val bp = (P.string("case") *> Parser.spaces).with1 *> Pattern.matchParser + // val bp = (P.string("case") *> Parser.spaces).with1 *> Pattern.matchParser val branch = OptIndent.block(Indy.lift(bp), withTrailingExpr) - val left = Indy.lift(matchKindParser <* spaces).cutThen(arg).cutLeftP(maybeSpace) - OptIndent.block(left, branch.nonEmptyList(Indy.toEOLIndent)) + val left = + Indy.lift(matchKindParser <* spaces).cutThen(arg).cutLeftP(maybeSpace) + OptIndent + .block(left, branch.nonEmptyList(Indy.toEOLIndent)) .region .map { case (r, ((kind, arg), branches)) => Match(kind, arg, branches)(r) } } - /** - * These are keywords inside declarations (if, match, def) - * that cannot be used by identifiers - */ + /** These are keywords inside declarations (if, match, def) that cannot be + * used by identifiers + */ val keywords: Set[String] = - Set("from", "import", "if", "else", "elif", "match", "matches", "def", "recur", "struct", "enum") - - /** - * A Parser that matches keywords - */ + Set( + "from", + "import", + "if", + "else", + "elif", + "match", + "matches", + "def", + "recur", + "struct", + "enum" + ) + + /** A Parser that matches keywords + */ val keywordsP: P[Unit] = P.oneOf(keywords.toList.sorted.map(P.string(_))) <* spaces val varP: P[Var] = - (!keywordsP).with1 *> Identifier.bindableParser.region.map { case (r, i) => Var(i)(r) } + (!keywordsP).with1 *> Identifier.bindableParser.region.map { case (r, i) => + Var(i)(r) + } // this returns a Var with a Constructor or a RecordConstrutor - def recordConstructorP(indent: String, declP: P[NonBinding], noAnn: P[NonBinding]): P[NonBinding] = { + def recordConstructorP( + indent: String, + declP: P[NonBinding], + noAnn: P[NonBinding] + ): P[NonBinding] = { val ws = Parser.maybeIndentedOrSpace(indent) val kv: P[RecordArg] = RecordArg.parser(indent, noAnn) val kvs = kv.nonEmptyListOfWs(ws) // here is the record style: Foo {x: 1, ... - val recArgs = kvs.bracketed(maybeSpace.with1.soft ~ P.char('{') ~ ws, ws ~ P.char('}')) + val recArgs = + kvs.bracketed(maybeSpace.with1.soft ~ P.char('{') ~ ws, ws ~ P.char('}')) // here is tuple style: Foo(a, b) - val tupArgs = declP - .parensLines1Cut - .region - .map { case (r, args) => - { (nm: Var) => Apply(nm, args, ApplyKind.Parens)(nm.region + r) } + val tupArgs = declP.parensLines1Cut.region + .map { + case (r, args) => { (nm: Var) => + Apply(nm, args, ApplyKind.Parens)(nm.region + r) + } } - (Identifier.consParser ~ Parser.either(recArgs, tupArgs).?) - .region + (Identifier.consParser ~ Parser.either(recArgs, tupArgs).?).region .map { case (region, (n, Some(Left(args)))) => RecordConstructor(n, args)(region) @@ -1047,18 +1239,23 @@ object Declaration { case object Equals extends PatternBindKind case object LeftApplyFn extends PatternBindKind - + val parser: P[PatternBindKind] = eqP.as(Equals) | leftApplyFnP.as(LeftApplyFn) } - private def patternBind(nonBindingParser: Indy[NonBinding], decl: Indy[Declaration]): Indy[Declaration] = { + private def patternBind( + nonBindingParser: Indy[NonBinding], + decl: Indy[Declaration] + ): Indy[Declaration] = { val pat = MaybeTupleOrParens.parser(Pattern.bindParser) - val patPart = pat.region ~ (maybeSpace *> PatternBindKind.parser <* maybeSpace) + val patPart = + pat.region ~ (maybeSpace *> PatternBindKind.parser <* maybeSpace) val parser = nonBindingParser <* Indy.lift(toEOL1) // we can't cut the pattern here because we have some ambiguity in declarations // allow = to be like a block, we can continue on the next line indented - OptIndent.blockLike(Indy.lift(patPart.backtrack), parser, P.unit) + OptIndent + .blockLike(Indy.lift(patPart.backtrack), parser, P.unit) .cutThen(restP(decl)) .region .map { case (region, ((((preg, rawPat), pbk), value), decl)) => @@ -1070,23 +1267,27 @@ object Declaration { case PatternBindKind.LeftApplyFn => val pat = Pattern.fromMaybeTupleOrParens(rawPat) LeftApply(pat, preg, value.get, decl) - + } } } private def listP(p: P[NonBinding], src: P[NonBinding]): P[ListDecl] = - ListLang.parser(p, src, Pattern.bindParser) + ListLang + .parser(p, src, Pattern.bindParser) .region .map { case (r, l) => ListDecl(l)(r) } private def dictP(p: P[NonBinding], src: P[NonBinding]): P[DictDecl] = - ListLang.dictParser(p, src, Pattern.bindParser) + ListLang + .dictParser(p, src, Pattern.bindParser) .region .map { case (r, l) => DictDecl(l)(r) } val lits: P[Literal] = - (Lit.integerParser | Lit.codePointParser).region.map { case (r, l) => Literal(l)(r) } + (Lit.integerParser | Lit.codePointParser).region.map { case (r, l) => + Literal(l)(r) + } private sealed abstract class ParseMode private object ParseMode { @@ -1102,221 +1303,255 @@ object Declaration { * we also parse Bind, Def, Comment */ private[this] val parserCache: ((ParseMode, String)) => P[Declaration] = - Memoize.memoizeDagHashedConcurrent[(ParseMode, String), P[Declaration]] { case ((pm, indent), rec) => - - // TODO: - // since we do a hard set of the mode in these, we lose the thread if we are inside a - // BranchArg so the trailing values : should be interpretted as a the branch end. - // This may actually make the file ambiguous in some cases, or at least point - // to a strange place on parse errors. - // - // I think we need to separate block-like expressions using : from NonBinding - // and make sure that we don't have block like expressions in certain places - - val recurseDecl: P[Declaration] = P.defer(rec((ParseMode.Decl, indent))) // needs to be inside a P for laziness - val recIndy: Indy[Declaration] = Indy { i => rec((ParseMode.Decl, i)) } - - // TODO: aren't NonBinding independent of indentation level> - val recNonBind: P[NonBinding] = P.defer(rec((ParseMode.NB, indent))).asInstanceOf[P[NonBinding]] - val recNBIndy: Indy[NonBinding] = Indy { i => rec((ParseMode.NB, i)).asInstanceOf[P[NonBinding]] } - - val recArg: P[NonBinding] = P.defer(rec((ParseMode.BranchArg, indent)).asInstanceOf[P[NonBinding]]) - val recArgIndy: Indy[NonBinding] = Indy { i => rec((ParseMode.BranchArg, i)).asInstanceOf[P[NonBinding]] } - - val recComp: P[NonBinding] = P.defer(rec((ParseMode.ComprehensionSource, indent))).asInstanceOf[P[NonBinding]] - - val nestedBlock: P[Region => Declaration.NonBinding] = { - /** - * we can either do: ( y = 1 - * y) - * starting a new declaration without indentation, - * or ( - * y = 1 - * y) - * where we allow more indentation. - */ - val noIndent = recurseDecl - val withIndent = Parser.newline *> Parser.spaces.string.flatMap { indent => recIndy(indent) } - maybeSpace.with1 *> (withIndent | noIndent).map { d => { (r: Region) => Parens(d)(r) } } <* maybeSpacesAndLines - } + Memoize.memoizeDagHashedConcurrent[(ParseMode, String), P[Declaration]] { + case ((pm, indent), rec) => + // TODO: + // since we do a hard set of the mode in these, we lose the thread if we are inside a + // BranchArg so the trailing values : should be interpretted as a the branch end. + // This may actually make the file ambiguous in some cases, or at least point + // to a strange place on parse errors. + // + // I think we need to separate block-like expressions using : from NonBinding + // and make sure that we don't have block like expressions in certain places + + val recurseDecl: P[Declaration] = P.defer( + rec((ParseMode.Decl, indent)) + ) // needs to be inside a P for laziness + val recIndy: Indy[Declaration] = Indy { i => rec((ParseMode.Decl, i)) } + + // TODO: aren't NonBinding independent of indentation level> + val recNonBind: P[NonBinding] = + P.defer(rec((ParseMode.NB, indent))).asInstanceOf[P[NonBinding]] + val recNBIndy: Indy[NonBinding] = Indy { i => + rec((ParseMode.NB, i)).asInstanceOf[P[NonBinding]] + } - val tupOrPar: P[NonBinding] = - // TODO: the backtrack here is bad... - Parser.parens(((maybeSpacesAndLines.with1.soft *> ((recNonBind <* (!(maybeSpace ~ bindOp))).backtrack <* maybeSpacesAndLines)) - .tupleOrParens0 - .map { - case Left(p) => { (r: Region) => Parens(p)(r) } - case Right(tup) => { (r: Region) => TupleCons(tup.toList)(r) } - }) - .orElse(nestedBlock) - // or it could be () which is just unit - .orElse(P.pure({ (r: Region) => TupleCons(Nil)(r) })) - , P.unit) - .region - .map { case (r, fn) => fn(r) } - - // since x -> y: t will parse like x -> (y: t) - // if we are in a branch arg, we can't parse annotations on the body of the lambda - val lambBody = if (pm == ParseMode.BranchArg) recArgIndy.asInstanceOf[Indy[Declaration]] else recIndy - val ternaryElseP = if (pm == ParseMode.BranchArg) recArg else recNonBind - - val allNonBind: P[NonBinding] = - P.defer( - P.oneOf( - lambdaP(lambBody)(indent) :: - ifElseP(recArgIndy, recIndy)(indent) :: - matchP(recArgIndy, recIndy)(indent) :: - dictP(recArg, recComp) :: - varP :: - listP(recNonBind, recComp) :: - lits :: - stringDeclOrLit(recNBIndy)(indent) :: - tupOrPar :: - recordConstructorP(indent, recNonBind, recArg) :: - // TODO: comment is ambiguous with binding/non-binding... - // so it prevents us commenting a binding statement - commentNBP(recNonBind)(indent) :: - Nil)) - - /* - * This is where we parse application, either direct, or dot-style - */ - val applied: P[NonBinding] = { - // here we are using . syntax foo.bar(1, 2) - // we also allow foo.(anyExpression)(1, 2) - val fn = varP.orElse(recNonBind.parensCut) - val slashcontinuation = ((maybeSpace ~ P.char('\\') ~ toEOL1).backtrack ~ Parser.maybeSpacesAndLines).?.void - // 0 or more args - val params0 = recNonBind.parensLines0Cut - val justDot = P.not(P.string(".\"") | P.string(".'")).with1 *> P.char('.') - val dotApply: P[NonBinding => NonBinding] = - (slashcontinuation.with1 *> justDot *> (fn ~ params0)) - .region - .map { case (r2, (fn, args)) => + val recArg: P[NonBinding] = P.defer( + rec((ParseMode.BranchArg, indent)).asInstanceOf[P[NonBinding]] + ) + val recArgIndy: Indy[NonBinding] = Indy { i => + rec((ParseMode.BranchArg, i)).asInstanceOf[P[NonBinding]] + } - { (head: NonBinding) => Apply(fn, NonEmptyList(head, args), ApplyKind.Dot)(head.region + r2) } - } + val recComp: P[NonBinding] = P + .defer(rec((ParseMode.ComprehensionSource, indent))) + .asInstanceOf[P[NonBinding]] - // 1 or more args - val params1 = recNonBind.parensLines1Cut - // here we directly call a function foo(1, 2) - val applySuffix: P[NonBinding => NonBinding] = - params1 + val nestedBlock: P[Region => Declaration.NonBinding] = { + + /** we can either do: ( y = 1 y) starting a new declaration without + * indentation, or ( y = 1 y) where we allow more indentation. + */ + val noIndent = recurseDecl + val withIndent = Parser.newline *> Parser.spaces.string.flatMap { + indent => recIndy(indent) + } + maybeSpace.with1 *> (withIndent | noIndent).map { d => + { (r: Region) => Parens(d)(r) } + } <* maybeSpacesAndLines + } + + val tupOrPar: P[NonBinding] = + // TODO: the backtrack here is bad... + Parser + .parens( + ((maybeSpacesAndLines.with1.soft *> ((recNonBind <* (!(maybeSpace ~ bindOp))).backtrack <* maybeSpacesAndLines)).tupleOrParens0 + .map { + case Left(p) => { (r: Region) => Parens(p)(r) } + case Right(tup) => { (r: Region) => TupleCons(tup.toList)(r) } + }) + .orElse(nestedBlock) + // or it could be () which is just unit + .orElse(P.pure({ (r: Region) => TupleCons(Nil)(r) })), + P.unit + ) .region - .map { case (r, args) => - { (fn: NonBinding) => Apply(fn, args, ApplyKind.Parens)(fn.region + r) } + .map { case (r, fn) => fn(r) } + + // since x -> y: t will parse like x -> (y: t) + // if we are in a branch arg, we can't parse annotations on the body of the lambda + val lambBody = + if (pm == ParseMode.BranchArg) + recArgIndy.asInstanceOf[Indy[Declaration]] + else recIndy + val ternaryElseP = if (pm == ParseMode.BranchArg) recArg else recNonBind + + val allNonBind: P[NonBinding] = + P.defer( + P.oneOf( + lambdaP(lambBody)(indent) :: + ifElseP(recArgIndy, recIndy)(indent) :: + matchP(recArgIndy, recIndy)(indent) :: + dictP(recArg, recComp) :: + varP :: + listP(recNonBind, recComp) :: + lits :: + stringDeclOrLit(recNBIndy)(indent) :: + tupOrPar :: + recordConstructorP(indent, recNonBind, recArg) :: + // TODO: comment is ambiguous with binding/non-binding... + // so it prevents us commenting a binding statement + commentNBP(recNonBind)(indent) :: + Nil + ) + ) + + /* + * This is where we parse application, either direct, or dot-style + */ + val applied: P[NonBinding] = { + // here we are using . syntax foo.bar(1, 2) + // we also allow foo.(anyExpression)(1, 2) + val fn = varP.orElse(recNonBind.parensCut) + val slashcontinuation = ((maybeSpace ~ P.char( + '\\' + ) ~ toEOL1).backtrack ~ Parser.maybeSpacesAndLines).?.void + // 0 or more args + val params0 = recNonBind.parensLines0Cut + val justDot = + P.not(P.string(".\"") | P.string(".'")).with1 *> P.char('.') + val dotApply: P[NonBinding => NonBinding] = + (slashcontinuation.with1 *> justDot *> (fn ~ params0)).region + .map { + case (r2, (fn, args)) => { (head: NonBinding) => + Apply(fn, NonEmptyList(head, args), ApplyKind.Dot)( + head.region + r2 + ) + } + } + + // 1 or more args + val params1 = recNonBind.parensLines1Cut + // here we directly call a function foo(1, 2) + val applySuffix: P[NonBinding => NonBinding] = + params1.region + .map { + case (r, args) => { (fn: NonBinding) => + Apply(fn, args, ApplyKind.Parens)(fn.region + r) + } + } + + def repFn[A](fn: P[A => A]): P0[A => A] = + fn.rep0.map { opList => + { (a: A) => opList.foldLeft(a) { (arg, fn) => fn(arg) } } } - def repFn[A](fn: P[A => A]): P0[A => A] = - fn.rep0.map { opList => - { (a: A) => opList.foldLeft(a) { (arg, fn) => fn(arg) } } + (allNonBind ~ repFn(dotApply.orElse(applySuffix))) + .map { case (a, f) => f(a) } + } + // lower priority than calls is type annotation + val annotated: P[NonBinding] = + if (pm == ParseMode.BranchArg) applied + else { + val an: P[NonBinding => NonBinding] = + TypeRef.annotationParser + // TODO remove this backtrack, + // currently we can confuse ending a block with type annotation + // without backtracking here due to nesting losing track of + // when a trailing item is in a BranchArg in e.g. match or if bodies + .backtrack.region + .map { + case (r, tpe) => { (nb: NonBinding) => + Annotation(nb, tpe)(nb.region + r) + } + } + + applied.maybeAp(an) } - (allNonBind ~ repFn(dotApply.orElse(applySuffix))) - .map { case (a, f) => f(a) } - } - // lower priority than calls is type annotation - val annotated: P[NonBinding] = - if (pm == ParseMode.BranchArg) applied - else { - val an: P[NonBinding => NonBinding] = - TypeRef.annotationParser - // TODO remove this backtrack, - // currently we can confuse ending a block with type annotation - // without backtracking here due to nesting losing track of - // when a trailing item is in a BranchArg in e.g. match or if bodies - .backtrack - .region - .map { case (r, tpe) => - { (nb: NonBinding) => Annotation(nb, tpe)(nb.region + r) } + // matched + val matched: P[NonBinding] = { + // x matches p + val matchesOp = + ((maybeSpace.with1 *> P.string( + "matches" + ) *> spaces).backtrack *> Pattern.matchParser).region + .map { + case (region, pat) => { (nb: NonBinding) => + Matches(nb, pat)(nb.region + region) + } } + .rep + .map { fns => fns.toList.reduceLeft(_.andThen(_)) } - applied.maybeAp(an) + annotated.maybeAp(matchesOp) } - // matched - val matched: P[NonBinding] = { - // x matches p - val matchesOp = - ((maybeSpace.with1 *> P.string("matches") *> spaces).backtrack *> Pattern.matchParser) - .region - .map { case (region, pat) => - - { (nb: NonBinding) => Matches(nb, pat)(nb.region + region) } + // Applying is higher precedence than any operators + // now parse an operator apply + def postOperators(nb: P[NonBinding]): P[NonBinding] = { + + def convert(form: Operators.Formula[NonBinding]): NonBinding = + form match { + case Operators.Formula.Sym(r) => r + case Operators.Formula.Op(left, op, right) => + val leftD = convert(left) + val rightD = convert(right) + // `op`(l, r) + ApplyOp(leftD, Identifier.Operator(op), rightD) } - .rep - .map { fns => fns.toList.reduceLeft(_.andThen(_)) } - annotated.maybeAp(matchesOp) - } + // one or more operators + val ops: P[NonBinding => Operators.Formula[NonBinding]] = + maybeSpace.with1.soft *> ((!bindOp).with1 *> Operators.Formula + .infixOps1(nb)) - // Applying is higher precedence than any operators - // now parse an operator apply - def postOperators(nb: P[NonBinding]): P[NonBinding] = { - - def convert(form: Operators.Formula[NonBinding]): NonBinding = - form match { - case Operators.Formula.Sym(r) => r - case Operators.Formula.Op(left, op, right) => - val leftD = convert(left) - val rightD = convert(right) - // `op`(l, r) - ApplyOp(leftD, Identifier.Operator(op), rightD) + // This already parses as many as it can, so we don't need repFn + val form = ops.map { fn => + { (d: NonBinding) => convert(fn(d)) } } - // one or more operators - val ops: P[NonBinding => Operators.Formula[NonBinding]] = - maybeSpace.with1.soft *> ((!bindOp).with1 *> Operators.Formula.infixOps1(nb)) - - // This already parses as many as it can, so we don't need repFn - val form = ops.map { fn => - - { (d: NonBinding) => convert(fn(d)) } + nb.maybeAp(form) } - nb.maybeAp(form) - } + // here is if/ternary operator + // it fully recurses on the else branch, which will parse any repeated ternaryies + // so no need to repeat here for correct precedence + val ternary: P[NonBinding => NonBinding] = + (((spaces *> P.string( + "if" + ) *> spaces).backtrack *> recNonBind) ~ (spaces *> keySpace( + "else" + ) *> ternaryElseP)) + .map { + case (cond, falseCase) => { (trueCase: NonBinding) => + Ternary(trueCase, cond, falseCase) + } + } - // here is if/ternary operator - // it fully recurses on the else branch, which will parse any repeated ternaryies - // so no need to repeat here for correct precedence - val ternary: P[NonBinding => NonBinding] = - (((spaces *> P.string("if") *> spaces).backtrack *> recNonBind) ~ (spaces *> keySpace("else") *> ternaryElseP)) - .map { case (cond, falseCase) => - { (trueCase: NonBinding) => Ternary(trueCase, cond, falseCase) } - } + val finalNonBind: P[NonBinding] = + if (pm != ParseMode.ComprehensionSource) + postOperators(matched).maybeAp(ternary) + else postOperators(matched) - val finalNonBind: P[NonBinding] = - if (pm != ParseMode.ComprehensionSource) postOperators(matched).maybeAp(ternary) - else postOperators(matched) - - if (pm != ParseMode.Decl) finalNonBind - else { - val finalBind: P[Declaration] = P.defer( - P.oneOf( - // these have keywords which need to be parsed before var (def, match, if) - defP(recIndy)(indent) :: - // these are not ambiguous with patterns - commentP(recIndy)(indent) :: - /* - * challenge is that not all Declarations are Patterns, and not - * all Patterns are Declarations. So, bindings, which are: pattern = declaration - * is a bit hard. This also makes cuts a bit dangerous, since this ambiguity - * between pattern and declaration means if we use cuts too aggressively, we - * will fail. - * - * If we parse a declaration first, if we see = we need to convert - * to pattern. If we parse a pattern, but it was actually a declaration, we need - * to convert there. This code tries to parse as a declaration first, then converts - * it to pattern if we see an = - */ - patternBind(recNBIndy, recIndy)(indent) :: - Nil) + if (pm != ParseMode.Decl) finalNonBind + else { + val finalBind: P[Declaration] = P.defer( + P.oneOf( + // these have keywords which need to be parsed before var (def, match, if) + defP(recIndy)(indent) :: + // these are not ambiguous with patterns + commentP(recIndy)(indent) :: + /* + * challenge is that not all Declarations are Patterns, and not + * all Patterns are Declarations. So, bindings, which are: pattern = declaration + * is a bit hard. This also makes cuts a bit dangerous, since this ambiguity + * between pattern and declaration means if we use cuts too aggressively, we + * will fail. + * + * If we parse a declaration first, if we see = we need to convert + * to pattern. If we parse a pattern, but it was actually a declaration, we need + * to convert there. This code tries to parse as a declaration first, then converts + * it to pattern if we see an = + */ + patternBind(recNBIndy, recIndy)(indent) :: + Nil + ) ) - // we have to parse non-binds last - finalBind.orElse(finalNonBind) - } + // we have to parse non-binds last + finalBind.orElse(finalNonBind) + } } val parser: Indy[Declaration] = @@ -1324,7 +1559,9 @@ object Declaration { val nonBindingParser: Indy[NonBinding] = Indy { i => parserCache((ParseMode.NB, i)) }.asInstanceOf[Indy[NonBinding]] val nonBindingParserNoTern: Indy[NonBinding] = - Indy { i => parserCache((ParseMode.ComprehensionSource, i)) }.asInstanceOf[Indy[NonBinding]] + Indy { i => parserCache((ParseMode.ComprehensionSource, i)) } + .asInstanceOf[Indy[NonBinding]] val nonBindingParserNoAnn: Indy[NonBinding] = - Indy { i => parserCache((ParseMode.BranchArg, i)) }.asInstanceOf[Indy[NonBinding]] + Indy { i => parserCache((ParseMode.BranchArg, i)) } + .asInstanceOf[Indy[NonBinding]] } diff --git a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala index 7fbdd83e6..f0d2709ec 100644 --- a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala @@ -7,20 +7,19 @@ import cats.implicits._ import Identifier.Bindable -/** - * Recursion in bosatsu is only allowed on a substructural match - * of one of the parameters to the def. This strict rule, along - * with strictly finite data, ensures that all recursion terminates - * - * The rules are as follows: - * 0. defs may not be shadowed. This makes checking for legal recursion easier - * 1. until we reach a recur match, we cannot access an outer def name. We want to avoid aliasing - * 2. a recur match must occur on one of the literal parameters to the def, and there can - * be only one recur match - * 3. inside each branch of the recur match, we may only recur on substructures in the match - * position. - * 4. if there is a recur match, there must be at least one real recursion - */ +/** Recursion in bosatsu is only allowed on a substructural match of one of the + * parameters to the def. This strict rule, along with strictly finite data, + * ensures that all recursion terminates + * + * The rules are as follows: 0. defs may not be shadowed. This makes checking + * for legal recursion easier + * 1. until we reach a recur match, we cannot access an outer def name. We + * want to avoid aliasing 2. a recur match must occur on one of the + * literal parameters to the def, and there can be only one recur match 3. + * inside each branch of the recur match, we may only recur on + * substructures in the match position. 4. if there is a recur match, + * there must be at least one real recursion + */ object DefRecursionCheck { type Res = ValidatedNel[RecursionError, Unit] @@ -29,30 +28,40 @@ object DefRecursionCheck { def region: Region def message: String } - case class InvalidRecursion(name: Bindable, illegalPosition: Region) extends RecursionError { + case class InvalidRecursion(name: Bindable, illegalPosition: Region) + extends RecursionError { def region = illegalPosition def message = s"invalid recursion on ${name.sourceCodeRepr}" } - case class IllegalShadow(fnname: Bindable, decl: Declaration) extends RecursionError { + case class IllegalShadow(fnname: Bindable, decl: Declaration) + extends RecursionError { def region = decl.region - def message = s"illegal shadowing on: ${fnname.sourceCodeRepr}. Recursive shadowing of def names disallowed" + def message = + s"illegal shadowing on: ${fnname.sourceCodeRepr}. Recursive shadowing of def names disallowed" } case class UnexpectedRecur(decl: Declaration.Match) extends RecursionError { def region = decl.region def message = "unexpected recur: may only appear unnested inside a def" } - case class RecurNotOnArg(decl: Declaration.Match, - fnname: Bindable, - args: NonEmptyList[NonEmptyList[Pattern.Parsed]]) extends RecursionError { + case class RecurNotOnArg( + decl: Declaration.Match, + fnname: Bindable, + args: NonEmptyList[NonEmptyList[Pattern.Parsed]] + ) extends RecursionError { def region = decl.region def message = { val argsDoc = - Doc.intercalate(Doc.empty, + Doc.intercalate( + Doc.empty, args.toList.map { group => (Doc.char('(') + - Doc.intercalate(Doc.comma + Doc.line, - group.toList.map { pat => Pattern.document[TypeRef].document(pat) }) + + Doc.intercalate( + Doc.comma + Doc.line, + group.toList.map { pat => + Pattern.document[TypeRef].document(pat) + } + ) + Doc.char(')')).grouped } ) @@ -60,25 +69,35 @@ object DefRecursionCheck { s"recur not on an argument to the def of ${fnname.sourceCodeRepr}, args: $argStr" } } - case class RecursionArgNotVar(fnname: Bindable, invalidArg: Declaration) extends RecursionError { + case class RecursionArgNotVar(fnname: Bindable, invalidArg: Declaration) + extends RecursionError { def region = invalidArg.region - def message = s"recursion in ${fnname.sourceCodeRepr} is not on a name (expect a name which is exactly a arg to the def)" + def message = + s"recursion in ${fnname.sourceCodeRepr} is not on a name (expect a name which is exactly a arg to the def)" } - case class RecursionNotSubstructural(fnname: Bindable, recurPat: Pattern.Parsed, arg: Declaration.Var) extends RecursionError { + case class RecursionNotSubstructural( + fnname: Bindable, + recurPat: Pattern.Parsed, + arg: Declaration.Var + ) extends RecursionError { def region = arg.region def message = s"recursion in ${fnname.sourceCodeRepr} not substructual" } - case class RecursiveDefNoRecur(defstmt: DefStatement[Pattern.Parsed, Declaration], recur: Declaration.Match) extends RecursionError { + case class RecursiveDefNoRecur( + defstmt: DefStatement[Pattern.Parsed, Declaration], + recur: Declaration.Match + ) extends RecursionError { def region = recur.region - def message = s"recur but no recursive call to ${defstmt.name.sourceCodeRepr}" + def message = + s"recur but no recursive call to ${defstmt.name.sourceCodeRepr}" } - /** - * Check a statement that all inner declarations contain legal - * recursion, or none at all. Note, we don't check for cases that will be caught - * by typechecking: namely, when we have nonrecursive defs, their names are not - * in scope during typechecking, so illegal recursion there simply won't typecheck. - */ + /** Check a statement that all inner declarations contain legal recursion, or + * none at all. Note, we don't check for cases that will be caught by + * typechecking: namely, when we have nonrecursive defs, their names are not + * in scope during typechecking, so illegal recursion there simply won't + * typecheck. + */ def checkStatement(s: Statement): Res = { import Statement._ import Impl._ @@ -124,14 +143,17 @@ object DefRecursionCheck { (dn == n) || outer.defNamesContain(n) } - def inDef(fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]]): InDef = + def inDef( + fnname: Bindable, + args: NonEmptyList[NonEmptyList[Pattern.Parsed]] + ): InDef = InDef(this, fnname, args, Set.empty) } sealed abstract class InDefState extends State { final def inDef: InDef = this match { - case id @ InDef(_, _, _, _) => id - case InDefRecurred(ir, _, _, _, _) => ir.inDef + case id @ InDef(_, _, _, _) => id + case InDefRecurred(ir, _, _, _, _) => ir.inDef case InRecurBranch(InDefRecurred(ir, _, _, _, _), _, _) => ir.inDef } @@ -139,7 +161,12 @@ object DefRecursionCheck { } case object TopLevel extends State - case class InDef(outer: State, fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]], localScope: Set[Bindable]) extends InDefState { + case class InDef( + outer: State, + fnname: Bindable, + args: NonEmptyList[NonEmptyList[Pattern.Parsed]], + localScope: Set[Bindable] + ) extends InDefState { def addLocal(b: Bindable): InDef = InDef(outer, fnname, args, localScope + b) @@ -149,30 +176,34 @@ object DefRecursionCheck { // This is eta-expansion of the function name as a lambda so we can check using the lambda rule def asLambda(region: Region): Declaration.Lambda = { - val allNames = Iterator.iterate(0)(_ + 1).map { idx => Identifier.Name(s"a$idx") }.filterNot(_ == fnname) - + val allNames = Iterator + .iterate(0)(_ + 1) + .map { idx => Identifier.Name(s"a$idx") } + .filterNot(_ == fnname) + val func = cats.Functor[NonEmptyList].compose[NonEmptyList] // we allocate the names first. There is only one name inside: fnname val argsB = func.map(args)(_ => allNames.next()) val argsV: NonEmptyList[NonEmptyList[Declaration.NonBinding]] = - func.map(argsB)( - n => Declaration.Var(n)(region) - ) + func.map(argsB)(n => Declaration.Var(n)(region)) val argsP: NonEmptyList[NonEmptyList[Pattern.Parsed]] = - func.map(argsB)( - n => Pattern.Var(n) - ) + func.map(argsB)(n => Pattern.Var(n)) - // fn == (x, y) -> z -> f(x, y)(z) - val body = argsV.toList.foldLeft(Declaration.Var(fnname)(region): Declaration.NonBinding) { (called, group) => - Declaration.Apply(called, group, Declaration.ApplyKind.Parens)(region) + // fn == (x, y) -> z -> f(x, y)(z) + val body = argsV.toList.foldLeft( + Declaration.Var(fnname)(region): Declaration.NonBinding + ) { (called, group) => + Declaration.Apply(called, group, Declaration.ApplyKind.Parens)(region) } - def lambdify(args: NonEmptyList[NonEmptyList[Pattern.Parsed]], body: Declaration): Declaration.Lambda = { - val body1 = args.tail match { - case Nil => body + def lambdify( + args: NonEmptyList[NonEmptyList[Pattern.Parsed]], + body: Declaration + ): Declaration.Lambda = { + val body1 = args.tail match { + case Nil => body case h :: tail => lambdify(NonEmptyList(h, tail), body) } Declaration.Lambda(args.head, body1)(region) @@ -181,10 +212,20 @@ object DefRecursionCheck { lambdify(argsP, body) } } - case class InDefRecurred(inRec: InDef, group: Int, index: Int, recur: Declaration.Match, recCount: Int) extends InDefState { + case class InDefRecurred( + inRec: InDef, + group: Int, + index: Int, + recur: Declaration.Match, + recCount: Int + ) extends InDefState { def incRecCount: InDefRecurred = copy(recCount = recCount + 1) } - case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed, allowedNames: Set[Bindable]) extends InDefState { + case class InRecurBranch( + inRec: InDefRecurred, + branch: Pattern.Parsed, + allowedNames: Set[Bindable] + ) extends InDefState { def incRecCount: InRecurBranch = copy(inRec = inRec.incRecCount) } @@ -192,10 +233,11 @@ object DefRecursionCheck { * What is the index into the list of def arguments where we are doing our recursion */ def getRecurIndex( - fnname: Bindable, - args: NonEmptyList[NonEmptyList[Pattern.Parsed]], - m: Declaration.Match, - locals: Set[Bindable]): ValidatedNel[RecursionError, (Int, Int)] = { + fnname: Bindable, + args: NonEmptyList[NonEmptyList[Pattern.Parsed]], + m: Declaration.Match, + locals: Set[Bindable] + ): ValidatedNel[RecursionError, (Int, Int)] = { import Declaration._ m.arg match { case Var(v) => @@ -209,7 +251,6 @@ object DefRecursionCheck { if item.topNames.contains(v) } yield (gidx, idx) - if (idxes.hasNext) Validated.valid(idxes.next()) else Validated.invalidNel(RecurNotOnArg(m, fnname, args)) } @@ -222,9 +263,14 @@ object DefRecursionCheck { * Check that decl is a strict substructure of pat. We do this by making sure decl is a Var * and that var is one of the strict substrutures of the pattern. */ - def allowedRecursion(fnname: Bindable, pat: Pattern.Parsed, names: Set[Bindable], decl: Declaration): Res = + def allowedRecursion( + fnname: Bindable, + pat: Pattern.Parsed, + names: Set[Bindable], + decl: Declaration + ): Res = decl match { - case v@Declaration.Var(nm: Bindable) => + case v @ Declaration.Var(nm: Bindable) => if (names.contains(nm)) unitValid else Validated.invalidNel(RecursionNotSubstructural(fnname, pat, v)) case _ => @@ -240,20 +286,25 @@ object DefRecursionCheck { * for the algorithm here, but also for human readers to see that recursion is total */ def checkForIllegalBinds[A]( - state: State, - bs: Iterable[Bindable], - decl: Declaration)(next: ValidatedNel[RecursionError, A]): ValidatedNel[RecursionError, A] = { - val outerSet = state.outerDefNames - if (outerSet.isEmpty) next - else { - NonEmptyList.fromList(bs.iterator.filter(outerSet).toList.sorted) match { - case Some(nel) => - Validated.invalid(nel.map(IllegalShadow(_, decl))) - case None => - next - } + state: State, + bs: Iterable[Bindable], + decl: Declaration + )( + next: ValidatedNel[RecursionError, A] + ): ValidatedNel[RecursionError, A] = { + val outerSet = state.outerDefNames + if (outerSet.isEmpty) next + else { + NonEmptyList.fromList( + bs.iterator.filter(outerSet).toList.sorted + ) match { + case Some(nel) => + Validated.invalid(nel.map(IllegalShadow(_, decl))) + case None => + next } } + } /* * Unfortunately we lose the Applicative structure inside Declaration checking. @@ -277,15 +328,16 @@ object DefRecursionCheck { new cats.data.IndexedStateT((fna, fnb).parMapN { (fn1, fn2) => { (state: State) => fn1(state) match { - case Right((s2, a)) => fn2(s2).map { case (st, b) => (st, (a, b)) } + case Right((s2, a)) => + fn2(s2).map { case (st, b) => (st, (a, b)) } case Left(nel1) => // just skip and merge fn2(state) match { - case Right(_) => Left(nel1) + case Right(_) => Left(nel1) case Left(nel2) => Left(nel1 ::: nel2) } } - } + } }) } } @@ -303,19 +355,22 @@ object DefRecursionCheck { val unitSt: St[Unit] = pureSt(()) def checkForIllegalBindsSt[A]( - bs: Iterable[Bindable], - decl: Declaration): St[Unit] = - for { - state <- getSt - _ <- toSt(checkForIllegalBinds(state, bs, decl)(unitValid)) - _ <- (state match { - case id@InDef(_, _, _, _) => setSt(bs.foldLeft(id)(_.addLocal(_))) - case _ => unitSt - }) - } yield () + bs: Iterable[Bindable], + decl: Declaration + ): St[Unit] = + for { + state <- getSt + _ <- toSt(checkForIllegalBinds(state, bs, decl)(unitValid)) + _ <- (state match { + case id @ InDef(_, _, _, _) => setSt(bs.foldLeft(id)(_.addLocal(_))) + case _ => unitSt + }) + } yield () - private def argsOnDefName(fn: Declaration, - groups: NonEmptyList[NonEmptyList[Declaration]]): Option[(Bindable, NonEmptyList[NonEmptyList[Declaration]])] = + private def argsOnDefName( + fn: Declaration, + groups: NonEmptyList[NonEmptyList[Declaration]] + ): Option[(Bindable, NonEmptyList[NonEmptyList[Declaration]])] = fn match { case Declaration.Var(nm: Bindable) => Some((nm, groups)) case Declaration.Apply(fn1, args, _) => @@ -330,9 +385,11 @@ object DefRecursionCheck { .flatMapN { case (a, InRecurBranch(ir1, b1, _)) => setSt(InRecurBranch(ir1, b1, names)).as(a) - // $COVERAGE-OFF$ this should be unreachable + // $COVERAGE-OFF$ this should be unreachable case (_, unexpected) => - sys.error(s"invariant violation expected InRecurBranch: start = $start, end = $unexpected") + sys.error( + s"invariant violation expected InRecurBranch: start = $start, end = $unexpected" + ) } case notRecur => sys.error(s"called setNames on $notRecur with names: $newNames") @@ -348,19 +405,24 @@ object DefRecursionCheck { setSt(InRecurBranch(ir1, b1, names)).as(a) // $COVERAGE-OFF$ this should be unreachable case (_, unexpected) => - sys.error(s"invariant violation expected InRecurBranch: start = $start, end = $unexpected") + sys.error( + s"invariant violation expected InRecurBranch: start = $start, end = $unexpected" + ) // $COVERAGE-ON$ this should be unreachable } case _ => in } - def checkApply(fn: Declaration, args: NonEmptyList[Declaration], region: Region): St[Unit] = + def checkApply( + fn: Declaration, + args: NonEmptyList[Declaration], + region: Region + ): St[Unit] = getSt.flatMap { case TopLevel => // without any recursion, normal typechecking will detect bad states: checkDecl(fn) *> args.parTraverse_(checkDecl) - case irb@InRecurBranch(inrec, branch, names) => - + case irb @ InRecurBranch(inrec, branch, names) => argsOnDefName(fn, NonEmptyList.one(args)) match { case Some((nm, groups)) => if (nm == irb.defname) { @@ -374,39 +436,38 @@ object DefRecursionCheck { toSt(allowedRecursion(irb.defname, branch, names, arg)) *> setSt(irb.incRecCount) // we have recurred again } - } - else if (irb.defNamesContain(nm)) { + } else if (irb.defNamesContain(nm)) { failSt(InvalidRecursion(nm, region)) - } - else if (names.contains(nm)) { + } else if (names.contains(nm)) { // we are calling a reachable function. Any lambda args are new names: args.parTraverse_[St, Unit] { case Declaration.Lambda(args, body) => val names1 = args.toList.flatMap(_.names) unionNames(names1)(checkDecl(body)) - case v@Declaration.Var(fn: Bindable) if irb.defname == fn => - val Declaration.Lambda(args, body) = irb.inDef.asLambda(v.region) + case v @ Declaration.Var(fn: Bindable) if irb.defname == fn => + val Declaration.Lambda(args, body) = + irb.inDef.asLambda(v.region) val names1 = args.toList.flatMap(_.names) unionNames(names1)(checkDecl(body)) case notLambda => checkDecl(notLambda) } - } - else { + } else { // traverse converting Var(name) to the lambda version to use the above check // not a recursive call args.parTraverse_(checkDecl) } case None => // this isn't a recursive call - checkDecl(fn) *> args.parTraverse_(checkDecl) + checkDecl(fn) *> args.parTraverse_(checkDecl) } case ir: InDefState => // we have either not yet, or already done the recursion argsOnDefName(fn, NonEmptyList.one(args)) match { - case Some((nm, _)) if ir.defNamesContain(nm) => failSt(InvalidRecursion(nm, region)) + case Some((nm, _)) if ir.defNamesContain(nm) => + failSt(InvalidRecursion(nm, region)) case _ => - checkDecl(fn) *> args.parTraverse_(checkDecl) - } + checkDecl(fn) *> args.parTraverse_(checkDecl) + } } /* * With the given state, check the given Declaration to see if @@ -417,13 +478,17 @@ object DefRecursionCheck { decl match { case Annotation(t, _) => checkDecl(t) case Apply(fn, args, _) => - checkApply(fn, args, decl.region) + checkApply(fn, args, decl.region) case ApplyOp(left, op, right) => - checkApply(Var(op)(decl.region), NonEmptyList(left, right :: Nil), decl.region) + checkApply( + Var(op)(decl.region), + NonEmptyList(left, right :: Nil), + decl.region + ) case Binding(BindingStatement(pat, thisDecl, next)) => checkForIllegalBindsSt(pat.names, decl) *> - checkDecl(thisDecl) *> - filterNames(pat.names)(checkDecl(next.padded)) + checkDecl(thisDecl) *> + filterNames(pat.names)(checkDecl(next.padded)) case Comment(cs) => checkDecl(cs.on.padded) case CommentNB(cs) => @@ -441,7 +506,7 @@ object DefRecursionCheck { } val e = checkDecl(elseCase.get) ifs *> e - case la@LeftApply(_, _, _, _) => + case la @ LeftApply(_, _, _, _) => checkDecl(la.rewrite) case Ternary(t, c, f) => checkDecl(t) *> checkDecl(c) *> checkDecl(f) @@ -460,10 +525,11 @@ object DefRecursionCheck { filterNames(pat.names)(checkDecl(next.get)) } argRes *> optRes - case recur@Match(RecursionKind.Recursive, _, cases) => + case recur @ Match(RecursionKind.Recursive, _, cases) => // this is a state change getSt.flatMap { - case TopLevel | InRecurBranch(_, _, _) | InDefRecurred(_, _, _, _, _) => + case TopLevel | InRecurBranch(_, _, _) | + InDefRecurred(_, _, _, _, _) => failSt(UnexpectedRecur(recur)) case InDef(_, defname, args, locals) => toSt(getRecurIndex(defname, args, recur, locals)).flatMap { idx => @@ -471,24 +537,24 @@ object DefRecursionCheck { // parent state def beginBranch(pat: Pattern.Parsed): St[Unit] = getSt.flatMap { - case ir@InDef(_, _, _, _) => + case ir @ InDef(_, _, _, _) => val rec = ir.setRecur(idx, recur) setSt(rec) *> beginBranch(pat) - case irr@InDefRecurred(_, _, _, _, _) => + case irr @ InDefRecurred(_, _, _, _, _) => setSt(InRecurBranch(irr, pat, pat.substructures.toSet)) case illegal => // $COVERAGE-OFF$ this should be unreachable sys.error(s"unreachable: $pat -> $illegal") - // $COVERAGE-ON$ - } + // $COVERAGE-ON$ + } val endBranch: St[Unit] = getSt.flatMap { case InRecurBranch(irr, _, _) => setSt(irr) - case illegal => + case illegal => // $COVERAGE-OFF$ this should be unreachable sys.error(s"unreachable end state: $illegal") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } cases.get.parTraverse_ { case (pat, next) => @@ -500,7 +566,7 @@ object DefRecursionCheck { } yield () } } - } + } case Matches(a, _) => // patterns don't use values checkDecl(a) @@ -517,13 +583,14 @@ object DefRecursionCheck { unitSt case ir: InDefState => // if this were an apply, it would have been handled by Apply(Var(... - if (ir.defNamesContain(v)) failSt(InvalidRecursion(v, decl.region)) + if (ir.defNamesContain(v)) + failSt(InvalidRecursion(v, decl.region)) else unitSt } case StringDecl(parts) => parts.parTraverse_ { - case StringDecl.CharExpr(nb) => checkDecl(nb) - case StringDecl.StrExpr(nb) => checkDecl(nb) + case StringDecl.CharExpr(nb) => checkDecl(nb) + case StringDecl.StrExpr(nb) => checkDecl(nb) case StringDecl.Literal(_, _) => unitSt } case ListDecl(ll) => @@ -565,7 +632,10 @@ object DefRecursionCheck { * Binds are not allowed to be recursive, only defs, so here we just make sure * none of the free variables of the pattern are used in decl */ - def checkDef[A](state: State, defstmt: DefStatement[Pattern.Parsed, (OptIndent[Declaration], A)]): Res = { + def checkDef[A]( + state: State, + defstmt: DefStatement[Pattern.Parsed, (OptIndent[Declaration], A)] + ): Res = { val body = defstmt.result._1.get val nameArgs = defstmt.args.toList.flatMap(_.patternNames) val state1 = state.inDef(defstmt.name, defstmt.args) @@ -579,11 +649,18 @@ object DefRecursionCheck { unitSt case InDefRecurred(_, _, _, recur, 0) => // we hit a recur, but we didn't recurse - failSt[Unit](RecursiveDefNoRecur(defstmt.copy(result = defstmt.result._1.get), recur)) + failSt[Unit]( + RecursiveDefNoRecur( + defstmt.copy(result = defstmt.result._1.get), + recur + ) + ) case unreachable => // $COVERAGE-OFF$ this should be unreachable - sys.error(s"we would like to prove in the types we can't get here: $unreachable, $defstmt"): St[Unit] - // $COVERAGE-ON$ + sys.error( + s"we would like to prove in the types we can't get here: $unreachable, $defstmt" + ): St[Unit] + // $COVERAGE-ON$ }) // Note a def can't change the state // we either have a valid nested def, or we don't diff --git a/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala b/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala index 932814983..803869cff 100644 --- a/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala @@ -29,14 +29,16 @@ object DefStatement { import defs._ val res = retType.fold(Doc.empty) { t => arrow + t.toDoc } val taDoc = typeArgs match { - case None => Doc.empty - case Some(ta) => TypeRef.docTypeArgs(ta.toList) { - case None => Doc.empty - case Some(k) => colonSpace + Kind.toDoc(k) - } + case None => Doc.empty + case Some(ta) => + TypeRef.docTypeArgs(ta.toList) { + case None => Doc.empty + case Some(k) => colonSpace + Kind.toDoc(k) + } } val argDoc = - Doc.intercalate(Doc.empty, + Doc.intercalate( + Doc.empty, args.toList.map { args => Doc.char('(') + Doc.intercalate( @@ -67,7 +69,9 @@ object DefStatement { ( Parser.keySpace( "def" - ) *> (Identifier.bindableParser ~ TypeRef.typeParams(kindAnnot.?).? ~ args.rep) <* maybeSpace, + ) *> (Identifier.bindableParser ~ TypeRef + .typeParams(kindAnnot.?) + .? ~ args.rep) <* maybeSpace, result.with1 <* (maybeSpace.with1 ~ P.char(':')), resultTParser ) diff --git a/core/src/main/scala/org/bykn/bosatsu/EditDistance.scala b/core/src/main/scala/org/bykn/bosatsu/EditDistance.scala index c1a23ca22..f224c0232 100644 --- a/core/src/main/scala/org/bykn/bosatsu/EditDistance.scala +++ b/core/src/main/scala/org/bykn/bosatsu/EditDistance.scala @@ -6,10 +6,10 @@ object EditDistance { def apply[A](a: Iterable[A], b: Iterable[A]): Int = a.foldLeft((0 to b.size).toList) { (prev, x) => (prev zip prev.tail zip b) - .scanLeft(prev.head + 1) { - case (h, ((d, v), y)) => min(min(h + 1, v + 1), d + (if (x == y) 0 else 1)) + .scanLeft(prev.head + 1) { case (h, ((d, v), y)) => + min(min(h + 1, v + 1), d + (if (x == y) 0 else 1)) } - }.last + }.last def string(a: String, b: String): Int = apply(a, b) diff --git a/core/src/main/scala/org/bykn/bosatsu/Evaluation.scala b/core/src/main/scala/org/bykn/bosatsu/Evaluation.scala index d24783b33..307361ef9 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Evaluation.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Evaluation.scala @@ -9,9 +9,9 @@ import cats.implicits._ import Identifier.Bindable case class Evaluation[T](pm: PackageMap.Typed[T], externals: Externals) { - /** - * Holds the final value of the environment for each Package - */ + + /** Holds the final value of the environment for each Package + */ private[this] val envCache: MMap[PackageName, Map[Identifier, Eval[Value]]] = MMap.empty @@ -20,46 +20,46 @@ case class Evaluation[T](pm: PackageMap.Typed[T], externals: Externals) { externalNames.iterator.map { n => val tpe = p.program.types.getValue(p.name, n) match { case Some(t) => t - case None => + case None => // $COVERAGE-OFF$ // should never happen due to typechecking sys.error(s"from ${p.name} import unknown external def: $n") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } externals.toMap.get((p.name, n.asString)) match { case Some(ext) => (n, Eval.later(ext.call(tpe))) - case None => + case None => // $COVERAGE-OFF$ // should never happen due to typechecking sys.error(s"from ${p.name} no External for external def: $n") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - } - .toMap + }.toMap } private[this] lazy val gdr = pm.getDataRepr - private def evalLets(thisPack: PackageName, lets: List[(Bindable, RecursionKind, TypedExpr[T])]): List[(Bindable, Eval[Value])] = { + private def evalLets( + thisPack: PackageName, + lets: List[(Bindable, RecursionKind, TypedExpr[T])] + ): List[(Bindable, Eval[Value])] = { val exprs: List[(Bindable, Matchless.Expr)] = - rankn.RefSpace - .allocCounter + rankn.RefSpace.allocCounter .flatMap { c => lets - .traverse { - case (name, rec, te) => - Matchless.fromLet(name, rec, te, gdr, c) - .map((name, _)) + .traverse { case (name, rec, te) => + Matchless + .fromLet(name, rec, te, gdr, c) + .map((name, _)) } - } - .run - .value + } + .run + .value - val evalFn: (PackageName, Identifier) => Eval[Value] = - { (p, i) => - if (p == thisPack) Eval.defer(evaluate(p)(i)) - else evaluate(p)(i) - } + val evalFn: (PackageName, Identifier) => Eval[Value] = { (p, i) => + if (p == thisPack) Eval.defer(evaluate(p)(i)) + else evaluate(p)(i) + } type F[A] = List[(Bindable, A)] val ffunc = cats.Functor[List].compose(cats.Functor[(Bindable, *)]) @@ -67,10 +67,12 @@ case class Evaluation[T](pm: PackageMap.Typed[T], externals: Externals) { } private def evaluate(packName: PackageName): Map[Identifier, Eval[Value]] = - envCache.getOrElseUpdate(packName, { - val pack = pm.toMap(packName) - externalEnv(pack) ++ evalLets(packName, pack.program.lets) - }) + envCache.getOrElseUpdate( + packName, { + val pack = pm.toMap(packName) + externalEnv(pack) ++ evalLets(packName, pack.program.lets) + } + ) def evaluateLast(p: PackageName): Option[(Eval[Value], Type)] = for { @@ -80,18 +82,21 @@ case class Evaluation[T](pm: PackageMap.Typed[T], externals: Externals) { } yield (value, tpe.getType) // TODO: this only works for lets, not externals - def evaluateName(p: PackageName, name: Bindable): Option[(Eval[Value], Type)] = + def evaluateName( + p: PackageName, + name: Bindable + ): Option[(Eval[Value], Type)] = for { pack <- pm.toMap.get(p) - (_, _, tpe) <- pack.program.lets.filter { case (n, _, _) => n == name }.lastOption + (_, _, tpe) <- pack.program.lets.filter { case (n, _, _) => + n == name + }.lastOption value <- evaluate(p).get(name) } yield (value, tpe.getType) - /** - * Return the last test, if any, in the package. - * this is the test that is run when we test - * the package - */ + /** Return the last test, if any, in the package. this is the test that is run + * when we test the package + */ def lastTest(p: PackageName): Option[Eval[Value]] = for { pack <- pm.toMap.get(p) @@ -115,36 +120,30 @@ case class Evaluation[T](pm: PackageMap.Typed[T], externals: Externals) { Doc.intercalate(Doc.lineOrSpace, packs).render(80) } - */ + */ def evalTest(ps: PackageName): Option[Eval[Test]] = lastTest(ps).map { ea => ea.map(Test.fromValue(_)) } - /** - * Convert a typechecked value to Json - * this code ASSUMES the type is correct. If not, we may throw or return - * incorrect data. - */ - val valueToJson: ValueToJson = ValueToJson({ - case Type.Const.Defined(pn, t) => - for { - pack <- pm.toMap.get(pn) - dt <- pack.program.types.getType(pn, t) - } yield dt + /** Convert a typechecked value to Json this code ASSUMES the type is correct. + * If not, we may throw or return incorrect data. + */ + val valueToJson: ValueToJson = ValueToJson({ case Type.Const.Defined(pn, t) => + for { + pack <- pm.toMap.get(pn) + dt <- pack.program.types.getType(pn, t) + } yield dt }) - /** - * Convert a typechecked value to Doc - * this code ASSUMES the type is correct. If not, we may throw or return - * incorrect data. - */ - val valueToDoc: ValueToDoc = ValueToDoc({ - case Type.Const.Defined(pn, t) => - for { - pack <- pm.toMap.get(pn) - dt <- pack.program.types.getType(pn, t) - } yield dt + /** Convert a typechecked value to Doc this code ASSUMES the type is correct. + * If not, we may throw or return incorrect data. + */ + val valueToDoc: ValueToDoc = ValueToDoc({ case Type.Const.Defined(pn, t) => + for { + pack <- pm.toMap.get(pn) + dt <- pack.program.types.getType(pn, t) + } yield dt }) } diff --git a/core/src/main/scala/org/bykn/bosatsu/ExportedName.scala b/core/src/main/scala/org/bykn/bosatsu/ExportedName.scala index b8d855d1b..f5b94c67c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ExportedName.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ExportedName.scala @@ -16,107 +16,120 @@ sealed abstract class ExportedName[+T] { self: Product => // we use them as hash keys final override val hashCode: Int = MurmurHash3.productHash(this) - /** - * Given name, in the current type environment and fully typed lets - * what does it correspond to? - */ + + /** Given name, in the current type environment and fully typed lets what does + * it correspond to? + */ private def toReferants[A]( - letValue: Option[rankn.Type], - definedType: Option[rankn.DefinedType[A]]): Option[NonEmptyList[ExportedName[Referant[A]]]] = - this match { - case ExportedName.Binding(n, _) => - letValue.map { tpe => - NonEmptyList.one(ExportedName.Binding(n, Referant.Value(tpe))) - } - case ExportedName.TypeName(nm, _) => - definedType.map { dt => - NonEmptyList.one(ExportedName.TypeName(nm, Referant.DefinedT(dt))) - } - case ExportedName.Constructor(nm, _) => - // export the type and all constructors - definedType.map { dt => - val cons = dt.constructors.map { cf => - ExportedName.Constructor(cf.name, Referant.Constructor(dt, cf)) - } - val t = ExportedName.TypeName(nm, Referant.DefinedT(dt)) - NonEmptyList(t, cons) - } - } + letValue: Option[rankn.Type], + definedType: Option[rankn.DefinedType[A]] + ): Option[NonEmptyList[ExportedName[Referant[A]]]] = + this match { + case ExportedName.Binding(n, _) => + letValue.map { tpe => + NonEmptyList.one(ExportedName.Binding(n, Referant.Value(tpe))) + } + case ExportedName.TypeName(nm, _) => + definedType.map { dt => + NonEmptyList.one(ExportedName.TypeName(nm, Referant.DefinedT(dt))) + } + case ExportedName.Constructor(nm, _) => + // export the type and all constructors + definedType.map { dt => + val cons = dt.constructors.map { cf => + ExportedName.Constructor(cf.name, Referant.Constructor(dt, cf)) + } + val t = ExportedName.TypeName(nm, Referant.DefinedT(dt)) + NonEmptyList(t, cons) + } + } } object ExportedName { - case class Binding[T](name: Identifier.Bindable, tag: T) extends ExportedName[T] - case class TypeName[T](name: Identifier.Constructor, tag: T) extends ExportedName[T] - case class Constructor[T](name: Identifier.Constructor, tag: T) extends ExportedName[T] + case class Binding[T](name: Identifier.Bindable, tag: T) + extends ExportedName[T] + case class TypeName[T](name: Identifier.Constructor, tag: T) + extends ExportedName[T] + case class Constructor[T](name: Identifier.Constructor, tag: T) + extends ExportedName[T] private[this] val consDoc = Doc.text("()") implicit val document: Document[ExportedName[Unit]] = { val di = Document[Identifier] Document.instance[ExportedName[Unit]] { - case Binding(n, _) => di.document(n) - case TypeName(n, _) => di.document(n) + case Binding(n, _) => di.document(n) + case TypeName(n, _) => di.document(n) case Constructor(n, _) => di.document(n) + consDoc } } val parser: P[ExportedName[Unit]] = - Identifier.bindableParser.map(Binding(_, ())) + Identifier.bindableParser + .map(Binding(_, ())) .orElse( (Identifier.consParser ~ P.string("()").?) .map { - case (n, None) => TypeName(n, ()) + case (n, None) => TypeName(n, ()) case (n, Some(_)) => Constructor(n, ()) } ) - /** - * Build exports into referants given a typeEnv - * The only error we have have here is if we name an export we didn't define - * Note a name can be two things: - * 1. a type - * 2. a value (e.g. a let or a constructor function) - */ + /** Build exports into referants given a typeEnv The only error we have have + * here is if we name an export we didn't define Note a name can be two + * things: + * 1. a type 2. a value (e.g. a let or a constructor function) + */ def buildExports[E, V, R, D]( - nm: PackageName, - exports: List[ExportedName[E]], - typeEnv: rankn.TypeEnv[V], - lets: List[(Identifier.Bindable, R, TypedExpr[D])])(implicit ev: V <:< Kind.Arg): ValidatedNel[ExportedName[E], List[ExportedName[Referant[V]]]] = { + nm: PackageName, + exports: List[ExportedName[E]], + typeEnv: rankn.TypeEnv[V], + lets: List[(Identifier.Bindable, R, TypedExpr[D])] + )(implicit + ev: V <:< Kind.Arg + ): ValidatedNel[ExportedName[E], List[ExportedName[Referant[V]]]] = { - val letMap = lets.iterator.map { case (n, _, t) => (n, t) }.toMap + val letMap = lets.iterator.map { case (n, _, t) => (n, t) }.toMap - def expName[A](ename: ExportedName[A]): Option[NonEmptyList[ExportedName[Referant[V]]]] = { - import ename.name - val letValue: Option[rankn.Type] = - name.toBindable - .flatMap { bn => - letMap.get(bn) - .map(_.getType) - .orElse { - // It could be an external or imported value in the TypeEnv - typeEnv.getValue(nm, bn) - } - } - val optDT = - name.toConstructor - .flatMap { cn => - typeEnv.getType(nm, org.bykn.bosatsu.TypeName(cn)) - } + def expName[A]( + ename: ExportedName[A] + ): Option[NonEmptyList[ExportedName[Referant[V]]]] = { + import ename.name + val letValue: Option[rankn.Type] = + name.toBindable + .flatMap { bn => + letMap + .get(bn) + .map(_.getType) + .orElse { + // It could be an external or imported value in the TypeEnv + typeEnv.getValue(nm, bn) + } + } + val optDT = + name.toConstructor + .flatMap { cn => + typeEnv.getType(nm, org.bykn.bosatsu.TypeName(cn)) + } - ename.toReferants(letValue, optDT) - } + ename.toReferants(letValue, optDT) + } - def expName1[A](ename: ExportedName[A]): ValidatedNel[ExportedName[A], List[ExportedName[Referant[V]]]] = - expName(ename) match { - case None => Validated.invalid(NonEmptyList.of(ename)) - case Some(v) => Validated.valid(v.toList) - } + def expName1[A]( + ename: ExportedName[A] + ): ValidatedNel[ExportedName[A], List[ExportedName[Referant[V]]]] = + expName(ename) match { + case None => Validated.invalid(NonEmptyList.of(ename)) + case Some(v) => Validated.valid(v.toList) + } - exports.traverse(expName1).map(_.flatten) + exports.traverse(expName1).map(_.flatten) } - def typeEnvFromExports[A](packageName: PackageName, exports: List[ExportedName[Referant[A]]]): TypeEnv[A] = + def typeEnvFromExports[A]( + packageName: PackageName, + exports: List[ExportedName[Referant[A]]] + ): TypeEnv[A] = exports.foldLeft((TypeEnv.empty): TypeEnv[A]) { (te, exp) => exp.tag.addTo(packageName, exp.name, te) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Expr.scala b/core/src/main/scala/org/bykn/bosatsu/Expr.scala index 6a0f6be04..2dbf18669 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Expr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Expr.scala @@ -1,9 +1,8 @@ package org.bykn.bosatsu -/** - * This is a scala port of the example of Hindley Milner inference - * here: http://dev.stephendiehl.com/fun/006_hindley_milner.html - */ +/** This is a scala port of the example of Hindley Milner inference here: + * http://dev.stephendiehl.com/fun/006_hindley_milner.html + */ import cats.implicits._ import cats.data.{Chain, Writer, NonEmptyList} @@ -16,11 +15,9 @@ import Identifier.{Bindable, Constructor} sealed abstract class Expr[T] { def tag: T - /** - * All the free variables in this expression in order - * encountered and with duplicates (to see how often - * they appear) - */ + /** All the free variables in this expression in order encountered and with + * duplicates (to see how often they appear) + */ lazy val freeVarsDup: List[Bindable] = { import Expr._ // nearly identical code to TypedExpr.freeVarsDup, bugs should be fixed in both places @@ -43,8 +40,7 @@ sealed abstract class Expr[T] { val argFree = if (rec.isRecursive) { ListUtil.filterNot(argFree0)(_ === arg) - } - else argFree0 + } else argFree0 argFree ::: (ListUtil.filterNot(in.freeVarsDup)(_ === arg)) case Literal(_, _) => @@ -60,8 +56,7 @@ sealed abstract class Expr[T] { else ListUtil.filterNot(bfree)(newBinds) } // we can only take one branch, so count the max on each branch: - val branchFreeMax = branchFrees - .zipWithIndex + val branchFreeMax = branchFrees.zipWithIndex .flatMap { case (names, br) => names.map((_, br)) } // these groupBys are okay because we sort at the end .groupBy(identity) // group-by-name x branch @@ -86,9 +81,9 @@ sealed abstract class Expr[T] { expr.globals case Annotation(t, _, _) => t.globals - case Local(_, _) => Set.empty + case Local(_, _) => Set.empty case g @ Global(_, _, _) => Set.empty + g - case Lambda(_, res, _) => res.globals + case Lambda(_, res, _) => res.globals case App(fn, args, _) => fn.globals | args.reduceMap(_.globals) case Let(_, argE, in, _, _) => @@ -102,15 +97,15 @@ sealed abstract class Expr[T] { def replaceTag(t: T): Expr[T] = { import Expr._ this match { - case g@Generic(_, e) => g.copy(in = e.replaceTag(t)) - case a@Annotation(_, _, _) => a.copy(tag = t) - case l@Local(_, _) => l.copy(tag = t) - case g @ Global(_, _, _) => g.copy(tag = t) - case l@Lambda(_, _, _) => l.copy(tag = t) - case a@App(_, _, _) => a.copy(tag = t) - case l@Let(_, _, _, _, _) => l.copy(tag = t) - case l@Literal(_, _) => l.copy(tag = t) - case m@Match(_, _, _) => m.copy(tag = t) + case g @ Generic(_, e) => g.copy(in = e.replaceTag(t)) + case a @ Annotation(_, _, _) => a.copy(tag = t) + case l @ Local(_, _) => l.copy(tag = t) + case g @ Global(_, _, _) => g.copy(tag = t) + case l @ Lambda(_, _, _) => l.copy(tag = t) + case a @ App(_, _, _) => a.copy(tag = t) + case l @ Let(_, _, _, _, _) => l.copy(tag = t) + case l @ Literal(_, _) => l.copy(tag = t) + case m @ Match(_, _, _) => m.copy(tag = t) } } @@ -123,18 +118,34 @@ object Expr { case class Annotation[T](expr: Expr[T], tpe: Type, tag: T) extends Expr[T] case class Local[T](name: Bindable, tag: T) extends Name[T] - case class Generic[T](typeVars: NonEmptyList[(Type.Var.Bound, Kind)], in: Expr[T]) extends Expr[T] { + case class Generic[T]( + typeVars: NonEmptyList[(Type.Var.Bound, Kind)], + in: Expr[T] + ) extends Expr[T] { def tag = in.tag } - case class Global[T](pack: PackageName, name: Identifier, tag: T) extends Name[T] - case class App[T](fn: Expr[T], args: NonEmptyList[Expr[T]], tag: T) extends Expr[T] - case class Lambda[T](args: NonEmptyList[(Bindable, Option[Type])], expr: Expr[T], tag: T) extends Expr[T] - case class Let[T](arg: Bindable, expr: Expr[T], in: Expr[T], recursive: RecursionKind, tag: T) extends Expr[T] { - def flatten: (NonEmptyList[(Bindable, RecursionKind, Expr[T], T)], Expr[T]) = { + case class Global[T](pack: PackageName, name: Identifier, tag: T) + extends Name[T] + case class App[T](fn: Expr[T], args: NonEmptyList[Expr[T]], tag: T) + extends Expr[T] + case class Lambda[T]( + args: NonEmptyList[(Bindable, Option[Type])], + expr: Expr[T], + tag: T + ) extends Expr[T] + case class Let[T]( + arg: Bindable, + expr: Expr[T], + in: Expr[T], + recursive: RecursionKind, + tag: T + ) extends Expr[T] { + def flatten + : (NonEmptyList[(Bindable, RecursionKind, Expr[T], T)], Expr[T]) = { val thisLet = (arg, recursive, expr, tag) in match { - case let@Let(_, _, _, _, _) => + case let @ Let(_, _, _, _, _) => val (lets, finalIn) = let.flatten (thisLet :: lets, finalIn) case _ => @@ -144,10 +155,19 @@ object Expr { } } case class Literal[T](lit: Lit, tag: T) extends Expr[T] - case class Match[T](arg: Expr[T], branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], Expr[T])], tag: T) extends Expr[T] + case class Match[T]( + arg: Expr[T], + branches: NonEmptyList[ + (Pattern[(PackageName, Constructor), Type], Expr[T]) + ], + tag: T + ) extends Expr[T] // Inverse of `Let.flatten` - def lets[T](binds: List[(Bindable, RecursionKind, Expr[T], T)], in: Expr[T]): Expr[T] = + def lets[T]( + binds: List[(Bindable, RecursionKind, Expr[T], T)], + in: Expr[T] + ): Expr[T] = binds match { case Nil => in case (b, r, e, t) :: tail => @@ -160,16 +180,19 @@ object Expr { expr match { case Annotation(_, tpe, _) => Some(tpe) case Lambda(args, Annotated(res), _) => - args.traverse { case (_, ot) => ot } + args + .traverse { case (_, ot) => ot } .map { argTpes => Type.Fun(argTpes, res) } - case Literal(lit, _) => Some(Type.getTypeOf(lit)) + case Literal(lit, _) => Some(Type.getTypeOf(lit)) case Let(_, _, Annotated(t), _, _) => Some(t) case Match(_, branches, _) => - branches.traverse { case (_, expr) => unapply(expr) } + branches + .traverse { case (_, expr) => unapply(expr) } .flatMap { allAnnotated => - if (allAnnotated.tail.forall(_ === allAnnotated.head)) Some(allAnnotated.head) + if (allAnnotated.tail.forall(_ === allAnnotated.head)) + Some(allAnnotated.head) else None } case _ => None @@ -189,29 +212,31 @@ object Expr { case Generic(typeVars, in) => Generic(nel ::: typeVars, in) case notAnn => Generic(nel, notAnn) - } + } } def quantifyFrees[A](expr: Expr[A]): Expr[A] = forAll(freeBoundTyVars(expr).map((_, Kind.Type)), expr) - /** - * Report all the Bindable names refered to in the given Expr. - * this can be used to allocate names that can never shadow - * anything being used in the expr - */ + /** Report all the Bindable names refered to in the given Expr. this can be + * used to allocate names that can never shadow anything being used in the + * expr + */ final def allNames[A](expr: Expr[A]): SortedSet[Bindable] = expr match { case Annotation(e, _, _) => allNames(e) - case Local(name, _) => SortedSet(name) - case Generic(_, in) => allNames(in) - case Global(_, _, _) => SortedSet.empty - case App(fn, args, _) => args.foldLeft(allNames(fn))((bs, e) => bs | allNames(e)) + case Local(name, _) => SortedSet(name) + case Generic(_, in) => allNames(in) + case Global(_, _, _) => SortedSet.empty + case App(fn, args, _) => + args.foldLeft(allNames(fn))((bs, e) => bs | allNames(e)) case Lambda(args, e, _) => allNames(e) ++ args.toList.iterator.map(_._1) case Let(arg, expr, in, _, _) => allNames(expr) | allNames(in) + arg - case Literal(_, _) => SortedSet.empty + case Literal(_, _) => SortedSet.empty case Match(exp, branches, _) => - allNames(exp) | branches.foldMap { case (pat, res) => allNames(res) ++ pat.names } + allNames(exp) | branches.foldMap { case (pat, res) => + allNames(res) ++ pat.names + } } implicit def hasRegion[T: HasRegion]: HasRegion[Expr[T]] = @@ -223,63 +248,83 @@ object Expr { private[this] val TruePat: Pattern[(PackageName, Constructor), Type] = Pattern.PositionalStruct((PackageName.PredefName, Constructor("True")), Nil) private[this] val FalsePat: Pattern[(PackageName, Constructor), Type] = - Pattern.PositionalStruct((PackageName.PredefName, Constructor("False")), Nil) - /** - * build a Match expression that is equivalent to if/else using Predef::True and Predef::False - */ - def ifExpr[T](cond: Expr[T], ifTrue: Expr[T], ifFalse: Expr[T], tag: T): Expr[T] = + Pattern.PositionalStruct( + (PackageName.PredefName, Constructor("False")), + Nil + ) + + /** build a Match expression that is equivalent to if/else using Predef::True + * and Predef::False + */ + def ifExpr[T]( + cond: Expr[T], + ifTrue: Expr[T], + ifFalse: Expr[T], + tag: T + ): Expr[T] = Match(cond, NonEmptyList.of((TruePat, ifTrue), (FalsePat, ifFalse)), tag) - /** - * Build an apply expression by appling these args left to right - */ + /** Build an apply expression by appling these args left to right + */ def buildApp[A](fn: Expr[A], args: List[Expr[A]], appTag: A): Expr[A] = args match { case head :: tail => App(fn, NonEmptyList(head, tail), appTag) - case Nil => fn + case Nil => fn } // Traverse all non-bound vars - private def traverseType[T, F[_]](expr: Expr[T], bound: Set[Type.Var.Bound])(fn: (Type, Set[Type.Var.Bound]) => F[Type])(implicit F: Applicative[F]): F[Expr[T]] = + private def traverseType[T, F[_]](expr: Expr[T], bound: Set[Type.Var.Bound])( + fn: (Type, Set[Type.Var.Bound]) => F[Type] + )(implicit F: Applicative[F]): F[Expr[T]] = expr match { case Annotation(e, tpe, a) => (traverseType(e, bound)(fn), fn(tpe, bound)).mapN(Annotation(_, _, a)) case v: Name[T] => F.pure(v) case App(f, args, t) => - (traverseType(f, bound)(fn), args.traverse(traverseType(_, bound)(fn))).mapN(App(_, _, t)) + (traverseType(f, bound)(fn), args.traverse(traverseType(_, bound)(fn))) + .mapN(App(_, _, t)) case Generic(bs, in) => // Seems dangerous since we are hiding from fn that the Type.TyVar inside // matching these are not unbound val bound1 = bound ++ bs.toList.iterator.map(_._1) traverseType(in, bound1)(fn).map(Generic(bs, _)) case Lambda(args, expr, t) => - (args.traverse { case (n, optT) => optT.traverse(fn(_, bound)).map((n, _)) }, - traverseType(expr, bound)(fn)).mapN(Lambda(_, _, t)) + ( + args.traverse { case (n, optT) => + optT.traverse(fn(_, bound)).map((n, _)) + }, + traverseType(expr, bound)(fn) + ).mapN(Lambda(_, _, t)) case Let(arg, exp, in, rec, tag) => - (traverseType(exp, bound)(fn), traverseType(in, bound)(fn)).mapN(Let(arg, _, _, rec, tag)) - case l@Literal(_, _) => F.pure(l) + (traverseType(exp, bound)(fn), traverseType(in, bound)(fn)) + .mapN(Let(arg, _, _, rec, tag)) + case l @ Literal(_, _) => F.pure(l) case Match(arg, branches, tag) => val argB = traverseType(arg, bound)(fn) type B = (Pattern[(PackageName, Constructor), Type], Expr[T]) def branchFn(b: B): F[B] = b match { case (pat, expr) => - pat.traverseType(fn(_, bound)) + pat + .traverseType(fn(_, bound)) .product(traverseType(expr, bound)(fn)) } val branchB = branches.traverse(branchFn _) (argB, branchB).mapN(Match(_, _, tag)) } - private def substExpr[A](keys: NonEmptyList[Type.Var], vals: NonEmptyList[Type.Rho], expr: Expr[A]): Expr[A] = { + private def substExpr[A]( + keys: NonEmptyList[Type.Var], + vals: NonEmptyList[Type.Rho], + expr: Expr[A] + ): Expr[A] = { val fn = Type.substTy(keys, vals) traverseType[A, cats.Id](expr, Set.empty) { (t, bound) => - // we have to remove any of the keys that are bound - val isBound: Type.Var => Boolean = - { - case b @ Type.Var.Bound(_) => bound(b) - case _ => false - } + // we have to remove any of the keys that are bound + val isBound: Type.Var => Boolean = { + case b @ Type.Var.Bound(_) => bound(b) + case _ => false + } if (keys.exists(isBound)) { val kv1 = keys.zip(vals).toList.filter { case (b, _) => !isBound(b) } @@ -290,8 +335,7 @@ object Expr { case None => t } - } - else fn(t) + } else fn(t) } } @@ -305,24 +349,25 @@ object Expr { w.written.iterator.toList.distinct } - /** - * Here we substitute any free bound variables with skolem variables - * - * This is a deviation from the paper. - * We are allowing a syntax like: - * - * def identity(x: a) -> a: - * x - * - * or: - * - * def foo(x: a): x - * - * We handle this by converting a to a skolem variable, - * running inference, then quantifying over that skolem - * variable. - */ - def skolemizeVars[F[_]: Applicative, A](vs: NonEmptyList[(Type.Var.Bound, Kind)], expr: Expr[A])(newSkolemTyVar: (Type.Var.Bound, Kind) => F[Type.Var.Skolem]): F[(NonEmptyList[Type.Var.Skolem], Expr[A])] = { + /** Here we substitute any free bound variables with skolem variables + * + * This is a deviation from the paper. We are allowing a syntax like: + * + * def identity(x: a) -> a: x + * + * or: + * + * def foo(x: a): x + * + * We handle this by converting a to a skolem variable, running inference, + * then quantifying over that skolem variable. + */ + def skolemizeVars[F[_]: Applicative, A]( + vs: NonEmptyList[(Type.Var.Bound, Kind)], + expr: Expr[A] + )( + newSkolemTyVar: (Type.Var.Bound, Kind) => F[Type.Var.Skolem] + ): F[(NonEmptyList[Type.Var.Skolem], Expr[A])] = { vs.traverse { case (b, k) => newSkolemTyVar(b, k) } .map { skVs => val sksT = skVs.map(Type.TyVar(_)) @@ -332,17 +377,15 @@ object Expr { } private[bosatsu] def nameIterator(): Iterator[Bindable] = - Type - .allBinders - .iterator + Type.allBinders.iterator .map(_.name) .map(Identifier.Name(_)) - def buildPatternLambda[A]( - args: NonEmptyList[Pattern[(PackageName, Constructor), Type]], - body: Expr[A], - outer: A): Expr[A] = { + args: NonEmptyList[Pattern[(PackageName, Constructor), Type]], + body: Expr[A], + outer: A + ): Expr[A] = { /* * compute this once if needed, which is why it is lazy. @@ -374,4 +417,3 @@ object Expr { Lambda(justArgs, lambdaResult, outer) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala b/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala index bbf03eaf0..85484d855 100644 --- a/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala +++ b/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala @@ -10,7 +10,9 @@ object FfiCall { final case class Fn1(fn: Value => Value) extends FfiCall { import Value.FnValue - private[this] val evalFn: FnValue = FnValue { case NonEmptyList(a, _) => fn(a) } + private[this] val evalFn: FnValue = FnValue { case NonEmptyList(a, _) => + fn(a) + } def call(t: rankn.Type): Value = evalFn } @@ -43,7 +45,7 @@ object FfiCall { def one(t: rankn.Type): Option[Class[_]] = loop(t, false) match { case c :: Nil => Some(c) - case _ => None + case _ => None } def loop(t: rankn.Type, top: Boolean): List[Class[_]] = { @@ -52,13 +54,15 @@ object FfiCall { val ats = as.map { a => one(a) match { case Some(at) => at - case function => sys.error(s"unsupported function type $function in $t") + case function => + sys.error(s"unsupported function type $function in $t") } } val res = one(b) match { case Some(at) => at - case function => sys.error(s"unsupported function type $function in $t") + case function => + sys.error(s"unsupported function type $function in $t") } ats.toList ::: res :: Nil case rankn.Type.ForAll(_, t) => @@ -69,4 +73,3 @@ object FfiCall { loop(t, true) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Fix.scala b/core/src/main/scala/org/bykn/bosatsu/Fix.scala index f53b76465..3f00bccf8 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Fix.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Fix.scala @@ -1,11 +1,10 @@ package org.bykn.bosatsu final object FixType { - /** - * Use a trick in scala to give an opaque - * type for a fixed point recursion without - * having to allocate wrappers at each level - */ + + /** Use a trick in scala to give an opaque type for a fixed point recursion + * without having to allocate wrappers at each level + */ type Fix[F[_]] final def fix[F[_]](f: F[Fix[F]]): Fix[F] = diff --git a/core/src/main/scala/org/bykn/bosatsu/Identifier.scala b/core/src/main/scala/org/bykn/bosatsu/Identifier.scala index 1df6706a9..2b2bb1a6e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Identifier.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Identifier.scala @@ -2,7 +2,7 @@ package org.bykn.bosatsu import cats.Order import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import Parser.{lowerIdent, upperIdent} @@ -26,21 +26,21 @@ sealed abstract class Identifier { def toBindable: Option[Identifier.Bindable] = this match { case b: Identifier.Bindable => Some(b) - case _ => None + case _ => None } def toConstructor: Option[Identifier.Constructor] = this match { case c: Identifier.Constructor => Some(c) - case _ => None + case _ => None } } object Identifier { - /** - * These are names that can appear in bindings. Importantly, - * we can't bind constructor names except to define types - */ + + /** These are names that can appear in bindings. Importantly, we can't bind + * constructor names except to define types + */ sealed abstract class Bindable extends Identifier final case class Constructor(asString: String) extends Identifier @@ -60,8 +60,8 @@ object Identifier { case Backticked(lit) => Doc.char('`') + Doc.text(Parser.escape('`', lit)) + Doc.char('`') case Constructor(n) => Doc.text(n) - case Name(n) => Doc.text(n) - case Operator(n) => opPrefix + Doc.text(n) + case Name(n) => Doc.text(n) + case Operator(n) => opPrefix + Doc.text(n) } val nameParser: P[Name] = @@ -70,25 +70,24 @@ object Identifier { val consParser: P[Constructor] = upperIdent.map { c => Constructor(c.intern) } - /** - * This is used to apply operators, it is the - * raw operator tokens without an `operator` prefix - */ + /** This is used to apply operators, it is the raw operator tokens without an + * `operator` prefix + */ val rawOperator: P[Operator] = Operators.operatorToken.map { op => Operator(op.intern) } - /** - * the keyword operator preceding a rawOperator - */ + /** the keyword operator preceding a rawOperator + */ val operator: P[Operator] = (P.string("operator").soft *> Parser.spaces) *> rawOperator - /** - * Name, Backticked or non-raw operator - */ + /** Name, Backticked or non-raw operator + */ val bindableParser: P[Bindable] = // operator has to come first to not look like a Name - P.oneOf(operator :: nameParser :: Parser.escapedString('`').map { b => Backticked(b.intern) } :: Nil) + P.oneOf(operator :: nameParser :: Parser.escapedString('`').map { b => + Backticked(b.intern) + } :: Nil) val parser: P[Identifier] = bindableParser.orElse(consParser) @@ -98,22 +97,20 @@ object Identifier { def appendToName(i: Bindable, suffix: String): Bindable = i match { case Backticked(b) => Backticked(b + suffix) - case _ => + case _ => // try to stry the same val p = operator.orElse(nameParser) val cand = i.sourceCodeRepr + suffix p.parseAll(cand) match { case Right(ident) => ident - case _ => + case _ => // just turn it into a Backticked Backticked(i.asString + suffix) } } - - /** - * Build an Identifier by parsing a string - */ + /** Build an Identifier by parsing a string + */ def unsafe(str: String): Identifier = unsafeParse(parser, str) diff --git a/core/src/main/scala/org/bykn/bosatsu/Import.scala b/core/src/main/scala/org/bykn/bosatsu/Import.scala index 2076e66f7..37b2486a5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Import.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Import.scala @@ -23,7 +23,9 @@ sealed abstract class ImportedName[+T] { ImportedName.Renamed(o, l, fn(t)) } - def traverse[F[_], U](fn: T => F[U])(implicit F: Functor[F]): F[ImportedName[U]] = + def traverse[F[_], U]( + fn: T => F[U] + )(implicit F: Functor[F]): F[ImportedName[U]] = this match { case ImportedName.OriginalName(n, t) => F.map(fn(t))(ImportedName.OriginalName(n, _)) @@ -33,11 +35,13 @@ sealed abstract class ImportedName[+T] { } object ImportedName { - case class OriginalName[T](originalName: Identifier, tag: T) extends ImportedName[T] { + case class OriginalName[T](originalName: Identifier, tag: T) + extends ImportedName[T] { def localName = originalName def withTag[U](tag: U): ImportedName[U] = copy(tag = tag) } - case class Renamed[T](originalName: Identifier, localName: Identifier, tag: T) extends ImportedName[T] { + case class Renamed[T](originalName: Identifier, localName: Identifier, tag: T) + extends ImportedName[T] { def withTag[U](tag: U): ImportedName[U] = copy(tag = tag) } @@ -54,7 +58,7 @@ object ImportedName { (of ~ (spaces.soft *> P.string("as") *> spaces *> of).?) .map { case (from, Some(to)) => ImportedName.Renamed(from, to, ()) - case (orig, None) => ImportedName.OriginalName(orig, ()) + case (orig, None) => ImportedName.OriginalName(orig, ()) } basedOn(Identifier.bindableParser) @@ -64,14 +68,17 @@ object ImportedName { case class Import[A, B](pack: A, items: NonEmptyList[ImportedName[B]]) { def resolveToGlobal: Map[Identifier, (A, Identifier)] = - items.foldLeft(Map.empty[Identifier, (A, Identifier)]) { case (m0, impName) => - m0.updated(impName.localName, (pack, impName.originalName)) + items.foldLeft(Map.empty[Identifier, (A, Identifier)]) { + case (m0, impName) => + m0.updated(impName.localName, (pack, impName.originalName)) } - def mapFilter[C](fn: (A, ImportedName[B]) => Option[ImportedName[C]]): Option[Import[A, C]] = + def mapFilter[C]( + fn: (A, ImportedName[B]) => Option[ImportedName[C]] + ): Option[Import[A, C]] = NonEmptyList.fromList(items.toList.flatMap { in => fn(pack, in) }) match { case Some(i1) => Some(Import(pack, i1)) - case None => None + case None => None } } @@ -80,7 +87,9 @@ object Import { Document.instance[Import[PackageName, Unit]] { case Import(pname, items) => val itemDocs = items.toList.map(Document[ImportedName[Unit]].document _) - Doc.text("from") + Doc.space + Document[PackageName].document(pname) + Doc.space + Doc.text("import") + + Doc.text("from") + Doc.space + Document[PackageName].document( + pname + ) + Doc.space + Doc.text("import") + // TODO: use paiges to pack this in nicely using .group or something Doc.space + Doc.intercalate(Doc.text(", "), itemDocs) } @@ -88,21 +97,24 @@ object Import { val parser: P[Import[PackageName, Unit]] = { val pyimps = ImportedName.parser.itemsMaybeParens.map(_._2) - ((P.string("from") ~ spaces).backtrack *> PackageName.parser <* spaces, - P.string("import") *> spaces *> pyimps) + ( + (P.string("from") ~ spaces).backtrack *> PackageName.parser <* spaces, + P.string("import") *> spaces *> pyimps + ) .mapN(Import(_, _)) } - /** - * This only keeps the last name if there are duplicate local names - * checking for duplicate local names should be done at another layer - */ - def locals[F[_]: Foldable, A, B, C](imp: Import[A, F[B]])(pn: PartialFunction[B, C]): Map[Identifier, C] = { + /** This only keeps the last name if there are duplicate local names checking + * for duplicate local names should be done at another layer + */ + def locals[F[_]: Foldable, A, B, C]( + imp: Import[A, F[B]] + )(pn: PartialFunction[B, C]): Map[Identifier, C] = { val fn = pn.lift imp.items.foldLeft(Map.empty[Identifier, C]) { case (m0, impName) => impName.tag.foldLeft(m0) { (m1, b) => fn(b) match { - case None => m1 + case None => m1 case Some(c) => m1.updated(impName.localName, c) } } @@ -110,9 +122,8 @@ object Import { } } -/** - * There are all the distinct imported names and the original ImportedName - */ +/** There are all the distinct imported names and the original ImportedName + */ case class ImportMap[A, B](toMap: Map[Identifier, (A, ImportedName[B])]) { def apply(name: Identifier): Option[(A, ImportedName[B])] = toMap.get(name) @@ -125,16 +136,18 @@ object ImportMap { def empty[A, B]: ImportMap[A, B] = ImportMap(Map.empty) // Return the list of collisions in local names along with a map // with the last name overwriting the import - def fromImports[A, B](is: List[Import[A, B]]): (List[(A, ImportedName[B])], ImportMap[A, B]) = + def fromImports[A, B]( + is: List[Import[A, B]] + ): (List[(A, ImportedName[B])], ImportMap[A, B]) = is.iterator .flatMap { case Import(p, is) => is.toList.iterator.map((p, _)) } .foldLeft((List.empty[(A, ImportedName[B])], ImportMap.empty[A, B])) { - case ((dups, imap), pim@(_, im)) => + case ((dups, imap), pim @ (_, im)) => val dups1 = imap(im.localName) match { case Some(nm) => nm :: dups - case None => dups + case None => dups } (dups1, imap + pim) - } + } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Indented.scala b/core/src/main/scala/org/bykn/bosatsu/Indented.scala index 4a32c4ebb..107fbf363 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Indented.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Indented.scala @@ -1,6 +1,6 @@ package org.bykn.bosatsu -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import cats.parse.{Parser => P} @@ -13,9 +13,9 @@ case class Indented[T](spaces: Int, value: T) { object Indented { def spaceCount(str: String): Int = str.foldLeft(0) { - case (s, ' ') => s + 1 + case (s, ' ') => s + 1 case (s, '\t') => s + 4 - case (_, c) => sys.error(s"unexpected space character($c) in $str") + case (_, c) => sys.error(s"unexpected space character($c) in $str") } implicit def document[T: Document]: Document[Indented[T]] = @@ -23,14 +23,11 @@ object Indented { Doc.spaces(i) + (Document[T].document(t).nested(i)) } - - /** - * This reads a new line at a deeper indentation level - * than we currently are. - * - * So we are starting from the 0 column and read - * the current indentation level plus at least one space more - */ + /** This reads a new line at a deeper indentation level than we currently are. + * + * So we are starting from the 0 column and read the current indentation + * level plus at least one space more + */ def indy[T](p: Parser.Indy[T]): Parser.Indy[Indented[T]] = Parser.Indy { indent => for { @@ -39,4 +36,3 @@ object Indented { } yield Indented(Indented.spaceCount(thisIndent), t) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/IorMethods.scala b/core/src/main/scala/org/bykn/bosatsu/IorMethods.scala index af335f390..a8d4c0e7d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/IorMethods.scala +++ b/core/src/main/scala/org/bykn/bosatsu/IorMethods.scala @@ -6,8 +6,8 @@ object IorMethods { implicit class IorExtension[A, B](val ior: Ior[A, B]) extends AnyVal { def strictToValidated: Validated[A, B] = ior match { - case Ior.Right(b) => Validated.valid(b) - case Ior.Left(a) => Validated.invalid(a) + case Ior.Right(b) => Validated.valid(b) + case Ior.Left(a) => Validated.invalid(a) case Ior.Both(a, _) => Validated.invalid(a) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Json.scala b/core/src/main/scala/org/bykn/bosatsu/Json.scala index d78cd77f1..b272d871c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Json.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Json.scala @@ -5,9 +5,8 @@ import org.typelevel.paiges.Doc import cats.parse.{Parser0 => P0, Parser => P} import cats.Eq -/** - * A simple JSON ast for output - */ +/** A simple JSON ast for output + */ sealed abstract class Json { def toDoc: Doc @@ -45,7 +44,7 @@ object Json { } def unapply(j: Json): Option[BigInteger] = j match { - case num@JNumberStr(str) => + case num @ JNumberStr(str) => if (allDigits(str)) Some(new BigInteger(str)) else num.toBigInteger case _ => None @@ -70,9 +69,9 @@ object Json { def unapply(j: Json): Option[Boolean] = j match { - case True => someTrue + case True => someTrue case False => someFalse - case _ => None + case _ => None } } @@ -86,7 +85,10 @@ object Json { def toDoc = if (toVector.isEmpty) emptyArray else { - val parts = Doc.intercalate(Doc.comma, toVector.map { j => (Doc.line + j.toDoc).grouped }) + val parts = Doc.intercalate( + Doc.comma, + toVector.map { j => (Doc.line + j.toDoc).grouped } + ) "[" +: ((parts :+ " ]").nested(2)) } @@ -104,75 +106,75 @@ object Json { else { val kvs = keys.map { k => val j = toMap(k) - JString(k).toDoc + Doc.char(':') + ((Doc.lineOrSpace + j.toDoc).nested(2)) + JString(k).toDoc + Doc.char(':') + ((Doc.lineOrSpace + j.toDoc) + .nested(2)) } val parts = Doc.intercalate(Doc.comma + Doc.line, kvs).grouped parts.bracketBy(text("{"), text("}")) } - /** - * Return a JObject with each key at most once, but in the order of this - */ + /** Return a JObject with each key at most once, but in the order of this + */ def normalize: JObject = JObject(keys.map { k => (k, toMap(k)) }) def render = toDoc.render(80) } - /** - * this checks for semantic equivalence: - * 1. we use BigDecimal to compare JNumberStr - * 2. we normalize objects - */ + /** this checks for semantic equivalence: + * 1. we use BigDecimal to compare JNumberStr 2. we normalize objects + */ implicit val eqJson: Eq[Json] = new Eq[Json] { def eqv(a: Json, b: Json) = (a, b) match { - case (JNull, JNull) => true - case (JBool.True, JBool.True) => true + case (JNull, JNull) => true + case (JBool.True, JBool.True) => true case (JBool.False, JBool.False) => true case (JString(sa), JString(sb)) => sa == sb case (JNumberStr(sa), JNumberStr(sb)) => new BigDecimal(sa).compareTo(new BigDecimal(sb)) == 0 case (JArray(itemsa), JArray(itemsb)) => (itemsa.size == itemsb.size) && - itemsa.iterator - .zip(itemsb.iterator) - .forall { case (a, b) => eqv(a, b) } - case (oa@JObject(_), ob@JObject(_)) => + itemsa.iterator + .zip(itemsb.iterator) + .forall { case (a, b) => eqv(a, b) } + case (oa @ JObject(_), ob @ JObject(_)) => val na = oa.normalize val nb = ob.normalize (na.toMap.keySet == nb.toMap.keySet) && - na.keys.forall { k => - eqv(na.toMap(k), nb.toMap(k)) - } + na.keys.forall { k => + eqv(na.toMap(k), nb.toMap(k)) + } case (_, _) => false } } private[this] val whitespace: P[Unit] = P.charIn(" \t\r\n").void private[this] val whitespaces0: P0[Unit] = whitespace.rep0.void - /** - * This doesn't have to be super fast (but is fairly fast) since we use it in places - * where speed won't matter: feeding it into a program that will convert it to bosatsu - * structured data - */ + + /** This doesn't have to be super fast (but is fairly fast) since we use it in + * places where speed won't matter: feeding it into a program that will + * convert it to bosatsu structured data + */ val parser: P[Json] = { val recurse = P.defer(parser) // cats-parse uses a radix tree for these so it only needs to check 1 character // to see if it misses - val pconst = P.fromStringMap(Map( - "null" -> JNull, - "true" -> JBool.True, - "false" -> JBool.False - /* you can imagine going nuts, but we should justify this with benchmarks + val pconst = P.fromStringMap( + Map( + "null" -> JNull, + "true" -> JBool.True, + "false" -> JBool.False + /* you can imagine going nuts, but we should justify this with benchmarks "0" -> JNumberStr("0"), "1" -> JNumberStr("1"), "\"\"" -> JString(""), "[]" -> JArray(Vector.empty), "{}" -> JObject(Nil) - */ - )) + */ + ) + ) val justStr = JsonStringUtil.escapedString('"') val str = justStr.map(JString(_)) @@ -197,6 +199,7 @@ object Json { } // any whitespace followed by json followed by whitespace followed by end - val parserFile: P[Json] = whitespaces0.with1 *> (parser ~ whitespaces0 ~ P.end).map(_._1._1) + val parserFile: P[Json] = + whitespaces0.with1 *> (parser ~ whitespaces0 ~ P.end).map(_._1._1) } diff --git a/core/src/main/scala/org/bykn/bosatsu/Kind.scala b/core/src/main/scala/org/bykn/bosatsu/Kind.scala index 3dc1098b3..116a2744e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Kind.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Kind.scala @@ -102,9 +102,11 @@ object Kind { case _ => false } - def validApply[A](left: Kind, right: Kind, onTypeErr: => A)(onSubsumeFail: Cons => A): Either[A, Kind] = + def validApply[A](left: Kind, right: Kind, onTypeErr: => A)( + onSubsumeFail: Cons => A + ): Either[A, Kind] = left match { - case cons@Cons(Kind.Arg(_, lhs), res) => + case cons @ Cons(Kind.Arg(_, lhs), res) => if (leftSubsumesRight(lhs, right)) Right(res) else Left(onSubsumeFail(cons)) case Kind.Type => Left(onTypeErr) @@ -303,7 +305,7 @@ object Kind { val left1 = (left & 1).toLong val right1 = (right & 1).toLong val acc1 = acc | (left1 << (2 * depth + 1)) | (right1 << (2 * depth)) - loop(left >>> 1, right >>> 1, depth + 1, acc1) + loop(left >>> 1, right >>> 1, depth + 1, acc1) } } @@ -329,15 +331,20 @@ object Kind { private def varianceToInt(v: Variance): Int = { import Variance._ v match { - case Invariant => 3 + case Invariant => 3 case Contravariant => 2 - case Covariant => 1 - case Phantom => 0 + case Covariant => 1 + case Phantom => 0 } } private val vars = - Array(Variance.Phantom, Variance.Covariant, Variance.Contravariant, Variance.Invariant) + Array( + Variance.Phantom, + Variance.Covariant, + Variance.Contravariant, + Variance.Invariant + ) private def intToVariance(v: Int): Variance = vars(v & 3) @@ -359,8 +366,7 @@ object Kind { val notZero = intr + 1L if (notZero == 0) None else Some(notZero) - } - else None + } else None } } @@ -375,7 +381,7 @@ object Kind { (longToKind(left), longToKind(right)) .mapN { (kl, kr) => val v = intToVariance(leftVar.toInt) - Cons(Arg(v, kl), kr) + Cons(Arg(v, kl), kr) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/KindFormula.scala b/core/src/main/scala/org/bykn/bosatsu/KindFormula.scala index dc870c1f9..ef2147cb4 100644 --- a/core/src/main/scala/org/bykn/bosatsu/KindFormula.scala +++ b/core/src/main/scala/org/bykn/bosatsu/KindFormula.scala @@ -262,16 +262,18 @@ object KindFormula { dts: List[DefinedType[Either[KnownShape, Kind.Arg]]] ): IorNec[Error, List[DefinedType[Kind.Arg]]] = dts - .foldM((List.empty[DefinedType[Kind.Arg]], Set.empty[RankNType.TyConst])) { case (st @ (acc, failed), dt) => + .foldM( + (List.empty[DefinedType[Kind.Arg]], Set.empty[RankNType.TyConst]) + ) { case (st @ (acc, failed), dt) => // don't evaluate dependsOn if failed is empty if (failed.nonEmpty && dt.dependsOn.exists(failed)) { // there was at least one failure already, just return and let that failure signal Ior.Right(st) - } - else { + } else { solveKind((imports, acc), dt) match { - case Validated.Valid(good) => Ior.Right((good :: acc, failed)) - case Validated.Invalid(errs) => Ior.Both(errs, (acc, failed + dt.toTypeTyConst)) + case Validated.Valid(good) => Ior.Right((good :: acc, failed)) + case Validated.Invalid(errs) => + Ior.Both(errs, (acc, failed + dt.toTypeTyConst)) } } } @@ -369,7 +371,7 @@ object KindFormula { // invariant: all subgraph values must be valid keys in the result // we can process the subgraph list in parallel. Those are all the indepentent next values def allSolutionChunk( - dt: DefinedType[Either[Shape.KnownShape, Kind.Arg]], + dt: DefinedType[Either[Shape.KnownShape, Kind.Arg]], cons: Cons, existing: LongMap[Variance], directions: LongMap[Direction], @@ -430,12 +432,14 @@ object KindFormula { NonEmptyLazyList.fromLazyList(validVariances) match { case Some(nel) => Validated.valid(nel) case None => - Validated.invalidNec(Error.Unsatisfiable(dt, cons, existing, subgraph)) + Validated.invalidNec( + Error.Unsatisfiable(dt, cons, existing, subgraph) + ) } } def allSolutions( - dt: DefinedType[Either[Shape.KnownShape, Kind.Arg]], + dt: DefinedType[Either[Shape.KnownShape, Kind.Arg]], cons: Cons, existing: LongMap[Variance], directions: LongMap[Direction], @@ -447,7 +451,7 @@ object KindFormula { .toEither def go[F[_]: Foldable]( - dt: DefinedType[Either[Shape.KnownShape, Kind.Arg]], + dt: DefinedType[Either[Shape.KnownShape, Kind.Arg]], cons: Cons, directions: LongMap[Direction], topo: F[NonEmptyList[SortedSet[Long]]] @@ -654,8 +658,9 @@ object KindFormula { ): RefSpace[KindFormula] = tpe match { case fa: rankn.Type.Quantified => - val newKindMap = kinds ++ fa.vars.toList.iterator.map { case (b, k) => - b -> BoundState.IsKind(k, fa, b) + val newKindMap = kinds ++ fa.vars.toList.iterator.map { + case (b, k) => + b -> BoundState.IsKind(k, fa, b) } kindOfType(direction, thisKind, cfn, idx, fa.in, newKindMap) diff --git a/core/src/main/scala/org/bykn/bosatsu/ListLang.scala b/core/src/main/scala/org/bykn/bosatsu/ListLang.scala index eb486a054..c19c1254b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ListLang.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ListLang.scala @@ -7,10 +7,9 @@ import cats.parse.{Parser => P} import cats.implicits._ -/** - * Represents the list construction sublanguage - * A is the expression type, B is the pattern type for bindings - */ +/** Represents the list construction sublanguage A is the expression type, B is + * the pattern type for bindings + */ sealed abstract class ListLang[F[_], A, +B] object ListLang { sealed abstract class SpliceOrItem[A] { @@ -37,10 +36,12 @@ object ListLang { .map(Splice(_)) .orElse(pa.map(Item(_))) - implicit def document[A](implicit A: Document[A]): Document[SpliceOrItem[A]] = + implicit def document[A](implicit + A: Document[A] + ): Document[SpliceOrItem[A]] = Document.instance[SpliceOrItem[A]] { case Splice(a) => Doc.char('*') + A.document(a) - case Item(a) => A.document(a) + case Item(a) => A.document(a) } } @@ -58,38 +59,58 @@ object ListLang { .map { case (k, v) => KVPair(k, v) } implicit def document[A](implicit A: Document[A]): Document[KVPair[A]] = - Document.instance[KVPair[A]] { - case KVPair(k, v) => A.document(k) + sep + A.document(v) + Document.instance[KVPair[A]] { case KVPair(k, v) => + A.document(k) + sep + A.document(v) } } case class Cons[F[_], A](items: List[F[A]]) extends ListLang[F, A, Nothing] - case class Comprehension[F[_], A, B](expr: F[A], binding: B, in: A, filter: Option[A]) extends ListLang[F, A, B] - - def parser[A, B](pa: P[A], psrc: P[A], pbind: P[B]): P[ListLang[SpliceOrItem, A, B]] = + case class Comprehension[F[_], A, B]( + expr: F[A], + binding: B, + in: A, + filter: Option[A] + ) extends ListLang[F, A, B] + + def parser[A, B]( + pa: P[A], + psrc: P[A], + pbind: P[B] + ): P[ListLang[SpliceOrItem, A, B]] = genParser(P.char('['), SpliceOrItem.parser(pa), psrc, pbind, P.char(']')) - def dictParser[A, B](pa: P[A], psrc: P[A], pbind: P[B]): P[ListLang[KVPair, A, B]] = + def dictParser[A, B]( + pa: P[A], + psrc: P[A], + pbind: P[B] + ): P[ListLang[KVPair, A, B]] = genParser(P.char('{'), KVPair.parser(pa), psrc, pbind, P.char('}')) - def genParser[F[_], A, B](left: P[Unit], fa: P[F[A]], pa: P[A], pbind: P[B], right: P[Unit]): P[ListLang[F, A, B]] = { + def genParser[F[_], A, B]( + left: P[Unit], + fa: P[F[A]], + pa: P[A], + pbind: P[B], + right: P[Unit] + ): P[ListLang[F, A, B]] = { // construct the tail of a list, so we will finally have at least one item - val consTail = fa.nonEmptyListOfWs(maybeSpacesAndLines).? - .map { tail => - val listTail = tail match { - case None => Nil - case Some(ne) => ne.toList - } - - { (a: F[A]) => Cons(a :: listTail) } + val consTail = fa.nonEmptyListOfWs(maybeSpacesAndLines).?.map { tail => + val listTail = tail match { + case None => Nil + case Some(ne) => ne.toList } + { (a: F[A]) => Cons(a :: listTail) } + } + val filterExpr = P.string("if") *> spacesAndLines *> pa val comp = - (P.string("for") *> spacesAndLines *> pbind <* maybeSpacesAndLines, - P.string("in") *> spacesAndLines *> pa <* maybeSpacesAndLines, - filterExpr.?) + ( + P.string("for") *> spacesAndLines *> pbind <* maybeSpacesAndLines, + P.string("in") *> spacesAndLines *> pa <* maybeSpacesAndLines, + filterExpr.? + ) .mapN { (b, i, f) => { (e: F[A]) => Comprehension(e, b, i, f) } } @@ -99,21 +120,24 @@ object ListLang { (left *> maybeSpacesAndLines *> (fa ~ inner.?).? <* maybeSpacesAndLines <* right) .map { - case None => Cons(Nil) - case Some((a, None)) => Cons(a :: Nil) + case None => Cons(Nil) + case Some((a, None)) => Cons(a :: Nil) case Some((a, Some(rest))) => rest(a) } } - def genDocument[F[_], A, B](left: Doc, right: Doc)(implicit F: Document[F[A]], A: Document[A], B: Document[B]): Document[ListLang[F, A, B]] = + def genDocument[F[_], A, B](left: Doc, right: Doc)(implicit + F: Document[F[A]], + A: Document[A], + B: Document[B] + ): Document[ListLang[F, A, B]] = Document.instance[ListLang[F, A, B]] { case Cons(items) => - left + Doc.intercalate(Doc.text(", "), - items.map(F.document(_))) + + left + Doc.intercalate(Doc.text(", "), items.map(F.document(_))) + right case Comprehension(e, b, i, f) => val filt = f match { - case None => Doc.empty + case None => Doc.empty case Some(e) => Doc.text(" if ") + A.document(e) } left + F.document(e) + Doc.text(" for ") + @@ -122,10 +146,15 @@ object ListLang { right } - implicit def document[A, B](implicit A: Document[A], B: Document[B]): Document[ListLang[SpliceOrItem, A, B]] = + implicit def document[A, B](implicit + A: Document[A], + B: Document[B] + ): Document[ListLang[SpliceOrItem, A, B]] = genDocument[SpliceOrItem, A, B](Doc.char('['), Doc.char(']')) - implicit def documentDict[A, B](implicit A: Document[A], B: Document[B]): Document[ListLang[KVPair, A, B]] = + implicit def documentDict[A, B](implicit + A: Document[A], + B: Document[B] + ): Document[ListLang[KVPair, A, B]] = genDocument[KVPair, A, B](Doc.char('{'), Doc.char('}')) } - diff --git a/core/src/main/scala/org/bykn/bosatsu/ListUtil.scala b/core/src/main/scala/org/bykn/bosatsu/ListUtil.scala index a567e93e5..950204004 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ListUtil.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ListUtil.scala @@ -14,7 +14,9 @@ private[bosatsu] object ListUtil { else (h :: t1) // we only allocate here } - def greedyGroup[A, G](list: NonEmptyList[A])(one: A => G)(combine: (G, A) => Option[G]): NonEmptyList[G] = { + def greedyGroup[A, G]( + list: NonEmptyList[A] + )(one: A => G)(combine: (G, A) => Option[G]): NonEmptyList[G] = { def loop(g: G, tail: List[A]): NonEmptyList[G] = tail match { case Nil => NonEmptyList.one(g) @@ -32,16 +34,20 @@ private[bosatsu] object ListUtil { loop(one(list.head), list.tail) } - def greedyGroup[A, G](list: List[A])(one: A => G)(combine: (G, A) => Option[G]): List[G] = + def greedyGroup[A, G]( + list: List[A] + )(one: A => G)(combine: (G, A) => Option[G]): List[G] = NonEmptyList.fromList(list) match { - case None => Nil + case None => Nil case Some(nel) => greedyGroup(nel)(one)(combine).toList } - def mapConserveNel[A <: AnyRef, B >: A <: AnyRef](nel: NonEmptyList[A])(f: A => B): NonEmptyList[B] = { + def mapConserveNel[A <: AnyRef, B >: A <: AnyRef]( + nel: NonEmptyList[A] + )(f: A => B): NonEmptyList[B] = { val as = nel.toList val bs = as.mapConserve(f) if (bs eq as) nel else NonEmptyList.fromListUnsafe(bs) } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/bykn/bosatsu/Lit.scala b/core/src/main/scala/org/bykn/bosatsu/Lit.scala index 22b15a4ba..67eccaf01 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Lit.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Lit.scala @@ -28,24 +28,21 @@ object Lit { private val cache: Array[Integer] = (INT_MIN_CACHE to INT_MAX_CACHE).map { i => new Integer(BigInteger.valueOf(i.toLong)) - } - .toArray - + }.toArray + def apply(bi: BigInteger): Integer = { val i = bi.intValue if ((INT_MIN_CACHE <= i) && (i <= INT_MAX_CACHE)) { val int = cache(i - INT_MIN_CACHE) if (bi == int.toBigInteger) int else new Integer(bi) - } - else new Integer(bi) + } else new Integer(bi) } def apply(l: Long): Integer = { if ((INT_MIN_CACHE <= l) && (l <= INT_MAX_CACHE)) { cache(l.toInt - INT_MIN_CACHE) - } - else new Integer(BigInteger.valueOf(l)) + } else new Integer(BigInteger.valueOf(l)) } } @@ -62,8 +59,9 @@ object Lit { private[this] val cache: Array[Chr] = (0 until 256).map(build).toArray - /** - * @throws IllegalArgumentException on a bad codepoint + + /** @throws IllegalArgumentException + * on a bad codepoint */ def fromCodePoint(cp: Int): Chr = if ((0 <= cp) && (cp < 256)) cache(cp) @@ -76,7 +74,9 @@ object Lit { def fromChar(c: Char): Lit = if (0xd800 <= c && c < 0xe000) - throw new IllegalArgumentException(s"utf-16 character int=${c.toInt} is not a valid single codepoint") + throw new IllegalArgumentException( + s"utf-16 character int=${c.toInt} is not a valid single codepoint" + ) else Chr.fromCodePoint(c.toInt) def fromCodePoint(cp: Int): Lit = Chr.fromCodePoint(cp) @@ -99,20 +99,21 @@ object Lit { val codePointParser: P[Chr] = { (StringUtil.codepoint(P.string(".\""), P.char('"')) | - StringUtil.codepoint(P.string(".'"), P.char('\''))).map(Chr.fromCodePoint(_)) + StringUtil.codepoint(P.string(".'"), P.char('\''))) + .map(Chr.fromCodePoint(_)) } implicit val litOrdering: Ordering[Lit] = new Ordering[Lit] { def compare(a: Lit, b: Lit): Int = (a, b) match { - case (Integer(a), Integer(b)) => a.compareTo(b) + case (Integer(a), Integer(b)) => a.compareTo(b) case (Integer(_), Str(_) | Chr(_)) => -1 - case (Chr(_), Integer(_)) => 1 - case (Chr(a), Chr(b)) => a.compareTo(b) - case (Chr(_), Str(_)) => -1 - case (Str(_), Integer(_)| Chr(_)) => 1 - case (Str(a), Str(b)) => a.compareTo(b) + case (Chr(_), Integer(_)) => 1 + case (Chr(a), Chr(b)) => a.compareTo(b) + case (Chr(_), Str(_)) => -1 + case (Str(_), Integer(_) | Chr(_)) => 1 + case (Str(a), Str(b)) => a.compareTo(b) } } @@ -125,7 +126,7 @@ object Lit { case Str(str) => val q = if (str.contains('\'') && !str.contains('"')) '"' else '\'' Doc.char(q) + Doc.text(escape(q, str)) + Doc.char(q) - case c @ Chr(_) => + case c @ Chr(_) => val str = c.asStr val (start, end) = if (str.contains('\'') && !str.contains('"')) (".\"", '"') @@ -133,4 +134,3 @@ object Lit { Doc.text(start) + Doc.text(escape(end, str)) + Doc.char(end) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/LocationMap.scala b/core/src/main/scala/org/bykn/bosatsu/LocationMap.scala index baffe6e09..9a5caa101 100644 --- a/core/src/main/scala/org/bykn/bosatsu/LocationMap.scala +++ b/core/src/main/scala/org/bykn/bosatsu/LocationMap.scala @@ -7,15 +7,13 @@ import cats.implicits._ import LocationMap.Colorize -/** - * Build a cache of the rows and columns in a given - * string. This is for showing error messages to users - */ +/** Build a cache of the rows and columns in a given string. This is for showing + * error messages to users + */ case class LocationMap(fromString: String) extends CPLocationMap(fromString) { private def lineRange(start: Int, end: Int): List[(Int, String)] = - (start to end) - .iterator + (start to end).iterator .filter(_ >= 0) .map { r => val liner = getLine(r).get // should never throw @@ -24,10 +22,9 @@ case class LocationMap(fromString: String) extends CPLocationMap(fromString) { } .toList - /** - * convert tab to tab, but otherwise space - * return the white space before this column - */ + /** convert tab to tab, but otherwise space return the white space before this + * column + */ private def spaceOf(row: Int, col: Int): Option[String] = getLine(row) .map { line => @@ -42,7 +39,11 @@ case class LocationMap(fromString: String) extends CPLocationMap(fromString) { bldr.toString() } - def showContext(offset: Int, previousLines: Int, color: Colorize): Option[Doc] = + def showContext( + offset: Int, + previousLines: Int, + color: Colorize + ): Option[Doc] = toLineCol(offset) .map { case (r, c) => val lines = lineRange(r - previousLines, r) @@ -60,10 +61,17 @@ case class LocationMap(fromString: String) extends CPLocationMap(fromString) { val ctx = Doc.intercalate(Doc.hardLine, lineDocs) // convert to spaces val colPad = spaceOf(r, c).get - ctx + Doc.hardLine + pointerPad + LocationMap.pointerTo(colPad, color) + Doc.hardLine + ctx + Doc.hardLine + pointerPad + LocationMap.pointerTo( + colPad, + color + ) + Doc.hardLine } - def showRegion(region: Region, previousLines: Int, color: Colorize): Option[Doc] = + def showRegion( + region: Region, + previousLines: Int, + color: Colorize + ): Option[Doc] = (toLineCol(region.start), toLineCol(region.end - 1)) .mapN { case ((l0, c0), (l1, c1)) => val lines = lineRange(l0 - previousLines, l1) @@ -78,14 +86,19 @@ case class LocationMap(fromString: String) extends CPLocationMap(fromString) { // same line // here is how much extra we need for the pointer val pointerPad = Doc.spaces(toLineStr(l0).render(0).length) - val lineDocs = lines.map { case (no, l) => toLineStr(no) + Doc.text(l) } + val lineDocs = lines.map { case (no, l) => + toLineStr(no) + Doc.text(l) + } val ctx = Doc.intercalate(Doc.hardLine, lineDocs) val c0Pad = spaceOf(l0, c0).get // we go one more to cover the column val c1Pad = spaceOf(l0, c1 + 1).get - ctx + Doc.hardLine + pointerPad + LocationMap.pointerRange(c0Pad, c1Pad, color) + Doc.hardLine - } - else { + ctx + Doc.hardLine + pointerPad + LocationMap.pointerRange( + c0Pad, + c1Pad, + color + ) + Doc.hardLine + } else { // we span multiple lines, show the start and the end: val newPrev = l1 - l0 showContext(region.start, previousLines, color).get + @@ -110,26 +123,32 @@ object LocationMap { object Console extends Colorize { def red(d: Doc) = - Doc.zeroWidth(scala.Console.RED) + d.unzero + Doc.zeroWidth(scala.Console.RESET) + Doc.zeroWidth(scala.Console.RED) + d.unzero + Doc.zeroWidth( + scala.Console.RESET + ) def green(d: Doc) = - Doc.zeroWidth(scala.Console.GREEN) + d.unzero + Doc.zeroWidth(scala.Console.RESET) + Doc.zeroWidth(scala.Console.GREEN) + d.unzero + Doc.zeroWidth( + scala.Console.RESET + ) } object HmtlFont extends Colorize { def red(d: Doc) = - Doc.zeroWidth("") + d.unzero + Doc.zeroWidth("") + Doc.zeroWidth("") + d.unzero + Doc.zeroWidth( + "" + ) def green(d: Doc) = - Doc.zeroWidth("") + d.unzero + Doc.zeroWidth("") + Doc.zeroWidth("") + d.unzero + Doc.zeroWidth( + "" + ) } } - /** - * Provide a string that points with a carat to a given column - * with 0 based indexing: - * e.g. pointerTo(2) == " ^" - */ + /** Provide a string that points with a carat to a given column with 0 based + * indexing: e.g. pointerTo(2) == " ^" + */ def pointerTo(colStr: String, color: Colorize): Doc = { val col = Doc.text(colStr) val pointer = Doc.char('^') @@ -141,7 +160,7 @@ object LocationMap { // just use tab for any tabs val pointerStr = endPad.drop(startPad.length).map { case '\t' => '\t' - case _ => '^' + case _ => '^' } val pointer = Doc.text(pointerStr) col + color.red(pointer) diff --git a/core/src/main/scala/org/bykn/bosatsu/MainModule.scala b/core/src/main/scala/org/bykn/bosatsu/MainModule.scala index bda023629..13684802c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MainModule.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MainModule.scala @@ -98,9 +98,10 @@ abstract class MainModule[IO[_]](implicit extends Output case class ShowOutput( - packages: List[Package.Typed[Any]], - ifaces: List[Package.Interface], - output: Option[Path]) extends Output + packages: List[Package.Typed[Any]], + ifaces: List[Package.Interface], + output: Option[Path] + ) extends Output } sealed abstract class MainException extends Exception { @@ -724,11 +725,20 @@ abstract class MainModule[IO[_]](implicit } yield packPath } - class Show(srcs: PathGen, ifaces: PathGen, includes: PathGen, packageResolver: PackageResolver) extends Inputs { + class Show( + srcs: PathGen, + ifaces: PathGen, + includes: PathGen, + packageResolver: PackageResolver + ) extends Inputs { def loadAndCompile(cmd: MainCommand, errColor: Colorize)(implicit ec: Par.EC ): IO[(List[Package.Interface], List[Package.Typed[Any]])] = - (srcs.read, ifaces.read.flatMap(readInterfaces), includes.read.flatMap(readPackages)) + ( + srcs.read, + ifaces.read.flatMap(readInterfaces), + includes.read.flatMap(readPackages) + ) .flatMapN { case (Nil, ifaces, packs) => moduleIOMonad.pure((ifaces, packs)) @@ -742,8 +752,9 @@ abstract class MainModule[IO[_]](implicit errColor, packageResolver ) - allPacks = (PackageMap.fromIterable(packs) ++ packPath._1.toMap.map(_._2)) - .toMap.toList.map(_._2) + allPacks = (PackageMap.fromIterable( + packs + ) ++ packPath._1.toMap.map(_._2)).toMap.toList.map(_._2) } yield (ifaces, allPacks) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index 362985481..adf8b5a9e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -56,7 +56,7 @@ object Matchless { def from(sp: StrPart): MatchSize = sp match { - case _: Glob => atLeast0 + case _: Glob => atLeast0 case _: CharPart => exactly1 case LitStr(str) => Exactly(str.codePointCount(0, str.length)) @@ -76,13 +76,23 @@ object Matchless { } // name is set for recursive (but not tail recursive) methods - case class Lambda(captures: List[Expr], name: Option[Bindable], args: NonEmptyList[Bindable], expr: Expr) extends FnExpr + case class Lambda( + captures: List[Expr], + name: Option[Bindable], + args: NonEmptyList[Bindable], + expr: Expr + ) extends FnExpr // this is a tail recursive function that should be compiled into a loop // when a call to name is done inside body, that should restart the loop // the type of this Expr a function with the arity of args that returns // the type of body - case class LoopFn(captures: List[Expr], name: Bindable, arg: NonEmptyList[Bindable], body: Expr) extends FnExpr + case class LoopFn( + captures: List[Expr], + name: Bindable, + arg: NonEmptyList[Bindable], + body: Expr + ) extends FnExpr case class Global(pack: PackageName, name: Bindable) extends CheapExpr @@ -97,7 +107,11 @@ object Matchless { // we aggregate all the applications to potentially make dispatch more efficient // note fn is never an App case class App(fn: Expr, arg: NonEmptyList[Expr]) extends Expr - case class Let(arg: Either[LocalAnon, (Bindable, RecursionKind)], expr: Expr, in: Expr) extends Expr + case class Let( + arg: Either[LocalAnon, (Bindable, RecursionKind)], + expr: Expr, + in: Expr + ) extends Expr case class LetMut(name: LocalAnonMut, span: Expr) extends Expr case class Literal(lit: Lit) extends CheapExpr @@ -109,7 +123,7 @@ object Matchless { (this, that) match { case (TrueConst, r) => r case (l, TrueConst) => l - case _ => And(this, that) + case _ => And(this, that) } } // returns 1 if it does, else 0 @@ -119,16 +133,30 @@ object Matchless { case class And(e1: BoolExpr, e2: BoolExpr) extends BoolExpr // checks if variant matches, and if so, writes to // a given mut - case class CheckVariant(expr: CheapExpr, expect: Int, size: Int, famArities: List[Int]) extends BoolExpr + case class CheckVariant( + expr: CheapExpr, + expect: Int, + size: Int, + famArities: List[Int] + ) extends BoolExpr // handle list matching, this is a while loop, that is evaluting // lst is initialized to init, leftAcc is initialized to empty // tail until it is true while mutating lst => lst.tail // this has the side-effect of mutating lst and leftAcc as well as any side effects that check has // which could have nested searches of its own - case class SearchList(lst: LocalAnonMut, init: CheapExpr, check: BoolExpr, leftAcc: Option[LocalAnonMut]) extends BoolExpr + case class SearchList( + lst: LocalAnonMut, + init: CheapExpr, + check: BoolExpr, + leftAcc: Option[LocalAnonMut] + ) extends BoolExpr // set the mutable variable to the given expr and return true // string matching is complex done at a lower level - case class MatchString(arg: CheapExpr, parts: List[StrPart], binds: List[LocalAnonMut]) extends BoolExpr + case class MatchString( + arg: CheapExpr, + parts: List[StrPart], + binds: List[LocalAnonMut] + ) extends BoolExpr // set the mutable variable to the given expr and return true case class SetMut(target: LocalAnonMut, expr: Expr) extends BoolExpr case object TrueConst extends BoolExpr @@ -136,9 +164,11 @@ object Matchless { def hasSideEffect(bx: BoolExpr): Boolean = bx match { case SetMut(_, _) => true - case TrueConst | CheckVariant(_, _, _, _) | EqualsLit(_, _) | EqualsNat(_, _) => false + case TrueConst | CheckVariant(_, _, _, _) | EqualsLit(_, _) | + EqualsNat(_, _) => + false case MatchString(_, _, b) => b.nonEmpty - case And(b1, b2) => hasSideEffect(b1) || hasSideEffect(b2) + case And(b1, b2) => hasSideEffect(b1) || hasSideEffect(b2) case SearchList(_, _, b, l) => l.nonEmpty || hasSideEffect(b) } @@ -149,18 +179,20 @@ object Matchless { if (hasSideEffect(cond)) Always(cond, thenExpr) else thenExpr - /** - * These aren't really super cheap, but when we treat them cheap we check that we will only - * call them one time - */ - case class GetEnumElement(arg: CheapExpr, variant: Int, index: Int, size: Int) extends CheapExpr - case class GetStructElement(arg: CheapExpr, index: Int, size: Int) extends CheapExpr + /** These aren't really super cheap, but when we treat them cheap we check + * that we will only call them one time + */ + case class GetEnumElement(arg: CheapExpr, variant: Int, index: Int, size: Int) + extends CheapExpr + case class GetStructElement(arg: CheapExpr, index: Int, size: Int) + extends CheapExpr sealed abstract class ConsExpr extends Expr { def arity: Int } // we need to compile calls to constructors into these - case class MakeEnum(variant: Int, arity: Int, famArities: List[Int]) extends ConsExpr + case class MakeEnum(variant: Int, arity: Int, famArities: List[Int]) + extends ConsExpr case class MakeStruct(arity: Int) extends ConsExpr case object ZeroNat extends ConsExpr { def arity = 0 @@ -175,25 +207,27 @@ object Matchless { private def asCheap(expr: Expr): Option[CheapExpr] = expr match { case c: CheapExpr => Some(c) - case _ => None + case _ => None } - private def maybeMemo[F[_]: Monad](tmp: F[Long])(fn: CheapExpr => F[Expr]): Expr => F[Expr] = - { (arg: Expr) => - asCheap(arg) match { - case Some(c) => fn(c) - case None => - for { - nm <- tmp - bound = LocalAnon(nm) - res <- fn(bound) - } yield Let(Left(bound), arg, res) - } + private def maybeMemo[F[_]: Monad]( + tmp: F[Long] + )(fn: CheapExpr => F[Expr]): Expr => F[Expr] = { (arg: Expr) => + asCheap(arg) match { + case Some(c) => fn(c) + case None => + for { + nm <- tmp + bound = LocalAnon(nm) + res <- fn(bound) + } yield Let(Left(bound), arg, res) } + } private[this] val empty = (PackageName.PredefName, Constructor("EmptyList")) private[this] val cons = (PackageName.PredefName, Constructor("NonEmptyList")) - private[this] val reverseFn = Global(PackageName.PredefName, Identifier.Name("reverse")) + private[this] val reverseFn = + Global(PackageName.PredefName, Identifier.Name("reverse")) // drop all items in the tail after the first time fn returns true // as a result, we have 0 or 1 items where fn is true in the result @@ -201,45 +235,47 @@ object Matchless { def stopAt[A](nel: NonEmptyList[A])(fn: A => Boolean): NonEmptyList[A] = nel match { case NonEmptyList(h, _) if fn(h) => NonEmptyList(h, Nil) - case s@NonEmptyList(_, Nil) => s - case NonEmptyList(h0, h1 :: t) => h0 :: stopAt(NonEmptyList(h1, t))(fn) + case s @ NonEmptyList(_, Nil) => s + case NonEmptyList(h0, h1 :: t) => h0 :: stopAt(NonEmptyList(h1, t))(fn) } // same as fromLet below, but uses RefSpace - def fromLet[A]( - name: Bindable, - rec: RecursionKind, - te: TypedExpr[A])( - variantOf: (PackageName, Constructor) => Option[DataRepr]): Expr = - (for { - c <- RefSpace.allocCounter - expr <- fromLet(name, rec, te, variantOf, c) - } yield expr).run.value + def fromLet[A](name: Bindable, rec: RecursionKind, te: TypedExpr[A])( + variantOf: (PackageName, Constructor) => Option[DataRepr] + ): Expr = + (for { + c <- RefSpace.allocCounter + expr <- fromLet(name, rec, te, variantOf, c) + } yield expr).run.value // we need a TypeEnv to inline the creation of structs and variants def fromLet[F[_]: Monad, A]( - name: Bindable, - rec: RecursionKind, - te: TypedExpr[A], - variantOf: (PackageName, Constructor) => Option[DataRepr], - makeAnon: F[Long]): F[Expr] = { + name: Bindable, + rec: RecursionKind, + te: TypedExpr[A], + variantOf: (PackageName, Constructor) => Option[DataRepr], + makeAnon: F[Long] + ): F[Expr] = { - type UnionMatch = NonEmptyList[(List[LocalAnonMut], BoolExpr, List[(Bindable, Expr)])] - val wildMatch: UnionMatch = NonEmptyList((Nil, TrueConst, Nil), Nil) + type UnionMatch = + NonEmptyList[(List[LocalAnonMut], BoolExpr, List[(Bindable, Expr)])] + val wildMatch: UnionMatch = NonEmptyList((Nil, TrueConst, Nil), Nil) val emptyExpr: Expr = empty match { case (p, c) => variantOf(p, c) match { case Some(DataRepr.Enum(v, s, f)) => MakeEnum(v, s, f) - case other => + case other => /* We assume the structure of Lists to be standard linked lists * Empty cannot be a struct */ // $COVERAGE-OFF$ - throw new IllegalStateException(s"empty List should be an enum, found: $other") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"empty List should be an enum, found: $other" + ) + // $COVERAGE-ON$ } } @@ -249,22 +285,19 @@ object Matchless { def apply(b: Bindable): Expr = slots.get(b) match { case Some(expr) => expr - case None => Local(b) + case None => Local(b) } def lambdaFrees(frees: List[Bindable]): (LambdaState, List[Expr]) = { name match { case None => - val newSlots = frees - .iterator - .zipWithIndex - .map { case (b, idx) => (b, ClosureSlot(idx)) } - .toMap + val newSlots = frees.iterator.zipWithIndex.map { case (b, idx) => + (b, ClosureSlot(idx)) + }.toMap val captures = frees.map(apply(_)) (copy(slots = newSlots), captures) case Some(n) => - val newSlots = frees - .iterator + val newSlots = frees.iterator .filterNot(_ === n) .zipWithIndex .map { case (b, idx) => (b, ClosureSlot(idx)) } @@ -280,11 +313,15 @@ object Matchless { def inLet(b: Bindable): LambdaState = copy(name = Some(b)) } - def loopLetVal(name: Bindable, e: TypedExpr[A], rec: RecursionKind, slots: LambdaState): F[Expr] = { + def loopLetVal( + name: Bindable, + e: TypedExpr[A], + rec: RecursionKind, + slots: LambdaState + ): F[Expr] = { lazy val e0 = loop(e, if (rec.isRecursive) slots.inLet(name) else slots) rec match { case RecursionKind.Recursive => - def letrec(e: Expr): Expr = Let(Right((name, RecursionKind.Recursive)), e, Local(name)) @@ -309,8 +346,7 @@ object Matchless { // but it definitely does in fuzz tests e0.map(letrec) } - } - else { + } else { // otherwise let rec x = fn in x e0.map(letrec) } @@ -320,39 +356,49 @@ object Matchless { def loop(te: TypedExpr[A], slots: LambdaState): F[Expr] = te match { - case TypedExpr.Generic(_, expr) => loop(expr, slots) + case TypedExpr.Generic(_, expr) => loop(expr, slots) case TypedExpr.Annotation(term, _) => loop(term, slots) case TypedExpr.AnnotatedLambda(args, res, _) => val frees = TypedExpr.freeVars(te :: Nil) val (slots1, captures) = slots.lambdaFrees(frees) - loop(res, slots1.unname).map(Lambda(captures, slots.name, args.map(_._1), _)) - case TypedExpr.Global(pack, cons@Constructor(_), _, _) => + loop(res, slots1.unname).map( + Lambda(captures, slots.name, args.map(_._1), _) + ) + case TypedExpr.Global(pack, cons @ Constructor(_), _, _) => Monad[F].pure(variantOf(pack, cons) match { case Some(dr) => dr match { case DataRepr.Enum(v, a, f) => MakeEnum(v, a, f) - case DataRepr.Struct(a) => MakeStruct(a) - case DataRepr.NewType => MakeStruct(1) - case DataRepr.ZeroNat => ZeroNat - case DataRepr.SuccNat => SuccNat + case DataRepr.Struct(a) => MakeStruct(a) + case DataRepr.NewType => MakeStruct(1) + case DataRepr.ZeroNat => ZeroNat + case DataRepr.SuccNat => SuccNat } - // $COVERAGE-OFF$ + // $COVERAGE-OFF$ case None => - throw new IllegalStateException(s"could not find $cons in global data types") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"could not find $cons in global data types" + ) + // $COVERAGE-ON$ }) case TypedExpr.Global(pack, notCons: Bindable, _, _) => Monad[F].pure(Global(pack, notCons)) case TypedExpr.Local(bind, _, _) => Monad[F].pure(slots(bind)) case TypedExpr.App(fn, as, _, _) => - (loop(fn, slots.unname), as.traverse(loop(_, slots.unname))).mapN(App(_, _)) + (loop(fn, slots.unname), as.traverse(loop(_, slots.unname))) + .mapN(App(_, _)) case TypedExpr.Let(a, e, in, r, _) => - (loopLetVal(a, e, r, slots.unname), loop(in, slots)).mapN(Let(Right((a, r)), _, _)) + (loopLetVal(a, e, r, slots.unname), loop(in, slots)) + .mapN(Let(Right((a, r)), _, _)) case TypedExpr.Literal(lit, _, _) => Monad[F].pure(Literal(lit)) case TypedExpr.Match(arg, branches, _) => - (loop(arg, slots.unname), branches.traverse { case (p, te) => loop(te, slots.unname).map((p, _)) }) - .tupled + ( + loop(arg, slots.unname), + branches.traverse { case (p, te) => + loop(te, slots.unname).map((p, _)) + } + ).tupled .flatMap { case (a, b) => matchExpr(a, makeAnon, b) } } @@ -362,9 +408,11 @@ object Matchless { * 2. a total binding to a given name * 3. or we return None indicating not one of these */ - def maybeSimple(p: Pattern[(PackageName, Constructor), Type]): Option[Either[Bindable, Unit]] = + def maybeSimple( + p: Pattern[(PackageName, Constructor), Type] + ): Option[Either[Bindable, Unit]] = p match { - case Pattern.WildCard => Some(Right(())) + case Pattern.WildCard => Some(Right(())) case Pattern.Literal(_) => // Literals are never total None @@ -372,21 +420,21 @@ object Matchless { case Pattern.Named(v, p) => maybeSimple(p) match { case Some(Right(_)) => Some(Left(v)) - case _ => None + case _ => None } case Pattern.StrPat(s) => s match { case NonEmptyList(Pattern.StrPart.WildStr, Nil) => Some(Right(())) case NonEmptyList(Pattern.StrPart.NamedStr(n), Nil) => Some(Left(n)) - case _ => None + case _ => None } case Pattern.ListPat(l) => l match { - case Pattern.ListPart.WildList :: Nil => Some(Right(())) + case Pattern.ListPart.WildList :: Nil => Some(Right(())) case Pattern.ListPart.NamedList(n) :: Nil => Some(Left(n)) - case _ => None + case _ => None } - case Pattern.Annotation(p, _) => maybeSimple(p) + case Pattern.Annotation(p, _) => maybeSimple(p) case Pattern.PositionalStruct((pack, cname), ps) => // Only branch-free structs with no inner names are simple variantOf(pack, cname) match { @@ -399,10 +447,12 @@ object Matchless { } case _ => None } - // $COVERAGE-OFF$ + // $COVERAGE-OFF$ case None => - throw new IllegalStateException(s"could not find $cons in global data types") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"could not find $cons in global data types" + ) + // $COVERAGE-ON$ } case Pattern.Union(h, t) => (h :: t.toList).traverse(maybeSimple).flatMap { inners => @@ -413,7 +463,11 @@ object Matchless { // return the check expression for the check we need to do, and the list of bindings // if must match is true, we know that the pattern must match, so we can potentially remove some checks - def doesMatch(arg: CheapExpr, pat: Pattern[(PackageName, Constructor), Type], mustMatch: Boolean): F[UnionMatch] = { + def doesMatch( + arg: CheapExpr, + pat: Pattern[(PackageName, Constructor), Type], + mustMatch: Boolean + ): F[UnionMatch] = { pat match { case Pattern.WildCard => // this is a total pattern @@ -428,40 +482,37 @@ object Matchless { }) case Pattern.StrPat(items) => val sbinds: List[Bindable] = - items - .toList + items.toList .collect { // that each name is distinct // should be checked in the SourceConverter/TotalityChecking code - case Pattern.StrPart.NamedStr(n) => n + case Pattern.StrPart.NamedStr(n) => n case Pattern.StrPart.NamedChar(n) => n } - val muts = sbinds.traverse { b => makeAnon.map(LocalAnonMut(_)).map((b, _)) } + val muts = sbinds.traverse { b => + makeAnon.map(LocalAnonMut(_)).map((b, _)) + } val pat = items.toList.map { - case Pattern.StrPart.NamedStr(_) => StrPart.IndexStr - case Pattern.StrPart.NamedChar(_) => StrPart.IndexChar - case Pattern.StrPart.WildStr => StrPart.WildStr - case Pattern.StrPart.WildChar => StrPart.WildChar - case Pattern.StrPart.LitStr(s) => StrPart.LitStr(s) - } + case Pattern.StrPart.NamedStr(_) => StrPart.IndexStr + case Pattern.StrPart.NamedChar(_) => StrPart.IndexChar + case Pattern.StrPart.WildStr => StrPart.WildStr + case Pattern.StrPart.WildChar => StrPart.WildChar + case Pattern.StrPart.LitStr(s) => StrPart.LitStr(s) + } muts.map { binds => val ms = binds.map(_._2) - NonEmptyList.of((ms, - MatchString( - arg, - pat, - ms), - binds)) + NonEmptyList.of((ms, MatchString(arg, pat, ms), binds)) } - case lp@Pattern.ListPat(_) => - + case lp @ Pattern.ListPat(_) => lp.toPositionalStruct(empty, cons) match { case Right(p) => doesMatch(arg, p, mustMatch) - case Left((glob, right@NonEmptyList(Pattern.ListPart.Item(_), _))) => + case Left( + (glob, right @ NonEmptyList(Pattern.ListPart.Item(_), _)) + ) => // we have a non-trailing list pattern // to match, this becomes a search problem // we loop over all the matches of p in the list, @@ -479,8 +530,7 @@ object Matchless { makeAnon.map { nm => Some((LocalAnonMut(nm), ln)) } } - (leftF, makeAnon) - .tupled + (leftF, makeAnon).tupled .flatMap { case (optAnonLeft, tmpList) => val anonList = LocalAnonMut(tmpList) @@ -493,26 +543,38 @@ object Matchless { // this shouldn't be possible, since there are no total list matches with // one item since we recurse on a ListPat with the first item being Right // which as we can see above always returns Some(_) - throw new IllegalStateException(s"$right should not be a total match") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"$right should not be a total match" + ) + // $COVERAGE-ON$ case (preLet, expr, binds) => - val letTail = anonList :: preLet val (resLet, leftOpt, resBind) = optAnonLeft match { case Some((anonLeft, ln)) => - val revList = App(reverseFn, NonEmptyList.one(anonLeft)) - (anonLeft :: letTail, Some(anonLeft), (ln, revList) :: binds) + val revList = + App(reverseFn, NonEmptyList.one(anonLeft)) + ( + anonLeft :: letTail, + Some(anonLeft), + (ln, revList) :: binds + ) case None => (letTail, None, binds) } - (resLet, SearchList(anonList, arg, expr, leftOpt), resBind) + ( + resLet, + SearchList(anonList, arg, expr, leftOpt), + resBind + ) } } } - case Left((glob, right@NonEmptyList(_: Pattern.ListPart.Glob, _))) => + case Left( + (glob, right @ NonEmptyList(_: Pattern.ListPart.Glob, _)) + ) => // we search on the right side, so the left will match nothing // this should be banned by SourceConverter/TotalityChecker because // it is confusing, but it can be handled @@ -531,7 +593,7 @@ object Matchless { } } } - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Pattern.Annotation(p, _) => @@ -541,29 +603,40 @@ object Matchless { // we assume the patterns have already been optimized // so that useless total patterns have been replaced with _ type Locals = Chain[(LocalAnonMut, Expr)] - def asStruct(getter: Int => CheapExpr): WriterT[F, Locals, UnionMatch] = { + def asStruct( + getter: Int => CheapExpr + ): WriterT[F, Locals, UnionMatch] = { // we have an and of a series of ors: // (m1 + m2 + m3) * (m4 + m5 + m6) ... = // we need to multiply them all out into a single set of ors - def operate(pat: Pattern[(PackageName, Constructor), Type], idx: Int): WriterT[F, Locals, UnionMatch] = + def operate( + pat: Pattern[(PackageName, Constructor), Type], + idx: Int + ): WriterT[F, Locals, UnionMatch] = maybeSimple(pat) match { case Some(Right(())) => // this is a total match WriterT.value(wildMatch) case Some(Left(v)) => // this is just an alias - WriterT.value(NonEmptyList((Nil, TrueConst, (v, getter(idx)) :: Nil), Nil)) + WriterT.value( + NonEmptyList((Nil, TrueConst, (v, getter(idx)) :: Nil), Nil) + ) case None => // we make an anonymous variable and write to that: for { nm <- WriterT.valueT[F, Locals, Long](makeAnon) lam = LocalAnonMut(nm) - um <- WriterT.valueT[F, Locals, UnionMatch](doesMatch(lam, pat, mustMatch)) + um <- WriterT.valueT[F, Locals, UnionMatch]( + doesMatch(lam, pat, mustMatch) + ) // if this is a total match, we don't need to do the getter at all - chain = if (um == wildMatch) Chain.empty else Chain.one((lam, getter(idx))) + chain = + if (um == wildMatch) Chain.empty + else Chain.one((lam, getter(idx))) _ <- WriterT.tell[F, Locals](chain) } yield um - } + } val ands: WriterT[F, Locals, List[UnionMatch]] = params.zipWithIndex @@ -571,20 +644,24 @@ object Matchless { ands.map(NonEmptyList.fromList(_) match { case None => wildMatch - case Some(nel) => product(nel) { case ((l1, o1, b1), (l2, o2, b2)) => - (l1 ::: l2, o1 && o2, b1 ::: b2) - } + case Some(nel) => + product(nel) { case ((l1, o1, b1), (l2, o2, b2)) => + (l1 ::: l2, o1 && o2, b1 ::: b2) + } }) } def forStruct(size: Int) = - asStruct { pos => GetStructElement(arg, pos, size) } - .run + asStruct { pos => GetStructElement(arg, pos, size) }.run .map { case (anons, ums) => ums.map { case (pre, cond, bind) => - val pre1 = anons.foldLeft(pre) { case (pre, (a, _)) => a :: pre } + val pre1 = anons.foldLeft(pre) { case (pre, (a, _)) => + a :: pre + } // we have to set these variables before we can evaluate the condition - val cond1 = anons.foldLeft(cond) { case (c, (a, e)) => SetMut(a, e) && c } + val cond1 = anons.foldLeft(cond) { case (c, (a, e)) => + SetMut(a, e) && c + } (pre1, cond1, bind) } } @@ -592,73 +669,97 @@ object Matchless { variantOf(pack, cname) match { case Some(dr) => dr match { - case DataRepr.Struct(size) => forStruct(size) - case DataRepr.NewType => forStruct(1) + case DataRepr.Struct(size) => forStruct(size) + case DataRepr.NewType => forStruct(1) case DataRepr.Enum(vidx, size, f) => // if we match the variant, then treat it as a struct - val cv: BoolExpr = if (mustMatch) TrueConst else CheckVariant(arg, vidx, size, f) - asStruct { pos => GetEnumElement(arg, vidx, pos, size) } - .run + val cv: BoolExpr = + if (mustMatch) TrueConst + else CheckVariant(arg, vidx, size, f) + asStruct { pos => GetEnumElement(arg, vidx, pos, size) }.run .map { case (anons, ums) => if (ums == wildMatch) { // we just need to check the variant - assert(anons.isEmpty, "anons must by construction always be empty on wildMatch") + assert( + anons.isEmpty, + "anons must by construction always be empty on wildMatch" + ) NonEmptyList((Nil, cv, Nil), Nil) - } - else { + } else { // now we need to set up the binds if the variant is right - val cond1 = anons.foldLeft(cv) { case (c, (mut, expr)) => - c && SetMut(mut, expr) + val cond1 = anons.foldLeft(cv) { + case (c, (mut, expr)) => + c && SetMut(mut, expr) } ums.map { case (pre, cond, b) => - val pre1 = anons.foldLeft(pre) { case (pre, (mut, _)) => mut :: pre } + val pre1 = anons.foldLeft(pre) { + case (pre, (mut, _)) => mut :: pre + } (pre1, cond1 && cond, b) } } } case DataRepr.ZeroNat => - val cv: BoolExpr = if (mustMatch) TrueConst else EqualsNat(arg, DataRepr.ZeroNat) + val cv: BoolExpr = + if (mustMatch) TrueConst + else EqualsNat(arg, DataRepr.ZeroNat) Monad[F].pure(NonEmptyList((Nil, cv, Nil), Nil)) case DataRepr.SuccNat => params match { case single :: Nil => // if we match, we recur on the inner pattern and prev of current - val check = if (mustMatch) TrueConst else EqualsNat(arg, DataRepr.SuccNat) + val check = + if (mustMatch) TrueConst + else EqualsNat(arg, DataRepr.SuccNat) for { nm <- makeAnon loc = LocalAnonMut(nm) prev = PrevNat(arg) rest <- doesMatch(loc, single, mustMatch) - } yield rest.map { case (preLets, cond, res) => (loc ::preLets, check && SetMut(loc, prev) && cond, res) } + } yield rest.map { case (preLets, cond, res) => + ( + loc :: preLets, + check && SetMut(loc, prev) && cond, + res + ) + } case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"expected typechecked Nat to only have one param, found: $other in $pat") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"expected typechecked Nat to only have one param, found: $other in $pat" + ) + // $COVERAGE-ON$ } } case None => // $COVERAGE-OFF$ - throw new IllegalStateException(s"could not find $cons in global data types") - // $COVERAGE-ON$ - } + throw new IllegalStateException( + s"could not find $cons in global data types" + ) + // $COVERAGE-ON$ + } case Pattern.Union(h, ts) => // note this list is exactly as long as h :: ts - val unionMustMatch = NonEmptyList.fromListUnsafe(List.fill(ts.size)(false) ::: mustMatch :: Nil) - ((h :: ts).zip(unionMustMatch)).traverse { case (p, mm) => doesMatch(arg, p, mm) }.map { nene => - val nel = nene.flatten - // at the first total match, we can stop - stopAt(nel) { - case (_, TrueConst, _) => true - case _ => false + val unionMustMatch = NonEmptyList.fromListUnsafe( + List.fill(ts.size)(false) ::: mustMatch :: Nil + ) + ((h :: ts) + .zip(unionMustMatch)) + .traverse { case (p, mm) => doesMatch(arg, p, mm) } + .map { nene => + val nel = nene.flatten + // at the first total match, we can stop + stopAt(nel) { + case (_, TrueConst, _) => true + case _ => false + } } - } } } def lets(binds: List[(Bindable, Expr)], in: Expr): Expr = binds.foldRight(in) { case ((b, e), r) => - val arg = Right((b, RecursionKind.NonRecursive)) Let(arg, e, r) } @@ -668,12 +769,27 @@ object Matchless { LetMut(anon, rest) } - def matchExpr(arg: Expr, tmp: F[Long], branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], Expr)]): F[Expr] = { - - def recur(arg: CheapExpr, branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], Expr)]): F[Expr] = { + def matchExpr( + arg: Expr, + tmp: F[Long], + branches: NonEmptyList[ + (Pattern[(PackageName, Constructor), Type], Expr) + ] + ): F[Expr] = { + + def recur( + arg: CheapExpr, + branches: NonEmptyList[ + (Pattern[(PackageName, Constructor), Type], Expr) + ] + ): F[Expr] = { val (p1, r1) = branches.head - def loop(cbs: NonEmptyList[(List[LocalAnonMut], BoolExpr, List[(Bindable, Expr)])]): F[Expr] = + def loop( + cbs: NonEmptyList[ + (List[LocalAnonMut], BoolExpr, List[(Bindable, Expr)]) + ] + ): F[Expr] = cbs match { case NonEmptyList((b0, TrueConst, binds), _) => // this is a total match, no fall through @@ -716,7 +832,10 @@ object Matchless { // toy matcher to see the structure // Left means match any number of items, like *_ - def matchList[A, B: Monoid](items: List[A], pattern: List[Either[List[A] => B, A => Option[B]]]): Option[B] = + def matchList[A, B: Monoid]( + items: List[A], + pattern: List[Either[List[A] => B, A => Option[B]]] + ): Option[B] = pattern match { case Nil => if (items.isEmpty) Some(Monoid[B].empty) @@ -725,7 +844,7 @@ object Matchless { items match { case ih :: it => fn(ih) match { - case None => None + case None => None case Some(b) => matchList(it, pt).map(Monoid[B].combine(b, _)) } case Nil => None @@ -734,13 +853,13 @@ object Matchless { case Left(lstFn) :: Nil => Some(lstFn(items)) - case Left(lstFn) :: (pt@(Left(_) :: _)) => + case Left(lstFn) :: (pt @ (Left(_) :: _)) => // it is ambiguous how much to absorb // so, just assume lstFn gets nothing matchList(items, pt) .map(Monoid.combine(lstFn(Nil), _)) - case Left(lstFn) :: (pt@(Right(_) :: _))=> + case Left(lstFn) :: (pt @ (Right(_) :: _)) => var revLeft: List[A] = Nil var it = items var result: Option[B] = None @@ -756,11 +875,11 @@ object Matchless { } } result - /* - * The above should be an imperative version - * of this code. The imperative code - * is easier to translate into low level - * instructions + /* + * The above should be an imperative version + * of this code. The imperative code + * is easier to translate into low level + * instructions items .toStream .mapWithIndex { (a, idx) => afn(a).map((_, idx)) } @@ -776,13 +895,14 @@ object Matchless { } } .headOption - */ + */ } - /** - * return the expanded product of sums - */ - def product[A1](sum: NonEmptyList[NonEmptyList[A1]])(prod: (A1, A1) => A1): NonEmptyList[A1] = + /** return the expanded product of sums + */ + def product[A1]( + sum: NonEmptyList[NonEmptyList[A1]] + )(prod: (A1, A1) => A1): NonEmptyList[A1] = sum match { case NonEmptyList(h, Nil) => // this (a1 + a2 + a3) case diff --git a/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala index b449f6632..dd269e719 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala @@ -6,34 +6,35 @@ import cats.implicits._ object MatchlessFromTypedExpr { // compile a set of packages given a set of external remappings - def compile[A](pm: PackageMap.Typed[A])(implicit ec: Par.EC): Map[PackageName, List[(Bindable, Matchless.Expr)]] = { + def compile[A]( + pm: PackageMap.Typed[A] + )(implicit ec: Par.EC): Map[PackageName, List[(Bindable, Matchless.Expr)]] = { val gdr = pm.getDataRepr // on JS Par.F[A] is actually Id[A], so we need to hold hands a bit - val allItemsList = pm.toMap - .toList - .traverse[Par.F, (PackageName, List[(Bindable, Matchless.Expr)])] { case (pname, pack) => - val lets = pack.program.lets - - Par.start { - val exprs: List[(Bindable, Matchless.Expr)] = - rankn.RefSpace - .allocCounter - .flatMap { c => - lets - .traverse { - case (name, rec, te) => - Matchless.fromLet(name, rec, te, gdr, c) - .map((name, _)) - } - } - .run - .value - - (pname, exprs) - } + val allItemsList = pm.toMap.toList + .traverse[Par.F, (PackageName, List[(Bindable, Matchless.Expr)])] { + case (pname, pack) => + val lets = pack.program.lets + + Par.start { + val exprs: List[(Bindable, Matchless.Expr)] = + rankn.RefSpace.allocCounter + .flatMap { c => + lets + .traverse { case (name, rec, te) => + Matchless + .fromLet(name, rec, te, gdr, c) + .map((name, _)) + } + } + .run + .value + + (pname, exprs) + } } // JS needs this to not see through the Par.F as Id diff --git a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala index 2acdcec02..61cdae9d0 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala @@ -16,7 +16,9 @@ object MatchlessToValue { import Matchless._ // reuse some cache structures across a number of calls - def traverse[F[_]: Functor](me: F[Expr])(resolve: (PackageName, Identifier) => Eval[Value]): F[Eval[Value]] = { + def traverse[F[_]: Functor]( + me: F[Expr] + )(resolve: (PackageName, Identifier) => Eval[Value]): F[Eval[Value]] = { val env = new Impl.Env(resolve) val fns = Functor[F].map(me) { expr => env.loop(expr) @@ -42,9 +44,10 @@ object MatchlessToValue { case MakeEnum(variant, arity, _) => if (arity == 0) SumValue(variant, UnitValue) else if (arity == 1) { - FnValue { case NonEmptyList(v, _) => SumValue(variant, ProductValue.single(v)) } - } - else + FnValue { case NonEmptyList(v, _) => + SumValue(variant, ProductValue.single(v)) + } + } else // arity > 1 FnValue { args => val prod = ProductValue.fromList(args.toList) @@ -53,9 +56,10 @@ object MatchlessToValue { case MakeStruct(arity) => if (arity == 0) UnitValue else if (arity == 1) FnValue.identity - else FnValue { args => - ProductValue.fromList(args.toList) - } + else + FnValue { args => + ProductValue.fromList(args.toList) + } case ZeroNat => zeroNat case SuccNat => succNat } @@ -65,10 +69,11 @@ object MatchlessToValue { val uninit: Value = ExternalValue(Uninitialized) final case class Scope( - locals: Map[Bindable, Eval[Value]], - anon: LongMap[Value], - muts: MLongMap[Value], - slots: Vector[Value]) { + locals: Map[Bindable, Eval[Value]], + anon: LongMap[Value], + muts: MLongMap[Value], + slots: Vector[Value] + ) { def let(b: Bindable, v: Eval[Value]): Scope = copy(locals = locals.updated(b, v)) @@ -92,22 +97,26 @@ object MatchlessToValue { def capture(it: Vector[Value], name: Option[Bindable]): Scope = Scope( name match { - case None => Map.empty + case None => Map.empty case Some(n) => Map((n, locals(n))) }, LongMap.empty, MLongMap(), - it) + it + ) } object Scope { - def empty(): Scope = Scope(Map.empty, LongMap.empty, MLongMap(), Vector.empty) + def empty(): Scope = + Scope(Map.empty, LongMap.empty, MLongMap(), Vector.empty) } sealed abstract class Scoped[A] { def apply(s: Scope): A def map[B](fn: A => B): Scoped[B] - def and(that: Scoped[Boolean])(implicit ev: Is[A, Boolean]): Scoped[Boolean] = { + def and( + that: Scoped[Boolean] + )(implicit ev: Is[A, Boolean]): Scoped[Boolean] = { // boolean conditions are generally never static, so we can't easily exercise // this code if we specialize it. So, we assume it is dynamic here val thisBool = ev.substitute[Scoped](this) @@ -140,7 +149,9 @@ object MatchlessToValue { def pure[A](a: A): Scoped[A] = Static(a) override def map[A, B](aa: Scoped[A])(fn: A => B): Scoped[B] = aa.map(fn) - override def map2[A, B, C](aa: Scoped[A], ab: Scoped[B])(fn: (A, B) => C): Scoped[C] = + override def map2[A, B, C](aa: Scoped[A], ab: Scoped[B])( + fn: (A, B) => C + ): Scoped[C] = (aa, ab) match { case (Static(a), Static(b)) => Static(fn(a, b)) case (Static(a), db) => @@ -162,7 +173,6 @@ object MatchlessToValue { private def boolExpr(ix: BoolExpr): Scoped[Boolean] = ix match { case EqualsLit(expr, lit) => - val litAny = lit.unboxToAny loop(expr).map { e => @@ -205,7 +215,9 @@ object MatchlessToValue { matchString(arg, pat, 0) != null } case _ => - val bary = binds.iterator.collect { case LocalAnonMut(id) => id }.toArray + val bary = binds.iterator.collect { case LocalAnonMut(id) => + id + }.toArray // this may be static val matchScope = loop(str).map { str => @@ -222,8 +234,7 @@ object MatchlessToValue { idx = idx + 1 } true - } - else false + } else false } } @@ -243,19 +254,24 @@ object MatchlessToValue { var res = false while (currentList ne null) { currentList match { - case nonempty@VList.Cons(_, tail) => + case nonempty @ VList.Cons(_, tail) => scope.updateMut(mutV, nonempty) res = checkF(scope) if (res) { currentList = null } else { currentList = tail } case _ => currentList = null - // we don't match empty lists + // we don't match empty lists } } res } - case SearchList(LocalAnonMut(mutV), init, check, Some(LocalAnonMut(left))) => + case SearchList( + LocalAnonMut(mutV), + init, + check, + Some(LocalAnonMut(left)) + ) => val initF = loop(init) val checkF = boolExpr(check) @@ -266,7 +282,7 @@ object MatchlessToValue { var leftList = VList.VNil while (currentList ne null) { currentList match { - case nonempty@VList.Cons(head, tail) => + case nonempty @ VList.Cons(head, tail) => scope.updateMut(mutV, nonempty) scope.updateMut(left, leftList) res = checkF(scope) @@ -277,14 +293,19 @@ object MatchlessToValue { } case _ => currentList = null - // we don't match empty lists + // we don't match empty lists } } res } } - def buildLoop(caps: Vector[Scoped[Value]], fnName: Bindable, args: NonEmptyList[Bindable], body: Scoped[Value]): Scoped[Value] = { + def buildLoop( + caps: Vector[Scoped[Value]], + fnName: Bindable, + args: NonEmptyList[Bindable], + body: Scoped[Value] + ): Scoped[Value] = { val argCount = args.length val argNames: Array[Bindable] = args.toList.toArray if (caps.isEmpty) { @@ -324,8 +345,7 @@ object MatchlessToValue { } Static(fn) - } - else { + } else { Dynamic { scope => // TODO this maybe isn't helpful // it doesn't matter if the scope @@ -384,13 +404,12 @@ object MatchlessToValue { resFn(scope2) } Static(fn) - } - else { + } else { val capScoped = caps.map(loop).toVector Dynamic { scope => val scope1 = scope .capture(capScoped.map { scoped => scoped(scope) }, name) - + // hopefully optimization/normalization has lifted anything // that doesn't depend on argV above this lambda FnValue { argV => @@ -409,11 +428,11 @@ object MatchlessToValue { // this has to be lazy because it could be // in this package, which isn't complete yet Dynamic { (_: Scope) => res.value } - case Local(b) => Dynamic(_.locals(b).value) - case LocalAnon(a) => Dynamic(_.anon(a)) - case LocalAnonMut(m) => Dynamic(_.muts(m)) + case Local(b) => Dynamic(_.locals(b).value) + case LocalAnon(a) => Dynamic(_.anon(a)) + case LocalAnonMut(m) => Dynamic(_.muts(m)) case ClosureSlot(idx) => Dynamic(_.slots(idx)) - case App(expr, args) => + case App(expr, args) => // TODO: App(LoopFn(.. // can be optimized into a while // loop, but there isn't any prior optimization @@ -425,7 +444,8 @@ object MatchlessToValue { Applicative[Scoped].map2(exprFn, argsFn) { (fn, args) => fn.applyAll(args) } - case Let(Right((n1, r)), loopFn@LoopFn(_, n2, _, _), Local(n3)) if (n1 === n3) && (n1 === n2) && r.isRecursive => + case Let(Right((n1, r)), loopFn @ LoopFn(_, n2, _, _), Local(n3)) + if (n1 === n3) && (n1 === n2) && r.isRecursive => // LoopFn already correctly handles recursion loop(loopFn) case Let(localOrBind, value, in) => @@ -447,8 +467,7 @@ object MatchlessToValue { scope1 } - } - else { + } else { inF.withScope { (scope: Scope) => val vv = Eval.now(valueF(scope)) scope.let(b, vv) @@ -462,7 +481,7 @@ object MatchlessToValue { } case LetMut(LocalAnonMut(l), in) => loop(in) match { - case s@Static(_) => s + case s @ Static(_) => s case Dynamic(inF) => Dynamic { (scope: Scope) => // we make sure there is @@ -515,8 +534,7 @@ object MatchlessToValue { if (sz == 1) { // this is a newtype loopFn - } - else { + } else { loop(expr).map { p => p.asProduct.get(idx) } @@ -536,23 +554,29 @@ object MatchlessToValue { } } - + private[this] val emptyStringArray: Array[String] = new Array[String](0) def matchString( - str: String, - pat: List[StrPart], - binds: Int): Array[String] = { + str: String, + pat: List[StrPart], + binds: Int + ): Array[String] = { import Matchless.StrPart._ val strLen = str.length() - val results = if (binds > 0) new Array[String](binds) else emptyStringArray + val results = + if (binds > 0) new Array[String](binds) else emptyStringArray def loop(offset: Int, pat: List[StrPart], next: Int): Boolean = pat match { case Nil => offset == strLen case LitStr(expect) :: tail => val len = expect.length - str.regionMatches(offset, expect, 0, len) && loop(offset + len, tail, next) + str.regionMatches(offset, expect, 0, len) && loop( + offset + len, + tail, + next + ) case (c: CharPart) :: tail => try { val nextOffset = str.offsetByCodePoints(offset, 1) @@ -560,12 +584,10 @@ object MatchlessToValue { if (c.capture) { results(next) = str.substring(offset, nextOffset) next + 1 - } - else next + } else next loop(nextOffset, tail, n) - } - catch { + } catch { case _: IndexOutOfBoundsException => false } case (h: Glob) :: tail => @@ -601,7 +623,7 @@ object MatchlessToValue { if (h.capture) { results(next) = str.substring(offset, off1) } - true + true } case LitStr(expect) :: tail2 => val next1 = if (h.capture) next + 1 else next @@ -618,7 +640,8 @@ object MatchlessToValue { if (candidate >= 0) { // we have to skip the current expect string val nextOff = candidate + expect.length - val check1 = canMatch(nextOff) && loop(nextOff, tail2, next1) + val check1 = + canMatch(nextOff) && loop(nextOff, tail2, next1) if (check1) { // this was a match, write into next if needed if (h.capture) { @@ -626,13 +649,13 @@ object MatchlessToValue { } result = true start = -1 - } - else { + } else { // we couldn't match here, try just after candidate - start = candidate + Character.charCount(str.codePointAt(candidate)) + start = candidate + Character.charCount( + str.codePointAt(candidate) + ) } - } - else { + } else { // no more candidates start = -1 } diff --git a/core/src/main/scala/org/bykn/bosatsu/MemoryMain.scala b/core/src/main/scala/org/bykn/bosatsu/MemoryMain.scala index fc96cbb8c..353ecc516 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MemoryMain.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MemoryMain.scala @@ -7,27 +7,33 @@ import scala.collection.immutable.SortedMap import cats.implicits._ -class MemoryMain[F[_], K: Ordering](split: K => List[String])( - implicit val pathArg: Argument[K], - val innerMonad: MonadError[F, Throwable]) extends MainModule[Kleisli[F, MemoryMain.State[K], *]] { +class MemoryMain[F[_], K: Ordering](split: K => List[String])(implicit + val pathArg: Argument[K], + val innerMonad: MonadError[F, Throwable] +) extends MainModule[Kleisli[F, MemoryMain.State[K], *]] { type IO[A] = Kleisli[F, MemoryMain.State[K], A] type Path = K def readPath(p: Path): IO[String] = - Kleisli.ask[F, MemoryMain.State[K]] + Kleisli + .ask[F, MemoryMain.State[K]] .flatMap { files => files.get(p) match { case Some(MemoryMain.FileContent.Str(res)) => moduleIOMonad.pure(res) - case other => moduleIOMonad.raiseError(new Exception(s"expect String content, found: $other")) + case other => + moduleIOMonad.raiseError( + new Exception(s"expect String content, found: $other") + ) } } def resolvePath: Option[(Path, PackageName) => IO[Option[Path]]] = None def readPackages(paths: List[Path]): IO[List[Package.Typed[Unit]]] = - Kleisli.ask[F, MemoryMain.State[K]] + Kleisli + .ask[F, MemoryMain.State[K]] .flatMap { files => paths .traverse { path => @@ -36,14 +42,16 @@ class MemoryMain[F[_], K: Ordering](split: K => List[String])( moduleIOMonad.pure(res) case other => moduleIOMonad.raiseError[List[Package.Typed[Unit]]]( - new Exception(s"expect Packages content, found: $other")) + new Exception(s"expect Packages content, found: $other") + ) } } .map(_.flatten) } def readInterfaces(paths: List[Path]): IO[List[Package.Interface]] = - Kleisli.ask[F, MemoryMain.State[K]] + Kleisli + .ask[F, MemoryMain.State[K]] .flatMap { files => paths .traverse { path => @@ -52,7 +60,8 @@ class MemoryMain[F[_], K: Ordering](split: K => List[String])( moduleIOMonad.pure(res) case other => moduleIOMonad.raiseError[List[Package.Interface]]( - new Exception(s"expect Packages content, found: $other")) + new Exception(s"expect Packages content, found: $other") + ) } } .map(_.flatten) @@ -66,22 +75,27 @@ class MemoryMain[F[_], K: Ordering](split: K => List[String])( def runWith( files: Iterable[(K, String)], packages: Iterable[(K, List[Package.Typed[Unit]])] = Nil, - interfaces: Iterable[(K, List[Package.Interface])] = Nil)(cmd: List[String]): F[Output] = - run(cmd) match { - case Left(msg) => - innerMonad.raiseError[Output](new Exception(s"got the help message for: $cmd: $msg")) - case Right(io) => - val state0 = files.foldLeft(SortedMap.empty[K, MemoryMain.FileContent]) { case (st, (k, str)) => + interfaces: Iterable[(K, List[Package.Interface])] = Nil + )(cmd: List[String]): F[Output] = + run(cmd) match { + case Left(msg) => + innerMonad.raiseError[Output]( + new Exception(s"got the help message for: $cmd: $msg") + ) + case Right(io) => + val state0 = + files.foldLeft(SortedMap.empty[K, MemoryMain.FileContent]) { + case (st, (k, str)) => st.updated(k, MemoryMain.FileContent.Str(str)) - } - val state1 = packages.foldLeft(state0) { case (st, (k, packs)) => - st.updated(k, MemoryMain.FileContent.Packages(packs)) - } - val state2 = interfaces.foldLeft(state1) { case (st, (k, ifs)) => - st.updated(k, MemoryMain.FileContent.Interfaces(ifs)) - } - io.run(state2) + } + val state1 = packages.foldLeft(state0) { case (st, (k, packs)) => + st.updated(k, MemoryMain.FileContent.Packages(packs)) } + val state2 = interfaces.foldLeft(state1) { case (st, (k, ifs)) => + st.updated(k, MemoryMain.FileContent.Interfaces(ifs)) + } + io.run(state2) + } def pathPackage(roots: List[Path], packFile: Path): Option[PackageName] = { val fparts = split(packFile) @@ -91,8 +105,7 @@ class MemoryMain[F[_], K: Ordering](split: K => List[String])( if (fparts.startsWith(splitP)) { val parts = fparts.drop(splitP.length) PackageName.parse(parts.mkString("/")) - } - else None + } else None } roots.collectFirstSome(getP) @@ -102,7 +115,6 @@ class MemoryMain[F[_], K: Ordering](split: K => List[String])( Kleisli(_ => innerMonad.pure(a)) } - object MemoryMain { sealed abstract class FileContent object FileContent { diff --git a/core/src/main/scala/org/bykn/bosatsu/NameKind.scala b/core/src/main/scala/org/bykn/bosatsu/NameKind.scala index f0b9d06c9..9932cdc9c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/NameKind.scala +++ b/core/src/main/scala/org/bykn/bosatsu/NameKind.scala @@ -4,14 +4,24 @@ import Identifier.Bindable sealed abstract class NameKind[T] object NameKind { - case class Let[T](name: Bindable, recursive: RecursionKind, value: TypedExpr[T]) extends NameKind[T] + case class Let[T]( + name: Bindable, + recursive: RecursionKind, + value: TypedExpr[T] + ) extends NameKind[T] case class Constructor[T]( - cn: Identifier.Constructor, - params: List[(Bindable, rankn.Type)], - defined: rankn.DefinedType[Kind.Arg], - valueType: rankn.Type) extends NameKind[T] - case class Import[T](fromPack: Package.Interface, originalName: Identifier) extends NameKind[T] - case class ExternalDef[T](pack: PackageName, defName: Identifier, defType: rankn.Type) extends NameKind[T] + cn: Identifier.Constructor, + params: List[(Bindable, rankn.Type)], + defined: rankn.DefinedType[Kind.Arg], + valueType: rankn.Type + ) extends NameKind[T] + case class Import[T](fromPack: Package.Interface, originalName: Identifier) + extends NameKind[T] + case class ExternalDef[T]( + pack: PackageName, + defName: Identifier, + defType: rankn.Type + ) extends NameKind[T] def externals[T](from: Package.Typed[T]): Iterable[ExternalDef[T]] = { val prog = from.program @@ -24,7 +34,10 @@ object NameKind { } } - def apply[T](from: Package.Typed[T], item: Identifier): Option[NameKind[T]] = { + def apply[T]( + from: Package.Typed[T], + item: Identifier + ): Option[NameKind[T]] = { val prog = from.program def getLet: Option[NameKind[T]] = diff --git a/core/src/main/scala/org/bykn/bosatsu/Operators.scala b/core/src/main/scala/org/bykn/bosatsu/Operators.scala index 17b327970..971d15b4f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Operators.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Operators.scala @@ -11,7 +11,7 @@ object Operators { val leftDone = left.length <= idx val rightDone = right.length <= idx (leftDone, rightDone) match { - case (true, true) => 0 + case (true, true) => 0 case (true, false) => 1 case (false, true) => -1 case (false, false) => @@ -21,7 +21,8 @@ object Operators { else { Integer.compare( priorityMap.getOrElse(lc, Int.MaxValue), - priorityMap.getOrElse(rc, Int.MaxValue)) + priorityMap.getOrElse(rc, Int.MaxValue) + ) } } } @@ -30,26 +31,19 @@ object Operators { else loop(0) } - /** - * strings for operators allowed in single character - * operators (excludes = and .) - */ + /** strings for operators allowed in single character operators (excludes = + * and .) + */ val singleToks = - List( - "/", "%", "*", - "-", "+", - "<", ">", - "!", "$", - "&", "^", "|", - "?", "~").map(_.intern) + List("/", "%", "*", "-", "+", "<", ">", "!", "$", "&", "^", "|", "?", "~") + .map(_.intern) private def from(strs: Iterable[String]): P[Unit] = P.stringIn(strs).void - /** - * strings for operators allowed in single character - * operators includes singleToks and . and = - */ + /** strings for operators allowed in single character operators includes + * singleToks and . and = + */ val multiToks: List[String] = ".".intern :: singleToks ::: List("=".intern) @@ -57,22 +51,17 @@ object Operators { from(multiToks) private val priorityMap: Map[String, Int] = - multiToks - .iterator - .zipWithIndex - .toMap - - /** - * Here are a list of operators we allow - */ + multiToks.iterator.zipWithIndex.toMap + + /** Here are a list of operators we allow + */ val operatorToken: P[String] = { val singles = from(singleToks) // write this in a way to avoid backtracking (((P.string("<-") | P.char('=') | P.string("->")) ~ multiToksP.rep).void | (singles ~ multiToksP.rep0).void | - multiToksP.rep(min = 2).void) - .string + multiToksP.rep(min = 2).void).string .map(_.intern) } @@ -87,17 +76,19 @@ object Operators { object Formula { case class Sym[A](value: A) extends Formula[A] - case class Op[A](left: Formula[A], op: String, right: Formula[A]) extends Formula[A] - - /** - * 1 * 2 + 3 => (1 * 2) + 3 - * 1 * 2 * 3 => ((1 * 2) * 3) - */ - def toFormula[A](init: Formula[A], rest: List[(String, Formula[A])]): Formula[A] = + case class Op[A](left: Formula[A], op: String, right: Formula[A]) + extends Formula[A] + + /** 1 * 2 + 3 => (1 * 2) + 3 1 * 2 * 3 => ((1 * 2) * 3) + */ + def toFormula[A]( + init: Formula[A], + rest: List[(String, Formula[A])] + ): Formula[A] = rest match { - case Nil => init + case Nil => init case (op, next) :: Nil => Op(init, op, next) - case (op1, next1) :: (right@((op2, next2) :: tail)) => + case (op1, next1) :: (right @ ((op2, next2) :: tail)) => val c = compareOperator(op1, op2) if (c > 0) { // right binds tighter @@ -106,36 +97,35 @@ object Operators { // in this example, then starting again val f2 = Op(next1, op2, next2) toFormula(init, (op1, f2) :: tail) - } - else { + } else { // 1 + 2 + 3 => (1 + 2) + 3 // 1 * 2 + 3 => (1 * 2) + 3 toFormula(Op(init, op1, next1), right) } } - /** - * Parse a chain of at least 1 operator being applied - * with the operator precedence handled by the formula - */ + /** Parse a chain of at least 1 operator being applied with the operator + * precedence handled by the formula + */ def infixOps1[A](p: P[A]): P[A => Formula[A]] = { val opA = operatorToken ~ (Parser.maybeSpacesAndLines.with1 *> p) val chain: P[NonEmptyList[(String, A)]] = P.repSep(opA, min = 1, sep = Parser.maybeSpace) chain.map { rest => - - { (a: A) => toFormula(Sym(a), rest.toList.map { case (o, s) => (o, Sym(s)) }) } + { (a: A) => + toFormula(Sym(a), rest.toList.map { case (o, s) => (o, Sym(s)) }) + } } } - /** - * An a formula is a series of A's separated by spaces, with - * the correct parenthesis - */ + + /** An a formula is a series of A's separated by spaces, with the correct + * parenthesis + */ def parser[A](p: P[A]): P[Formula[A]] = (p ~ (Parser.maybeSpace.with1 *> infixOps1(p)).?) .map { - case (a, None) => Sym(a) + case (a, None) => Sym(a) case (a, Some(f)) => f(a) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/OptIndent.scala b/core/src/main/scala/org/bykn/bosatsu/OptIndent.scala index 53e9af25f..7b1b67c61 100644 --- a/core/src/main/scala/org/bykn/bosatsu/OptIndent.scala +++ b/core/src/main/scala/org/bykn/bosatsu/OptIndent.scala @@ -14,7 +14,7 @@ sealed abstract class OptIndent[A] { def sepDoc: Doc = this match { - case OptIndent.SameLine(_) => Doc.space + case OptIndent.SameLine(_) => Doc.space case OptIndent.NotSameLine(_) => Doc.empty } @@ -43,7 +43,8 @@ object OptIndent { NotSameLine(toPadIndent) case class SameLine[A](get: A) extends OptIndent[A] - case class NotSameLine[A](toPadIndent: Padding[Indented[A]]) extends OptIndent[A] { + case class NotSameLine[A](toPadIndent: Padding[Indented[A]]) + extends OptIndent[A] { def get: A = toPadIndent.padded.value } @@ -56,7 +57,7 @@ object OptIndent { val dpi = Document[Padding[Indented[A]]] Document.instance[OptIndent[A]] { - case SameLine(a) => da.document(a) + case SameLine(a) => da.document(a) case NotSameLine(p) => dpi.document(p) } } @@ -64,20 +65,23 @@ object OptIndent { def indy[A](p: Indy[A]): Indy[OptIndent[A]] = { val ind = Indented.indy(p) // we need to read at least 1 new line here - val not = ind.mapF { p => Padding.parser1(p).map(notSame[A](_)): P[OptIndent[A]] } + val not = ind.mapF { p => + Padding.parser1(p).map(notSame[A](_)): P[OptIndent[A]] + } val sm = p.map(same[A](_)) not <+> sm } - /** - * A: B or - * A: - * B - */ + /** A: B or A: B + */ def block[A, B](first: Indy[A], next: Indy[B]): Indy[(A, OptIndent[B])] = blockLike(first, next, (maybeSpace ~ P.char(':')).void) - def blockLike[A, B](first: Indy[A], next: Indy[B], sep: P0[Unit]): Indy[(A, OptIndent[B])] = + def blockLike[A, B]( + first: Indy[A], + next: Indy[B], + sep: P0[Unit] + ): Indy[(A, OptIndent[B])] = first .cutLeftP(sep ~ maybeSpace) .cutThen(OptIndent.indy(next)) diff --git a/core/src/main/scala/org/bykn/bosatsu/Package.scala b/core/src/main/scala/org/bykn/bosatsu/Package.scala index dfac98a3e..870687fdd 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Package.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Package.scala @@ -11,14 +11,16 @@ import rankn._ import Parser.{spaces, Combinators} import FixType.Fix -/** - * Represents a package over its life-cycle: from parsed to resolved to inferred - */ + +/** Represents a package over its life-cycle: from parsed to resolved to + * inferred + */ final case class Package[A, B, C, +D]( - name: PackageName, - imports: List[Import[A, B]], - exports: List[ExportedName[C]], - program: D) { + name: PackageName, + imports: List[Import[A, B]], + exports: List[ExportedName[C]], + program: D +) { // It is really important to cache the hashcode and these large dags if // we use them as hash keys @@ -49,27 +51,28 @@ final case class Package[A, B, C, +D]( def mapProgram[D1](fn: D => D1): Package[A, B, C, D1] = Package(name, imports, exports, fn(program)) - def replaceImports[A1, B1](newImports: List[Import[A1, B1]]): Package[A1, B1, C, D] = + def replaceImports[A1, B1]( + newImports: List[Import[A1, B1]] + ): Package[A1, B1, C, D] = Package(name, newImports, exports, program) } object Package { type Interface = Package[Nothing, Nothing, Referant[Kind.Arg], Unit] - /** - * This is a package whose import type is Either: - * 1 a package of the same kind - * 2 an interface - */ + + /** This is a package whose import type is Either: 1 a package of the same + * kind 2 an interface + */ type FixPackage[B, C, D] = Fix[λ[a => Either[Interface, Package[a, B, C, D]]]] - type PackageF[A, B, C] = Either[Interface, Package[FixPackage[A, B, C], A, B, C]] + type PackageF[A, B, C] = + Either[Interface, Package[FixPackage[A, B, C], A, B, C]] type PackageF2[A, B] = PackageF[A, A, B] type Parsed = Package[PackageName, Unit, Unit, List[Statement]] - type Resolved = FixPackage[Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])] - type Typed[T] = Package[ - Interface, - NonEmptyList[Referant[Kind.Arg]], - Referant[Kind.Arg], - Program[TypeEnv[Kind.Arg], TypedExpr[T], Any]] + type Resolved = + FixPackage[Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])] + type Typed[T] = Package[Interface, NonEmptyList[Referant[Kind.Arg]], Referant[ + Kind.Arg + ], Program[TypeEnv[Kind.Arg], TypedExpr[T], Any]] type Inferred = Typed[Declaration] val typedFunctor: Functor[Typed] = @@ -82,37 +85,39 @@ object Package { } } - /** - * Return the last binding in the file with the test type - */ - def testValue[A](tp: Typed[A]): Option[(Identifier.Bindable, RecursionKind, TypedExpr[A])] = - tp - .program - .lets - .filter { case (_, _, te) => te.getType == Type.TestType } - .lastOption - - /** - * Discard any top level values that are not referenced, exported, - * the final test value, or the final expression - * - * This is used to remove private top levels that were inlined. - */ + /** Return the last binding in the file with the test type + */ + def testValue[A]( + tp: Typed[A] + ): Option[(Identifier.Bindable, RecursionKind, TypedExpr[A])] = + tp.program.lets.filter { case (_, _, te) => + te.getType == Type.TestType + }.lastOption + + /** Discard any top level values that are not referenced, exported, the final + * test value, or the final expression + * + * This is used to remove private top levels that were inlined. + */ def discardUnused[A](tp: Typed[A]): Typed[A] = { val pinned: Set[Identifier] = tp.exports.iterator.map(_.name).toSet ++ - tp.program.lets.lastOption.map(_._1) ++ + tp.program.lets.lastOption.map(_._1) ++ testValue(tp).map(_._1) def topLevels(s: Set[(PackageName, Identifier)]): Set[Identifier] = s.collect { case (p, i) if p === tp.name => i } - val letWithGlobals = tp.program.lets.map { case tup @ (_, _, te) => (tup, topLevels(te.globals)) } + val letWithGlobals = tp.program.lets.map { case tup @ (_, _, te) => + (tup, topLevels(te.globals)) + } @annotation.tailrec def loop(reached: Set[Identifier]): Set[Identifier] = { val step = letWithGlobals - .foldMap { case ((bn, _, _), tops) => if (reached(bn)) tops else Set.empty[Identifier] } + .foldMap { case ((bn, _, _), tops) => + if (reached(bn)) tops else Set.empty[Identifier] + } if (step.forall(reached)) reached else loop(step | reached) @@ -120,7 +125,9 @@ object Package { val reached = loop(pinned) - val reachedLets = letWithGlobals.collect { case (tup @ (bn, _, _), _) if reached(bn) => tup } + val reachedLets = letWithGlobals.collect { + case (tup @ (bn, _, _), _) if reached(bn) => tup + } tp.copy(program = tp.program.copy(lets = reachedLets)) } @@ -129,10 +136,10 @@ object Package { def unfix[A, B, C](fp: FixPackage[A, B, C]): PackageF[A, B, C] = FixType.unfix[λ[a => Either[Interface, Package[a, A, B, C]]]](fp) - /** - * build a Parsed Package from a Statement. This is useful for testing or - * library usages. - */ + + /** build a Parsed Package from a Statement. This is useful for testing or + * library usages. + */ def fromStatements(pn: PackageName, stmts: List[Statement]): Package.Parsed = Package(pn, Nil, Nil, stmts) @@ -142,98 +149,135 @@ object Package { def setProgramFrom[A, B](t: Typed[A], newFrom: B): Typed[A] = t.copy(program = t.program.copy(from = newFrom)) - implicit val document: Document[Package[PackageName, Unit, Unit, List[Statement]]] = - Document.instance[Package.Parsed] { case Package(name, imports, exports, statments) => - val p = Doc.text("package ") + Document[PackageName].document(name) + Doc.line - val i = imports match { - case Nil => Doc.empty - case nonEmptyImports => - Doc.line + - Doc.intercalate(Doc.line, nonEmptyImports.map(Document[Import[PackageName, Unit]].document _)) + - Doc.line - } - val e = exports match { - case Nil => Doc.empty - case nonEmptyExports => - Doc.line + - Doc.text("export ") + - Doc.intercalate(Doc.text(", "), nonEmptyExports.map(Document[ExportedName[Unit]].document _)) + - Doc.line - } - val b = statments.map(Document[Statement].document(_)) - Doc.intercalate(Doc.empty, p :: i :: e :: b) + implicit val document + : Document[Package[PackageName, Unit, Unit, List[Statement]]] = + Document.instance[Package.Parsed] { + case Package(name, imports, exports, statments) => + val p = + Doc.text("package ") + Document[PackageName].document(name) + Doc.line + val i = imports match { + case Nil => Doc.empty + case nonEmptyImports => + Doc.line + + Doc.intercalate( + Doc.line, + nonEmptyImports.map( + Document[Import[PackageName, Unit]].document _ + ) + ) + + Doc.line + } + val e = exports match { + case Nil => Doc.empty + case nonEmptyExports => + Doc.line + + Doc.text("export ") + + Doc.intercalate( + Doc.text(", "), + nonEmptyExports.map(Document[ExportedName[Unit]].document _) + ) + + Doc.line + } + val b = statments.map(Document[Statement].document(_)) + Doc.intercalate(Doc.empty, p :: i :: e :: b) } - def parser(defaultPack: Option[PackageName]): P0[Package[PackageName, Unit, Unit, List[Statement]]] = { + def parser( + defaultPack: Option[PackageName] + ): P0[Package[PackageName, Unit, Unit, List[Statement]]] = { // TODO: support comments before the Statement - val parsePack = Padding.parser((P.string("package").soft ~ spaces) *> PackageName.parser <* Parser.toEOL).map(_.padded) + val parsePack = Padding + .parser( + (P.string("package") + .soft ~ spaces) *> PackageName.parser <* Parser.toEOL + ) + .map(_.padded) val pname: P0[PackageName] = defaultPack match { - case None => parsePack + case None => parsePack case Some(p) => parsePack.?.map(_.getOrElse(p)) } val im = Padding.parser(Import.parser <* Parser.toEOL).map(_.padded).rep0 - val ex = Padding.parser((P.string("export").soft ~ spaces) *> ExportedName.parser.itemsMaybeParens.map(_._2) <* Parser.toEOL).map(_.padded) + val ex = Padding + .parser( + (P.string("export") + .soft ~ spaces) *> ExportedName.parser.itemsMaybeParens + .map(_._2) <* Parser.toEOL + ) + .map(_.padded) val body: P0[List[Statement]] = Statement.parser (pname, im, Parser.nonEmptyListToList(ex), body) .mapN { (p, i, e, b) => Package(p, i, e, b) } } - /** - * After having type checked the imports, we now type check the body - * in order to type check the exports - * - * This is used by test code - */ + /** After having type checked the imports, we now type check the body in order + * to type check the exports + * + * This is used by test code + */ def inferBody( - p: PackageName, - imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]], - stmts: List[Statement]): - Ior[NonEmptyList[PackageError], - Program[TypeEnv[Kind.Arg], TypedExpr[Declaration], List[Statement]]] = - inferBodyUnopt(p, imps, stmts).map { - case (fullTypeEnv, prog) => - TypedExprNormalization.normalizeProgram(p, fullTypeEnv, prog) - } + p: PackageName, + imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]], + stmts: List[Statement] + ): Ior[NonEmptyList[PackageError], Program[TypeEnv[Kind.Arg], TypedExpr[ + Declaration + ], List[Statement]]] = + inferBodyUnopt(p, imps, stmts).map { case (fullTypeEnv, prog) => + TypedExprNormalization.normalizeProgram(p, fullTypeEnv, prog) + } - /** - * Infer the types but do not optimize/normalize the lets - */ + /** Infer the types but do not optimize/normalize the lets + */ def inferBodyUnopt( - p: PackageName, - imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]], - stmts: List[Statement]): - Ior[NonEmptyList[PackageError], - (TypeEnv[Kind.Arg], Program[TypeEnv[Kind.Arg], TypedExpr[Declaration], List[Statement]])] = { + p: PackageName, + imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]], + stmts: List[Statement] + ): Ior[NonEmptyList[ + PackageError + ], (TypeEnv[Kind.Arg], Program[TypeEnv[Kind.Arg], TypedExpr[Declaration], List[Statement]])] = { // here we make a pass to get all the local names - val optProg = SourceConverter.toProgram(p, imps.map { i => i.copy(pack = i.pack.name) }, stmts) - .leftMap(_.map(PackageError.SourceConverterErrorIn(_, p): PackageError).toNonEmptyList) + val optProg = SourceConverter + .toProgram(p, imps.map { i => i.copy(pack = i.pack.name) }, stmts) + .leftMap( + _.map( + PackageError.SourceConverterErrorIn(_, p): PackageError + ).toNonEmptyList + ) lazy val typeDefRegions: Map[Type.Const.Defined, Region] = - stmts.iterator.collect { - case tds: TypeDefinitionStatement => - Type.Const.Defined(p, TypeName(tds.name)) -> tds.region - } - .toMap + stmts.iterator.collect { case tds: TypeDefinitionStatement => + Type.Const.Defined(p, TypeName(tds.name)) -> tds.region + }.toMap optProg.flatMap { case Program((importedTypeEnv, parsedTypeEnv), lets, extDefs, _) => - val inferVarianceParsed: Ior[NonEmptyList[PackageError], ParsedTypeEnv[Kind.Arg]] = - KindFormula.solveShapesAndKinds(importedTypeEnv, parsedTypeEnv.allDefinedTypes.reverse) - .bimap({ necError => - necError.map(PackageError.KindInferenceError(p, _, typeDefRegions)).toNonEmptyList - }, { infDTs => - ParsedTypeEnv(infDTs, parsedTypeEnv.externalDefs) - }) + val inferVarianceParsed + : Ior[NonEmptyList[PackageError], ParsedTypeEnv[Kind.Arg]] = + KindFormula + .solveShapesAndKinds( + importedTypeEnv, + parsedTypeEnv.allDefinedTypes.reverse + ) + .bimap( + { necError => + necError + .map(PackageError.KindInferenceError(p, _, typeDefRegions)) + .toNonEmptyList + }, + { infDTs => + ParsedTypeEnv(infDTs, parsedTypeEnv.externalDefs) + } + ) inferVarianceParsed.flatMap { parsedTypeEnv => /* * Check that all recursion is allowable */ val defRecursionCheck: ValidatedNel[PackageError, Unit] = - stmts.traverse_(DefRecursionCheck.checkStatement(_)) + stmts + .traverse_(DefRecursionCheck.checkStatement(_)) .leftMap { badRecursions => badRecursions.map(PackageError.RecursionError(p, _)) } @@ -241,19 +285,21 @@ object Package { val typeEnv: TypeEnv[Kind.Arg] = TypeEnv.fromParsed(parsedTypeEnv) /* - * These are values, including all constructor functions - * that have been imported, this includes local external - * defs - */ + * These are values, including all constructor functions + * that have been imported, this includes local external + * defs + */ val withFQN: Map[(Option[PackageName], Identifier), Type] = { val fqn = - Referant.fullyQualifiedImportedValues(imps)(_.name) + Referant + .fullyQualifiedImportedValues(imps)(_.name) .iterator .map { case ((p, n), t) => ((Some(p), n), t) } // these are local construtors/externals val localDefined = - typeEnv.localValuesOf(p) + typeEnv + .localValuesOf(p) .iterator .map { case (n, t) => ((Some(p), n), t) } @@ -263,11 +309,17 @@ object Package { val fullTypeEnv = importedTypeEnv ++ typeEnv val totalityCheck = lets - .traverse { case (_, _, expr) => TotalityCheck(fullTypeEnv).checkExpr(expr) } - .leftMap { errs => errs.map(PackageError.TotalityCheckError(p, _)) } + .traverse { case (_, _, expr) => + TotalityCheck(fullTypeEnv).checkExpr(expr) + } + .leftMap { errs => + errs.map(PackageError.TotalityCheckError(p, _)) + } - val inferenceEither = Infer.typeCheckLets(p, lets) - .runFully(withFQN, + val inferenceEither = Infer + .typeCheckLets(p, lets) + .runFully( + withFQN, Referant.typeConstructors(imps) ++ typeEnv.typeConstructors, fullTypeEnv.toKindMap ) @@ -278,12 +330,15 @@ object Package { .map(PackageError.TypeErrorIn(_, p)) val checkUnusedLets = - lets.traverse_ { case (_, _, expr) => - UnusedLetCheck.check(expr) - } - .leftMap { errs => - NonEmptyList.one(PackageError.UnusedLetError(p, errs.toNonEmptyList)) - } + lets + .traverse_ { case (_, _, expr) => + UnusedLetCheck.check(expr) + } + .leftMap { errs => + NonEmptyList.one( + PackageError.UnusedLetError(p, errs.toNonEmptyList) + ) + } /* * Checks accumulate errors, but have no return value: @@ -291,11 +346,13 @@ object Package { * error accumulation */ val checks = List( - defRecursionCheck, checkUnusedLets, totalityCheck - ) - .sequence_ + defRecursionCheck, + checkUnusedLets, + totalityCheck + ).sequence_ - val inference = Validated.fromEither(inferenceEither).leftMap(NonEmptyList.of(_)) + val inference = + Validated.fromEither(inferenceEither).leftMap(NonEmptyList.of(_)) Parallel[Ior[NonEmptyList[PackageError], *]] .parProductR(checks.toIor)(inference.toIor) @@ -303,29 +360,39 @@ object Package { } } - /** - * The parsed representation of the predef. - */ + /** The parsed representation of the predef. + */ lazy val predefPackage: Package.Parsed = parser(None).parse(Predef.predefString) match { case Right((_, pack)) => // Make function defs: - def paramType(n: Int) = (TypeRef.TypeVar(s"i$n"), Some(Kind.Arg(Variance.contra, Kind.Type))) - def makeFns(n: Int, - typeArgs: List[(TypeRef.TypeVar, Option[Kind.Arg])], - acc: List[Statement.ExternalStruct]): List[Statement.ExternalStruct] = + def paramType(n: Int) = + (TypeRef.TypeVar(s"i$n"), Some(Kind.Arg(Variance.contra, Kind.Type))) + def makeFns( + n: Int, + typeArgs: List[(TypeRef.TypeVar, Option[Kind.Arg])], + acc: List[Statement.ExternalStruct] + ): List[Statement.ExternalStruct] = if (n > Type.FnType.MaxSize) acc else { - val fn = Statement.ExternalStruct(Identifier.Constructor(s"Fn$n"), typeArgs)(Region(0, 1)) + val fn = Statement.ExternalStruct( + Identifier.Constructor(s"Fn$n"), + typeArgs + )(Region(0, 1)) val acc1 = fn :: acc makeFns(n + 1, paramType(n) :: typeArgs, acc1) } val out = (TypeRef.TypeVar("z"), Some(Kind.Arg(Variance.co, Kind.Type))) val allFns = makeFns(1, paramType(0) :: out :: Nil, Nil).reverse - val exported = allFns.map { extstr => ExportedName.TypeName(extstr.name, ()) } + val exported = allFns.map { extstr => + ExportedName.TypeName(extstr.name, ()) + } // Add functions into the predef - pack.copy(exports = exported ::: pack.exports, program = allFns ::: pack.program) + pack.copy( + exports = exported ::: pack.exports, + program = allFns ::: pack.program + ) case Left(err) => val idx = err.failedAtOffset val lm = LocationMap(Predef.predefString) @@ -340,29 +407,50 @@ object Package { def document(pack: Typed[Any]): Doc = Doc.text("package: ") + Doc.text(pack.name.asString) + { val lines = Doc.hardLine - val imps = Doc.text("imports: ") + Doc.intercalate(Doc.line, pack.imports.map { imp => - Doc.text(imp.pack.name.asString) + Doc.space + (Doc.char('[') + Doc.line + - Doc.intercalate(Doc.comma + Doc.line, imp.items.toList.map { imp => - Doc.text(imp.originalName.sourceCodeRepr) - }) + Doc.line + Doc.char(']') - ).grouped - }).nested(4) - - val exports = Doc.text("exports: ") + Doc.intercalate(Doc.line, + val imps = Doc.text("imports: ") + Doc + .intercalate( + Doc.line, + pack.imports.map { imp => + Doc.text(imp.pack.name.asString) + Doc.space + (Doc.char( + '[' + ) + Doc.line + + Doc.intercalate( + Doc.comma + Doc.line, + imp.items.toList.map { imp => + Doc.text(imp.originalName.sourceCodeRepr) + } + ) + Doc.line + Doc.char(']')).grouped + } + ) + .nested(4) + + val exports = Doc.text("exports: ") + Doc + .intercalate( + Doc.line, pack.exports.map { exp => Doc.text(exp.name.sourceCodeRepr) - }).grouped.nested(4) - - val tpes = Doc.text("types: ") + Doc.intercalate(Doc.comma + Doc.line, - pack.program.types.definedTypes.toList.map { case (_, t) => - Doc.text(t.name.ident.sourceCodeRepr) - }).grouped.nested(4) + } + ) + .grouped + .nested(4) + + val tpes = Doc.text("types: ") + Doc + .intercalate( + Doc.comma + Doc.line, + pack.program.types.definedTypes.toList.map { case (_, t) => + Doc.text(t.name.ident.sourceCodeRepr) + } + ) + .grouped + .nested(4) val eqDoc = Doc.text(" = ") - val exprs = Doc.intercalate(Doc.hardLine + Doc.hardLine, + val exprs = Doc.intercalate( + Doc.hardLine + Doc.hardLine, pack.program.lets.map { case (n, _, te) => Doc.text(n.sourceCodeRepr) + eqDoc + te.repr - }) + } + ) val all = lines :: imps :: exports :: tpes :: exprs :: Nil @@ -376,14 +464,19 @@ object Package { Doc.text("interface: ") + Doc.text(iface.name.asString) + { val lines = Doc.hardLine - val exports = Doc.text("exports: ") + Doc.intercalate(Doc.line, + val exports = Doc.text("exports: ") + Doc + .intercalate( + Doc.line, iface.exports.map { exp => Doc.text(exp.name.sourceCodeRepr) - }).grouped.nested(4) + } + ) + .grouped + .nested(4) val all = lines :: exports :: Nil Doc.intercalate(Doc.hardLine, all) }.nested(4) } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageCustoms.scala b/core/src/main/scala/org/bykn/bosatsu/PackageCustoms.scala index ae362960a..d3beb4b54 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageCustoms.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageCustoms.scala @@ -1,7 +1,15 @@ package org.bykn.bosatsu import cats.Monad -import cats.data.{Chain, NonEmptyList, NonEmptySet, NonEmptyChain, State, Validated, ValidatedNec} +import cats.data.{ + Chain, + NonEmptyList, + NonEmptySet, + NonEmptyChain, + State, + Validated, + ValidatedNec +} import org.bykn.bosatsu.rankn.{Type, TypeEnv, DefinedType} import scala.collection.immutable.SortedSet @@ -10,31 +18,32 @@ import org.bykn.bosatsu.Referant.Constructor import org.bykn.bosatsu.Referant.DefinedT import org.bykn.bosatsu.TypedExpr.Match -/** - * This checks the imports and exports of compiled packages - * and makes sure they are valid +/** This checks the imports and exports of compiled packages and makes sure they + * are valid */ object PackageCustoms { - def apply[A](pack: Package.Typed[A]): ValidatedNec[PackageError, Package.Typed[A]] = { + def apply[A]( + pack: Package.Typed[A] + ): ValidatedNec[PackageError, Package.Typed[A]] = { checkValuesHaveExportedTypes(pack.name, pack.exports) *> allImportsAreUsed(pack) } private def removeUnused[A]( - vals: Option[NonEmptySet[(PackageName, Identifier)]], - types: Option[NonEmptySet[(PackageName, Type.Const)]], - pack: Package.Typed[A] + vals: Option[NonEmptySet[(PackageName, Identifier)]], + types: Option[NonEmptySet[(PackageName, Type.Const)]], + pack: Package.Typed[A] ): Package.Typed[A] = (vals, types) match { case (None, None) => pack case (ov, ot) => val unV = ov match { case Some(v) => v.toSortedSet - case None => SortedSet.empty[(PackageName, Identifier)] + case None => SortedSet.empty[(PackageName, Identifier)] } val unT = ot match { case Some(v) => v.toSortedSet - case None => SortedSet.empty[(PackageName, Type.Const)] + case None => SortedSet.empty[(PackageName, Type.Const)] } val i = pack.imports.flatMap { imp => imp.mapFilter { (pack, item) => @@ -51,7 +60,9 @@ object PackageCustoms { pack.copy(imports = i) } - private def allImportsAreUsed[A](pack: Package.Typed[A]): ValidatedNec[PackageError, Package.Typed[A]] = { + private def allImportsAreUsed[A]( + pack: Package.Typed[A] + ): ValidatedNec[PackageError, Package.Typed[A]] = { // Note, we can't import just a name or just a type when we have Structs, // we get both. So, we will count a Constructor used if the Identifier // OR the type is used @@ -89,30 +100,36 @@ object PackageCustoms { val usedValuesSt: VState[Unit] = pack.program.lets.traverse_ { case (_, _, te) => te.traverseUp { - case g@TypedExpr.Global(p, n, _, _) => + case g @ TypedExpr.Global(p, n, _, _) => State { s => (s + ((p, n)), g) } - case m @ Match(_, branches, _) => - branches.traverse_ { - case (pat, _) => - pat.traverseStruct[VState, (PackageName, Identifier.Constructor)] { (n, parts) => - State.modify[VSet](_ + n) *> - parts.map { inner => - Pattern.PositionalStruct(n, inner) - } - } - .void - }.as(m) + case m @ Match(_, branches, _) => + branches + .traverse_ { case (pat, _) => + pat + .traverseStruct[ + VState, + (PackageName, Identifier.Constructor) + ] { (n, parts) => + State.modify[VSet](_ + n) *> + parts.map { inner => + Pattern.PositionalStruct(n, inner) + } + } + .void + } + .as(m) case te => Monad[VState].pure(te) } } val usedValues = usedValuesSt.runS(Set.empty).value - val usedTypes: Set[Type.Const] = - pack.program.lets.iterator.flatMap( - _._3.allTypes.flatMap(Type.constantsOf(_)) - ).toSet + pack.program.lets.iterator + .flatMap( + _._3.allTypes.flatMap(Type.constantsOf(_)) + ) + .toSet val unusedValues = impValues.filterNot { tup => tup match { @@ -122,7 +139,7 @@ object PackageCustoms { case _ => // no ambiguity for bindable usedValues(tup) - } + } } val unusedTypes = impTypes.filterNot { case (pn, t) => // deal with the ambiguity of capital names @@ -131,51 +148,57 @@ object PackageCustoms { val unusedValMap = unusedValues.groupByNes(_._1) val unusedTypeMap = unusedTypes.groupByNes(_._1) - val unusedPacks = (unusedValMap.keySet | unusedTypeMap.keySet) - PackageName.PredefName + val unusedPacks = + (unusedValMap.keySet | unusedTypeMap.keySet) - PackageName.PredefName if (unusedPacks.isEmpty) { // remove unused - Validated.valid(removeUnused( - unusedValMap.get(PackageName.PredefName), - unusedTypeMap.get(PackageName.PredefName), - pack - )) - } - else { - val badImports = NonEmptyList.fromListUnsafe( - unusedPacks - .iterator - .map { ipack => - val thisVals = unusedValMap.get(ipack) match { - case Some(nes) => nes.toSortedSet.toList.map { case (_, i) => + Validated.valid( + removeUnused( + unusedValMap.get(PackageName.PredefName), + unusedTypeMap.get(PackageName.PredefName), + pack + ) + ) + } else { + val badImports = + NonEmptyList.fromListUnsafe(unusedPacks.iterator.map { ipack => + val thisVals = unusedValMap.get(ipack) match { + case Some(nes) => + nes.toSortedSet.toList.map { case (_, i) => ImportedName.OriginalName(i, ()) } - case None => Nil - } + case None => Nil + } - val thisTpes = unusedTypeMap.get(ipack) match { - case Some(nes) => nes.toSortedSet.toList.map { case (_, t) => + val thisTpes = unusedTypeMap.get(ipack) match { + case Some(nes) => + nes.toSortedSet.toList.map { case (_, t) => ImportedName.OriginalName(t.toDefined.name.ident, ()) } - case None => Nil - } - // one or the other or both of these is non-empty - Import(ipack, NonEmptyList.fromListUnsafe((thisVals ::: thisTpes).distinct)) + case None => Nil } - .toList) + // one or the other or both of these is non-empty + Import( + ipack, + NonEmptyList.fromListUnsafe((thisVals ::: thisTpes).distinct) + ) + }.toList) Validated.invalidNec(PackageError.UnusedImport(pack.name, badImports)) } } } - private def checkValuesHaveExportedTypes[V](pn: PackageName, exports: List[ExportedName[Referant[V]]]): ValidatedNec[PackageError, Unit] = { - val exportedTypes: List[DefinedType[V]] = exports - .iterator + private def checkValuesHaveExportedTypes[V]( + pn: PackageName, + exports: List[ExportedName[Referant[V]]] + ): ValidatedNec[PackageError, Unit] = { + val exportedTypes: List[DefinedType[V]] = exports.iterator .map(_.tag) .collect { case Referant.Constructor(dt, _) => dt - case Referant.DefinedT(dt) => dt + case Referant.DefinedT(dt) => dt } .toList .distinct @@ -183,18 +206,16 @@ object PackageCustoms { val exportedTE = TypeEnv.fromDefinitions(exportedTypes) type Exp = ExportedName[Referant[V]] - val usedTypes: Iterator[(Type.Const, Exp, Type)] = exports - .iterator + val usedTypes: Iterator[(Type.Const, Exp, Type)] = exports.iterator .flatMap { n => n.tag match { case Referant.Value(t) => Iterator.single((t, n)) - case _ => Iterator.empty + case _ => Iterator.empty } } .flatMap { case (t, n) => Type.constantsOf(t).map((_, n, t)) } .filter { case (Type.Const.Defined(p, _), _, _) => p === pn } - def errorFor(t: (Type.Const, Exp, Type)): List[PackageError] = exportedTE.toDefinedType(t._1) match { case None => @@ -202,9 +223,11 @@ object PackageCustoms { case Some(_) => Nil } - NonEmptyChain.fromChain(Chain.fromIterableOnce(usedTypes.flatMap(errorFor))) match { - case None => Validated.valid(()) + NonEmptyChain.fromChain( + Chain.fromIterableOnce(usedTypes.flatMap(errorFor)) + ) match { + case None => Validated.valid(()) case Some(nel) => Validated.invalid(nel) } } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageError.scala b/core/src/main/scala/org/bykn/bosatsu/PackageError.scala index ee1b266d6..40cf7869a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageError.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageError.scala @@ -7,7 +7,10 @@ import rankn._ import LocationMap.Colorize sealed abstract class PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize): String + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ): String } object PackageError { @@ -15,17 +18,17 @@ object PackageError { // TODO: we should use the imports in each package to talk about // types in ways that are local to that package require(pack ne null) - tpes - .iterator - .map { t => - (t, Type.fullyResolvedDocument.document(t)) - } - .toMap + tpes.iterator.map { t => + (t, Type.fullyResolvedDocument.document(t)) + }.toMap } - def nearest[A](ident: Identifier, existing: Iterable[(Identifier, A)], count: Int): List[(Identifier, A)] = - existing - .iterator + def nearest[A]( + ident: Identifier, + existing: Iterable[(Identifier, A)], + count: Int + ): List[(Identifier, A)] = + existing.iterator .map { case (i, a) => val d = EditDistance.string(ident.asString, i.asString) (i, d, a) @@ -47,24 +50,28 @@ object PackageError { def headLine(packageName: PackageName, region: Option[Region]): Doc = { val (lm, sourceName) = getMapSrc(packageName) val suffix = (region.flatMap { r => lm.toLineCol(r.start) }) match { - case Some((line, col)) => s":${line + 1}:${col + 1}" - case None => "" + case Some((line, col)) => s":${line + 1}:${col + 1}" + case None => "" } Doc.text(s"in file: $sourceName$suffix, package ${packageName.asString}") } def getMapSrc(pack: PackageName): (LocationMap, String) = sm.get(pack) match { - case None => (emptyLocMap, "") + case None => (emptyLocMap, "") case Some(found) => found } } - - case class UnknownExport[A](ex: ExportedName[A], - in: PackageName, - lets: List[(Identifier.Bindable, RecursionKind, TypedExpr[Declaration])]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class UnknownExport[A]( + ex: ExportedName[A], + in: PackageName, + lets: List[(Identifier.Bindable, RecursionKind, TypedExpr[Declaration])] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, sourceName) = sourceMap.getMapSrc(in) val header = s"in $sourceName unknown export ${ex.name.sourceCodeRepr}" @@ -73,7 +80,10 @@ object PackageError { val candidates = nearest(ex.name, candidateMap, 3) .map { case (n, r) => - val pos = lm.toLineCol(r.start).map { case (l, c) => s":${l + 1}:${c + 1}" }.getOrElse("") + val pos = lm + .toLineCol(r.start) + .map { case (l, c) => s":${l + 1}:${c + 1}" } + .getOrElse("") s"${n.asString}$pos" } val candstr = candidates.mkString("\n\t", "\n\t", "\n") @@ -84,34 +94,51 @@ object PackageError { } } - case class PrivateTypeEscape[A](ex: ExportedName[A], - exType: Type, - in: PackageName, - privateType: Type.Const) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class PrivateTypeEscape[A]( + ex: ExportedName[A], + exType: Type, + in: PackageName, + privateType: Type.Const + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (_, sourceName) = sourceMap.getMapSrc(in) val pt = Type.TyConst(privateType) val tpeMap = showTypes(in, exType :: pt :: Nil) - val first = s"in $sourceName export ${ex.name.sourceCodeRepr} of type ${tpeMap(exType).render(80)}" + val first = + s"in $sourceName export ${ex.name.sourceCodeRepr} of type ${tpeMap(exType).render(80)}" if (exType == pt) { s"$first has an unexported (private) type." - } - else { + } else { s"$first references an unexported (private) type ${tpeMap(pt).render(80)}." } } } - case class UnknownImportPackage[A, B, C](pack: PackageName, fromName: PackageName) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class UnknownImportPackage[A, B, C]( + pack: PackageName, + fromName: PackageName + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (_, sourceName) = sourceMap.getMapSrc(fromName) s"in $sourceName package ${fromName.asString} imports unknown package ${pack.asString}" } } - case class DuplicatedImport(duplicates: NonEmptyList[(PackageName, ImportedName[Unit])]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = - duplicates.sortBy(_._2.localName) + case class DuplicatedImport( + duplicates: NonEmptyList[(PackageName, ImportedName[Unit])] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = + duplicates + .sortBy(_._2.localName) .toList .iterator .map { case (pack, imp) => @@ -123,55 +150,70 @@ object PackageError { // We could check if we forgot to export the name in the package and give that error case class UnknownImportName[A, B]( - in: PackageName, - importedPackage: PackageName, - letMap: Map[Identifier, Unit], - iname: ImportedName[A], - exports: List[ExportedName[B]]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { - val ipname = importedPackage - - val (_, sourceName) = sourceMap.getMapSrc(in) - letMap - .get(iname.originalName) match { - case Some(_) => - s"in $sourceName package: ${ipname.asString} has ${iname.originalName.sourceCodeRepr} but it is not exported. Add to exports" - case None => - val near = nearest(iname.originalName, letMap, 3) - .map { case (n, _) => n.sourceCodeRepr } - .mkString(" Nearest: ", ", ", "") - s"in $sourceName package: ${ipname.asString} does not have name ${iname.originalName.sourceCodeRepr}.$near" - } + in: PackageName, + importedPackage: PackageName, + letMap: Map[Identifier, Unit], + iname: ImportedName[A], + exports: List[ExportedName[B]] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { + val ipname = importedPackage + + val (_, sourceName) = sourceMap.getMapSrc(in) + letMap + .get(iname.originalName) match { + case Some(_) => + s"in $sourceName package: ${ipname.asString} has ${iname.originalName.sourceCodeRepr} but it is not exported. Add to exports" + case None => + val near = nearest(iname.originalName, letMap, 3) + .map { case (n, _) => n.sourceCodeRepr } + .mkString(" Nearest: ", ", ", "") + s"in $sourceName package: ${ipname.asString} does not have name ${iname.originalName.sourceCodeRepr}.$near" } } + } case class UnknownImportFromInterface[A, B]( - in: PackageName, - importingName: PackageName, - exportNames: List[Identifier], - iname: ImportedName[A], - exports: List[ExportedName[B]]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { - - val exportMap = exportNames.map { e => (e, ()) }.toMap - - val near = Doc.text(" Nearest: ") + - (Doc.intercalate( + in: PackageName, + importingName: PackageName, + exportNames: List[Identifier], + iname: ImportedName[A], + exports: List[ExportedName[B]] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { + + val exportMap = exportNames.map { e => (e, ()) }.toMap + + val near = Doc.text(" Nearest: ") + + (Doc + .intercalate( Doc.text(",") + Doc.line, nearest(iname.originalName, exportMap, 3) .map { ident => Doc.text(ident._1.sourceCodeRepr) } ) .nested(4) .grouped) - - (sourceMap.headLine(importingName, None) + Doc.hardLine + Doc.text( - s"does not have name ${iname.originalName}.") + near - ).render(80) - } + + (sourceMap.headLine(importingName, None) + Doc.hardLine + Doc.text( + s"does not have name ${iname.originalName}." + ) + near).render(80) } + } - case class CircularDependency[A, B, C](from: PackageName, path: NonEmptyList[PackageName]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class CircularDependency[A, B, C]( + from: PackageName, + path: NonEmptyList[PackageName] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val packs = from :: (path.toList) val msg = packs.map { p => val (_, src) = sourceMap.getMapSrc(p) @@ -182,36 +224,62 @@ object PackageError { } } - case class VarianceInferenceFailure(from: PackageName, failed: NonEmptyList[rankn.DefinedType[Unit]]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { - s"failed to infer variance in ${from.asString} of " + failed.toList.map(_.name.ident.asString).sorted.mkString(", ") + case class VarianceInferenceFailure( + from: PackageName, + failed: NonEmptyList[rankn.DefinedType[Unit]] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { + s"failed to infer variance in ${from.asString} of " + failed.toList + .map(_.name.ident.asString) + .sorted + .mkString(", ") } } - case class TypeErrorIn(tpeErr: Infer.Error, pack: PackageName) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class TypeErrorIn(tpeErr: Infer.Error, pack: PackageName) + extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) def singleToDoc(tpeErr: Infer.Error.Single): Doc = { val (teMessage, region) = tpeErr match { case Infer.Error.NotUnifiable(t0, t1, r0, r1) => val context0 = - if (r0 == r1) Doc.space // sometimes the region of the error is the same on right and left + if (r0 == r1) + Doc.space // sometimes the region of the error is the same on right and left else { - val m = lm.showRegion(r0, 2, errColor).getOrElse(Doc.str(r0)) // we should highlight the whole region + val m = lm + .showRegion(r0, 2, errColor) + .getOrElse( + Doc.str(r0) + ) // we should highlight the whole region Doc.hardLine + m + Doc.hardLine } val context1 = - lm.showRegion(r1, 2, errColor).getOrElse(Doc.str(r1)) // we should highlight the whole region + lm.showRegion(r1, 2, errColor) + .getOrElse(Doc.str(r1)) // we should highlight the whole region val fnHint = (t0, t1) match { - case (Type.RootConst(Type.FnType(_, leftSize)), - Type.RootConst(Type.FnType(_, rightSize))) => + case ( + Type.RootConst(Type.FnType(_, leftSize)), + Type.RootConst(Type.FnType(_, rightSize)) + ) => // both are functions - def args(n: Int) = if (n == 1) "one argument" else s"$n arguments" - Doc.text(s"hint: the first type is a function with ${args(leftSize)} and the second is a function with ${args(rightSize)}.") + Doc.hardLine + def args(n: Int) = + if (n == 1) "one argument" else s"$n arguments" + Doc.text( + s"hint: the first type is a function with ${args(leftSize)} and the second is a function with ${args(rightSize)}." + ) + Doc.hardLine case (Type.Fun(_, _), _) | (_, Type.Fun(_, _)) => - Doc.text("hint: this often happens when you apply the wrong number of arguments to a function.") + Doc.hardLine + Doc.text( + "hint: this often happens when you apply the wrong number of arguments to a function." + ) + Doc.hardLine case _ => Doc.empty } @@ -223,26 +291,38 @@ object PackageError { (doc, Some(r0)) case Infer.Error.VarNotInScope((_, name), scope, region) => - val ctx = lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) + val ctx = + lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) val candidates: List[String] = nearest(name, scope.map { case ((_, n), _) => (n, ()) }, 3) .map { case (n, _) => n.asString } val cmessage = - if (candidates.nonEmpty) candidates.mkString("\nClosest: ", ", ", ".\n") + if (candidates.nonEmpty) + candidates.mkString("\nClosest: ", ", ", ".\n") else "" val qname = "\"" + name.sourceCodeRepr + "\"" - (Doc.text("name ") + Doc.text(qname) + Doc.text(" unknown.") + Doc.text(cmessage) + Doc.hardLine + - ctx, Some(region)) + ( + Doc.text("name ") + Doc.text(qname) + Doc.text(" unknown.") + Doc + .text(cmessage) + Doc.hardLine + + ctx, + Some(region) + ) case Infer.Error.SubsumptionCheckFailure(t0, t1, r0, r1, _) => val context0 = - if (r0 == r1) Doc.space // sometimes the region of the error is the same on right and left + if (r0 == r1) + Doc.space // sometimes the region of the error is the same on right and left else { - val m = lm.showRegion(r0, 2, errColor).getOrElse(Doc.str(r0)) // we should highlight the whole region + val m = lm + .showRegion(r0, 2, errColor) + .getOrElse( + Doc.str(r0) + ) // we should highlight the whole region Doc.hardLine + m + Doc.hardLine } val context1 = - lm.showRegion(r1, 2, errColor).getOrElse(Doc.str(r1)) // we should highlight the whole region + lm.showRegion(r1, 2, errColor) + .getOrElse(Doc.str(r1)) // we should highlight the whole region val tmap = showTypes(pack, List(t0, t1)) val doc = Doc.text("type ") + tmap(t0) + context0 + @@ -250,8 +330,12 @@ object PackageError { context1 (doc, Some(r0)) - case uc@Infer.Error.UnknownConstructor((_, n), region, _) => - val near = nearest(n, uc.knownConstructors.map { case (_, n) => (n, ()) }.toMap, 3) + case uc @ Infer.Error.UnknownConstructor((_, n), region, _) => + val near = nearest( + n, + uc.knownConstructors.map { case (_, n) => (n, ()) }.toMap, + 3 + ) .map { case (n, _) => n.asString } val nearStr = @@ -259,7 +343,10 @@ object PackageError { else near.mkString(", nearest: ", ", ", "") val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) // we should highlight the whole region + lm.showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region val doc = Doc.text("unknown constructor ") + Doc.text(n.asString) + Doc.text(nearStr) + Doc.hardLine + context @@ -267,9 +354,14 @@ object PackageError { case Infer.Error.KindCannotTyApply(applied, region) => val tmap = showTypes(pack, applied :: Nil) val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) // we should highlight the whole region + lm.showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region val doc = Doc.text("kind error: for kind of the left of ") + - tmap(applied) + Doc.text(" is *. Cannot apply to kind *.") + Doc.hardLine + + tmap(applied) + Doc.text( + " is *. Cannot apply to kind *." + ) + Doc.hardLine + context (doc, Some(region)) @@ -278,78 +370,112 @@ object PackageError { val rightT = applied.arg val tmap = showTypes(pack, applied :: leftT :: rightT :: Nil) val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) - val doc = Doc.text("kind error: ") + Doc.text("the type: ") + tmap(applied) + - Doc.text(" is invalid because the left ") + tmap(leftT) + Doc.text(" has kind ") + Kind.toDoc(leftK) + - Doc.text(" and the right ") + tmap(rightT) + Doc.text(" has kind ") + Kind.toDoc(rightK) + + lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) + val doc = Doc.text("kind error: ") + Doc.text("the type: ") + tmap( + applied + ) + + Doc.text(" is invalid because the left ") + tmap(leftT) + Doc + .text(" has kind ") + Kind.toDoc(leftK) + + Doc.text(" and the right ") + tmap(rightT) + Doc.text( + " has kind " + ) + Kind.toDoc(rightK) + Doc.text(s" but left cannot accept the kind of the right:") + Doc.hardLine + context (doc, Some(region)) - case Infer.Error.KindMismatch(meta, metaK, rightT, rightK, metaR, rightR) => + case Infer.Error.KindMismatch( + meta, + metaK, + rightT, + rightK, + metaR, + rightR + ) => val tmap = showTypes(pack, meta :: rightT :: Nil) val context0 = - lm.showRegion(metaR, 2, errColor).getOrElse(Doc.str(metaR)) // we should highlight the whole region + lm.showRegion(metaR, 2, errColor) + .getOrElse( + Doc.str(metaR) + ) // we should highlight the whole region val context1 = { if (metaR != rightR) { Doc.text(" at: ") + Doc.hardLine + - lm.showRegion(rightR, 2, errColor).getOrElse(Doc.str(rightR)) + // we should highlight the whole region - Doc.hardLine - } - else { + lm.showRegion(rightR, 2, errColor) + .getOrElse( + Doc.str(rightR) + ) + // we should highlight the whole region + Doc.hardLine + } else { Doc.empty } } - val doc = Doc.text("kind error: ") + Doc.text("the type: ") + tmap(meta) + - Doc.text(" of kind: ") + Kind.toDoc(metaK) + Doc.text(" at: ") + Doc.hardLine + - context0 + Doc.hardLine + Doc.hardLine + - Doc.text("cannot be unified with the type ") + tmap(rightT) + - Doc.text(" of kind: ") + Kind.toDoc(rightK) + context1 + - Doc.hardLine + - Doc.text("because the first kind does not subsume the second.") + val doc = + Doc.text("kind error: ") + Doc.text("the type: ") + tmap(meta) + + Doc.text(" of kind: ") + Kind.toDoc(metaK) + Doc.text( + " at: " + ) + Doc.hardLine + + context0 + Doc.hardLine + Doc.hardLine + + Doc.text("cannot be unified with the type ") + tmap(rightT) + + Doc.text(" of kind: ") + Kind.toDoc(rightK) + context1 + + Doc.hardLine + + Doc.text("because the first kind does not subsume the second.") (doc, Some(metaR)) case Infer.Error.UnexpectedMeta(meta, in, metaR, rightR) => val tymeta = Type.TyMeta(meta) val tmap = showTypes(pack, tymeta :: in :: Nil) val context0 = - lm.showRegion(metaR, 2, errColor).getOrElse(Doc.str(metaR)) // we should highlight the whole region + lm.showRegion(metaR, 2, errColor) + .getOrElse( + Doc.str(metaR) + ) // we should highlight the whole region val context1 = { if (metaR != rightR) { Doc.text(" at: ") + Doc.hardLine + - lm.showRegion(rightR, 2, errColor).getOrElse(Doc.str(rightR)) + // we should highlight the whole region - Doc.hardLine - } - else { + lm.showRegion(rightR, 2, errColor) + .getOrElse( + Doc.str(rightR) + ) + // we should highlight the whole region + Doc.hardLine + } else { Doc.empty } } - val doc = Doc.text("Unexpected unknown: the type: ") + tmap(tymeta) + - Doc.text(" of kind: ") + Kind.toDoc(meta.kind) + Doc.text(" at: ") + Doc.hardLine + - context0 + Doc.hardLine + Doc.hardLine + - Doc.text("inside the type ") + tmap(in) + context1 + - Doc.hardLine + - Doc.text("this sometimes happens when a function arg has been omitted, or an illegal recursive type or function.") + val doc = + Doc.text("Unexpected unknown: the type: ") + tmap(tymeta) + + Doc.text(" of kind: ") + Kind.toDoc(meta.kind) + Doc.text( + " at: " + ) + Doc.hardLine + + context0 + Doc.hardLine + Doc.hardLine + + Doc.text("inside the type ") + tmap(in) + context1 + + Doc.hardLine + + Doc.text( + "this sometimes happens when a function arg has been omitted, or an illegal recursive type or function." + ) (doc, Some(metaR)) - case Infer.Error.NotPolymorphicEnough(tpe, _, _, region) => + case Infer.Error.NotPolymorphicEnough(tpe, _, _, region) => val tmap = showTypes(pack, tpe :: Nil) val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) + lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) - (Doc.text("the type ") + tmap(tpe) + Doc.text(" is not polymorphic enough") + Doc.hardLine + context, Some(region)) - case Infer.Error.ArityMismatch(leftA, leftR, rightA, rightR) => + ( + Doc.text("the type ") + tmap(tpe) + Doc.text( + " is not polymorphic enough" + ) + Doc.hardLine + context, + Some(region) + ) + case Infer.Error.ArityMismatch(leftA, leftR, rightA, rightR) => val context0 = - lm.showRegion(leftR, 2, errColor).getOrElse(Doc.str(leftR)) + lm.showRegion(leftR, 2, errColor).getOrElse(Doc.str(leftR)) val context1 = { if (leftR != rightR) { Doc.text(" at: ") + Doc.hardLine + - lm.showRegion(rightR, 2, errColor).getOrElse(Doc.str(rightR)) - } - else { + lm.showRegion(rightR, 2, errColor).getOrElse(Doc.str(rightR)) + } else { Doc.empty } } @@ -357,68 +483,103 @@ object PackageError { def args(n: Int) = if (n == 1) "one argument" else s"$n arguments" - (Doc.text(s"function with ${args(leftA)} at:") + Doc.hardLine + context0 + - Doc.text(s" does not match function with ${args(rightA)}") + context1, Some(leftR)) + ( + Doc.text( + s"function with ${args(leftA)} at:" + ) + Doc.hardLine + context0 + + Doc.text( + s" does not match function with ${args(rightA)}" + ) + context1, + Some(leftR) + ) case Infer.Error.ArityTooLarge(found, max, region) => val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) + lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) - (Doc.text(s"function with $found arguments is too large. Maximum function argument count is $max.") + Doc.hardLine + context, - Some(region)) + ( + Doc.text( + s"function with $found arguments is too large. Maximum function argument count is $max." + ) + Doc.hardLine + context, + Some(region) + ) case Infer.Error.UnexpectedBound(bound, _, reg, _) => val tyvar = Type.TyVar(bound) val tmap = showTypes(pack, tyvar :: Nil) val context = - lm.showRegion(reg, 2, errColor).getOrElse(Doc.str(reg)) + lm.showRegion(reg, 2, errColor).getOrElse(Doc.str(reg)) - (Doc.text("unexpected bound: ") + tmap(tyvar) + Doc.hardLine + context, Some(reg)) + ( + Doc.text("unexpected bound: ") + tmap( + tyvar + ) + Doc.hardLine + context, + Some(reg) + ) case Infer.Error.UnionPatternBindMismatch(_, names, region) => val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) + lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) val uniqueSets = graph.Tree.distinctBy(names)(_.toSet) - val uniqs = Doc.intercalate(Doc.char(',') + Doc.line, + val uniqs = Doc.intercalate( + Doc.char(',') + Doc.line, uniqueSets.toList.map { names => - Doc.text(names.iterator.map(_.sourceCodeRepr).mkString("[", ", ", "]")) + Doc.text( + names.iterator.map(_.sourceCodeRepr).mkString("[", ", ", "]") + ) } ) - (Doc.text("not all union elements bind the same names: ") + - (Doc.line + uniqs + context).nested(4).grouped, - Some(region)) + ( + Doc.text("not all union elements bind the same names: ") + + (Doc.line + uniqs + context).nested(4).grouped, + Some(region) + ) case Infer.Error.UnknownDefined(const, reg) => val tpe = Type.TyConst(const) val tmap = showTypes(pack, tpe :: Nil) val context = - lm.showRegion(reg, 2, errColor).getOrElse(Doc.str(reg)) + lm.showRegion(reg, 2, errColor).getOrElse(Doc.str(reg)) - (Doc.text("unknown type: ") + tmap(tpe) + Doc.hardLine + context, Some(reg)) + ( + Doc.text("unknown type: ") + tmap(tpe) + Doc.hardLine + context, + Some(reg) + ) case ie: Infer.Error.InternalError => val context = - lm.showRegion(ie.region, 2, errColor).getOrElse(Doc.str(ie.region)) + lm.showRegion(ie.region, 2, errColor) + .getOrElse(Doc.str(ie.region)) (Doc.text(ie.message) + Doc.hardLine + context, Some(ie.region)) } - val h = sourceMap.headLine(pack, region) + val h = sourceMap.headLine(pack, region) (h + Doc.hardLine + teMessage) } - + val finalDoc = tpeErr match { case s: Infer.Error.Single => singleToDoc(s) - case c@Infer.Error.Combine(_, _) => + case c @ Infer.Error.Combine(_, _) => val twoLines = Doc.hardLine + Doc.hardLine - c.flatten.iterator.map(singleToDoc).reduce((a, b) => a + (twoLines + b)) + c.flatten.iterator + .map(singleToDoc) + .reduce((a, b) => a + (twoLines + b)) } finalDoc.render(80) } } - case class SourceConverterErrorIn(err: SourceConverter.Error, pack: PackageName) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class SourceConverterErrorIn( + err: SourceConverter.Error, + pack: PackageName + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) val msg = { val context = lm.showRegion(err.region, 2, errColor) - .getOrElse(Doc.str(err.region)) // we should highlight the whole region + .getOrElse( + Doc.str(err.region) + ) // we should highlight the whole region Doc.text(err.message) + Doc.hardLine + context } @@ -428,16 +589,27 @@ object PackageError { } } - case class TotalityCheckError(pack: PackageName, err: TotalityCheck.ExprError[Declaration]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class TotalityCheckError( + pack: PackageName, + err: TotalityCheck.ExprError[Declaration] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) val region = err.matchExpr.tag.region val context1 = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) // we should highlight the whole region + lm.showRegion(region, 2, errColor) + .getOrElse(Doc.str(region)) // we should highlight the whole region val teMessage = err match { case TotalityCheck.NonTotalMatch(_, missing) => - val allTypes = missing.traverse(_.traverseType { t => Writer(Chain.one(t), ()) }) - .run._1.toList.distinct + val allTypes = missing + .traverse(_.traverseType { t => Writer(Chain.one(t), ()) }) + .run + ._1 + .toList + .distinct val showT = showTypes(pack, allTypes) val doc = Pattern.compiledDocument(Document.instance[Type] { t => @@ -445,11 +617,17 @@ object PackageError { }) Doc.text("non-total match, missing: ") + - (Doc.intercalate(Doc.char(',') + Doc.lineOrSpace, - missing.toList.map(doc.document(_)))) + (Doc.intercalate( + Doc.char(',') + Doc.lineOrSpace, + missing.toList.map(doc.document(_)) + )) case TotalityCheck.UnreachableBranches(_, unreachableBranches) => - val allTypes = unreachableBranches.traverse(_.traverseType { t => Writer(Chain.one(t), ()) }) - .run._1.toList.distinct + val allTypes = unreachableBranches + .traverse(_.traverseType { t => Writer(Chain.one(t), ()) }) + .run + ._1 + .toList + .distinct val showT = showTypes(pack, allTypes) val doc = Pattern.compiledDocument(Document.instance[Type] { t => @@ -457,13 +635,17 @@ object PackageError { }) Doc.text("unreachable branches: ") + - (Doc.intercalate(Doc.char(',') + Doc.lineOrSpace, - unreachableBranches.toList.map(doc.document(_)))) + (Doc.intercalate( + Doc.char(',') + Doc.lineOrSpace, + unreachableBranches.toList.map(doc.document(_)) + )) case TotalityCheck.InvalidPattern(_, err) => import TotalityCheck._ err match { case ArityMismatch((_, n), _, _, exp, found) => - Doc.text(s"arity mismatch: ${n.asString} expected $exp parameters, found $found") + Doc.text( + s"arity mismatch: ${n.asString} expected $exp parameters, found $found" + ) case UnknownConstructor((_, n), _, _) => Doc.text(s"unknown constructor: ${n.asString}") case InvalidStrPat(pat, _) => @@ -472,8 +654,10 @@ object PackageError { Doc.text(" (adjacent string bindings aren't allowed)") case MultipleSplicesInPattern(_, _) => // TODO: get printing of compiled patterns working well - //val docp = Document[Pattern.Parsed].document(Pattern.ListPat(pat)) + - Doc.text("multiple splices in pattern, only one per match allowed") + // val docp = Document[Pattern.Parsed].document(Pattern.ListPat(pat)) + + Doc.text( + "multiple splices in pattern, only one per match allowed" + ) } } val prefix = sourceMap.headLine(pack, Some(region)) @@ -484,27 +668,43 @@ object PackageError { } } - case class UnusedLetError(pack: PackageName, errs: NonEmptyList[(Identifier.Bindable, Region)]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class UnusedLetError( + pack: PackageName, + errs: NonEmptyList[(Identifier.Bindable, Region)] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) val docs = errs .sortBy(_._2) .map { case (bn, region) => - val rdoc = lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) // we should highlight the whole region + val rdoc = lm + .showRegion(region, 2, errColor) + .getOrElse(Doc.str(region)) // we should highlight the whole region val message = Doc.text("unused let binding: " + bn.sourceCodeRepr) message + Doc.hardLine + rdoc } val packDoc = sourceMap.headLine(pack, Some(errs.head._2)) val line2 = Doc.hardLine + Doc.hardLine - (packDoc + (line2 + Doc.intercalate(line2, docs.toList)).nested(2)).render(80) + (packDoc + (line2 + Doc.intercalate(line2, docs.toList)).nested(2)) + .render(80) } } - case class RecursionError(pack: PackageName, err: DefRecursionCheck.RecursionError) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class RecursionError( + pack: PackageName, + err: DefRecursionCheck.RecursionError + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) - val ctx = lm.showRegion(err.region, 2, errColor) + val ctx = lm + .showRegion(err.region, 2, errColor) .getOrElse(Doc.str(err.region)) // we should highlight the whole region val errMessage = err.message // TODO use the sourceMap/regions in RecursionError @@ -516,17 +716,22 @@ object PackageError { } } - case class DuplicatedPackageError(dups: NonEmptyMap[PackageName, (String, NonEmptyList[String])]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class DuplicatedPackageError( + dups: NonEmptyMap[PackageName, (String, NonEmptyList[String])] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val packDoc = Doc.text("package ") val dupInDoc = Doc.text(" duplicated in ") - val dupMessages = dups - .toSortedMap + val dupMessages = dups.toSortedMap .map { case (pname, (one, nelist)) => - val dupsrcs = Doc.intercalate(Doc.comma + Doc.lineOrSpace, - (one :: nelist.toList) - .sorted - .map(Doc.text(_)) + val dupsrcs = Doc + .intercalate( + Doc.comma + Doc.lineOrSpace, + (one :: nelist.toList).sorted + .map(Doc.text(_)) ) .nested(4) packDoc + Doc.text(pname.asString) + dupInDoc + dupsrcs @@ -536,36 +741,54 @@ object PackageError { } } - case class KindInferenceError(pack: PackageName, kindError: KindFormula.Error, regions: Map[Type.Const.Defined, Region]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class KindInferenceError( + pack: PackageName, + kindError: KindFormula.Error, + regions: Map[Type.Const.Defined, Region] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) val region = regions(kindError.dt.toTypeConst) - val ctx = lm.showRegion(region, 2, errColor) + val ctx = lm + .showRegion(region, 2, errColor) .getOrElse(Doc.str(region)) // we should highlight the whole region val prefix = sourceMap.headLine(pack, Some(region)) val message = kindError match { - case KindFormula.Error.Unsatisfiable(_, _, _, _) => - // TODO: would be good to give a more precise problem, e.g. which - // type parameters are the problem. + case KindFormula.Error.Unsatisfiable(_, _, _, _) => + // TODO: would be good to give a more precise problem, e.g. which + // type parameters are the problem. Doc.text("could not solve for valid variances") case KindFormula.Error.FromShapeError(se) => se match { case Shape.UnificationError(_, cons, left, right) => - Doc.text("shape error: expected ") + Shape.shapeDoc(left) + Doc.text(" and ") + Shape.shapeDoc(right) + - Doc.text(s" to match in the constructor ${cons.name.sourceCodeRepr}") + Doc.hardLine + Doc.text("shape error: expected ") + Shape.shapeDoc(left) + Doc + .text(" and ") + Shape.shapeDoc(right) + + Doc.text( + s" to match in the constructor ${cons.name.sourceCodeRepr}" + ) + Doc.hardLine case Shape.ShapeMismatch(_, cons, outer, tyApp, right) => val tmap = showTypes(pack, outer :: tyApp :: Nil) val typeDoc = - if (outer != tyApp) (tmap(outer) + Doc.text(" at application ") + tmap(tyApp)) + if (outer != tyApp) + (tmap(outer) + Doc.text(" at application ") + tmap(tyApp)) else tmap(outer) - Doc.text("shape error: expected ") + Shape.shapeDoc(right) + Doc.text(" -> ?") + Doc.text(" but found * ") + - Doc.text(s"in the constructor ${cons.name.sourceCodeRepr} inside type ") + - typeDoc + + Doc.text("shape error: expected ") + Shape.shapeDoc(right) + Doc + .text(" -> ?") + Doc.text(" but found * ") + + Doc.text( + s"in the constructor ${cons.name.sourceCodeRepr} inside type " + ) + + typeDoc + Doc.hardLine case Shape.FinishFailure(dt, left, right) => - val tdoc = showTypes(pack, dt.toTypeTyConst :: Nil)(dt.toTypeTyConst) - Doc.text("in type ") + tdoc + Doc.text(" could not unify shapes: ") + Shape.shapeDoc(left) + Doc.text(" and ") + + val tdoc = + showTypes(pack, dt.toTypeTyConst :: Nil)(dt.toTypeTyConst) + Doc.text("in type ") + tdoc + Doc.text( + " could not unify shapes: " + ) + Shape.shapeDoc(left) + Doc.text(" and ") + Shape.shapeDoc(right) case Shape.ShapeLoop(dt, tpe, _) => val tpe2 = tpe match { @@ -574,24 +797,30 @@ object PackageError { } val tdocs = showTypes(pack, dt.toTypeTyConst :: tpe2 :: Nil) - Doc.text("in type ") + tdocs(dt.toTypeTyConst) + Doc.text(" cyclic dependency encountered in ") + + Doc.text("in type ") + tdocs(dt.toTypeTyConst) + Doc.text( + " cyclic dependency encountered in " + ) + tdocs(tpe2) case Shape.UnboundVar(dt, cfn, v) => val tpe2 = Type.TyVar(v) val tdocs = showTypes(pack, dt.toTypeTyConst :: tpe2 :: Nil) - val cfnMsg = if (dt.isStruct) Doc.empty else { - Doc.text(s" in constructor ${cfn.name.sourceCodeRepr} ") - } + val cfnMsg = + if (dt.isStruct) Doc.empty + else { + Doc.text(s" in constructor ${cfn.name.sourceCodeRepr} ") + } Doc.text("in type ") + tdocs(dt.toTypeTyConst) + Doc.text(" unbound type variable ") + tdocs(tpe2) + cfnMsg case Shape.UnknownConst(dt, cfn, c) => val tpe2 = Type.TyConst(c) val tdocs = showTypes(pack, dt.toTypeTyConst :: tpe2 :: Nil) - val cfnMsg = if (dt.isStruct) Doc.empty else { - Doc.text(s" in constructor ${cfn.name.sourceCodeRepr} ") - } + val cfnMsg = + if (dt.isStruct) Doc.empty + else { + Doc.text(s" in constructor ${cfn.name.sourceCodeRepr} ") + } Doc.text("in type ") + tdocs(dt.toTypeTyConst) + Doc.text(" unknown type ") + tdocs(tpe2) + cfnMsg } @@ -600,16 +829,26 @@ object PackageError { } } - case class UnusedImport(inPack: PackageName, badImports: NonEmptyList[Import[PackageName, Unit]]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class UnusedImport( + inPack: PackageName, + badImports: NonEmptyList[Import[PackageName, Unit]] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val prefix = sourceMap.headLine(inPack, None) - val di = (Doc.hardLine + Doc.intercalate(Doc.hardLine, - badImports.toList.map(Document[Import[PackageName, Unit]].document(_)) - )) + val di = (Doc.hardLine + Doc.intercalate( + Doc.hardLine, + badImports.toList.map(Document[Import[PackageName, Unit]].document(_)) + )) .nested(2) - val imports = if (badImports.tail.lengthCompare(0) == 0) "import" else "imports" - (prefix + Doc.hardLine + Doc.text(s"unused $imports of:") + di + Doc.hardLine).render(80) + val imports = + if (badImports.tail.lengthCompare(0) == 0) "import" else "imports" + (prefix + Doc.hardLine + Doc.text( + s"unused $imports of:" + ) + di + Doc.hardLine).render(80) } } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala b/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala index 37a584bd2..bedf25506 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala @@ -2,7 +2,15 @@ package org.bykn.bosatsu import org.bykn.bosatsu.graph.Memoize import cats.{Foldable, Monad, Show} -import cats.data.{Ior, IorT, NonEmptyList, NonEmptyMap, Validated, ValidatedNel, ReaderT} +import cats.data.{ + Ior, + IorT, + NonEmptyList, + NonEmptyMap, + Validated, + ValidatedNel, + ReaderT +} import scala.collection.immutable.SortedMap import Identifier.Constructor @@ -12,45 +20,56 @@ import rankn.{DataRepr, TypeEnv} import cats.implicits._ -case class PackageMap[A, B, C, +D](toMap: SortedMap[PackageName, Package[A, B, C, D]]) { +case class PackageMap[A, B, C, +D]( + toMap: SortedMap[PackageName, Package[A, B, C, D]] +) { def +[D1 >: D](pack: Package[A, B, C, D1]): PackageMap[A, B, C, D1] = PackageMap(toMap + (pack.name -> pack)) - def ++[D1 >: D](packs: Iterable[Package[A, B, C, D1]]): PackageMap[A, B, C, D1] = + def ++[D1 >: D]( + packs: Iterable[Package[A, B, C, D1]] + ): PackageMap[A, B, C, D1] = packs.foldLeft(this: PackageMap[A, B, C, D1])(_ + _) - def getDataRepr(implicit ev: D <:< Program[TypeEnv[Any], Any, Any]): (PackageName, Constructor) => Option[DataRepr] = { - (pname, cons) => - toMap.get(pname) - .flatMap { pack => - ev(pack.program) - .types - .getConstructor(pname, cons) - .map(_._1.dataRepr(cons)) - } - } - - def allExternals(implicit ev: D <:< Program[TypeEnv[Any], Any, Any]): Map[PackageName, List[Identifier.Bindable]] = + def getDataRepr(implicit + ev: D <:< Program[TypeEnv[Any], Any, Any] + ): (PackageName, Constructor) => Option[DataRepr] = { (pname, cons) => toMap - .iterator - .map { case (name, pack) => - (name, ev(pack.program).externalDefs) + .get(pname) + .flatMap { pack => + ev(pack.program).types + .getConstructor(pname, cons) + .map(_._1.dataRepr(cons)) } - .toMap + } + + def allExternals(implicit + ev: D <:< Program[TypeEnv[Any], Any, Any] + ): Map[PackageName, List[Identifier.Bindable]] = + toMap.iterator.map { case (name, pack) => + (name, ev(pack.program).externalDefs) + }.toMap } object PackageMap { def empty[A, B, C, D]: PackageMap[A, B, C, D] = PackageMap(SortedMap.empty) - def fromIterable[A, B, C, D](ps: Iterable[Package[A, B, C, D]]): PackageMap[A, B, C, D] = + def fromIterable[A, B, C, D]( + ps: Iterable[Package[A, B, C, D]] + ): PackageMap[A, B, C, D] = empty[A, B, C, D] ++ ps import Package.FixPackage type MapF3[A, B, C] = PackageMap[FixPackage[A, B, C], A, B, C] type MapF2[A, B] = MapF3[A, A, B] - type ParsedImp = PackageMap[PackageName, Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])] + type ParsedImp = PackageMap[ + PackageName, + Unit, + Unit, + (List[Statement], ImportMap[PackageName, Unit]) + ] type Resolved = MapF2[Unit, (List[Statement], ImportMap[PackageName, Unit])] type Typed[+T] = PackageMap[ Package.Interface, @@ -62,7 +81,7 @@ object PackageMap { Any ] ] - + type SourceMap = Map[PackageName, (LocationMap, String)] // convenience for type inference @@ -70,65 +89,97 @@ object PackageMap { type Inferred = Typed[Declaration] - /** - * This builds a DAG of actual packages where names have been replaced by the fully resolved - * packages - */ - def resolvePackages[A, B, C](map: PackageMap[PackageName, A, B, C], ifs: List[Package.Interface]): ValidatedNel[PackageError, MapF3[A, B, C]] = { + /** This builds a DAG of actual packages where names have been replaced by the + * fully resolved packages + */ + def resolvePackages[A, B, C]( + map: PackageMap[PackageName, A, B, C], + ifs: List[Package.Interface] + ): ValidatedNel[PackageError, MapF3[A, B, C]] = { val interfaceMap = ifs.iterator.map { iface => (iface.name, iface) }.toMap def getPackage( - i: Import[PackageName, A], - from: Package[PackageName, A, B, C]): ValidatedNel[PackageError, Import[Either[Package.Interface, Package[PackageName, A, B, C]], A]] = - map.toMap.get(i.pack) match { - case Some(pack) => Validated.valid(Import(Right(pack), i.items)) - case None => - interfaceMap.get(i.pack) match { - case Some(iface) => - Validated.valid(Import(Left(iface), i.items)) - case None => - Validated.invalidNel(PackageError.UnknownImportPackage(i.pack, from.name)) - } - } + i: Import[PackageName, A], + from: Package[PackageName, A, B, C] + ): ValidatedNel[PackageError, Import[ + Either[Package.Interface, Package[PackageName, A, B, C]], + A + ]] = + map.toMap.get(i.pack) match { + case Some(pack) => Validated.valid(Import(Right(pack), i.items)) + case None => + interfaceMap.get(i.pack) match { + case Some(iface) => + Validated.valid(Import(Left(iface), i.items)) + case None => + Validated.invalidNel( + PackageError.UnknownImportPackage(i.pack, from.name) + ) + } + } type PackageFix = Package[FixPackage[A, B, C], A, B, C] // We use the ReaderT to build the list of imports we are on // to detect circular dependencies, if the current package imports itself transitively we // want to report the full path - val step: Package[PackageName, A, B, C] => ReaderT[Either[NonEmptyList[PackageError], *], List[PackageName], PackageFix] = - Memoize.memoizeDagHashed[Package[PackageName, A, B, C], ReaderT[Either[NonEmptyList[PackageError], *], List[PackageName], PackageFix]] { (p, rec) => - val edeps = ReaderT.ask[Either[NonEmptyList[PackageError], *], List[PackageName]] - .flatMapF { - case nonE@(h :: tail) if nonE.contains(p.name) => - Left(NonEmptyList.of(PackageError.CircularDependency(p.name, NonEmptyList(h, tail)))) - case _ => - val deps = p.imports.traverse(getPackage(_, p)) // the packages p depends on - deps.toEither - } - - edeps - .flatMap { (deps: List[Import[Either[Package.Interface, Package[PackageName, A, B, C]], A]]) => - deps.traverse { i => - i.pack match { - case Right(pack) => - rec(pack) - .local[List[PackageName]](p.name :: _) // add this package into the path of all the deps - .map { p => Import(Package.fix[A, B, C](Right(p)), i.items) } - case Left(iface) => - ReaderT.pure[ - Either[NonEmptyList[PackageError], *], - List[PackageName], - Import[FixPackage[A, B, C], A]](Import(Package.fix[A, B, C](Left(iface)), i.items)) - } + val step: Package[PackageName, A, B, C] => ReaderT[Either[NonEmptyList[ + PackageError + ], *], List[PackageName], PackageFix] = + Memoize.memoizeDagHashed[Package[PackageName, A, B, C], ReaderT[ + Either[NonEmptyList[PackageError], *], + List[PackageName], + PackageFix + ]] { (p, rec) => + val edeps = ReaderT + .ask[Either[NonEmptyList[PackageError], *], List[PackageName]] + .flatMapF { + case nonE @ (h :: tail) if nonE.contains(p.name) => + Left( + NonEmptyList.of( + PackageError.CircularDependency(p.name, NonEmptyList(h, tail)) + ) + ) + case _ => + val deps = p.imports.traverse( + getPackage(_, p) + ) // the packages p depends on + deps.toEither } - .map { imports => - Package(p.name, imports, p.exports, p.program) + + edeps + .flatMap { + (deps: List[Import[ + Either[Package.Interface, Package[PackageName, A, B, C]], + A + ]]) => + deps + .traverse { i => + i.pack match { + case Right(pack) => + rec(pack) + .local[List[PackageName]]( + p.name :: _ + ) // add this package into the path of all the deps + .map { p => + Import(Package.fix[A, B, C](Right(p)), i.items) + } + case Left(iface) => + ReaderT.pure[Either[NonEmptyList[PackageError], *], List[ + PackageName + ], Import[FixPackage[A, B, C], A]]( + Import(Package.fix[A, B, C](Left(iface)), i.items) + ) + } + } + .map { imports => + Package(p.name, imports, p.exports, p.program) + } } - } - } + } type M = SortedMap[PackageName, PackageFix] - val r: ReaderT[Either[NonEmptyList[PackageError], *], List[PackageName], M] = + val r + : ReaderT[Either[NonEmptyList[PackageError], *], List[PackageName], M] = map.toMap.traverse(step) // we start with no imports on @@ -137,41 +188,69 @@ object PackageMap { m.map(PackageMap(_)).toValidated } - /** - * Convenience method to create a PackageMap then resolve it - */ - def resolveAll[A: Show](ps: List[(A, Package.Parsed)], ifs: List[Package.Interface]): Ior[NonEmptyList[PackageError], Resolved] = { + /** Convenience method to create a PackageMap then resolve it + */ + def resolveAll[A: Show]( + ps: List[(A, Package.Parsed)], + ifs: List[Package.Interface] + ): Ior[NonEmptyList[PackageError], Resolved] = { type AP = (A, Package.Parsed) - val (nonUnique, unique): (SortedMap[PackageName, (AP, NonEmptyList[AP])], SortedMap[PackageName, AP]) = + val (nonUnique, unique): ( + SortedMap[PackageName, (AP, NonEmptyList[AP])], + SortedMap[PackageName, AP] + ) = NonEmptyList.fromList(ps) match { case Some(neps) => - CollectionUtils.uniqueByKey(neps)(_._2.name) + CollectionUtils + .uniqueByKey(neps)(_._2.name) .fold( { a => (a.toSortedMap, SortedMap.empty[PackageName, AP]) }, - { b => (SortedMap.empty[PackageName, (AP, NonEmptyList[AP])], b.toSortedMap) }, + { b => + ( + SortedMap.empty[PackageName, (AP, NonEmptyList[AP])], + b.toSortedMap + ) + }, { (a, b) => (a.toSortedMap, b.toSortedMap) } ) case None => - (SortedMap.empty[PackageName, (AP, NonEmptyList[AP])], SortedMap.empty[PackageName, AP]) + ( + SortedMap.empty[PackageName, (AP, NonEmptyList[AP])], + SortedMap.empty[PackageName, AP] + ) } - def toProg(p: Package.Parsed): - (Option[PackageError], - Package[PackageName, Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])]) = { + def toProg(p: Package.Parsed): ( + Option[PackageError], + Package[ + PackageName, + Unit, + Unit, + (List[Statement], ImportMap[PackageName, Unit]) + ] + ) = { val (errs0, imap) = ImportMap.fromImports(p.imports) val errs = - NonEmptyList.fromList(errs0) + NonEmptyList + .fromList(errs0) .map(PackageError.DuplicatedImport) (errs, p.mapProgram((_, imap))) } // we know all the package names are unique here - def foldMap(m: Map[PackageName, (A, Package.Parsed)]): (List[PackageError], PackageMap.ParsedImp) = { + def foldMap( + m: Map[PackageName, (A, Package.Parsed)] + ): (List[PackageError], PackageMap.ParsedImp) = { val initPm = PackageMap - .empty[PackageName, Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])] + .empty[ + PackageName, + Unit, + Unit, + (List[Statement], ImportMap[PackageName, Unit]) + ] m.iterator.foldLeft((List.empty[PackageError], initPm)) { case ((errs, pm), (_, (_, pack))) => @@ -195,9 +274,13 @@ object PackageMap { NonEmptyMap.fromMap(nonUnique) match { case Some(nenu) => val paths = nenu.map { case ((a, _), rest) => - (a.show, rest.map(_._1.show)) + (a.show, rest.map(_._1.show)) } - Ior.left(NonEmptyList.one[PackageError](PackageError.DuplicatedPackageError(paths))) + Ior.left( + NonEmptyList.one[PackageError]( + PackageError.DuplicatedPackageError(paths) + ) + ) case None => Ior.right(()) } @@ -205,10 +288,11 @@ object PackageMap { (nuEr, check, res.toIor).parMapN { (_, _, r) => r } } - /** - * Infer all the types in a resolved PackageMap - */ - def inferAll(ps: Resolved)(implicit cpuEC: Par.EC): Ior[NonEmptyList[PackageError], Inferred] = { + /** Infer all the types in a resolved PackageMap + */ + def inferAll( + ps: Resolved + )(implicit cpuEC: Par.EC): Ior[NonEmptyList[PackageError], Inferred] = { import Par.F @@ -217,48 +301,65 @@ object PackageMap { FixPackage[Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])], Unit, Unit, - (List[Statement], ImportMap[PackageName, Unit])] + (List[Statement], ImportMap[PackageName, Unit]) + ] type FutVal[A] = IorT[F, NonEmptyList[PackageError], A] /* * We memoize this function to avoid recomputing diamond dependencies */ - val infer0: ResolvedU => Par.F[Ior[NonEmptyList[PackageError], (TypeEnv[Kind.Arg], Package.Inferred)]] = - Memoize.memoizeDagFuture[ResolvedU, Ior[NonEmptyList[PackageError], (TypeEnv[Kind.Arg], Package.Inferred)]] { + val infer0: ResolvedU => Par.F[ + Ior[NonEmptyList[PackageError], (TypeEnv[Kind.Arg], Package.Inferred)] + ] = + Memoize.memoizeDagFuture[ResolvedU, Ior[NonEmptyList[ + PackageError + ], (TypeEnv[Kind.Arg], Package.Inferred)]] { // TODO, we ignore importMap here, we only check earlier we don't // have duplicate imports case (Package(nm, imports, exports, (stmt, _)), recurse) => - - def getImport[A, B](packF: Package.Inferred, - exMap: Map[Identifier, NonEmptyList[ExportedName[A]]], - i: ImportedName[B]): Ior[NonEmptyList[PackageError], ImportedName[NonEmptyList[A]]] = + def getImport[A, B]( + packF: Package.Inferred, + exMap: Map[Identifier, NonEmptyList[ExportedName[A]]], + i: ImportedName[B] + ): Ior[NonEmptyList[PackageError], ImportedName[NonEmptyList[A]]] = exMap.get(i.originalName) match { case None => - Ior.left(NonEmptyList.one( - PackageError.UnknownImportName( - nm, - packF.name, - packF - .program - .lets - .iterator - .map { case (n, _, _) => (n: Identifier, ()) }.toMap, - i, - exMap.iterator.flatMap(_._2.toList).toList))) + Ior.left( + NonEmptyList.one( + PackageError.UnknownImportName( + nm, + packF.name, + packF.program.lets.iterator.map { case (n, _, _) => + (n: Identifier, ()) + }.toMap, + i, + exMap.iterator.flatMap(_._2.toList).toList + ) + ) + ) case Some(exps) => val bs = exps.map(_.tag) Ior.right(i.map(_ => bs)) } - def getImportIface[A, B](packF: Package.Interface, - exMap: Map[Identifier, NonEmptyList[ExportedName[A]]], - i: ImportedName[B]): Ior[NonEmptyList[PackageError], ImportedName[NonEmptyList[A]]] = + def getImportIface[A, B]( + packF: Package.Interface, + exMap: Map[Identifier, NonEmptyList[ExportedName[A]]], + i: ImportedName[B] + ): Ior[NonEmptyList[PackageError], ImportedName[NonEmptyList[A]]] = exMap.get(i.originalName) match { case None => - Ior.left(NonEmptyList.one( - PackageError.UnknownImportFromInterface( - nm, packF.name, packF.exports.map(_.name), i, - exMap.iterator.flatMap(_._2.toList).toList))) + Ior.left( + NonEmptyList.one( + PackageError.UnknownImportFromInterface( + nm, + packF.name, + packF.exports.map(_.name), + i, + exMap.iterator.flatMap(_._2.toList).toList + ) + ) + ) case Some(exps) => val bs = exps.map(_.tag) Ior.right(i.map(_ => bs)) @@ -271,8 +372,11 @@ object PackageMap { * type can have the same name as a constructor. After this step, each * distinct object has its own entry in the list */ - type ImpRes = Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]] - def stepImport(imp: Import[Package.Resolved, Unit]): FutVal[ImpRes] = { + type ImpRes = + Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]] + def stepImport( + imp: Import[Package.Resolved, Unit] + ): FutVal[ImpRes] = { val Import(fixpack, items) = imp Package.unfix(fixpack) match { case Right(p) => @@ -308,47 +412,57 @@ object PackageMap { inferImports .flatMap { imps => // run this in a thread - IorT(Par.start(Package.inferBodyUnopt(nm, imps, stmt).map((imps, _)))) + IorT( + Par.start( + Package.inferBodyUnopt(nm, imps, stmt).map((imps, _)) + ) + ) } - inferBody - .flatMap { case (imps, (fte, program@Program(types, lets, _, _))) => + inferBody.flatMap { + case (imps, (fte, program @ Program(types, lets, _, _))) => val ior = ExportedName .buildExports(nm, exports, types, lets) match { - case Validated.Valid(exports) => - // We have a result, which we can continue to check - val pack = Package(nm, imps, exports, program) - val res = (fte, pack) - PackageCustoms(pack) match { - case Validated.Valid(p1) => Ior.right((fte, p1)) - case Validated.Invalid(errs) => - Ior.both(errs.toNonEmptyList, res) - } - case Validated.Invalid(badPackages) => - Ior.left(badPackages.map { n => - PackageError.UnknownExport(n, nm, lets): PackageError - }) - } + case Validated.Valid(exports) => + // We have a result, which we can continue to check + val pack = Package(nm, imps, exports, program) + val res = (fte, pack) + PackageCustoms(pack) match { + case Validated.Valid(p1) => Ior.right((fte, p1)) + case Validated.Invalid(errs) => + Ior.both(errs.toNonEmptyList, res) + } + case Validated.Invalid(badPackages) => + Ior.left(badPackages.map { n => + PackageError.UnknownExport(n, nm, lets): PackageError + }) + } IorT.fromIor(ior) - } - .value - } + }.value + } /* * Since Par.F is starts computation when start is called * we want to start all the computations *then* collect * the result together */ - val infer: ResolvedU => Par.F[Ior[NonEmptyList[PackageError], Package.Inferred]] = + val infer: ResolvedU => Par.F[ + Ior[NonEmptyList[PackageError], Package.Inferred] + ] = infer0.andThen { parF => // As soon as each Par.F is complete, we can start normalizing that one Monad[Par.F].flatMap(parF) { ior => - ior.traverse { - case (fte, pack) => - Par.start { - val optPack = pack.copy(program = TypedExprNormalization.normalizeProgram(pack.name, fte, pack.program)) - Package.discardUnused(optPack) - } + ior.traverse { case (fte, pack) => + Par.start { + val optPack = pack.copy(program = + TypedExprNormalization.normalizeProgram( + pack.name, + fte, + pack.program + ) + ) + Package.discardUnused(optPack) + } } } } @@ -356,58 +470,82 @@ object PackageMap { val fut = ps.toMap.parTraverse(infer.andThen(IorT(_))) // Wait until all the resolution is complete - Par.await(fut.value) + Par + .await(fut.value) .map(PackageMap(_)) } def resolveThenInfer[A: Show]( - ps: List[(A, Package.Parsed)], - ifs: List[Package.Interface])(implicit cpuEC: Par.EC): Ior[NonEmptyList[PackageError], Inferred] = - resolveAll(ps, ifs).flatMap(inferAll) - - def buildSourceMap[F[_]: Foldable, A](parsedFiles: F[((A, LocationMap), Package.Parsed)]): Map[PackageName, (LocationMap, String)] = - parsedFiles.foldLeft(Map.empty[PackageName, (LocationMap, String)]) { case (map, ((path, lm), pack)) => - map.updated(pack.name, (lm, path.toString)) + ps: List[(A, Package.Parsed)], + ifs: List[Package.Interface] + )(implicit cpuEC: Par.EC): Ior[NonEmptyList[PackageError], Inferred] = + resolveAll(ps, ifs).flatMap(inferAll) + + def buildSourceMap[F[_]: Foldable, A]( + parsedFiles: F[((A, LocationMap), Package.Parsed)] + ): Map[PackageName, (LocationMap, String)] = + parsedFiles.foldLeft(Map.empty[PackageName, (LocationMap, String)]) { + case (map, ((path, lm), pack)) => + map.updated(pack.name, (lm, path.toString)) } /** typecheck a list of packages given a list of interface dependencies - * - * @param packs a list of parsed packages, along with a key A to tag the source - * @param ifs the interfaces we are compiling against. If Bosatsu.Predef is not in this list, the default is added - */ + * + * @param packs + * a list of parsed packages, along with a key A to tag the source + * @param ifs + * the interfaces we are compiling against. If Bosatsu.Predef is not in + * this list, the default is added + */ def typeCheckParsed[A: Show]( - packs: NonEmptyList[((A, LocationMap), Package.Parsed)], - ifs: List[Package.Interface], - predefKey: A)(implicit cpuEC: Par.EC): Ior[NonEmptyList[PackageError], PackageMap.Inferred] = { + packs: NonEmptyList[((A, LocationMap), Package.Parsed)], + ifs: List[Package.Interface], + predefKey: A + )(implicit + cpuEC: Par.EC + ): Ior[NonEmptyList[PackageError], PackageMap.Inferred] = { // if we have passed in a use supplied predef, don't use the internal one - val useInternalPredef = !ifs.exists { (p: Package.Interface) => p.name == PackageName.PredefName } + val useInternalPredef = !ifs.exists { (p: Package.Interface) => + p.name == PackageName.PredefName + } // Now we have completed all IO, here we do all the checks we need for correctness val parsed = - if (useInternalPredef) withPredefA[(A, LocationMap)]((predefKey, LocationMap("")), packs.toList) - else withPredefImportsA[(A, LocationMap)](packs.toList) - - PackageMap.resolveThenInfer[A]( - parsed.map { case ((a, _), p) => (a, p) }, - ifs) + if (useInternalPredef) + withPredefA[(A, LocationMap)]( + (predefKey, LocationMap("")), + packs.toList + ) + else withPredefImportsA[(A, LocationMap)](packs.toList) + + PackageMap + .resolveThenInfer[A](parsed.map { case ((a, _), p) => (a, p) }, ifs) } - /** - * Here is the fully compiled Predef - */ + /** Here is the fully compiled Predef + */ val predefCompiled: Package.Inferred = { import DirectEC.directEC - //implicit val showUnit: Show[Unit] = Show.show[Unit](_ => "predefCompiled") - val inferred = PackageMap.resolveThenInfer(((), Package.predefPackage) :: Nil, Nil).strictToValidated + // implicit val showUnit: Show[Unit] = Show.show[Unit](_ => "predefCompiled") + val inferred = PackageMap + .resolveThenInfer(((), Package.predefPackage) :: Nil, Nil) + .strictToValidated inferred match { case Validated.Valid(v) => v.toMap.get(PackageName.PredefName) match { - case None => sys.error("internal error: predef package not found after compilation") + case None => + sys.error( + "internal error: predef package not found after compilation" + ) case Some(inf) => inf } case Validated.Invalid(errs) => - val map = Map(PackageName.PredefName -> (LocationMap(Predef.predefString), "")) + val map = Map( + PackageName.PredefName -> (LocationMap( + Predef.predefString + ), "") + ) errs.iterator.foreach { err => println(err.message(map, LocationMap.Colorize.None)) } @@ -424,13 +562,18 @@ object PackageMap { private val predefImports: Import[PackageName, Unit] = Import(PackageName.PredefName, NonEmptyList.fromList(predefImportList).get) - private def withPredefImportsA[A](ps: List[(A, Package.Parsed)]): List[(A, Package.Parsed)] = + private def withPredefImportsA[A]( + ps: List[(A, Package.Parsed)] + ): List[(A, Package.Parsed)] = ps.map { case (a, p) => (a, p.withImport(predefImports)) } def withPredef(ps: List[Package.Parsed]): List[Package.Parsed] = Package.predefPackage :: ps.map(_.withImport(predefImports)) - def withPredefA[A](predefA: A, ps: List[(A, Package.Parsed)]): List[(A, Package.Parsed)] = + def withPredefA[A]( + predefA: A, + ps: List[(A, Package.Parsed)] + ): List[(A, Package.Parsed)] = (predefA, Package.predefPackage) :: withPredefImportsA(ps) } diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageName.scala b/core/src/main/scala/org/bykn/bosatsu/PackageName.scala index db70b6c77..b3b871d9f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageName.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageName.scala @@ -16,7 +16,7 @@ case class PackageName(parts: NonEmptyList[String]) { object PackageName { def parts(first: String, rest: String*): PackageName = - PackageName(NonEmptyList.of(first, rest :_*)) + PackageName(NonEmptyList.of(first, rest: _*)) implicit val document: Document[PackageName] = Document.instance[PackageName] { pn => Doc.text(pn.asString) } @@ -30,7 +30,7 @@ object PackageName { def parse(s: String): Option[PackageName] = parser.parse(s) match { case Right(("", pn)) => Some(pn) - case _ => None + case _ => None } implicit val order: Order[PackageName] = @@ -42,4 +42,3 @@ object PackageName { val PredefName: PackageName = PackageName(NonEmptyList.of("Bosatsu", "Predef")) } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Padding.scala b/core/src/main/scala/org/bykn/bosatsu/Padding.scala index e37962c9d..f9286d121 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Padding.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Padding.scala @@ -2,7 +2,7 @@ package org.bykn.bosatsu import cats.Functor import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import Parser.maybeSpace @@ -20,9 +20,8 @@ object Padding { Doc.line.repeat(padding.lines) + Document[T].document(padding.padded) } - /** - * This allows an empty padding - */ + /** This allows an empty padding + */ def parser[T](p: P[T]): P[Padding[T]] = { val spacing = (maybeSpace.with1.soft ~ Parser.newline).void.rep0 @@ -30,17 +29,14 @@ object Padding { .map { case (vec, t) => Padding(vec.size, t) } } - /** - * Parses a padding of length 1 or more, then p - */ + /** Parses a padding of length 1 or more, then p + */ def parser1[T](p: P0[T]): P[Padding[T]] = ((maybeSpace.with1.soft ~ Parser.newline).void.rep ~ p) .map { case (vec, t) => Padding(vec.size, t) } - /** - * This is parser1 by itself, with the padded value being () - */ + /** This is parser1 by itself, with the padded value being () + */ val nonEmptyParser: P[Padding[Unit]] = parser1(P.unit) } - diff --git a/core/src/main/scala/org/bykn/bosatsu/ParallelViaProduct.scala b/core/src/main/scala/org/bykn/bosatsu/ParallelViaProduct.scala index 4021b0702..e2f10be29 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ParallelViaProduct.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ParallelViaProduct.scala @@ -30,4 +30,4 @@ abstract class ParallelViaProduct[G[_]] extends Parallel[G] { self => override def product[A, B](fa: F[A], fb: F[B]) = parallelProduct(fa, fb) } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/bykn/bosatsu/Parser.scala b/core/src/main/scala/org/bykn/bosatsu/Parser.scala index 442f9f1a0..89c9db043 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Parser.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Parser.scala @@ -9,12 +9,10 @@ import cats.implicits._ import java.math.BigInteger object Parser { - /** - * This is an indentation aware - * parser, the input is the string that - * should be parsed after a new-line to - * continue the current indentation block - */ + + /** This is an indentation aware parser, the input is the string that should + * be parsed after a new-line to continue the current indentation block + */ type Indy[A] = Kleisli[P, String, A] object Indy { @@ -24,9 +22,8 @@ object Parser { def lift[A](p: P[A]): Indy[A] = Kleisli.liftF(p) - /** - * Parse spaces, end of line, then the next indentation - */ + /** Parse spaces, end of line, then the next indentation + */ val toEOLIndent: Indy[Unit] = apply { indent => toEOL1 *> P.string0(indent) @@ -36,10 +33,8 @@ object Parser { def region: Indy[(Region, A)] = toKleisli.mapF(_.region) - /** - * Parse exactly the current indentation - * starting now - */ + /** Parse exactly the current indentation starting now + */ def indentBefore: Indy[A] = apply(indent => P.string0(indent).with1 *> toKleisli.run(indent)) @@ -65,16 +60,13 @@ object Parser { toKleisli(indent) *> that(indent) } - /** - * This optionally allows extra indentation that starts now - */ + /** This optionally allows extra indentation that starts now + */ def maybeMore: Parser.Indy[A] = Indy { indent => // run this one time, not each spaces are parsed val noIndent = toKleisli.run(indent) - val someIndent: P[A] = Parser - .spaces - .string + val someIndent: P[A] = Parser.spaces.string .flatMap { thisIndent => toKleisli.run(indent + thisIndent) } @@ -93,25 +85,35 @@ object Parser { } object Error { - case class ParseFailure(position: Int, locations: LocationMap, expected: NonEmptyList[P.Expectation]) extends Error - - def showExpectations(locations: LocationMap, expected: NonEmptyList[P.Expectation], errColor: LocationMap.Colorize): Doc = { - val errs: SortedMap[Int, NonEmptyList[P.Expectation]] = expected.groupBy(_.offset) + case class ParseFailure( + position: Int, + locations: LocationMap, + expected: NonEmptyList[P.Expectation] + ) extends Error + + def showExpectations( + locations: LocationMap, + expected: NonEmptyList[P.Expectation], + errColor: LocationMap.Colorize + ): Doc = { + val errs: SortedMap[Int, NonEmptyList[P.Expectation]] = + expected.groupBy(_.offset) def show(s: String): Doc = { val q = '\'' if (s.forall(_.isWhitespace)) { val chars = s.length val plural = if (chars == 1) "char" else "chars" - Doc.text(s"$chars whitespace $plural \"") + Doc.intercalate(Doc.empty, + Doc.text(s"$chars whitespace $plural \"") + Doc.intercalate( + Doc.empty, s.map { case '\t' => Doc.text("\\t") case '\n' => Doc.text("\\n") case '\r' => Doc.text("\\r") - case c => Doc.char(c) - }) + Doc.char('"') - } - else { + case c => Doc.char(c) + } + ) + Doc.char('"') + } else { Doc.char(q) + Doc.text(escape(q, s)) + Doc.char(q) } } @@ -123,22 +125,30 @@ object Parser { case one :: Nil => Doc.text("expected ") + show(one) case _ => - Doc.text("expected one of: ") + Doc.intercalate(Doc.line, strs.map(show)).grouped.nested(4) + Doc.text("expected one of: ") + Doc + .intercalate(Doc.line, strs.map(show)) + .grouped + .nested(4) } case P.Expectation.InRange(_, lower, upper) => if (lower == upper) { Doc.text("expected char: ") + show(lower.toString) + } else { + Doc.text("expected char in range: [") + show(lower.toString) + Doc + .text(", ") + show(upper.toString) + Doc.text("]") } - else { - Doc.text("expected char in range: [") + show(lower.toString) + Doc.text(", ") + show(upper.toString) + Doc.text("]") - } - case P.Expectation.StartOfString(_) => Doc.text("expected start of the file") + case P.Expectation.StartOfString(_) => + Doc.text("expected start of the file") case P.Expectation.EndOfString(_, length) => Doc.text(s"expected end of file but $length characters remaining") case P.Expectation.Length(_, expected, actual) => - Doc.text(s"expected $expected more characters but only $actual remaining") + Doc.text( + s"expected $expected more characters but only $actual remaining" + ) case P.Expectation.ExpectedFailureAt(_, matched) => - Doc.text("expected failure but the parser matched: ") + show(matched) + Doc.text("expected failure but the parser matched: ") + show( + matched + ) case P.Expectation.Fail(_) => Doc.text("failed") case P.Expectation.FailWith(_, message) => @@ -147,10 +157,14 @@ object Parser { expToDoc(expect) } - Doc.intercalate(Doc.hardLine, errs.map { case (pos, xs) => - locations.showContext(pos, 2, errColor).get + ( - Doc.hardLine + Doc.intercalate(Doc.comma + Doc.line, xs.toList.map(expToDoc)).grouped).nested(4) - }) + Doc.intercalate( + Doc.hardLine, + errs.map { case (pos, xs) => + locations.showContext(pos, 2, errColor).get + (Doc.hardLine + Doc + .intercalate(Doc.comma + Doc.line, xs.toList.map(expToDoc)) + .grouped).nested(4) + } + ) } } @@ -166,14 +180,16 @@ object Parser { } val identifierCharsP: P0[String] = - P.charIn('_' :: ('a' to 'z').toList ::: ('A' to 'Z').toList ::: ('0' to '9').toList).repAs0 + P.charIn( + '_' :: ('a' to 'z').toList ::: ('A' to 'Z').toList ::: ('0' to '9').toList + ).repAs0 // parse one or more space characters val spaces: P[Unit] = P.charIn(Set(' ', '\t')).rep.void val maybeSpace: P0[Unit] = spaces.?.void /** prefer to parse Right, then Left - */ + */ def either[A, B](pb: P0[B], pa: P0[A]): P0[Either[B, A]] = pa.map(Right(_)).orElse(pb.map(Left(_))) @@ -193,7 +209,9 @@ object Parser { (P.charIn('A' to 'Z') ~ identifierCharsP).string val py2Ident: P[String] = - (P.charIn('_' :: ('A' to 'Z').toList ::: ('a' to 'z').toList) ~ identifierCharsP).string + (P.charIn( + '_' :: ('A' to 'Z').toList ::: ('a' to 'z').toList + ) ~ identifierCharsP).string // parse a keyword and some space or backtrack def keySpace(str: String): P[Unit] = @@ -201,16 +219,13 @@ object Parser { val digit19: P[Char] = P.charIn('1' to '9') val digit09: P[Char] = P.charIn('0' to '9') - /** - * This parser allows _ between any two digits to allow - * literals such as: - * 1_000_000 - * - * It will also parse terrible examples like: - * 1_0_0_0_0_0_0 - * but I think banning things like that shouldn't - * be done by the parser - */ + + /** This parser allows _ between any two digits to allow literals such as: + * 1_000_000 + * + * It will also parse terrible examples like: 1_0_0_0_0_0_0 but I think + * banning things like that shouldn't be done by the parser + */ val positiveIntegerString: P[String] = { val rest = (P.char('_').?.with1 ~ digit09).rep0 val nonZero: P[Unit] = (digit19 ~ rest).void @@ -226,16 +241,15 @@ object Parser { val signs = "+" :: "-" :: Nil ((0 to 99).iterator.map { i => - (i.toString, BigInteger.valueOf(i.toLong)) + (i.toString, BigInteger.valueOf(i.toLong)) } ++ - (0 to 9).iterator.flatMap { i => - signs.map { sign => - if (sign == "-") { - (s"-$i", BigInteger.valueOf(-i.toLong)) + (0 to 9).iterator.flatMap { i => + signs.map { sign => + if (sign == "-") { + (s"-$i", BigInteger.valueOf(-i.toLong)) + } else (s"+$i", BigInteger.valueOf(i.toLong)) } - else (s"+$i", BigInteger.valueOf(i.toLong)) - } - }).toMap + }).toMap } val integerWithBase: P[(BigInteger, Int)] = { val binDigit = P.charIn(('0' to '1')) @@ -247,9 +261,9 @@ object Parser { d.repSep(sep = under.?).string def base(n: Int, str: String, d: P[Char]) = - (P.string(str.toLowerCase) | P.string(str)) *> - (under.?.with1 *> rest(d)) - .map((_, n)) + (P.string(str.toLowerCase) | P.string(str)) *> + (under.?.with1 *> rest(d)) + .map((_, n)) val not10 = base(2, "0B", binDigit) | base(8, "0O", octDigit) | @@ -269,19 +283,13 @@ object Parser { } object JsonNumber { - /** - * from: https://tools.ietf.org/html/rfc4627 - * number = [ minus ] int [ frac ] [ exp ] - * decimal-point = %x2E ; . - * digit1-9 = %x31-39 ; 1-9 - * e = %x65 / %x45 ; e E - * exp = e [ minus / plus ] 1*DIGIT - * frac = decimal-point 1*DIGIT - * int = zero / ( digit1-9 *DIGIT ) - * minus = %x2D ; - - * plus = %x2B ; + - * zero = %x30 ; 0 - */ + + /** from: https://tools.ietf.org/html/rfc4627 number = [ minus ] int [ frac + * ] [ exp ] decimal-point = %x2E ; . digit1-9 = %x31-39 ; 1-9 e = %x65 / + * %x45 ; e E exp = e [ minus / plus ] 1*DIGIT frac = decimal-point 1*DIGIT + * int = zero / ( digit1-9 *DIGIT ) minus = %x2D ; - plus = %x2B ; + zero = + * %x30 ; 0 + */ val digits: P0[Unit] = digit09.rep0.void val digits1: P[Unit] = digit09.rep.void val int: P[Unit] = P.char('0') <+> (digit19 ~ digits).void @@ -292,7 +300,12 @@ object Parser { (P.char('-').?.with1 ~ int ~ frac.? ~ exp.?).string // this gives you the individual parts of a floating point string - case class Parts(negative: Boolean, leftOfPoint: String, floatingPart: String, exp: String) { + case class Parts( + negative: Boolean, + leftOfPoint: String, + floatingPart: String, + exp: String + ) { def asString: String = { val neg = if (negative) "-" else "" s"$neg$leftOfPoint$floatingPart$exp" @@ -317,14 +330,13 @@ object Parser { def nonEmptyListToList[T](p: P0[NonEmptyList[T]]): P0[List[T]] = p.?.map { - case None => Nil + case None => Nil case Some(ne) => ne.toList } - /** - * Parse python-like dicts: delimited by curlies "{" "}" and - * keys separated by colon - */ + /** Parse python-like dicts: delimited by curlies "{" "}" and keys separated + * by colon + */ def dictLikeParser[K, V](pkey: P[K], pvalue: P[V]): P[List[(K, V)]] = { val ws = maybeSpacesAndLines val kv = (pkey ~ ((ws ~ P.char(':') ~ ws).with1 *> pvalue)) @@ -344,11 +356,15 @@ object Parser { def maybeAp(fn: P0[T => T]): P[T] = (item ~ fn.?) .map { - case (a, None) => a + case (a, None) => a case (a, Some(f)) => f(a) } - def nonEmptyListOfWsSep(ws: P0[Unit], sep: P0[Unit], allowTrailing: Boolean): P[NonEmptyList[T]] = { + def nonEmptyListOfWsSep( + ws: P0[Unit], + sep: P0[Unit], + allowTrailing: Boolean + ): P[NonEmptyList[T]] = { val wsSep = (ws.soft ~ sep ~ ws).void val trail = if (allowTrailing) (ws.soft ~ sep).?.void @@ -381,33 +397,29 @@ object Parser { parens(item) def parensLines1Cut: P[NonEmptyList[T]] = - item.nonEmptyListOfWs(maybeSpacesAndLines) - .parensCut + item.nonEmptyListOfWs(maybeSpacesAndLines).parensCut def parensLines0Cut: P[List[T]] = parens(nonEmptyListToList(item.nonEmptyListOfWs(maybeSpacesAndLines))) - /** - * either: a, b, c, .. - * or (a, b, c, ) where we allow newlines: - * return true if we do have parens - */ + + /** either: a, b, c, .. or (a, b, c, ) where we allow newlines: return true + * if we do have parens + */ def itemsMaybeParens: P[(Boolean, NonEmptyList[T])] = { val withP = item.parensLines1Cut.map((true, _)) val noP = item.nonEmptyListOfWs(maybeSpace).map((false, _)) withP.orElse(noP) } - /** - * Parse a python-like tuple or a parens - */ + /** Parse a python-like tuple or a parens + */ def tupleOrParens: P[Either[T, List[T]]] = parens { - tupleOrParens0.? - .map { - case None => Right(Nil) - case Some(Left(t)) => Left(t) - case Some(Right(l)) => Right(l.toList) - } + tupleOrParens0.?.map { + case None => Right(Nil) + case Some(Left(t)) => Left(t) + case Some(Right(l)) => Right(l.toList) + } } def tupleOrParens0: P[Either[T, NonEmptyList[T]]] = { @@ -447,7 +459,8 @@ object Parser { case Right(a) => a case Left(err) => val idx = err.failedAtOffset - sys.error(s"failed to parse: $str: at $idx: (${str.substring(idx)}) with errors: ${err.expected}") + sys.error(s"failed to parse: $str: at $idx: (${str + .substring(idx)}) with errors: ${err.expected}") } sealed abstract class MaybeTupleOrParens[A] @@ -459,7 +472,7 @@ object Parser { def tupleOrParens[A](p: P[A]): P[NotBare[A]] = p.tupleOrParens.map { - case Right(tup) => Tuple(tup) + case Right(tup) => Tuple(tup) case Left(parens) => Parens(parens) } diff --git a/core/src/main/scala/org/bykn/bosatsu/PathGen.scala b/core/src/main/scala/org/bykn/bosatsu/PathGen.scala index d888c959c..7fd0c3127 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PathGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PathGen.scala @@ -12,14 +12,20 @@ object PathGen { def read(implicit m: Monad[IO]): IO[List[Path]] = m.pure(path :: Nil) } - final case class ChildrenOfDir[IO[_], Path](dir: Path, select: Path => Boolean, recurse: Boolean, unfold: Path => IO[Option[IO[List[Path]]]]) extends PathGen[IO, Path] { + final case class ChildrenOfDir[IO[_], Path]( + dir: Path, + select: Path => Boolean, + recurse: Boolean, + unfold: Path => IO[Option[IO[List[Path]]]] + ) extends PathGen[IO, Path] { def read(implicit m: Monad[IO]): IO[List[Path]] = { val pureEmpty: IO[List[Path]] = m.pure(Nil) lazy val rec: List[Path] => IO[List[Path]] = - if (recurse) { (children: List[Path]) => children.traverse(step).map(_.flatten) } - else { (_: List[Path]) => pureEmpty } + if (recurse) { (children: List[Path]) => + children.traverse(step).map(_.flatten) + } else { (_: List[Path]) => pureEmpty } def step(path: Path): IO[List[Path]] = unfold(path).flatMap { @@ -35,7 +41,8 @@ object PathGen { step(dir) } } - final case class Combine[IO[_], Path](gens: List[PathGen[IO, Path]]) extends PathGen[IO, Path] { + final case class Combine[IO[_], Path](gens: List[PathGen[IO, Path]]) + extends PathGen[IO, Path] { def read(implicit m: Monad[IO]): IO[List[Path]] = gens.traverse(_.read).map(_.flatten) } @@ -45,12 +52,12 @@ object PathGen { val empty: PathGen[IO, Path] = Combine(Nil) def combine(a: PathGen[IO, Path], b: PathGen[IO, Path]) = (a, b) match { - case (Combine(Nil), b) => b - case (a, Combine(Nil)) => a + case (Combine(Nil), b) => b + case (a, Combine(Nil)) => a case (Combine(as), Combine(bs)) => Combine(as ::: bs) - case (Combine(as), b) => Combine(as :+ b) - case (a, Combine(bs)) => Combine(a :: bs) - case (a, b) => Combine(a :: b :: Nil) + case (Combine(as), b) => Combine(as :+ b) + case (a, Combine(bs)) => Combine(a :: bs) + case (a, b) => Combine(a :: b :: Nil) } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala index 1dd22a307..95eda61e8 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala @@ -3,10 +3,10 @@ package org.bykn.bosatsu import cats.{Applicative, Foldable} import cats.data.NonEmptyList import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import org.bykn.bosatsu.pattern.{NamedSeqPattern, SeqPattern, SeqPart} -import Parser.{ Combinators, maybeSpace, MaybeTupleOrParens } +import Parser.{Combinators, maybeSpace, MaybeTupleOrParens} import cats.implicits._ import Identifier.{Bindable, Constructor} @@ -20,16 +20,20 @@ sealed abstract class Pattern[+N, +T] { def mapType[U](fn: T => U): Pattern[N, U] = (new Pattern.InvariantPattern(this)).traverseType[cats.Id, U](fn) - /** - * List all the names that are bound in Vars inside this pattern - * in the left to right order they are encountered, without any duplication - */ + /** List all the names that are bound in Vars inside this pattern in the left + * to right order they are encountered, without any duplication + */ lazy val names: List[Bindable] = { @annotation.tailrec - def loop(stack: List[Pattern[N, T]], seen: Set[Bindable], acc: List[Bindable]): List[Bindable] = + def loop( + stack: List[Pattern[N, T]], + seen: Set[Bindable], + acc: List[Bindable] + ): List[Bindable] = stack match { case Nil => acc.reverse - case (Pattern.WildCard | Pattern.Literal(_)) :: tail => loop(tail, seen, acc) + case (Pattern.WildCard | Pattern.Literal(_)) :: tail => + loop(tail, seen, acc) case Pattern.Var(v) :: tail => if (seen(v)) loop(tail, seen, acc) else loop(tail, seen + v, v :: acc) @@ -37,14 +41,20 @@ sealed abstract class Pattern[+N, +T] { if (seen(v)) loop(p :: tail, seen, acc) else loop(p :: tail, seen + v, v :: acc) case Pattern.StrPat(items) :: tail => - val names = items.collect { - case Pattern.StrPart.NamedStr(n) => n - case Pattern.StrPart.NamedChar(n) => n - }.filterNot(seen) + val names = items + .collect { + case Pattern.StrPart.NamedStr(n) => n + case Pattern.StrPart.NamedChar(n) => n + } + .filterNot(seen) loop(tail, seen ++ names, names reverse_::: acc) case Pattern.ListPat(items) :: tail => - val globs = items.collect { case Pattern.ListPart.NamedList(glob) => glob }.filterNot(seen) - val next = items.collect { case Pattern.ListPart.Item(inner) => inner } + val globs = items + .collect { case Pattern.ListPart.NamedList(glob) => glob } + .filterNot(seen) + val next = items.collect { case Pattern.ListPart.Item(inner) => + inner + } loop(next ::: tail, seen ++ globs, globs reverse_::: acc) case Pattern.Annotation(p, _) :: tail => loop(p :: tail, seen, acc) case Pattern.PositionalStruct(_, params) :: tail => @@ -56,68 +66,88 @@ sealed abstract class Pattern[+N, +T] { loop(this :: Nil, Set.empty, Nil) } - /** - * What are the names that will be bound to the entire pattern, - * Bar(x) as foo would return List(foo) - * foo as bar as baz would return List(baz, bar, foo) - * Bar(x) would return Nil - */ + /** What are the names that will be bound to the entire pattern, Bar(x) as foo + * would return List(foo) foo as bar as baz would return List(baz, bar, foo) + * Bar(x) would return Nil + */ lazy val topNames: List[Bindable] = { this match { - case Pattern.Var(v) => v :: Nil + case Pattern.Var(v) => v :: Nil case Pattern.Named(v, p) => (v :: p.topNames).distinct case Pattern.ListPat(Pattern.ListPart.NamedList(n) :: Nil) => n :: Nil - case Pattern.Annotation(p, _) => p.topNames - case Pattern.Union(h, t) => + case Pattern.Annotation(p, _) => p.topNames + case Pattern.Union(h, t) => // the intersection of all top level names // is okay val pats = h :: t.toList val patIntr = pats.map(_.topNames.toSet).reduce(_ & _) // put them in the same order as written: pats.flatMap(_.topNames).iterator.filter(patIntr).toList.distinct - case Pattern.ListPat(_) | Pattern.WildCard | Pattern.Literal(_) | Pattern.StrPat(_) | Pattern.PositionalStruct(_, _) => Nil + case Pattern.ListPat(_) | Pattern.WildCard | Pattern.Literal(_) | + Pattern.StrPat(_) | Pattern.PositionalStruct(_, _) => + Nil } } - /** - * List all the names that strictly smaller than anything that would match this pattern - * e.g. a top level var, would not be returned - */ + /** List all the names that strictly smaller than anything that would match + * this pattern e.g. a top level var, would not be returned + */ def substructures: List[Bindable] = { - def cheat(stack: List[(Pattern[N, T], Boolean)], seen: Set[Bindable], acc: List[Bindable]): List[Bindable] = + def cheat( + stack: List[(Pattern[N, T], Boolean)], + seen: Set[Bindable], + acc: List[Bindable] + ): List[Bindable] = loop(stack, seen, acc) import Pattern.{ListPart, StrPart} @annotation.tailrec - def loop(stack: List[(Pattern[N, T], Boolean)], seen: Set[Bindable], acc: List[Bindable]): List[Bindable] = + def loop( + stack: List[(Pattern[N, T], Boolean)], + seen: Set[Bindable], + acc: List[Bindable] + ): List[Bindable] = stack match { case Nil => acc.reverse - case ((Pattern.WildCard, _) | (Pattern.Literal(_), _)) :: tail => loop(tail, seen, acc) + case ((Pattern.WildCard, _) | (Pattern.Literal(_), _)) :: tail => + loop(tail, seen, acc) case (Pattern.Var(v), isTop) :: tail => if (seen(v) || isTop) loop(tail, seen, acc) else loop(tail, seen + v, v :: acc) case (Pattern.Named(v, p), isTop) :: tail => if (seen(v) || isTop) loop((p, isTop) :: tail, seen, acc) else loop((p, isTop) :: tail, seen + v, v :: acc) - case (Pattern.StrPat(NonEmptyList(StrPart.NamedStr(_), Nil)), true) :: tail => - // this is a total match at the top level, not a substructure - loop(tail, seen, acc) + case ( + Pattern.StrPat(NonEmptyList(StrPart.NamedStr(_), Nil)), + true + ) :: tail => + // this is a total match at the top level, not a substructure + loop(tail, seen, acc) case (Pattern.StrPat(items), _) :: tail => - val globs = items.collect { case StrPart.NamedStr(glob) => glob }.filterNot(seen) + val globs = items + .collect { case StrPart.NamedStr(glob) => glob } + .filterNot(seen) loop(tail, seen ++ globs, globs reverse_::: acc) case (Pattern.ListPat(ListPart.NamedList(_) :: Nil), true) :: tail => - // this is a total match at the top level, not a substructure - loop(tail, seen, acc) + // this is a total match at the top level, not a substructure + loop(tail, seen, acc) case (Pattern.ListPat(items), _) :: tail => - val globs = items.collect { case ListPart.NamedList(glob) => glob }.filterNot(seen) - val next = items.collect { case ListPart.Item(inner) => (inner, false) } + val globs = items + .collect { case ListPart.NamedList(glob) => glob } + .filterNot(seen) + val next = items.collect { case ListPart.Item(inner) => + (inner, false) + } loop(next ::: tail, seen ++ globs, globs reverse_::: acc) - case (Pattern.Annotation(p, _), isTop) :: tail => loop((p, isTop) :: tail, seen, acc) + case (Pattern.Annotation(p, _), isTop) :: tail => + loop((p, isTop) :: tail, seen, acc) case (Pattern.PositionalStruct(_, params), _) :: tail => loop(params.map((_, false)) ::: tail, seen, acc) case (Pattern.Union(h, t), isTop) :: tail => - val all = (h :: t.toList).map { p => cheat((p, isTop) :: tail, seen, acc) } + val all = (h :: t.toList).map { p => + cheat((p, isTop) :: tail, seen, acc) + } // we need to be substructual on all: val intr = all.map(_.toSet).reduce(_.intersect(_)) all.flatMap(_.filter(intr)).distinct @@ -126,20 +156,18 @@ sealed abstract class Pattern[+N, +T] { loop((this, true) :: Nil, Set.empty, Nil) } - /** - * Return the pattern with all the binding names removed - */ + /** Return the pattern with all the binding names removed + */ def unbind: Pattern[N, T] = filterVars(Set.empty) - /** - * replace all Var names with Wildcard that are not - * satifying the keep predicate - */ + /** replace all Var names with Wildcard that are not satifying the keep + * predicate + */ def filterVars(keep: Bindable => Boolean): Pattern[N, T] = this match { case Pattern.WildCard | Pattern.Literal(_) => this - case p@Pattern.Var(v) => + case p @ Pattern.Var(v) => if (keep(v)) p else Pattern.WildCard case Pattern.Named(v, p) => val inner = p.filterVars(keep) @@ -147,18 +175,20 @@ sealed abstract class Pattern[+N, +T] { else inner case Pattern.StrPat(items) => Pattern.StrPat(items.map { - case wl@(Pattern.StrPart.WildStr | Pattern.StrPart.WildChar | Pattern.StrPart.LitStr(_)) => wl - case in@Pattern.StrPart.NamedStr(n) => + case wl @ (Pattern.StrPart.WildStr | Pattern.StrPart.WildChar | + Pattern.StrPart.LitStr(_)) => + wl + case in @ Pattern.StrPart.NamedStr(n) => if (keep(n)) in else Pattern.StrPart.WildStr - case in@Pattern.StrPart.NamedChar(n) => + case in @ Pattern.StrPart.NamedChar(n) => if (keep(n)) in else Pattern.StrPart.WildChar }) case Pattern.ListPat(items) => Pattern.ListPat(items.map { case Pattern.ListPart.WildList => Pattern.ListPart.WildList - case in@Pattern.ListPart.NamedList(n) => + case in @ Pattern.ListPart.NamedList(n) => if (keep(n)) in else Pattern.ListPart.WildList case Pattern.ListPart.Item(p) => @@ -172,23 +202,27 @@ sealed abstract class Pattern[+N, +T] { Pattern.Union(h.filterVars(keep), t.map(_.filterVars(keep))) } - /** - * a collision happens when the same binding happens twice - * not separated by a union - */ + /** a collision happens when the same binding happens twice not separated by a + * union + */ def collisionBinds: List[Bindable] = { def loop(pat: Pattern[N, T]): (Set[Bindable], List[Bindable]) = pat match { case Pattern.WildCard | Pattern.Literal(_) => (Set.empty, Nil) - case Pattern.Var(v) => (Set(v), Nil) + case Pattern.Var(v) => (Set(v), Nil) case Pattern.Named(v, p) => val (s1, l1) = loop(p) if (s1(v)) (s1, v :: l1) else (s1 + v, l1) case Pattern.StrPat(items) => items.foldLeft((Set.empty[Bindable], List.empty[Bindable])) { - case (res, Pattern.StrPart.WildStr | Pattern.StrPart.WildChar | Pattern.StrPart.LitStr(_)) => res + case ( + res, + Pattern.StrPart.WildStr | Pattern.StrPart.WildChar | + Pattern.StrPart.LitStr(_) + ) => + res case ((s1, l1), Pattern.StrPart.NamedStr(v)) => if (s1(v)) (s1, v :: l1) else (s1 + v, l1) @@ -223,32 +257,30 @@ sealed abstract class Pattern[+N, +T] { loop(this)._2.distinct.sorted } - /** - * @return the type if we can directly see it + /** @return + * the type if we can directly see it */ def simpleTypeOf: Option[T] = this match { - case Pattern.Named(_, p) => p.simpleTypeOf + case Pattern.Named(_, p) => p.simpleTypeOf case Pattern.Annotation(_, t) => Some(t) case Pattern.Union(_, _) | Pattern.ListPat(_) | Pattern.Literal(_) | - Pattern.WildCard | Pattern.Var(_) | Pattern.StrPat(_) | - Pattern.PositionalStruct(_, _) => None + Pattern.WildCard | Pattern.Var(_) | Pattern.StrPat(_) | + Pattern.PositionalStruct(_, _) => + None } } object Pattern { - /** - * Represents the different patterns that are all for structs - * (2, 3) - * Foo(2, 3) - * etc... - */ + /** Represents the different patterns that are all for structs (2, 3) Foo(2, + * 3) etc... + */ sealed abstract class StructKind { def namedStyle: Option[StructKind.Style] = this match { - case StructKind.Tuple => None - case StructKind.Named(_, style) => Some(style) + case StructKind.Tuple => None + case StructKind.Named(_, style) => Some(style) case StructKind.NamedPartial(_, style) => Some(style) } } @@ -276,7 +308,8 @@ object Pattern { // Represents a complete tuple-like pattern Foo(a, b) final case class Named(name: Constructor, style: Style) extends NamedKind // Represents a partial tuple-like pattern Foo(a, ...) - final case class NamedPartial(name: Constructor, style: Style) extends NamedKind + final case class NamedPartial(name: Constructor, style: Style) + extends NamedKind } sealed abstract class StrPart @@ -297,17 +330,18 @@ object Pattern { def document(q: Char): Document[StrPart] = Document.instance { - case WildStr => wildDoc + case WildStr => wildDoc case WildChar => wildCharDoc - case NamedStr(b) => prefix + Document[Bindable].document(b) + Doc.char('}') - case NamedChar(b) => prefixChar + Document[Bindable].document(b) + Doc.char('}') + case NamedStr(b) => + prefix + Document[Bindable].document(b) + Doc.char('}') + case NamedChar(b) => + prefixChar + Document[Bindable].document(b) + Doc.char('}') case LitStr(s) => Doc.text(StringUtil.escape(q, s)) } } - /** - * represents items in a list pattern - */ + /** represents items in a list pattern + */ sealed abstract class ListPart[+A] { def map[B](fn: A => B): ListPart[B] } @@ -322,30 +356,29 @@ object Pattern { } } - /** - * This will match any list without any binding - */ + /** This will match any list without any binding + */ val AnyList: Pattern[Nothing, Nothing] = Pattern.ListPat(ListPart.WildList :: Nil) type Parsed = Pattern[StructKind, TypeRef] - /** - * Flatten a pattern out such that there are no top-level - * unions - */ + /** Flatten a pattern out such that there are no top-level unions + */ def flatten[N, T](p: Pattern[N, T]): NonEmptyList[Pattern[N, T]] = p match { case Union(h, t) => NonEmptyList(h, t.toList).flatMap(flatten(_)) - case nonU => NonEmptyList.one(nonU) + case nonU => NonEmptyList.one(nonU) } - /** - * Create a normalized pattern, which doesn't have nested top level unions - */ - def union[N, T](head: Pattern[N, T], tail: List[Pattern[N, T]]): Pattern[N, T] = { + /** Create a normalized pattern, which doesn't have nested top level unions + */ + def union[N, T]( + head: Pattern[N, T], + tail: List[Pattern[N, T]] + ): Pattern[N, T] = { NonEmptyList(head, tail).flatMap(flatten(_)) match { - case NonEmptyList(h, Nil) => h + case NonEmptyList(h, Nil) => h case NonEmptyList(h0, h1 :: tail) => Union(h0, NonEmptyList(h1, tail)) } } @@ -355,26 +388,36 @@ object Pattern { traversePattern[F, N, T1]( { (n, args) => args.map(PositionalStruct(n, _)) }, fn, - { parts => parts.map(ListPat(_)) }) - - def traverseStruct[F[_]: Applicative, N1](parts: (N, F[List[Pattern[N1, T]]]) => F[Pattern[N1, T]]): F[Pattern[N1, T]] = - traversePattern[F, N1, T](parts, Applicative[F].pure(_), { parts => parts.map(ListPat(_)) }) - - def mapStruct[N1](parts: (N, List[Pattern[N1, T]]) => Pattern[N1, T]): Pattern[N1, T] = + { parts => parts.map(ListPat(_)) } + ) + + def traverseStruct[F[_]: Applicative, N1]( + parts: (N, F[List[Pattern[N1, T]]]) => F[Pattern[N1, T]] + ): F[Pattern[N1, T]] = + traversePattern[F, N1, T]( + parts, + Applicative[F].pure(_), + { parts => parts.map(ListPat(_)) } + ) + + def mapStruct[N1]( + parts: (N, List[Pattern[N1, T]]) => Pattern[N1, T] + ): Pattern[N1, T] = traverseStruct[cats.Id, N1](parts) def traversePattern[F[_]: Applicative, N1, T1]( - parts: (N, F[List[Pattern[N1, T1]]]) => F[Pattern[N1, T1]], - tpeFn: T => F[T1], - listFn: F[List[ListPart[Pattern[N1, T1]]]] => F[Pattern[N1, T1]]): F[Pattern[N1, T1]] = { + parts: (N, F[List[Pattern[N1, T1]]]) => F[Pattern[N1, T1]], + tpeFn: T => F[T1], + listFn: F[List[ListPart[Pattern[N1, T1]]]] => F[Pattern[N1, T1]] + ): F[Pattern[N1, T1]] = { lazy val pwild: F[Pattern[N1, T1]] = Applicative[F].pure(Pattern.WildCard) def go(pat: Pattern[N, T]): F[Pattern[N1, T1]] = pat match { - case Pattern.WildCard => pwild + case Pattern.WildCard => pwild case Pattern.Literal(lit) => Applicative[F].pure(Pattern.Literal(lit)) - case Pattern.Var(v) => Applicative[F].pure(Pattern.Var(v)) - case Pattern.StrPat(s) => Applicative[F].pure(Pattern.StrPat(s)) + case Pattern.Var(v) => Applicative[F].pure(Pattern.Var(v)) + case Pattern.StrPat(s) => Applicative[F].pure(Pattern.StrPat(s)) case Pattern.Named(v, p) => go(p).map(Pattern.Named(v, _)) case Pattern.ListPat(items) => @@ -401,14 +444,17 @@ object Pattern { } } - implicit class FoldablePattern[F[_], N, T](private val pats: F[Pattern[N, T]]) extends AnyVal { - def patternNames(implicit F: Foldable[F]): List[Bindable] = F.toList(pats).flatMap(_.names) + implicit class FoldablePattern[F[_], N, T](private val pats: F[Pattern[N, T]]) + extends AnyVal { + def patternNames(implicit F: Foldable[F]): List[Bindable] = + F.toList(pats).flatMap(_.names) } case object WildCard extends Pattern[Nothing, Nothing] case class Literal(toLit: Lit) extends Pattern[Nothing, Nothing] case class Var(name: Bindable) extends Pattern[Nothing, Nothing] - case class StrPat(parts: NonEmptyList[StrPart]) extends Pattern[Nothing, Nothing] { + case class StrPat(parts: NonEmptyList[StrPart]) + extends Pattern[Nothing, Nothing] { def isEmpty: Boolean = this == StrPat.Empty lazy val isTotal: Boolean = { @@ -416,7 +462,7 @@ object Pattern { !parts.exists { case LitStr(_) | WildChar | NamedChar(_) => true - case _ => false + case _ => false } } @@ -434,28 +480,37 @@ object Pattern { isTotal || matcher(str).isDefined } - /** - * Patterns like Some(_) as foo - * as binds tighter than |, so use ( ) with groups you want to bind - */ - case class Named[N, T](name: Bindable, pat: Pattern[N, T]) extends Pattern[N, T] - case class ListPat[N, T](parts: List[ListPart[Pattern[N, T]]]) extends Pattern[N, T] { + /** Patterns like Some(_) as foo as binds tighter than |, so use ( ) with + * groups you want to bind + */ + case class Named[N, T](name: Bindable, pat: Pattern[N, T]) + extends Pattern[N, T] + case class ListPat[N, T](parts: List[ListPart[Pattern[N, T]]]) + extends Pattern[N, T] { lazy val toNamedSeqPattern: NamedSeqPattern[Pattern[N, T]] = ListPat.toNamedSeqPattern(this) lazy val toSeqPattern: SeqPattern[Pattern[N, T]] = toNamedSeqPattern.unname - def toPositionalStruct(empty: N, cons: N): Either[(ListPart.Glob, NonEmptyList[ListPart[Pattern[N, T]]]), Pattern[N, T]] = { - def loop(parts: List[ListPart[Pattern[N, T]]]): Either[(ListPart.Glob, NonEmptyList[ListPart[Pattern[N, T]]]), Pattern[N, T]] = + def toPositionalStruct(empty: N, cons: N): Either[ + (ListPart.Glob, NonEmptyList[ListPart[Pattern[N, T]]]), + Pattern[N, T] + ] = { + def loop( + parts: List[ListPart[Pattern[N, T]]] + ): Either[(ListPart.Glob, NonEmptyList[ListPart[Pattern[N, T]]]), Pattern[ + N, + T + ]] = parts match { - case Nil => Right(PositionalStruct(empty, Nil)) + case Nil => Right(PositionalStruct(empty, Nil)) case ListPart.WildList :: Nil => Right(WildCard) case ListPart.NamedList(glob) :: Nil => Right(Var(glob)) - case ListPart.Item(p) :: tail => + case ListPart.Item(p) :: tail => // we can always make some progress here val tailPat = loop(tail).toOption.getOrElse(ListPat(tail)) Right(PositionalStruct(cons, List(p, tailPat))) - case (l@ListPart.WildList) :: (r@ListPart.Item(WildCard)) :: t => + case (l @ ListPart.WildList) :: (r @ ListPart.Item(WildCard)) :: t => // we can switch *_, _ with _, *_ loop(r :: l :: t) case (glob: ListPart.Glob) :: h1 :: t => @@ -466,9 +521,12 @@ object Pattern { loop(parts) } } - case class Annotation[N, T](pattern: Pattern[N, T], tpe: T) extends Pattern[N, T] - case class PositionalStruct[N, T](name: N, params: List[Pattern[N, T]]) extends Pattern[N, T] - case class Union[N, T](head: Pattern[N, T], rest: NonEmptyList[Pattern[N, T]]) extends Pattern[N, T] { + case class Annotation[N, T](pattern: Pattern[N, T], tpe: T) + extends Pattern[N, T] + case class PositionalStruct[N, T](name: N, params: List[Pattern[N, T]]) + extends Pattern[N, T] + case class Union[N, T](head: Pattern[N, T], rest: NonEmptyList[Pattern[N, T]]) + extends Pattern[N, T] { def split: (Pattern[N, T], Pattern[N, T]) = { // we have at least two patterns here val pats = head :: rest.flatMap(flatten(_)) @@ -492,7 +550,10 @@ object Pattern { def fromSeqPattern[N, T](sp: SeqPattern[Pattern[N, T]]): ListPat[N, T] = { @annotation.tailrec - def loop(ps: List[SeqPart[Pattern[N, T]]], front: List[ListPart[Pattern[N, T]]]): List[ListPart[Pattern[N, T]]] = + def loop( + ps: List[SeqPart[Pattern[N, T]]], + front: List[ListPart[Pattern[N, T]]] + ): List[ListPart[Pattern[N, T]]] = ps match { case Nil => front.reverse case SeqPart.Lit(p) :: tail => @@ -509,19 +570,25 @@ object Pattern { ListPat(loop(sp.toList, Nil)) } - def toNamedSeqPattern[N, T](lp: ListPat[N, T]): NamedSeqPattern[Pattern[N, T]] = { - def partToNsp(lp: ListPart[Pattern[N, T]]): NamedSeqPattern[Pattern[N, T]] = + def toNamedSeqPattern[N, T]( + lp: ListPat[N, T] + ): NamedSeqPattern[Pattern[N, T]] = { + def partToNsp( + lp: ListPart[Pattern[N, T]] + ): NamedSeqPattern[Pattern[N, T]] = lp match { case ListPart.Item(WildCard) => NamedSeqPattern.Any - case ListPart.Item(p) => NamedSeqPattern.fromLit(p) - case ListPart.WildList => NamedSeqPattern.Wild + case ListPart.Item(p) => NamedSeqPattern.fromLit(p) + case ListPart.WildList => NamedSeqPattern.Wild case ListPart.NamedList(n) => NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Wild) } - def loop(lp: List[ListPart[Pattern[N, T]]]): NamedSeqPattern[Pattern[N, T]] = + def loop( + lp: List[ListPart[Pattern[N, T]]] + ): NamedSeqPattern[Pattern[N, T]] = lp match { - case Nil => NamedSeqPattern.NEmpty + case Nil => NamedSeqPattern.NEmpty case h :: Nil => partToNsp(h) case h :: t => NamedSeqPattern.NCat(partToNsp(h), loop(t)) @@ -540,7 +607,10 @@ object Pattern { if (rev.isEmpty) Nil else StrPart.LitStr(rev.reverse.mkString) :: Nil - def loop(ps: List[SeqPart[Char]], front: List[Char]): NonEmptyList[StrPart] = + def loop( + ps: List[SeqPart[Char]], + front: List[Char] + ): NonEmptyList[StrPart] = ps match { case Nil => NonEmptyList.fromList(lit(front)).getOrElse(Empty.parts) case SeqPart.Lit(c) :: tail => @@ -562,7 +632,7 @@ object Pattern { else tr.prepend(StrPart.WildStr) NonEmptyList.fromList(lit(front)) match { - case None => tailRes + case None => tailRes case Some(h) => h ::: tailRes } } @@ -579,17 +649,18 @@ object Pattern { NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Wild) case StrPart.NamedChar(n) => NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Any) - case StrPart.WildStr => NamedSeqPattern.Wild + case StrPart.WildStr => NamedSeqPattern.Wild case StrPart.WildChar => NamedSeqPattern.Any case StrPart.LitStr(s) => if (s.isEmpty) empty - else s.toList.foldRight(empty) { (c, tail) => - NamedSeqPattern.NCat(NamedSeqPattern.fromLit(c), tail) - } + else + s.toList.foldRight(empty) { (c, tail) => + NamedSeqPattern.NCat(NamedSeqPattern.fromLit(c), tail) + } } sp.parts.toList.foldRight(empty) { (h, t) => - NamedSeqPattern.NCat(partToNsp(h), t) + NamedSeqPattern.NCat(partToNsp(h), t) } } @@ -597,18 +668,13 @@ object Pattern { StrPat(NonEmptyList.one(StrPart.LitStr(s))) } - /** - * If this pattern is: - * x - * (x: T) - * unnamed as x - * x | x | x - * then it is "SinglyNamed" - */ + /** If this pattern is: x (x: T) unnamed as x x | x | x then it is + * "SinglyNamed" + */ object SinglyNamed { def unapply[N, T](p: Pattern[N, T]): Option[Bindable] = p match { - case Var(b) => Some(b) + case Var(b) => Some(b) case Annotation(SinglyNamed(b), _) => Some(b) case Named(b, inner) => if (inner.names.isEmpty) Some(b) @@ -621,7 +687,8 @@ object Pattern { } } - implicit def patternOrdering[N: Ordering, T: Ordering]: Ordering[Pattern[N, T]] = + implicit def patternOrdering[N: Ordering, T: Ordering] + : Ordering[Pattern[N, T]] = new Ordering[Pattern[N, T]] { val ordN = implicitly[Ordering[N]] val ordT = implicitly[Ordering[T]] @@ -631,14 +698,14 @@ object Pattern { new Ordering[ListPart[A]] { def compare(a: ListPart[A], b: ListPart[A]) = (a, b) match { - case (ListPart.WildList, ListPart.WildList) => 0 - case (ListPart.WildList, _) => -1 + case (ListPart.WildList, ListPart.WildList) => 0 + case (ListPart.WildList, _) => -1 case (ListPart.NamedList(_), ListPart.WildList) => 1 case (ListPart.NamedList(a), ListPart.NamedList(b)) => ordBin.compare(a, b) case (ListPart.NamedList(_), ListPart.Item(_)) => -1 case (ListPart.Item(a), ListPart.Item(b)) => ordA.compare(a, b) - case (ListPart.Item(_), _) => 1 + case (ListPart.Item(_), _) => 1 } } @@ -649,19 +716,19 @@ object Pattern { def compare(a: StrPart, b: StrPart) = (a, b) match { - case (WildStr, WildStr) => 0 - case (WildStr, _) => -1 - case (WildChar, WildStr) => 1 - case (WildChar, WildChar) => 0 - case (WildChar, _) => -1 - case (LitStr(_), WildStr | WildChar) => 1 - case (LitStr(sa), LitStr(sb)) => sa.compareTo(sb) + case (WildStr, WildStr) => 0 + case (WildStr, _) => -1 + case (WildChar, WildStr) => 1 + case (WildChar, WildChar) => 0 + case (WildChar, _) => -1 + case (LitStr(_), WildStr | WildChar) => 1 + case (LitStr(sa), LitStr(sb)) => sa.compareTo(sb) case (LitStr(_), NamedStr(_) | NamedChar(_)) => -1 case (NamedChar(_), WildStr | WildChar | LitStr(_)) => 1 case (NamedChar(na), NamedChar(nb)) => ordBin.compare(na, nb) - case (NamedChar(_), NamedStr(_)) => -1 - case (NamedStr(na), NamedStr(nb)) => ordBin.compare(na, nb) - case (NamedStr(_), _) => 1 + case (NamedChar(_), NamedStr(_)) => -1 + case (NamedStr(na), NamedStr(nb)) => ordBin.compare(na, nb) + case (NamedStr(_), _) => 1 } } val strOrd = ListOrdering.onType(ordStrPart) @@ -670,30 +737,34 @@ object Pattern { def compare(a: Pattern[N, T], b: Pattern[N, T]): Int = (a, b) match { - case (WildCard, WildCard) => 0 - case (WildCard, _) => -1 - case (Literal(_), WildCard) => 1 - case (Literal(a), Literal(b)) => Lit.litOrdering.compare(a, b) - case (Literal(_), _) => -1 + case (WildCard, WildCard) => 0 + case (WildCard, _) => -1 + case (Literal(_), WildCard) => 1 + case (Literal(a), Literal(b)) => Lit.litOrdering.compare(a, b) + case (Literal(_), _) => -1 case (Var(_), WildCard | Literal(_)) => 1 - case (Var(a), Var(b)) => compIdent.compare(a, b) - case (Var(_), _) => -1 + case (Var(a), Var(b)) => compIdent.compare(a, b) + case (Var(_), _) => -1 case (Named(_, _), WildCard | Literal(_) | Var(_)) => 1 case (Named(n1, p1), Named(n2, p2)) => val c = compIdent.compare(n1, n2) if (c == 0) compare(p1, p2) else c - case (Named(_, _), _) => -1 + case (Named(_, _), _) => -1 case (StrPat(_), WildCard | Literal(_) | Var(_) | Named(_, _)) => 1 case (StrPat(as), StrPat(bs)) => strOrd.compare(as.toList, bs.toList) - case (StrPat(_), _) => -1 - case (ListPat(_), WildCard | Literal(_) | Var(_) | Named(_, _) | StrPat(_)) => 1 + case (StrPat(_), _) => -1 + case ( + ListPat(_), + WildCard | Literal(_) | Var(_) | Named(_, _) | StrPat(_) + ) => + 1 case (ListPat(as), ListPat(bs)) => listE.compare(as, bs) - case (ListPat(_), _) => -1 + case (ListPat(_), _) => -1 case (Annotation(_, _), PositionalStruct(_, _) | Union(_, _)) => -1 case (Annotation(a0, t0), Annotation(a1, t1)) => val c = compare(a0, a1) if (c == 0) ordT.compare(t0, t1) else c - case (Annotation(_, _), _) => 1 + case (Annotation(_, _), _) => 1 case (PositionalStruct(_, _), Union(_, _)) => -1 case (PositionalStruct(n0, a0), PositionalStruct(n1, a1)) => val c = ordN.compare(n0, n1) @@ -707,32 +778,39 @@ object Pattern { implicit def document[T: Document]: Document[Pattern[StructKind, T]] = Document.instance[Pattern[StructKind, T]] { - case WildCard => Doc.char('_') - case Literal(lit) => Document[Lit].document(lit) - case Var(n) => Document[Identifier].document(n) - case Named(n, u@Union(_, _)) => + case WildCard => Doc.char('_') + case Literal(lit) => Document[Lit].document(lit) + case Var(n) => Document[Identifier].document(n) + case Named(n, u @ Union(_, _)) => // union is also an operator, so we need to use parens to explicitly bind | more tightly // than the @ on the left. - Doc.char('(') + document.document(u) + Doc.char(')') + Doc.text(" as ") + Document[Identifier].document(n) + Doc.char('(') + document.document(u) + Doc.char(')') + Doc.text( + " as " + ) + Document[Identifier].document(n) case Named(n, p) => - document.document(p) + Doc.text(" as ") + Document[Identifier].document(n) + document.document(p) + Doc.text(" as ") + Document[Identifier].document( + n + ) case StrPat(items) => // prefer ' if possible, else use " val useDouble = items.exists { case StrPart.LitStr(str) => str.contains('\'') && !str.contains('"') - case _ => false + case _ => false } val q = if (useDouble) '"' else '\'' val sd = StrPart.document(q) val inner = Doc.intercalate(Doc.empty, items.toList.map(sd.document(_))) Doc.char(q) + inner + Doc.char(q) case ListPat(items) => - Doc.char('[') + Doc.intercalate(Doc.text(", "), + Doc.char('[') + Doc.intercalate( + Doc.text(", "), items.map { case ListPart.WildList => Doc.text("*_") - case ListPart.NamedList(glob) => Doc.char('*') + Document[Identifier].document(glob) + case ListPart.NamedList(glob) => + Doc.char('*') + Document[Identifier].document(glob) case ListPart.Item(p) => document.document(p) - }) + Doc.char(']') + } + ) + Doc.char(']') case Annotation(p, t) => /* * We need to know what package we are in and what imports we depend on here. @@ -783,13 +861,16 @@ object Pattern { // of fields here val cspace = Doc.text(": ") val identDoc = Document[Identifier] - val kvargs = Doc.intercalate(Doc.text(", "), - fields.toList.zip(args) + val kvargs = Doc.intercalate( + Doc.text(", "), + fields.toList + .zip(args) .map { case (StructKind.Style.FieldKind.Explicit(n), adoc) => identDoc.document(n) + cspace + adoc case (StructKind.Style.FieldKind.Implicit(_), adoc) => adoc - }) + } + ) prefix + Doc.text(" {") + kvargs + @@ -810,42 +891,49 @@ object Pattern { } def recordPat[N <: StructKind.NamedKind]( - name: Constructor, - args: NonEmptyList[Either[Bindable, (Bindable, Parsed)]])( - fn: (Constructor, StructKind.Style) => N): PositionalStruct[StructKind, TypeRef] = { + name: Constructor, + args: NonEmptyList[Either[Bindable, (Bindable, Parsed)]] + )( + fn: (Constructor, StructKind.Style) => N + ): PositionalStruct[StructKind, TypeRef] = { val fields = args.map { - case Left(b) => StructKind.Style.FieldKind.Implicit(b) + case Left(b) => StructKind.Style.FieldKind.Implicit(b) case Right((b, _)) => StructKind.Style.FieldKind.Explicit(b) } val structArgs = args.toList.map { - case Left(b) => Pattern.Var(b) + case Left(b) => Pattern.Var(b) case Right((_, p)) => p } - PositionalStruct( - fn(name, StructKind.Style.RecordLike(fields)), - structArgs) + PositionalStruct(fn(name, StructKind.Style.RecordLike(fields)), structArgs) } - def compiledDocument[A: Document]: Document[Pattern[(PackageName, Constructor), A]] = { - lazy val doc: Document[Pattern[(PackageName, Constructor), A]] = compiledDocument[A] + def compiledDocument[A: Document] + : Document[Pattern[(PackageName, Constructor), A]] = { + lazy val doc: Document[Pattern[(PackageName, Constructor), A]] = + compiledDocument[A] Document.instance[Pattern[(PackageName, Constructor), A]] { - case WildCard => Doc.char('_') - case Literal(lit) => Document[Lit].document(lit) - case Var(n) => Document[Identifier].document(n) - case Named(n, u@Union(_, _)) => + case WildCard => Doc.char('_') + case Literal(lit) => Document[Lit].document(lit) + case Var(n) => Document[Identifier].document(n) + case Named(n, u @ Union(_, _)) => // union is also an operator, so we need to use parens to explicitly bind | more tightly // than the as on the left. - Doc.char('(') + doc.document(u) + Doc.char(')') + Doc.text(" as ") + Document[Identifier].document(n) + Doc.char('(') + doc.document(u) + Doc.char(')') + Doc.text( + " as " + ) + Document[Identifier].document(n) case Named(n, p) => doc.document(p) + Doc.text(" as ") + Document[Identifier].document(n) case StrPat(items) => document.document(StrPat(items)) case ListPat(items) => - Doc.char('[') + Doc.intercalate(Doc.text(", "), + Doc.char('[') + Doc.intercalate( + Doc.text(", "), items.map { case ListPart.WildList => Doc.text("*_") - case ListPart.NamedList(glob) => Doc.char('*') + Document[Identifier].document(glob) + case ListPart.NamedList(glob) => + Doc.char('*') + Document[Identifier].document(glob) case ListPart.Item(p) => doc.document(p) - }) + Doc.char(']') + } + ) + Doc.char(']') case Annotation(p, t) => /* * We need to know what package we are in and what imports we depend on here. @@ -856,12 +944,20 @@ object Pattern { * case */ doc.document(p) + Doc.text(": ") + Document[A].document(t) - case ps@PositionalStruct((_, c), a) => - def untuple(p: Pattern[(PackageName, Constructor), A]): Option[List[Doc]] = + case ps @ PositionalStruct((_, c), a) => + def untuple( + p: Pattern[(PackageName, Constructor), A] + ): Option[List[Doc]] = p match { - case PositionalStruct((PackageName.PredefName, Constructor("Unit")), Nil) => + case PositionalStruct( + (PackageName.PredefName, Constructor("Unit")), + Nil + ) => Some(Nil) - case PositionalStruct((PackageName.PredefName, Constructor("TupleCons")), a :: b :: Nil) => + case PositionalStruct( + (PackageName.PredefName, Constructor("TupleCons")), + a :: b :: Nil + ) => untuple(b).map { l => doc.document(a) :: l } case _ => None } @@ -875,7 +971,7 @@ object Pattern { case None => val args = a match { case Nil => Doc.empty - case _ => tup(a.map(doc.document(_))) + case _ => tup(a.map(doc.document(_))) } Doc.text(c.asString) + args } @@ -893,26 +989,27 @@ object Pattern { } } - /** - * For fully typed patterns, compute the type environment of the bindings - * from this pattern. This will sys.error if you pass a bad pattern, which - * you should never do (and this code will never do unless there is some - * broken invariant) - */ - def envOf[C, K, T](p: Pattern[C, T], env: Map[K, T])(kfn: Identifier => K): Map[K, T] = { + /** For fully typed patterns, compute the type environment of the bindings + * from this pattern. This will sys.error if you pass a bad pattern, which + * you should never do (and this code will never do unless there is some + * broken invariant) + */ + def envOf[C, K, T](p: Pattern[C, T], env: Map[K, T])( + kfn: Identifier => K + ): Map[K, T] = { def update(env: Map[K, T], n: Identifier, typeOf: Option[T]): Map[K, T] = - typeOf match { - case None => - // $COVERAGE-OFF$ should be unreachable - sys.error(s"no type found for $n in $p") - // $COVERAGE-ON$ should be unreachable - case Some(t) => env.updated(kfn(n), t) - } + typeOf match { + case None => + // $COVERAGE-OFF$ should be unreachable + sys.error(s"no type found for $n in $p") + // $COVERAGE-ON$ should be unreachable + case Some(t) => env.updated(kfn(n), t) + } def loop(p0: Pattern[C, T], typeOf: Option[T], env: Map[K, T]): Map[K, T] = p0 match { - case WildCard => env + case WildCard => env case Literal(_) => env - case Var(n) => update(env, n, typeOf) + case Var(n) => update(env, n, typeOf) case Named(n, p1) => val e1 = loop(p1, typeOf, env) update(e1, n, typeOf) @@ -921,12 +1018,12 @@ object Pattern { items .foldLeft(env) { case (env, StrPart.NamedStr(n)) => update(env, n, typeOf) - case (env, _) => env + case (env, _) => env } case ListPat(items) => items.foldLeft(env) { - case (env, ListPart.WildList) => env + case (env, ListPart.WildList) => env case (env, ListPart.NamedList(n)) => // the type of a named sub-list is // the same as the type of the list @@ -949,10 +1046,10 @@ object Pattern { private[this] val plit: P[Pattern[Nothing, Nothing]] = { val intp = (Lit.integerParser | Lit.codePointParser).map(Literal(_)) val startStr = P.string("${").as { (opt: Option[Bindable]) => - opt.fold(StrPart.WildStr: StrPart)(StrPart.NamedStr(_)) + opt.fold(StrPart.WildStr: StrPart)(StrPart.NamedStr(_)) } val startChar = P.string("$.{").as { (opt: Option[Bindable]) => - opt.fold(StrPart.WildChar: StrPart)(StrPart.NamedChar(_)) + opt.fold(StrPart.WildChar: StrPart)(StrPart.NamedChar(_)) } val start = startStr | startChar val end = P.char('}') @@ -962,37 +1059,38 @@ object Pattern { val part: P[Option[Bindable]] = pwild | pname def strp(q: Char): P[List[StrPart]] = - StringUtil.interpolatedString(q, start, part, end) + StringUtil + .interpolatedString(q, start, part, end) .map(_.map { - case Left(p) => p + case Left(p) => p case Right((_, str)) => StrPart.LitStr(str) }) val eitherString = strp('\'') <+> strp('"') // don't emit complex patterns for simple strings: val str = eitherString.map { - case Nil => Literal(Lit.EmptyStr) + case Nil => Literal(Lit.EmptyStr) case StrPart.LitStr(str) :: Nil => Literal(Lit.Str(str)) - case h :: tail => StrPat(NonEmptyList(h, tail)) + case h :: tail => StrPat(NonEmptyList(h, tail)) } str <+> intp } - /** - * This does not allow a top-level type annotation which would be ambiguous - * with : used for ending the match case block - */ + /** This does not allow a top-level type annotation which would be ambiguous + * with : used for ending the match case block + */ val matchParser: P[Parsed] = P.defer(matchOrNot(isMatch = true)) - /** - * A Pattern in a match position allows top level un-parenthesized type annotation - */ + /** A Pattern in a match position allows top level un-parenthesized type + * annotation + */ val bindParser: P[Parsed] = P.defer(matchOrNot(isMatch = false)) - private val maybePartial: P0[(Constructor, StructKind.Style) => StructKind.NamedKind] = { + private val maybePartial + : P0[(Constructor, StructKind.Style) => StructKind.NamedKind] = { val partial = (maybeSpace.soft ~ P.string("...")).as( { (n: Constructor, s: StructKind.Style) => StructKind.NamedPartial(n, s) } ) @@ -1004,24 +1102,30 @@ object Pattern { partial.orElse(notPartial) } - private def parseRecordStruct(recurse: P0[Parsed]): P[Constructor => PositionalStruct[StructKind, TypeRef]] = { + private def parseRecordStruct( + recurse: P0[Parsed] + ): P[Constructor => PositionalStruct[StructKind, TypeRef]] = { // We do maybeSpace, then { } then either a Bindable or Bindable: Pattern // maybe followed by ... val item: P[Either[Bindable, (Bindable, Parsed)]] = - (Identifier.bindableParser ~ ((maybeSpace.soft ~ P.char(':') ~ maybeSpace) *> recurse).?) + (Identifier.bindableParser ~ ((maybeSpace.soft ~ P.char( + ':' + ) ~ maybeSpace) *> recurse).?) .map { - case (b, None) => Left(b) + case (b, None) => Left(b) case (b, Some(pat)) => Right((b, pat)) } val items = item.nonEmptyList ~ maybePartial - ((maybeSpace.with1.soft ~ P.char('{') ~ maybeSpace) *> items <* (maybeSpace ~ P.char('}'))) - .map { case (args, fn) => - { (c: Constructor) => recordPat(c, args)(fn) } - } + ((maybeSpace.with1.soft ~ P.char( + '{' + ) ~ maybeSpace) *> items <* (maybeSpace ~ P.char('}'))) + .map { case (args, fn) => { (c: Constructor) => recordPat(c, args)(fn) } } } - private def parseTupleStruct(recurse: P[Parsed]): P[Constructor => PositionalStruct[StructKind, TypeRef]] = { + private def parseTupleStruct( + recurse: P[Parsed] + ): P[Constructor => PositionalStruct[StructKind, TypeRef]] = { // There are three cases: // Foo(1 or more patterns) // Foo(1 or more patterns, ...) @@ -1029,13 +1133,19 @@ object Pattern { val oneOrMore = recurse.nonEmptyList.map(_.toList) ~ maybePartial val onlyPartial = P.string("...").as { - (Nil, { (n: Constructor, s: StructKind.Style) => StructKind.NamedPartial(n, s) }) + ( + Nil, + { (n: Constructor, s: StructKind.Style) => + StructKind.NamedPartial(n, s) + } + ) } - (oneOrMore <+> onlyPartial) - .parensCut - .map { case (args, fn) => - { (n: Constructor) => PositionalStruct(fn(n, StructKind.Style.TupleLike), args) } + (oneOrMore <+> onlyPartial).parensCut + .map { + case (args, fn) => { (n: Constructor) => + PositionalStruct(fn(n, StructKind.Style.TupleLike), args) + } } } @@ -1045,30 +1155,35 @@ object Pattern { def isNonUnitTuple(arg: Parsed): Boolean = arg match { case PositionalStruct(StructKind.Tuple, args) => args.nonEmpty - case _ => false + case _ => false } def fromTupleOrParens(e: Either[Parsed, List[Parsed]]): Parsed = e match { - case Right(tup) => tuple(tup) + case Right(tup) => tuple(tup) case Left(parens) => parens } def fromMaybeTupleOrParens(p: MaybeTupleOrParens[Parsed]): Parsed = p match { - case MaybeTupleOrParens.Bare(b) => b + case MaybeTupleOrParens.Bare(b) => b case MaybeTupleOrParens.Parens(p) => p - case MaybeTupleOrParens.Tuple(p) => tuple(p) + case MaybeTupleOrParens.Tuple(p) => tuple(p) } private def matchOrNot(isMatch: Boolean): P[Parsed] = { val recurse = P.defer(bindParser) val positional = - (Identifier.consParser ~ (parseTupleStruct(recurse) <+> parseRecordStruct(recurse)).?) + (Identifier.consParser ~ (parseTupleStruct(recurse) <+> parseRecordStruct( + recurse + )).?) .map { case (n, None) => - PositionalStruct(StructKind.Named(n, StructKind.Style.TupleLike), Nil) + PositionalStruct( + StructKind.Named(n, StructKind.Style.TupleLike), + Nil + ) case (n, Some(fn)) => fn(n) } @@ -1088,10 +1203,16 @@ object Pattern { val pvar = Identifier.bindableParser.map(Var(_)) val nonAnnotated = - P.defer(P.oneOf(plit :: pwild :: tupleOrParens :: positional :: listP :: pvar :: Nil)) + P.defer( + P.oneOf( + plit :: pwild :: tupleOrParens :: positional :: listP :: pvar :: Nil + ) + ) val namedOp: P[Parsed => Parsed] = - ((maybeSpace.with1 *> P.string("as") <* Parser.spaces).backtrack *> Identifier.bindableParser) + ((maybeSpace.with1 *> P.string( + "as" + ) <* Parser.spaces).backtrack *> Identifier.bindableParser) .map { n => { (pat: Parsed) => Named(n, pat) } } @@ -1123,4 +1244,3 @@ object Pattern { else withAs.maybeAp(unionOp.orElse(typeAnnotOp)) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Predef.scala b/core/src/main/scala/org/bykn/bosatsu/Predef.scala index 017861e8e..93e6f5ddf 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Predef.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Predef.scala @@ -5,16 +5,16 @@ import java.math.BigInteger import language.experimental.macros object Predef { - /** - * Loads a file *at compile time* as a means of embedding - * external files into strings. This lets us avoid resources - * which compilicate matters for scalajs. - */ - private[bosatsu] def loadFileInCompile(file: String): String = macro Macro.loadFileInCompileImpl - - /** - * String representation of the predef - */ + + /** Loads a file *at compile time* as a means of embedding external files into + * strings. This lets us avoid resources which compilicate matters for + * scalajs. + */ + private[bosatsu] def loadFileInCompile(file: String): String = + macro Macro.loadFileInCompileImpl + + /** String representation of the predef + */ val predefString: String = loadFileInCompile("core/src/main/resources/bosatsu/predef.bosatsu") @@ -22,8 +22,7 @@ object Predef { PackageName.PredefName val jvmExternals: Externals = - Externals - .empty + Externals.empty .add(packageName, "add", FfiCall.Fn2(PredefImpl.add(_, _))) .add(packageName, "div", FfiCall.Fn2(PredefImpl.div(_, _))) .add(packageName, "sub", FfiCall.Fn2(PredefImpl.sub(_, _))) @@ -32,20 +31,52 @@ object Predef { .add(packageName, "cmp_Int", FfiCall.Fn2(PredefImpl.cmp_Int(_, _))) .add(packageName, "gcd_Int", FfiCall.Fn2(PredefImpl.gcd_Int(_, _))) .add(packageName, "mod_Int", FfiCall.Fn2(PredefImpl.mod_Int(_, _))) - .add(packageName, "shift_right_Int", FfiCall.Fn2(PredefImpl.shift_right_Int(_, _))) - .add(packageName, "shift_left_Int", FfiCall.Fn2(PredefImpl.shift_left_Int(_, _))) + .add( + packageName, + "shift_right_Int", + FfiCall.Fn2(PredefImpl.shift_right_Int(_, _)) + ) + .add( + packageName, + "shift_left_Int", + FfiCall.Fn2(PredefImpl.shift_left_Int(_, _)) + ) .add(packageName, "and_Int", FfiCall.Fn2(PredefImpl.and_Int(_, _))) .add(packageName, "or_Int", FfiCall.Fn2(PredefImpl.or_Int(_, _))) .add(packageName, "xor_Int", FfiCall.Fn2(PredefImpl.xor_Int(_, _))) .add(packageName, "not_Int", FfiCall.Fn1(PredefImpl.not_Int(_))) .add(packageName, "int_loop", FfiCall.Fn3(PredefImpl.intLoop(_, _, _))) - .add(packageName, "int_to_String", FfiCall.Fn1(PredefImpl.int_to_String(_))) + .add( + packageName, + "int_to_String", + FfiCall.Fn1(PredefImpl.int_to_String(_)) + ) .add(packageName, "trace", FfiCall.Fn2(PredefImpl.trace(_, _))) - .add(packageName, "string_Order_fn", FfiCall.Fn2(PredefImpl.string_Order_Fn(_, _))) - .add(packageName, "concat_String", FfiCall.Fn1(PredefImpl.concat_String(_))) - .add(packageName, "char_to_String", FfiCall.Fn1(PredefImpl.char_to_String(_))) - .add(packageName, "partition_String", FfiCall.Fn2(PredefImpl.partitionString(_, _))) - .add(packageName, "rpartition_String", FfiCall.Fn2(PredefImpl.rightPartitionString(_, _))) + .add( + packageName, + "string_Order_fn", + FfiCall.Fn2(PredefImpl.string_Order_Fn(_, _)) + ) + .add( + packageName, + "concat_String", + FfiCall.Fn1(PredefImpl.concat_String(_)) + ) + .add( + packageName, + "char_to_String", + FfiCall.Fn1(PredefImpl.char_to_String(_)) + ) + .add( + packageName, + "partition_String", + FfiCall.Fn2(PredefImpl.partitionString(_, _)) + ) + .add( + packageName, + "rpartition_String", + FfiCall.Fn2(PredefImpl.rightPartitionString(_, _)) + ) } object PredefImpl { @@ -55,7 +86,7 @@ object PredefImpl { private def i(a: Value): BigInteger = a match { case VInt(bi) => bi - case _ => sys.error(s"expected integer: $a") + case _ => sys.error(s"expected integer: $a") } def add(a: Value, b: Value): Value = @@ -128,13 +159,12 @@ object PredefImpl { val bi = b.intValue() val a1 = a.shiftRight(bi) if (b.compareTo(MaxIntBI) > 0) { - //$COVERAGE-OFF$ + // $COVERAGE-OFF$ // java bigInteger can't actually store arbitrarily large // integers, just blow up here sys.error(s"invalid huge shiftRight($a, $b)") - //$COVERAGE-ON$ - } - else { + // $COVERAGE-ON$ + } else { a1 } } @@ -146,13 +176,12 @@ object PredefImpl { val bi = b.intValue() val a1 = a.shiftLeft(bi) if (b.compareTo(MaxIntBI) > 0) { - //$COVERAGE-OFF$ + // $COVERAGE-OFF$ // java bigInteger can't actually store arbitrarily large // integers, just blow up here sys.error(s"invalid huge shiftLeft($a, $b)") - //$COVERAGE-ON$ - } - else { + // $COVERAGE-ON$ + } else { a1 } } @@ -172,7 +201,7 @@ object PredefImpl { def not_Int(a: Value): Value = VInt(i(a).not()) - //def intLoop(intValue: Int, state: a, fn: Int -> a -> Tuple2[Int, a]) -> a + // def intLoop(intValue: Int, state: a, fn: Int -> a -> Tuple2[Int, a]) -> a final def intLoop(intValue: Value, state: Value, fn: Value): Value = { val fnT = fn.asFn @@ -186,9 +215,9 @@ object PredefImpl { if (n.compareTo(bi) >= 0) { // we are done in this case nextA - } - else loop(nextI, n, nextA) - case other => sys.error(s"unexpected ill-typed value: at $bi, $state, $other") + } else loop(nextI, n, nextA) + case other => + sys.error(s"unexpected ill-typed value: at $bi, $state, $other") } } @@ -219,17 +248,16 @@ object PredefImpl { case Value.VList(parts) => Value.Str(parts.iterator.map { case Value.Str(s) => s - case other => - //$COVERAGE-OFF$ + case other => + // $COVERAGE-OFF$ sys.error(s"type error: $other") - //$COVERAGE-ON$ - } - .mkString) + // $COVERAGE-ON$ + }.mkString) case other => - //$COVERAGE-OFF$ + // $COVERAGE-OFF$ sys.error(s"type error: $other") - //$COVERAGE-ON$ + // $COVERAGE-ON$ } // return an Option[(String, String)] @@ -242,11 +270,12 @@ object PredefImpl { val idx = argS.indexOf(sepS) if (idx < 0) Value.VOption.none - else Value.VOption.some { - val left = argS.substring(0, idx) - val right = argS.substring(idx + sepS.length) - Value.Tuple(Value.ExternalValue(left), Value.ExternalValue(right)) - } + else + Value.VOption.some { + val left = argS.substring(0, idx) + val right = argS.substring(idx + sepS.length) + Value.Tuple(Value.ExternalValue(left), Value.ExternalValue(right)) + } } } @@ -258,12 +287,12 @@ object PredefImpl { val argS = arg.asExternal.toAny.asInstanceOf[String] val idx = argS.lastIndexOf(sepS) if (idx < 0) Value.VOption.none - else Value.VOption.some { - val left = argS.substring(0, idx) - val right = argS.substring(idx + sepS.length) - Value.Tuple(Value.ExternalValue(left), Value.ExternalValue(right)) - } + else + Value.VOption.some { + val left = argS.substring(0, idx) + val right = argS.substring(idx + sepS.length) + Value.Tuple(Value.ExternalValue(left), Value.ExternalValue(right)) + } } } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Program.scala b/core/src/main/scala/org/bykn/bosatsu/Program.scala index c1684fa80..f1729c74e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Program.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Program.scala @@ -3,10 +3,11 @@ package org.bykn.bosatsu import Identifier.Bindable case class Program[+T, +D, +S]( - types: T, - lets: List[(Bindable, RecursionKind, D)], - externalDefs: List[Bindable], - from: S) { + types: T, + lets: List[(Bindable, RecursionKind, D)], + externalDefs: List[Bindable], + from: S +) { private[this] lazy val letMap: Map[Bindable, (RecursionKind, D)] = lets.iterator.map { case (n, r, d) => (n, (r, d)) }.toMap diff --git a/core/src/main/scala/org/bykn/bosatsu/Referant.scala b/core/src/main/scala/org/bykn/bosatsu/Referant.scala index 4a5ad2d40..71cff11c3 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Referant.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Referant.scala @@ -6,20 +6,23 @@ import rankn.{ConstructorFn, DefinedType, Type, TypeEnv} import Identifier.{Constructor => ConstructorName} -/** - * A Referant is something that can be exported or imported after resolving - * Before resolving, imports and exports are just names. - */ +/** A Referant is something that can be exported or imported after resolving + * Before resolving, imports and exports are just names. + */ sealed abstract class Referant[+A] { // if this is a Constructor or DefinedT, return the associated DefinedType def definedType: Option[DefinedType[A]] = this match { - case Referant.Value(_) => None - case Referant.DefinedT(dt) => Some(dt) + case Referant.Value(_) => None + case Referant.DefinedT(dt) => Some(dt) case Referant.Constructor(dt, _) => Some(dt) } - def addTo[A1 >: A](packageName: PackageName, name: Identifier, te: TypeEnv[A1]): TypeEnv[A1] = + def addTo[A1 >: A]( + packageName: PackageName, + name: Identifier, + te: TypeEnv[A1] + ): TypeEnv[A1] = this match { case Referant.Value(t) => te.addExternalValue(packageName, name, t) @@ -33,71 +36,84 @@ sealed abstract class Referant[+A] { object Referant { case class Value(scheme: Type) extends Referant[Nothing] case class DefinedT[A](dtype: DefinedType[A]) extends Referant[A] - case class Constructor[A](dtype: DefinedType[A], fn: ConstructorFn) extends Referant[A] + case class Constructor[A](dtype: DefinedType[A], fn: ConstructorFn) + extends Referant[A] - private def imported[A, B, C](imps: List[Import[A, NonEmptyList[Referant[C]]]])(fn: PartialFunction[Referant[C], B]): Map[Identifier, B] = + private def imported[A, B, C]( + imps: List[Import[A, NonEmptyList[Referant[C]]]] + )(fn: PartialFunction[Referant[C], B]): Map[Identifier, B] = imps.foldLeft(Map.empty[Identifier, B]) { (m0, imp) => m0 ++ Import.locals(imp)(fn) } - def importedTypes[A, B](imps: List[Import[A, NonEmptyList[Referant[B]]]]): Map[Identifier, (PackageName, TypeName)] = - imported(imps) { - case Referant.DefinedT(dt) => (dt.packageName, dt.name) + def importedTypes[A, B]( + imps: List[Import[A, NonEmptyList[Referant[B]]]] + ): Map[Identifier, (PackageName, TypeName)] = + imported(imps) { case Referant.DefinedT(dt) => + (dt.packageName, dt.name) } - /** - * These are all the imported items that may be used in a match - */ - def importedConsNames[A, B](imps: List[Import[A, NonEmptyList[Referant[B]]]]): Map[Identifier, (PackageName, ConstructorName)] = - imported(imps) { - case Referant.Constructor(dt, fn) => (dt.packageName, fn.name) + /** These are all the imported items that may be used in a match + */ + def importedConsNames[A, B]( + imps: List[Import[A, NonEmptyList[Referant[B]]]] + ): Map[Identifier, (PackageName, ConstructorName)] = + imported(imps) { case Referant.Constructor(dt, fn) => + (dt.packageName, fn.name) } - /** - * Fully qualified original names - */ + /** Fully qualified original names + */ def fullyQualifiedImportedValues[A, B]( - imps: List[Import[A, NonEmptyList[Referant[B]]]])(nameOf: A => PackageName)(implicit ev: B <:< Kind.Arg): Map[(PackageName, Identifier), Type] = + imps: List[Import[A, NonEmptyList[Referant[B]]]] + )( + nameOf: A => PackageName + )(implicit ev: B <:< Kind.Arg): Map[(PackageName, Identifier), Type] = imps.iterator.flatMap { item => val pn = nameOf(item.pack) item.items.toList.iterator.flatMap { i => val orig = i.originalName val key = (pn, orig) i.tag.toList.iterator.collect { - case Referant.Value(t) => (key, t) + case Referant.Value(t) => (key, t) case Referant.Constructor(dt, fn) => (key, dt.fnTypeOf(fn)) } } - } - .toMap + }.toMap def typeConstructors[A, B]( - imps: List[Import[A, NonEmptyList[Referant[B]]]]): - Map[(PackageName, ConstructorName), (List[(Type.Var.Bound, B)], List[Type], Type.Const.Defined)] = { - val refs: Iterator[Referant[B]] = imps.iterator.flatMap(_.items.toList.iterator.flatMap(_.tag.toList)) + imps: List[Import[A, NonEmptyList[Referant[B]]]] + ): Map[ + (PackageName, ConstructorName), + (List[(Type.Var.Bound, B)], List[Type], Type.Const.Defined) + ] = { + val refs: Iterator[Referant[B]] = + imps.iterator.flatMap(_.items.toList.iterator.flatMap(_.tag.toList)) refs.collect { case Constructor(dt, fn) => - ((dt.packageName, fn.name), (dt.annotatedTypeParams, fn.args.map(_._2), dt.toTypeConst)) - } - .toMap + ( + (dt.packageName, fn.name), + (dt.annotatedTypeParams, fn.args.map(_._2), dt.toTypeConst) + ) + }.toMap } - /** - * Build the TypeEnv view of the given imports - */ - def importedTypeEnv[A, B](inps: List[Import[A, NonEmptyList[Referant[B]]]])(nameOf: A => PackageName): TypeEnv[B] = - inps.foldLeft((TypeEnv.empty): TypeEnv[B]) { - case (te, imps) => - val pack = nameOf(imps.pack) - imps.items.foldLeft(te) { (te, imp) => - val nm = imp.localName - imp.tag.foldLeft(te) { - case (te1, Referant.Value(t)) => - te1.addExternalValue(pack, nm, t) - case (te1, Referant.Constructor(dt, cf)) => - te1.addConstructor(pack, dt, cf) - case (te1, Referant.DefinedT(dt)) => - te1.addDefinedType(dt) - } + /** Build the TypeEnv view of the given imports + */ + def importedTypeEnv[A, B](inps: List[Import[A, NonEmptyList[Referant[B]]]])( + nameOf: A => PackageName + ): TypeEnv[B] = + inps.foldLeft((TypeEnv.empty): TypeEnv[B]) { case (te, imps) => + val pack = nameOf(imps.pack) + imps.items.foldLeft(te) { (te, imp) => + val nm = imp.localName + imp.tag.foldLeft(te) { + case (te1, Referant.Value(t)) => + te1.addExternalValue(pack, nm, t) + case (te1, Referant.Constructor(dt, cf)) => + te1.addConstructor(pack, dt, cf) + case (te1, Referant.DefinedT(dt)) => + te1.addDefinedType(dt) } + } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/SelfCallKind.scala b/core/src/main/scala/org/bykn/bosatsu/SelfCallKind.scala index 25d4d5604..769c27375 100644 --- a/core/src/main/scala/org/bykn/bosatsu/SelfCallKind.scala +++ b/core/src/main/scala/org/bykn/bosatsu/SelfCallKind.scala @@ -44,11 +44,12 @@ object SelfCallKind { private def isFn[A](n: Bindable, te: TypedExpr[A]): Boolean = te match { - case TypedExpr.Generic(_, in) => isFn(n, in) + case TypedExpr.Generic(_, in) => isFn(n, in) case TypedExpr.Annotation(te, _) => isFn(n, te) - case TypedExpr.Local(vn, _, _) => vn == n - case _ => false + case TypedExpr.Local(vn, _, _) => vn == n + case _ => false } + /** assuming expr is bound to nm, what kind of self call does it contain? */ def apply[A](n: Bindable, te: TypedExpr[A]): SelfCallKind = diff --git a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala index 2962c9311..13f0081ee 100644 --- a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala @@ -1,7 +1,7 @@ package org.bykn.bosatsu import cats.{Applicative, Traverse} -import cats.data.{ Chain, Ior, NonEmptyChain, NonEmptyList, State } +import cats.data.{Chain, Ior, NonEmptyChain, NonEmptyList, State} import org.bykn.bosatsu.rankn.{ParsedTypeEnv, Type, TypeEnv} import scala.collection.immutable.SortedSet import scala.collection.mutable.{Map => MMap} @@ -20,14 +20,14 @@ import Declaration._ import SourceConverter.{success, Result} -/** - * Convert a source types (a syntactic expression) into - * the internal representations - */ +/** Convert a source types (a syntactic expression) into the internal + * representations + */ final class SourceConverter( - thisPackage: PackageName, - imports: List[Import[PackageName, NonEmptyList[Referant[Kind.Arg]]]], - localDefs: List[TypeDefinitionStatement]) { + thisPackage: PackageName, + imports: List[Import[PackageName, NonEmptyList[Referant[Kind.Arg]]]], + localDefs: List[TypeDefinitionStatement] +) { /* * We should probably error for non-predef name collisions. * Maybe we should even error even or predef collisions that @@ -37,7 +37,8 @@ final class SourceConverter( private val localConstructors = localDefs.flatMap(_.constructors).toSet private val typeCache: MMap[Constructor, Type.Const] = MMap.empty - private val consCache: MMap[Constructor, (PackageName, Constructor)] = MMap.empty + private val consCache: MMap[Constructor, (PackageName, Constructor)] = + MMap.empty private val importedTypes: Map[Identifier, (PackageName, TypeName)] = Referant.importedTypes(imports) @@ -51,7 +52,10 @@ final class SourceConverter( val importedTypeEnv: TypeEnv[Kind.Arg] = Referant.importedTypeEnv(imports)(identity) - private def nameToType(c: Constructor, region: Region): Result[rankn.Type.Const] = + private def nameToType( + c: Constructor, + region: Region + ): Result[rankn.Type.Const] = typeCache.get(c) match { case Some(r) => success(r) case None => @@ -60,8 +64,7 @@ final class SourceConverter( val res = Type.Const.Defined(thisPackage, tc) typeCache.update(c, res) success(res) - } - else { + } else { importedTypes.get(c) match { case Some((p, t)) => val res = Type.Const.Defined(p, t) @@ -72,23 +75,28 @@ final class SourceConverter( val bestEffort = Type.Const.Defined(thisPackage, tc) SourceConverter.partial( SourceConverter.UnknownTypeName(c, region), - bestEffort) + bestEffort + ) } } } private def nameToCons(c: Constructor): (PackageName, Constructor) = - consCache.getOrElseUpdate(c, { - if (localConstructors(c)) (thisPackage, c) - else resolveImportedCons.getOrElse(c, (thisPackage, c)) - }) + consCache.getOrElseUpdate( + c, { + if (localConstructors(c)) (thisPackage, c) + else resolveImportedCons.getOrElse(c, (thisPackage, c)) + } + ) /* * This ignores the name completely and just returns the lambda expression here */ private def toLambdaExpr[B]( - ds: DefStatement[Pattern.Parsed, B], region: Region, tag: Result[Declaration])( - resultExpr: B => Result[Expr[Declaration]]): Result[Expr[Declaration]] = { + ds: DefStatement[Pattern.Parsed, B], + region: Region, + tag: Result[Declaration] + )(resultExpr: B => Result[Expr[Declaration]]): Result[Expr[Declaration]] = { val unTypedBody = resultExpr(ds.result) val bodyType: Option[Result[Type]] = ds.retType.map(toType(_, region)) @@ -102,7 +110,7 @@ final class SourceConverter( type Pat = Pattern[(PackageName, Constructor), Type] val convertedArgs: Result[NonEmptyList[NonEmptyList[Pat]]] = - travNE2.traverse(ds.args)(convertPattern(_, region)) + travNE2.traverse(ds.args)(convertPattern(_, region)) // If we have the full type of the lambda, apply it. This // helps in recursive cases since we can see at the call site @@ -110,49 +118,63 @@ final class SourceConverter( // was incorrect. Without this, type errors become very non-specific. val maybeFullyTyped: Result[Option[Type]] = (convertedArgs, bodyType.sequence).parMapN { case (args, optResTpe) => - (travNE2.traverse(args)((p: Pat) => p.simpleTypeOf), optResTpe).mapN { case (argsTpe, resTpe) => - argsTpe.toList.foldRight(resTpe) { (args, res) => rankn.Type.Fun(args, res) } + (travNE2.traverse(args)((p: Pat) => p.simpleTypeOf), optResTpe).mapN { + case (argsTpe, resTpe) => + argsTpe.toList.foldRight(resTpe) { (args, res) => + rankn.Type.Fun(args, res) + } } } - (convertedArgs, - bodyExp, - tag, - maybeFullyTyped).parMapN { (groups, b, t, fullType) => - val lambda0 = groups.toList.foldRight(b) { case (as, b) => Expr.buildPatternLambda(as, b, t) } - val lambda = fullType.fold(lambda0)(Expr.Annotation(lambda0, _, t)) - ds.typeArgs match { - case None => success(lambda) - case Some(args) => - val bs = args.map { - case (tr, optK) => - (tr.toBoundVar, optK match { - case None => Kind.Type - case Some(k) => k - }) + (convertedArgs, bodyExp, tag, maybeFullyTyped).parMapN { + (groups, b, t, fullType) => + val lambda0 = groups.toList.foldRight(b) { case (as, b) => + Expr.buildPatternLambda(as, b, t) + } + val lambda = fullType.fold(lambda0)(Expr.Annotation(lambda0, _, t)) + ds.typeArgs match { + case None => success(lambda) + case Some(args) => + val bs = args.map { case (tr, optK) => + ( + tr.toBoundVar, + optK match { + case None => Kind.Type + case Some(k) => k + } + ) } - val gen = Expr.forAll(bs.toList, lambda) - val freeVarsList = Expr.freeBoundTyVars(lambda) - val freeVars = freeVarsList.toSet - val notFreeDecl = bs.exists { case (a, _) => !freeVars(a) } - if (notFreeDecl) { - // we have a lint that fails if declTV is not - // a superset of what you would derive from the args - // the purpose here is to control the *order* of - // and to allow introducing phantom parameters, not - // it is confusing if some are explicit, but some are not - SourceConverter.partial( - SourceConverter.InvalidDefTypeParameters(args, freeVarsList, ds, region), - gen) - } - else success(gen) - } - } - .flatten + val gen = Expr.forAll(bs.toList, lambda) + val freeVarsList = Expr.freeBoundTyVars(lambda) + val freeVars = freeVarsList.toSet + val notFreeDecl = bs.exists { case (a, _) => !freeVars(a) } + if (notFreeDecl) { + // we have a lint that fails if declTV is not + // a superset of what you would derive from the args + // the purpose here is to control the *order* of + // and to allow introducing phantom parameters, not + // it is confusing if some are explicit, but some are not + SourceConverter.partial( + SourceConverter.InvalidDefTypeParameters( + args, + freeVarsList, + ds, + region + ), + gen + ) + } else success(gen) + } + }.flatten } - private def resolveToVar[A](ident: Identifier, decl: A, bound: Set[Bindable], topBound: Set[Bindable]): Expr[A] = + private def resolveToVar[A]( + ident: Identifier, + decl: A, + bound: Set[Bindable], + topBound: Set[Bindable] + ): Expr[A] = ident match { - case c@Constructor(_) => + case c @ Constructor(_) => val (p, cons) = nameToCons(c) Expr.Global(p, cons, decl) case b: Bindable => @@ -160,11 +182,10 @@ final class SourceConverter( else if (topBound(b)) { // local top level bindings can shadow imports after they are imported Expr.Global(thisPackage, b, decl) - } - else { + } else { importedNames.get(ident) match { case Some((p, n)) => Expr.Global(p, n, decl) - case None => + case None => // this is an error, but it will be caught later // at type-checking Expr.Local(b, decl) @@ -172,84 +193,98 @@ final class SourceConverter( } } - - private val unitName = Identifier.Constructor("Unit") - // this is lazy so it isn't initialized before Type - private lazy val tup: Array[Declaration => Expr[Declaration]] = - (Iterator.single( - { (tc: Declaration) => - Expr.Global(PackageName.PredefName, unitName, tc) - } - ) ++ (1 to Type.FnType.MaxSize) - .iterator - .map { idx => - val tup = Type.Tuple.Arity(idx) - val defined = tup.tpe.toDefined - val pn = defined.packageName - val cn = defined.name.ident - - { (tc: Declaration) => Expr.Global(pn, cn, tc) } - }).toArray - - private def makeTuple(tc: Declaration, args: List[Declaration])(conv: Declaration => Result[Expr[Declaration]]): Result[Expr[Declaration]] = - args.traverse(conv) - .flatMap { exps => - val size = exps.length - if (size <= Type.FnType.MaxSize) { - val fn = tup(size)(tc) - val res = Expr.buildApp(fn, exps, tc) - success(res) - } - else { - SourceConverter.failure( - SourceConverter.TooManyConstructorArgs( - Type.Tuple.Arity(32).tpe.toDefined.name.ident, - size, 32, tc.region) + private val unitName = Identifier.Constructor("Unit") + // this is lazy so it isn't initialized before Type + private lazy val tup: Array[Declaration => Expr[Declaration]] = + (Iterator.single( + { (tc: Declaration) => + Expr.Global(PackageName.PredefName, unitName, tc) + } + ) ++ (1 to Type.FnType.MaxSize).iterator + .map { idx => + val tup = Type.Tuple.Arity(idx) + val defined = tup.tpe.toDefined + val pn = defined.packageName + val cn = defined.name.ident + + { (tc: Declaration) => Expr.Global(pn, cn, tc) } + }).toArray + + private def makeTuple(tc: Declaration, args: List[Declaration])( + conv: Declaration => Result[Expr[Declaration]] + ): Result[Expr[Declaration]] = + args + .traverse(conv) + .flatMap { exps => + val size = exps.length + if (size <= Type.FnType.MaxSize) { + val fn = tup(size)(tc) + val res = Expr.buildApp(fn, exps, tc) + success(res) + } else { + SourceConverter.failure( + SourceConverter.TooManyConstructorArgs( + Type.Tuple.Arity(32).tpe.toDefined.name.ident, + size, + 32, + tc.region ) - } + ) } - - private val unitPat = - success(Pattern.PositionalStruct( - (PackageName.PredefName, unitName), - Nil)) - - def makeTuplePattern[A](args: List[Pattern[(PackageName, Constructor), A]], region: Region): Result[Pattern[(PackageName, Constructor), A]] = - args match { - case Nil => unitPat - case nonEmpty => - val size = nonEmpty.size - val tupleCons = Type.Tuple.Arity(size) - val defined = tupleCons.tpe.toDefined - val pat = Pattern.PositionalStruct( - (defined.packageName, defined.name.ident), - nonEmpty) - if (size <= Type.FnType.MaxSize) { - success(pat) - } - else { - SourceConverter.partial( - SourceConverter.TooManyConstructorArgs( - Type.Tuple.Arity(32).tpe.toDefined.name.ident, - size, 32, region), - pat - ) - } } - private def fromDecl(decl: Declaration, bound: Set[Bindable], topBound: Set[Bindable]): Result[Expr[Declaration]] = { + private val unitPat = + success(Pattern.PositionalStruct((PackageName.PredefName, unitName), Nil)) + + def makeTuplePattern[A]( + args: List[Pattern[(PackageName, Constructor), A]], + region: Region + ): Result[Pattern[(PackageName, Constructor), A]] = + args match { + case Nil => unitPat + case nonEmpty => + val size = nonEmpty.size + val tupleCons = Type.Tuple.Arity(size) + val defined = tupleCons.tpe.toDefined + val pat = Pattern.PositionalStruct( + (defined.packageName, defined.name.ident), + nonEmpty + ) + if (size <= Type.FnType.MaxSize) { + success(pat) + } else { + SourceConverter.partial( + SourceConverter.TooManyConstructorArgs( + Type.Tuple.Arity(32).tpe.toDefined.name.ident, + size, + 32, + region + ), + pat + ) + } + } + + private def fromDecl( + decl: Declaration, + bound: Set[Bindable], + topBound: Set[Bindable] + ): Result[Expr[Declaration]] = { implicit val parAp = SourceConverter.parallelIor def loop(decl: Declaration) = fromDecl(decl, bound, topBound) - def withBound(decl: Declaration, newB: Iterable[Bindable]) = fromDecl(decl, bound ++ newB, topBound) + def withBound(decl: Declaration, newB: Iterable[Bindable]) = + fromDecl(decl, bound ++ newB, topBound) decl match { case Annotation(term, tpe) => - (loop(term), toType(tpe, decl.region)).parMapN(Expr.Annotation(_, _, decl)) + (loop(term), toType(tpe, decl.region)) + .parMapN(Expr.Annotation(_, _, decl)) case Apply(fn, args, _) => (loop(fn), args.toList.traverse(loop(_))) .parMapN { Expr.buildApp(_, _, decl) } - case ao@ApplyOp(left, op, right) => - val opVar: Expr[Declaration] = resolveToVar(op, ao.opVar, bound, topBound) + case ao @ ApplyOp(left, op, right) => + val opVar: Expr[Declaration] = + resolveToVar(op, ao.opVar, bound, topBound) (loop(left), loop(right)).parMapN { (l, r) => Expr.buildApp(opVar, l :: r :: Nil, decl) } @@ -257,7 +292,10 @@ final class SourceConverter( val erest = withBound(rest, pat.names) val assignRegion = decl.region - value.region - def solvePat(pat: Pattern.Parsed, rrhs: Result[Expr[Declaration]]): Result[Expr[Declaration]] = + def solvePat( + pat: Pattern.Parsed, + rrhs: Result[Expr[Declaration]] + ): Result[Expr[Declaration]] = pat match { case Pattern.Var(arg) => (erest, rrhs).parMapN { (e, rhs) => @@ -268,19 +306,22 @@ final class SourceConverter( // move the annotation to the right // it's not ideal to use the Declaration of rhs, but it's better // than the entire let - val newRhs = rrhs.map { r => Expr.Annotation(r, realTpe, r.tag) } + val newRhs = rrhs.map { r => + Expr.Annotation(r, realTpe, r.tag) + } solvePat(pat, newRhs) } case Pattern.Named(nm, p) => - // this is the same as creating a let nm = value first + // this is the same as creating a let nm = value first (solvePat(p, rrhs), rrhs).parMapN { (inner, rhs) => Expr.Let(nm, rhs, inner, RecursionKind.NonRecursive, decl) } case pat => // TODO: we need the region on the pattern... - (convertPattern(pat, assignRegion), erest, rrhs).parMapN { (newPattern, e, rhs) => - val expBranches = NonEmptyList.of((newPattern, e)) - Expr.Match(rhs, expBranches, decl) + (convertPattern(pat, assignRegion), erest, rrhs).parMapN { + (newPattern, e, rhs) => + val expBranches = NonEmptyList.of((newPattern, e)) + Expr.Match(rhs, expBranches, decl) } } @@ -289,23 +330,30 @@ final class SourceConverter( loop(decl).map(_.replaceTag(decl)) case CommentNB(CommentStatement(_, Padding(_, decl))) => loop(decl).map(_.replaceTag(decl)) - case DefFn(defstmt@DefStatement(_, _, _, _, _)) => + case DefFn(defstmt @ DefStatement(_, _, _, _, _)) => val inExpr = defstmt.result match { case (_, Padding(_, in)) => withBound(in, defstmt.name :: Nil) } - val newBindings = defstmt.name :: defstmt.args.toList.flatMap(_.patternNames) - val lambda = toLambdaExpr(defstmt, decl.region, success(decl))({ res => withBound(res._1.get, newBindings) }) + val newBindings = + defstmt.name :: defstmt.args.toList.flatMap(_.patternNames) + val lambda = toLambdaExpr(defstmt, decl.region, success(decl))({ res => + withBound(res._1.get, newBindings) + }) (inExpr, lambda).parMapN { (in, lam) => // We rely on DefRecursionCheck to rule out bad recursions val boundName = defstmt.name val rec = - if (UnusedLetCheck.freeBound(lam).contains(boundName)) RecursionKind.Recursive + if (UnusedLetCheck.freeBound(lam).contains(boundName)) + RecursionKind.Recursive else RecursionKind.NonRecursive Expr.Let(boundName, lam, in, recursive = rec, decl) } case IfElse(ifCases, elseCase) => - def loop0(ifs: NonEmptyList[(Expr[Declaration], Expr[Declaration])], elseC: Expr[Declaration]): Expr[Declaration] = + def loop0( + ifs: NonEmptyList[(Expr[Declaration], Expr[Declaration])], + elseC: Expr[Declaration] + ): Expr[Declaration] = ifs match { case NonEmptyList((cond, ifTrue), Nil) => Expr.ifExpr(cond, ifTrue, elseC, decl) @@ -319,15 +367,19 @@ final class SourceConverter( val else1 = loop(elseCase.get) (if1, else1).parMapN(loop0(_, _)) - case tern@Ternary(t, c, f) => - loop(IfElse(NonEmptyList.one((c, OptIndent.same(t))), OptIndent.same(f))(tern.region)) + case tern @ Ternary(t, c, f) => + loop( + IfElse(NonEmptyList.one((c, OptIndent.same(t))), OptIndent.same(f))( + tern.region + ) + ) case Lambda(args, body) => val argsRes = args.traverse(convertPattern(_, decl.region)) val bodyRes = withBound(body, args.patternNames) (argsRes, bodyRes).parMapN { (args, body) => Expr.buildPatternLambda(args, body, decl) } - case la@LeftApply(_, _, _, _) => + case la @ LeftApply(_, _, _, _) => loop(la.rewrite).map(_.replaceTag(decl)) case Literal(lit) => success(Expr.Literal(lit, decl)) @@ -346,55 +398,80 @@ final class SourceConverter( newPattern.product(withBound(decl, pat.names)) } (loop(arg), expBranches).parMapN(Expr.Match(_, _, decl)) - case m@Matches(a, p) => + case m @ Matches(a, p) => // x matches p == // match x: // p: True // _: False - val True: Expr[Declaration] = Expr.Global(PackageName.PredefName, Identifier.Constructor("True"), m) - val False: Expr[Declaration] = Expr.Global(PackageName.PredefName, Identifier.Constructor("False"), m) + val True: Expr[Declaration] = + Expr.Global(PackageName.PredefName, Identifier.Constructor("True"), m) + val False: Expr[Declaration] = Expr.Global( + PackageName.PredefName, + Identifier.Constructor("False"), + m + ) (loop(a), convertPattern(p, m.region)).mapN { (a, p) => - val branches = NonEmptyList((p, True), (Pattern.WildCard, False) :: Nil) + val branches = + NonEmptyList((p, True), (Pattern.WildCard, False) :: Nil) Expr.Match(a, branches, m) } - case tc@TupleCons(its) => makeTuple(tc, its)(loop) - case s@StringDecl(parts) => + case tc @ TupleCons(its) => makeTuple(tc, its)(loop) + case s @ StringDecl(parts) => // a single string item should be converted // to that thing, // two or more should be converted this to concat_String([items]) def charToString(expr: Expr[Declaration]): Expr[Declaration] = { val fnName: Expr[Declaration] = - Expr.Global(PackageName.PredefName, Identifier.Name("char_to_String"), expr.tag) + Expr.Global( + PackageName.PredefName, + Identifier.Name("char_to_String"), + expr.tag + ) Expr.buildApp(fnName, expr :: Nil, expr.tag) } val decls = parts.parTraverse { case StringDecl.Literal(r, str) => loop(Literal(Lit(str))(r)) - case StringDecl.CharExpr(decl) => loop(decl).map(charToString) - case StringDecl.StrExpr(decl) => loop(decl) + case StringDecl.CharExpr(decl) => loop(decl).map(charToString) + case StringDecl.StrExpr(decl) => loop(decl) } decls.map { case NonEmptyList(one, Nil) => one case twoOrMore => - def listOf(expr: List[Expr[Declaration]], restDecl: Declaration): Expr[Declaration] = + def listOf( + expr: List[Expr[Declaration]], + restDecl: Declaration + ): Expr[Declaration] = expr match { case Nil => - Expr.Global(PackageName.PredefName, Identifier.Constructor("EmptyList"), restDecl) + Expr.Global( + PackageName.PredefName, + Identifier.Constructor("EmptyList"), + restDecl + ) case h :: tail => - val cons = Expr.Global(PackageName.PredefName, Identifier.Constructor("NonEmptyList"), restDecl) + val cons = Expr.Global( + PackageName.PredefName, + Identifier.Constructor("NonEmptyList"), + restDecl + ) val tailExpr = listOf(tail, h.tag) Expr.buildApp(cons, h :: tailExpr :: Nil, restDecl) } val fnName: Expr[Declaration] = - Expr.Global(PackageName.PredefName, Identifier.Name("concat_String"), s) + Expr.Global( + PackageName.PredefName, + Identifier.Name("concat_String"), + s + ) Expr.buildApp(fnName, listOf(twoOrMore.toList, s) :: Nil, s) } - case l@ListDecl(list) => + case l @ ListDecl(list) => list match { case ListLang.Cons(items) => val revDecs: Result[List[SpliceOrItem[Expr[Declaration]]]] = @@ -412,10 +489,16 @@ final class SourceConverter( Expr.Global(pn, Identifier.Name(c), l) val Empty: Expr[Declaration] = mkC("EmptyList") - def cons(head: Expr[Declaration], tail: Expr[Declaration]): Expr[Declaration] = + def cons( + head: Expr[Declaration], + tail: Expr[Declaration] + ): Expr[Declaration] = Expr.buildApp(mkC("NonEmptyList"), head :: tail :: Nil, l) - def concat(headList: Expr[Declaration], tail: Expr[Declaration]): Expr[Declaration] = + def concat( + headList: Expr[Declaration], + tail: Expr[Declaration] + ): Expr[Declaration] = Expr.buildApp(mkN("concat"), headList :: tail :: Nil, l) revDecs.map(_.foldLeft(Empty) { @@ -458,10 +541,11 @@ final class SourceConverter( "flat_map_List" } val newBound = binding.names - val opExpr: Expr[Declaration] = Expr.Global(pn, Identifier.Name(opName), l) + val opExpr: Expr[Declaration] = + Expr.Global(pn, Identifier.Name(opName), l) val resExpr: Result[Expr[Declaration]] = filter match { - case None => withBound(res.value, newBound) + case None => withBound(res.value, newBound) case Some(cond) => // To do filters, we lift all results into lists, // so single items must be made singleton lists @@ -473,9 +557,14 @@ final class SourceConverter( // here we lift the result into a a singleton list withBound(r, newBound).map { ritem => Expr.App( - Expr.Global(pn, Identifier.Constructor("NonEmptyList"), rdec), + Expr.Global( + pn, + Identifier.Constructor("NonEmptyList"), + rdec + ), NonEmptyList(ritem, empty :: Nil), - rdec) + rdec + ) } case SpliceOrItem.Splice(r) => withBound(r, newBound) } @@ -484,34 +573,40 @@ final class SourceConverter( Expr.ifExpr(c, sing, empty, cond) } } - (convertPattern(binding, decl.region), - resExpr, - loop(in)).mapN { (newPattern, resExpr, in) => - val fnExpr: Expr[Declaration] = - Expr.buildPatternLambda(NonEmptyList.of(newPattern), resExpr, l) - Expr.buildApp(opExpr, in :: fnExpr :: Nil, l) + (convertPattern(binding, decl.region), resExpr, loop(in)).mapN { + (newPattern, resExpr, in) => + val fnExpr: Expr[Declaration] = + Expr.buildPatternLambda( + NonEmptyList.of(newPattern), + resExpr, + l + ) + Expr.buildApp(opExpr, in :: fnExpr :: Nil, l) } } - case l@DictDecl(dict) => + case l @ DictDecl(dict) => val pn = PackageName.PredefName def mkN(n: String): Expr[Declaration] = Expr.Global(pn, Identifier.Name(n), l) val empty: Expr[Declaration] = Expr.App(mkN("empty_Dict"), NonEmptyList.one(mkN("string_Order")), l) - def add(dict: Expr[Declaration], k: Expr[Declaration], v: Expr[Declaration]): Expr[Declaration] = { + def add( + dict: Expr[Declaration], + k: Expr[Declaration], + v: Expr[Declaration] + ): Expr[Declaration] = { val fn = mkN("add_key") Expr.buildApp(fn, dict :: k :: v :: Nil, l) } dict match { case ListLang.Cons(items) => val revDecs: Result[List[KVPair[Expr[Declaration]]]] = - items.reverse.traverse { - case KVPair(k, v) => - (loop(k), loop(v)).mapN(KVPair(_, _)) + items.reverse.traverse { case KVPair(k, v) => + (loop(k), loop(v)).mapN(KVPair(_, _)) } - revDecs.map(_.foldLeft(empty) { - case (dict, KVPair(k, v)) => add(dict, k, v) + revDecs.map(_.foldLeft(empty) { case (dict, KVPair(k, v)) => + add(dict, k, v) }) case ListLang.Comprehension(KVPair(k, v), binding, in, filter) => /* @@ -529,19 +624,19 @@ final class SourceConverter( val newBound = binding.names val pn = PackageName.PredefName - val opExpr: Expr[Declaration] = Expr.Global(pn, Identifier.Name("foldLeft"), l) + val opExpr: Expr[Declaration] = + Expr.Global(pn, Identifier.Name("foldLeft"), l) val dictSymbol = unusedNames(decl.allNames).next() val init: Expr[Declaration] = Expr.Local(dictSymbol, l) - val added = (withBound(k, newBound), withBound(v, newBound)).mapN(add(init, _, _)) + val added = (withBound(k, newBound), withBound(v, newBound)).mapN( + add(init, _, _) + ) val resExpr: Result[Expr[Declaration]] = filter match { case None => added case Some(cond0) => (added, withBound(cond0, newBound)).mapN { (added, cond) => - Expr.ifExpr(cond, - added, - init, - cond0) + Expr.ifExpr(cond, added, init, cond0) } } val newPattern = convertPattern(binding, decl.region) @@ -550,67 +645,83 @@ final class SourceConverter( Expr.buildPatternLambda( NonEmptyList(Pattern.Var(dictSymbol), pat :: Nil), res, - l) + l + ) Expr.buildApp(opExpr, in :: empty :: foldFn :: Nil, l) } - } - case rc@RecordConstructor(name, args) => - val (p, c) = nameToCons(name) - val cons: Expr[Declaration] = Expr.Global(p, c, rc) - localTypeEnv.flatMap(_.getConstructorParams(p, c) match { - case Some(params) => - def argExpr(arg: RecordArg): (Bindable, Result[Expr[Declaration]]) = - arg match { - case RecordArg.Simple(b) => - (b, success(resolveToVar(b, rc, bound, topBound))) - case RecordArg.Pair(k, v) => - (k, loop(v)) - } - val mappingList = args.toList.map(argExpr) - val mapping = mappingList.toMap - - lazy val present = - mappingList - .iterator - .map(_._1) - .foldLeft(SortedSet.empty[Bindable])(_ + _) - - def get(b: Bindable): Result[Expr[Declaration]] = - mapping.get(b) match { - case Some(expr) => expr - case None => - SourceConverter.failure( - SourceConverter.MissingArg(name, rc, present, b, rc.region)) - } - val exprArgs = params.traverse { case (b, _) => get(b) } - - val res = exprArgs.map { args => - Expr.buildApp(cons, args.toList, rc) + } + case rc @ RecordConstructor(name, args) => + val (p, c) = nameToCons(name) + val cons: Expr[Declaration] = Expr.Global(p, c, rc) + localTypeEnv.flatMap(_.getConstructorParams(p, c) match { + case Some(params) => + def argExpr(arg: RecordArg): (Bindable, Result[Expr[Declaration]]) = + arg match { + case RecordArg.Simple(b) => + (b, success(resolveToVar(b, rc, bound, topBound))) + case RecordArg.Pair(k, v) => + (k, loop(v)) } - // we also need to check that there are no unused or duplicated - // fields - val paramNamesList = params.map(_._1) - val paramNames = paramNamesList.toSet - // here are all the fields we don't understand - val extra = mappingList.collect { case (k, _) if !paramNames(k) => k } - // Check that the mapping is exactly the right size - NonEmptyList.fromList(extra) match { - case None => res - case Some(extra) => - SourceConverter - .addError(res, - SourceConverter.UnexpectedField(name, rc, extra, paramNamesList, rc.region)) + val mappingList = args.toList.map(argExpr) + val mapping = mappingList.toMap + + lazy val present = + mappingList.iterator + .map(_._1) + .foldLeft(SortedSet.empty[Bindable])(_ + _) + + def get(b: Bindable): Result[Expr[Declaration]] = + mapping.get(b) match { + case Some(expr) => expr + case None => + SourceConverter.failure( + SourceConverter.MissingArg(name, rc, present, b, rc.region) + ) } - case None => - SourceConverter.failure(SourceConverter.UnknownConstructor(name, rc, decl.region)) - }) + val exprArgs = params.traverse { case (b, _) => get(b) } + + val res = exprArgs.map { args => + Expr.buildApp(cons, args.toList, rc) + } + // we also need to check that there are no unused or duplicated + // fields + val paramNamesList = params.map(_._1) + val paramNames = paramNamesList.toSet + // here are all the fields we don't understand + val extra = mappingList.collect { + case (k, _) if !paramNames(k) => k + } + // Check that the mapping is exactly the right size + NonEmptyList.fromList(extra) match { + case None => res + case Some(extra) => + SourceConverter + .addError( + res, + SourceConverter.UnexpectedField( + name, + rc, + extra, + paramNamesList, + rc.region + ) + ) + } + case None => + SourceConverter.failure( + SourceConverter.UnknownConstructor(name, rc, decl.region) + ) + }) } } private def toType(t: TypeRef, region: Region): Result[Type] = TypeRefConverter[Result](t)(nameToType(_, region)) - def toDefinition(pname: PackageName, tds: TypeDefinitionStatement): Result[rankn.DefinedType[Option[Kind.Arg]]] = { + def toDefinition( + pname: PackageName, + tds: TypeDefinitionStatement + ): Result[rankn.DefinedType[Option[Kind.Arg]]] = { import Statement._ type StT = ((Set[Type.TyVar], List[Type.TyVar]), LazyList[Type.TyVar]) @@ -637,114 +748,144 @@ final class SourceConverter( Type.freeTyVars(pt).map(Type.TyVar(_)) } - def buildParams(args: List[(Bindable, Option[Type])]): VarState[List[(Bindable, Type)]] = + def buildParams( + args: List[(Bindable, Option[Type])] + ): VarState[List[(Bindable, Type)]] = args.traverse(buildParam _) // This is a traverse on List[(Bindable, Option[A])] - val deep = Traverse[List].compose(Traverse[(Bindable, *)]).compose(Traverse[Option]) + val deep = + Traverse[List].compose(Traverse[(Bindable, *)]).compose(Traverse[Option]) def updateInferedWithDecl( - typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], - typeParams0: List[Type.Var.Bound]): Result[List[(Type.Var.Bound, Option[Kind.Arg])]] = - typeArgs match { - case None => success(typeParams0.map((_, None))) - case Some(decl) => - val neBound = decl.map { case (v, k) => (v.toBoundVar, k) } - val declSet = neBound.toList.iterator.map(_._1).toSet - val missingFromDecl = typeParams0.filterNot(declSet) - if ((declSet.size != neBound.size) || missingFromDecl.nonEmpty) { - val bestEffort = neBound.toList.distinctBy(_._1) ::: missingFromDecl.map((_, None)) - // we have a lint that fails if declTV is not - // a superset of what you would derive from the args - // the purpose here is to control the *order* of - // and to allow introducing phantom parameters, not - // it is confusing if some are explicit, but some are not - SourceConverter.partial( - SourceConverter.InvalidTypeParameters(decl, typeParams0, tds), - bestEffort) - } - else success(neBound.toList ::: missingFromDecl.map((_, None))) - } + typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], + typeParams0: List[Type.Var.Bound] + ): Result[List[(Type.Var.Bound, Option[Kind.Arg])]] = + typeArgs match { + case None => success(typeParams0.map((_, None))) + case Some(decl) => + val neBound = decl.map { case (v, k) => (v.toBoundVar, k) } + val declSet = neBound.toList.iterator.map(_._1).toSet + val missingFromDecl = typeParams0.filterNot(declSet) + if ((declSet.size != neBound.size) || missingFromDecl.nonEmpty) { + val bestEffort = + neBound.toList.distinctBy(_._1) ::: missingFromDecl.map((_, None)) + // we have a lint that fails if declTV is not + // a superset of what you would derive from the args + // the purpose here is to control the *order* of + // and to allow introducing phantom parameters, not + // it is confusing if some are explicit, but some are not + SourceConverter.partial( + SourceConverter.InvalidTypeParameters(decl, typeParams0, tds), + bestEffort + ) + } else success(neBound.toList ::: missingFromDecl.map((_, None))) + } - def validateArgCount(nm: Constructor, args: Int, region: Region): Result[Unit] = + def validateArgCount( + nm: Constructor, + args: Int, + region: Region + ): Result[Unit] = if (args <= Type.FnType.MaxSize) SourceConverter.successUnit - else SourceConverter.partial( - SourceConverter.TooManyConstructorArgs(nm, args, Type.FnType.MaxSize, region), ()) + else + SourceConverter.partial( + SourceConverter + .TooManyConstructorArgs(nm, args, Type.FnType.MaxSize, region), + () + ) // TODO we have to make sure we don't have more than 8 arguments to a struct // or the constructor Fn won't be a valid function tds match { case Struct(nm, typeArgs, args) => validateArgCount(nm, args.length, tds.region) *> - deep.traverse(args)(toType(_, tds.region)) - .flatMap { argsType => - val declVars = typeArgs.iterator.flatMap(_.toList).map { p => Type.TyVar(p._1.toBoundVar) } - val initVars = existingVars(argsType) - val initState = ((initVars.toSet ++ declVars, initVars.reverse), Type.allBinders.map(Type.TyVar)) - val (((_, typeVars), _), params) = buildParams(argsType).run(initState).value + deep + .traverse(args)(toType(_, tds.region)) + .flatMap { argsType => + val declVars = typeArgs.iterator.flatMap(_.toList).map { p => + Type.TyVar(p._1.toBoundVar) + } + val initVars = existingVars(argsType) + val initState = ( + (initVars.toSet ++ declVars, initVars.reverse), + Type.allBinders.map(Type.TyVar) + ) + val (((_, typeVars), _), params) = + buildParams(argsType).run(initState).value + // we reverse to make sure we see in traversal order + val typeParams0 = reverseMap(typeVars) { tv => + tv.toVar match { + case b @ Type.Var.Bound(_) => b + // $COVERAGE-OFF$ this should be unreachable + case unexpected => + sys.error( + s"unexpectedly parsed a non bound var: $unexpected" + ) + // $COVERAGE-ON$ + } + } + + updateInferedWithDecl(typeArgs, typeParams0).map { typeParams => + val tname = TypeName(nm) + val consFn = rankn.ConstructorFn(nm, params) + + rankn.DefinedType(pname, tname, typeParams, consFn :: Nil) + } + } + case Enum(nm, typeArgs, items) => + items.get + .traverse { case (nm, args) => + validateArgCount(nm, args.length, tds.region) *> + deep + .traverse(args)(toType(_, tds.region)) + .map((nm, _)) + } + .flatMap { conArgs => + val constructorsS = conArgs.traverse { case (nm, argsType) => + buildParams(argsType).map { params => + (nm, params) + } + } + val declVars = typeArgs.iterator.flatMap(_.toList).map { p => + Type.TyVar(p._1.toBoundVar) + } + val initVars = existingVars(conArgs.toList.flatMap(_._2)) + val initState = ( + (initVars.toSet ++ declVars, initVars.reverse), + Type.allBinders.map(Type.TyVar) + ) + val (((_, typeVars), _), constructors) = + constructorsS.run(initState).value // we reverse to make sure we see in traversal order val typeParams0 = reverseMap(typeVars) { tv => tv.toVar match { - case b@Type.Var.Bound(_) => b + case b @ Type.Var.Bound(_) => b // $COVERAGE-OFF$ this should be unreachable case unexpected => sys.error(s"unexpectedly parsed a non bound var: $unexpected") // $COVERAGE-ON$ } } - updateInferedWithDecl(typeArgs, typeParams0).map { typeParams => - val tname = TypeName(nm) - val consFn = rankn.ConstructorFn(nm, params) - - rankn.DefinedType(pname, - tname, - typeParams, - consFn :: Nil) - } - } - case Enum(nm, typeArgs, items) => - items.get.traverse { case (nm, args) => - validateArgCount(nm, args.length, tds.region) *> - deep.traverse(args)(toType(_, tds.region)) - .map((nm, _)) - } - .flatMap { conArgs => - - val constructorsS = conArgs.traverse { case (nm, argsType) => - buildParams(argsType).map { params => - (nm, params) - } - } - val declVars = typeArgs.iterator.flatMap(_.toList).map { p => Type.TyVar(p._1.toBoundVar) } - val initVars = existingVars(conArgs.toList.flatMap(_._2)) - val initState = ((initVars.toSet ++ declVars, initVars.reverse), Type.allBinders.map(Type.TyVar)) - val (((_, typeVars), _), constructors) = constructorsS.run(initState).value - // we reverse to make sure we see in traversal order - val typeParams0 = reverseMap(typeVars) { tv => - tv.toVar match { - case b@Type.Var.Bound(_) => b - // $COVERAGE-OFF$ this should be unreachable - case unexpected => sys.error(s"unexpectedly parsed a non bound var: $unexpected") - // $COVERAGE-ON$ - } - } - updateInferedWithDecl(typeArgs, typeParams0).map { typeParams => - val finalCons = constructors.toList.map { case (c, params) => - rankn.ConstructorFn(c, params) + val finalCons = constructors.toList.map { case (c, params) => + rankn.ConstructorFn(c, params) + } + rankn.DefinedType(pname, TypeName(nm), typeParams, finalCons) } - rankn.DefinedType(pname, TypeName(nm), typeParams, finalCons) } - } case ExternalStruct(nm, targs) => // TODO make a real check here of allowed kinds success( rankn.DefinedType( pname, TypeName(nm), - targs.map { case (TypeRef.TypeVar(v), optK) => (Type.Var.Bound(v), optK) }, - Nil) + targs.map { case (TypeRef.TypeVar(v), optK) => + (Type.Var.Bound(v), optK) + }, + Nil ) + ) } } @@ -760,25 +901,44 @@ final class SourceConverter( loop(as, Nil) } - private def convertPattern(pat: Pattern.Parsed, region: Region): Result[Pattern[(PackageName, Constructor), rankn.Type]] = { + private def convertPattern( + pat: Pattern.Parsed, + region: Region + ): Result[Pattern[(PackageName, Constructor), rankn.Type]] = { val nonTupled = unTuplePattern(pat, region) val collisions = pat.collisionBinds NonEmptyList.fromList(collisions) match { case None => nonTupled case Some(nel) => - SourceConverter.addError(nonTupled, SourceConverter.PatternShadow(nel, pat, region)) + SourceConverter.addError( + nonTupled, + SourceConverter.PatternShadow(nel, pat, region) + ) } } - private[this] val empty = Pattern.PositionalStruct((PackageName.PredefName, Constructor("EmptyList")), Nil) - private[this] val nonEmpty = (PackageName.PredefName, Constructor("NonEmptyList")) - - /** - * As much as possible, convert a list pattern into a normal enum pattern which simplifies - * matching, and possibly allows us to more easily statically remove more of the match - */ - private def unlistPattern(parts: List[Pattern.ListPart[Pattern[(PackageName, Constructor), rankn.Type]]]): Pattern[(PackageName, Constructor), rankn.Type] = { - def loop(parts: List[Pattern.ListPart[Pattern[(PackageName, Constructor), rankn.Type]]], topLevel: Boolean): Pattern[(PackageName, Constructor), rankn.Type] = + private[this] val empty = Pattern.PositionalStruct( + (PackageName.PredefName, Constructor("EmptyList")), + Nil + ) + private[this] val nonEmpty = + (PackageName.PredefName, Constructor("NonEmptyList")) + + /** As much as possible, convert a list pattern into a normal enum pattern + * which simplifies matching, and possibly allows us to more easily + * statically remove more of the match + */ + private def unlistPattern( + parts: List[ + Pattern.ListPart[Pattern[(PackageName, Constructor), rankn.Type]] + ] + ): Pattern[(PackageName, Constructor), rankn.Type] = { + def loop( + parts: List[ + Pattern.ListPart[Pattern[(PackageName, Constructor), rankn.Type]] + ], + topLevel: Boolean + ): Pattern[(PackageName, Constructor), rankn.Type] = parts match { case Nil => empty case Pattern.ListPart.Item(h) :: tail => @@ -790,8 +950,7 @@ final class SourceConverter( // changing to _ would allow more things to typecheck, which we can't do // and we can't annotate because we don't know the type of the list Pattern.ListPat(parts) - } - else { + } else { // we are already in the tail of a list, so we can just put _ here Pattern.WildCard } @@ -801,12 +960,13 @@ final class SourceConverter( // changing to _ would allow more things to typecheck, which we can't do // and we can't annotate because we don't know the type of the list Pattern.ListPat(parts) - } - else { + } else { // we are already in the tail of a list, so we can just put n here Pattern.Var(n) } - case (Pattern.ListPart.WildList :: (i@Pattern.ListPart.Item(Pattern.WildCard)) :: tail) => + case (Pattern.ListPart.WildList :: (i @ Pattern.ListPart.Item( + Pattern.WildCard + )) :: tail) => // [*_, _, x...] = [_, *_, x...] loop(i :: Pattern.ListPart.WildList :: tail, topLevel) case (Pattern.ListPart.WildList | Pattern.ListPart.NamedList(_)) :: _ => @@ -817,138 +977,211 @@ final class SourceConverter( loop(parts, true) } - /** - * Tuples are converted into standard types using an HList strategy - */ - private def unTuplePattern(pat: Pattern.Parsed, region: Region): Result[Pattern[(PackageName, Constructor), rankn.Type]] = - pat.traversePattern[Result, (PackageName, Constructor), rankn.Type]({ - case (Pattern.StructKind.Tuple, args) => - // this is a tuple pattern - args.flatMap(makeTuplePattern(_, region)) - case (Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), rargs) => - rargs.flatMap { args => - val pc@(p, c) = nameToCons(nm) - localTypeEnv.flatMap(_.getConstructorParams(p, c) match { - case Some(params) => - val argLen = args.size - val paramLen = params.size - if (argLen == paramLen) { - SourceConverter.success(Pattern.PositionalStruct(pc, args)) - } - else { - // do the best we can - val fixedArgs = (args ::: List.fill(paramLen - argLen)(Pattern.WildCard)).take(paramLen) - SourceConverter.partial( - SourceConverter.InvalidArgCount(nm, pat, argLen, paramLen, region), - Pattern.PositionalStruct(pc, fixedArgs)) - } - case None => - SourceConverter.failure(SourceConverter.UnknownConstructor(nm, pat, region)) - }) - } - case (Pattern.StructKind.NamedPartial(nm, Pattern.StructKind.Style.TupleLike), rargs) => - rargs.flatMap { args => - val pc@(p, c) = nameToCons(nm) - localTypeEnv.flatMap(_.getConstructorParams(p, c) match { - case Some(params) => - val argLen = args.size - val paramLen = params.size - if (argLen <= paramLen) { - val fixedArgs = if (argLen < paramLen) (args ::: List.fill(paramLen - argLen)(Pattern.WildCard)) else args - SourceConverter.success(Pattern.PositionalStruct(pc, fixedArgs)) - } - else { - // we have too many - val fixedArgs = args.take(paramLen) - SourceConverter.partial( - SourceConverter.InvalidArgCount(nm, pat, argLen, paramLen, region), - Pattern.PositionalStruct(pc, fixedArgs)) - } - case None => - SourceConverter.failure(SourceConverter.UnknownConstructor(nm, pat, region)) - }) - } - case (Pattern.StructKind.Named(nm, Pattern.StructKind.Style.RecordLike(fs)), rargs) => - rargs.flatMap { args => - val pc@(p, c) = nameToCons(nm) - localTypeEnv.flatMap(_.getConstructorParams(p, c) match { - case Some(params) => - val mapping = fs.toList.iterator.map(_.field).zip(args.iterator).toMap - lazy val present = SortedSet(fs.toList.iterator.map(_.field).toList: _*) - def get(b: Bindable): Result[Pattern[(PackageName, Constructor), rankn.Type]] = - mapping.get(b) match { - case Some(pat) => - SourceConverter.success(pat) - case None => - SourceConverter.partial(SourceConverter.MissingArg(nm, pat, present, b, region), Pattern.WildCard) + /** Tuples are converted into standard types using an HList strategy + */ + private def unTuplePattern( + pat: Pattern.Parsed, + region: Region + ): Result[Pattern[(PackageName, Constructor), rankn.Type]] = + pat.traversePattern[Result, (PackageName, Constructor), rankn.Type]( + { + case (Pattern.StructKind.Tuple, args) => + // this is a tuple pattern + args.flatMap(makeTuplePattern(_, region)) + case ( + Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), + rargs + ) => + rargs.flatMap { args => + val pc @ (p, c) = nameToCons(nm) + localTypeEnv.flatMap(_.getConstructorParams(p, c) match { + case Some(params) => + val argLen = args.size + val paramLen = params.size + if (argLen == paramLen) { + SourceConverter.success(Pattern.PositionalStruct(pc, args)) + } else { + // do the best we can + val fixedArgs = + (args ::: List.fill(paramLen - argLen)(Pattern.WildCard)) + .take(paramLen) + SourceConverter.partial( + SourceConverter + .InvalidArgCount(nm, pat, argLen, paramLen, region), + Pattern.PositionalStruct(pc, fixedArgs) + ) } - val mapped = - params - .traverse { case (b, _) => get(b) }(SourceConverter.parallelIor) - .map(Pattern.PositionalStruct(pc, _)) - - val paramNamesList = params.map(_._1) - val paramNames = paramNamesList.toSet - // here are all the fields we don't understand - val extra = fs.toList.iterator.map(_.field).filterNot(paramNames).toList - // Check that the mapping is exactly the right size - NonEmptyList.fromList(extra) match { - case None => mapped - case Some(extra) => - SourceConverter - .addError(mapped, - SourceConverter.UnexpectedField(nm, pat, extra, paramNamesList, region)) - } - case None => - SourceConverter.failure(SourceConverter.UnknownConstructor(nm, pat, region)) - }) - } - case (Pattern.StructKind.NamedPartial(nm, Pattern.StructKind.Style.RecordLike(fs)), rargs) => - rargs.flatMap { args => - val pc@(p, c) = nameToCons(nm) - localTypeEnv.flatMap(_.getConstructorParams(p, c) match { - case Some(params) => - val mapping = fs.toList.iterator.map(_.field).zip(args.iterator).toMap - def get(b: Bindable): Pattern[(PackageName, Constructor), rankn.Type] = - mapping.get(b) match { - case Some(pat) => pat - case None => Pattern.WildCard + case None => + SourceConverter.failure( + SourceConverter.UnknownConstructor(nm, pat, region) + ) + }) + } + case ( + Pattern.StructKind + .NamedPartial(nm, Pattern.StructKind.Style.TupleLike), + rargs + ) => + rargs.flatMap { args => + val pc @ (p, c) = nameToCons(nm) + localTypeEnv.flatMap(_.getConstructorParams(p, c) match { + case Some(params) => + val argLen = args.size + val paramLen = params.size + if (argLen <= paramLen) { + val fixedArgs = + if (argLen < paramLen) + (args ::: List.fill(paramLen - argLen)(Pattern.WildCard)) + else args + SourceConverter.success( + Pattern.PositionalStruct(pc, fixedArgs) + ) + } else { + // we have too many + val fixedArgs = args.take(paramLen) + SourceConverter.partial( + SourceConverter + .InvalidArgCount(nm, pat, argLen, paramLen, region), + Pattern.PositionalStruct(pc, fixedArgs) + ) } - val derefArgs = params.map { case (b, _) => get(b) } - val res0 = SourceConverter.success(Pattern.PositionalStruct(pc, derefArgs)) - - val paramNamesList = params.map(_._1) - val paramNames = paramNamesList.toSet - // here are all the fields we don't understand - val extra = fs.toList.iterator.map(_.field).filterNot(paramNames).toList - // Check that the mapping is exactly the right size - NonEmptyList.fromList(extra) match { - case None => res0 - case Some(extra) => - SourceConverter - .addError(res0, - SourceConverter.UnexpectedField(nm, pat, extra, paramNamesList, region)) - } - case None => - SourceConverter.failure(SourceConverter.UnknownConstructor(nm, pat, region)) - }) - } + case None => + SourceConverter.failure( + SourceConverter.UnknownConstructor(nm, pat, region) + ) + }) + } + case ( + Pattern.StructKind + .Named(nm, Pattern.StructKind.Style.RecordLike(fs)), + rargs + ) => + rargs.flatMap { args => + val pc @ (p, c) = nameToCons(nm) + localTypeEnv.flatMap(_.getConstructorParams(p, c) match { + case Some(params) => + val mapping = + fs.toList.iterator.map(_.field).zip(args.iterator).toMap + lazy val present = + SortedSet(fs.toList.iterator.map(_.field).toList: _*) + def get( + b: Bindable + ): Result[Pattern[(PackageName, Constructor), rankn.Type]] = + mapping.get(b) match { + case Some(pat) => + SourceConverter.success(pat) + case None => + SourceConverter.partial( + SourceConverter.MissingArg(nm, pat, present, b, region), + Pattern.WildCard + ) + } + val mapped = + params + .traverse { case (b, _) => get(b) }( + SourceConverter.parallelIor + ) + .map(Pattern.PositionalStruct(pc, _)) + + val paramNamesList = params.map(_._1) + val paramNames = paramNamesList.toSet + // here are all the fields we don't understand + val extra = + fs.toList.iterator.map(_.field).filterNot(paramNames).toList + // Check that the mapping is exactly the right size + NonEmptyList.fromList(extra) match { + case None => mapped + case Some(extra) => + SourceConverter + .addError( + mapped, + SourceConverter.UnexpectedField( + nm, + pat, + extra, + paramNamesList, + region + ) + ) + } + case None => + SourceConverter.failure( + SourceConverter.UnknownConstructor(nm, pat, region) + ) + }) + } + case ( + Pattern.StructKind + .NamedPartial(nm, Pattern.StructKind.Style.RecordLike(fs)), + rargs + ) => + rargs.flatMap { args => + val pc @ (p, c) = nameToCons(nm) + localTypeEnv.flatMap(_.getConstructorParams(p, c) match { + case Some(params) => + val mapping = + fs.toList.iterator.map(_.field).zip(args.iterator).toMap + def get( + b: Bindable + ): Pattern[(PackageName, Constructor), rankn.Type] = + mapping.get(b) match { + case Some(pat) => pat + case None => Pattern.WildCard + } + val derefArgs = params.map { case (b, _) => get(b) } + val res0 = SourceConverter.success( + Pattern.PositionalStruct(pc, derefArgs) + ) + + val paramNamesList = params.map(_._1) + val paramNames = paramNamesList.toSet + // here are all the fields we don't understand + val extra = + fs.toList.iterator.map(_.field).filterNot(paramNames).toList + // Check that the mapping is exactly the right size + NonEmptyList.fromList(extra) match { + case None => res0 + case Some(extra) => + SourceConverter + .addError( + res0, + SourceConverter.UnexpectedField( + nm, + pat, + extra, + paramNamesList, + region + ) + ) + } + case None => + SourceConverter.failure( + SourceConverter.UnknownConstructor(nm, pat, region) + ) + }) + } }, { t => toType(t, region) }, { items => items.map(unlistPattern) } - )(SourceConverter.parallelIor) // use the parallel, not the default Applicative which is Monadic + )( + SourceConverter.parallelIor + ) // use the parallel, not the default Applicative which is Monadic private lazy val toTypeEnv: Result[ParsedTypeEnv[Option[Kind.Arg]]] = { val sunit = success(()) - val dupTypes = localDefs.groupByNel(_.name) + val dupTypes = localDefs + .groupByNel(_.name) .toList .traverse { case (n, tes) => if (tes.tail.isEmpty) sunit else { val dupRegions = tes.map(_.region) - SourceConverter.partial(SourceConverter.Duplication(n, SourceConverter.DupKind.TypeName, dupRegions), - ()) + SourceConverter.partial( + SourceConverter + .Duplication(n, SourceConverter.DupKind.TypeName, dupRegions), + () + ) } } @@ -962,11 +1195,13 @@ final class SourceConverter( // these are colliding constructors, but if they also collide on type // name we have already reported it above sunit - } - else { + } else { val dupRegions = tes.map(_._2.region) - SourceConverter.partial(SourceConverter.Duplication(n, SourceConverter.DupKind.Constructor, dupRegions), - ()) + SourceConverter.partial( + SourceConverter + .Duplication(n, SourceConverter.DupKind.Constructor, dupRegions), + () + ) } } @@ -983,9 +1218,7 @@ final class SourceConverter( toTypeEnv.map { p => importedTypeEnv ++ TypeEnv.fromParsed(p) } private def anonNameStrings(): Iterator[String] = - rankn.Type - .allBinders - .iterator + rankn.Type.allBinders.iterator .map(_.name) private def unusedNames(allNames: Bindable => Boolean): Iterator[Bindable] = @@ -993,13 +1226,14 @@ final class SourceConverter( .map(Identifier.Name(_)) .filterNot(allNames) - /** - * Externals are not permitted to be shadowed at the top level - */ - private def checkExternalDefShadowing(values: List[Statement.ValueStatement]): Result[Unit] = { + /** Externals are not permitted to be shadowed at the top level + */ + private def checkExternalDefShadowing( + values: List[Statement.ValueStatement] + ): Result[Unit] = { val extDefNames = - values.collect { - case ed@Statement.ExternalDef(name, _, _) => (name, ed.region) + values.collect { case ed @ Statement.ExternalDef(name, _, _) => + (name, ed.region) } val sunit = success(()) @@ -1014,15 +1248,22 @@ final class SourceConverter( case NonEmptyList(_, Nil) => sunit case NonEmptyList((_, r1), (_, r2) :: rest) => SourceConverter.partial( - SourceConverter.Duplication(name, SourceConverter.DupKind.ExtDef, NonEmptyList(r1, r2 :: rest.map(_._2))), - ()) + SourceConverter.Duplication( + name, + SourceConverter.DupKind.ExtDef, + NonEmptyList(r1, r2 :: rest.map(_._2)) + ), + () + ) } } - def bindOrDef(s: Statement.ValueStatement): Option[Either[Statement.Bind, Statement.Def]] = + def bindOrDef( + s: Statement.ValueStatement + ): Option[Either[Statement.Bind, Statement.Def]] = s match { - case b@Statement.Bind(_) => Some(Left(b)) - case d@Statement.Def(_) => Some(Right(d)) + case b @ Statement.Bind(_) => Some(Left(b)) + case d @ Statement.Def(_) => Some(Right(d)) case Statement.ExternalDef(_, _, _) => None } @@ -1034,16 +1275,15 @@ final class SourceConverter( val shadows = names.filter(extDefNamesSet) NonEmptyList.fromList(shadows) match { - case None => sunit + case None => sunit case Some(nel) => // we are shadowing SourceConverter.partial( - SourceConverter.ExtDefShadow( - SourceConverter.BindKind.Bind, - nel, - s.region), - ()) - } + SourceConverter + .ExtDefShadow(SourceConverter.BindKind.Bind, nel, s.region), + () + ) + } } dupRes *> values.traverse_(checkDefBind) @@ -1051,9 +1291,9 @@ final class SourceConverter( } // Flatten pattern bindings out - private def bindingsDecl( - b: Pattern.Parsed, - decl: Declaration)(alloc: () => Bindable): NonEmptyList[(Bindable, Declaration)] = + private def bindingsDecl(b: Pattern.Parsed, decl: Declaration)( + alloc: () => Bindable + ): NonEmptyList[(Bindable, Declaration)] = b match { case Pattern.Var(nm) => NonEmptyList.one((nm, decl)) @@ -1076,8 +1316,7 @@ final class SourceConverter( if (decl.isCheap) { // no need to make a new var to point to a var (Nil, decl) - } - else { + } else { val ident = alloc() val v = Var(ident)(decl.region) ((ident, decl) :: Nil, v) @@ -1090,7 +1329,8 @@ final class SourceConverter( Match( RecursionKind.NonRecursive, rhsNB, - OptIndent.same(NonEmptyList.one((pat, resOI))))(decl.region) + OptIndent.same(NonEmptyList.one((pat, resOI))) + )(decl.region) } val tail: List[(Bindable, Declaration)] = @@ -1113,16 +1353,17 @@ final class SourceConverter( } } - private def parFold[F[_], S, A, B](s0: S, as: List[A])(fn: (S, A) => (S, F[B]))(implicit F: Applicative[F]): F[List[B]] = { + private def parFold[F[_], S, A, B](s0: S, as: List[A])( + fn: (S, A) => (S, F[B]) + )(implicit F: Applicative[F]): F[List[B]] = { val avec = as.toVector def loop(start: Int, end: Int, s: S): (S, F[Chain[B]]) = if (start >= end) (s, F.pure(Chain.empty)) else if (start == (end - 1)) { val (s1, fb) = fn(s, avec(start)) (s1, fb.map(Chain.one(_))) - } - else { - val mid = start + (end - start)/2 + } else { + val mid = start + (end - start) / 2 val (s1, f1) = loop(start, mid, s) val (s2, f2) = loop(mid, end, s1) (s2, F.map2(f1, f2)(_ ++ _)) @@ -1131,17 +1372,16 @@ final class SourceConverter( loop(0, avec.size, s0)._2.map(_.toList) } - /** - * Return the lets in order they appear - */ - private def toLets(stmts: Seq[Statement.ValueStatement]): Result[List[(Bindable, RecursionKind, Expr[Declaration])]] = { + /** Return the lets in order they appear + */ + private def toLets( + stmts: Seq[Statement.ValueStatement] + ): Result[List[(Bindable, RecursionKind, Expr[Declaration])]] = { import Statement._ val newName: () => Bindable = { lazy val allNames: Set[Bindable] = - stmts - .flatMap { v => v.names.iterator ++ v.allNames.iterator } - .toSet + stmts.flatMap { v => v.names.iterator ++ v.allNames.iterator }.toSet // Each time we need a name, we can call anonNames.next() // it is mutable, but in a limited scope @@ -1154,15 +1394,14 @@ final class SourceConverter( val flatList: List[(Bindable, RecursionKind, Flattened)] = stmts.toList.flatMap { - case d@Def(_) => + case d @ Def(_) => (d.defstatement.name, RecursionKind.Recursive, Left(d)) :: Nil case ExternalDef(_, _, _) => // we don't allow external defs to shadow at all, so skip it here Nil case Bind(BindingStatement(bound, decl, _)) => - bindingsDecl(bound, decl)(newName) - .toList - .map { case pair@(b, _) => + bindingsDecl(bound, decl)(newName).toList + .map { case pair @ (b, _) => (b, RecursionKind.NonRecursive, Right(pair)) } } @@ -1173,45 +1412,52 @@ final class SourceConverter( // TODO make a better name, close to the original, but also not colliding // by using idx val newNameV: Bindable = newName() - val fn: Flattened => Flattened = - { - case Left(d@Def(dstmt)) => - val d1 = if (dstmt.name === bind) dstmt.copy(name = newNameV) else dstmt - val res = - if (dstmt.args.flatten.iterator.flatMap(_.names).exists(_ == bind)) { - // the args are shadowing the binding, so we don't need to substitute - dstmt.result - } - else { - dstmt.result.map { body => - Declaration.substitute(bind, Var(newNameV)(body.region), body) match { - case Some(body1) => body1 - case None => - // $COVERAGE-OFF$ - throw new IllegalStateException("we know newName can't mask") - // $COVERAGE-ON$ - } + val fn: Flattened => Flattened = { + case Left(d @ Def(dstmt)) => + val d1 = + if (dstmt.name === bind) dstmt.copy(name = newNameV) else dstmt + val res = + if ( + dstmt.args.flatten.iterator.flatMap(_.names).exists(_ == bind) + ) { + // the args are shadowing the binding, so we don't need to substitute + dstmt.result + } else { + dstmt.result.map { body => + Declaration.substitute( + bind, + Var(newNameV)(body.region), + body + ) match { + case Some(body1) => body1 + case None => + // $COVERAGE-OFF$ + throw new IllegalStateException( + "we know newName can't mask" + ) + // $COVERAGE-ON$ } } - Left(Def(d1.copy(result = res))(d.region)) - case Right((b0, d)) => - // we don't need to update b0, we discard it anyway - Declaration.substitute(bind, Var(newNameV)(d.region), d) match { - case Some(d1) => Right((b0, d1)) - // $COVERAGE-OFF$ - case None => - throw new IllegalStateException("we know newName can't mask") - // $COVERAGE-ON$ } - } + Left(Def(d1.copy(result = res))(d.region)) + case Right((b0, d)) => + // we don't need to update b0, we discard it anyway + Declaration.substitute(bind, Var(newNameV)(d.region), d) match { + case Some(d1) => Right((b0, d1)) + // $COVERAGE-OFF$ + case None => + throw new IllegalStateException("we know newName can't mask") + // $COVERAGE-ON$ + } + } (newNameV, fn) } val withEx: List[Either[ExternalDef, Flattened]] = - stmts.collect { case e@ExternalDef(_, _, _) => Left(e) }.toList ::: + stmts.collect { case e @ ExternalDef(_, _, _) => Left(e) }.toList ::: flatIn.map { - case (b, _, Left(d@Def(dstmt))) => + case (b, _, Left(d @ Def(dstmt))) => Right(Left(Def(dstmt.copy(name = b))(d.region))) case (b, _, Right((_, d))) => Right(Right((b, d))) } @@ -1219,15 +1465,20 @@ final class SourceConverter( parFold(Set.empty[Bindable], withEx) { case (topBound, stmt) => stmt match { case Right(Right((nm, decl))) => - - val r = fromDecl(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil) + val r = fromDecl(decl, Set.empty, topBound).map( + (nm, RecursionKind.NonRecursive, _) :: Nil + ) // make sure all the free types are Generic // we have to do this at the top level because in Declaration => Expr // we allow closing over type variables defined at a higher level - val r1 = r.map { exs => exs.map { case (n, r, e) => (n, r, Expr.quantifyFrees(e)) } } + val r1 = r.map { exs => + exs.map { case (n, r, e) => (n, r, Expr.quantifyFrees(e)) } + } (topBound + nm, r1) - case Right(Left(d @ Def(defstmt@DefStatement(_, _, argGroups, _, _)))) => + case Right( + Left(d @ Def(defstmt @ DefStatement(_, _, argGroups, _, _))) + ) => // using body for the outer here is a bummer, but not really a good outer otherwise val boundName = defstmt.name @@ -1238,15 +1489,20 @@ final class SourceConverter( toLambdaExpr[OptIndent[Declaration]]( defstmt, d.region, - success(defstmt.result.get))( - { (res: OptIndent[Declaration]) => - fromDecl(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1) - }) + success(defstmt.result.get) + )({ (res: OptIndent[Declaration]) => + fromDecl( + res.get, + argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, + topBound1 + ) + }) val r = lam.map { (l: Expr[Declaration]) => // We rely on DefRecursionCheck to rule out bad recursions val rec = - if (UnusedLetCheck.freeBound(l).contains(boundName)) RecursionKind.Recursive + if (UnusedLetCheck.freeBound(l).contains(boundName)) + RecursionKind.Recursive else RecursionKind.NonRecursive // make sure all the free types are Generic // we have to do this at the top level because in Declaration => Expr @@ -1259,14 +1515,21 @@ final class SourceConverter( (topBound + n, success(Nil)) } }(SourceConverter.parallelIor) - .map(_.flatten) + .map(_.flatten) } - def toProgram(ss: List[Statement]): Result[Program[(TypeEnv[Kind.Arg], ParsedTypeEnv[Option[Kind.Arg]]), Expr[Declaration], List[Statement]]] = { + def toProgram( + ss: List[Statement] + ): Result[Program[(TypeEnv[Kind.Arg], ParsedTypeEnv[Option[Kind.Arg]]), Expr[ + Declaration + ], List[Statement]]] = { val stmts = Statement.valuesOf(ss).toList - stmts.collect { - case ed@Statement.ExternalDef(name, params, result) => - (params.traverse { p => toType(p._2, ed.region) }, toType(result, ed.region)) + stmts + .collect { case ed @ Statement.ExternalDef(name, params, result) => + ( + params.traverse { p => toType(p._2, ed.region) }, + toType(result, ed.region) + ) .flatMapN { (paramTypes, resType) => NonEmptyList.fromList(paramTypes) match { case None => success(resType) @@ -1276,7 +1539,10 @@ final class SourceConverter( case None => val invalid = rankn.Type.Fun(nel, resType) SourceConverter - .partial(SourceConverter.InvalidArity(nel.length, ed.region), invalid) + .partial( + SourceConverter.InvalidArity(nel.length, ed.region), + invalid + ) } } } @@ -1284,32 +1550,34 @@ final class SourceConverter( val freeVars = rankn.Type.freeTyVars(tpe :: Nil) // these vars were parsed so they are never skolem vars val freeBound = freeVars.map { - case b@rankn.Type.Var.Bound(_) => b - case s@rankn.Type.Var.Skolem(_, _, _, _) => + case b @ rankn.Type.Var.Bound(_) => b + case s @ rankn.Type.Var.Skolem(_, _, _, _) => // $COVERAGE-OFF$ this should be unreachable sys.error(s"invariant violation: parsed a skolem var: $s") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } // TODO: Kind support parsing kinds - val maybeForAll = rankn.Type.forAll(freeBound.map { n => (n, Kind.Type) }, tpe) + val maybeForAll = + rankn.Type.forAll(freeBound.map { n => (n, Kind.Type) }, tpe) (name, maybeForAll) } - } - // TODO: we could implement Iterable[Ior[A, B]] => Ior[A, Iterble[B]] - // where we drop all total failures in order to make more progress - .sequence - .flatMap { exts => - val pte1 = toTypeEnv.map { p => - exts.foldLeft(p) { case (pte, (name, tpe)) => - pte.addExternalValue(thisPackage, name, tpe) - } } + // TODO: we could implement Iterable[Ior[A, B]] => Ior[A, Iterble[B]] + // where we drop all total failures in order to make more progress + .sequence + .flatMap { exts => + val pte1 = toTypeEnv.map { p => + exts.foldLeft(p) { case (pte, (name, tpe)) => + pte.addExternalValue(thisPackage, name, tpe) + } + } - implicit val parallel = SourceConverter.parallelIor - (checkExternalDefShadowing(stmts), toLets(stmts), pte1).mapN { (_, binds, pte1) => - Program((importedTypeEnv, pte1), binds, exts.map(_._1).toList, ss) + implicit val parallel = SourceConverter.parallelIor + (checkExternalDefShadowing(stmts), toLets(stmts), pte1).mapN { + (_, binds, pte1) => + Program((importedTypeEnv, pte1), binds, exts.map(_._1).toList, ss) + } } - } } } @@ -1319,7 +1587,8 @@ object SourceConverter { def success[A](a: A): Result[A] = Ior.Right(a) val successUnit: Result[Unit] = success(()) - def partial[A](err: Error, a: A): Result[A] = Ior.Both(NonEmptyChain.one(err), a) + def partial[A](err: Error, a: A): Result[A] = + Ior.Both(NonEmptyChain.one(err), a) def failure[A](err: Error): Result[A] = Ior.Left(NonEmptyChain.one(err)) def addError[A](r: Result[A], err: Error): Result[A] = @@ -1329,114 +1598,124 @@ object SourceConverter { private val parallelIor: Applicative[Result] = Ior.catsDataParallelForIor[NonEmptyChain[Error]].applicative - def toProgram( - thisPackage: PackageName, - imports: List[Import[PackageName, NonEmptyList[Referant[Kind.Arg]]]], - stmts: List[Statement]): Result[Program[(TypeEnv[Kind.Arg], ParsedTypeEnv[Option[Kind.Arg]]), Expr[Declaration], List[Statement]]] = - (new SourceConverter(thisPackage, imports, Statement.definitionsOf(stmts).toList)).toProgram(stmts) + def toProgram( + thisPackage: PackageName, + imports: List[Import[PackageName, NonEmptyList[Referant[Kind.Arg]]]], + stmts: List[Statement] + ): Result[Program[(TypeEnv[Kind.Arg], ParsedTypeEnv[Option[Kind.Arg]]), Expr[ + Declaration + ], List[Statement]]] = + (new SourceConverter( + thisPackage, + imports, + Statement.definitionsOf(stmts).toList + )).toProgram(stmts) private def concat[A](ls: List[A], tail: NonEmptyList[A]): NonEmptyList[A] = ls match { - case Nil => tail + case Nil => tail case h :: t => NonEmptyList(h, t ::: tail.toList) } - /** - * For all duplicate binds, for all but the final - * value, rename them - */ - def makeLetsUnique[D]( - lets: List[(Bindable, RecursionKind, D)])( - newName: (Bindable, Int) => (Bindable, D => D)): List[(Bindable, RecursionKind, D)] = - NonEmptyList.fromList(lets) match { - case None => Nil - case Some(nelets) => - // there is at least 1 let, but maybe no duplicates - val dups: Map[Bindable, Int] = - nelets.foldLeft(Map.empty[Bindable, Int]) { - case (bound, (b, _, _)) => - bound.get(b) match { - case Some(c) => bound.updated(b, c + 1) - case None => bound.updated(b, 1) - } + /** For all duplicate binds, for all but the final value, rename them + */ + def makeLetsUnique[D](lets: List[(Bindable, RecursionKind, D)])( + newName: (Bindable, Int) => (Bindable, D => D) + ): List[(Bindable, RecursionKind, D)] = + NonEmptyList.fromList(lets) match { + case None => Nil + case Some(nelets) => + // there is at least 1 let, but maybe no duplicates + val dups: Map[Bindable, Int] = + nelets + .foldLeft(Map.empty[Bindable, Int]) { case (bound, (b, _, _)) => + bound.get(b) match { + case Some(c) => bound.updated(b, c + 1) + case None => bound.updated(b, 1) + } } .filter { case (_, v) => v > 1 } - if (dups.isEmpty) { - // no duplicated top level names - lets - } - else { - // we rename all but the last name for each duplicate - type BRD = (Bindable, RecursionKind, D) - - /* - * Invariant, lets.exists(_._1 == name) == true - * if this is false, this method will throw - */ - @annotation.tailrec - def renameUntilNext(name: Bindable, lets: NonEmptyList[BRD], acc: List[BRD])(fn: D => D): NonEmptyList[BRD] = { - // note this is a total match: - val NonEmptyList(head @ (b, r, d), tail) = lets - - if (b == name) { - val head1 = - if (r.isRecursive) { - // the new b is in scope right away - head - } - else { - // the old b1 is in scope for this one - (b, r, fn(d)) - } - NonEmptyList(head1, acc).reverse.concat(tail) - } - else { - // if b != name, then that implies there is - // at least one item in the tail with b, - // so tail cannot be empty - val netail = NonEmptyList.fromListUnsafe(tail) - renameUntilNext(name, netail, (b, r, fn(d)) :: acc)(fn) - } + if (dups.isEmpty) { + // no duplicated top level names + lets + } else { + // we rename all but the last name for each duplicate + type BRD = (Bindable, RecursionKind, D) + + /* + * Invariant, lets.exists(_._1 == name) == true + * if this is false, this method will throw + */ + @annotation.tailrec + def renameUntilNext( + name: Bindable, + lets: NonEmptyList[BRD], + acc: List[BRD] + )(fn: D => D): NonEmptyList[BRD] = { + // note this is a total match: + val NonEmptyList(head @ (b, r, d), tail) = lets + + if (b == name) { + val head1 = + if (r.isRecursive) { + // the new b is in scope right away + head + } else { + // the old b1 is in scope for this one + (b, r, fn(d)) + } + NonEmptyList(head1, acc).reverse.concat(tail) + } else { + // if b != name, then that implies there is + // at least one item in the tail with b, + // so tail cannot be empty + val netail = NonEmptyList.fromListUnsafe(tail) + renameUntilNext(name, netail, (b, r, fn(d)) :: acc)(fn) } + } - @annotation.tailrec - def loop(lets: NonEmptyList[BRD], state: Map[Bindable, (Int, Int)], acc: List[BRD]): NonEmptyList[BRD] = { - val head = lets.head - NonEmptyList.fromList(lets.tail) match { - case Some(netail) => - val (b, r, d) = head - state.get(b) match { - case Some((cnt, sz)) if cnt < (sz - 1) => - val newState = state.updated(b, (cnt + 1, sz)) - // we have to rename until the next bind - val (b1, renamer) = newName(b, cnt) - val d1 = - if (r.isRecursive) renamer(d) - else d - - val head1 = (b1, r, d1) - // since cnt < (sz - 1) we know that - // b must occur at least once in netail - val tail1 = renameUntilNext(b, netail, Nil)(renamer) - loop(tail1, newState, head1 :: acc) - case _ => - // this is the last one or not a duplicate, we don't change it - loop(netail, state, head :: acc) - } - case None => - // the last one is never renamed - NonEmptyList(head, acc).reverse - } + @annotation.tailrec + def loop( + lets: NonEmptyList[BRD], + state: Map[Bindable, (Int, Int)], + acc: List[BRD] + ): NonEmptyList[BRD] = { + val head = lets.head + NonEmptyList.fromList(lets.tail) match { + case Some(netail) => + val (b, r, d) = head + state.get(b) match { + case Some((cnt, sz)) if cnt < (sz - 1) => + val newState = state.updated(b, (cnt + 1, sz)) + // we have to rename until the next bind + val (b1, renamer) = newName(b, cnt) + val d1 = + if (r.isRecursive) renamer(d) + else d + + val head1 = (b1, r, d1) + // since cnt < (sz - 1) we know that + // b must occur at least once in netail + val tail1 = renameUntilNext(b, netail, Nil)(renamer) + loop(tail1, newState, head1 :: acc) + case _ => + // this is the last one or not a duplicate, we don't change it + loop(netail, state, head :: acc) + } + case None => + // the last one is never renamed + NonEmptyList(head, acc).reverse } + } - // there are duplicates - val dupState: Map[Bindable, (Int, Int)] = - dups.iterator.map { case (k, sz) => (k, (0, sz)) }.toMap + // there are duplicates + val dupState: Map[Bindable, (Int, Int)] = + dups.iterator.map { case (k, sz) => (k, (0, sz)) }.toMap - loop(nelets, dupState, Nil).toList - } + loop(nelets, dupState, Nil).toList } + } sealed abstract class Error { def region: Region @@ -1454,7 +1733,11 @@ object SourceConverter { final case object Bind extends BindKind("bind") } - final case class ExtDefShadow(kind: BindKind, names: NonEmptyList[Bindable], region: Region) extends Error { + final case class ExtDefShadow( + kind: BindKind, + names: NonEmptyList[Bindable], + region: Region + ) extends Error { def message = { val ns = names.toList.iterator.map(_.sourceCodeRepr).mkString(", ") s"${kind.asString} names $ns shadow external def" @@ -1468,13 +1751,21 @@ object SourceConverter { case object Constructor extends DupKind("constructor") } - final case class Duplication(name: Identifier, kind: DupKind, duplicates: NonEmptyList[Region]) extends Error { + final case class Duplication( + name: Identifier, + kind: DupKind, + duplicates: NonEmptyList[Region] + ) extends Error { def region = duplicates.head def message = s"${kind.asString}: ${name.sourceCodeRepr} defined multiple times" } - final case class PatternShadow(names: NonEmptyList[Bindable], pattern: Pattern.Parsed, region: Region) extends Error { + final case class PatternShadow( + names: NonEmptyList[Bindable], + pattern: Pattern.Parsed, + region: Region + ) extends Error { def message = { val str = names.toList.map(_.sourceCodeRepr).mkString(", ") "repeated bindings in pattern: " + str @@ -1488,7 +1779,8 @@ object SourceConverter { final case class Pat(toPattern: Pattern.Parsed) extends ConstructorSyntax { def toDoc = Document[Pattern.Parsed].document(toPattern) } - final case class RecCons(toDeclaration: Declaration.RecordConstructor) extends ConstructorSyntax { + final case class RecCons(toDeclaration: Declaration.RecordConstructor) + extends ConstructorSyntax { def toDoc = toDeclaration.toDoc } @@ -1499,10 +1791,19 @@ object SourceConverter { RecCons(c) } - final case class UnknownConstructor(name: Constructor, syntax: ConstructorSyntax, region: Region) extends ConstructorError { + final case class UnknownConstructor( + name: Constructor, + syntax: ConstructorSyntax, + region: Region + ) extends ConstructorError { def message = { val maybeDoc = syntax match { - case ConstructorSyntax.Pat(Pattern.PositionalStruct(Pattern.StructKind.Named(n, Pattern.StructKind.Style.TupleLike), Nil)) if n == name => + case ConstructorSyntax.Pat( + Pattern.PositionalStruct( + Pattern.StructKind.Named(n, Pattern.StructKind.Style.TupleLike), + Nil + ) + ) if n == name => // the pattern is just name Doc.empty case _ => @@ -1511,28 +1812,59 @@ object SourceConverter { (Doc.text(s"unknown constructor ${name.asString}") + maybeDoc).render(80) } } - final case class InvalidArgCount(name: Constructor, syntax: ConstructorSyntax, argCount: Int, expected: Int, region: Region) extends ConstructorError { + final case class InvalidArgCount( + name: Constructor, + syntax: ConstructorSyntax, + argCount: Int, + expected: Int, + region: Region + ) extends ConstructorError { def message = - (Doc.text(s"invalid argument count in ${name.asString}, found $argCount expected $expected") + Doc.lineOrSpace + syntax.toDoc).render(80) + (Doc.text( + s"invalid argument count in ${name.asString}, found $argCount expected $expected" + ) + Doc.lineOrSpace + syntax.toDoc).render(80) } - final case class MissingArg(name: Constructor, syntax: ConstructorSyntax, present: SortedSet[Bindable], missing: Bindable, region: Region) extends ConstructorError { + final case class MissingArg( + name: Constructor, + syntax: ConstructorSyntax, + present: SortedSet[Bindable], + missing: Bindable, + region: Region + ) extends ConstructorError { def message = - (Doc.text(s"missing field ${missing.asString} in ${name.asString}") + Doc.lineOrSpace + syntax.toDoc).render(80) + (Doc.text( + s"missing field ${missing.asString} in ${name.asString}" + ) + Doc.lineOrSpace + syntax.toDoc).render(80) } - final case class UnexpectedField(name: Constructor, syntax: ConstructorSyntax, unexpected: NonEmptyList[Bindable], expected: List[Bindable], region: Region) extends ConstructorError { + final case class UnexpectedField( + name: Constructor, + syntax: ConstructorSyntax, + unexpected: NonEmptyList[Bindable], + expected: List[Bindable], + region: Region + ) extends ConstructorError { def message = { val plural = if (unexpected.tail.isEmpty) "field" else "fields" - val unexDoc = Doc.intercalate(Doc.comma + Doc.lineOrSpace, unexpected.toList.map { b => Doc.text(b.asString) }) - val exDoc = Doc.intercalate(Doc.comma + Doc.lineOrSpace, expected.map { b => Doc.text(b.asString) }) + val unexDoc = Doc.intercalate( + Doc.comma + Doc.lineOrSpace, + unexpected.toList.map { b => Doc.text(b.asString) } + ) + val exDoc = Doc.intercalate( + Doc.comma + Doc.lineOrSpace, + expected.map { b => Doc.text(b.asString) } + ) (Doc.text(s"unexpected $plural: ") + unexDoc + Doc.lineOrSpace + - Doc.text(s"in ${name.asString}, expected: ") + exDoc + Doc.lineOrSpace + syntax.toDoc).render(80) - } + Doc.text( + s"in ${name.asString}, expected: " + ) + exDoc + Doc.lineOrSpace + syntax.toDoc).render(80) + } } final case class InvalidTypeParameters( - declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])], - discoveredTypes: List[Type.Var.Bound], - statement: TypeDefinitionStatement) extends Error { + declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])], + discoveredTypes: List[Type.Var.Bound], + statement: TypeDefinitionStatement + ) extends Error { def region = statement.region def message = { @@ -1540,50 +1872,69 @@ object SourceConverter { l.iterator.map(_.name).mkString("[", ", ", "]") val decl = - TypeRef.docTypeArgs(declaredParams.toList) { - case None => Doc.empty - case Some(ka) => Doc.text(": ") + Kind.argDoc(ka) - }.renderTrim(80) + TypeRef + .docTypeArgs(declaredParams.toList) { + case None => Doc.empty + case Some(ka) => Doc.text(": ") + Kind.argDoc(ka) + } + .renderTrim(80) val disc = tstr(discoveredTypes) s"${statement.name.asString} found declared: $decl, not a superset of $disc" } } final case class InvalidDefTypeParameters[B]( - declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind])], - free: List[Type.Var.Bound], - defstmt: DefStatement[Pattern.Parsed, B], - region: Region) extends Error { + declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind])], + free: List[Type.Var.Bound], + defstmt: DefStatement[Pattern.Parsed, B], + region: Region + ) extends Error { def message = { def tstr(l: List[Type.Var.Bound]): String = l.iterator.map(_.name).mkString("[", ", ", "]") - val decl = TypeRef.docTypeArgs(declaredParams.toList) { - case None => Doc.empty - case Some(k) => Doc.text(": ") + Kind.toDoc(k) - }.renderTrim(80) + val decl = TypeRef + .docTypeArgs(declaredParams.toList) { + case None => Doc.empty + case Some(k) => Doc.text(": ") + Kind.toDoc(k) + } + .renderTrim(80) val freeStr = tstr(free) s"${defstmt.name.asString} found declared types: $decl, not a subset of $freeStr" } } - final case class UnknownTypeName(tpe: Constructor, region: Region) extends Error { + final case class UnknownTypeName(tpe: Constructor, region: Region) + extends Error { def message = s"unknown type: ${tpe.asString}" } final case class InvalidArity(size: Int, region: Region) extends Error { - def message = s"invalid function arguments = $size, maximum = ${rankn.Type.FnType.MaxSize}" + def message = + s"invalid function arguments = $size, maximum = ${rankn.Type.FnType.MaxSize}" } - final case class TooManyConstructorArgs(name: Constructor, argCount: Int, max: Int, region: Region) extends Error { + final case class TooManyConstructorArgs( + name: Constructor, + argCount: Int, + max: Int, + region: Region + ) extends Error { def message = if (name.asString == "Tuple32") { - Doc.text(s"invalid tuple size. Found $argCount, but maximum allowed ${Type.FnType.MaxSize}").render(80) - } - else { - Doc.text(s"invalid argument count in constructor for ${name.asString} found $argCount maximum allowed $max").render(80) + Doc + .text( + s"invalid tuple size. Found $argCount, but maximum allowed ${Type.FnType.MaxSize}" + ) + .render(80) + } else { + Doc + .text( + s"invalid argument count in constructor for ${name.asString} found $argCount maximum allowed $max" + ) + .render(80) } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Statement.scala b/core/src/main/scala/org/bykn/bosatsu/Statement.scala index 4469a94b9..16fb21833 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Statement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Statement.scala @@ -1,10 +1,10 @@ package org.bykn.bosatsu -import Parser.{ Combinators, Indy, maybeSpace, keySpace, toEOL } +import Parser.{Combinators, Indy, maybeSpace, keySpace, toEOL} import cats.data.NonEmptyList import cats.implicits._ import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import scala.collection.immutable.SortedSet import Indy.IndyMethods @@ -12,10 +12,9 @@ import Identifier.{Bindable, Constructor} sealed abstract class Statement { - /** - * This describes the region of the current statement, not the entire linked list - * of statements - */ + /** This describes the region of the current statement, not the entire linked + * list of statements + */ def region: Region def replaceRegions(r: Region): Statement = { @@ -46,14 +45,12 @@ sealed abstract class Statement { sealed abstract class TypeDefinitionStatement extends Statement { import Statement.{Struct, Enum, ExternalStruct} - /** - * This is the name of the type being defined - */ + /** This is the name of the type being defined + */ def name: Constructor - /** - * here are the names of the constructors for this type - */ + /** here are the names of the constructors for this type + */ def constructors: List[Constructor] = this match { case Struct(nm, _, _) => nm :: Nil @@ -65,48 +62,53 @@ sealed abstract class TypeDefinitionStatement extends Statement { object Statement { - def definitionsOf(stmts: Iterable[Statement]): LazyList[TypeDefinitionStatement] = - stmts.iterator.collect { case tds: TypeDefinitionStatement => tds }.to(LazyList) + def definitionsOf( + stmts: Iterable[Statement] + ): LazyList[TypeDefinitionStatement] = + stmts.iterator + .collect { case tds: TypeDefinitionStatement => tds } + .to(LazyList) def valuesOf(stmts: Iterable[Statement]): LazyList[ValueStatement] = stmts.iterator.collect { case vs: ValueStatement => vs }.to(LazyList) - /** - * These introduce new values into scope - */ + /** These introduce new values into scope + */ sealed abstract class ValueStatement extends Statement { - /** - * All the names that are bound by this statement - */ + + /** All the names that are bound by this statement + */ def names: List[Bindable] = this match { - case Bind(BindingStatement(bound, _, _)) => bound.names // TODO Keep identifiers - case Def(defstatement) => defstatement.name :: Nil + case Bind(BindingStatement(bound, _, _)) => + bound.names // TODO Keep identifiers + case Def(defstatement) => defstatement.name :: Nil case ExternalDef(name, _, _) => name :: Nil } - /** - * These are all the free bindable names in the right hand side - * of this binding - */ + /** These are all the free bindable names in the right hand side of this + * binding + */ def freeVars: SortedSet[Bindable] = this match { case Bind(BindingStatement(_, decl, _)) => decl.freeVars case Def(defstatement) => val innerFrees = defstatement.result.get.freeVars // but the def name and, args shadow - (innerFrees - defstatement.name) -- defstatement.args.toList.flatMap(_.patternNames) + (innerFrees - defstatement.name) -- defstatement.args.toList.flatMap( + _.patternNames + ) case ExternalDef(_, _, _) => SortedSet.empty } - /** - * These are all the bindings, free or not, in this Statement - */ + /** These are all the bindings, free or not, in this Statement + */ def allNames: SortedSet[Bindable] = { this match { case Bind(BindingStatement(pat, decl, _)) => decl.allNames ++ pat.names case Def(defstatement) => - (defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.toList.flatMap(_.patternNames) + (defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.toList + .flatMap(_.patternNames) case ExternalDef(name, _, _) => SortedSet(name) } } @@ -114,180 +116,236 @@ object Statement { ////// // All the ValueStatements, which set up new bindings in the order they appear in the file - /////. - case class Bind(bind: BindingStatement[Pattern.Parsed, Declaration.NonBinding, Unit])(val region: Region) extends ValueStatement - case class Def(defstatement: DefStatement[Pattern.Parsed, OptIndent[Declaration]])(val region: Region) extends ValueStatement - case class ExternalDef(name: Bindable, params: List[(Bindable, TypeRef)], result: TypeRef)(val region: Region) extends ValueStatement + ///// . + case class Bind( + bind: BindingStatement[Pattern.Parsed, Declaration.NonBinding, Unit] + )(val region: Region) + extends ValueStatement + case class Def( + defstatement: DefStatement[Pattern.Parsed, OptIndent[Declaration]] + )(val region: Region) + extends ValueStatement + case class ExternalDef( + name: Bindable, + params: List[(Bindable, TypeRef)], + result: TypeRef + )(val region: Region) + extends ValueStatement ////// // TypeDefinitionStatement types: ////// - case class Enum(name: Constructor, - typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], - items: OptIndent[NonEmptyList[(Constructor, List[(Bindable, Option[TypeRef])])]] - )(val region: Region) extends TypeDefinitionStatement - case class ExternalStruct(name: Constructor, typeArgs: List[(TypeRef.TypeVar, Option[Kind.Arg])])(val region: Region) extends TypeDefinitionStatement - case class Struct(name: Constructor, - typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], - args: List[(Bindable, Option[TypeRef])])(val region: Region) extends TypeDefinitionStatement + case class Enum( + name: Constructor, + typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], + items: OptIndent[ + NonEmptyList[(Constructor, List[(Bindable, Option[TypeRef])])] + ] + )(val region: Region) + extends TypeDefinitionStatement + case class ExternalStruct( + name: Constructor, + typeArgs: List[(TypeRef.TypeVar, Option[Kind.Arg])] + )(val region: Region) + extends TypeDefinitionStatement + case class Struct( + name: Constructor, + typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], + args: List[(Bindable, Option[TypeRef])] + )(val region: Region) + extends TypeDefinitionStatement //// // These have no effect on the semantics of the Statement linked list //// - case class PaddingStatement(padding: Padding[Unit])(val region: Region) extends Statement - case class Comment(comment: CommentStatement[Unit])(val region: Region) extends Statement + case class PaddingStatement(padding: Padding[Unit])(val region: Region) + extends Statement + case class Comment(comment: CommentStatement[Unit])(val region: Region) + extends Statement // Parse a single item final val parser1: P[Statement] = { - import Declaration.NonBinding + import Declaration.NonBinding val bindingLike: Indy[(Pattern.Parsed, OptIndent[NonBinding])] = { val pat = Pattern.bindParser val patPart = pat <* (maybeSpace *> Declaration.eqP *> maybeSpace) // allow = to be like a block, we can continue on the next line indented - OptIndent.blockLike(Indy.lift(patPart), Declaration.nonBindingParser, P.unit) + OptIndent.blockLike( + Indy.lift(patPart), + Declaration.nonBindingParser, + P.unit + ) } - val bindingP: P[Statement] = - (bindingLike("") <* toEOL) - .region - .map { case (region, (pat, value)) => - Bind(BindingStatement(pat, value.get, ()))(region) - } - - val paddingSP: P[Statement] = - Padding - .nonEmptyParser - .region - .map { case (region, p) => PaddingStatement(p)(region) } - - val commentP: P[Statement] = - CommentStatement.parser(_ => P.unit).region - .map { case (region, cs) => Comment(cs)(region) }.run("") - - val defBody = maybeSpace.with1 *> OptIndent.indy(Declaration.parser).run("") - val defP: P[Statement] = - DefStatement.parser(Pattern.bindParser, defBody <* toEOL).region + val bindingP: P[Statement] = + (bindingLike("") <* toEOL).region + .map { case (region, (pat, value)) => + Bind(BindingStatement(pat, value.get, ()))(region) + } + + val paddingSP: P[Statement] = + Padding.nonEmptyParser.region + .map { case (region, p) => PaddingStatement(p)(region) } + + val commentP: P[Statement] = + CommentStatement + .parser(_ => P.unit) + .region + .map { case (region, cs) => Comment(cs)(region) } + .run("") + + val defBody = maybeSpace.with1 *> OptIndent.indy(Declaration.parser).run("") + val defP: P[Statement] = + DefStatement + .parser(Pattern.bindParser, defBody <* toEOL) + .region .map { case (region, DefStatement(nm, ta, args, ret, body)) => Def(DefStatement(nm, ta, args, ret, body))(region) } - val argParser: P[(Bindable, Option[TypeRef])] = - Identifier.bindableParser ~ TypeRef.annotationParser.? - - val structKey = keySpace("struct") - - val typeParams: P[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]] = { - val kindAnnot: P[Kind.Arg] = - (maybeSpace.soft.with1 *> (P.char(':') *> maybeSpace *> Kind.paramKindParser)) - - TypeRef.typeParams(kindAnnot.?) - } - val external = { - val externalStruct = - (structKey *> (Identifier.consParser ~ Parser.nonEmptyListToList(typeParams)).region <* toEOL) - .map { - case (region, (name, tva)) => ExternalStruct(name, tva)(region) - } - - val argParser: P[(Bindable, TypeRef)] = Identifier.bindableParser ~ TypeRef.annotationParser - - val externalDef = { - - val args = P.char('(') *> maybeSpace *> argParser.nonEmptyList <* maybeSpace <* P.char(')') - - val result = maybeSpace.with1 *> P.string("->") *> maybeSpace *> TypeRef.parser - - (((keySpace("def") *> Identifier.bindableParser ~ args ~ result).region) <* toEOL) - .map { - case (region, ((name, args), resType)) => - ExternalDef(name, args.toList, resType)(region) - } - } - - val externalVal = - (argParser <* toEOL) - .region - .map { case (region, (name, resType)) => - ExternalDef(name, Nil, resType)(region) - } - - keySpace("external") *> P.oneOf(externalStruct :: externalDef :: externalVal :: Nil) - } - - val struct = - ((structKey *> Identifier.consParser ~ typeParams.? ~ Parser.nonEmptyListToList(argParser.parensLines1Cut)).region <* toEOL) - .map { case (region, ((name, typeArgs), argsList)) => - Struct(name, typeArgs, argsList)(region) - } - - val enumP = { - val constructorP = - (Identifier.consParser ~ argParser.parensLines1Cut.?) - .map { - case (n, None) => (n, Nil) - case (n, Some(args)) => (n, args.toList) - } - - val sep = (Indy.lift(P.char(',') <* maybeSpace)) - .combineK(Indy.toEOLIndent) - .void - - val variants = Indy.lift(constructorP <* maybeSpace).nonEmptyList(sep) - - val nameVars = - OptIndent.block( - Indy.lift(keySpace("enum") *> Identifier.consParser ~ (typeParams.?)), - variants - ) - .run("") - .region - - (nameVars <* toEOL) - .map { case (region, ((ename, typeArgs), vars)) => - Enum(ename, typeArgs, vars)(region) - } - } - - // bindingP should come last so there is no ambiguity about identifiers - P.oneOf(commentP :: paddingSP :: defP :: struct :: enumP :: external :: bindingP :: Nil) + val argParser: P[(Bindable, Option[TypeRef])] = + Identifier.bindableParser ~ TypeRef.annotationParser.? + + val structKey = keySpace("struct") + + val typeParams: P[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]] = { + val kindAnnot: P[Kind.Arg] = + (maybeSpace.soft.with1 *> (P.char( + ':' + ) *> maybeSpace *> Kind.paramKindParser)) + + TypeRef.typeParams(kindAnnot.?) + } + val external = { + val externalStruct = + (structKey *> (Identifier.consParser ~ Parser.nonEmptyListToList( + typeParams + )).region <* toEOL) + .map { case (region, (name, tva)) => + ExternalStruct(name, tva)(region) + } + + val argParser: P[(Bindable, TypeRef)] = + Identifier.bindableParser ~ TypeRef.annotationParser + + val externalDef = { + + val args = + P.char('(') *> maybeSpace *> argParser.nonEmptyList <* maybeSpace <* P + .char(')') + + val result = + maybeSpace.with1 *> P.string("->") *> maybeSpace *> TypeRef.parser + + (((keySpace( + "def" + ) *> Identifier.bindableParser ~ args ~ result).region) <* toEOL) + .map { case (region, ((name, args), resType)) => + ExternalDef(name, args.toList, resType)(region) + } + } + + val externalVal = + (argParser <* toEOL).region + .map { case (region, (name, resType)) => + ExternalDef(name, Nil, resType)(region) + } + + keySpace("external") *> P.oneOf( + externalStruct :: externalDef :: externalVal :: Nil + ) + } + + val struct = + ((structKey *> Identifier.consParser ~ typeParams.? ~ Parser + .nonEmptyListToList(argParser.parensLines1Cut)).region <* toEOL) + .map { case (region, ((name, typeArgs), argsList)) => + Struct(name, typeArgs, argsList)(region) + } + + val enumP = { + val constructorP = + (Identifier.consParser ~ argParser.parensLines1Cut.?) + .map { + case (n, None) => (n, Nil) + case (n, Some(args)) => (n, args.toList) + } + + val sep = (Indy + .lift(P.char(',') <* maybeSpace)) + .combineK(Indy.toEOLIndent) + .void + + val variants = Indy.lift(constructorP <* maybeSpace).nonEmptyList(sep) + + val nameVars = + OptIndent + .block( + Indy.lift( + keySpace("enum") *> Identifier.consParser ~ (typeParams.?) + ), + variants + ) + .run("") + .region + + (nameVars <* toEOL) + .map { case (region, ((ename, typeArgs), vars)) => + Enum(ename, typeArgs, vars)(region) + } + } + + // bindingP should come last so there is no ambiguity about identifiers + P.oneOf( + commentP :: paddingSP :: defP :: struct :: enumP :: external :: bindingP :: Nil + ) } - /** - * This parses the *rest* of the string (it must end with End) - */ + /** This parses the *rest* of the string (it must end with End) + */ val parser: P0[List[Statement]] = parser1.rep0 <* Parser.maybeSpacesAndLines <* P.end - private def constructor(name: Constructor, taDoc: Doc, args: List[(Bindable, Option[TypeRef])]): Doc = + private def constructor( + name: Constructor, + taDoc: Doc, + args: List[(Bindable, Option[TypeRef])] + ): Doc = Document[Identifier].document(name) + taDoc + - (if (args.nonEmpty) { Doc.char('(') + Doc.intercalate(Doc.text(", "), args.toList.map(TypeRef.argDoc[Bindable] _)) + Doc.char(')') } - else Doc.empty) + (if (args.nonEmpty) { + Doc.char('(') + Doc.intercalate( + Doc.text(", "), + args.toList.map(TypeRef.argDoc[Bindable] _) + ) + Doc.char(')') + } else Doc.empty) private val colonSpace = Doc.text(": ") - private implicit val dunit: Document[Unit] = Document.instance[Unit](_ => Doc.empty) + private implicit val dunit: Document[Unit] = + Document.instance[Unit](_ => Doc.empty) private val optKindArgs: Document[Option[Kind.Arg]] = Document { - case None => Doc.empty + case None => Doc.empty case Some(ka) => colonSpace + Kind.argDoc(ka) } implicit lazy val document: Document[Statement] = { - val db = Document[BindingStatement[Pattern.Parsed, Declaration.NonBinding, Unit]] + val db = + Document[BindingStatement[Pattern.Parsed, Declaration.NonBinding, Unit]] val dc = Document[CommentStatement[Unit]] implicit val pair: Document[OptIndent[Declaration]] = - Document.instance[OptIndent[Declaration]] { - body => - body.sepDoc + + Document.instance[OptIndent[Declaration]] { body => + body.sepDoc + OptIndent.document(Declaration.document).document(body) } val dd = DefStatement.document[Pattern.Parsed, OptIndent[Declaration]] - implicit val consDoc = Document.instance[(Constructor, List[(Bindable, Option[TypeRef])])] { - case (nm, parts) => constructor(nm, Doc.empty, parts) - } + implicit val consDoc = + Document.instance[(Constructor, List[(Bindable, Option[TypeRef])])] { + case (nm, parts) => constructor(nm, Doc.empty, parts) + } Document.instance[Statement] { case Bind(bs) => @@ -302,14 +360,13 @@ object Statement { Padding.document[Unit].document(p) case Struct(nm, typeArgs, args) => val taDoc = typeArgs match { - case None => Doc.empty + case None => Doc.empty case Some(ta) => TypeRef.docTypeArgs(ta.toList)(optKindArgs.document) } Doc.text("struct ") + constructor(nm, taDoc, args) + Doc.line case Enum(nm, typeArgs, parts) => - val (colonSep, itemSep) = parts match { - case OptIndent.SameLine(_) => (Doc.space, Doc.text(", ")) + case OptIndent.SameLine(_) => (Doc.space, Doc.text(", ")) case OptIndent.NotSameLine(_) => (Doc.empty, Doc.line) } @@ -321,30 +378,40 @@ object Statement { val indentedCons = OptIndent.document(neDoc(consDoc)).document(parts) val taDoc = typeArgs match { - case None => Doc.empty + case None => Doc.empty case Some(ta) => TypeRef.docTypeArgs(ta.toList)(optKindArgs.document) } - Doc.text("enum ") + Document[Constructor].document(nm) + taDoc + Doc.char(':') + + Doc.text("enum ") + Document[Constructor].document(nm) + taDoc + Doc + .char(':') + colonSep + indentedCons + Doc.line case ExternalDef(name, Nil, res) => - Doc.text("external ") + Document[Bindable].document(name) + Doc.text(": ") + res.toDoc + Doc.line + Doc.text("external ") + Document[Bindable].document(name) + Doc.text( + ": " + ) + res.toDoc + Doc.line case ExternalDef(name, args, res) => val argDoc = { - val da = Doc.intercalate(Doc.text(", "), args.map { case (n, tr) => - Document[Bindable].document(n) + Doc.text(": ") + tr.toDoc - }) + val da = Doc.intercalate( + Doc.text(", "), + args.map { case (n, tr) => + Document[Bindable].document(n) + Doc.text(": ") + tr.toDoc + } + ) Doc.char('(') + da + Doc.char(')') } - Doc.text("external def ") + Document[Bindable].document(name) + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line + Doc.text("external def ") + Document[Bindable].document( + name + ) + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line case ExternalStruct(nm, typeArgs) => val taDoc = TypeRef.docTypeArgs(typeArgs.toList) { - case None => Doc.empty + case None => Doc.empty case Some(ka) => Doc.text(": ") + Kind.argDoc(ka) } - Doc.text("external struct ") + Document[Constructor].document(nm) + taDoc + Doc.line + Doc.text("external struct ") + Document[Constructor].document( + nm + ) + taDoc + Doc.line } } @@ -353,4 +420,3 @@ object Statement { Doc.intercalate(Doc.empty, stmts.toList.map(document.document(_))) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala b/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala index ce34f577a..dd7673bb1 100644 --- a/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala +++ b/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala @@ -5,14 +5,16 @@ import cats.parse.{Parser0 => P0, Parser => P, Accumulator, Appender} abstract class GenericStringUtil { protected def decodeTable: Map[Char, Char] - private val encodeTable = decodeTable.iterator.map { case (v, k) => (k, s"\\$v") }.toMap + private val encodeTable = decodeTable.iterator.map { case (v, k) => + (k, s"\\$v") + }.toMap private val nonPrintEscape: Array[String] = (0 until 32).map { c => val strHex = c.toHexString val strPad = List.fill(4 - strHex.length)('0').mkString s"\\u$strPad$strHex" - }.toArray + }.toArray val escapedToken: P[Int] = { def parseIntStr(p: P[Any], base: Int): P[Int] = @@ -49,11 +51,10 @@ abstract class GenericStringUtil { P.charWhere { c => val ci = c.toInt (0xdc00 <= ci) && (ci <= 0xdfff) - } - .map { low => + }.map { low => val lowOff = low - 0xdc00 + 0x10000 - { high => + { high => val highPart = (high - 0xd800) * 0x400 highPart + lowOff } @@ -64,7 +65,7 @@ abstract class GenericStringUtil { val codePointAccumulator: Accumulator[Int, String] = new Accumulator[Int, String] { - def newAppender(first: Int): Appender[Int,String] = + def newAppender(first: Int): Appender[Int, String] = new Appender[Int, String] { val strbuilder = new java.lang.StringBuilder strbuilder.appendCodePoint(first) @@ -76,11 +77,12 @@ abstract class GenericStringUtil { def finish(): String = strbuilder.toString } } - /** - * String content without the delimiter - */ + + /** String content without the delimiter + */ def undelimitedString1(endP: P[Unit]): P[String] = { - escapedToken.orElse((!endP).with1 *> utf16Codepoint) + escapedToken + .orElse((!endP).with1 *> utf16Codepoint) .repAs(codePointAccumulator) } @@ -94,14 +96,22 @@ abstract class GenericStringUtil { end *> undelimitedString1(end).orElse(P.pure("")) <* end } - def interpolatedString[A, B](quoteChar: Char, istart: P[A => B], interp: P0[A], iend: P[Unit]): P[List[Either[B, (Region, String)]]] = { + def interpolatedString[A, B]( + quoteChar: Char, + istart: P[A => B], + interp: P0[A], + iend: P[Unit] + ): P[List[Either[B, (Region, String)]]] = { val strQuote = P.char(quoteChar) val strLit: P[String] = undelimitedString1(strQuote.orElse(istart.void)) - val notStr: P[B] = (istart ~ interp ~ iend).map { case ((fn, a), _) => fn(a) } + val notStr: P[B] = (istart ~ interp ~ iend).map { case ((fn, a), _) => + fn(a) + } val either: P[Either[B, (Region, String)]] = - ((P.index.with1 ~ strLit ~ P.index).map { case ((s, str), l) => Right((Region(s, l), str)) }) + ((P.index.with1 ~ strLit ~ P.index) + .map { case ((s, str), l) => Right((Region(s, l), str)) }) .orElse(notStr.map(Left(_))) (strQuote ~ either.rep0 ~ strQuote).map { case ((_, lst), _) => lst } @@ -110,15 +120,17 @@ abstract class GenericStringUtil { def escape(quoteChar: Char, str: String): String = { // We can ignore escaping the opposite character used for the string // x isn't escaped anyway and is kind of a hack here - val ignoreEscape = if (quoteChar == '\'') '"' else if (quoteChar == '"') '\'' else 'x' + val ignoreEscape = + if (quoteChar == '\'') '"' else if (quoteChar == '"') '\'' else 'x' str.flatMap { c => if (c == ignoreEscape) c.toString - else encodeTable.get(c) match { - case None => - if (c < ' ') nonPrintEscape(c.toInt) - else c.toString - case Some(esc) => esc - } + else + encodeTable.get(c) match { + case None => + if (c < ' ') nonPrintEscape(c.toInt) + else c.toString + case Some(esc) => esc + } } } @@ -140,25 +152,21 @@ abstract class GenericStringUtil { if (idx >= str.length) { // done idx - } - else if (idx < 0) { + } else if (idx < 0) { // error from decodeNum idx - } - else { + } else { val c0 = str.charAt(idx) if (c0 != '\\') { sb.append(c0) loop(idx + 1) - } - else { + } else { // str(idx) == \ val nextIdx = idx + 1 if (nextIdx >= str.length) { // error we expect there to be a character after \ ~idx - } - else { + } else { val c = str.charAt(nextIdx) decodeTable.get(c) match { case Some(d) => @@ -166,10 +174,10 @@ abstract class GenericStringUtil { loop(idx + 2) case None => c match { - case 'o' => loop(decodeNum(idx + 2, 2, 8)) - case 'x' => loop(decodeNum(idx + 2, 2, 16)) - case 'u' => loop(decodeNum(idx + 2, 4, 16)) - case 'U' => loop(decodeNum(idx + 2, 8, 16)) + case 'o' => loop(decodeNum(idx + 2, 2, 8)) + case 'x' => loop(decodeNum(idx + 2, 2, 16)) + case 'u' => loop(decodeNum(idx + 2, 4, 16)) + case 'U' => loop(decodeNum(idx + 2, 8, 16)) case other => // \c is interpretted as just \c, if the character isn't escaped sb.append('\\') @@ -202,7 +210,8 @@ object StringUtil extends GenericStringUtil { ('n', '\n'), ('r', '\r'), ('t', '\t'), - ('v', 11.toChar)) // vertical tab + ('v', 11.toChar) + ) // vertical tab } object JsonStringUtil extends GenericStringUtil { @@ -216,5 +225,6 @@ object JsonStringUtil extends GenericStringUtil { ('f', 12.toChar), // form-feed ('n', '\n'), ('r', '\r'), - ('t', '\t')) + ('t', '\t') + ) } diff --git a/core/src/main/scala/org/bykn/bosatsu/Test.scala b/core/src/main/scala/org/bykn/bosatsu/Test.scala index 03aaf754d..88055e914 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Test.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Test.scala @@ -9,8 +9,8 @@ sealed abstract class Test { def failures: Option[Test] = this match { - case Test.Assertion(true, _) => None - case f@Test.Assertion(false, _) => Some(f) + case Test.Assertion(true, _) => None + case f @ Test.Assertion(false, _) => Some(f) case Test.Suite(nm, ts) => { val innerFails = ts.flatMap(_.failures.toList) if (innerFails.isEmpty) None @@ -62,7 +62,13 @@ object Test { loop(t, None, 0, 0, Doc.empty) @annotation.tailrec - def loop(ts: List[Test], lastSuite: Option[(Int, Int)], passes: Int, fails: Int, front: Doc): Report = + def loop( + ts: List[Test], + lastSuite: Option[(Int, Int)], + passes: Int, + fails: Int, + front: Doc + ): Report = ts match { case Nil => val sumDoc = @@ -75,33 +81,48 @@ object Test { case Assertion(true, _) :: rest => loop(rest, lastSuite, passes + 1, fails, front) case Assertion(false, label) :: rest => - loop(rest, lastSuite, passes, fails + 1, front + (Doc.line + Doc.text(label) + colonSpace + failDoc)) + loop( + rest, + lastSuite, + passes, + fails + 1, + front + (Doc.line + Doc.text(label) + colonSpace + failDoc) + ) case Suite(label, rest) :: tail => val Report(p, f, d) = init(rest) - val res = Doc.line + Doc.text(label) + Doc.char(':') + (Doc.lineOrSpace + d).nested(2) + val res = Doc.line + Doc.text(label) + Doc.char( + ':' + ) + (Doc.lineOrSpace + d).nested(2) loop(tail, Some((p, f)), passes + p, fails + f, front + res) } init(t :: Nil) } - def outputFor(resultList: List[(PackageName, Option[Eval[Test]])], color: LocationMap.Colorize): Report = { + def outputFor( + resultList: List[(PackageName, Option[Eval[Test]])], + color: LocationMap.Colorize + ): Report = { val noTests = resultList.collect { case (p, None) => p } - val results = resultList.collect { case (p, Some(t)) => (p, Test.report(t.value, color)) }.sortBy(_._1) + val results = resultList + .collect { case (p, Some(t)) => (p, Test.report(t.value, color)) } + .sortBy(_._1) val successes = results.iterator.map { case (_, Report(s, _, _)) => s }.sum val failures = results.iterator.map { case (_, Report(_, f, _)) => f }.sum val success = noTests.isEmpty && (failures == 0) val suffix = - if (results.lengthCompare(1) > 0) (Doc.hardLine + Doc.hardLine + Test.summary(successes, failures, color)) + if (results.lengthCompare(1) > 0) + (Doc.hardLine + Doc.hardLine + Test.summary(successes, failures, color)) else Doc.empty val docRes: Doc = - Doc.intercalate(Doc.hardLine + Doc.hardLine, + Doc.intercalate( + Doc.hardLine + Doc.hardLine, results.map { case (p, Report(_, _, d)) => Doc.text(p.asString) + Doc.char(':') + (Doc.lineOrSpace + d).nested(2) - }) + suffix - + } + ) + suffix if (success) Report(successes, failures, docRes) else { @@ -109,11 +130,17 @@ object Test { if (noTests.isEmpty) Nil else { val prefix = Doc.text("packages with missing tests: ") - val missingDoc = Doc.intercalate(Doc.comma + Doc.lineOrSpace, noTests.sorted.map { p => Doc.text(p.asString) }) + val missingDoc = Doc.intercalate( + Doc.comma + Doc.lineOrSpace, + noTests.sorted.map { p => Doc.text(p.asString) } + ) (prefix + missingDoc.nested(2)) :: Nil } - val fullOut = Doc.intercalate(Doc.hardLine + Doc.hardLine + (Doc.char('#') * 80) + Doc.line, docRes :: missingDoc) + val fullOut = Doc.intercalate( + Doc.hardLine + Doc.hardLine + (Doc.char('#') * 80) + Doc.line, + docRes :: missingDoc + ) val failureStr = if (failures == 1) "1 test failure" @@ -124,10 +151,13 @@ object Test { if (missingCount > 0) { val packString = if (missingCount == 1) "package" else "packages" s"$failureStr and $missingCount $packString with no tests found" - } - else failureStr + } else failureStr - Report(successes, failures, fullOut + Doc.hardLine + Doc.hardLine + Doc.text(excepMessage)) + Report( + successes, + failures, + fullOut + Doc.hardLine + Doc.hardLine + Doc.text(excepMessage) + ) } } @@ -137,7 +167,7 @@ object Test { a match { case ProductValue(b, Str(message)) => val bool = b match { - case True => true + case True => true case False => false case _ => sys.error(s"expected test value: $a") @@ -146,8 +176,8 @@ object Test { case other => // $COVERAGE-OFF$ sys.error(s"expected test value: $other") - // $COVERAGE-ON$ - } + // $COVERAGE-ON$ + } def toSuite(a: ProductValue): Test = a match { case ProductValue(Str(name), VList(tests)) => @@ -155,7 +185,7 @@ object Test { case other => // $COVERAGE-OFF$ sys.error(s"expected test value: $other") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } def toTest(a: Value): Test = @@ -171,9 +201,9 @@ object Test { case unexpected => // $COVERAGE-OFF$ sys.error(s"unreachable if compilation has worked: $unexpected") - // $COVERAGE-ON$ + // $COVERAGE-ON$ - } + } toTest(value) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala index cee8e49fd..961c68ab3 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala @@ -21,58 +21,83 @@ object TotalityCheck { type ListPatElem = ListPart[Pattern[Cons, Type]] sealed abstract class Error - case class ArityMismatch(cons: Cons, in: Pattern[Cons, Type], env: TypeEnv[Any], expected: Int, found: Int) extends Error - case class UnknownConstructor(cons: Cons, in: Pattern[Cons, Type], env: TypeEnv[Any]) extends Error - case class MultipleSplicesInPattern(pat: ListPat[Cons, Type], env: TypeEnv[Any]) extends Error + case class ArityMismatch( + cons: Cons, + in: Pattern[Cons, Type], + env: TypeEnv[Any], + expected: Int, + found: Int + ) extends Error + case class UnknownConstructor( + cons: Cons, + in: Pattern[Cons, Type], + env: TypeEnv[Any] + ) extends Error + case class MultipleSplicesInPattern( + pat: ListPat[Cons, Type], + env: TypeEnv[Any] + ) extends Error case class InvalidStrPat(pat: StrPat, env: TypeEnv[Any]) extends Error sealed abstract class ExprError[A] { def matchExpr: Expr.Match[A] } - case class NonTotalMatch[A](matchExpr: Expr.Match[A], missing: NonEmptyList[Pattern[Cons, Type]]) extends ExprError[A] - case class InvalidPattern[A](matchExpr: Expr.Match[A], err: Error) extends ExprError[A] - case class UnreachableBranches[A](matchExpr: Expr.Match[A], branches: NonEmptyList[Pattern[Cons, Type]]) extends ExprError[A] + case class NonTotalMatch[A]( + matchExpr: Expr.Match[A], + missing: NonEmptyList[Pattern[Cons, Type]] + ) extends ExprError[A] + case class InvalidPattern[A](matchExpr: Expr.Match[A], err: Error) + extends ExprError[A] + case class UnreachableBranches[A]( + matchExpr: Expr.Match[A], + branches: NonEmptyList[Pattern[Cons, Type]] + ) extends ExprError[A] } -/** - * Here is code for performing totality checks of matches. - * One key thing: we can assume that any two patterns are describing the same type, or otherwise - * typechecking cannot pass. So, this allows us to make certain inferences, e.g. - * _ - [_] = [_, _, *_] - * because we know the type must be a list of some kind of [_] is to be a well typed pattern. - * - * similarly, some things are ill-typed: `1 - 'foo'` doesn't make any sense. Those two patterns - * don't describe the same type. - */ +/** Here is code for performing totality checks of matches. One key thing: we + * can assume that any two patterns are describing the same type, or otherwise + * typechecking cannot pass. So, this allows us to make certain inferences, + * e.g. _ - [_] = [_, _, *_] because we know the type must be a list of some + * kind of [_] is to be a well typed pattern. + * + * similarly, some things are ill-typed: `1 - 'foo'` doesn't make any sense. + * Those two patterns don't describe the same type. + */ case class TotalityCheck(inEnv: TypeEnv[Any]) { import TotalityCheck._ - /** - * Constructors must match all items to be legal - */ - private def checkArity(nm: Cons, size: Int, pat: Pattern[Cons, Type]): Res[Unit] = + /** Constructors must match all items to be legal + */ + private def checkArity( + nm: Cons, + size: Int, + pat: Pattern[Cons, Type] + ): Res[Unit] = inEnv.typeConstructors.get(nm) match { case None => Left(NonEmptyList.of(UnknownConstructor(nm, pat, inEnv))) case Some((_, params, _)) => val cmp = params.lengthCompare(size) if (cmp == 0) validUnit - else Left(NonEmptyList.of(ArityMismatch(nm, pat, inEnv, size, params.size))) + else + Left( + NonEmptyList.of(ArityMismatch(nm, pat, inEnv, size, params.size)) + ) } private[this] val validUnit: Res[Unit] = Right(()) - /** - * Check that a given pattern follows all the rules. - * - * The main rules are: - * * in strings, you cannot have two adjacent variable patterns (where should one end?) - * * in lists we cannot have more than one variable pattern (maybe relaxed later to the above) - */ + + /** Check that a given pattern follows all the rules. + * + * The main rules are: * in strings, you cannot have two adjacent variable + * patterns (where should one end?) * in lists we cannot have more than one + * variable pattern (maybe relaxed later to the above) + */ def validatePattern(p: Pattern[Cons, Type]): Res[Unit] = p match { - case lp@ListPat(parts) => + case lp @ ListPat(parts) => val twoAdj = lp.toSeqPattern.toList.sliding(2).exists { case Seq(SeqPart.Wildcard, SeqPart.Wildcard) => true - case _ => false + case _ => false } val outer = if (!twoAdj) validUnit @@ -81,12 +106,12 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { val inners: Res[Unit] = parts.parTraverse_ { case ListPart.Item(p) => validatePattern(p) - case _ => validUnit + case _ => validUnit } (outer, inners).parMapN { (_, _) => () } - case sp@StrPat(_) => + case sp @ StrPat(_) => val simp = sp.toSeqPattern def hasAdjacentWild[A](seq: SeqPattern[A]): Boolean = seq match { @@ -110,20 +135,19 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { case _ => validUnit } - /** - * Check that an expression, and all inner expressions, are total, or return - * a NonEmptyList of matches that are not total - */ + /** Check that an expression, and all inner expressions, are total, or return + * a NonEmptyList of matches that are not total + */ def checkExpr[A](expr: Expr[A]): ValidatedNel[ExprError[A], Unit] = { import Expr._ expr match { - case Annotation(e, _, _) => checkExpr(e) - case Generic(_, e) => checkExpr(e) - case Lambda(_, e, _) => checkExpr(e) + case Annotation(e, _, _) => checkExpr(e) + case Generic(_, e) => checkExpr(e) + case Lambda(_, e, _) => checkExpr(e) case Global(_, _, _) | Local(_, _) | Literal(_, _) => Validated.valid(()) - case App(fn, args, _) => checkExpr(fn) *> args.traverse_(checkExpr) + case App(fn, args, _) => checkExpr(fn) *> args.traverse_(checkExpr) case Let(_, e1, e2, _, _) => checkExpr(e1) *> checkExpr(e2) - case m@Match(arg, branches, _) => + case m @ Match(arg, branches, _) => val patterns = branches.toList.map(_._1) patterns .parTraverse_(validatePattern) @@ -149,7 +173,9 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { val unr = patternSetOps.unreachableBranches(patterns) NonEmptyList.fromList(unr) match { case Some(nel) => - Validated.invalidNel(UnreachableBranches(m, nel): ExprError[A]) + Validated.invalidNel( + UnreachableBranches(m, nel): ExprError[A] + ) case None => Validated.valid(()) } } @@ -175,15 +201,22 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { def difference(a: Pattern[Cons, Type], b: Pattern[Cons, Type]): Patterns = patternSetOps.difference(a, b) - private def structToList(n: Cons, args: List[Pattern[Cons, Type]]): Option[Pattern.ListPat[Cons, Type]] = + private def structToList( + n: Cons, + args: List[Pattern[Cons, Type]] + ): Option[Pattern.ListPat[Cons, Type]] = (n, args) match { - case ((PackageName.PredefName, Constructor("EmptyList")), Nil) => Some(Pattern.ListPat(Nil)) - case ((PackageName.PredefName, Constructor("NonEmptyList")), h :: t :: Nil) => + case ((PackageName.PredefName, Constructor("EmptyList")), Nil) => + Some(Pattern.ListPat(Nil)) + case ( + (PackageName.PredefName, Constructor("NonEmptyList")), + h :: t :: Nil + ) => val tailRes = t match { case Pattern.PositionalStruct(n, a) => structToList(n, a).map(_.parts) case Pattern.ListPat(parts) => Some(parts) - case _ => + case _ => if (patternSetOps.isTop(t)) Some(Pattern.ListPart.WildList :: Nil) else None } @@ -200,13 +233,18 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { SetOps.imap[SeqPattern[Pattern[Cons, Type]], ListPat[Cons, Type]]( seqP, ListPat.fromSeqPattern(_), - _.toSeqPattern) + _.toSeqPattern + ) private val strPatternSetOps: SetOps[StrPat] = SetOps.imap[SeqPattern[Char], StrPat]( - SeqPattern.seqPatternSetOps(SeqPart.part1SetOps(SetOps.distinct[Char]), implicitly), + SeqPattern.seqPatternSetOps( + SeqPart.part1SetOps(SetOps.distinct[Char]), + implicitly + ), StrPat.fromSeqPattern(_), - _.toSeqPattern) + _.toSeqPattern + ) private val getProd: Int => SetOps[List[Pattern[Cons, Type]]] = memoizeDagHashed[Int, SetOps[List[Pattern[Cons, Type]]]] { @@ -226,28 +264,33 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { // $COVERAGE-OFF$ case _ => sys.error(s"invalid arity: $arity, found empty list") // $COVERAGE-ON$ - }) + } + ) } - private def fromList(pats: List[Pattern[Cons, Type]]): Option[Pattern[Cons, Type]] = - pats match { - case Nil => None - case one :: Nil => Some(one) - case h :: tail => Some(Pattern.union(h, tail)) - } + private def fromList( + pats: List[Pattern[Cons, Type]] + ): Option[Pattern[Cons, Type]] = + pats match { + case Nil => None + case one :: Nil => Some(one) + case h :: tail => Some(Pattern.union(h, tail)) + } lazy val patternSetOps: SetOps[Pattern[Cons, Type]] = new SetOps[Pattern[Cons, Type]] { self => - val urm = new Relatable.UnionRelModule[Option[Pattern[Cons, Type]]] { - def relatable: Relatable[Option[Pattern[Cons,Type]]] = + def relatable: Relatable[Option[Pattern[Cons, Type]]] = new Relatable[Option[Pattern[Cons, Type]]] { - def relate(left: Option[Pattern[Cons,Type]], right: Option[Pattern[Cons,Type]]): Rel = + def relate( + left: Option[Pattern[Cons, Type]], + right: Option[Pattern[Cons, Type]] + ): Rel = (left, right) match { case (Some(l), Some(r)) => self.relate(l, r) - case (None, None) => Rel.Same - case (None, Some(_)) => Rel.Sub - case (Some(_), None) => Rel.Super + case (None, None) => Rel.Same + case (None, Some(_)) => Rel.Sub + case (Some(_), None) => Rel.Super } } @@ -259,14 +302,13 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { val size = left.size if (size == right.size) { getProd(size).intersection(left, right) - } - else Nil + } else Nil }, solveOne = { prod => (prod.zipWithIndex.collectFirstSome { case (param, idx) => deunion(Some(param)) match { case Right((Some(p1), Some(p2))) => Some(((p1, p2), idx)) - case _ => None + case _ => None } }) match { case Some(((p1, p2), idx)) => @@ -279,20 +321,30 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { prodSetOps.missingBranches(prod :: Nil, union).isEmpty }) } - })(new Relatable[List[Pattern[Cons, Type]]] { - def relate(a: List[Pattern[Cons, Type]], b: List[Pattern[Cons, Type]]) = { - val size = a.size - if (size == b.size) getProd(size).relate(a, b) - else Rel.Disjoint - } - }) + } + )(new Relatable[List[Pattern[Cons, Type]]] { + def relate( + a: List[Pattern[Cons, Type]], + b: List[Pattern[Cons, Type]] + ) = { + val size = a.size + if (size == b.size) getProd(size).relate(a, b) + else Rel.Disjoint + } + }) - def cheapUnion(head: Option[Pattern[Cons, Type]], tail: List[Option[Pattern[Cons,Type]]]): Option[Pattern[Cons,Type]] = + def cheapUnion( + head: Option[Pattern[Cons, Type]], + tail: List[Option[Pattern[Cons, Type]]] + ): Option[Pattern[Cons, Type]] = fromList((head :: tail).flatten) // there are no empty patterns - def isEmpty(a: Option[Pattern[Cons,Type]]): Boolean = a.isEmpty - def intersect(a: Option[Pattern[Cons,Type]], b: Option[Pattern[Cons,Type]]): Option[Pattern[Cons,Type]] = + def isEmpty(a: Option[Pattern[Cons, Type]]): Boolean = a.isEmpty + def intersect( + a: Option[Pattern[Cons, Type]], + b: Option[Pattern[Cons, Type]] + ): Option[Pattern[Cons, Type]] = (a, b) match { case (Some(l), Some(r)) => fromList(self.intersection(l, r)) @@ -301,7 +353,13 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { } // we know that a is nonEmpty by contract - def deunion(a: Option[Pattern[Cons,Type]]): Either[(Option[Pattern[Cons,Type]], Option[Pattern[Cons,Type]]) => Rel.SuperOrSame,(Option[Pattern[Cons,Type]], Option[Pattern[Cons,Type]])] = + def deunion(a: Option[Pattern[Cons, Type]]): Either[ + ( + Option[Pattern[Cons, Type]], + Option[Pattern[Cons, Type]] + ) => Rel.SuperOrSame, + (Option[Pattern[Cons, Type]], Option[Pattern[Cons, Type]]) + ] = a.get match { case u @ Pattern.Union(_, _) => val (left, right) = u.split @@ -309,10 +367,10 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { case Pattern.Literal(_) => // if a literal is >= something it is same, no partial supersets Left((_, _) => Rel.Same) - case Named(_, pat) => deunion(Some(pat)) + case Named(_, pat) => deunion(Some(pat)) case Annotation(pat, _) => deunion(Some(pat)) case WildCard | Var(_) => - Left({(a, b) => + Left({ (a, b) => // unify union returns no top level unions // so isTop is cheap if (unifyUnion(a.toList ::: b.toList).exists(isTop)) Rel.Same @@ -321,45 +379,55 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { case pos @ PositionalStruct(name, params) => structToList(name, params) match { case Some(lp) => deunion(Some(lp)) - case None => - // this isn't a list and is >= the union of two items - // so it has to be a struct with the same name - // assuming pos >= pat - // we know that pat is a Struct(name, _) or union of those - // decompose it into List representing the union, of a list - // of the fields - - val paramSize = params.size - - def unstruct(pat: Option[Pattern[Cons, Type]]): List[List[Pattern[Cons, Type]]] = - pat match { - case None => Nil - case Some(PositionalStruct(n, ps)) => - assert(n == name) - assert(ps.size == paramSize) - ps :: Nil - case Some(Pattern.Union(h, t)) => - (h :: t.toList).flatMap { p => unstruct(Some(p)) } - // $COVERAGE-OFF$ - case Some(Annotation(p, _)) => unstruct(Some(p)) - case Some(Named(_, p)) => unstruct(Some(p)) - case Some(unexpected) => - sys.error(s"unexpected sub pattern of ($pos) in deunion: $unexpected") - // $COVERAGE-ON$ - } - - def solve(p1: Option[Pattern[Cons,Type]], p2: Option[Pattern[Cons,Type]]): Rel.SuperOrSame = { - val unionParams = (unstruct(p1) ::: unstruct(p2)).distinct - // pos <:> (p1 | p2) we can just element wise un - - unionProduct.relate(params :: Nil, unionParams) - .asInstanceOf[Rel.SuperOrSame] + case None => + // this isn't a list and is >= the union of two items + // so it has to be a struct with the same name + // assuming pos >= pat + // we know that pat is a Struct(name, _) or union of those + // decompose it into List representing the union, of a list + // of the fields + + val paramSize = params.size + + def unstruct( + pat: Option[Pattern[Cons, Type]] + ): List[List[Pattern[Cons, Type]]] = + pat match { + case None => Nil + case Some(PositionalStruct(n, ps)) => + assert(n == name) + assert(ps.size == paramSize) + ps :: Nil + case Some(Pattern.Union(h, t)) => + (h :: t.toList).flatMap { p => unstruct(Some(p)) } + // $COVERAGE-OFF$ + case Some(Annotation(p, _)) => unstruct(Some(p)) + case Some(Named(_, p)) => unstruct(Some(p)) + case Some(unexpected) => + sys.error( + s"unexpected sub pattern of ($pos) in deunion: $unexpected" + ) + // $COVERAGE-ON$ } - Left(solve(_, _)) - } - case lp@ListPat(_) => - def optPatternToList(p: Option[Pattern[Cons, Type]]): List[ListPat[Cons, Type]] = + def solve( + p1: Option[Pattern[Cons, Type]], + p2: Option[Pattern[Cons, Type]] + ): Rel.SuperOrSame = { + val unionParams = (unstruct(p1) ::: unstruct(p2)).distinct + // pos <:> (p1 | p2) we can just element wise un + + unionProduct + .relate(params :: Nil, unionParams) + .asInstanceOf[Rel.SuperOrSame] + } + + Left(solve(_, _)) + } + case lp @ ListPat(_) => + def optPatternToList( + p: Option[Pattern[Cons, Type]] + ): List[ListPat[Cons, Type]] = p match { case Some(sp @ ListPat(_)) => sp :: Nil case Some(PositionalStruct(n, ps)) => @@ -374,14 +442,17 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { } Left( - {(b, c) => - val rhs = optPatternToList(b) ::: optPatternToList(c) - if (listPatternSetOps.missingBranches(lp :: Nil, rhs).isEmpty) Rel.Same + { (b, c) => + val rhs = optPatternToList(b) ::: optPatternToList(c) + if (listPatternSetOps.missingBranches(lp :: Nil, rhs).isEmpty) + Rel.Same else Rel.Super } ) - case sp@StrPat(_) => - def optPatternToStr(p: Option[Pattern[Cons, Type]]): List[StrPat] = + case sp @ StrPat(_) => + def optPatternToStr( + p: Option[Pattern[Cons, Type]] + ): List[StrPat] = p match { case Some(sp @ StrPat(_)) => sp :: Nil case Some(Literal(Lit.Str(s))) => @@ -395,9 +466,10 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { } Left( - {(b, c) => + { (b, c) => val rhs = optPatternToStr(b) ::: optPatternToStr(c) - if (strPatternSetOps.missingBranches(sp :: Nil, rhs).isEmpty) Rel.Same + if (strPatternSetOps.missingBranches(sp :: Nil, rhs).isEmpty) + Rel.Same else Rel.Super } ) @@ -408,91 +480,98 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { def topFor(dt: DefinedType[Any]): List[PositionalStruct[Cons, Type]] = dt.constructors.map { cf => - PositionalStruct((dt.packageName, cf.name), cf.args.map(_ => WildCard)) + PositionalStruct( + (dt.packageName, cf.name), + cf.args.map(_ => WildCard) + ) } def intersection( - left: Pattern[Cons, Type], - right: Pattern[Cons, Type]): List[Pattern[Cons, Type]] = - (left, right) match { - case (_, _) if left == right => left :: Nil - case (Var(va), Var(vb)) => Var(Ordering[Bindable].min(va, vb)) :: Nil - case (Var(_), v) => v :: Nil - case (v, Var(_)) => v :: Nil - case (Named(va, pa), Named(vb, pb)) if va == vb => - intersection(pa, pb).map(Named(va, _)) - case (Named(_, pa), r) => intersection(pa, r) - case (l, Named(_, pb)) => intersection(l, pb) - case (WildCard, v) => v :: Nil - case (v, WildCard) => v :: Nil - case (Annotation(p, _), t) => intersection(p, t) - case (t, Annotation(p, _)) => intersection(t, p) - case (Literal(a), Literal(b)) => - if (a == b) left :: Nil - else Nil - case (Literal(Lit.Str(s)), p@StrPat(_)) => - if (p.matches(s)) left :: Nil - else Nil - case (p@StrPat(_), Literal(Lit.Str(s))) => - if (p.matches(s)) right :: Nil - else Nil - case (p1@StrPat(_), p2@StrPat(_)) => - strPatternSetOps.intersection(p1, p2) - case (lp@ListPat(_), rp@ListPat(_)) => - listPatternSetOps.intersection(lp, rp) - case (PositionalStruct(n, as), rp@ListPat(_)) => - structToList(n, as) match { - case Some(lp) => intersection(lp, rp) - case None => - if (isTop(rp)) left :: Nil - else Nil - } - case (lp@ListPat(_), pos@PositionalStruct(_, _)) => - intersection(pos, lp) - case (PositionalStruct(ln, lps), PositionalStruct(rn, rps)) => - if (ln == rn) { - val la = lps.size - if (rps.size == la) { - // the arity must match or check expr fails - // if the arity doesn't match, just consider this - // a mismatch - unifyUnion( - getProd(la) - .intersection(lps, rps) - .map(PositionalStruct(ln, _)) - ) - } + left: Pattern[Cons, Type], + right: Pattern[Cons, Type] + ): List[Pattern[Cons, Type]] = + (left, right) match { + case (_, _) if left == right => left :: Nil + case (Var(va), Var(vb)) => Var(Ordering[Bindable].min(va, vb)) :: Nil + case (Var(_), v) => v :: Nil + case (v, Var(_)) => v :: Nil + case (Named(va, pa), Named(vb, pb)) if va == vb => + intersection(pa, pb).map(Named(va, _)) + case (Named(_, pa), r) => intersection(pa, r) + case (l, Named(_, pb)) => intersection(l, pb) + case (WildCard, v) => v :: Nil + case (v, WildCard) => v :: Nil + case (Annotation(p, _), t) => intersection(p, t) + case (t, Annotation(p, _)) => intersection(t, p) + case (Literal(a), Literal(b)) => + if (a == b) left :: Nil + else Nil + case (Literal(Lit.Str(s)), p @ StrPat(_)) => + if (p.matches(s)) left :: Nil + else Nil + case (p @ StrPat(_), Literal(Lit.Str(s))) => + if (p.matches(s)) right :: Nil + else Nil + case (p1 @ StrPat(_), p2 @ StrPat(_)) => + strPatternSetOps.intersection(p1, p2) + case (lp @ ListPat(_), rp @ ListPat(_)) => + listPatternSetOps.intersection(lp, rp) + case (PositionalStruct(n, as), rp @ ListPat(_)) => + structToList(n, as) match { + case Some(lp) => intersection(lp, rp) + case None => + if (isTop(rp)) left :: Nil else Nil - } - else Nil - case _ => - relate(left, right) match { - case Rel.Disjoint => Nil - case Rel.Sub => left :: Nil - case Rel.Same => normalizePattern(left) :: Nil - case Rel.Super => right :: Nil - case Rel.Intersects => - // we know that neither left nor right can be a top - // value because top >= everything - // - // non-trivial intersection - (left, right) match { - case (Union(a, b), p) => - val u = a :: b.toList - unifyUnion(u.flatMap(intersection(_, p))) - case (p, Union(a, b)) => - val u = a :: b.toList - unifyUnion(u.flatMap(intersection(p, _))) - case _ => - sys.error(s"can't intersect and get here: intersection($left, $right)") - } - } - } + } + case (lp @ ListPat(_), pos @ PositionalStruct(_, _)) => + intersection(pos, lp) + case (PositionalStruct(ln, lps), PositionalStruct(rn, rps)) => + if (ln == rn) { + val la = lps.size + if (rps.size == la) { + // the arity must match or check expr fails + // if the arity doesn't match, just consider this + // a mismatch + unifyUnion( + getProd(la) + .intersection(lps, rps) + .map(PositionalStruct(ln, _)) + ) + } else Nil + } else Nil + case _ => + relate(left, right) match { + case Rel.Disjoint => Nil + case Rel.Sub => left :: Nil + case Rel.Same => normalizePattern(left) :: Nil + case Rel.Super => right :: Nil + case Rel.Intersects => + // we know that neither left nor right can be a top + // value because top >= everything + // + // non-trivial intersection + (left, right) match { + case (Union(a, b), p) => + val u = a :: b.toList + unifyUnion(u.flatMap(intersection(_, p))) + case (p, Union(a, b)) => + val u = a :: b.toList + unifyUnion(u.flatMap(intersection(p, _))) + case _ => + sys.error( + s"can't intersect and get here: intersection($left, $right)" + ) + } + } + } - def difference(left: Pattern[Cons, Type], right: Pattern[Cons, Type]): Patterns = + def difference( + left: Pattern[Cons, Type], + right: Pattern[Cons, Type] + ): Patterns = relate(left, right) match { case Rel.Sub | Rel.Same => Nil - case Rel.Disjoint => left :: Nil + case Rel.Disjoint => left :: Nil case _ => lazy val leftIsTop = isTop(left) // left is a superset of right or intersects @@ -501,28 +580,33 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { // also, right cannot be top, because nothing // is a strict superset or only intersects with top (left, right) match { - case (Named(_, p), r) => difference(p, r) - case (l, Named(_, p)) => difference(l, p) + case (Named(_, p), r) => difference(p, r) + case (l, Named(_, p)) => difference(l, p) case (Annotation(p, _), r) => difference(p, r) case (l, Annotation(p, _)) => difference(l, p) - case (left@ListPat(_), right@ListPat(_)) => + case (left @ ListPat(_), right @ ListPat(_)) => listPatternSetOps.difference(left, right) - case (_, listPat@ListPat(_)) if leftIsTop => + case (_, listPat @ ListPat(_)) if leftIsTop => // _ is the same as [*_] for well typed expressions - listPatternSetOps.difference(ListPat(ListPart.WildList :: Nil), listPat) - case (sa@StrPat(_), Literal(Lit.Str(str))) => + listPatternSetOps.difference( + ListPat(ListPart.WildList :: Nil), + listPat + ) + case (sa @ StrPat(_), Literal(Lit.Str(str))) => strPatternSetOps.difference(sa, StrPat.fromLitStr(str)) - case (sa@StrPat(_), sb@StrPat(_)) => + case (sa @ StrPat(_), sb @ StrPat(_)) => strPatternSetOps.difference(sa, sb) - case (_, right@StrPat(_)) if leftIsTop => + case (_, right @ StrPat(_)) if leftIsTop => // _ is the same as "${_}" for well typed expressions strPatternSetOps.difference(StrPat.Wild, right) case (_, Literal(Lit.Str(s))) if leftIsTop => if (s.isEmpty) { // "${_}" - "" == "$.{_}${_}" - strPatternSetOps.difference(StrPat.Wild, StrPat.fromLitStr("")) - } - else { + strPatternSetOps.difference( + StrPat.Wild, + StrPat.fromLitStr("") + ) + } else { // this is not(str), but we can't represent that, :( topList } @@ -541,7 +625,8 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { case (left, Union(_, _)) => val u = Pattern.flatten(right).toList unifyUnion(differenceAll(left :: Nil, u)) - case (PositionalStruct(ln, lp), PositionalStruct(rn, rp)) if ln == rn => + case (PositionalStruct(ln, lp), PositionalStruct(rn, rp)) + if ln == rn => val la = lp.size if (rp.size == la) { // the arity must match or check expr fails @@ -552,17 +637,16 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { .difference(lp, rp) .map(PositionalStruct(ln, _)) ) - } - else (left :: Nil) - case (PositionalStruct(n, as), rp@ListPat(_)) => + } else (left :: Nil) + case (PositionalStruct(n, as), rp @ ListPat(_)) => structToList(n, as) match { case Some(lp) => difference(lp, rp) - case None => left :: Nil + case None => left :: Nil } - case (lp@ListPat(_), PositionalStruct(n, as)) => + case (lp @ ListPat(_), PositionalStruct(n, as)) => structToList(n, as) match { case Some(rp) => difference(lp, rp) - case None => left :: Nil + case None => left :: Nil } case (_, PositionalStruct(nm, _)) if leftIsTop => inEnv.definedTypeFor(nm) match { @@ -577,16 +661,16 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { if (leftIsTop) topList else left :: Nil } - } + } def isTop(p: Pattern[Cons, Type]): Boolean = p match { case Pattern.WildCard | Pattern.Var(_) => true - case Pattern.Named(_, p) => isTop(p) - case Pattern.Annotation(p, _) => isTop(p) - case Pattern.Literal(_) => false // literals are not total - case s@Pattern.StrPat(_) => strPatternSetOps.isTop(s) - case l@Pattern.ListPat(_) => listPatternSetOps.isTop(l) + case Pattern.Named(_, p) => isTop(p) + case Pattern.Annotation(p, _) => isTop(p) + case Pattern.Literal(_) => false // literals are not total + case s @ Pattern.StrPat(_) => strPatternSetOps.isTop(s) + case l @ Pattern.ListPat(_) => listPatternSetOps.isTop(l) case Pattern.PositionalStruct(name, params) => inEnv.definedTypeFor(name) match { case Some(dt) => @@ -599,31 +683,33 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { unifyUnion(union :: Nil).exists(isTop) } - override def subset(a0: Pattern[Cons, Type], b0: Pattern[Cons, Type]): Boolean = + override def subset( + a0: Pattern[Cons, Type], + b0: Pattern[Cons, Type] + ): Boolean = relate(a0, b0).isSubtype - def relate(a0: Pattern[Cons,Type], b0: Pattern[Cons,Type]): Rel = { + def relate(a0: Pattern[Cons, Type], b0: Pattern[Cons, Type]): Rel = { def loop(a: Pattern[Cons, Type], b: Pattern[Cons, Type]): Rel = (a, b) match { - case _ if a == b => Rel.Same - case (Named(_, p), _) => loop(p, b) - case (_, Named(_, p)) => loop(a, p) + case _ if a == b => Rel.Same + case (Named(_, p), _) => loop(p, b) + case (_, Named(_, p)) => loop(a, p) case (Annotation(p, _), _) => loop(p, b) case (_, Annotation(p, _)) => loop(a, p) - case (_, u@Union(_, _)) => + case (_, u @ Union(_, _)) => val utop = isTop(u) if (isTop(a)) { if (utop) Rel.Same else Rel.Super - } - else if (utop) Rel.Sub + } else if (utop) Rel.Sub else { val (ua, ub) = u.split urm.unionRelCompare(Some(a), Some(ua), Some(ub)) } case (Union(_, _), _) => loop(b, a).invert // All unions have been handled by this point - case (Literal(Lit.Str(s)), sp@Pattern.StrPat(_)) => + case (Literal(Lit.Str(s)), sp @ Pattern.StrPat(_)) => sp.toLiteralString match { case Some(rs) => if (s == rs) Rel.Same @@ -636,17 +722,16 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { if (isTop(b)) Rel.Sub else Rel.Disjoint case (_, Literal(_)) => loop(b, a).invert - case (s1@Pattern.StrPat(_), s2@Pattern.StrPat(_)) => + case (s1 @ Pattern.StrPat(_), s2 @ Pattern.StrPat(_)) => strPatternSetOps.relate(s1, s2) - case (s1@Pattern.StrPat(_), _) => + case (s1 @ Pattern.StrPat(_), _) => if (isTop(b)) { if (s1.isTotal) Rel.Same else Rel.Sub - } - else if (s1.isTotal) Rel.Super + } else if (s1.isTotal) Rel.Super else Rel.Disjoint case (_, Pattern.StrPat(_)) => loop(b, a).invert - case (l1@Pattern.ListPat(_), l2@Pattern.ListPat(_)) => + case (l1 @ Pattern.ListPat(_), l2 @ Pattern.ListPat(_)) => listPatternSetOps.relate(l1, l2) case (lp @ Pattern.ListPat(_), Pattern.PositionalStruct(n, p)) => structToList(n, p) match { @@ -657,41 +742,43 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { } case (Pattern.PositionalStruct(_, _), Pattern.ListPat(_)) => relate(b, a).invert - case (s1@Pattern.ListPat(_), _) => + case (s1 @ Pattern.ListPat(_), _) => if (isTop(b)) { if (listPatternSetOps.isTop(s1)) Rel.Same else Rel.Sub - } - else if (listPatternSetOps.isTop(s1)) Rel.Super + } else if (listPatternSetOps.isTop(s1)) Rel.Super else Rel.Disjoint case (_, Pattern.ListPat(_)) => loop(b, a).invert - case (Pattern.PositionalStruct(ln, lp), Pattern.PositionalStruct(rn, rp)) => + case ( + Pattern.PositionalStruct(ln, lp), + Pattern.PositionalStruct(rn, rp) + ) => if ((ln == rn) && (lp.size == rp.size)) { (lp.zip(rp).foldLeft(Rel.Same: Rel) { case (acc, (l, r)) => acc.lazyCombine(loop(l, r)) }) - } - else Rel.Disjoint + } else Rel.Disjoint case (Pattern.PositionalStruct(_, _), _) => if (isTop(b)) { if (isTop(a)) Rel.Same else Rel.Sub - } - else Rel.Disjoint + } else Rel.Disjoint case (_, Pattern.PositionalStruct(_, _)) => loop(b, a).invert case (Var(_) | WildCard, Var(_) | WildCard) => Rel.Same } - loop(a0, b0) - } + loop(a0, b0) + } - private def unwrap(p: Pattern[Cons, Type]): NonEmptyList[Pattern[Cons, Type]] = + private def unwrap( + p: Pattern[Cons, Type] + ): NonEmptyList[Pattern[Cons, Type]] = p match { - case Named(_, pat) => unwrap(pat) + case Named(_, pat) => unwrap(pat) case Annotation(pat, _) => unwrap(pat) - case u @ Union(_, _) => Pattern.flatten(u).flatMap(unwrap) - case _ => NonEmptyList.one(p) + case u @ Union(_, _) => Pattern.flatten(u).flatMap(unwrap) + case _ => NonEmptyList.one(p) } // Invariant: this returns no top level unions @@ -699,95 +786,101 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { u0.flatMap(unwrap(_).toList) match { case Nil => Nil case singleton @ (one :: Nil) => - if (isTop(one)) topList else singleton + if (isTop(one)) topList else singleton case u => + val structsDs = u + .collect { case Pattern.PositionalStruct(n, a) => (n, a) } + .groupByNel { case (n, a) => (n, a.size) } + .iterator + .flatMap { case ((n, arity), as) => + val optDt = inEnv.definedTypeFor(n) + getProd(arity) + .unifyUnion(as.toList.map(_._2)) + .map { params => + (optDt, Pattern.PositionalStruct(n, params)) + } + } + .toList + + // See if all the structs together have created all the needed + // items for a top value + val hasTopStruct = structsDs + .groupBy(_._1) + .iterator + .exists { + case (Some(dt), dtsPs) => + val topList = topFor(dt) + val structSet = dtsPs.map(_._2).toSet + topList.forall(structSet) + case (None, _) => false + } - val structsDs = u - .collect { case Pattern.PositionalStruct(n, a) => (n, a) } - .groupByNel { case (n, a) => (n, a.size) } - .iterator - .flatMap { case ((n, arity), as) => - val optDt = inEnv.definedTypeFor(n) - getProd(arity) - .unifyUnion(as.toList.map(_._2)) - .map { params => - (optDt, Pattern.PositionalStruct(n, params)) - } - } - .toList - - // See if all the structs together have created all the needed - // items for a top value - val hasTopStruct = structsDs - .groupBy(_._1) - .iterator - .exists { - case (Some(dt), dtsPs) => - val topList = topFor(dt) - val structSet = dtsPs.map(_._2).toSet - topList.forall(structSet) - case (None, _) => false - } - - if (hasTopStruct) topList - else { - val structs = structsDs.map(_._2) - val lps = listPatternSetOps.unifyUnion(u.collect { case lp@Pattern.ListPat(_) => lp }) - val sps = strPatternSetOps.unifyUnion(u.collect { case sp@Pattern.StrPat(_) => sp }) - - if (lps.exists(isTop) || sps.exists(isTop)) topList + if (hasTopStruct) topList else { + val structs = structsDs.map(_._2) + val lps = listPatternSetOps.unifyUnion(u.collect { + case lp @ Pattern.ListPat(_) => lp + }) + val sps = strPatternSetOps.unifyUnion(u.collect { + case sp @ Pattern.StrPat(_) => sp + }) - val strs = u.collect { case Pattern.Literal(ls@Lit.Str(_)) => ls } + if (lps.exists(isTop) || sps.exists(isTop)) topList + else { - val distinctStrs = - strs - .distinct - .filterNot { s => - sps.exists(_.matches(s.toStr)) + val strs = u.collect { case Pattern.Literal(ls @ Lit.Str(_)) => + ls } - .sortBy(_.toStr) - .map(Pattern.Literal(_)) - val notListStr = u.filterNot { - case Pattern.ListPat(_) | Pattern.StrPat(_) | Pattern.Literal(Lit.Str(_)) | Pattern.PositionalStruct(_, _) => true - case _ => false + val distinctStrs = + strs.distinct + .filterNot { s => + sps.exists(_.matches(s.toStr)) + } + .sortBy(_.toStr) + .map(Pattern.Literal(_)) + + val notListStr = u.filterNot { + case Pattern.ListPat(_) | Pattern.StrPat(_) | + Pattern.Literal(Lit.Str(_)) | + Pattern.PositionalStruct(_, _) => + true + case _ => false + }.distinct + + if (notListStr.exists(isTop)) topList + else + (lps ::: sps ::: distinctStrs ::: notListStr ::: structs).sorted + } } - .distinct - - if (notListStr.exists(isTop)) topList - else (lps ::: sps ::: distinctStrs ::: notListStr ::: structs).sorted - } } - } } - /** - * recursively replace as much as possible with Wildcard - * This should match exactly the same set for the same type as - * the previous pattern, without any binding names - */ + /** recursively replace as much as possible with Wildcard This should match + * exactly the same set for the same type as the previous pattern, without + * any binding names + */ def normalizePattern(p: Pattern[Cons, Type]): Pattern[Cons, Type] = p match { - case WildCard | Literal(_) => p - case Var(_) => WildCard - case Named(_, p) => normalizePattern(p) - case Annotation(p, _) => normalizePattern(p) + case WildCard | Literal(_) => p + case Var(_) => WildCard + case Named(_, p) => normalizePattern(p) + case Annotation(p, _) => normalizePattern(p) case _ if patternSetOps.isTop(p) => WildCard - case u@Union(_, _) => + case u @ Union(_, _) => val flattened = Pattern.flatten(u).map(normalizePattern(_)) patternSetOps.unifyUnion(flattened.toList) match { case h :: t => Pattern.union(h, t) - case Nil => + case Nil => // $COVERAGE-OFF$ sys.error("unreachable: union can't remove items") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - case strPat@StrPat(_) => + case strPat @ StrPat(_) => strPat.toLiteralString match { case Some(str) => Literal(Lit.Str(str)) - case None => StrPat.fromSeqPattern(strPat.toSeqPattern) + case None => StrPat.fromSeqPattern(strPat.toSeqPattern) } case ListPat(parts) => val p1 = @@ -802,15 +895,13 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { case PositionalStruct(n, params) => val normParams = params.map(normalizePattern) structToList(n, normParams) match { - case None => PositionalStruct(n, normParams) + case None => PositionalStruct(n, normParams) case Some(lp) => lp } } - /** - * This tells if two patterns for the same type - * would match the same values - */ + /** This tells if two patterns for the same type would match the same values + */ val eqPat: Eq[Pattern[Cons, Type]] = new Eq[Pattern[Cons, Type]] { def eqv(l: Pattern[Cons, Type], r: Pattern[Cons, Type]) = diff --git a/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala b/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala index a23c3e492..ffbb85245 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala @@ -2,10 +2,17 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import cats.parse.{Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import cats.syntax.all._ -import Parser.{ Combinators, MaybeTupleOrParens, lowerIdent, maybeSpace, maybeSpacesAndLines, keySpace } +import Parser.{ + Combinators, + MaybeTupleOrParens, + lowerIdent, + maybeSpace, + maybeSpacesAndLines, + keySpace +} abstract class TypeParser[A] { /* @@ -14,7 +21,10 @@ abstract class TypeParser[A] { protected def parseRoot: P[A] protected def makeFn(in: NonEmptyList[A], out: A): A protected def universal(vars: NonEmptyList[(String, Option[Kind])], in: A): A - protected def existential(vars: NonEmptyList[(String, Option[Kind])], in: A): A + protected def existential( + vars: NonEmptyList[(String, Option[Kind])], + in: A + ): A protected def applyTypes(cons: A, args: NonEmptyList[A]): A protected def makeTuple(items: List[A]): A @@ -23,25 +33,37 @@ abstract class TypeParser[A] { */ protected def unapplyRoot(a: A): Option[Doc] protected def unapplyFn(a: A): Option[(NonEmptyList[A], A)] - protected def unapplyUniversal(a: A): Option[(List[(String, Option[Kind])], A)] - protected def unapplyExistential(a: A): Option[(List[(String, Option[Kind])], A)] + protected def unapplyUniversal( + a: A + ): Option[(List[(String, Option[Kind])], A)] + protected def unapplyExistential( + a: A + ): Option[(List[(String, Option[Kind])], A)] protected def unapplyTypeApply(a: A): Option[(A, List[A])] protected def unapplyTuple(a: A): Option[List[A]] - final val parser: P[A] = P.recursive[A] { recurse => val univItem: P[(String, Option[Kind])] = { val kindP: P[Kind] = - (maybeSpacesAndLines.soft.with1 *> (P.char(':') *> maybeSpacesAndLines *> Kind.parser)) + (maybeSpacesAndLines.soft.with1 *> (P.char( + ':' + ) *> maybeSpacesAndLines *> Kind.parser)) lowerIdent ~ kindP.? } val quantified: P[(NonEmptyList[(String, Option[Kind])], A) => A] = keySpace("forall").as(universal(_, _)) | - keySpace("exists").as(existential(_, _)) + keySpace("exists").as(existential(_, _)) val lambda: P[MaybeTupleOrParens[A]] = - (quantified, univItem.nonEmptyListOfWs(maybeSpacesAndLines) ~ (maybeSpacesAndLines *> P.char('.') *> maybeSpacesAndLines *> recurse)) + ( + quantified, + univItem.nonEmptyListOfWs( + maybeSpacesAndLines + ) ~ (maybeSpacesAndLines *> P.char( + '.' + ) *> maybeSpacesAndLines *> recurse) + ) .mapN { case (fn, (args, e)) => MaybeTupleOrParens.Bare(fn(args, e)) } val tupleOrParens: P[MaybeTupleOrParens[A]] = @@ -49,22 +71,25 @@ abstract class TypeParser[A] { def nonArrow(mtp: MaybeTupleOrParens[A]): A = mtp match { - case MaybeTupleOrParens.Bare(a) => a + case MaybeTupleOrParens.Bare(a) => a case MaybeTupleOrParens.Parens(a) => a case MaybeTupleOrParens.Tuple(as) => makeTuple(as) } val appP: P[MaybeTupleOrParens[A] => MaybeTupleOrParens[A]] = - (P.char('[') *> maybeSpacesAndLines *> recurse.nonEmptyListOfWs(maybeSpacesAndLines) <* maybeSpacesAndLines <* P.char(']')) + (P.char('[') *> maybeSpacesAndLines *> recurse.nonEmptyListOfWs( + maybeSpacesAndLines + ) <* maybeSpacesAndLines <* P.char(']')) .map { args => - { left => MaybeTupleOrParens.Bare(applyTypes(nonArrow(left), args)) } } val arrowP: P[MaybeTupleOrParens[A] => MaybeTupleOrParens[A]] = - ((maybeSpace.with1.soft ~ P.string("->") ~ maybeSpacesAndLines) *> recurse) + ((maybeSpace.with1.soft ~ P.string( + "->" + ) ~ maybeSpacesAndLines) *> recurse) .map { right => { case MaybeTupleOrParens.Bare(a) => @@ -73,7 +98,7 @@ abstract class TypeParser[A] { MaybeTupleOrParens.Bare(makeFn(NonEmptyList.one(a), right)) case MaybeTupleOrParens.Tuple(items) => val args = NonEmptyList.fromList(items) match { - case None => NonEmptyList.one(makeTuple(Nil)) + case None => NonEmptyList.one(makeTuple(Nil)) case Some(nel) => nel } // We know th @@ -81,8 +106,11 @@ abstract class TypeParser[A] { } } - P.oneOf(lambda :: parseRoot.map(MaybeTupleOrParens.Bare(_)) :: tupleOrParens :: Nil) - .maybeAp(appP) + P.oneOf( + lambda :: parseRoot.map( + MaybeTupleOrParens.Bare(_) + ) :: tupleOrParens :: Nil + ).maybeAp(appP) .maybeAp(arrowP) .map(nonArrow) } @@ -100,7 +128,7 @@ abstract class TypeParser[A] { case None => () case Some(ts) => return ts match { - case Nil => unitDoc + case Nil => unitDoc case h :: Nil => Doc.char('(') + toDoc(h) + commaPar case twoAndMore => p(Doc.intercalate(commaSpace, twoAndMore.map(toDoc))) @@ -118,19 +146,18 @@ abstract class TypeParser[A] { .orElse(unapplyExistential(in0)) .orElse(unapplyTuple(in0)) match { case Some(_) => par(din) - case None => din + case None => din } - } - else { + } else { // there is more than 1 arg so parens are always used: (a, b) -> c par(Doc.intercalate(commaSpace, ins.toList.map(toDoc))) - + } return (args + (spaceArrow + toDoc(out))) } unapplyRoot(a) match { - case None => () + case None => () case Some(d) => return d } @@ -139,31 +166,38 @@ abstract class TypeParser[A] { case Some((of, args)) => val ofDoc0 = toDoc(of) val ofDoc = unapplyUniversal(of).orElse(unapplyExistential(of)) match { - case None => ofDoc0 + case None => ofDoc0 case Some(_) => par(ofDoc0) } - return ofDoc + Doc.char('[') + Doc.intercalate(commaSpace, args.map(toDoc)) + Doc.char(']') + return ofDoc + Doc.char('[') + Doc.intercalate( + commaSpace, + args.map(toDoc) + ) + Doc.char(']') } unapplyUniversal(a) match { case None => () case Some((vars, in)) => - return forAll + Doc.intercalate(commaSpace, + return forAll + Doc.intercalate( + commaSpace, vars.map { - case (a, None) => Doc.text(a) + case (a, None) => Doc.text(a) case (a, Some(k)) => Doc.text(a) + TypeParser.colonSpace + k.toDoc - }) + + } + ) + Doc.char('.') + Doc.space + toDoc(in) } unapplyExistential(a) match { case None => () case Some((vars, in)) => - return exists + Doc.intercalate(commaSpace, + return exists + Doc.intercalate( + commaSpace, vars.map { - case (a, None) => Doc.text(a) + case (a, None) => Doc.text(a) case (a, Some(k)) => Doc.text(a) + TypeParser.colonSpace + k.toDoc - }) + + } + ) + Doc.char('.') + Doc.space + toDoc(in) } diff --git a/core/src/main/scala/org/bykn/bosatsu/TypeRef.scala b/core/src/main/scala/org/bykn/bosatsu/TypeRef.scala index ab7a63cff..b166d9138 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypeRef.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypeRef.scala @@ -6,47 +6,45 @@ import cats.implicits._ import cats.parse.{Parser => P, Parser0} import org.bykn.bosatsu.rankn.Type import org.bykn.bosatsu.{TypeName => Name} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import Parser.{lowerIdent, maybeSpace, Combinators} -/** - * This AST is the syntactic version of Type - * it is shaped slightly differently to match the way - * the syntax looks (nested non empty lists are explicit - * whereas we use a recursion/cons style in Type - */ +/** This AST is the syntactic version of Type it is shaped slightly differently + * to match the way the syntax looks (nested non empty lists are explicit + * whereas we use a recursion/cons style in Type + */ sealed abstract class TypeRef { import TypeRef._ def toDoc: Doc = TypeRef.document.document(this) - /** - * Nested TypeForAll can be combined, and should be generally - */ + /** Nested TypeForAll can be combined, and should be generally + */ def normalizeForAll: TypeRef = this match { case TypeVar(_) | TypeName(_) => this - case TypeArrow(a, b) => TypeArrow(a.map(_.normalizeForAll), b.normalizeForAll) + case TypeArrow(a, b) => + TypeArrow(a.map(_.normalizeForAll), b.normalizeForAll) case TypeApply(a, bs) => TypeApply(a.normalizeForAll, bs.map(_.normalizeForAll)) case TypeForAll(pars, e) => // Remove `Some(Type)` since that's the default val normPars = pars.map { case (v, Some(Kind.Type)) => (v, None) - case other => other + case other => other } // we normalize to lifting all the foralls to the outside e.normalizeForAll match { case TypeForAll(p1, in) => TypeForAll(normPars ::: p1, in) - case notForAll => TypeForAll(normPars, notForAll) + case notForAll => TypeForAll(normPars, notForAll) } case TypeExists(pars, e) => // Remove `Some(Type)` since that's the default val normPars = pars.map { case (v, Some(Kind.Type)) => (v, None) - case other => other + case other => other } // we normalize to lifting all the foralls to the outside e.normalizeForAll match { @@ -65,10 +63,9 @@ sealed abstract class TypeRef { object TypeRef { private val colonSpace = Doc.text(": ") - def argDoc[A: Document](st: (A, Option[TypeRef])): Doc = st match { - case (s, None) => Document[A].document(s) + case (s, None) => Document[A].document(s) case (s, Some(tr)) => Document[A].document(s) + colonSpace + (tr.toDoc) } @@ -86,8 +83,14 @@ object TypeRef { case class TypeApply(of: TypeRef, args: NonEmptyList[TypeRef]) extends TypeRef - case class TypeForAll(params: NonEmptyList[(TypeVar, Option[Kind])], in: TypeRef) extends TypeRef - case class TypeExists(params: NonEmptyList[(TypeVar, Option[Kind])], in: TypeRef) extends TypeRef + case class TypeForAll( + params: NonEmptyList[(TypeVar, Option[Kind])], + in: TypeRef + ) extends TypeRef + case class TypeExists( + params: NonEmptyList[(TypeVar, Option[Kind])], + in: TypeRef + ) extends TypeRef case class TypeTuple(params: List[TypeRef]) extends TypeRef implicit val typeRefOrdering: Ordering[TypeRef] = @@ -99,36 +102,46 @@ object TypeRef { def compare(a: TypeRef, b: TypeRef): Int = (a, b) match { - case (TypeVar(v0), TypeVar(v1)) => v0.compareTo(v1) - case (TypeVar(_), _) => -1 + case (TypeVar(v0), TypeVar(v1)) => v0.compareTo(v1) + case (TypeVar(_), _) => -1 case (TypeName(v0), TypeName(v1)) => Ordering[Name].compare(v0, v1) - case (TypeName(_), TypeVar(_)) => 1 - case (TypeName(_), _) => -1 + case (TypeName(_), TypeVar(_)) => 1 + case (TypeName(_), _) => -1 case (TypeArrow(a0, b0), TypeArrow(a1, b1)) => val c = nelTR.compare(a0, a1) if (c == 0) compare(b0, b1) else c case (TypeArrow(_, _), TypeVar(_) | TypeName(_)) => 1 - case (TypeArrow(_, _), _) => -1 + case (TypeArrow(_, _), _) => -1 case (TypeApply(o0, a0), TypeApply(o1, a1)) => val c = compare(o0, o1) if (c != 0) c else list.compare(a0.toList, a1.toList) - case (TypeApply(_, _), TypeVar(_) | TypeName(_) | TypeArrow(_, _)) => 1 - case (TypeApply(_, _), _) => -1 + case (TypeApply(_, _), TypeVar(_) | TypeName(_) | TypeArrow(_, _)) => + 1 + case (TypeApply(_, _), _) => -1 case (TypeForAll(p0, in0), TypeForAll(p1, in1)) => // TODO, we could normalize the parmeters here val c = nelistKind.compare(p0, p1) if (c == 0) compare(in0, in1) else c - case (TypeForAll(_, _), TypeVar(_) | TypeName(_) | TypeArrow(_, _) | TypeApply(_, _)) => 1 + case ( + TypeForAll(_, _), + TypeVar(_) | TypeName(_) | TypeArrow(_, _) | TypeApply(_, _) + ) => + 1 case (TypeForAll(_, _), TypeTuple(_) | TypeExists(_, _)) => -1 - case (TypeExists(p0, in0), TypeExists(p1, in1)) => + case (TypeExists(p0, in0), TypeExists(p1, in1)) => // TODO, we could normalize the parmeters here val c = nelistKind.compare(p0, p1) if (c == 0) compare(in0, in1) else c - case (TypeExists(_, _), TypeForAll(_, _) | TypeVar(_) | TypeName(_) | TypeArrow(_, _) | TypeApply(_, _)) => 1 - case (TypeExists(_, _), _) => -1 + case ( + TypeExists(_, _), + TypeForAll(_, _) | TypeVar(_) | TypeName(_) | TypeArrow(_, _) | + TypeApply(_, _) + ) => + 1 + case (TypeExists(_, _), _) => -1 case (TypeTuple(t0), TypeTuple(t1)) => list.compare(t0, t1) - case (TypeTuple(_), _) => 1 + case (TypeTuple(_), _) => 1 } } @@ -139,13 +152,21 @@ object TypeRef { tvar.orElse(tname) } - def makeFn(in: NonEmptyList[TypeRef], out: TypeRef): TypeRef = TypeArrow(in, out) - - def applyTypes(cons: TypeRef, args: NonEmptyList[TypeRef]): TypeRef = TypeApply(cons, args) - def universal(vars: NonEmptyList[(String, Option[Kind])], in: TypeRef): TypeRef = + def makeFn(in: NonEmptyList[TypeRef], out: TypeRef): TypeRef = + TypeArrow(in, out) + + def applyTypes(cons: TypeRef, args: NonEmptyList[TypeRef]): TypeRef = + TypeApply(cons, args) + def universal( + vars: NonEmptyList[(String, Option[Kind])], + in: TypeRef + ): TypeRef = TypeForAll(vars.map { case (s, k) => (TypeVar(s), k) }, in) - def existential(vars: NonEmptyList[(String, Option[Kind])], in: TypeRef): TypeRef = + def existential( + vars: NonEmptyList[(String, Option[Kind])], + in: TypeRef + ): TypeRef = TypeExists(vars.map { case (s, k) => (TypeVar(s), k) }, in) def makeTuple(items: List[TypeRef]): TypeRef = TypeTuple(items) @@ -153,38 +174,44 @@ object TypeRef { def unapplyRoot(a: TypeRef): Option[Doc] = a match { case TypeName(n) => Some(Document[Identifier].document(n.ident)) - case TypeVar(s) => Some(Doc.text(s)) - case _ => None + case TypeVar(s) => Some(Doc.text(s)) + case _ => None } def unapplyFn(a: TypeRef): Option[(NonEmptyList[TypeRef], TypeRef)] = a match { case TypeArrow(a, b) => Some((a, b)) - case _ => None + case _ => None } - def unapplyUniversal(a: TypeRef): Option[(List[(String, Option[Kind])], TypeRef)] = + def unapplyUniversal( + a: TypeRef + ): Option[(List[(String, Option[Kind])], TypeRef)] = a match { - case TypeForAll(vs, a) => Some(((vs.map { case (v, k) => (v.asString, k) }).toList, a)) + case TypeForAll(vs, a) => + Some(((vs.map { case (v, k) => (v.asString, k) }).toList, a)) case _ => None } - def unapplyExistential(a: TypeRef): Option[(List[(String, Option[Kind])], TypeRef)] = + def unapplyExistential( + a: TypeRef + ): Option[(List[(String, Option[Kind])], TypeRef)] = a match { - case TypeExists(vs, a) => Some(((vs.map { case (v, k) => (v.asString, k) }).toList, a)) + case TypeExists(vs, a) => + Some(((vs.map { case (v, k) => (v.asString, k) }).toList, a)) case _ => None } def unapplyTypeApply(a: TypeRef): Option[(TypeRef, List[TypeRef])] = a match { case TypeApply(a, args) => Some((a, args.toList)) - case _ => None + case _ => None } def unapplyTuple(a: TypeRef): Option[List[TypeRef]] = a match { case TypeTuple(as) => Some(as) - case _ => None + case _ => None } } @@ -198,7 +225,9 @@ object TypeRef { targs match { case Nil => Doc.empty case nonEmpty => - val params = nonEmpty.map { case (TypeRef.TypeVar(v), a) => Doc.text(v) + aDoc(a) } + val params = nonEmpty.map { case (TypeRef.TypeVar(v), a) => + Doc.text(v) + aDoc(a) + } Doc.char('[') + Doc.intercalate(Doc.text(", "), params) + Doc.char(']') } @@ -207,4 +236,3 @@ object TypeRef { nel.map { case (s, a) => (TypeRef.TypeVar(s.intern), a) } } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala b/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala index ba3ac1449..15abe8ebc 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala @@ -8,18 +8,20 @@ import org.bykn.bosatsu.Identifier.Constructor import cats.implicits._ object TypeRefConverter { - /** - * given the ability to convert a name to a fully resolved - * type constant, convert TypeRef to Type - */ - def apply[F[_]: Applicative](t: TypeRef)(nameToType: Constructor => F[Type.Const]): F[Type] = { + + /** given the ability to convert a name to a fully resolved type constant, + * convert TypeRef to Type + */ + def apply[F[_]: Applicative]( + t: TypeRef + )(nameToType: Constructor => F[Type.Const]): F[Type] = { def toType(t: TypeRef): F[Type] = apply(t)(nameToType) import Type._ import TypeRef._ t match { - case TypeVar(v) => Applicative[F].pure(TyVar(Type.Var.Bound(v))) + case TypeVar(v) => Applicative[F].pure(TyVar(Type.Var.Bound(v))) case TypeName(n) => nameToType(n.ident).map(TyConst(_)) case TypeArrow(as, b) => (as.traverse(toType(_)), toType(b)).mapN(Fun(_, _)) @@ -27,23 +29,29 @@ object TypeRefConverter { (toType(a), bs.toList.traverse(toType)).mapN(Type.applyAll(_, _)) case TypeForAll(pars, e) => toType(e).map { te => - Type.forAll(pars.map { case (TypeVar(v), optK) => - val k = optK match { - case None => Kind.Type - case Some(k) => k - } - (Type.Var.Bound(v), k) - }, te) + Type.forAll( + pars.map { case (TypeVar(v), optK) => + val k = optK match { + case None => Kind.Type + case Some(k) => k + } + (Type.Var.Bound(v), k) + }, + te + ) } case TypeExists(pars, e) => toType(e).map { te => - Type.exists(pars.map { case (TypeVar(v), optK) => - val k = optK match { - case None => Kind.Type - case Some(k) => k - } - (Type.Var.Bound(v), k) - }, te) + Type.exists( + pars.map { case (TypeVar(v), optK) => + val k = optK match { + case None => Kind.Type + case Some(k) => k + } + (Type.Var.Bound(v), k) + }, + te + ) } case TypeTuple(ts) => ts.traverse(toType).map(Type.Tuple(_)) @@ -51,10 +59,11 @@ object TypeRefConverter { } def fromTypeA[F[_]: Applicative]( - tpe: Type, - onSkolem: Type.Var.Skolem => F[TypeRef], - onMeta: Long => F[TypeRef], - onConst: Type.Const.Defined => F[TypeRef]): F[TypeRef] = { + tpe: Type, + onSkolem: Type.Var.Skolem => F[TypeRef], + onMeta: Long => F[TypeRef], + onConst: Type.Const.Defined => F[TypeRef] + ): F[TypeRef] = { import Type._ import TypeRef._ @@ -76,17 +85,16 @@ object TypeRefConverter { case Type.Tuple(ts) => // this needs to be above TyConst ts.traverse(loop(_)).map(TypeTuple(_)) - case TyConst(defined@Type.Const.Defined(_, _)) => + case TyConst(defined @ Type.Const.Defined(_, _)) => onConst(defined) case Type.Fun(args, to) => (args.traverse(loop), loop(to)).mapN { (ftr, ttr) => TypeArrow(ftr, ttr) } - case ta@TyApply(_, _) => + case ta @ TyApply(_, _) => val (on, args) = unapplyAll(ta) - (loop(on), args.traverse(loop)).mapN { - (of, arg1) => - TypeApply(of, NonEmptyList.fromListUnsafe(arg1)) + (loop(on), args.traverse(loop)).mapN { (of, arg1) => + TypeApply(of, NonEmptyList.fromListUnsafe(arg1)) } case TyVar(tv) => tv match { @@ -101,7 +109,7 @@ object TypeRefConverter { case other => // the extractors mess this up sys.error(s"unreachable: $other") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } } diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala index 9e880e695..8dd703b05 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala @@ -5,7 +5,7 @@ import cats.arrow.FunctionK import cats.data.{NonEmptyList, Writer} import cats.implicits._ import org.bykn.bosatsu.rankn.Type -import org.typelevel.paiges.{Doc, Document } +import org.typelevel.paiges.{Doc, Document} import scala.collection.immutable.SortedSet import scala.util.hashing.MurmurHash3 @@ -21,23 +21,20 @@ sealed abstract class TypedExpr[+T] { self: Product => MurmurHash3.productHash(this) def tag: T - /** - * For any well typed expression, i.e. - * one that has already gone through type - * inference, we should be able to get a type - * for each expression - * - */ + + /** For any well typed expression, i.e. one that has already gone through type + * inference, we should be able to get a type for each expression + */ lazy val getType: Type = this match { - case g@Generic(_, _) => g.quantType + case g @ Generic(_, _) => g.quantType case Annotation(_, tpe) => tpe case AnnotatedLambda(args, res, _) => Type.Fun(args.map(_._2), res.getType) - case Local(_, tpe, _) => tpe + case Local(_, tpe, _) => tpe case Global(_, _, tpe, _) => tpe - case App(_, _, tpe, _) => tpe + case App(_, _, tpe, _) => tpe case Let(_, _, in, _, _) => in.getType case Literal(_, tpe, _) => @@ -47,9 +44,9 @@ sealed abstract class TypedExpr[+T] { self: Product => branches.head._2.getType } - lazy val size: Int = + lazy val size: Int = this match { - case Generic(_, g) => g.size + case Generic(_, g) => g.size case Annotation(a, _) => a.size case AnnotatedLambda(_, res, _) => res.size @@ -58,7 +55,7 @@ sealed abstract class TypedExpr[+T] { self: Product => case Let(_, e, in, _, _) => e.size + in.size case Match(a, branches, _) => a.size + branches.foldMap(_._2.size) - } + } // TODO: we need to make sure this parsable and maybe have a mode that has the compiler // emit these @@ -67,29 +64,50 @@ sealed abstract class TypedExpr[+T] { self: Product => def loop(te: TypedExpr[T]): Doc = { te match { - case g@Generic(_, expr) => - (Doc.text("(generic") + Doc.lineOrSpace + rept(g.quantType) + Doc.lineOrSpace + loop(expr) + Doc.char(')')).nested(4) + case g @ Generic(_, expr) => + (Doc.text("(generic") + Doc.lineOrSpace + rept( + g.quantType + ) + Doc.lineOrSpace + loop(expr) + Doc.char(')')).nested(4) case Annotation(expr, tpe) => - (Doc.text("(ann") + Doc.lineOrSpace + rept(tpe) + Doc.lineOrSpace + loop(expr) + Doc.char(')')).nested(4) + (Doc.text("(ann") + Doc.lineOrSpace + rept( + tpe + ) + Doc.lineOrSpace + loop(expr) + Doc.char(')')).nested(4) case AnnotatedLambda(args, res, _) => (Doc.text("(lambda") + Doc.lineOrSpace + ( - Doc.char('[') + Doc.intercalate(Doc.lineOrSpace, args.toList.map { case (arg, tpe) => + Doc.char('[') + Doc.intercalate( + Doc.lineOrSpace, + args.toList.map { case (arg, tpe) => Doc.text(arg.sourceCodeRepr) + Doc.lineOrSpace + rept(tpe) - }) + Doc.char(']') - ) + Doc.lineOrSpace + loop(res) + Doc.char(')')).nested(4) + } + ) + Doc.char(']') + ) + Doc.lineOrSpace + loop(res) + Doc.char(')')).nested(4) case Local(v, tpe, _) => - (Doc.text("(var") + Doc.lineOrSpace + Doc.text(v.sourceCodeRepr) + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) + (Doc.text("(var") + Doc.lineOrSpace + Doc.text( + v.sourceCodeRepr + ) + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) case Global(p, v, tpe, _) => val pstr = Doc.text(p.asString + "::" + v.sourceCodeRepr) - (Doc.text("(var") + Doc.lineOrSpace + pstr + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) + (Doc.text("(var") + Doc.lineOrSpace + pstr + Doc.lineOrSpace + rept( + tpe + ) + Doc.char(')')).nested(4) case App(fn, args, tpe, _) => val argsDoc = Doc.intercalate(Doc.lineOrSpace, args.toList.map(loop)) - (Doc.text("(ap") + Doc.lineOrSpace + loop(fn) + Doc.lineOrSpace + argsDoc + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) + (Doc.text("(ap") + Doc.lineOrSpace + loop( + fn + ) + Doc.lineOrSpace + argsDoc + Doc.lineOrSpace + rept(tpe) + Doc + .char(')')).nested(4) case Let(n, b, in, rec, _) => - val nm = if (rec.isRecursive) Doc.text("(letrec") else Doc.text("(let") - (nm + Doc.lineOrSpace + Doc.text(n.sourceCodeRepr) + Doc.lineOrSpace + loop(b) + Doc.lineOrSpace + loop(in) + Doc.char(')')).nested(4) + val nm = + if (rec.isRecursive) Doc.text("(letrec") else Doc.text("(let") + (nm + Doc.lineOrSpace + Doc.text( + n.sourceCodeRepr + ) + Doc.lineOrSpace + loop(b) + Doc.lineOrSpace + loop(in) + Doc.char( + ')' + )).nested(4) case Literal(v, tpe, _) => - (Doc.text("(lit") + Doc.lineOrSpace + Doc.text(v.repr) + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) + (Doc.text("(lit") + Doc.lineOrSpace + Doc.text( + v.repr + ) + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) case Match(arg, branches, _) => implicit val docType: Document[Type] = Document.instance { tpe => rept(tpe) } @@ -98,22 +116,22 @@ sealed abstract class TypedExpr[+T] { self: Product => cpat.document(p) val bstr = branches.toList.map { case (p, t) => - (Doc.char('[') + pat(p) + Doc.comma + Doc.lineOrSpace + loop(t).grouped + Doc.char(']')).nested(4) + (Doc.char('[') + pat(p) + Doc.comma + Doc.lineOrSpace + loop( + t + ).grouped + Doc.char(']')).nested(4) } (Doc.text("(match") + Doc.lineOrSpace + loop(arg) + (Doc.hardLine + - Doc.intercalate(Doc.hardLine, bstr)).nested(4) + Doc.char(')')).nested(4) + Doc.intercalate(Doc.hardLine, bstr)).nested(4) + Doc.char(')')) + .nested(4) } } loop(this) } - - /** - * All the free variables in this expression in order - * encountered and with duplicates (to see how often - * they appear) - */ + /** All the free variables in this expression in order encountered and with + * duplicates (to see how often they appear) + */ lazy val freeVarsDup: List[Bindable] = // nearly identical code to Expr.freeVarsDup, bugs should be fixed in both places this match { @@ -135,8 +153,7 @@ sealed abstract class TypedExpr[+T] { self: Product => val argFree = if (rec.isRecursive) { ListUtil.filterNot(argFree0)(_ === arg) - } - else argFree0 + } else argFree0 argFree ::: (ListUtil.filterNot(in.freeVarsDup)(_ === arg)) case Literal(_, _, _) => @@ -152,8 +169,7 @@ sealed abstract class TypedExpr[+T] { self: Product => else ListUtil.filterNot(bfree)(newBinds) } // we can only take one branch, so count the max on each branch: - val branchFreeMax = branchFrees - .zipWithIndex + val branchFreeMax = branchFrees.zipWithIndex .flatMap { case (names, br) => names.map((_, br)) } // these groupBys are okay because we sort at the end .groupBy(identity) // group-by-name x branch @@ -176,55 +192,86 @@ sealed abstract class TypedExpr[+T] { self: Product => object TypedExpr { - type Rho[A] = TypedExpr[A] // an expression with a Rho type (no top level forall) + type Rho[A] = + TypedExpr[A] // an expression with a Rho type (no top level forall) sealed abstract class Name[A] extends TypedExpr[A] with Product - /** - * This says that the resulting term is generic on a given param - * - * The paper says to add TyLam and TyApp nodes, but it never mentions what to do with them - */ - case class Generic[T](quant: Type.Quantification, in: TypedExpr[T]) extends TypedExpr[T] { - lazy val quantType: Type.Quantified = + + /** This says that the resulting term is generic on a given param + * + * The paper says to add TyLam and TyApp nodes, but it never mentions what to + * do with them + */ + case class Generic[T](quant: Type.Quantification, in: TypedExpr[T]) + extends TypedExpr[T] { + lazy val quantType: Type.Quantified = Type.quantify(quant, in.getType) def tag: T = in.tag } // Annotation really means "widen", the term has a type that is a subtype of coerce, so we are widening // to the given type. This happens on Locals/Globals also in their tpe - case class Annotation[T](term: TypedExpr[T], coerce: Type) extends TypedExpr[T] { + case class Annotation[T](term: TypedExpr[T], coerce: Type) + extends TypedExpr[T] { def tag: T = term.tag } - case class AnnotatedLambda[T](args: NonEmptyList[(Bindable, Type)], expr: TypedExpr[T], tag: T) extends TypedExpr[T] + case class AnnotatedLambda[T]( + args: NonEmptyList[(Bindable, Type)], + expr: TypedExpr[T], + tag: T + ) extends TypedExpr[T] case class Local[T](name: Bindable, tpe: Type, tag: T) extends Name[T] - case class Global[T](pack: PackageName, name: Identifier, tpe: Type, tag: T) extends Name[T] - case class App[T](fn: TypedExpr[T], args: NonEmptyList[TypedExpr[T]], result: Type, tag: T) extends TypedExpr[T] - case class Let[T](arg: Bindable, expr: TypedExpr[T], in: TypedExpr[T], recursive: RecursionKind, tag: T) extends TypedExpr[T] + case class Global[T](pack: PackageName, name: Identifier, tpe: Type, tag: T) + extends Name[T] + case class App[T]( + fn: TypedExpr[T], + args: NonEmptyList[TypedExpr[T]], + result: Type, + tag: T + ) extends TypedExpr[T] + case class Let[T]( + arg: Bindable, + expr: TypedExpr[T], + in: TypedExpr[T], + recursive: RecursionKind, + tag: T + ) extends TypedExpr[T] // TODO, this shouldn't have a type, we know the type from Lit currently case class Literal[T](lit: Lit, tpe: Type, tag: T) extends TypedExpr[T] - case class Match[T](arg: TypedExpr[T], branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], TypedExpr[T])], tag: T) extends TypedExpr[T] - - def letAllNonRec[T](binds: NonEmptyList[(Bindable, TypedExpr[T])], in: TypedExpr[T], tag: T): Let[T] = { + case class Match[T]( + arg: TypedExpr[T], + branches: NonEmptyList[ + (Pattern[(PackageName, Constructor), Type], TypedExpr[T]) + ], + tag: T + ) extends TypedExpr[T] + + def letAllNonRec[T]( + binds: NonEmptyList[(Bindable, TypedExpr[T])], + in: TypedExpr[T], + tag: T + ): Let[T] = { val in1 = binds.tail match { - case Nil => in + case Nil => in case h1 :: t1 => letAllNonRec(NonEmptyList(h1, t1), in, tag) } val (n, ne) = binds.head Let(n, ne, in1, RecursionKind.NonRecursive, tag) } - /** - * If we expect expr to be a lambda of the given arity, return - * the parameter names and types and the rest of the body - */ - def toArgsBody[A](arity: Int, expr: TypedExpr[A]): Option[(NonEmptyList[(Bindable, Type)], TypedExpr[A])] = + /** If we expect expr to be a lambda of the given arity, return the parameter + * names and types and the rest of the body + */ + def toArgsBody[A]( + arity: Int, + expr: TypedExpr[A] + ): Option[(NonEmptyList[(Bindable, Type)], TypedExpr[A])] = expr match { - case Generic(_, e) => toArgsBody(arity, e) + case Generic(_, e) => toArgsBody(arity, e) case Annotation(e, _) => toArgsBody(arity, e) case AnnotatedLambda(args, expr, _) => if (args.length == arity) { Some((args, expr)) - } - else { + } else { None } case Let(arg, e, in, r, t) => @@ -236,8 +283,7 @@ object TypedExpr { // can't lift, we could alpha-rename to // deal with this case None - } - else { + } else { // push it down: Some((args, Let(arg, e, body, r, t))) } @@ -251,8 +297,7 @@ object TypedExpr { // can't lift, we could alpha-rename to // deal with this case None - } - else { + } else { Some((n, (p, b1))) } } @@ -261,8 +306,7 @@ object TypedExpr { argSetO.flatMap { argSet => if (argSet.map(_._1).toList.toSet.size == 1) { Some((argSet.head._1, Match(arg, argSet.map(_._2), tag))) - } - else { + } else { None } } @@ -274,7 +318,9 @@ object TypedExpr { implicit class InvariantTypedExpr[A](val self: TypedExpr[A]) extends AnyVal { def allTypes: SortedSet[Type] = - traverseType { t => Writer[SortedSet[Type], Type](SortedSet(t), t) }.run._1 + traverseType { t => + Writer[SortedSet[Type], Type](SortedSet(t), t) + }.run._1 def allBound: SortedSet[Type.Var.Bound] = traverseType { @@ -286,49 +332,51 @@ object TypedExpr { def freeTyVars: List[Type.Var] = { def loop(self: TypedExpr[A]): Set[Type.Var] = - self match { - case Generic(quant, expr) => - loop(expr) -- quant.vars.iterator.map(_._1) - case Annotation(of, tpe) => - loop(of) ++ Type.freeTyVars(tpe :: Nil) - case AnnotatedLambda(args, res, _) => - loop(res) ++ Type.freeTyVars(args.toList.map { case (_, t) => t }) - case Local(_, tpe, _) => - Type.freeTyVars(tpe :: Nil).toSet - case Global(_, _, tpe, _) => - // this shouldn't happen but does in generated tests - Type.freeTyVars(tpe :: Nil).toSet - case App(f, args, tpe, _) => - args.foldLeft(loop(f))(_ | loop(_)) ++ - Type.freeTyVars(tpe :: Nil) - case Let(_, exp, in, _, _) => - loop(exp) | loop(in) - case Literal(_, tpe, _) => - // this shouldn't happen but does in generated tests - Type.freeTyVars(tpe :: Nil).toSet - case Match(expr, branches, _) => - // all branches have the same type: - branches.foldLeft(loop(expr)) { case (acc, (p, t)) => + self match { + case Generic(quant, expr) => + loop(expr) -- quant.vars.iterator.map(_._1) + case Annotation(of, tpe) => + loop(of) ++ Type.freeTyVars(tpe :: Nil) + case AnnotatedLambda(args, res, _) => + loop(res) ++ Type.freeTyVars(args.toList.map { case (_, t) => t }) + case Local(_, tpe, _) => + Type.freeTyVars(tpe :: Nil).toSet + case Global(_, _, tpe, _) => + // this shouldn't happen but does in generated tests + Type.freeTyVars(tpe :: Nil).toSet + case App(f, args, tpe, _) => + args.foldLeft(loop(f))(_ | loop(_)) ++ + Type.freeTyVars(tpe :: Nil) + case Let(_, exp, in, _, _) => + loop(exp) | loop(in) + case Literal(_, tpe, _) => + // this shouldn't happen but does in generated tests + Type.freeTyVars(tpe :: Nil).toSet + case Match(expr, branches, _) => + // all branches have the same type: + branches.foldLeft(loop(expr)) { case (acc, (p, t)) => (acc | loop(t)) ++ allPatternTypes(p).iterator.collect { case Type.TyVar(v) => v } - } - } + } + } loop(self).toList.sorted } - /** - * Traverse all the *non-shadowed* types inside the TypedExpr - */ + + /** Traverse all the *non-shadowed* types inside the TypedExpr + */ def traverseType[F[_]: Applicative](fn: Type => F[Type]): F[TypedExpr[A]] = self match { case gen @ Generic(quant, expr) => // params shadow below, so they are not free values // and can easily create bugs if passed into fn val params = quant.vars - val shadowed: Set[Type.Var.Bound] = params.toList.iterator.map(_._1).toSet + val shadowed: Set[Type.Var.Bound] = + params.toList.iterator.map(_._1).toSet val shadowFn: Type => F[Type] = { - case tvar@Type.TyVar(v: Type.Var.Bound) if shadowed(v) => Applicative[F].pure(tvar) + case tvar @ Type.TyVar(v: Type.Var.Bound) if shadowed(v) => + Applicative[F].pure(tvar) case notShadowed => fn(notShadowed) } @@ -337,19 +385,20 @@ object TypedExpr { .map(Generic(quant, _)) case Annotation(of, tpe) => (of.traverseType(fn), fn(tpe)).mapN(Annotation(_, _)) - case lam@AnnotatedLambda(args, res, tag) => + case lam @ AnnotatedLambda(args, res, tag) => val a1 = args.traverse { case (n, t) => fn(t).map(n -> _) } fn(lam.getType) *> (a1, res.traverseType(fn)).mapN { - AnnotatedLambda( _, _, tag) + AnnotatedLambda(_, _, tag) } case Local(v, tpe, tag) => fn(tpe).map(Local(v, _, tag)) case Global(p, v, tpe, tag) => fn(tpe).map(Global(p, v, _, tag)) case App(f, args, tpe, tag) => - (f.traverseType(fn), args.traverse(_.traverseType(fn)), fn(tpe)).mapN { - App(_, _, _, tag) - } + (f.traverseType(fn), args.traverse(_.traverseType(fn)), fn(tpe)) + .mapN { + App(_, _, _, tag) + } case Let(v, exp, in, rec, tag) => (exp.traverseType(fn), in.traverseType(fn)).mapN { Let(v, _, _, rec, tag) @@ -358,18 +407,18 @@ object TypedExpr { fn(tpe).map(Literal(lit, _, tag)) case Match(expr, branches, tag) => // all branches have the same type: - val tbranch = branches.traverse { - case (p, t) => - p.traverseType(fn).product(t.traverseType(fn)) + val tbranch = branches.traverse { case (p, t) => + p.traverseType(fn).product(t.traverseType(fn)) } (expr.traverseType(fn), tbranch).mapN(Match(_, _, tag)) } - /** - * This applies fn on all the contained types, replaces the elements, then calls on the - * resulting. This is "bottom up" - */ - def traverseUp[F[_]: Monad](fn: TypedExpr[A] => F[TypedExpr[A]]): F[TypedExpr[A]] = { + /** This applies fn on all the contained types, replaces the elements, then + * calls on the resulting. This is "bottom up" + */ + def traverseUp[F[_]: Monad]( + fn: TypedExpr[A] => F[TypedExpr[A]] + ): F[TypedExpr[A]] = { // be careful not to mistake loop with fn def loop(te: TypedExpr[A]): F[TypedExpr[A]] = te.traverseUp(fn) @@ -386,8 +435,8 @@ object TypedExpr { loop(res).flatMap { res1 => fn(AnnotatedLambda(args, res1, tag)) } - case v@(Global(_, _, _, _) | Local(_, _, _) | Literal(_, _, _)) => - fn(v) + case v @ (Global(_, _, _, _) | Local(_, _, _) | Literal(_, _, _)) => + fn(v) case App(f, args, tpe, tag) => (loop(f), args.traverse(loop(_))) .mapN(App(_, _, tpe, tag)) @@ -397,8 +446,8 @@ object TypedExpr { .mapN(Let(v, _, _, rec, tag)) .flatMap(fn) case Match(expr, branches, tag) => - val tbranch = branches.traverse { - case (p, t) => loop(t).map((p, _)) + val tbranch = branches.traverse { case (p, t) => + loop(t).map((p, _)) } (loop(expr), tbranch) .mapN(Match(_, _, tag)) @@ -406,32 +455,33 @@ object TypedExpr { } } - /** - * Here are all the global names inside this expression - */ + /** Here are all the global names inside this expression + */ def globals: Set[(PackageName, Identifier)] = traverseUp[Writer[Set[(PackageName, Identifier)], *]] { - case g @ Global(p, i, _, _) => Writer.tell(Set[(PackageName, Identifier)]((p, i))).as(g) + case g @ Global(p, i, _, _) => + Writer.tell(Set[(PackageName, Identifier)]((p, i))).as(g) case notG => Monad[Writer[Set[(PackageName, Identifier)], *]].pure(notG) - } - .written + }.written } - def zonkMeta[F[_]: Applicative, A](te: TypedExpr[A])(fn: Type.Meta => F[Option[Type.Rho]]): F[TypedExpr[A]] = + def zonkMeta[F[_]: Applicative, A](te: TypedExpr[A])( + fn: Type.Meta => F[Option[Type.Rho]] + ): F[TypedExpr[A]] = te.traverseType(Type.zonkMeta(_)(fn)) - /** - * quantify every meta variable that is not escaped into - * the outer environment. - * - * TODO: This can probably be optimized. I think it is currently - * quadradic in depth of the TypedExpr - */ + /** quantify every meta variable that is not escaped into the outer + * environment. + * + * TODO: This can probably be optimized. I think it is currently quadradic in + * depth of the TypedExpr + */ def quantify[F[_]: Monad, A]( - env: Map[(Option[PackageName], Identifier), Type], - rho: TypedExpr.Rho[A], - readFn: Type.Meta => F[Option[Type.Rho]], - writeFn: (Type.Meta, Type.Rho) => F[Unit]): F[TypedExpr[A]] = { + env: Map[(Option[PackageName], Identifier), Type], + rho: TypedExpr.Rho[A], + readFn: Type.Meta => F[Option[Type.Rho]], + writeFn: (Type.Meta, Type.Rho) => F[Unit] + ): F[TypedExpr[A]] = { val zFn = Type.zonk(SortedSet.empty, readFn, writeFn) // we need to zonk before so any known metas are removed @@ -442,13 +492,18 @@ object TypedExpr { } } - def quantify0(metaList: List[Type.Meta], rho: TypedExpr[A]): F[TypedExpr[A]] = + def quantify0( + metaList: List[Type.Meta], + rho: TypedExpr[A] + ): F[TypedExpr[A]] = NonEmptyList.fromList(metaList) match { case None => Applicative[F].pure(rho) case Some(metas) => val used: Set[Type.Var.Bound] = rho.allBound val aligned = Type.alignBinders(metas, used) - val bound = aligned.traverse { case (m, n) => writeFn(m, Type.TyVar(n)).as(((n, m.kind), m.existential)) } + val bound = aligned.traverse { case (m, n) => + writeFn(m, Type.TyVar(n)).as(((n, m.kind), m.existential)) + } // we only need to zonk after doing a write: // it isnot clear that zonkMeta correctly here because the existentials // here have been realized to Type.Var now, and and meta pointing at them should @@ -456,7 +511,8 @@ object TypedExpr { val zFn = Type.zonk( metas.iterator.filter(_.existential).to(SortedSet), readFn, - writeFn) + writeFn + ) (bound, zonkMeta(rho)(zFn)) .mapN { (typeArgs, r) => val forAlls = typeArgs.collect { case (nk, false) => nk } @@ -465,7 +521,11 @@ object TypedExpr { } } - def quantifyMetas(envList: => List[Type], metas: SortedSet[Type.Meta], te: TypedExpr[A]): F[TypedExpr[A]] = + def quantifyMetas( + envList: => List[Type], + metas: SortedSet[Type.Meta], + te: TypedExpr[A] + ): F[TypedExpr[A]] = if (metas.isEmpty) Applicative[F].pure(te) else { for { @@ -478,10 +538,11 @@ object TypedExpr { def quantifyFree(env: Set[Type], te: TypedExpr[A]): F[TypedExpr[A]] = { // this is lazy because we only evaluate it if there is an existential skolem lazy val envList = env.toList - lazy val envExistSkols = Type.freeTyVars(envList) + lazy val envExistSkols = Type + .freeTyVars(envList) .iterator - .collect { - case ex @ Skolem(_, _, true, _) => ex + .collect { case ex @ Skolem(_, _, true, _) => + ex } .toSet[Type.Var.Skolem] @@ -490,7 +551,7 @@ object TypedExpr { .collect { case ex @ Skolem(_, _, true, _) if !envExistSkols(ex) => ex } - + val te1 = NonEmptyList.fromList(teSkols) match { case None => te case Some(nel) => @@ -499,14 +560,17 @@ object TypedExpr { }.toSet val names = Type.alignBinders(nel, used) - val aligned = names.iterator.map { - case (v, b) => (v, Type.TyVar(b)) - } - .toMap[Type.Var, Type] + val aligned = names.iterator + .map { case (v, b) => + (v, Type.TyVar(b)) + } + .toMap[Type.Var, Type] - quantVars(Nil, + quantVars( + Nil, names.toList.map { case (sk, b) => (b, sk.kind) }, - substituteTypeVar(te, aligned)) + substituteTypeVar(te, aligned) + ) } getMetaTyVars(te1.allTypes.toList) @@ -542,7 +606,10 @@ object TypedExpr { // this introduces something into the env val inEnv = env + expr.getType val exprEnv = if (rec.isRecursive) inEnv else env - (deepQuantify(exprEnv + te.getType, expr), deepQuantify(inEnv + te.getType, in)) + ( + deepQuantify(exprEnv + te.getType, expr), + deepQuantify(inEnv + te.getType, in) + ) .mapN { (e1, i1) => Let(arg, e1, i1, rec, tag) } @@ -572,15 +639,20 @@ object TypedExpr { * which has a type forall a. Int which is the same * as Int */ - type Branch = (Pattern[(PackageName, Constructor), Type], TypedExpr[A]) + type Branch = + (Pattern[(PackageName, Constructor), Type], TypedExpr[A]) val allMatchMetas: F[SortedSet[Type.Meta]] = - getMetaTyVars(arg.getType :: branches.foldMap { case (p, _) => allPatternTypes(p) }.toList) + getMetaTyVars(arg.getType :: branches.foldMap { case (p, _) => + allPatternTypes(p) + }.toList) val env1 = env + te.getType def handleBranch(br: Branch): F[Branch] = { val (p, expr) = br - val branchEnv = env1 ++ Pattern.envOf(p, Map.empty) { ident => (None, ident) }.values + val branchEnv = env1 ++ Pattern + .envOf(p, Map.empty) { ident => (None, ident) } + .values deepQuantify(branchEnv, expr).map((p, _)) } @@ -596,123 +668,147 @@ object TypedExpr { // we still need to recurse on arg deepQuantify(env1, arg).map(Match(_, branches, tag)) case Generic(quants, expr) => - finish(expr).map(quantVars(quants.forallList, quants.existList, _)) + finish(expr).map( + quantVars(quants.forallList, quants.existList, _) + ) // $COVERAGE-OFF$ case unreach => - sys.error(s"Match quantification yielded neither Generic nor Match: $unreach") + sys.error( + s"Match quantification yielded neither Generic nor Match: $unreach" + ) // $COVERAGE-ON$ } noArg.flatMap(finish) - case nonest@(Global(_, _, _, _) | Local(_, _, _) | Literal(_, _, _)) => + case nonest @ (Global(_, _, _, _) | Local(_, _, _) | + Literal(_, _, _)) => Applicative[F].pure(nonest) } deepQuantify(env.values.toSet, rho) } - implicit val traverseTypedExpr: Traverse[TypedExpr] = new Traverse[TypedExpr] { - def traverse[F[_]: Applicative, T, S](typedExprT: TypedExpr[T])(fn: T => F[S]): F[TypedExpr[S]] = - typedExprT match { - case Generic(params, expr) => - expr.traverse(fn).map(Generic(params, _)) - case Annotation(of, tpe) => - of.traverse(fn).map(Annotation(_, tpe)) - case AnnotatedLambda(args, res, tag) => - (res.traverse(fn), fn(tag)).mapN { - AnnotatedLambda(args, _, _) - } - case Local(v, tpe, tag) => - fn(tag).map(Local(v, tpe, _)) - case Global(p, v, tpe, tag) => - fn(tag).map(Global(p, v, tpe, _)) - case App(f, args, tpe, tag) => - (f.traverse(fn), args.traverse(_.traverse(fn)), fn(tag)).mapN { - App(_, _, tpe, _) - } - case Let(v, exp, in, rec, tag) => - (exp.traverse(fn), in.traverse(fn), fn(tag)).mapN { - Let(v, _, _, rec, _) - } - case Literal(lit, tpe, tag) => - fn(tag).map(Literal(lit, tpe, _)) - case Match(expr, branches, tag) => - // all branches have the same type: - val tbranch = branches.traverse { - case (p, t) => + implicit val traverseTypedExpr: Traverse[TypedExpr] = + new Traverse[TypedExpr] { + def traverse[F[_]: Applicative, T, S]( + typedExprT: TypedExpr[T] + )(fn: T => F[S]): F[TypedExpr[S]] = + typedExprT match { + case Generic(params, expr) => + expr.traverse(fn).map(Generic(params, _)) + case Annotation(of, tpe) => + of.traverse(fn).map(Annotation(_, tpe)) + case AnnotatedLambda(args, res, tag) => + (res.traverse(fn), fn(tag)).mapN { + AnnotatedLambda(args, _, _) + } + case Local(v, tpe, tag) => + fn(tag).map(Local(v, tpe, _)) + case Global(p, v, tpe, tag) => + fn(tag).map(Global(p, v, tpe, _)) + case App(f, args, tpe, tag) => + (f.traverse(fn), args.traverse(_.traverse(fn)), fn(tag)).mapN { + App(_, _, tpe, _) + } + case Let(v, exp, in, rec, tag) => + (exp.traverse(fn), in.traverse(fn), fn(tag)).mapN { + Let(v, _, _, rec, _) + } + case Literal(lit, tpe, tag) => + fn(tag).map(Literal(lit, tpe, _)) + case Match(expr, branches, tag) => + // all branches have the same type: + val tbranch = branches.traverse { case (p, t) => t.traverse(fn).map((p, _)) - } - (expr.traverse(fn), tbranch, fn(tag)).mapN(Match(_, _, _)) - } + } + (expr.traverse(fn), tbranch, fn(tag)).mapN(Match(_, _, _)) + } - def foldLeft[A, B](typedExprA: TypedExpr[A], b: B)(f: (B, A) => B): B = typedExprA match { - case Generic(_, e) => - foldLeft(e, b)(f) - case Annotation(e, _) => - foldLeft(e, b)(f) - case AnnotatedLambda(_, e, tag) => - val b1 = foldLeft(e, b)(f) - f(b1, tag) - case n: Name[A] => f(b, n.tag) - case App(fn, args, _, tag) => - val b1 = foldLeft(fn, b)(f) - val b2 = args.foldLeft(b1)((b1, a) => foldLeft(a, b1)(f)) - f(b2, tag) - case Let(_, exp, in, _, tag) => - val b1 = foldLeft(exp, b)(f) - val b2 = foldLeft(in, b1)(f) - f(b2, tag) - case Literal(_, _, tag) => - f(b, tag) - case Match(arg, branches, tag) => - val b1 = foldLeft(arg, b)(f) - val b2 = branches.foldLeft(b1) { case (bn, (_, t)) => foldLeft(t, bn)(f) } - f(b2, tag) - } + def foldLeft[A, B](typedExprA: TypedExpr[A], b: B)(f: (B, A) => B): B = + typedExprA match { + case Generic(_, e) => + foldLeft(e, b)(f) + case Annotation(e, _) => + foldLeft(e, b)(f) + case AnnotatedLambda(_, e, tag) => + val b1 = foldLeft(e, b)(f) + f(b1, tag) + case n: Name[A] => f(b, n.tag) + case App(fn, args, _, tag) => + val b1 = foldLeft(fn, b)(f) + val b2 = args.foldLeft(b1)((b1, a) => foldLeft(a, b1)(f)) + f(b2, tag) + case Let(_, exp, in, _, tag) => + val b1 = foldLeft(exp, b)(f) + val b2 = foldLeft(in, b1)(f) + f(b2, tag) + case Literal(_, _, tag) => + f(b, tag) + case Match(arg, branches, tag) => + val b1 = foldLeft(arg, b)(f) + val b2 = branches.foldLeft(b1) { case (bn, (_, t)) => + foldLeft(t, bn)(f) + } + f(b2, tag) + } - def foldRight[A, B](typedExprA: TypedExpr[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] = typedExprA match { - case Generic(_, e) => - foldRight(e, lb)(f) - case Annotation(e, _) => - foldRight(e, lb)(f) - case AnnotatedLambda(_, e, tag) => - val lb1 = f(tag, lb) - foldRight(e, lb1)(f) - case n: Name[A] => f(n.tag, lb) - case App(fn, args, _, tag) => - val b1 = f(tag, lb) - val b2 = args.toList.foldRight(b1)((a, b1) => foldRight(a, b1)(f)) - foldRight(fn, b2)(f) - case Let(_, exp, in, _, tag) => - val b1 = f(tag, lb) - val b2 = foldRight(in, b1)(f) - foldRight(exp, b2)(f) - case Literal(_, _, tag) => - f(tag, lb) - case Match(arg, branches, tag) => - val b1 = f(tag, lb) - val b2 = branches.foldRight(b1) { case ((_, t), bn) => foldRight(t, bn)(f) } - foldRight(arg, b2)(f) - } + def foldRight[A, B](typedExprA: TypedExpr[A], lb: Eval[B])( + f: (A, Eval[B]) => Eval[B] + ): Eval[B] = typedExprA match { + case Generic(_, e) => + foldRight(e, lb)(f) + case Annotation(e, _) => + foldRight(e, lb)(f) + case AnnotatedLambda(_, e, tag) => + val lb1 = f(tag, lb) + foldRight(e, lb1)(f) + case n: Name[A] => f(n.tag, lb) + case App(fn, args, _, tag) => + val b1 = f(tag, lb) + val b2 = args.toList.foldRight(b1)((a, b1) => foldRight(a, b1)(f)) + foldRight(fn, b2)(f) + case Let(_, exp, in, _, tag) => + val b1 = f(tag, lb) + val b2 = foldRight(in, b1)(f) + foldRight(exp, b2)(f) + case Literal(_, _, tag) => + f(tag, lb) + case Match(arg, branches, tag) => + val b1 = f(tag, lb) + val b2 = branches.foldRight(b1) { case ((_, t), bn) => + foldRight(t, bn)(f) + } + foldRight(arg, b2)(f) + } - override def map[A, B](te: TypedExpr[A])(fn: A => B): TypedExpr[B] = te match { - case Generic(tv, in) => Generic(tv, map(in)(fn)) - case Annotation(term, tpe) => Annotation(map(term)(fn), tpe) - case AnnotatedLambda(args, expr, tag) => AnnotatedLambda(args, map(expr)(fn), fn(tag)) - case l@Local(_, _, _) => l.copy(tag = fn(l.tag)) - case g@Global(_, _, _, _) => g.copy(tag = fn(g.tag)) - case App(fnT, args, tpe, tag) => App(map(fnT)(fn), args.map(map(_)(fn)), tpe, fn(tag)) - case Let(b, e, in, r, t) => Let(b, map(e)(fn), map(in)(fn), r, fn(t)) - case lit@Literal(_, _, _) => lit.copy(tag = fn(lit.tag)) - case Match(arg, branches, tag) => - Match(map(arg)(fn), branches.map { case (p, t) => (p, map(t)(fn)) }, fn(tag)) + override def map[A, B](te: TypedExpr[A])(fn: A => B): TypedExpr[B] = + te match { + case Generic(tv, in) => Generic(tv, map(in)(fn)) + case Annotation(term, tpe) => Annotation(map(term)(fn), tpe) + case AnnotatedLambda(args, expr, tag) => + AnnotatedLambda(args, map(expr)(fn), fn(tag)) + case l @ Local(_, _, _) => l.copy(tag = fn(l.tag)) + case g @ Global(_, _, _, _) => g.copy(tag = fn(g.tag)) + case App(fnT, args, tpe, tag) => + App(map(fnT)(fn), args.map(map(_)(fn)), tpe, fn(tag)) + case Let(b, e, in, r, t) => Let(b, map(e)(fn), map(in)(fn), r, fn(t)) + case lit @ Literal(_, _, _) => lit.copy(tag = fn(lit.tag)) + case Match(arg, branches, tag) => + Match( + map(arg)(fn), + branches.map { case (p, t) => (p, map(t)(fn)) }, + fn(tag) + ) + } } - } type Coerce = FunctionK[TypedExpr, TypedExpr] - private def pushDownCovariant(tpe: Type.Quantified, kinds: Type => Option[Kind]): Type = + private def pushDownCovariant( + tpe: Type.Quantified, + kinds: Type => Option[Kind] + ): Type = tpe match { case Type.ForAll(targs, in) => val (cons, cargs) = Type.unapplyAll(in) @@ -723,55 +819,58 @@ object TypedExpr { // recursions) tpe case Some(kind) => - val kindArgs = kind.toArgs - val kindArgsWithArgs = kindArgs.zip(cargs).map { case (ka, a) => (Some(ka), a) } ::: - cargs.drop(kindArgs.length).map((None, _)) + val kindArgsWithArgs = + kindArgs.zip(cargs).map { case (ka, a) => (Some(ka), a) } ::: + cargs.drop(kindArgs.length).map((None, _)) - val argsVectorIdx = kindArgsWithArgs - .iterator - .zipWithIndex - .map { case ((optKA, tpe), idx) => + val argsVectorIdx = kindArgsWithArgs.iterator.zipWithIndex.map { + case ((optKA, tpe), idx) => (Type.freeBoundTyVars(tpe :: Nil).toSet, optKA, tpe, idx) - } - .toVector + }.toVector // if an arg is covariant, it can pull all it's unique freeVars def uniqueFreeVars(idx: Int): Set[Type.Var.Bound] = { val (justIdx, optKA, _, _) = argsVectorIdx(idx) if (optKA.exists(_.variance == Variance.co)) { - argsVectorIdx.iterator.filter(_._4 != idx) + argsVectorIdx.iterator + .filter(_._4 != idx) .foldLeft(justIdx) { case (acc, (s, _, _, _)) => acc -- s } - } - else Set.empty + } else Set.empty } - val withPulled = argsVectorIdx.map { case rec@(_, _, _, idx) => + val withPulled = argsVectorIdx.map { case rec @ (_, _, _, idx) => (rec, uniqueFreeVars(idx)) } val allPulled: Set[Type.Var.Bound] = withPulled.foldMap(_._2) val nonpulled = targs.filterNot { case (v, _) => allPulled(v) } - val pulledArgs = withPulled.iterator.map { case ((_, _, tpe, _), uniques) => - val keep: Type.Var.Bound => Boolean = uniques - Type.forAll(targs.filter { case (t, _) => keep(t) }, tpe) - } - .toList + val pulledArgs = withPulled.iterator.map { + case ((_, _, tpe, _), uniques) => + val keep: Type.Var.Bound => Boolean = uniques + Type.forAll(targs.filter { case (t, _) => keep(t) }, tpe) + }.toList Type.forAll(nonpulled, Type.applyAll(cons, pulledArgs)) - } - case notForAll => - // TODO: we can push down existentials too - notForAll - } + } + case notForAll => + // TODO: we can push down existentials too + notForAll + } // We know initTpe <:< instTpe, we may be able to simply // fix some of the universally quantified variables - def instantiateTo[A](gen: Generic[A], instTpe: Type.Rho, kinds: Type => Option[Kind]): TypedExpr[A] = { + def instantiateTo[A]( + gen: Generic[A], + instTpe: Type.Rho, + kinds: Type => Option[Kind] + ): TypedExpr[A] = { import Type._ - def solve(left: Type, - right: Type, - state: Map[Type.Var, Type], - solveSet: Set[Type.Var], - varKinds: Map[Type.Var, Kind]): Option[Map[Type.Var, Type]] = + def solve( + left: Type, + right: Type, + state: Map[Type.Var, Type], + solveSet: Set[Type.Var], + varKinds: Map[Type.Var, Kind] + ): Option[Map[Type.Var, Type]] = (left, right) match { case (TyVar(v), right) if solveSet(v) => state.get(v) match { @@ -786,18 +885,19 @@ object TypedExpr { else { val vlist = fa.vars.toList - solve(fa.in, + solve( + fa.in, r, state, solveSet -- vlist.iterator.map(_._1), - varKinds ++ vlist) + varKinds ++ vlist + ) } case (_, fa: Type.Quantified) => - val kindsWithVars: Type => Option[Kind] = - { - case v: Type.TyVar => varKinds.get(v.toVar) - case t => kinds(t) - } + val kindsWithVars: Type => Option[Kind] = { + case v: Type.TyVar => varKinds.get(v.toVar) + case t => kinds(t) + } val fa1 = pushDownCovariant(fa, kindsWithVars) if (fa1 != fa) solve(left, fa1, state, solveSet, varKinds) else { @@ -816,14 +916,13 @@ object TypedExpr { if (left == right) { // can't recurse further into left Some(state) - } - else None + } else None case (TyApply(_, _), _) => None } val (bs, in) = gen.quantType match { case Type.ForAll(a, in) => (a.toList, in) - case notForAll => (Nil, notForAll) + case notForAll => (Nil, notForAll) } val solveSet: Set[Var] = bs.iterator.map(_._1).toSet @@ -833,8 +932,12 @@ object TypedExpr { .map { subs => val freeVars = solveSet -- subs.keySet val subBody = substituteTypeVar(gen.in, subs) - val freeExists = gen.quantType.existList.filter { case (t, _) => freeVars(t) } - val freeForall = gen.quantType.forallList.filter { case (t, _) => freeVars(t) } + val freeExists = gen.quantType.existList.filter { case (t, _) => + freeVars(t) + } + val freeForall = gen.quantType.forallList.filter { case (t, _) => + freeVars(t) + } val q = Type.quantify( forallList = freeForall, existList = freeExists, @@ -845,23 +948,23 @@ object TypedExpr { case tq: Type.Quantified => val newGen = Generic(tq.quant, subBody) pushGeneric(newGen) match { - case badOpt @ (None | Some(Generic(_, _)))=> + case badOpt @ (None | Some(Generic(_, _))) => // just wrap ann(badOpt.getOrElse(newGen), instTpe) case Some(notGen) => notGen } - } + } } result match { case None => // TODO some of these just don't look fully unified yet, for instance: // could not solve instantiate: - // + // // forall b: *. Bosatsu/Predef::Order[b] -> forall a: *. Bosatsu/Predef::Dict[b, a] - // + // // to - // + // // Bosatsu/Predef::Order[?338] -> Bosatsu/Predef::Dict[$k$303, $v$304] // but those two types aren't the same. It seems like we have to later // learn that ?338 == $k$303, but we don't seem to know that yet @@ -873,17 +976,21 @@ object TypedExpr { } private def allPatternTypes[N](p: Pattern[N, Type]): SortedSet[Type] = - p.traverseType { t => Writer[SortedSet[Type], Type](SortedSet(t), t) }.run._1 + p.traverseType { t => Writer[SortedSet[Type], Type](SortedSet(t), t) } + .run + ._1 // Invariant, nel must have at least one item in common with quant.vars private def filterQuant( - nel: NonEmptyList[Type.Var], - quant: Type.Quantification + nel: NonEmptyList[Type.Var], + quant: Type.Quantification ): Type.Quantification = { val innerSet = nel.toList.toSet - Type.Quantification.fromLists( - forallList = quant.forallList.filter { case (v, _) => innerSet(v) }, - existList = quant.existList.filter { case (v, _) => innerSet(v) }) + Type.Quantification + .fromLists( + forallList = quant.forallList.filter { case (v, _) => innerSet(v) }, + existList = quant.existList.filter { case (v, _) => innerSet(v) } + ) // this get is safe because at least one var is present .get } @@ -892,7 +999,9 @@ object TypedExpr { g.in match { case AnnotatedLambda(args, body, a) => val argFree = Type.freeBoundTyVars(args.toList.map(_._2)).toSet - val (outer, inner) = g.quantType.vars.toList.partition { case (b, _) => argFree(b) } + val (outer, inner) = g.quantType.vars.toList.partition { case (b, _) => + argFree(b) + } NonEmptyList.fromList(inner).map { inner => // we know this has at least one item val inners = inner.map(_._1) @@ -901,21 +1010,22 @@ object TypedExpr { val pushedBody = pushGeneric(gbody).getOrElse(gbody) val lam = AnnotatedLambda(args, pushedBody, a) NonEmptyList.fromList(outer) match { - case None => lam + case None => lam case Some(outer) => forAll(outer, lam) } } // we can do the same thing on Match case Match(arg, branches, tag) => - val preTypes = branches.foldLeft(arg.allTypes) { case (ts, (p, _)) => ts | allPatternTypes(p) } + val preTypes = branches.foldLeft(arg.allTypes) { case (ts, (p, _)) => + ts | allPatternTypes(p) + } val argFree = Type.freeBoundTyVars(preTypes.toList).toSet if (g.quantType.vars.exists { case (b, _) => argFree(b) }) { None - } - else { + } else { // the only the branches have generics val b1 = branches.map { case (p, b) => - val gb = Generic(g.quant, b) + val gb = Generic(g.quant, b) val gb1 = pushGeneric(gb).getOrElse(gb) (p, gb1) } @@ -925,8 +1035,7 @@ object TypedExpr { val argFree = Type.freeBoundTyVars(v.getType :: Nil).toSet if (g.quantType.vars.exists { case (b, _) => argFree(b) }) { None - } - else { + } else { val gin = Generic(g.quant, in) val gin1 = pushGeneric(gin).getOrElse(gin) Some(Let(b, v, gin1, rec, tag)) @@ -943,7 +1052,7 @@ object TypedExpr { val cb = coerceRho(b, kinds) val cas = args.map { case aRho: Type.Rho => Some(coerceRho(aRho, kinds)) - case _ => None + case _ => None } coerceFn1(args, b, cas, cb, kinds) @@ -952,13 +1061,14 @@ object TypedExpr { def apply[A](expr: TypedExpr[A]) = expr match { case _ if expr.getType.sameAs(tpe) => expr - case Annotation(t, _) => self(t) - case Local(_, _, _) | Global(_, _, _, _) | AnnotatedLambda(_, _, _)| Literal(_, _, _) => + case Annotation(t, _) => self(t) + case Local(_, _, _) | Global(_, _, _, _) | + AnnotatedLambda(_, _, _) | Literal(_, _, _) => // All of these are widened. The lambda seems like we should be able to do // better, but the type isn't a Fun(Type, Type.Rho)... this is probably unreachable for // the AnnotatedLambda Annotation(expr, tpe) - case gen@Generic(_, _) => + case gen @ Generic(_, _) => pushGeneric(gen) match { case Some(e1) => self(e1) case None => @@ -967,7 +1077,7 @@ object TypedExpr { case App(fn, aargs, _, tag) => fn match { case AnnotatedLambda(lamArgs, body, _) => - //(\xs - res)(ys) == let x1 = y1 in let x2 = y2 in ... res + // (\xs - res)(ys) == let x1 = y1 in let x2 = y2 in ... res val binds = lamArgs.zip(aargs).map { case ((n, rho: Type.Rho), arg) => (n, coerceRho(rho, kinds)(arg)) @@ -984,7 +1094,13 @@ object TypedExpr { case (arg, nonRho) => (arg, nonRho, None) } - val fn1 = coerceFn1(cArgs.map(_._2), tpe, cArgs.map(_._3), self, kinds)(fn) + val fn1 = coerceFn1( + cArgs.map(_._2), + tpe, + cArgs.map(_._3), + self, + kinds + )(fn) App(fn1, cArgs.map(_._1), tpe, tag) case _ => // TODO, what should we do here? @@ -1002,14 +1118,17 @@ object TypedExpr { // TODO: this may be wrong. e.g. we could leaving meta in the types // embedded in patterns, this does not seem to happen since we would // error if metas escape typechecking - Match(arg, branches.map { case (p, expr) => (p, self(expr)) }, tag) + Match( + arg, + branches.map { case (p, expr) => (p, self(expr)) }, + tag + ) } } } - /** - * Return the list of the free vars - */ + /** Return the list of the free vars + */ def freeVars[A](ts: List[TypedExpr[A]]): List[Bindable] = freeVarsDup(ts).distinct @@ -1019,16 +1138,18 @@ object TypedExpr { private def freeVarsDup[A](ts: List[TypedExpr[A]]): List[Bindable] = ts.flatMap(_.freeVarsDup) - /** - * Try to substitute ex for ident in the expression: in - * - * This can fail if the free variables in ex are shadowed - * above ident in in. - * - * this code is very similar to Declaration.substitute - * if bugs are found in one, consult the other - */ - def substitute[A](ident: Bindable, ex: TypedExpr[A], in: TypedExpr[A]): Option[TypedExpr[A]] = { + /** Try to substitute ex for ident in the expression: in + * + * This can fail if the free variables in ex are shadowed above ident in in. + * + * this code is very similar to Declaration.substitute if bugs are found in + * one, consult the other + */ + def substitute[A]( + ident: Bindable, + ex: TypedExpr[A], + in: TypedExpr[A] + ): Option[TypedExpr[A]] = { // if we hit a shadow, we don't need to substitute down // that branch @inline def shadows(i: Bindable): Boolean = i === ident @@ -1040,7 +1161,7 @@ object TypedExpr { def loop(in: TypedExpr[A]): Option[TypedExpr[A]] = in match { - case Local(i, _, _) if i === ident => Some(ex) + case Local(i, _, _) if i === ident => Some(ex) case Global(_, _, _, _) | Local(_, _, _) | Literal(_, _, _) => Some(in) case Generic(a, expr) => loop(expr).map(Generic(a, _)) @@ -1052,20 +1173,19 @@ object TypedExpr { else loop(res).map(AnnotatedLambda(args, _, tag)) case App(fn, args, tpe, tag) => (loop(fn), args.traverse(loop(_))).mapN(App(_, _, tpe, tag)) - case let@Let(arg, argE, in, rec, tag) => + case let @ Let(arg, argE, in, rec, tag) => if (masks(arg)) None else if (shadows(arg)) { // recursive shadow blocks both argE and in if (rec.isRecursive) Some(let) else loop(argE).map(Let(arg, _, in, rec, tag)) - } - else { + } else { (loop(argE), loop(in)).mapN(Let(arg, _, _, rec, tag)) } case Match(arg, branches, tag) => // Maintain the order we encounter things: val arg1 = loop(arg) - val b1 = branches.traverse { case in@(p, b) => + val b1 = branches.traverse { case in @ (p, b) => // these are not free variables in this branch val ns = p.names if (ns.exists(masks)) None @@ -1078,73 +1198,81 @@ object TypedExpr { loop(in) } - def substituteTypeVar[A](typedExpr: TypedExpr[A], env: Map[Type.Var, Type]): TypedExpr[A] = + def substituteTypeVar[A]( + typedExpr: TypedExpr[A], + env: Map[Type.Var, Type] + ): TypedExpr[A] = if (env.isEmpty) typedExpr - else typedExpr match { - case Generic(quant, expr) => - // we need to remove the params which are shadowed below - val paramSet: Set[Type.Var] = quant.vars.toList.iterator.map(_._1).toSet - val env1 = env.iterator.filter { case (k, _) => !paramSet(k) }.toMap - Generic(quant, substituteTypeVar(expr, env1)) - case Annotation(of, tpe) => - Annotation( - substituteTypeVar(of, env), - Type.substituteVar(tpe, env)) - case AnnotatedLambda(args, res, tag) => - AnnotatedLambda( - args.map { case (n, tpe) => - (n, Type.substituteVar(tpe, env)) - }, - substituteTypeVar(res, env), - tag) - case Local(v, tpe, tag) => - Local(v, Type.substituteVar(tpe, env), tag) - case Global(p, v, tpe, tag) => - Global(p, v, Type.substituteVar(tpe, env), tag) - case App(f, args, tpe, tag) => - App( - substituteTypeVar(f, env), - args.map(substituteTypeVar(_, env)), - Type.substituteVar(tpe, env), - tag) - case Let(v, exp, in, rec, tag) => - Let( - v, - substituteTypeVar(exp, env), - substituteTypeVar(in, env), - rec, - tag) - case Literal(lit, tpe, tag) => - Literal(lit, Type.substituteVar(tpe, env), tag) - case Match(expr, branches, tag) => - val branches1 = branches.map { - case (p, t) => + else + typedExpr match { + case Generic(quant, expr) => + // we need to remove the params which are shadowed below + val paramSet: Set[Type.Var] = + quant.vars.toList.iterator.map(_._1).toSet + val env1 = env.iterator.filter { case (k, _) => !paramSet(k) }.toMap + Generic(quant, substituteTypeVar(expr, env1)) + case Annotation(of, tpe) => + Annotation(substituteTypeVar(of, env), Type.substituteVar(tpe, env)) + case AnnotatedLambda(args, res, tag) => + AnnotatedLambda( + args.map { case (n, tpe) => + (n, Type.substituteVar(tpe, env)) + }, + substituteTypeVar(res, env), + tag + ) + case Local(v, tpe, tag) => + Local(v, Type.substituteVar(tpe, env), tag) + case Global(p, v, tpe, tag) => + Global(p, v, Type.substituteVar(tpe, env), tag) + case App(f, args, tpe, tag) => + App( + substituteTypeVar(f, env), + args.map(substituteTypeVar(_, env)), + Type.substituteVar(tpe, env), + tag + ) + case Let(v, exp, in, rec, tag) => + Let( + v, + substituteTypeVar(exp, env), + substituteTypeVar(in, env), + rec, + tag + ) + case Literal(lit, tpe, tag) => + Literal(lit, Type.substituteVar(tpe, env), tag) + case Match(expr, branches, tag) => + val branches1 = branches.map { case (p, t) => val p1 = p.mapType(Type.substituteVar(_, env)) val t1 = substituteTypeVar(t, env) (p1, t1) - } - val expr1 = substituteTypeVar(expr, env) - Match(expr1, branches1, tag) - } + } + val expr1 = substituteTypeVar(expr, env) + Match(expr1, branches1, tag) + } - private def replaceVarType[A](te: TypedExpr[A], name: Bindable, tpe: Type): TypedExpr[A] = { + private def replaceVarType[A]( + te: TypedExpr[A], + name: Bindable, + tpe: Type + ): TypedExpr[A] = { def recur(t: TypedExpr[A]) = replaceVarType(t, name, tpe) te match { - case Generic(tv, in) => Generic(tv, recur(in)) - case Annotation(term, tpe) => Annotation(recur(term), tpe) + case Generic(tv, in) => Generic(tv, recur(in)) + case Annotation(term, tpe) => Annotation(recur(term), tpe) case AnnotatedLambda(args, expr, tag) => // this is a kind of let: if (args.exists(_._1 == name)) { // we are shadowing, so we are done: te - } - else { + } else { // no shadow AnnotatedLambda(args, recur(expr), tag) } case Local(nm, _, tag) if nm == name => Local(name, tpe, tag) - case n: Name[A] => n + case n: Name[A] => n case App(fnT, args, tpe, tag) => App(recur(fnT), args.map(recur), tpe, tag) case Let(b, e, in, r, t) => @@ -1158,9 +1286,8 @@ object TypedExpr { // but b does shadow inside `in` Let(b, recur(e), in, r, t) } - } - else Let(b, recur(e), recur(in), r, t) - case lit@Literal(_, _, _) => lit + } else Let(b, recur(e), recur(in), r, t) + case lit @ Literal(_, _, _) => lit case Match(arg, branches, tag) => Match(recur(arg), branches.map { case (p, t) => (p, recur(t)) }, tag) } @@ -1170,40 +1297,49 @@ object TypedExpr { if (te.getType.sameAs(tpe)) te else Annotation(te, tpe) - /** - * TODO this seems pretty expensive to blindly apply: we are deoptimizing - * the nodes pretty heavily - */ - def coerceFn(args: NonEmptyList[Type], result: Type.Rho, coarg: NonEmptyList[Coerce], cores: Coerce, kinds: Type => Option[Kind]): Coerce = + /** TODO this seems pretty expensive to blindly apply: we are deoptimizing the + * nodes pretty heavily + */ + def coerceFn( + args: NonEmptyList[Type], + result: Type.Rho, + coarg: NonEmptyList[Coerce], + cores: Coerce, + kinds: Type => Option[Kind] + ): Coerce = coerceFn1(args, result, coarg.map(Some(_)), cores, kinds) - private def coerceFn1(arg: NonEmptyList[Type], result: Type.Rho, coargOpt: NonEmptyList[Option[Coerce]], cores: Coerce, kinds: Type => Option[Kind]): Coerce = + private def coerceFn1( + arg: NonEmptyList[Type], + result: Type.Rho, + coargOpt: NonEmptyList[Option[Coerce]], + cores: Coerce, + kinds: Type => Option[Kind] + ): Coerce = new FunctionK[TypedExpr, TypedExpr] { self => val fntpe = Type.Fun(arg, result) def apply[A](expr: TypedExpr[A]) = { expr match { - case _ if expr.getType.sameAs(fntpe) => expr - case Annotation(t, _) => self(t) + case _ if expr.getType.sameAs(fntpe) => expr + case Annotation(t, _) => self(t) case AnnotatedLambda(args0, res, tag) => // note, Var(None, name, originalType, tag) // is hanging out in res, or it is unused - val args1 = args0.zip(arg).map { - case ((n, _), t) => (n, t) + val args1 = args0.zip(arg).map { case ((n, _), t) => + (n, t) } - val res1 = args1 - .toList - .foldRight(res) { - case ((name, arg), res) => - replaceVarType(res, name, arg) + val res1 = args1.toList + .foldRight(res) { case ((name, arg), res) => + replaceVarType(res, name, arg) } AnnotatedLambda(args1, cores(res1), tag) - case gen@Generic(_, _) => + case gen @ Generic(_, _) => pushGeneric(gen) match { case Some(e1) => self(e1) case None => instantiateTo(gen, fntpe, kinds) - } + } case Local(_, _, _) | Global(_, _, _, _) | Literal(_, _, _) => ann(expr, fntpe) case Let(arg, argE, in, rec, tag) => @@ -1214,21 +1350,23 @@ object TypedExpr { // error if metas escape typechecking Match(arg, branches.map { case (p, expr) => (p, self(expr)) }, tag) case App(AnnotatedLambda(lamArgs, body, _), aArgs, _, tag) => - //(\x - res)(y) == let x = y in res + // (\x - res)(y) == let x = y in res val arg1 = lamArgs.zip(aArgs).map { case ((n, rho: Type.Rho), arg) => (n, coerceRho(rho, kinds)(arg)) - case ((n, _), arg) => (n, arg) + case ((n, _), arg) => (n, arg) } letAllNonRec(arg1, self(body), tag) case App(_, _, _, _) => /* - * We have to be careful not to collide with the free vars in expr - * TODO: it is unclear why we are doing this... it may have just been - * a cute trick in the original rankn types paper, but I'm not - * sure what is buying us. - */ + * We have to be careful not to collide with the free vars in expr + * TODO: it is unclear why we are doing this... it may have just been + * a cute trick in the original rankn types paper, but I'm not + * sure what is buying us. + */ val free = freeVarsSet(expr :: Nil) - val nameGen = Type.allBinders.iterator.map { v => Identifier.Name(v.name) }.filterNot(free) + val nameGen = Type.allBinders.iterator + .map { v => Identifier.Name(v.name) } + .filterNot(free) val lamArgs = arg.map { t => (nameGen.next(), t) } val aArgs = lamArgs.map { case (n, t) => Local(n, t, expr.tag) } // name -> (expr((name: arg)): result) @@ -1238,14 +1376,21 @@ object TypedExpr { } } - def forAll[A](params: NonEmptyList[(Type.Var.Bound, Kind)], expr: TypedExpr[A]): TypedExpr[A] = + def forAll[A]( + params: NonEmptyList[(Type.Var.Bound, Kind)], + expr: TypedExpr[A] + ): TypedExpr[A] = quantVars(forallList = params.toList, Nil, expr) - def normalizeQuantVars[A](q: Type.Quantification, expr: TypedExpr[A]): TypedExpr[A] = + def normalizeQuantVars[A]( + q: Type.Quantification, + expr: TypedExpr[A] + ): TypedExpr[A] = expr match { case Generic(oldQuant, ex0) => normalizeQuantVars(q.concat(oldQuant), ex0) - case Annotation(term, tpe) if Type.quantify(q, tpe).sameAs(term.getType) => + case Annotation(term, tpe) + if Type.quantify(q, tpe).sameAs(term.getType) => // we not uncommonly add an annotation just to make a generic wrapper to get back where term case _ => @@ -1253,10 +1398,9 @@ object TypedExpr { // We cannot rebind to any used typed inside of expr, but we can reuse // any that are q val frees: Set[Type.Var.Bound] = - expr.freeTyVars.iterator.collect { - case b: Type.Var.Bound => b - } - .toSet + expr.freeTyVars.iterator.collect { case b: Type.Var.Bound => + b + }.toSet q.filter(frees) match { case None => expr @@ -1269,52 +1413,70 @@ object TypedExpr { q match { case ForAll(vars) => val fa1 = Type.alignBinders(vars, avoid) - val subs = fa1.iterator.collect { case ((b, _), b1) if b != b1 => - (b, Type.TyVar(b1)) - } - .toMap[Type.Var, Type] + val subs = fa1.iterator + .collect { + case ((b, _), b1) if b != b1 => + (b, Type.TyVar(b1)) + } + .toMap[Type.Var, Type] Generic( - ForAll(fa1.map { case ((_, k), b) => (b, k)}), - substituteTypeVar(expr, subs)) - case Exists(vars) => + ForAll(fa1.map { case ((_, k), b) => (b, k) }), + substituteTypeVar(expr, subs) + ) + case Exists(vars) => val ex1 = Type.alignBinders(vars, avoid) - val subs = ex1.iterator.collect { case ((b, _), b1) if b != b1 => - (b, Type.TyVar(b1)) - } - .toMap[Type.Var, Type] + val subs = ex1.iterator + .collect { + case ((b, _), b1) if b != b1 => + (b, Type.TyVar(b1)) + } + .toMap[Type.Var, Type] Generic( - Exists(ex1.map { case ((_, k), b) => (b, k)}), - substituteTypeVar(expr, subs)) - case Dual(foralls, exists) => + Exists(ex1.map { case ((_, k), b) => (b, k) }), + substituteTypeVar(expr, subs) + ) + case Dual(foralls, exists) => val fa1 = Type.alignBinders(foralls, avoid) - val ex1 = Type.alignBinders(exists, avoid ++ fa1.iterator.map(_._2)) - val subs = (fa1.iterator ++ ex1.iterator).collect { case ((b, _), b1) if b != b1 => - (b, Type.TyVar(b1)) - } - .toMap[Type.Var, Type] + val ex1 = + Type.alignBinders(exists, avoid ++ fa1.iterator.map(_._2)) + val subs = (fa1.iterator ++ ex1.iterator) + .collect { + case ((b, _), b1) if b != b1 => + (b, Type.TyVar(b1)) + } + .toMap[Type.Var, Type] Generic( Dual( - fa1.map { case ((_, k), b) => (b, k)}, - ex1.map { case ((_, k), b) => (b, k)} + fa1.map { case ((_, k), b) => (b, k) }, + ex1.map { case ((_, k), b) => (b, k) } ), - substituteTypeVar(expr, subs)) + substituteTypeVar(expr, subs) + ) } } } def quantVars[A]( - forallList: List[(Type.Var.Bound, Kind)], - existList: List[(Type.Var.Bound, Kind)], - expr: TypedExpr[A]): TypedExpr[A] = - Type.Quantification.fromLists(forallList = forallList, existList = existList) match { + forallList: List[(Type.Var.Bound, Kind)], + existList: List[(Type.Var.Bound, Kind)], + expr: TypedExpr[A] + ): TypedExpr[A] = + Type.Quantification.fromLists( + forallList = forallList, + existList = existList + ) match { case Some(q) => Generic(q, expr) - case None => expr + case None => expr } - private def lambda[A](args: NonEmptyList[(Bindable, Type)], expr: TypedExpr[A], tag: A): TypedExpr[A] = + private def lambda[A]( + args: NonEmptyList[(Bindable, Type)], + expr: TypedExpr[A], + tag: A + ): TypedExpr[A] = AnnotatedLambda(args, expr, tag) implicit def typedExprHasRegion[T: HasRegion]: HasRegion[TypedExpr[T]] = diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala index 4c2dd898a..8dc2f4eaa 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala @@ -11,41 +11,74 @@ import cats.syntax.all._ object TypedExprNormalization { import TypedExpr._ - type ScopeT[A, S] = Map[(Option[PackageName], Bindable), (RecursionKind, TypedExpr[A], S)] + type ScopeT[A, S] = + Map[(Option[PackageName], Bindable), (RecursionKind, TypedExpr[A], S)] type Scope[A] = FixType.Fix[ScopeT[A, *]] def emptyScope[A]: Scope[A] = FixType.fix[ScopeT[A, *]](Map.empty) implicit final class ScopeOps[A](private val scope: Scope[A]) extends AnyVal { - def updated(key: Bindable, value: (RecursionKind, TypedExpr[A], Scope[A])): Scope[A] = - FixType.fix[ScopeT[A, *]](FixType.unfix[ScopeT[A, *]](scope).updated((None, key), value)) - - def updatedGlobal(pack: PackageName, key: Bindable, value: (RecursionKind, TypedExpr[A], Scope[A])): Scope[A] = - FixType.fix[ScopeT[A, *]](FixType.unfix[ScopeT[A, *]](scope).updated((Some(pack), key), value)) + def updated( + key: Bindable, + value: (RecursionKind, TypedExpr[A], Scope[A]) + ): Scope[A] = + FixType.fix[ScopeT[A, *]]( + FixType.unfix[ScopeT[A, *]](scope).updated((None, key), value) + ) + + def updatedGlobal( + pack: PackageName, + key: Bindable, + value: (RecursionKind, TypedExpr[A], Scope[A]) + ): Scope[A] = + FixType.fix[ScopeT[A, *]]( + FixType.unfix[ScopeT[A, *]](scope).updated((Some(pack), key), value) + ) def -(key: Bindable): Scope[A] = - FixType.fix[ScopeT[A, *]](FixType.unfix[ScopeT[A, *]](scope) - (None -> key)) + FixType.fix[ScopeT[A, *]]( + FixType.unfix[ScopeT[A, *]](scope) - (None -> key) + ) def --(keys: Iterable[Bindable]): Scope[A] = keys.foldLeft(scope)(_ - _) - def getLocal(key: Bindable): Option[(RecursionKind, TypedExpr[A], Scope[A])] = + def getLocal( + key: Bindable + ): Option[(RecursionKind, TypedExpr[A], Scope[A])] = FixType.unfix[ScopeT[A, *]](scope).get((None, key)) - def getGlobal(pack: PackageName, n: Bindable): Option[(RecursionKind, TypedExpr[A], Scope[A])] = + def getGlobal( + pack: PackageName, + n: Bindable + ): Option[(RecursionKind, TypedExpr[A], Scope[A])] = FixType.unfix[ScopeT[A, *]](scope).get((Some(pack), n)) } - private def nameScope[A](b: Bindable, r: RecursionKind, scope: Scope[A]): (Option[Bindable], Scope[A]) = + private def nameScope[A]( + b: Bindable, + r: RecursionKind, + scope: Scope[A] + ): (Option[Bindable], Scope[A]) = if (r.isRecursive) (Some(b), scope - b) else (None, scope) - def normalizeAll[A, V](pack: PackageName, lets: List[(Bindable, RecursionKind, TypedExpr[A])], typeEnv: TypeEnv[V])(implicit ev: V <:< Kind.Arg): List[(Bindable, RecursionKind, TypedExpr[A])] = { + def normalizeAll[A, V]( + pack: PackageName, + lets: List[(Bindable, RecursionKind, TypedExpr[A])], + typeEnv: TypeEnv[V] + )(implicit + ev: V <:< Kind.Arg + ): List[(Bindable, RecursionKind, TypedExpr[A])] = { @annotation.tailrec - def loop(scope: Scope[A], lets: List[(Bindable, RecursionKind, TypedExpr[A])], acc: List[(Bindable, RecursionKind, TypedExpr[A])]): List[(Bindable, RecursionKind, TypedExpr[A])] = + def loop( + scope: Scope[A], + lets: List[(Bindable, RecursionKind, TypedExpr[A])], + acc: List[(Bindable, RecursionKind, TypedExpr[A])] + ): List[(Bindable, RecursionKind, TypedExpr[A])] = lets match { - case Nil => acc.reverse + case Nil => acc.reverse case (b, r, t) :: tail => // if we have a recursive value it shadows the scope val (optName, s0) = nameScope(b, r, scope) @@ -58,34 +91,49 @@ object TypedExprNormalization { } def normalizeProgram[A, V]( - p: PackageName, - fullTypeEnv: TypeEnv[V], - prog: Program[TypeEnv[V], TypedExpr[Declaration], A])(implicit ev: V <:< Kind.Arg): Program[TypeEnv[V], TypedExpr[Declaration], A] = { - val Program(typeEnv, lets, extDefs, stmts) = prog - val normalLets = normalizeAll(p, lets, fullTypeEnv) - Program(typeEnv, normalLets, extDefs, stmts) - } + p: PackageName, + fullTypeEnv: TypeEnv[V], + prog: Program[TypeEnv[V], TypedExpr[Declaration], A] + )(implicit + ev: V <:< Kind.Arg + ): Program[TypeEnv[V], TypedExpr[Declaration], A] = { + val Program(typeEnv, lets, extDefs, stmts) = prog + val normalLets = normalizeAll(p, lets, fullTypeEnv) + Program(typeEnv, normalLets, extDefs, stmts) + } // if you have made one step of progress, use this to recurse // so we don't throw away if we don't progress more - private def normalize1[A, V](namerec: Option[Bindable], te: TypedExpr[A], scope: Scope[A], typeEnv: TypeEnv[V])(implicit ev: V <:< Kind.Arg): Some[TypedExpr[A]] = + private def normalize1[A, V]( + namerec: Option[Bindable], + te: TypedExpr[A], + scope: Scope[A], + typeEnv: TypeEnv[V] + )(implicit ev: V <:< Kind.Arg): Some[TypedExpr[A]] = normalizeLetOpt(namerec, te, scope, typeEnv) match { - case None => Some(te) - case s@Some(_) => s + case None => Some(te) + case s @ Some(_) => s } private def setType[A](expr: TypedExpr[A], tpe: Type): TypedExpr[A] = if (!tpe.sameAs(expr.getType)) Annotation(expr, tpe) else expr - /** - * if the te is not in normal form, transform it into normal form - */ - private def normalizeLetOpt[A, V](namerec: Option[Bindable], te: TypedExpr[A], scope: Scope[A], typeEnv: TypeEnv[V])(implicit ev: V <:< Kind.Arg): Option[TypedExpr[A]] = { + /** if the te is not in normal form, transform it into normal form + */ + private def normalizeLetOpt[A, V]( + namerec: Option[Bindable], + te: TypedExpr[A], + scope: Scope[A], + typeEnv: TypeEnv[V] + )(implicit ev: V <:< Kind.Arg): Option[TypedExpr[A]] = { val kindOf: Type => Option[Kind] = - Type.kindOfOption { case const @ Type.TyConst(_) => typeEnv.getType(const).map(_.kindOf) } + Type.kindOfOption { case const @ Type.TyConst(_) => + typeEnv.getType(const).map(_.kindOf) + } te match { - case g@Generic(_, Annotation(term, _)) if g.getType.sameAs(term.getType) => + case g @ Generic(_, Annotation(term, _)) + if g.getType.sameAs(term.getType) => normalize1(namerec, term, scope, typeEnv) case Generic(q0, Generic(q1, in)) => val term = Generic(q0.concat(q1), in) @@ -103,7 +151,7 @@ object TypedExprNormalization { case _ if e1.getType.sameAs(tpe) => // the type is already right Some(e1) - case (gen@Generic(_, _), rho: Type.Rho) => + case (gen @ Generic(_, _), rho: Type.Rho) => val inst = TypedExpr.instantiateTo(gen, rho, kindOf) // we compare thes to te because instantiate // can add an Annotation back @@ -114,12 +162,12 @@ object TypedExprNormalization { if (notSameTpe eq term) { if (nt == tpe) None else Some(Annotation(term, nt)) - } - else Some(Annotation(notSameTpe, nt)) + } else Some(Annotation(notSameTpe, nt)) } case AnnotatedLambda(lamArgs0, expr, tag) => - lazy val anons: Iterator[Bindable] = Expr.nameIterator() + lazy val anons: Iterator[Bindable] = Expr + .nameIterator() .filterNot(expr.freeVarsDup.toSet) val bodyScope = scope -- lamArgs0.toList.map(_._1) @@ -133,88 +181,101 @@ object TypedExprNormalization { val next = anons.next() changed = changed || (next != n) next - } - else { + } else { n } (n1, Type.normalize(t)) } if (changed) { - normalize1(namerec, - AnnotatedLambda(lamArgs, e1, tag), - scope, - typeEnv) + normalize1(namerec, AnnotatedLambda(lamArgs, e1, tag), scope, typeEnv) - } - else { - - def doesntUseArgs(te: TypedExpr[A]): Boolean = - lamArgs.forall { case (n, _) => te.notFree(n) } - - // assuming b is bound below lamArgs, return true if it doesn't shadow an arg - def doesntShadow(b: Bindable): Boolean = - !lamArgs.exists { case (n, _) => n === b } - - def matchesArgs(nel: NonEmptyList[TypedExpr[A]]): Boolean = - (nel.length == lamArgs.length) && lamArgs.iterator.zip(nel.iterator).forall { - case ((lamN, _), Local(argN, _, _)) => lamN === argN - case _ => false - } + } else { - val ws = Impl.WithScope(scope, ev.substituteCo[TypeEnv](typeEnv)) - e1 match { - case App(fn, aargs, _, _) if matchesArgs(aargs) && doesntUseArgs(fn) => - // x -> f(x) == f (eta conversion) - normalize1(None, setType(fn, te.getType), scope, typeEnv) - case App(ws.ResolveToLambda(Nil, args1, body, ftag), aargs, resT, atag) if namerec.isEmpty => - // args -> (args1 -> e1)(...) - // this is inlining, which we do only when nested directly inside another lambda - // TODO: this is possibly very expensive to always apply. It can really increase - // code size. We probably need better hueristics for when to inline, - // or remove inlining from here unless it can hever hurt and put inlining at a - // different phase. - val fn1 = AnnotatedLambda(args1, body, ftag) - val e2 = App(fn1, aargs, resT, atag) - if (e1 != e2) { - // in this case we have inlined, vs there already being - // a literal lambda being applied - // by normalizing this, it will become a let binding - val e3 = normalize1(None, e2, bodyScope, typeEnv).get - - if (e3.size <= expr.size) { - // we haven't made the code larger - normalize1(namerec, - AnnotatedLambda(lamArgs, e3, tag), - scope, typeEnv) + def doesntUseArgs(te: TypedExpr[A]): Boolean = + lamArgs.forall { case (n, _) => te.notFree(n) } + + // assuming b is bound below lamArgs, return true if it doesn't shadow an arg + def doesntShadow(b: Bindable): Boolean = + !lamArgs.exists { case (n, _) => n === b } + + def matchesArgs(nel: NonEmptyList[TypedExpr[A]]): Boolean = + (nel.length == lamArgs.length) && lamArgs.iterator + .zip(nel.iterator) + .forall { + case ((lamN, _), Local(argN, _, _)) => lamN === argN + case _ => false } - else { - // inlining will make the code larger that it was originally + + val ws = Impl.WithScope(scope, ev.substituteCo[TypeEnv](typeEnv)) + e1 match { + case App(fn, aargs, _, _) + if matchesArgs(aargs) && doesntUseArgs(fn) => + // x -> f(x) == f (eta conversion) + normalize1(None, setType(fn, te.getType), scope, typeEnv) + case App( + ws.ResolveToLambda(Nil, args1, body, ftag), + aargs, + resT, + atag + ) if namerec.isEmpty => + // args -> (args1 -> e1)(...) + // this is inlining, which we do only when nested directly inside another lambda + // TODO: this is possibly very expensive to always apply. It can really increase + // code size. We probably need better hueristics for when to inline, + // or remove inlining from here unless it can hever hurt and put inlining at a + // different phase. + val fn1 = AnnotatedLambda(args1, body, ftag) + val e2 = App(fn1, aargs, resT, atag) + if (e1 != e2) { + // in this case we have inlined, vs there already being + // a literal lambda being applied + // by normalizing this, it will become a let binding + val e3 = normalize1(None, e2, bodyScope, typeEnv).get + + if (e3.size <= expr.size) { + // we haven't made the code larger + normalize1( + namerec, + AnnotatedLambda(lamArgs, e3, tag), + scope, + typeEnv + ) + } else { + // inlining will make the code larger that it was originally + if ((e1 eq expr) && (lamArgs === lamArgs0)) None + else Some(AnnotatedLambda(lamArgs, e1, tag)) + } + } else { if ((e1 eq expr) && (lamArgs === lamArgs0)) None else Some(AnnotatedLambda(lamArgs, e1, tag)) } - } - else { - if ((e1 eq expr) && (lamArgs === lamArgs0)) None - else Some(AnnotatedLambda(lamArgs, e1, tag)) - } - case Let(arg1, ex, in, rec, tag1) if doesntUseArgs(ex) && doesntShadow(arg1) => - // x -> - // y = z - // f(y) - //same as: - //y = z - //x -> f(y) - //avoid recomputing y - //TODO: we could reorder Lets if we have several in a row - normalize1(None, Let(arg1, ex, AnnotatedLambda(lamArgs, in, tag), rec, tag1), scope, typeEnv) - case m@Match(arg1, branches, tag1) if lamArgs.forall { case (arg, _) => arg1.notFree(arg) } => - // same as above: if match does not depend on lambda arg, lift it out + case Let(arg1, ex, in, rec, tag1) + if doesntUseArgs(ex) && doesntShadow(arg1) => + // x -> + // y = z + // f(y) + // same as: + // y = z + // x -> f(y) + // avoid recomputing y + // TODO: we could reorder Lets if we have several in a row + normalize1( + None, + Let(arg1, ex, AnnotatedLambda(lamArgs, in, tag), rec, tag1), + scope, + typeEnv + ) + case m @ Match(arg1, branches, tag1) if lamArgs.forall { + case (arg, _) => arg1.notFree(arg) + } => + // same as above: if match does not depend on lambda arg, lift it out val b1 = branches.traverse { case (p, b) => - if (!lamArgs.exists { case (arg, _) => p.names.contains(arg) }) { + if ( + !lamArgs.exists { case (arg, _) => p.names.contains(arg) } + ) { Some((p, AnnotatedLambda(lamArgs, b, tag))) - } - else None + } else None } b1 match { case None => @@ -224,11 +285,11 @@ object TypedExprNormalization { val m1 = Match(arg1, bs, tag1) normalize1(namerec, m1, scope, typeEnv) } - case notApp => - if ((notApp eq expr) && (lamArgs === lamArgs0)) None - else Some(AnnotatedLambda(lamArgs, notApp, tag)) + case notApp => + if ((notApp eq expr) && (lamArgs === lamArgs0)) None + else Some(AnnotatedLambda(lamArgs, notApp, tag)) + } } - } case Literal(_, _, _) => // these are fundamental None @@ -239,7 +300,8 @@ object TypedExprNormalization { case Global(p, n: Bindable, tpe0, tag) => scope.getGlobal(p, n).flatMap { - case (RecursionKind.NonRecursive, te, _) if Impl.isSimple(te, lambdaSimple = false) => + case (RecursionKind.NonRecursive, te, _) + if Impl.isSimple(te, lambdaSimple = false) => // inlining lambdas naively can cause an exponential blow up in size Some(te) case _ => @@ -267,15 +329,20 @@ object TypedExprNormalization { f1 match { case AnnotatedLambda(lamArgs, expr, _) => // (y -> z)(x) = let y = x in z - val lets = lamArgs.zip(args).map { - case ((n, ltpe), arg) => (n, setType(arg, ltpe)) + val lets = lamArgs.zip(args).map { case ((n, ltpe), arg) => + (n, setType(arg, ltpe)) } val expr2 = setType(expr, tpe) val l = TypedExpr.letAllNonRec(lets, expr2, tag) normalize1(namerec, l, scope, typeEnv) case Let(arg1, ex, in, rec, tag1) if a1.forall(_.notFree(arg1)) => - // (app (let x y z) w) == (let x y (app z w)) if w does not have x free - normalize1(namerec, Let(arg1, ex, App(in, a1, tpe, tag), rec, tag1), scope, typeEnv) + // (app (let x y z) w) == (let x y (app z w)) if w does not have x free + normalize1( + namerec, + Let(arg1, ex, App(in, a1, tpe, tag), rec, tag1), + scope, + typeEnv + ) case _ => if ((f1 eq fn) && (tpe == tpe0) && (a1 eq args)) None else Some(App(f1, a1, tpe, tag)) @@ -286,7 +353,8 @@ object TypedExprNormalization { val (ni, si) = nameScope(arg, rec, scope) val ex1 = normalize1(ni, ex, si, typeEnv).get ex1 match { - case Let(ex1a, ex1ex, ex1in, RecursionKind.NonRecursive, ex1tag) if !rec.isRecursive && in.notFree(ex1a) => + case Let(ex1a, ex1ex, ex1in, RecursionKind.NonRecursive, ex1tag) + if !rec.isRecursive && in.notFree(ex1a) => // according to a SPJ paper, it is generally better // to float lets out of nesting inside in: // let foo = let bar = x in bar in foo @@ -296,14 +364,23 @@ object TypedExprNormalization { // since you are going to evaluate and keep in scope // the expression // we can lift - val l1 = Let(ex1a, ex1ex, Let(arg, ex1in, in, RecursionKind.NonRecursive, tag), RecursionKind.NonRecursive, ex1tag) + val l1 = Let( + ex1a, + ex1ex, + Let(arg, ex1in, in, RecursionKind.NonRecursive, tag), + RecursionKind.NonRecursive, + ex1tag + ) normalize1(namerec, l1, scope, typeEnv) case _ => val scopeIn = si.updated(arg, (rec, ex1, si)) val in1 = normalize1(namerec, in, scopeIn, typeEnv).get in1 match { - case Match(marg, branches, mtag) if !rec.isRecursive && marg.notFree(arg) && branches.exists { case (p, r) => p.names.contains(arg) || r.notFree(arg) } => + case Match(marg, branches, mtag) + if !rec.isRecursive && marg.notFree(arg) && branches.exists { + case (p, r) => p.names.contains(arg) || r.notFree(arg) + } => // x = y // match z: // case w: ww @@ -327,7 +404,8 @@ object TypedExprNormalization { val shouldInline = (!rec.isRecursive) && { (cnt == 1) || Impl.isSimple(ex1, lambdaSimple = true) } - val inlined = if (shouldInline) substitute(arg, ex1, in1) else None + val inlined = + if (shouldInline) substitute(arg, ex1, in1) else None inlined match { case Some(il) => normalize1(namerec, il, scope, typeEnv) @@ -338,15 +416,15 @@ object TypedExprNormalization { normalize1(namerec, step, scope, typeEnv) } } - } - else { + } else { // let x = y in z if x isn't free in z = z Some(in1) } } } - case Match(_, NonEmptyList((p, e), Nil), _) if !e.freeVarsDup.exists(p.names.toSet) => + case Match(_, NonEmptyList((p, e), Nil), _) + if !e.freeVarsDup.exists(p.names.toSet) => // match x: // foo: fn // @@ -356,13 +434,20 @@ object TypedExprNormalization { // match x: // y: fn // let y = x in fn - normalize1(namerec, Let(y, arg, e, RecursionKind.NonRecursive, tag), scope, typeEnv) + normalize1( + namerec, + Let(y, arg, e, RecursionKind.NonRecursive, tag), + scope, + typeEnv + ) case Match(arg, branches, tag) => - - def ncount(shadows: Iterable[Bindable], e: TypedExpr[A]): (Int, TypedExpr[A]) = + def ncount( + shadows: Iterable[Bindable], + e: TypedExpr[A] + ): (Int, TypedExpr[A]) = // the final result of the branch is what is assigned to the name normalizeLetOpt(None, e, scope -- shadows, typeEnv) match { - case None => (0, e) + case None => (0, e) case Some(e) => (1, e) } // we can remove any bindings that aren't used in branches @@ -385,7 +470,10 @@ object TypedExprNormalization { case Pattern.WildCard => (changed0, branches1) case notWild if notWild.names.isEmpty => - val newb = branches1.init ::: ((Pattern.WildCard, branches1.last._2) :: Nil) + val newb = branches1.init ::: (( + Pattern.WildCard, + branches1.last._2 + ) :: Nil) // this newb list clearly has more than 0 elements (changed0 + 1, NonEmptyList.fromListUnsafe(newb)) case _ => @@ -409,8 +497,7 @@ object TypedExprNormalization { normalize1(namerec, m2, scope, typeEnv) case _ => None } - } - else { + } else { // there has been some change, so // see if that unlocked any new changes normalize1(namerec, Match(a1, branches1a, tag), scope, typeEnv) @@ -423,41 +510,64 @@ object TypedExprNormalization { private object Impl { - def scopeMatches[A](names: Set[Bindable], scope: Scope[A], scope1: Scope[A]): Boolean = + def scopeMatches[A]( + names: Set[Bindable], + scope: Scope[A], + scope1: Scope[A] + ): Boolean = names.forall { b => (scope.getLocal(b), scope1.getLocal(b)) match { case (None, None) => true case (Some((r1, t1, s1)), Some((r2, t2, s2))) => (r1 == r2) && - (t1.void == t2.void) && - scopeMatches(t1.freeVarsDup.toSet, s1, s2) + (t1.void == t2.void) && + scopeMatches(t1.freeVarsDup.toSet, s1, s2) case _ => false } } case class WithScope[A](scope: Scope[A], typeEnv: TypeEnv[Kind.Arg]) { private lazy val kindOf: Type => Option[Kind] = - Type.kindOfOption { case const @ Type.TyConst(_) => typeEnv.getType(const).map(_.kindOf) } + Type.kindOfOption { case const @ Type.TyConst(_) => + typeEnv.getType(const).map(_.kindOf) + } object ResolveToLambda { // TODO: don't we need to worry about the type environment for locals? They // can also capture type references to outer Generics - def unapply(te: TypedExpr[A]): Option[(List[(Type.Var.Bound, Kind)], NonEmptyList[(Bindable, Type)], TypedExpr[A], A)] = + def unapply(te: TypedExpr[A]): Option[ + ( + List[(Type.Var.Bound, Kind)], + NonEmptyList[(Bindable, Type)], + TypedExpr[A], + A + ) + ] = te match { - case Annotation(ResolveToLambda((h :: t), args, ex, tag), rho: Type.Rho) => + case Annotation( + ResolveToLambda((h :: t), args, ex, tag), + rho: Type.Rho + ) => val body = AnnotatedLambda(args, ex, tag) val quant = Type.Quantification.ForAll(NonEmptyList(h, t)) val asGen = Generic(quant, body) TypedExpr.instantiateTo(asGen, rho, kindOf) match { case AnnotatedLambda(a, e, t) => Some((Nil, a, e, t)) - case Generic(Type.Quantification.ForAll(nel), AnnotatedLambda(a, e, t)) => + case Generic( + Type.Quantification.ForAll(nel), + AnnotatedLambda(a, e, t) + ) => Some((nel.toList, a, e, t)) case _ => None } - case Generic(Type.Quantification.ForAll(frees), ResolveToLambda(f1, args, ex, tag)) => + case Generic( + Type.Quantification.ForAll(frees), + ResolveToLambda(f1, args, ex, tag) + ) => Some((frees.toList ::: f1, args, ex, tag)) - case AnnotatedLambda(args, expr, ltag) => Some((Nil, args, expr, ltag)) + case AnnotatedLambda(args, expr, ltag) => + Some((Nil, args, expr, ltag)) case Global(p, n: Bindable, _, _) => scope.getGlobal(p, n).flatMap { case (RecursionKind.NonRecursive, te, scope1) => @@ -467,10 +577,15 @@ object TypedExprNormalization { // we can't just replace variables if the scopes don't match. // we could also repair the scope by making a let binding // for any names that don't match (which has to be done recursively - if (scopeMatches(expr.freeVarsDup.toSet -- args.iterator.map(_._1), scope, scope1)) { + if ( + scopeMatches( + expr.freeVarsDup.toSet -- args.iterator.map(_._1), + scope, + scope1 + ) + ) { Some((frees, args, expr, ltag)) - } - else None + } else None case _ => None } case _ => None @@ -484,10 +599,15 @@ object TypedExprNormalization { // we can't just replace variables if the scopes don't match. // we could also repair the scope by making a let binding // for any names that don't match (which has to be done recursively - if (scopeMatches(expr.freeVarsDup.toSet -- args.iterator.map(_._1), scope, scope1)) { + if ( + scopeMatches( + expr.freeVarsDup.toSet -- args.iterator.map(_._1), + scope, + scope1 + ) + ) { Some((frees, args, expr, ltag)) - } - else None + } else None case _ => None } case _ => None @@ -501,8 +621,8 @@ object TypedExprNormalization { final def isSimple[A](ex: TypedExpr[A], lambdaSimple: Boolean): Boolean = ex match { case Literal(_, _, _) | Local(_, _, _) | Global(_, _, _, _) => true - case Annotation(t, _) => isSimple(t, lambdaSimple) - case Generic(_, t) => isSimple(t, lambdaSimple) + case Annotation(t, _) => isSimple(t, lambdaSimple) + case Generic(_, t) => isSimple(t, lambdaSimple) case AnnotatedLambda(_, _, _) => // maybe inline lambdas so we can possibly // apply (x -> f)(g) => let x = g in f @@ -512,15 +632,21 @@ object TypedExprNormalization { sealed abstract class EvalResult[A] object EvalResult { - case class Cons[A](pack: PackageName, cons: Constructor, args: List[TypedExpr[A]]) extends EvalResult[A] + case class Cons[A]( + pack: PackageName, + cons: Constructor, + args: List[TypedExpr[A]] + ) extends EvalResult[A] case class Constant[A](lit: Lit) extends EvalResult[A] } object FnArgs { - def unapply[A](te: TypedExpr[A]): Option[(TypedExpr[A], NonEmptyList[TypedExpr[A]])] = + def unapply[A]( + te: TypedExpr[A] + ): Option[(TypedExpr[A], NonEmptyList[TypedExpr[A]])] = te match { case App(fn, args, _, _) => Some((fn, args)) - case _ => None + case _ => None } } @@ -537,17 +663,24 @@ object TypedExprNormalization { case _ => None } case Let(arg, expr, in, RecursionKind.NonRecursive, _) => - evaluate(in, scope.updated(arg, (RecursionKind.NonRecursive, expr, scope))) + evaluate( + in, + scope.updated(arg, (RecursionKind.NonRecursive, expr, scope)) + ) case FnArgs(fn, args) => evaluate(fn, scope).map { - case EvalResult.Cons(p, c, ahead) => EvalResult.Cons(p, c, ahead ::: args.toList) + case EvalResult.Cons(p, c, ahead) => + EvalResult.Cons(p, c, ahead ::: args.toList) // $COVERAGE-OFF$ case EvalResult.Constant(c) => // this really shouldn't happen, - sys.error(s"unreachable: cannot apply a constant: $te => ${fn.repr} => $c") + sys.error( + s"unreachable: cannot apply a constant: $te => ${fn.repr} => $c" + ) // $COVERAGE-ON$ } - case Global(pack, cons: Constructor, _, _) => Some(EvalResult.Cons(pack, cons, Nil)) + case Global(pack, cons: Constructor, _, _) => + Some(EvalResult.Cons(pack, cons, Nil)) case Global(pack, n: Bindable, _, _) => scope.getGlobal(pack, n).flatMap { case (RecursionKind.NonRecursive, t, s) => @@ -568,25 +701,27 @@ object TypedExprNormalization { type Pat = Pattern[(PackageName, Constructor), Type] type Branch[A] = (Pat, TypedExpr[A]) - def maybeEvalMatch[A](m: Match[_ <: A], scope: Scope[A]): Option[TypedExpr[A]] = + def maybeEvalMatch[A]( + m: Match[_ <: A], + scope: Scope[A] + ): Option[TypedExpr[A]] = evaluate(m.arg, scope).flatMap { case EvalResult.Cons(p, c, args) => - val alen = args.length def isTotal(p: Pat): Boolean = p match { case Pattern.WildCard | Pattern.Var(_) => true - case Pattern.Named(_, p) => isTotal(p) - case Pattern.Annotation(p, _) => isTotal(p) + case Pattern.Named(_, p) => isTotal(p) + case Pattern.Annotation(p, _) => isTotal(p) case Pattern.Union(h, t) => isTotal(h) || t.exists(isTotal) - case _ => false + case _ => false } // The Option signals we can't complete def filterPat(pat: Pat): Option[Option[Pat]] = pat match { - case ps@Pattern.PositionalStruct((p0, c0), args0) => + case ps @ Pattern.PositionalStruct((p0, c0), args0) => if (p0 == p && c0 == c && args0.length == alen) Some(Some(ps)) else Some(None) // we definitely don't match this branch case Pattern.Named(n, p) => @@ -599,18 +734,18 @@ object TypedExprNormalization { case Pattern.Union(h, t) => (filterPat(h), t.traverse(filterPat)) .mapN { (optP1, p2s) => - val flatP2s: List[Pat] = p2s.toList.flatten - optP1 match { - case None => - flatP2s match { - case Nil => None - case h :: t => Some(Pattern.union(h, t)) - } - case Some(p1) => Some(Pattern.union(p1, flatP2s)) - } + val flatP2s: List[Pat] = p2s.toList.flatten + optP1 match { + case None => + flatP2s match { + case Nil => None + case h :: t => Some(Pattern.union(h, t)) + } + case Some(p1) => Some(Pattern.union(p1, flatP2s)) + } } case Pattern.WildCard | Pattern.Var(_) => Some(Some(pat)) - case Pattern.ListPat(_) => + case Pattern.ListPat(_) => // TODO some of these patterns we could evaluate None case _ => None @@ -624,7 +759,8 @@ object TypedExprNormalization { case Pattern.PositionalStruct(_, pats) => Some((Nil, pats)) case Pattern.WildCard => Some((Nil, args.as(Pattern.WildCard))) - case Pattern.Var(n) => Some((n :: Nil, args.as(Pattern.WildCard))) + case Pattern.Var(n) => + Some((n :: Nil, args.as(Pattern.WildCard))) case _ => None } @@ -635,18 +771,28 @@ object TypedExprNormalization { // if we can check all the branches for a match, maybe we can evaluate .flatMap { branches => val candidates: List[(Pat, TypedExpr[A])] = - branches.collect { case (Some(p), r) => (p, r)} + branches.collect { case (Some(p), r) => (p, r) } candidates match { - // $COVERAGE-OFF$ + // $COVERAGE-OFF$ case Nil => // TODO hitting this looks like a bug - sys.error(s"no branch matched in ${m.repr} matched: $p::$c(${args.map(_.repr)})") - // $COVERAGE-ON$ - case (MaybeNamedStruct(b, pats), r) :: rest if rest.isEmpty || pats.forall(isTotal) => + sys.error( + s"no branch matched in ${m.repr} matched: $p::$c(${args.map(_.repr)})" + ) + // $COVERAGE-ON$ + case (MaybeNamedStruct(b, pats), r) :: rest + if rest.isEmpty || pats.forall(isTotal) => // If there are no more items, or all inner patterns are total, we are done // exactly one matches, this can be a sequential match - def matchAll(argPat: List[(TypedExpr[A], Pattern[(PackageName, Constructor), Type])]): TypedExpr[A] = + def matchAll( + argPat: List[ + ( + TypedExpr[A], + Pattern[(PackageName, Constructor), Type] + ) + ] + ): TypedExpr[A] = argPat match { case Nil => r case (a, p) :: tail => @@ -664,7 +810,11 @@ object TypedExprNormalization { } val res = matchAll(args.zip(pats)) - Some(b.foldRight(res)(Let(_, m.arg, _, RecursionKind.NonRecursive, m.tag))) + Some( + b.foldRight(res)( + Let(_, m.arg, _, RecursionKind.NonRecursive, m.tag) + ) + ) case h :: t => // more than one branch might match, wait till runtime val m1 = Match(m.arg, NonEmptyList(h, t), m.tag) @@ -674,12 +824,14 @@ object TypedExprNormalization { } case EvalResult.Constant(li @ Lit.Integer(i)) => - def makeLet(p: Pattern[(PackageName, Constructor), Type]): Option[List[Bindable]] = + def makeLet( + p: Pattern[(PackageName, Constructor), Type] + ): Option[List[Bindable]] = p match { case Pattern.Named(v, p) => makeLet(p).map(v :: _) - case Pattern.WildCard => Some(Nil) - case Pattern.Var(v) => Some(v :: Nil) + case Pattern.WildCard => Some(Nil) + case Pattern.Var(v) => Some(v :: Nil) case Pattern.Annotation(p, _) => makeLet(p) case Pattern.Literal(Lit.Integer(j)) => if (j == i) Some(Nil) @@ -687,19 +839,23 @@ object TypedExprNormalization { case Pattern.Union(h, t) => (h :: t).toList.iterator.map(makeLet).reduce(_.orElse(_)) // $COVERAGE-OFF$ this is ill-typed so should be unreachable - case Pattern.PositionalStruct(_, _) | Pattern.ListPat(_) | Pattern.StrPat(_) | Pattern.Literal(Lit.Str(_) | Lit.Chr(_)) => None + case Pattern.PositionalStruct(_, _) | Pattern.ListPat(_) | + Pattern.StrPat(_) | + Pattern.Literal(Lit.Str(_) | Lit.Chr(_)) => + None // $COVERAGE-ON$ } Foldable[NonEmptyList] - .collectFirstSome[Branch[A], TypedExpr[A]](m.branches) { case (p, r) => - makeLet(p).map { names => - val lit = Literal[A](li, Type.getTypeOf(li), m.tag) - // all these names are bound to the lit - names.distinct.foldLeft(r) { case (r, n) => - Let(n, lit, r, RecursionKind.NonRecursive, m.tag) + .collectFirstSome[Branch[A], TypedExpr[A]](m.branches) { + case (p, r) => + makeLet(p).map { names => + val lit = Literal[A](li, Type.getTypeOf(li), m.tag) + // all these names are bound to the lit + names.distinct.foldLeft(r) { case (r, n) => + Let(n, lit, r, RecursionKind.NonRecursive, m.tag) + } } - } } case EvalResult.Constant(Lit.Str(_) | Lit.Chr(_)) => // TODO, we can match some of these statically diff --git a/core/src/main/scala/org/bykn/bosatsu/UnusedLetCheck.scala b/core/src/main/scala/org/bykn/bosatsu/UnusedLetCheck.scala index 94ab40499..2d2342cd5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/UnusedLetCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/UnusedLetCheck.scala @@ -1,7 +1,14 @@ package org.bykn.bosatsu import cats.Applicative -import cats.data.{Chain, NonEmptyList, Validated, ValidatedNec, Writer, NonEmptyChain} +import cats.data.{ + Chain, + NonEmptyList, + Validated, + ValidatedNec, + Writer, + NonEmptyChain +} import cats.implicits._ import Expr._ @@ -10,9 +17,14 @@ import Identifier.Bindable object UnusedLetCheck { private[this] val ap = Applicative[Writer[Chain[(Bindable, Region)], *]] - private[this] val empty: Writer[Chain[(Bindable, Region)], Set[Bindable]] = ap.pure(Set.empty) + private[this] val empty: Writer[Chain[(Bindable, Region)], Set[Bindable]] = + ap.pure(Set.empty) - private[this] def checkArg(arg: Bindable, reg: => Region, w: Writer[Chain[(Bindable, Region)], Set[Bindable]]) = + private[this] def checkArg( + arg: Bindable, + reg: => Region, + w: Writer[Chain[(Bindable, Region)], Set[Bindable]] + ) = w.flatMap { free => if (free(arg)) ap.pure(free - arg) else { @@ -21,14 +33,16 @@ object UnusedLetCheck { } } - private[this] def loop[A: HasRegion](e: Expr[A]): Writer[Chain[(Bindable, Region)], Set[Bindable]] = + private[this] def loop[A: HasRegion]( + e: Expr[A] + ): Writer[Chain[(Bindable, Region)], Set[Bindable]] = e match { case Annotation(expr, _, _) => loop(expr) case Generic(_, in) => loop(in) case Lambda(args, expr, _) => args.toList.foldRight(loop(expr)) { (arg, res) => - checkArg(arg._1, HasRegion.region(e), res) + checkArg(arg._1, HasRegion.region(e), res) } case Let(arg, expr, in, rec, _) => val exprCheck = loop(expr) @@ -38,13 +52,15 @@ object UnusedLetCheck { if (rec.isRecursive) exprCheck.map(_ - arg) else exprCheck // the region of the let isn't directly tracked, but // it would start with the whole region starts and end at expr - val inCheck = checkArg(arg, - { + val inCheck = checkArg( + arg, { val wholeRegion = HasRegion.region(e) val endRegion = HasRegion.region(expr) val bindRegion = wholeRegion.copy(end = endRegion.end) bindRegion - }, loop(in)) + }, + loop(in) + ) (exprRes, inCheck).mapN(_ ++ _) case Local(name, _) => // this is a free variable: @@ -57,40 +73,47 @@ object UnusedLetCheck { // TODO: patterns need their own region val branchRegions = NonEmptyList.fromListUnsafe( - branches.toList.scanLeft((HasRegion.region(arg), Option.empty[Region])) { case ((prev, _), (_, caseExpr)) => - // between the previous expression and the case is the pattern - (HasRegion.region(caseExpr), Some(Region(prev.end, HasRegion.region(caseExpr).start))) - } - .collect { case (_, Some(r)) => r } + branches.toList + .scanLeft((HasRegion.region(arg), Option.empty[Region])) { + case ((prev, _), (_, caseExpr)) => + // between the previous expression and the case is the pattern + ( + HasRegion.region(caseExpr), + Some(Region(prev.end, HasRegion.region(caseExpr).start)) + ) + } + .collect { case (_, Some(r)) => r } ) - val bcheck = branchRegions.zip(branches).traverse { case (region, (pat, expr)) => - loop(expr).flatMap { frees => - val thisPatNames = pat.names - val unused = thisPatNames.filterNot(frees) - val nextFrees = frees -- thisPatNames + val bcheck = branchRegions + .zip(branches) + .traverse { case (region, (pat, expr)) => + loop(expr).flatMap { frees => + val thisPatNames = pat.names + val unused = thisPatNames.filterNot(frees) + val nextFrees = frees -- thisPatNames - ap.pure(nextFrees).tell(Chain.fromSeq(unused.map((_, region)))) + ap.pure(nextFrees).tell(Chain.fromSeq(unused.map((_, region)))) + } } - } - .map(_.combineAll) + .map(_.combineAll) (argCheck, bcheck).mapN(_ ++ _) } - /** - * Check for any unused lets, defs, or pattern bindings - */ - def check[A: HasRegion](e: Expr[A]): ValidatedNec[(Bindable, Region), Unit] = { + /** Check for any unused lets, defs, or pattern bindings + */ + def check[A: HasRegion]( + e: Expr[A] + ): ValidatedNec[(Bindable, Region), Unit] = { val (chain, _) = loop(e).run NonEmptyChain.fromChain(chain) match { - case None => Validated.valid(()) + case None => Validated.valid(()) case Some(nec) => Validated.invalid(nec.distinct) } } - /** - * Return the free Bindable names in this expression - */ + /** Return the free Bindable names in this expression + */ def freeBound[A](e: Expr[A]): Set[Bindable] = loop(e)(HasRegion.instance(_ => Region(0, 0))).run._2 } diff --git a/core/src/main/scala/org/bykn/bosatsu/Value.scala b/core/src/main/scala/org/bykn/bosatsu/Value.scala index 989f1607e..b0c0199af 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Value.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Value.scala @@ -4,50 +4,47 @@ import cats.data.NonEmptyList import java.math.BigInteger import scala.collection.immutable.SortedMap -/** - * If we later determine that this performance matters - * and this wrapping is hurting, we could replace - * Value with a less structured type and put - * all the reflection into unapply calls but keep - * most of the API - */ +/** If we later determine that this performance matters and this wrapping is + * hurting, we could replace Value with a less structured type and put all the + * reflection into unapply calls but keep most of the API + */ sealed abstract class Value { import Value._ def asFn: NonEmptyList[Value] => Value = this match { case FnValue(f) => f - case other => + case other => // $COVERAGE-OFF$this should be unreachable sys.error(s"invalid cast to Fn: $other") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } def asSum: SumValue = this match { case s: SumValue => s - case _ => + case _ => // $COVERAGE-OFF$this should be unreachable sys.error(s"invalid cast to SumValue: $this") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - def asProduct:ProductValue = + def asProduct: ProductValue = this match { case p: ProductValue => p - case _ => + case _ => // $COVERAGE-OFF$this should be unreachable sys.error(s"invalid cast to ProductValue: $this") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } def asExternal: ExternalValue = this match { case ex: ExternalValue => ex - case _ => + case _ => // $COVERAGE-OFF$this should be unreachable sys.error(s"invalid cast to ExternalValue: $this") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } final def applyAll(args: NonEmptyList[Value]): Value = @@ -56,7 +53,8 @@ sealed abstract class Value { object Value { final class ProductValue(val values: Array[Value]) extends Value { - override lazy val hashCode = scala.util.hashing.MurmurHash3.arrayHash(values) + override lazy val hashCode = + scala.util.hashing.MurmurHash3.arrayHash(values) final def get(idx: Int): Value = values(idx) override def equals(obj: Any): Boolean = @@ -65,7 +63,8 @@ object Value { (this eq thatP) || java.util.Arrays.equals( values.asInstanceOf[Array[AnyRef]], - thatP.values.asInstanceOf[Array[AnyRef]]) + thatP.values.asInstanceOf[Array[AnyRef]] + ) case _ => false } override def toString: String = values.mkString("ProductValue(", ",", ")") @@ -84,14 +83,16 @@ object Value { def unapplySeq(v: Value): Option[Seq[Value]] = v match { case p: ProductValue => Some(p.values.toSeq) - case _ => None + case _ => None } } - final class SumValue(val variant: Int, val value: ProductValue) extends Value { + final class SumValue(val variant: Int, val value: ProductValue) + extends Value { override def equals(that: Any) = that match { - case s: SumValue => (s eq this) || ((variant == s.variant) && (value == s.value)) + case s: SumValue => + (s eq this) || ((variant == s.variant) && (value == s.value)) case _ => false } override def hashCode: Int = @@ -107,7 +108,8 @@ object Value { (0 until constCount).map(new SumValue(_, UnitValue)).toArray def apply(variant: Int, value: ProductValue): SumValue = - if ((value == UnitValue) && ((variant & sizeMask) == 0)) constants(variant) + if ((value == UnitValue) && ((variant & sizeMask) == 0)) + constants(variant) else new SumValue(variant, value) } @@ -122,11 +124,12 @@ object Value { case class SimpleFnValue(toFn: NonEmptyList[Value] => Value) extends Arg - def apply(toFn: NonEmptyList[Value] => Value): FnValue = new FnValue(SimpleFnValue(toFn)) - def unapply(fnValue: FnValue): Some[NonEmptyList[Value] => Value] = Some(fnValue.arg.toFn) + def unapply(fnValue: FnValue): Some[NonEmptyList[Value] => Value] = Some( + fnValue.arg.toFn + ) val identity: FnValue = FnValue(vs => vs.head) } @@ -140,7 +143,7 @@ object Value { def unapply(v: Value): Option[List[Value]] = v match { case p: ProductValue => Some(p.values.toList) - case _ => None + case _ => None } def fromList(vs: List[Value]): ProductValue = @@ -162,7 +165,7 @@ object Value { def fromLit(l: Lit): Value = l match { - case Lit.Str(s) => ExternalValue(s) + case Lit.Str(s) => ExternalValue(s) case Lit.Integer(i) => ExternalValue(i) case c @ Lit.Chr(_) => ExternalValue(c.asStr) } @@ -173,7 +176,7 @@ object Value { def unapply(v: Value): Option[BigInteger] = v match { case ExternalValue(v: BigInteger) => Some(v) - case _ => None + case _ => None } } @@ -182,7 +185,7 @@ object Value { def unapply(v: Value): Option[String] = v match { case ExternalValue(str: String) => Some(str) - case _ => None + case _ => None } } @@ -199,10 +202,9 @@ object Value { else if ((s.variant == 1)) { s.value.values match { case Array(head) => Some(Some(head)) - case _ => None + case _ => None } - } - else None + } else None case _ => None } } @@ -219,10 +221,9 @@ object Value { if (s.variant == 1) { s.value.values match { case Array(head, rest) => Some((head, rest)) - case _ => None + case _ => None } - } - else None + } else None case _ => None } } @@ -231,7 +232,7 @@ object Value { @annotation.tailrec def go(vs: List[Value], acc: Value): Value = vs match { - case Nil => acc + case Nil => acc case h :: tail => go(tail, Cons(h, acc)) } go(items.reverse, VNil) @@ -256,25 +257,28 @@ object Value { fn.applyAll( NonEmptyList( new ProductValue(Array(v1, null)), - new ProductValue(Array(v2, null)) :: Nil) - ) - .asSum - .variant + new ProductValue(Array(v2, null)) :: Nil + ) + ).asSum + .variant // v = 0, 1, 2 for LT, EQ, GT v - 1 } } - //enum Tree: Empty, Branch(size: Int, height: Int, key: a, left: Tree[a], right: Tree[a]) - //struct Dict[k, v](ord: Order[(k, v)], tree: Tree[(k, v)]) + // enum Tree: Empty, Branch(size: Int, height: Int, key: a, left: Tree[a], right: Tree[a]) + // struct Dict[k, v](ord: Order[(k, v)], tree: Tree[(k, v)]) def unapply(v: Value): Option[SortedMap[Value, Value]] = v match { case ProductValue(ordFn: FnValue, tree) => implicit val ord: Ordering[Value] = keyOrderingFromOrdFn(ordFn) - def treeToList(t: Value, acc: SortedMap[Value, Value]): SortedMap[Value, Value] = { + def treeToList( + t: Value, + acc: SortedMap[Value, Value] + ): SortedMap[Value, Value] = { val v = t.asSum if (v.variant == 0) acc // empty else { @@ -286,7 +290,7 @@ object Value { case other => // $COVERAGE-OFF$ sys.error(s"ill-shaped: $other") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } } } @@ -299,21 +303,24 @@ object Value { } val strOrdFn: FnValue = - FnValue { - case NonEmptyList(tup1, tup2 :: Nil) => - (tup1, tup2) match { - case (Tuple(ExternalValue(k1: String) :: _), Tuple(ExternalValue(k2: String) :: _)) => - Comparison.fromInt(k1.compareTo(k2)) - case _ => - // $COVERAGE-OFF$ - sys.error(s"ill-typed in String Dict order: $tup1, $tup2") - // $COVERAGE-ON$ - } + FnValue { case NonEmptyList(tup1, tup2 :: Nil) => + (tup1, tup2) match { + case ( + Tuple(ExternalValue(k1: String) :: _), + Tuple(ExternalValue(k2: String) :: _) + ) => + Comparison.fromInt(k1.compareTo(k2)) + case _ => + // $COVERAGE-OFF$ + sys.error(s"ill-typed in String Dict order: $tup1, $tup2") + // $COVERAGE-ON$ } + } def fromStringKeys(kvs: List[(String, Value)]): Value = { val allItems: Array[(String, Value)] = kvs.toMap.toArray - java.util.Arrays.sort(allItems, Ordering[String].on { (kv: (String, Value)) => kv._1 }) + java.util.Arrays + .sort(allItems, Ordering[String].on { (kv: (String, Value)) => kv._1 }) val empty = (BigInteger.ZERO, BigInteger.ZERO, SumValue(0, UnitValue)) @@ -326,15 +333,22 @@ object Value { val (rh, rz, right) = makeTree(mid + 1, end) val h = lh.max(rh).add(BigInteger.ONE) val z = lz.add(rz).add(BigInteger.ONE) - (h, z, SumValue(1, - new ProductValue( - Array( - ExternalValue(z), - ExternalValue(h), - new ProductValue(Array(ExternalValue(k), v)), - left, - right) - ))) + ( + h, + z, + SumValue( + 1, + new ProductValue( + Array( + ExternalValue(z), + ExternalValue(h), + new ProductValue(Array(ExternalValue(k), v)), + left, + right + ) + ) + ) + ) } val (_, _, tree) = makeTree(0, allItems.length) diff --git a/core/src/main/scala/org/bykn/bosatsu/ValueToDoc.scala b/core/src/main/scala/org/bykn/bosatsu/ValueToDoc.scala index ca88ea430..b66d60d40 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ValueToDoc.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ValueToDoc.scala @@ -14,15 +14,14 @@ import JsonEncodingError.IllTyped case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { - /** - * Convert a typechecked value to a Document representation - * - * Note, we statically build the conversion function if it is possible at - * all, after that, only value errors can occur - * - * this code ASSUMES the type is correct. If not, we may return - * incorrect data if it is not clearly illtyped - */ + /** Convert a typechecked value to a Document representation + * + * Note, we statically build the conversion function if it is possible at + * all, after that, only value errors can occur + * + * this code ASSUMES the type is correct. If not, we may return incorrect + * data if it is not clearly illtyped + */ def toDoc(tpe: Type): Value => Either[IllTyped, Doc] = { type Fn = Value => Either[IllTyped, Doc] @@ -39,33 +38,31 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Some(fn) => fn case None => val res: Eval[Fn] = Eval.later(tpe match { - case Type.IntType => - { - case ExternalValue(v: BigInteger) => - Right(Doc.str(v)) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } - case Type.StrType => - { - case ExternalValue(v: String) => - Right(Document[Lit].document(Lit.Str(v))) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } + case Type.IntType => { + case ExternalValue(v: BigInteger) => + Right(Doc.str(v)) + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTyped(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ + } + case Type.StrType => { + case ExternalValue(v: String) => + Right(Document[Lit].document(Lit.Str(v))) + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTyped(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ + } case Type.UnitType => // encode this as null { case UnitValue => Right(Doc.text("()")) - case other => + case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } + // $COVERAGE-ON$ + } case Type.ListT(t1) => lazy val inner = loop(t1, tpe :: revPath).value @@ -73,12 +70,14 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case VList(vs) => vs.traverse(inner) .map { inners => - Doc.char('[') + (Doc.lineOrEmpty + commaBlock(inners) + Doc.lineOrEmpty).aligned + Doc.char(']') + Doc.char('[') + (Doc.lineOrEmpty + commaBlock( + inners + ) + Doc.lineOrEmpty).aligned + Doc.char(']') } case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.DictT(Type.StrType, vt) => lazy val inner = loop(vt, tpe :: revPath).value @@ -86,25 +85,30 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { { case VDict(d) => - d.toList.traverse { case (k, v) => - k match { - case Str(kstr) => - inner(v).map { vdoc => - (docStr.document(Lit.Str(kstr)) + (Doc.char(':') + Doc.line + vdoc).nested(4)).grouped - } - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTyped(revPath.reverse, tpe, other)) + d.toList + .traverse { case (k, v) => + k match { + case Str(kstr) => + inner(v).map { vdoc => + (docStr.document(Lit.Str(kstr)) + (Doc.char( + ':' + ) + Doc.line + vdoc).nested(4)).grouped + } + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTyped(revPath.reverse, tpe, other)) // $COVERAGE-ON$ + } + } + .map { kvs => + Doc.char('{') + (Doc.lineOrEmpty + commaBlock( + kvs + ) + Doc.lineOrEmpty).aligned + Doc.char('}') } - } - .map { kvs => - Doc.char('{') + (Doc.lineOrEmpty + commaBlock(kvs) + Doc.lineOrEmpty).aligned + Doc.char('}') - } case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.Tuple(ts) => val p1 = tpe :: revPath @@ -117,12 +121,13 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { .toVector .traverse { case (a, fn) => fn(a) } .map { items => - Doc.char('(') + (Doc.lineOrEmpty + commaBlock(items) + Doc.char(',') + Doc.lineOrEmpty).aligned + Doc.char(')') + Doc.char('(') + (Doc.lineOrEmpty + commaBlock(items) + Doc + .char(',') + Doc.lineOrEmpty).aligned + Doc.char(')') } case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.ForAll(_, inner) => @@ -131,7 +136,7 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Type.TyVar(_) => // we don't really know what to do with { _ => Right(Doc.text("")) } - case fn@Type.Fun(_, _) => + case fn @ Type.Fun(_, _) => def arity(fn: Type): Int = fn match { case Type.Fun(_, dest) => @@ -147,7 +152,7 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case _ => // We can have complicated recursion here, we @@ -168,9 +173,11 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Some(dt) => val cons = dt.constructors val (_, targs) = Type.unapplyAll(tpe) - val replaceMap = dt.typeParams.zip(targs).toMap[Type.Var, Type] + val replaceMap = + dt.typeParams.zip(targs).toMap[Type.Var, Type] - lazy val resInner: Map[Int, (Constructor, List[(String, Fn)])] = + lazy val resInner + : Map[Int, (Constructor, List[(String, Fn)])] = cons.zipWithIndex .traverse { case (cf, idx) => val rec = cf.args.traverse { case (field, t) => @@ -183,7 +190,11 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { .map(_.toMap) .value - def params(variant: Int, params: List[Value], src: Value): Either[IllTyped, Doc] = + def params( + variant: Int, + params: List[Value], + src: Value + ): Either[IllTyped, Doc] = resInner.get(variant) match { case None => Left(IllTyped(revPath.reverse, tpe, src)) @@ -193,42 +204,43 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { .zip(fields) .traverse { case (v, (nm, fn)) => fn(v).map { vdoc => - (Doc.text(nm) + Doc.char(':') + Doc.lineOrSpace + vdoc).nested(4) + (Doc.text(nm) + Doc.char( + ':' + ) + Doc.lineOrSpace + vdoc).nested(4) } } .map { paramsDoc => val nm = Doc.text(name.asString) if (paramsDoc.isEmpty) nm else { - nm + Doc.space + - (Doc.char('{') + (Doc.line + commaBlock(paramsDoc)).nested(4) + Doc.line + Doc.char('}')).grouped + nm + Doc.space + + (Doc.char('{') + (Doc.line + commaBlock( + paramsDoc + )).nested(4) + Doc.line + Doc + .char('}')).grouped } } - } - else Left(IllTyped(revPath.reverse, tpe, src)) + } else Left(IllTyped(revPath.reverse, tpe, src)) } - dt.dataFamily match { case DataFamily.NewType => // the outer wrapping is so we add it back { v => params(0, v :: Nil, v) } - case DataFamily.Struct => - { - case prod: ProductValue => - params(0, prod.values.toList, prod) + case DataFamily.Struct => { + case prod: ProductValue => + params(0, prod.values.toList, prod) - case other => - Left(IllTyped(revPath.reverse, tpe, other)) - } - case DataFamily.Enum => - { - case s: SumValue => - params(s.variant, s.value.values.toList, s) - case a => - Left(IllTyped(revPath.reverse, tpe, a)) - } + case other => + Left(IllTyped(revPath.reverse, tpe, other)) + } + case DataFamily.Enum => { + case s: SumValue => + params(s.variant, s.value.values.toList, s) + case a => + Left(IllTyped(revPath.reverse, tpe, a)) + } case DataFamily.Nat => // this is nat-like // TODO, maybe give a warning @@ -239,8 +251,8 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { Left(IllTyped(revPath.reverse, tpe, other)) } } - } - }) + } + }) // put the result in the cache before we compute it // so we can recurse successCache.put(tpe, res) diff --git a/core/src/main/scala/org/bykn/bosatsu/ValueToJson.scala b/core/src/main/scala/org/bykn/bosatsu/ValueToJson.scala index 5dd0ef37d..74e976600 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ValueToJson.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ValueToJson.scala @@ -15,19 +15,17 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { def canEncodeToNull(t: Type): Boolean = t match { - case Type.UnitType => true + case Type.UnitType => true case Type.OptionT(inner) => // if the inside of an Option cannot be null, we can use null // to represent None !canEncodeToNull(inner) case Type.ForAll(_, inner) => canEncodeToNull(inner) - case _ => false + case _ => false } - - /** - * Is a given type supported for Json conversion - */ + /** Is a given type supported for Json conversion + */ def supported(t: Type): Either[UnsupportedType, Unit] = { // if we are currently working on a Type // we assume it is supported, and it isn't @@ -39,13 +37,13 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { Left(UnsupportedType(NonEmptyList(t, working).reverse)) t match { - case _ if working.contains(t) => good + case _ if working.contains(t) => good case Type.IntType | Type.StrType | Type.BoolType | Type.UnitType => good - case Type.OptionT(inner) => loop(inner, t :: working) - case Type.ListT(inner) => loop(inner, t :: working ) + case Type.OptionT(inner) => loop(inner, t :: working) + case Type.ListT(inner) => loop(inner, t :: working) case Type.DictT(Type.StrType, inner) => loop(inner, t :: working) - case Type.ForAll(_, _) => bad - case Type.TyVar(_) | Type.TyMeta(_) => bad + case Type.ForAll(_, _) => bad + case Type.TyVar(_) | Type.TyMeta(_) => bad case Type.Tuple(ts) => val w1 = t :: working ts.traverse_(loop(_, w1)) @@ -59,7 +57,8 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Some(dt) => val cons = dt.constructors val (_, targs) = Type.unapplyAll(consOrApply) - val replaceMap = dt.typeParams.zip(targs).toMap[Type.Var, Type] + val replaceMap = + dt.typeParams.zip(targs).toMap[Type.Var, Type] cons.traverse_ { cf => cf.args.traverse_ { case (_, t) => @@ -80,19 +79,20 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { // $COVERAGE-OFF$ case Left(u) => sys.error(s"should have only called on a supported type: $u") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - /** - * Convert a typechecked value to Json - * - * Note, we statically build the conversion function if it is possible at - * all, after that, only value errors can occur - * - * this code ASSUMES the type is correct. If not, we may return - * incorrect data. - */ - def toJson(tpe: Type): Either[UnsupportedType, Value => Either[IllTyped, Json]] = { + /** Convert a typechecked value to Json + * + * Note, we statically build the conversion function if it is possible at + * all, after that, only value errors can occur + * + * this code ASSUMES the type is correct. If not, we may return incorrect + * data. + */ + def toJson( + tpe: Type + ): Either[UnsupportedType, Value => Either[IllTyped, Json]] = { type Fn = Value => Either[IllTyped, Json] // when we complete a custom type, we put it in here @@ -105,60 +105,57 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Some(fn) => fn case None => val res: Eval[Fn] = Eval.later(tpe match { - case Type.IntType => - { - case ExternalValue(v: BigInteger) => - Right(Json.JNumberStr(v.toString)) - // $COVERAGE-OFF$this should be unreachable - case other => - Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } - case Type.StrType => - { - case ExternalValue(v: String) => - Right(Json.JString(v)) + case Type.IntType => { + case ExternalValue(v: BigInteger) => + Right(Json.JNumberStr(v.toString)) + // $COVERAGE-OFF$this should be unreachable + case other => + Left(IllTyped(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ + } + case Type.StrType => { + case ExternalValue(v: String) => + Right(Json.JString(v)) + // $COVERAGE-OFF$this should be unreachable + case other => + Left(IllTyped(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ + } + case Type.BoolType => { + case True => Right(Json.JBool(true)) + case False => Right(Json.JBool(false)) + case other => // $COVERAGE-OFF$this should be unreachable - case other => - Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } - case Type.BoolType => - { - case True => Right(Json.JBool(true)) - case False => Right(Json.JBool(false)) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } + Left(IllTyped(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ + } case Type.UnitType => // encode this as null { case UnitValue => Right(Json.JNull) - case other => + case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } - case opt@Type.OptionT(t1) => + // $COVERAGE-ON$ + } + case opt @ Type.OptionT(t1) => lazy val inner = loop(t1, tpe :: revPath).value if (canEncodeToNull(opt)) { - // not a nested option + // not a nested option { - case VOption(None) => Right(Json.JNull) + case VOption(None) => Right(Json.JNull) case VOption(Some(a)) => inner(a) case other => Left(IllTyped(revPath.reverse, tpe, other)) } - } - else { + } else { { case VOption(None) => Right(Json.JArray(Vector.empty)) - case VOption(Some(a)) => inner(a).map { j => Json.JArray(Vector(j)) } + case VOption(Some(a)) => + inner(a).map { j => Json.JArray(Vector(j)) } case other => Left(IllTyped(revPath.reverse, tpe, other)) } @@ -174,27 +171,28 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.DictT(Type.StrType, vt) => lazy val inner = loop(vt, tpe :: revPath).value { case VDict(d) => - d.toList.traverse { case (k, v) => - k match { - case Str(kstr) => inner(v).map((kstr, _)) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTyped(revPath.reverse, tpe, other)) + d.toList + .traverse { case (k, v) => + k match { + case Str(kstr) => inner(v).map((kstr, _)) + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTyped(revPath.reverse, tpe, other)) // $COVERAGE-ON$ + } } - } - .map(Json.JObject(_)) + .map(Json.JObject(_)) case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.Tuple(ts) => val p1 = tpe :: revPath @@ -210,7 +208,7 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.ForAll(_, inner) => @@ -227,24 +225,26 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { getDefinedType(const) match { case Some(dt) => Right(dt) case None => - Left(UnsupportedType(NonEmptyList(tpe, revPath).reverse)) + Left( + UnsupportedType(NonEmptyList(tpe, revPath).reverse) + ) } case None => Left(UnsupportedType(NonEmptyList(tpe, revPath).reverse)) }) dt.dataFamily match { - case DataFamily.Nat => - { - case ExternalValue(b: BigInteger) => - Right(Json.JNumberStr(b.toString)) - case other => - Left(IllTyped(revPath.reverse, tpe, other)) - } + case DataFamily.Nat => { + case ExternalValue(b: BigInteger) => + Right(Json.JNumberStr(b.toString)) + case other => + Left(IllTyped(revPath.reverse, tpe, other)) + } case notNat => val cons = dt.constructors val (_, targs) = Type.unapplyAll(tpe) - val replaceMap = dt.typeParams.zip(targs).toMap[Type.Var, Type] + val replaceMap = + dt.typeParams.zip(targs).toMap[Type.Var, Type] val resInner: Eval[Map[Int, List[(String, Fn)]]] = cons.zipWithIndex @@ -258,7 +258,6 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { } .map(_.toMap) - notNat match { case DataFamily.NewType => lazy val inner = resInner.value.head._2.head._2 @@ -273,13 +272,13 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { val plist = prod.values.toList if (plist.size == size) { - plist.zip(productsInner) + plist + .zip(productsInner) .traverse { case (p, (key, f)) => f(p).map((key, _)) } .map { ps => Json.JObject(ps) } - } - else { + } else { Left(IllTyped(revPath.reverse, tpe, prod)) } @@ -297,41 +296,42 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Some(fn) => val vlist = s.value.values.toList if (vlist.size == fn.size) { - vlist.zip(fn) + vlist + .zip(fn) .traverse { case (p, (key, f)) => f(p).map((key, _)) } .map { ps => Json.JObject(ps) } - } - else Left(IllTyped(revPath.reverse, tpe, s)) + } else Left(IllTyped(revPath.reverse, tpe, s)) case None => Left(IllTyped(revPath.reverse, tpe, s)) } case a => Left(IllTyped(revPath.reverse, tpe, a)) - } - } + } + } } - }) + }) // put the result in the cache before we compute it // so we can recurse successCache.put(tpe, res) res - } + } supported(tpe).map(_ => loop(tpe, Nil).value) } - /** - * Convert a Json to a Value - * - * Note, we statically build the conversion function if it is possible at - * all, after that, only value errors can occur - * - * this code ASSUMES the type is correct. If not, we may return - * incorrect data. - */ - def toValue(tpe: Type): Either[UnsupportedType, Json => Either[IllTypedJson, Value]] = { + /** Convert a Json to a Value + * + * Note, we statically build the conversion function if it is possible at + * all, after that, only value errors can occur + * + * this code ASSUMES the type is correct. If not, we may return incorrect + * data. + */ + def toValue( + tpe: Type + ): Either[UnsupportedType, Json => Either[IllTypedJson, Value]] = { type Fn = Json => Either[IllTypedJson, Value] // when we complete a custom type, we put it in here @@ -341,192 +341,199 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { successCache.get(tpe) match { case Some(res) => res case None => - val res: Eval[Json => Either[IllTypedJson, Value]] = Eval.later(tpe match { - case Type.IntType => - { + val res: Eval[Json => Either[IllTypedJson, Value]] = + Eval.later(tpe match { + case Type.IntType => { case Json.JBigInteger(b) => Right(ExternalValue(b)) case other => Left(IllTypedJson(revPath.reverse, tpe, other)) } - case Type.StrType => - { + case Type.StrType => { case Json.JString(v) => Right(ExternalValue(v)) case other => // $COVERAGE-OFF$this should be unreachable Left(IllTypedJson(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - case Type.BoolType => - { + case Type.BoolType => { case Json.JBool(value) => Right(if (value) True else False) case other => // $COVERAGE-OFF$this should be unreachable Left(IllTypedJson(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } - case Type.UnitType => - // encode this as null - { - case Json.JNull => Right(UnitValue) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTypedJson(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - case opt@Type.OptionT(t1) => - if (canEncodeToNull(opt)) { - // not a nested option - lazy val inner = loop(t1, tpe :: revPath).value - + case Type.UnitType => + // encode this as null { - case Json.JNull => Right(VOption.none) - case notNull => inner(notNull).map(VOption.some(_)) + case Json.JNull => Right(UnitValue) + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTypedJson(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ } - } - else { - // we can't encode Option[Option[T]] as null or not, so we encode - // as list of 0 or 1 items + case opt @ Type.OptionT(t1) => + if (canEncodeToNull(opt)) { + // not a nested option + lazy val inner = loop(t1, tpe :: revPath).value + + { + case Json.JNull => Right(VOption.none) + case notNull => inner(notNull).map(VOption.some(_)) + } + } else { + // we can't encode Option[Option[T]] as null or not, so we encode + // as list of 0 or 1 items - lazy val inner = loop(t1, tpe :: revPath).value + lazy val inner = loop(t1, tpe :: revPath).value - { - case Json.JArray(items) if items.lengthCompare(1) <= 0 => - items.headOption match { - case None => Right(VOption.none) - case Some(a) => inner(a).map(VOption.some(_)) - } - case other => + { + case Json.JArray(items) if items.lengthCompare(1) <= 0 => + items.headOption match { + case None => Right(VOption.none) + case Some(a) => inner(a).map(VOption.some(_)) + } + case other => Left(IllTypedJson(revPath.reverse, tpe, other)) + } } - } - case Type.ListT(t) => - lazy val inner = loop(t, tpe :: revPath).value + case Type.ListT(t) => + lazy val inner = loop(t, tpe :: revPath).value - { - case Json.JArray(vs) => - vs.toVector - .traverse(inner) - .map { vs => VList(vs.toList) } - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTypedJson(revPath.reverse, tpe, other)) + { + case Json.JArray(vs) => + vs.toVector + .traverse(inner) + .map { vs => VList(vs.toList) } + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTypedJson(revPath.reverse, tpe, other)) // $COVERAGE-ON$ - } - case Type.DictT(Type.StrType, vt) => - lazy val inner = loop(vt, tpe :: revPath).value + } + case Type.DictT(Type.StrType, vt) => + lazy val inner = loop(vt, tpe :: revPath).value - { - case Json.JObject(items) => - items.traverse { case (k, v) => - inner(v).map((k, _)) - } - .map { kvs => - VDict.fromStringKeys(kvs) - } - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTypedJson(revPath.reverse, tpe, other)) + { + case Json.JObject(items) => + items + .traverse { case (k, v) => + inner(v).map((k, _)) + } + .map { kvs => + VDict.fromStringKeys(kvs) + } + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTypedJson(revPath.reverse, tpe, other)) // $COVERAGE-ON$ - } - case Type.Tuple(ts) => - val p1 = tpe :: revPath - lazy val inners = ts.traverse(loop(_, p1)).value + } + case Type.Tuple(ts) => + val p1 = tpe :: revPath + lazy val inners = ts.traverse(loop(_, p1)).value - { - case ary@Json.JArray(as) => - if (as.size == inners.size) { - as.zip(inners) - .toVector - .traverse { case (a, fn) => fn(a) } - .map { vs => Tuple.fromList(vs.toList) } - } - else Left(IllTypedJson(revPath.reverse, tpe, ary)) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTypedJson(revPath.reverse, tpe, other)) + { + case ary @ Json.JArray(as) => + if (as.size == inners.size) { + as.zip(inners) + .toVector + .traverse { case (a, fn) => fn(a) } + .map { vs => Tuple.fromList(vs.toList) } + } else Left(IllTypedJson(revPath.reverse, tpe, ary)) + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTypedJson(revPath.reverse, tpe, other)) // $COVERAGE-ON$ - } - - case Type.ForAll(_, inner) => - // we assume the generic positions don't matter and to continue - loop(inner, tpe :: revPath).value - case _ => - val fullPath = tpe :: revPath - - val dt = - get(Type.rootConst(tpe) match { - case Some(Type.TyConst(const)) => - getDefinedType(const) match { - case Some(dt) => Right(dt) - case None => - Left(UnsupportedType(NonEmptyList(tpe, revPath).reverse)) - } - case None => - Left(UnsupportedType(NonEmptyList(tpe, revPath).reverse)) - }) + } - val resInner: Eval[ - List[(Int, List[(String, Json => Either[IllTypedJson, Value])])] - ] = { + case Type.ForAll(_, inner) => + // we assume the generic positions don't matter and to continue + loop(inner, tpe :: revPath).value + case _ => + val fullPath = tpe :: revPath + + val dt = + get(Type.rootConst(tpe) match { + case Some(Type.TyConst(const)) => + getDefinedType(const) match { + case Some(dt) => Right(dt) + case None => + Left( + UnsupportedType(NonEmptyList(tpe, revPath).reverse) + ) + } + case None => + Left(UnsupportedType(NonEmptyList(tpe, revPath).reverse)) + }) + + val resInner: Eval[ + List[ + (Int, List[(String, Json => Either[IllTypedJson, Value])]) + ] + ] = { val cons = dt.constructors val (_, targs) = Type.unapplyAll(tpe) - val replaceMap = dt.typeParams.zip(targs).toMap[Type.Var, Type] + val replaceMap = + dt.typeParams.zip(targs).toMap[Type.Var, Type] cons.zipWithIndex .traverse { case (cf, idx) => - cf.args.traverse { case (pn, t) => - val subsT = Type.substituteVar(t, replaceMap) - loop(subsT, fullPath) - .map((pn.asString, _)) - } - .map { pair => (idx, pair) } + cf.args + .traverse { case (pn, t) => + val subsT = Type.substituteVar(t, replaceMap) + loop(subsT, fullPath) + .map((pn.asString, _)) + } + .map { pair => (idx, pair) } } - } + } - dt.dataFamily match { - case DataFamily.NewType => - // there is one single arg constructor - lazy val inner = resInner.value.head._2.head._2 + dt.dataFamily match { + case DataFamily.NewType => + // there is one single arg constructor + lazy val inner = resInner.value.head._2.head._2 - { j => inner(j) } - case DataFamily.Struct | DataFamily.Enum => + { j => inner(j) } + case DataFamily.Struct | DataFamily.Enum => // This is lazy because we don't want to run // the Evals until we have the first value lazy val mapping: List[(Int, Map[String, (Int, Fn)])] = // if we are in here, all constituent parts can be solved - resInner.value.map { case (idx, lst) => - (idx, - lst - .iterator - .zipWithIndex - .map { case ((nm, fn), idx) => (nm, (idx, fn)) } - .toMap) - } + resInner.value.map { case (idx, lst) => + ( + idx, + lst.iterator.zipWithIndex.map { + case ((nm, fn), idx) => (nm, (idx, fn)) + }.toMap + ) + } { - case obj@Json.JObject(_) => + case obj @ Json.JObject(_) => val keySet = obj.toMap.keySet - def run(cand: List[(Int, Map[String, (Int, Fn)])]): Either[IllTypedJson, Value] = + def run( + cand: List[(Int, Map[String, (Int, Fn)])] + ): Either[IllTypedJson, Value] = cand match { case Nil => Left(IllTypedJson(revPath.reverse, tpe, obj)) - case (variant, decode) :: _ if keySet == decode.keySet => + case (variant, decode) :: _ + if keySet == decode.keySet => val itemArray = new Array[Value](keySet.size) - obj.items.foldM(itemArray) { case (ary, (k, v)) => - val (idx, fn) = decode(k) - fn(v).map { value => - ary(idx) = value - ary + obj.items + .foldM(itemArray) { case (ary, (k, v)) => + val (idx, fn) = decode(k) + fn(v).map { value => + ary(idx) = value + ary + } + } + .map { ary => + val prod = ProductValue.fromList(ary.toList) + if (dt.isStruct) prod + else SumValue(variant, prod) } - } - .map { ary => - val prod = ProductValue.fromList(ary.toList) - if (dt.isStruct) prod - else SumValue(variant, prod) - } case _ :: tail => run(tail) } @@ -534,15 +541,15 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case other => Left(IllTypedJson(revPath.reverse, tpe, other)) } - case DataFamily.Nat => - // this is a nat like type which we encode into integers - { - case Json.JBigInteger(bi) => - Right(ExternalValue(bi)) - case other => - Left(IllTypedJson(revPath.reverse, tpe, other)) - } - } + case DataFamily.Nat => + // this is a nat like type which we encode into integers + { + case Json.JBigInteger(bi) => + Right(ExternalValue(bi)) + case other => + Left(IllTypedJson(revPath.reverse, tpe, other)) + } + } }) successCache.put(tpe, res) @@ -552,11 +559,13 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { supported(tpe).map(_ => loop(tpe, Nil).value) } - /** - * Given a type return the function to convert it a function - * if it is not a function, we consider it a function of 0-arity - */ - def valueFnToJsonFn(t: Type): Either[UnsupportedType, (Int, Value => Either[DataError, Json.JArray => Either[DataError, Json]])] = + /** Given a type return the function to convert it a function if it is not a + * function, we consider it a function of 0-arity + */ + def valueFnToJsonFn(t: Type): Either[ + UnsupportedType, + (Int, Value => Either[DataError, Json.JArray => Either[DataError, Json]]) + ] = t match { case Type.Fun((args, res)) => (args.traverse(toValue(_)), toJson(res)).mapN { (argsFn, resFn) => @@ -565,38 +574,44 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { val arity = argsFn.size val argsFnVector = argsFn.toList.toVector - (arity, { - case Value.FnValue(fn) => - - val jsonFn = { (inputs: Json.JArray) => - if (inputs.toVector.size != arity) Left(IllTypedJson(Nil, t, inputs)) - else { - // we know arity >= 1 because it is a function, so the fromListUnsafe will succeed - inputs.toVector - .zip(argsFnVector) - .traverse { case (a, fn) => fn(a) } - .map { vect => fn(NonEmptyList.fromListUnsafe(vect.toList)) } - .flatMap(resFn) + ( + arity, + { + case Value.FnValue(fn) => + val jsonFn = { (inputs: Json.JArray) => + if (inputs.toVector.size != arity) + Left(IllTypedJson(Nil, t, inputs)) + else { + // we know arity >= 1 because it is a function, so the fromListUnsafe will succeed + inputs.toVector + .zip(argsFnVector) + .traverse { case (a, fn) => fn(a) } + .map { vect => + fn(NonEmptyList.fromListUnsafe(vect.toList)) + } + .flatMap(resFn) + } } - } - Right(jsonFn) - case notFn => Left(IllTyped(Nil, t, notFn)) - }) + Right(jsonFn) + case notFn => Left(IllTyped(Nil, t, notFn)) + } + ) } case _ => // this isn't a function at all toJson(t).map { (fn: (Value) => Either[DataError, Json]) => - - (0, fn.andThen { either => - either.map { result => - - { (args: Json.JArray) => - if (args.toVector.isEmpty) Right(result) - else Left(IllTypedJson(Nil, t, args)) + ( + 0, + fn.andThen { either => + either.map { result => + { (args: Json.JArray) => + if (args.toVector.isEmpty) Right(result) + else Left(IllTypedJson(Nil, t, args)) + } } } - }) + ) } } @@ -606,8 +621,11 @@ sealed abstract class JsonEncodingError object JsonEncodingError { sealed abstract class DataError extends JsonEncodingError - final case class UnsupportedType(path: NonEmptyList[Type]) extends JsonEncodingError + final case class UnsupportedType(path: NonEmptyList[Type]) + extends JsonEncodingError - final case class IllTyped(path: List[Type], tpe: Type, value: Value) extends DataError - final case class IllTypedJson(path: List[Type], tpe: Type, value: Json) extends DataError + final case class IllTyped(path: List[Type], tpe: Type, value: Value) + extends DataError + final case class IllTypedJson(path: List[Type], tpe: Type, value: Json) + extends DataError } diff --git a/core/src/main/scala/org/bykn/bosatsu/Variance.scala b/core/src/main/scala/org/bykn/bosatsu/Variance.scala index eb5ab4925..f6a77f70a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Variance.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Variance.scala @@ -8,35 +8,35 @@ sealed abstract class Variance { def unary_- : Variance = this match { case Contravariant => Covariant - case Covariant => Contravariant - case topOrBottom => topOrBottom + case Covariant => Contravariant + case topOrBottom => topOrBottom } // if you have f[x] the variance of the result is the arg of f times variance of x def *(that: Variance): Variance = (this, that) match { - case (Phantom, _) => Phantom - case (_, Phantom) => Phantom - case (Invariant, _) => Invariant - case (_, Invariant) => Invariant - case (Covariant, r) => r + case (Phantom, _) => Phantom + case (_, Phantom) => Phantom + case (Invariant, _) => Invariant + case (_, Invariant) => Invariant + case (Covariant, r) => r case (Contravariant, Contravariant) => Covariant - case (Contravariant, Covariant) => Contravariant + case (Contravariant, Covariant) => Contravariant } - /** - * Variance forms a lattice with Phantom at the bottom and Invariant at the top. - */ + /** Variance forms a lattice with Phantom at the bottom and Invariant at the + * top. + */ def +(that: Variance): Variance = (this, that) match { - case (Phantom, r) => r - case (r, Phantom) => r - case (Invariant, _) => Invariant - case (_, Invariant) => Invariant - case (Covariant, Covariant) => Covariant + case (Phantom, r) => r + case (r, Phantom) => r + case (Invariant, _) => Invariant + case (_, Invariant) => Invariant + case (Covariant, Covariant) => Covariant case (Contravariant, Contravariant) => Contravariant - case (Covariant, Contravariant) => Invariant - case (Contravariant, Covariant) => Invariant + case (Covariant, Contravariant) => Invariant + case (Contravariant, Covariant) => Invariant } } object Variance { @@ -51,7 +51,8 @@ object Variance { def contra: Variance = Contravariant def in: Variance = Invariant - val all: List[Variance] = Phantom :: Covariant :: Contravariant :: Invariant :: Nil + val all: List[Variance] = + Phantom :: Covariant :: Contravariant :: Invariant :: Nil implicit val varianceBoundedSemilattice: BoundedSemilattice[Variance] = new BoundedSemilattice[Variance] { @@ -67,15 +68,15 @@ object Variance { case Phantom => if (right == Phantom) 0 else -1 case Covariant => right match { - case Phantom => 1 - case Covariant => 0 + case Phantom => 1 + case Covariant => 0 case Contravariant | Invariant => -1 } - case Contravariant => + case Contravariant => right match { case Phantom | Covariant => 1 - case Contravariant => 0 - case Invariant => -1 + case Contravariant => 0 + case Invariant => -1 } case Invariant => if (right == Invariant) 0 else 1 } diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala index 02e17e734..ce9c88cf9 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala @@ -22,9 +22,9 @@ object Code { sealed abstract class Expression extends ValueLike with Code { def identOrParens: Expression = this match { - case i: Code.Ident => i - case p@Code.Parens(_) => p - case other => Code.Parens(other) + case i: Code.Ident => i + case p @ Code.Parens(_) => p + case other => Code.Parens(other) } def apply(args: Expression*): Apply = @@ -62,7 +62,7 @@ object Code { def evalMinus(that: Expression): Expression = eval(Const.Minus, that) - + def -(that: Expression): Expression = evalMinus(that) def evalTimes(that: Expression): Expression = @@ -81,7 +81,7 @@ object Code { def statements: NonEmptyList[Statement] = this match { case Block(ss) => ss - case notBlock => NonEmptyList.one(notBlock) + case notBlock => NonEmptyList.one(notBlock) } def +:(stmt: Statement): Statement = @@ -105,8 +105,8 @@ object Code { case Pass => vl case _ => vl match { - case wv@WithValue(_, _) => this +: wv - case _ => WithValue(this, vl) + case wv @ WithValue(_, _) => this +: wv + case _ => WithValue(this, vl) } } } @@ -117,11 +117,12 @@ object Code { private def maybePar(c: Expression): Doc = c match { case Lambda(_, _) | Ternary(_, _, _) => par(toDoc(c)) - case _ => toDoc(c) + case _ => toDoc(c) } private def iflike(name: String, cond: Doc, body: Doc): Doc = - Doc.text(name) + Doc.space + cond + Doc.char(':') + (Doc.hardLine + body).nested(4) + Doc.text(name) + Doc.space + cond + Doc.char(':') + (Doc.hardLine + body) + .nested(4) private val trueDoc = Doc.text("True") private val falseDoc = Doc.text("False") @@ -139,42 +140,59 @@ object Code { def exprToDoc(expr: Expression): Doc = expr match { case PyInt(bi) => Doc.text(bi.toString) - case PyString(s) => Doc.char('"') + Doc.text(StringUtil.escape('"', s)) + Doc.char('"') + case PyString(s) => + Doc.char('"') + Doc.text(StringUtil.escape('"', s)) + Doc.char('"') case PyBool(b) => if (b) trueDoc else falseDoc - case Ident(i) => Doc.text(i) - case o@Op(_, _, _) => o.toDoc - case Parens(inner@Parens(_)) => exprToDoc(inner) - case Parens(p) => par(exprToDoc(p)) + case Ident(i) => Doc.text(i) + case o @ Op(_, _, _) => o.toDoc + case Parens(inner @ Parens(_)) => exprToDoc(inner) + case Parens(p) => par(exprToDoc(p)) case SelectItem(x, i) => maybePar(x) + Doc.char('[') + exprToDoc(i) + Doc.char(']') case SelectRange(x, os, oe) => - val middle = os.fold(Doc.empty)(exprToDoc) + Doc.char(':') + oe.fold(Doc.empty)(exprToDoc) + val middle = os.fold(Doc.empty)(exprToDoc) + Doc.char(':') + oe.fold( + Doc.empty + )(exprToDoc) maybePar(x) + (Doc.char('[') + middle + Doc.char(']')).nested(4) case Ternary(ift, cond, iff) => // python parses the else condition as the rest of experssion, so // no need to put parens around it - maybePar(ift) + spaceIfSpace + maybePar(cond) + spaceElseSpace + exprToDoc(iff) + maybePar(ift) + spaceIfSpace + maybePar( + cond + ) + spaceElseSpace + exprToDoc(iff) case MakeTuple(items) => items match { - case Nil => unitDoc + case Nil => unitDoc case h :: Nil => par(exprToDoc(h) + Doc.comma).nested(4) - case twoOrMore => par(Doc.intercalate(Doc.comma + Doc.line, twoOrMore.map(exprToDoc)).grouped).nested(4) + case twoOrMore => + par( + Doc + .intercalate(Doc.comma + Doc.line, twoOrMore.map(exprToDoc)) + .grouped + ).nested(4) } case MakeList(items) => val inner = items.map(exprToDoc) - (Doc.char('[') + Doc.intercalate(Doc.comma + Doc.line, inner).grouped + Doc.char(']')).nested(4) + (Doc.char('[') + Doc + .intercalate(Doc.comma + Doc.line, inner) + .grouped + Doc.char(']')).nested(4) case Lambda(args, res) => - lamDoc + Doc.intercalate(Doc.comma + Doc.space, args.map(exprToDoc)) + colonSpace + exprToDoc(res) + lamDoc + Doc.intercalate( + Doc.comma + Doc.space, + args.map(exprToDoc) + ) + colonSpace + exprToDoc(res) case Apply(fn, args) => - maybePar(fn) + par(Doc.intercalate(Doc.comma + Doc.line, args.map(exprToDoc)).grouped).nested(4) + maybePar(fn) + par( + Doc.intercalate(Doc.comma + Doc.line, args.map(exprToDoc)).grouped + ).nested(4) case DotSelect(left, right) => val ld = left match { case PyInt(_) => par(exprToDoc(left)) - case _ => exprToDoc(left) + case _ => exprToDoc(left) } ld + Doc.char('.') + exprToDoc(right) } @@ -182,14 +200,16 @@ object Code { def toDoc(c: Code): Doc = c match { case expr: Expression => exprToDoc(expr) - case Call(ap) => toDoc(ap) + case Call(ap) => toDoc(ap) case ClassDef(name, ex, body) => val exDoc = if (ex.isEmpty) Doc.empty else par(Doc.intercalate(Doc.comma + Doc.space, ex.map(toDoc))) - Doc.text("class") + Doc.space + Doc.text(name.name) + exDoc + Doc.char(':') + (Doc.hardLine + + Doc.text("class") + Doc.space + Doc.text(name.name) + exDoc + Doc.char( + ':' + ) + (Doc.hardLine + toDoc(body)).nested(4) case IfStatement(conds, Some(Pass)) => @@ -199,7 +219,9 @@ object Code { val condsDoc = conds.map { case (x, b) => (toDoc(x), toDoc(b)) } val i1 = iflike("if", condsDoc.head._1, condsDoc.head._2) val i2 = condsDoc.tail.map { case (x, b) => iflike("elif", x, b) } - val el = optElse.fold(Doc.empty) { els => Doc.hardLine + elseColon + (Doc.hardLine + toDoc(els)).nested(4) } + val el = optElse.fold(Doc.empty) { els => + Doc.hardLine + elseColon + (Doc.hardLine + toDoc(els)).nested(4) + } Doc.intercalate(Doc.hardLine, i1 :: i2) + el @@ -208,18 +230,23 @@ object Code { case Def(nm, args, body) => defDoc + Doc.space + Doc.text(nm.name) + - par(Doc.intercalate(Doc.comma + Doc.lineOrSpace, args.map(toDoc))).nested(4) + Doc.char(':') + (Doc.hardLine + toDoc(body)).nested(4) + par(Doc.intercalate(Doc.comma + Doc.lineOrSpace, args.map(toDoc))) + .nested(4) + Doc.char(':') + (Doc.hardLine + toDoc(body)).nested(4) case Return(expr) => retSpaceDoc + toDoc(expr) case Assign(nm, expr) => toDoc(nm) + spaceEqSpace + toDoc(expr) - case Pass => Doc.text("pass") + case Pass => Doc.text("pass") case While(cond, body) => - whileDoc + Doc.space + toDoc(cond) + Doc.char(':') + (Doc.hardLine + toDoc(body)).nested(4) + whileDoc + Doc.space + toDoc(cond) + Doc.char( + ':' + ) + (Doc.hardLine + toDoc(body)).nested(4) case Import(name, aliasOpt) => // import name as alias val imp = Doc.text("import") + Doc.space + Doc.text(name) - aliasOpt.fold(imp) { a => imp + Doc.space + Doc.text("as") + Doc.space + toDoc(a) } + aliasOpt.fold(imp) { a => + imp + Doc.space + Doc.text("as") + Doc.space + toDoc(a) + } } ///////////////////////// @@ -239,12 +266,16 @@ object Code { def simplify: Expression = this } // Binary operator used for +, -, and, == etc... - case class Op(left: Expression, op: Operator, right: Expression) extends Expression { + case class Op(left: Expression, op: Operator, right: Expression) + extends Expression { // operators like + can associate // def toDoc: Doc = { // invariant: all items in right associate - def loop(left: Expression, rights: NonEmptyList[(Operator, Expression)]): Doc = + def loop( + left: Expression, + rights: NonEmptyList[(Operator, Expression)] + ): Doc = // a op1 b op2 c if op1 and op2 associate no need for a parens // left match { // case Op(_, @@ -255,7 +286,7 @@ object Code { else loop(Parens(left), rights) case leftNotOp => rights.head match { - case (ol, or@Op(_, o2, _)) if !ol.associates(o2) => + case (ol, or @ Op(_, o2, _)) if !ol.associates(o2) => // we we can't break rights.head because the ops // don't associate. We wrap or in Parens val rights1 = (ol, Parens(or)) @@ -266,17 +297,22 @@ object Code { case (ol, rightNotOp) => rights.tail match { case Nil => - maybePar(leftNotOp) + Doc.space + Doc.text(ol.name) + Doc.space + maybePar(rightNotOp) + maybePar(leftNotOp) + Doc.space + Doc.text( + ol.name + ) + Doc.space + maybePar(rightNotOp) case (o2, r2) :: rest => // everything in rights associate - val leftDoc = maybePar(leftNotOp) + Doc.space + Doc.text(ol.name) + Doc.space + val leftDoc = maybePar(leftNotOp) + Doc.space + Doc.text( + ol.name + ) + Doc.space if (ol.associates(o2)) { leftDoc + loop(rightNotOp, NonEmptyList((o2, r2), rest)) - } - else { + } else { // we need to put a parens ending after rightNotOp // leftNotOp ol (rightNotOp o2 r2 :: rest) - leftDoc + par(loop(rightNotOp, NonEmptyList((o2, r2), rest))) + leftDoc + par( + loop(rightNotOp, NonEmptyList((o2, r2), rest)) + ) } } @@ -291,11 +327,11 @@ object Code { this match { case Op(PyInt(a), io: IntOp, PyInt(b)) => PyInt(io(a, b)) - case Op(i@PyInt(a), Const.Times, right) => + case Op(i @ PyInt(a), Const.Times, right) => if (a == BigInteger.ZERO) i else if (a == BigInteger.ONE) right.simplify else right.simplify.evalTimes(i) - case Op(left, Const.Times, i@PyInt(b)) => + case Op(left, Const.Times, i @ PyInt(b)) => if (b == BigInteger.ZERO) i else if (b == BigInteger.ONE) left.simplify else { @@ -303,14 +339,14 @@ object Code { if (l1 == left) this else (l1.evalTimes(i)) } - case Op(i@PyInt(a), Const.Plus, right) => + case Op(i @ PyInt(a), Const.Plus, right) => if (a == BigInteger.ZERO) right.simplify else { val r1 = right.simplify // put the constant on the right r1.evalPlus(i) } - case Op(left, Const.Plus, i@PyInt(b)) => + case Op(left, Const.Plus, i @ PyInt(b)) => if (b == BigInteger.ZERO) left.simplify else { val l1 = left.simplify @@ -322,16 +358,15 @@ object Code { // right associate ll.evalPlus(rl.evalPlus(i)) case Const.Minus => - //(ll - rl) + i == ll - (rl - i) + // (ll - rl) + i == ll - (rl - i) ll.evalMinus(rl.evalMinus(i)) case _ => this } case _ => this } - } - else (l1.evalPlus(i)) + } else (l1.evalPlus(i)) } - case Op(i@PyInt(_), Const.Minus, right) => + case Op(i @ PyInt(_), Const.Minus, right) => val r1 = right.simplify if (r1 == right) { r1 match { @@ -341,9 +376,9 @@ object Code { // right associate rl.evalPlus(rr.evalPlus(i)) case Const.Minus => - //i - (rl - rr) + // i - (rl - rr) rr match { - case ri@PyInt(_) => + case ri @ PyInt(_) => Op(i.evalPlus(ri), Const.Minus, rl) case _ => this } @@ -351,9 +386,8 @@ object Code { } case _ => this } - } - else (i.evalMinus(r1)) - case Op(left, Const.Minus, i@PyInt(b)) => + } else (i.evalMinus(r1)) + case Op(left, Const.Minus, i @ PyInt(b)) => if (b == BigInteger.ZERO) left.simplify else { val l1 = left.simplify @@ -365,16 +399,15 @@ object Code { // (ll + rl) - i == ll + (rl - i) ll.evalPlus(rl.evalMinus(i)) case Const.Minus => - //(ll - rl) - i == ll - (rl + i) + // (ll - rl) - i == ll - (rl + i) ll.evalMinus(rl.evalPlus(i)) case _ => this } case _ => this } - } - else (l1.evalMinus(i)) + } else (l1.evalMinus(i)) } - case Op(a, Const.Eq, b) if a == b => Const.True + case Op(a, Const.Eq, b) if a == b => Const.True case Op(a, Const.Gt | Const.Lt | Const.Neq, b) if a == b => Const.False case Op(PyInt(a), Const.Gt, PyInt(b)) => fromBoolean(a.compareTo(b) > 0) @@ -386,13 +419,13 @@ object Code { fromBoolean(a == b) case Op(a, Const.And, b) => a.simplify match { - case Const.True => b.simplify + case Const.True => b.simplify case Const.False => Const.False case a1 => b.simplify match { - case Const.True => a1 + case Const.True => a1 case Const.False => Const.False - case b1 => Op(a1, Const.And, b1) + case b1 => Op(a1, Const.And, b1) } } case _ => @@ -400,8 +433,7 @@ object Code { val r1 = right.simplify if ((l1 != left) || (r1 != right)) { Op(l1, op, r1).simplify - } - else { + } else { (left, op) match { case (Op(ll, Const.Plus, lr), Const.Plus) => // right associate @@ -424,18 +456,22 @@ object Code { case class Parens(expr: Expression) extends Expression { def simplify: Expression = expr.simplify match { - case x@(PyBool(_) | Ident(_) | PyInt(_) | PyString(_) | Parens(_)) => x + case x @ (PyBool(_) | Ident(_) | PyInt(_) | PyString(_) | Parens(_)) => + x case exprS => Parens(exprS) } } - case class SelectItem(arg: Expression, position: Expression) extends Expression { + case class SelectItem(arg: Expression, position: Expression) + extends Expression { def simplify: Expression = (arg.simplify, position.simplify) match { case (MakeTuple(items), PyInt(bi)) - if (items.lengthCompare(bi.intValue()) > 0) && (bi.intValue() >= 0) => + if (items.lengthCompare(bi.intValue()) > 0) && (bi + .intValue() >= 0) => items(bi.intValue()) case (MakeList(items), PyInt(bi)) - if (items.lengthCompare(bi.intValue()) > 0) && (bi.intValue() >= 0) => + if (items.lengthCompare(bi.intValue()) > 0) && (bi + .intValue() >= 0) => items(bi.intValue()) case (simp, spos) => SelectItem(simp, spos) @@ -446,10 +482,16 @@ object Code { SelectItem(arg, Code.fromInt(position)) } // foo[a:b] - case class SelectRange(arg: Expression, start: Option[Expression], end: Option[Expression]) extends Expression { - def simplify: Expression = SelectRange(arg, start.map(_.simplify), end.map(_.simplify)) + case class SelectRange( + arg: Expression, + start: Option[Expression], + end: Option[Expression] + ) extends Expression { + def simplify: Expression = + SelectRange(arg, start.map(_.simplify), end.map(_.simplify)) } - case class Ternary(ifTrue: Expression, cond: Expression, ifFalse: Expression) extends Expression { + case class Ternary(ifTrue: Expression, cond: Expression, ifFalse: Expression) + extends Expression { def simplify: Expression = cond.simplify match { case PyBool(b) => @@ -481,19 +523,19 @@ object Code { def alloc(rename: List[Ident], avoid: Set[Ident]): List[Ident] = rename match { case Nil => Nil - case (i @ Ident(nm)) :: tail => + case (i @ Ident(nm)) :: tail => val nm1 = if (clashIdent(i)) { // the following iterator is infinite and distinct, and the avoid // set is finite, so the get here must terminate in at most avoid.size // steps - Iterator.from(0) + Iterator + .from(0) .map { i => Ident(nm + i.toString) } .collectFirst { case n if !avoid(n) => n } .get - } - else i + } else i nm1 :: alloc(tail, avoid + nm1) } @@ -510,21 +552,21 @@ object Code { case class Apply(fn: Expression, args: List[Expression]) extends Expression { def simplify: Expression = { fn.simplify match { - case Lambda(largs, result) if largs.length == args.length => + case Lambda(largs, result) if largs.length == args.length => // if this is a lambda, but the args don't match, let // the python error val subMap = largs.iterator.zip(args).toMap val subs = substitute(subMap, result) // now we can simplify after we have inlined the args subs.simplify - case Parens(Lambda(largs, result)) if largs.length == args.length => + case Parens(Lambda(largs, result)) if largs.length == args.length => // if this is a lambda, but the args don't match, let // the python error val subMap = largs.iterator.zip(args).toMap val subs = substitute(subMap, result) // now we can simplify after we have inlined the args subs.simplify - + case notLambda => Apply(notLambda, args.map(_.simplify)) } @@ -539,15 +581,18 @@ object Code { ///////////////////////// // this prepares an expression with a number of statements - case class WithValue(statement: Statement, value: ValueLike) extends ValueLike { + case class WithValue(statement: Statement, value: ValueLike) + extends ValueLike { def +:(stmt: Statement): WithValue = WithValue(stmt +: statement, value) def :+(stmt: Statement): WithValue = WithValue(statement :+ stmt, value) } - case class IfElse(conds: NonEmptyList[(Expression, ValueLike)], elseCond: ValueLike) extends ValueLike - + case class IfElse( + conds: NonEmptyList[(Expression, ValueLike)], + elseCond: ValueLike + ) extends ValueLike ///////////////////////// // Here are all the Statements @@ -555,29 +600,42 @@ object Code { case class Call(sideEffect: Apply) extends Statement // extends are really certain DotSelects, but we can't constrain that much - case class ClassDef(name: Ident, extendList: List[Expression], body: Statement) extends Statement + case class ClassDef( + name: Ident, + extendList: List[Expression], + body: Statement + ) extends Statement case class Block(stmts: NonEmptyList[Statement]) extends Statement - case class IfStatement(conds: NonEmptyList[(Expression, Statement)], elseCond: Option[Statement]) extends Statement - case class Def(name: Ident, args: List[Ident], body: Statement) extends Statement + case class IfStatement( + conds: NonEmptyList[(Expression, Statement)], + elseCond: Option[Statement] + ) extends Statement + case class Def(name: Ident, args: List[Ident], body: Statement) + extends Statement case class Return(expr: Expression) extends Statement case class Assign(target: Expression, value: Expression) extends Statement case object Pass extends Statement case class While(cond: Expression, body: Statement) extends Statement case class Import(modname: String, alias: Option[Ident]) extends Statement - def ifStatement(conds: NonEmptyList[(Expression, Statement)], elseCond: Option[Statement]): Statement = { + def ifStatement( + conds: NonEmptyList[(Expression, Statement)], + elseCond: Option[Statement] + ): Statement = { val simpConds = conds.map { case (e, s) => (e.simplify, s) } val allBranches: NonEmptyList[(Expression, Statement)] = elseCond match { case Some(s) => simpConds :+ ((Code.Const.True, s)) - case None => simpConds + case None => simpConds } // we know the returned expression is never a constant expression - def untilTrue(lst: List[(Expression, Statement)]): (List[(Expression, Statement)], Statement) = + def untilTrue( + lst: List[(Expression, Statement)] + ): (List[(Expression, Statement)], Statement) = lst match { - case Nil => (Nil, Pass) + case Nil => (Nil, Pass) case (Code.Const.True, last) :: _ => (Nil, last) case head :: tail => val (rest, e) = untilTrue(tail) @@ -596,7 +654,11 @@ object Code { def if1(cond: Expression, stmt: Statement): Statement = ifStatement(NonEmptyList.one((cond, stmt)), None) - def ifElseS(cond: Expression, ifCase: Statement, elseCase: Statement): Statement = + def ifElseS( + cond: Expression, + ifCase: Statement, + elseCase: Statement + ): Statement = ifStatement(NonEmptyList.one((cond, ifCase)), Some(elseCase)) /* @@ -625,9 +687,9 @@ object Code { def flatten(s: Statement): List[Statement] = s match { - case Pass => Nil + case Pass => Nil case Block(stmts) => stmts.toList.flatMap(flatten) - case single => single :: Nil + case single => single :: Nil } def block(stmt: Statement, rest: Statement*): Statement = @@ -636,7 +698,7 @@ object Code { def blockFromList(list: List[Statement]): Statement = { val all = list.flatMap(flatten) all match { - case Nil => Pass + case Nil => Pass case one :: Nil => one case head :: tail => Block(NonEmptyList(head, tail)) @@ -653,24 +715,24 @@ object Code { unapply(stmts.last).map { case (s0, i, e) => val s1 = NonEmptyList.fromList(stmts.init) match { - case None => s0 + case None => s0 case Some(inits) => Block(inits) :+ s0 } (s1, i, e) } case Assign(i @ Ident(_), expr) => Some((Pass, i, expr)) - case _ => None + case _ => None } } def substitute(subMap: Map[Ident, Expression], in: Expression): Expression = in match { case PyInt(_) | PyString(_) | PyBool(_) => in - case i@Ident(_) => + case i @ Ident(_) => subMap.get(i) match { case Some(value) => value - case None => i + case None => i } case Op(left, op, right) => Op(substitute(subMap, left), op, substitute(subMap, right)) @@ -679,14 +741,17 @@ object Code { case SelectItem(arg, position) => SelectItem(substitute(subMap, arg), substitute(subMap, position)) case SelectRange(arg, start, end) => - SelectRange(substitute(subMap, arg), + SelectRange( + substitute(subMap, arg), start.map(substitute(subMap, _)), - end.map(substitute(subMap, _))) + end.map(substitute(subMap, _)) + ) case Ternary(ifTrue, cond, ifFalse) => Ternary( substitute(subMap, ifTrue), substitute(subMap, cond), - substitute(subMap, ifFalse)) + substitute(subMap, ifFalse) + ) case MakeTuple(args) => MakeTuple(args.map(substitute(subMap, _))) case MakeList(args) => @@ -698,8 +763,7 @@ object Code { val nonShadowed = subMap.filterNot { case (i, _) => argsSet(i) } // if subFrees is empty, unshadow is a no-op. // but that is efficiently handled by unshadow - val subFrees = nonShadowed - .iterator + val subFrees = nonShadowed.iterator .map { case (_, v) => freeIdents(v) } .foldLeft(nonShadowed.keySet)(_ | _) @@ -718,7 +782,7 @@ object Code { def loop(ex: Expression, bound: Set[Ident]): Set[Ident] = ex match { case PyInt(_) | PyString(_) | PyBool(_) => Set.empty - case i@Ident(_) => + case i @ Ident(_) => if (bound(i)) Set.empty else Set(i) case Op(left, _, right) => loop(left, bound) | loop(right, bound) @@ -743,7 +807,7 @@ object Code { case DotSelect(ex, _) => loop(ex, bound) } - + loop(ex, Set.empty) } @@ -759,10 +823,10 @@ object Code { conds.map { case (c, v) => (c, toReturn(v)) }, - Some(toReturn(elseCond))) + Some(toReturn(elseCond)) + ) } - // boolean expressions can contain side effects // this runs the side effects but discards // and resulting value @@ -784,9 +848,9 @@ object Code { def litToExpr(lit: Lit): Expression = lit match { - case Lit.Str(s) => PyString(s) + case Lit.Str(s) => PyString(s) case Lit.Integer(bi) => PyInt(bi) - case Lit.Chr(s) => PyString(s) + case Lit.Chr(s) => PyString(s) } implicit def fromInt(i: Int): Expression = @@ -800,7 +864,6 @@ object Code { else if (i == 1L) Const.One else PyInt(BigInteger.valueOf(i)) - implicit def fromBoolean(b: Boolean): Expression = if (b) Code.Const.True else Code.Const.False @@ -808,9 +871,9 @@ object Code { def associates(that: Operator): Boolean = { // true if (a this b) that c == a this (b that c) this match { - case Const.Plus => (that == Const.Plus) || (that == Const.Minus) + case Const.Plus => (that == Const.Plus) || (that == Const.Minus) case Const.Minus => false - case Const.And => that == Const.And + case Const.And => that == Const.And case Const.Times => // (a * b) * c == a * (b * c) // (a * b) + c != a * (b + c) @@ -826,15 +889,15 @@ object Code { sealed abstract class IntOp(nm: String) extends Operator(nm) { def apply(a: BigInteger, b: BigInteger): BigInteger = this match { - case Const.Plus => a.add(b) - case Const.Minus => a.subtract(b) - case Const.Times => a.multiply(b) - case Const.Div => PredefImpl.divBigInteger(a, b) - case Const.Mod => PredefImpl.modBigInteger(a, b) - case Const.BitwiseAnd => a.and(b) - case Const.BitwiseOr => a.or(b) - case Const.BitwiseXor => a.xor(b) - case Const.BitwiseShiftLeft => PredefImpl.shiftLeft(a, b) + case Const.Plus => a.add(b) + case Const.Minus => a.subtract(b) + case Const.Times => a.multiply(b) + case Const.Div => PredefImpl.divBigInteger(a, b) + case Const.Mod => PredefImpl.modBigInteger(a, b) + case Const.BitwiseAnd => a.and(b) + case Const.BitwiseOr => a.or(b) + case Const.BitwiseXor => a.xor(b) + case Const.BitwiseShiftLeft => PredefImpl.shiftLeft(a, b) case Const.BitwiseShiftRight => PredefImpl.shiftRight(a, b) } } @@ -869,13 +932,36 @@ object Code { "[_A-Za-z][_0-9A-Za-z]*".r.pattern val pyKeywordList: Set[String] = Set( - "and", "del", "from", "not", "while", - "as", "elif", "global", "or", "with", - "assert", "else", "if", "pass", "yield", - "break", "except", "import", "print", - "class", "exec", "in", "raise", - "continue", "finally", "is", "return", - "def", "for", "lambda", "try" + "and", + "del", + "from", + "not", + "while", + "as", + "elif", + "global", + "or", + "with", + "assert", + "else", + "if", + "pass", + "yield", + "break", + "except", + "import", + "print", + "class", + "exec", + "in", + "raise", + "continue", + "finally", + "is", + "return", + "def", + "for", + "lambda", + "try" ) } - diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index c3667dda8..a3572718e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -3,7 +3,14 @@ package org.bykn.bosatsu.codegen.python import cats.Monad import cats.data.{NonEmptyList, State} import cats.parse.{Parser => P} -import org.bykn.bosatsu.{PackageName, Identifier, Matchless, Par, Parser, RecursionKind} +import org.bykn.bosatsu.{ + PackageName, + Identifier, + Matchless, + Par, + Parser, + RecursionKind +} import org.bykn.bosatsu.rankn.Type import org.typelevel.paiges.Doc @@ -40,18 +47,24 @@ object PythonGen { private object Impl { case class EnvState( - imports: Map[Module, Code.Ident], - bindings: Map[Bindable, (Int, List[Code.Ident])], - tops: Set[Bindable], - nextTmp: Long) { - - private def bindInc(b: Bindable, inc: Int)(fn: Int => Code.Ident): (EnvState, Code.Ident) = { + imports: Map[Module, Code.Ident], + bindings: Map[Bindable, (Int, List[Code.Ident])], + tops: Set[Bindable], + nextTmp: Long + ) { + + private def bindInc(b: Bindable, inc: Int)( + fn: Int => Code.Ident + ): (EnvState, Code.Ident) = { val (c, s) = bindings.getOrElse(b, (0, Nil)) val pname = fn(c) - (copy( - bindings = bindings.updated(b, (c + inc, pname :: s)) - ), pname) + ( + copy( + bindings = bindings.updated(b, (c + inc, pname :: s)) + ), + pname + ) } def bind(b: Bindable): (EnvState, Code.Ident) = @@ -68,11 +81,13 @@ object PythonGen { // see if we are shadowing, or top level bindings.get(b) match { case Some((_, h :: _)) => h - case _ if tops(b) => escape(b) - case other => + case _ if tops(b) => escape(b) + case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"unexpected deref: $b with bindings: $other") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"unexpected deref: $b with bindings: $other" + ) + // $COVERAGE-ON$ } def unbind(b: Bindable): EnvState = @@ -81,11 +96,14 @@ object PythonGen { copy(bindings = bindings.updated(b, (cnt, tail))) case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"invalid scope: $other for $b with $bindings") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"invalid scope: $other for $b with $bindings" + ) + // $COVERAGE-ON$ } - def getNextTmp: (EnvState, Long) = (copy(nextTmp = nextTmp + 1L), nextTmp) + def getNextTmp: (EnvState, Long) = + (copy(nextTmp = nextTmp + 1L), nextTmp) def topLevel(b: Bindable): (EnvState, Code.Ident) = (copy(tops = tops + b), escape(b)) @@ -95,13 +113,14 @@ object PythonGen { case Some(alias) => (this, alias) case None => val impNumber = imports.size - val alias = Code.Ident(escapeRaw("___i", mod.last.name + impNumber.toString)) + val alias = Code.Ident( + escapeRaw("___i", mod.last.name + impNumber.toString) + ) (copy(imports = imports.updated(mod, alias)), alias) } def importStatements: List[Code.Import] = - imports - .iterator + imports.iterator .map { case (path, alias) => val modName = path.map(_.name).toList.mkString(".") Code.Import(modName, Some(alias)) @@ -163,7 +182,8 @@ object PythonGen { Env.pure(Code.Ident(s"___a$long")) def newAssignableVar: Env[Code.Ident] = - Impl.env(_.getNextTmp) + Impl + .env(_.getNextTmp) .map { long => Code.Ident(s"___t$long") } @@ -182,8 +202,14 @@ object PythonGen { def topLevelName(n: Bindable): Env[Code.Ident] = Impl.env(_.topLevel(n)) - def onLastsM(cs: List[ValueLike])(fn: List[Expression] => Env[ValueLike]): Env[ValueLike] = { - def loop(cs: List[ValueLike], setup: List[Statement], args: List[Expression]): Env[ValueLike] = + def onLastsM( + cs: List[ValueLike] + )(fn: List[Expression] => Env[ValueLike]): Env[ValueLike] = { + def loop( + cs: List[ValueLike], + setup: List[Statement], + args: List[Expression] + ): Env[ValueLike] = cs match { case Nil => val res = fn(args.reverse) @@ -191,11 +217,11 @@ object PythonGen { case None => res case Some(nel) => val stmts = nel.reverse - val stmt = Code.block(stmts.head, stmts.tail :_*) + val stmt = Code.block(stmts.head, stmts.tail: _*) res.map(stmt.withValue(_)) } - case (e: Expression) :: t => loop(t, setup, e :: args) - case (ifelse@IfElse(_, _)) :: tail => + case (e: Expression) :: t => loop(t, setup, e :: args) + case (ifelse @ IfElse(_, _)) :: tail => // we allocate a result and assign // the result on each value Env.newAssignableVar.flatMap { v => @@ -208,31 +234,44 @@ object PythonGen { loop(cs, Nil, Nil) } - def onLasts(cs: List[ValueLike])(fn: List[Expression] => ValueLike): Env[ValueLike] = + def onLasts(cs: List[ValueLike])( + fn: List[Expression] => ValueLike + ): Env[ValueLike] = onLastsM(cs)(fn.andThen(Env.pure(_))) - def onLastM(c: ValueLike)(fn: Expression => Env[ValueLike]): Env[ValueLike] = + def onLastM( + c: ValueLike + )(fn: Expression => Env[ValueLike]): Env[ValueLike] = onLastsM(c :: Nil) { case x :: Nil => fn(x) - case other => + case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"expected list to have size 1: $other") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"expected list to have size 1: $other" + ) + // $COVERAGE-ON$ } def onLast(c: ValueLike)(fn: Expression => ValueLike): Env[ValueLike] = onLastM(c)(fn.andThen(Env.pure(_))) - def onLast2(c1: ValueLike, c2: ValueLike)(fn: (Expression, Expression) => ValueLike): Env[ValueLike] = + def onLast2(c1: ValueLike, c2: ValueLike)( + fn: (Expression, Expression) => ValueLike + ): Env[ValueLike] = onLasts(c1 :: c2 :: Nil) { case x1 :: x2 :: Nil => fn(x1, x2) - case other => + case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"expected list to have size 2: $other") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"expected list to have size 2: $other" + ) + // $COVERAGE-ON$ } - def ifElse(conds: NonEmptyList[(ValueLike, ValueLike)], elseV: ValueLike): Env[ValueLike] = { + def ifElse( + conds: NonEmptyList[(ValueLike, ValueLike)], + elseV: ValueLike + ): Env[ValueLike] = { // for all the non-expression conditions, we need to defer evaluating them // until they are really needed conds match { @@ -268,7 +307,11 @@ object PythonGen { } } - def ifElseS(cond: ValueLike, thenS: Statement, elseS: Statement): Env[Statement] = + def ifElseS( + cond: ValueLike, + thenS: Statement, + elseS: Statement + ): Env[Statement] = cond match { case x: Expression => Env.pure(Code.ifElseS(x, thenS, elseS)) case WithValue(stmt, vl) => @@ -287,7 +330,8 @@ object PythonGen { def andCode(c1: ValueLike, c2: ValueLike): Env[ValueLike] = (c1, c2) match { - case (t: Expression, c2) if t.simplify == Code.Const.True => Env.pure(c2) + case (t: Expression, c2) if t.simplify == Code.Const.True => + Env.pure(c2) case (_, x2: Expression) => onLast(c1)(_.evalAnd(x2)) case _ => @@ -300,7 +344,8 @@ object PythonGen { res <- Env.newAssignableVar ifstmt <- ifElseS(x1, res := c2, Code.Pass) } yield { - Code.block( + Code + .block( res := Code.Const.False, ifstmt ) @@ -309,54 +354,67 @@ object PythonGen { } } - def makeDef(defName: Code.Ident, arg: NonEmptyList[Code.Ident], v: ValueLike): Code.Def = + def makeDef( + defName: Code.Ident, + arg: NonEmptyList[Code.Ident], + v: ValueLike + ): Code.Def = Code.Def(defName, arg.toList, toReturn(v)) - def replaceTailCallWithAssign(name: Ident, argSize: Int, body: ValueLike)(onArgs: List[Expression] => Statement): Env[ValueLike] = { + def replaceTailCallWithAssign(name: Ident, argSize: Int, body: ValueLike)( + onArgs: List[Expression] => Statement + ): Env[ValueLike] = { val initBody = body def loop(body: ValueLike): Env[ValueLike] = body match { - case a@Apply(fn0, args0) => + case a @ Apply(fn0, args0) => if (fn0 == name) { if (args0.length == argSize) { val all = onArgs(args0) // set all the values and return the empty tuple Env.pure(all.withValue(Code.Const.Unit)) - } - else { + } else { // $COVERAGE-OFF$ - throw new IllegalStateException(s"expected a tailcall for $name in $initBody, but found: $a") + throw new IllegalStateException( + s"expected a tailcall for $name in $initBody, but found: $a" + ) // $COVERAGE-ON$ } - } - else { + } else { Env.pure(a) } case Parens(p) => loop(p).flatMap(onLast(_)(Parens(_))) case IfElse(ifCases, elseCase) => // only the result types are in tail position, we don't need to recurse on conds - val ifs = ifCases.traverse { case (cond, res) => loop(res).map((cond, _)) } + val ifs = ifCases.traverse { case (cond, res) => + loop(res).map((cond, _)) + } (ifs, loop(elseCase)) .mapN(ifElse(_, _)) .flatten case Ternary(ifTrue, cond, ifFalse) => // both results are in the tail position - (loop(ifTrue), loop(ifFalse)) - .mapN { (t, f) => - ifElse(NonEmptyList.one((cond, t)), f) - } - .flatten + (loop(ifTrue), loop(ifFalse)).mapN { (t, f) => + ifElse(NonEmptyList.one((cond, t)), f) + }.flatten case WithValue(stmt, v) => loop(v).map(stmt.withValue(_)) // the rest cannot have a call in the tail position - case DotSelect(_, _) | Op(_, _, _) | Lambda(_, _) | MakeTuple(_) | MakeList(_) | SelectItem(_, _) | SelectRange(_, _, _) | Ident(_) | PyBool(_) | PyString(_) | PyInt(_) => Env.pure(body) + case DotSelect(_, _) | Op(_, _, _) | Lambda(_, _) | MakeTuple(_) | + MakeList(_) | SelectItem(_, _) | SelectRange(_, _, _) | Ident(_) | + PyBool(_) | PyString(_) | PyInt(_) => + Env.pure(body) } loop(initBody) } // these are always recursive so we can use def to define them - def buildLoop(selfName: Ident, fnMutArgs: NonEmptyList[(Ident, Ident)], body: ValueLike): Env[Statement] = { + def buildLoop( + selfName: Ident, + fnMutArgs: NonEmptyList[(Ident, Ident)], + body: ValueLike + ): Env[Statement] = { /* * bodyUpdate = body except App(foo, args) is replaced with * reseting the inputs, and setting cont to True and having @@ -378,17 +436,14 @@ object PythonGen { // we could mutate a variable a later expression depends on // some times we generate code that does x = x, remove those cases val (left, right) = - mutArgs.toList.zip(args) - .filter { case (x, y) => x != y } - .unzip + mutArgs.toList.zip(args).filter { case (x, y) => x != y }.unzip Code.block( cont := Const.True, if (left.isEmpty) Pass else if (left.lengthCompare(1) == 0) { left.head := right.head - } - else { + } else { (MakeTuple(left) := MakeTuple(right)) } ) @@ -399,7 +454,9 @@ object PythonGen { ac = assignMut(cont)(fnArgs.toList) res <- Env.newAssignableVar ar = (res := Code.Const.Unit) - body1 <- replaceTailCallWithAssign(selfName, mutArgs.length, body)(assignMut(cont)) + body1 <- replaceTailCallWithAssign(selfName, mutArgs.length, body)( + assignMut(cont) + ) setRes = res := body1 loop = While(cont, (cont := false) +: setRes) newBody = (ac +: ar +: loop).withValue(res) @@ -408,10 +465,10 @@ object PythonGen { } - private[this] val base62Items = (('0' to '9') ++ ('A' to 'Z') ++ ('a' to 'z')).toSet + private[this] val base62Items = + (('0' to '9') ++ ('A' to 'Z') ++ ('a' to 'z')).toSet private def toBase62(c: Char): String = - if (base62Items(c)) c.toString else if (c == '_') "__" else { @@ -420,8 +477,7 @@ object PythonGen { // $COVERAGE-OFF$ sys.error(s"invalid in: $i0") // $COVERAGE-ON$ - } - else if (i0 < 10) (i0 + '0'.toInt).toChar + } else if (i0 < 10) (i0 + '0'.toInt).toChar else if (i0 < 36) (i0 - 10 + 'A'.toInt).toChar else if (i0 < 62) (i0 - 36 + 'a'.toInt).toChar else { @@ -445,11 +501,15 @@ object PythonGen { private def escapeRaw(prefix: String, str: String): String = str.map(toBase62).mkString(prefix, "", "") - private def unBase62(str: String, offset: Int, bldr: java.lang.StringBuilder): Int = { + private def unBase62( + str: String, + offset: Int, + bldr: java.lang.StringBuilder + ): Int = { var idx = offset var num = 0 - while(idx < str.length) { + while (idx < str.length) { val c = str.charAt(idx) idx += 1 if (c == '_') { @@ -458,14 +518,12 @@ object PythonGen { val numC = num.toChar bldr.append(numC) return (idx - offset) - } - else { + } else { // "__" decodes to "_" bldr.append('_') return (idx - offset) } - } - else { + } else { val base = if (c <= '9') '0'.toInt else if (c <= 'Z') ('A'.toInt - 10) @@ -489,7 +547,10 @@ object PythonGen { // ___b: shadowable (internal) names def escape(n: Bindable): Code.Ident = { val str = n.asString - if (!str.startsWith("___") && Code.python2Name.matcher(str).matches && !Code.pyKeywordList(str)) Code.Ident(str) + if ( + !str.startsWith("___") && Code.python2Name.matcher(str).matches && !Code + .pyKeywordList(str) + ) Code.Ident(str) else { // we need to escape Code.Ident(escapeRaw("___n", str)) @@ -497,7 +558,10 @@ object PythonGen { } def escapeModule(str: String): Code.Ident = { - if (!str.startsWith("___") && Code.python2Name.matcher(str).matches && !Code.pyKeywordList(str)) Code.Ident(str) + if ( + !str.startsWith("___") && Code.python2Name.matcher(str).matches && !Code + .pyKeywordList(str) + ) Code.Ident(str) else { // we need to escape Code.Ident(escapeRaw("___m", str)) @@ -518,15 +582,13 @@ object PythonGen { else { idx += res } - } - else { + } else { bldr.append(c) } } bldr.toString() - } - else { + } else { str } @@ -538,25 +600,27 @@ object PythonGen { } } - /** - * Remap is used to handle remapping external values - */ - private def apply(packName: PackageName, name: Bindable, me: Expr)(remap: (PackageName, Bindable) => Env[Option[ValueLike]]): Env[Statement] = { + /** Remap is used to handle remapping external values + */ + private def apply(packName: PackageName, name: Bindable, me: Expr)( + remap: (PackageName, Bindable) => Env[Option[ValueLike]] + ): Env[Statement] = { val ops = new Impl.Ops(packName, remap) // if we have a top level let rec with the same name, handle it more cleanly me match { - case Let(Right((n1, RecursionKind.NonRecursive)), inner, Local(n2)) if ((n1 === name) && (n2 === name)) => + case Let(Right((n1, RecursionKind.NonRecursive)), inner, Local(n2)) + if ((n1 === name) && (n2 === name)) => // we can just bind now at the top level for { nm <- Env.topLevelName(name) res <- inner match { case fn: FnExpr => ops.topFn(nm, fn, None) - case _ => ops.loop(inner, None).map(nm := _) + case _ => ops.loop(inner, None).map(nm := _) } } yield res case Let(Right((n1, RecursionKind.Recursive)), fn: FnExpr, Local(n2)) - if (n1 === name) && (n2 === name) => + if (n1 === name) && (n2 === name) => for { nm <- Env.topLevelName(name) res <- ops.topFn(nm, fn, None) @@ -582,10 +646,11 @@ object PythonGen { // def test_all(self): // # iterate through making assertions // - (Env.importLiteral(NonEmptyList.one(Code.Ident("unittest"))), + ( + Env.importLiteral(NonEmptyList.one(Code.Ident("unittest"))), Env.newAssignableVar, Env.topLevelName(name) - ) + ) .mapN { (importedName, tmpVar, testName) => import Impl._ @@ -598,48 +663,54 @@ object PythonGen { // Assertion(bool, msg) val testAssertion: Code.Statement = - Code.Call(Code.Apply(selfName.dot(Code.Ident("assertTrue")), - argName.get(1) :: argName.get(2) :: Nil)) + Code.Call( + Code.Apply( + selfName.dot(Code.Ident("assertTrue")), + argName.get(1) :: argName.get(2) :: Nil + ) + ) // TestSuite(suiteName, tests) val testSuite: Code.Statement = Code.block( tmpVar := argName.get(2), // get the test list - Code.While(isNonEmpty(tmpVar), + Code.While( + isNonEmpty(tmpVar), Code.block( Code.Call(Code.Apply(loopName, headList(tmpVar) :: Nil)), tmpVar := tailList(tmpVar) ) ) - ) + ) val loopBody: Code.Statement = Code.IfStatement( NonEmptyList.one((isAssertion, testAssertion)), - Some(testSuite)) + Some(testSuite) + ) val recTest = - Code.Def( - loopName, - argName :: Nil, - loopBody) + Code.Def(loopName, argName :: Nil, loopBody) val body = - Code.block( - recTest, - Code.Call(Code.Apply(loopName, testName :: Nil))) + Code.block(recTest, Code.Call(Code.Apply(loopName, testName :: Nil))) val defBody = - Code.Def(Code.Ident("test_all"), - selfName :: Nil, - body) + Code.Def(Code.Ident("test_all"), selfName :: Nil, body) - Code.ClassDef(Code.Ident("BosatsuTests"), List(importedName.dot(Code.Ident("TestCase"))), - defBody) + Code.ClassDef( + Code.Ident("BosatsuTests"), + List(importedName.dot(Code.Ident("TestCase"))), + defBody + ) } } - private def addMainEval(name: Bindable, mod: Module, ci: Code.Ident): Env[Statement] = + private def addMainEval( + name: Bindable, + mod: Module, + ci: Code.Ident + ): Env[Statement] = /* * this does: * if __name__ == "__main__": @@ -672,7 +743,10 @@ object PythonGen { Parser.dictLikeParser(Identifier.bindableParser, modParser) val outer: P[List[(PackageName, List[(Bindable, (Module, Code.Ident))])]] = - Parser.maybeSpacesAndLines.with1 *> Parser.dictLikeParser(PackageName.parser, inner) <* Parser.maybeSpacesAndLines + Parser.maybeSpacesAndLines.with1 *> Parser.dictLikeParser( + PackageName.parser, + inner + ) <* Parser.maybeSpacesAndLines outer.map { items => items.flatMap { case (p, bs) => @@ -684,27 +758,31 @@ object PythonGen { // parses a map of of evaluators // { fullyqualifiedType: foo.bar.baz, } val evaluatorParser: P[List[(Type, (Module, Code.Ident))]] = - Parser.maybeSpacesAndLines.with1 *> Parser.dictLikeParser(Type.fullyResolvedParser, modParser) <* Parser.maybeSpacesAndLines + Parser.maybeSpacesAndLines.with1 *> Parser.dictLikeParser( + Type.fullyResolvedParser, + modParser + ) <* Parser.maybeSpacesAndLines // compile a set of packages given a set of external remappings def renderAll( - pm: Map[PackageName, List[(Bindable, Expr)]], - externals: Map[(PackageName, Bindable), (Module, Code.Ident)], - tests: Map[PackageName, Bindable], - evaluators: Map[PackageName, (Bindable, Module, Code.Ident)])(implicit ec: Par.EC): Map[PackageName, (Module, Doc)] = { - - val externalRemap: (PackageName, Bindable) => Env[Option[ValueLike]] = - { (p, b) => + pm: Map[PackageName, List[(Bindable, Expr)]], + externals: Map[(PackageName, Bindable), (Module, Code.Ident)], + tests: Map[PackageName, Bindable], + evaluators: Map[PackageName, (Bindable, Module, Code.Ident)] + )(implicit ec: Par.EC): Map[PackageName, (Module, Doc)] = { + + val externalRemap: (PackageName, Bindable) => Env[Option[ValueLike]] = { + (p, b) => externals.get((p, b)) match { case None => Env.pure(None) case Some((m, i)) => - Env.importLiteral(m) + Env + .importLiteral(m) .map { alias => Some(Code.DotSelect(alias, i)) } } - } + } - val all = pm - .toList + val all = pm.toList .traverse { case (p, lets) => Par.start { val stmts0: Env[List[Statement]] = @@ -714,7 +792,9 @@ object PythonGen { } val evalStmt: Env[Option[Statement]] = - evaluators.get(p).traverse { case (b, m, c) => addMainEval(b, m, c) } + evaluators.get(p).traverse { case (b, m, c) => + addMainEval(b, m, c) + } val testStmt: Env[Option[Statement]] = tests.get(p).traverse(addUnitTest) @@ -763,92 +843,172 @@ object PythonGen { lst.get(2).simplify object PredefExternal { - private val cmpFn: List[ValueLike] => Env[ValueLike] = { - input => - Env.onLast2(input.head, input.tail.head) { (arg0, arg1) => - Code.Ternary( - 0, - arg0 :< arg1, - Code.Ternary(1, arg0 =:= arg1, 2) - ).simplify - } + private val cmpFn: List[ValueLike] => Env[ValueLike] = { input => + Env.onLast2(input.head, input.tail.head) { (arg0, arg1) => + Code + .Ternary( + 0, + arg0 :< arg1, + Code.Ternary(1, arg0 =:= arg1, 2) + ) + .simplify + } } val results: Map[Bindable, (List[ValueLike] => Env[ValueLike], Int)] = Map( - (Identifier.unsafeBindable("add"), + ( + Identifier.unsafeBindable("add"), ( - input => Env.onLast2(input.head, input.tail.head)(_.evalPlus(_)) - , 2)), - (Identifier.unsafeBindable("sub"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.evalMinus(_)) - } , 2)), - (Identifier.unsafeBindable("times"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.evalTimes(_)) - }, 2)), - (Identifier.unsafeBindable("div"), - ({ - input => Env.onLast2(input.head, input.tail.head) { (a, b) => - Code.Ternary( - Code.Op(a, Code.Const.Div, b), - b, // 0 is false in python - 0 - ).simplify - } - }, 2)), - (Identifier.unsafeBindable("mod_Int"), - ({ - input => Env.onLast2(input.head, input.tail.head) { (a, b) => - Code.Ternary( - Code.Op(a, Code.Const.Mod, b), - b, // 0 is false in python - a - ).simplify - } - }, 2)), + input => Env.onLast2(input.head, input.tail.head)(_.evalPlus(_)), + 2 + ) + ), + ( + Identifier.unsafeBindable("sub"), + ( + { input => + Env.onLast2(input.head, input.tail.head)(_.evalMinus(_)) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("times"), + ( + { input => + Env.onLast2(input.head, input.tail.head)(_.evalTimes(_)) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("div"), + ( + { input => + Env.onLast2(input.head, input.tail.head) { (a, b) => + Code + .Ternary( + Code.Op(a, Code.Const.Div, b), + b, // 0 is false in python + 0 + ) + .simplify + } + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("mod_Int"), + ( + { input => + Env.onLast2(input.head, input.tail.head) { (a, b) => + Code + .Ternary( + Code.Op(a, Code.Const.Mod, b), + b, // 0 is false in python + a + ) + .simplify + } + }, + 2 + ) + ), (Identifier.unsafeBindable("cmp_Int"), (cmpFn, 2)), - (Identifier.unsafeBindable("eq_Int"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.eval(Code.Const.Eq, _)) - }, 2)), - - (Identifier.unsafeBindable("shift_left_Int"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.eval(Code.Const.BitwiseShiftLeft, _)) - }, 2)), - (Identifier.unsafeBindable("shift_right_Int"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.eval(Code.Const.BitwiseShiftRight, _)) - }, 2)), - (Identifier.unsafeBindable("and_Int"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.eval(Code.Const.BitwiseAnd, _)) - }, 2)), - (Identifier.unsafeBindable("or_Int"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.eval(Code.Const.BitwiseOr, _)) - }, 2)), - (Identifier.unsafeBindable("xor_Int"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.eval(Code.Const.BitwiseXor, _)) - }, 2)), - (Identifier.unsafeBindable("not_Int"), - ({ - // leverage not(x) == -1 - x - input => Env.onLast(input.head)(Code.fromInt(-1).evalMinus(_)) - }, 2)), - (Identifier.unsafeBindable("gcd_Int"), - ({ - input => - (Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar) - .mapN { (tmpa, tmpb, tmpc) => - Env.onLast2(input.head, input.tail.head) { (a, b) => - Code.block( + ( + Identifier.unsafeBindable("eq_Int"), + ( + { input => + Env.onLast2(input.head, input.tail.head)( + _.eval(Code.Const.Eq, _) + ) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("shift_left_Int"), + ( + { input => + Env.onLast2(input.head, input.tail.head)( + _.eval(Code.Const.BitwiseShiftLeft, _) + ) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("shift_right_Int"), + ( + { input => + Env.onLast2(input.head, input.tail.head)( + _.eval(Code.Const.BitwiseShiftRight, _) + ) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("and_Int"), + ( + { input => + Env.onLast2(input.head, input.tail.head)( + _.eval(Code.Const.BitwiseAnd, _) + ) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("or_Int"), + ( + { input => + Env.onLast2(input.head, input.tail.head)( + _.eval(Code.Const.BitwiseOr, _) + ) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("xor_Int"), + ( + { input => + Env.onLast2(input.head, input.tail.head)( + _.eval(Code.Const.BitwiseXor, _) + ) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("not_Int"), + ( + { + // leverage not(x) == -1 - x + input => Env.onLast(input.head)(Code.fromInt(-1).evalMinus(_)) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("gcd_Int"), + ( + { input => + ( + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar + ).mapN { (tmpa, tmpb, tmpc) => + Env.onLast2(input.head, input.tail.head) { (a, b) => + Code + .block( tmpa := a, tmpb := b, - Code.While(tmpb, + Code.While( + tmpb, Code.block( tmpc := tmpb, // we know b != 0 because we are in the while loop @@ -859,165 +1019,215 @@ object PythonGen { ) ) .withValue(tmpa) - } } - .flatten - }, 2)), - //external def int_loop(intValue: Int, state: a, fn: (Int, a) -> (Int, a) -> a - // def int_loop(i, a, fn): - // if i <= 0: a - // else: - // (i1, a1) = fn(i, a) - // if i <= i1: a - // else int_loop(i1, a, fn) - // - // def int_loop(i, a, fn): - // cont = (0 < i) - // res = a - // _i = i - // _a = a - // while cont: - // res = fn(_i, _a) - // tmp_i = res[0] - // _a = res[1][0] - // cont = (0 < tmp_i) and (tmp_i < _i) - // _i = tmp_i - // return _a - (Identifier.unsafeBindable("int_loop"), - ({ - input => - (Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar) - .tupled + }.flatten + }, + 2 + ) + ), + // external def int_loop(intValue: Int, state: a, fn: (Int, a) -> (Int, a) -> a + // def int_loop(i, a, fn): + // if i <= 0: a + // else: + // (i1, a1) = fn(i, a) + // if i <= i1: a + // else int_loop(i1, a, fn) + // + // def int_loop(i, a, fn): + // cont = (0 < i) + // res = a + // _i = i + // _a = a + // while cont: + // res = fn(_i, _a) + // tmp_i = res[0] + // _a = res[1][0] + // cont = (0 < tmp_i) and (tmp_i < _i) + // _i = tmp_i + // return _a + ( + Identifier.unsafeBindable("int_loop"), + ( + { input => + ( + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar + ).tupled .flatMap { case (cont, res, _i, _a, tmp_i) => Env.onLasts(input) { case i :: a :: fn :: Nil => - Code.block( - cont := (Code.fromInt(0) :< i), - res := a, - _i := i, - _a := a, - Code.While(cont, { - fn(_i, _a).simplify match { - case Code.MakeTuple(fst :: snd :: Nil) => - // inline the tuple allocation and destructuring - Code.block( - tmp_i := fst, - _a := snd, - cont := (Code.fromInt(0) :< tmp_i).evalAnd(tmp_i :< _i), - _i := tmp_i - ) - case notTup => - Code.block( - res := notTup, - tmp_i := res.get(0), - _a := res.get(1), - cont := (Code.fromInt(0) :< tmp_i).evalAnd(tmp_i :< _i), - _i := tmp_i - ) - } - } + Code + .block( + cont := (Code.fromInt(0) :< i), + res := a, + _i := i, + _a := a, + Code.While( + cont, { + fn(_i, _a).simplify match { + case Code.MakeTuple(fst :: snd :: Nil) => + // inline the tuple allocation and destructuring + Code.block( + tmp_i := fst, + _a := snd, + cont := (Code.fromInt(0) :< tmp_i) + .evalAnd(tmp_i :< _i), + _i := tmp_i + ) + case notTup => + Code.block( + res := notTup, + tmp_i := res.get(0), + _a := res.get(1), + cont := (Code.fromInt(0) :< tmp_i) + .evalAnd(tmp_i :< _i), + _i := tmp_i + ) + } + } + ) ) - ) - .withValue(_a) + .withValue(_a) case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"expected arity 3 got: $other") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"expected arity 3 got: $other" + ) + // $COVERAGE-ON$ } } - }, 3)), - (Identifier.unsafeBindable("concat_String"), - ({ input => - Env.onLastM(input.head) { listOfStrings => - // convert to python list, then call "".join(seq) - Env.newAssignableVar - .flatMap { pyList => - bosatsuListToPython(pyList, listOfStrings) - .map { loop => - Code.block( - pyList := Code.MakeList(Nil), - loop - ) - .withValue { - Code.PyString("").dot(Code.Ident("join"))(pyList) - } + }, + 3 + ) + ), + ( + Identifier.unsafeBindable("concat_String"), + ( + { input => + Env.onLastM(input.head) { listOfStrings => + // convert to python list, then call "".join(seq) + Env.newAssignableVar + .flatMap { pyList => + bosatsuListToPython(pyList, listOfStrings) + .map { loop => + Code + .block( + pyList := Code.MakeList(Nil), + loop + ) + .withValue { + Code.PyString("").dot(Code.Ident("join"))(pyList) + } + } } } - } - }, 1)), - (Identifier.unsafeBindable("int_to_String"), - ({ - input => Env.onLast(input.head) { - case Code.PyInt(i) => Code.PyString(i.toString) - case i => Code.Apply(Code.DotSelect(i, Code.Ident("__str__")), Nil) - } - }, 1)), - (Identifier.unsafeBindable("char_to_String"), - // we encode chars as strings so this is just identity - ({ input => Env.envMonad.pure(input.head) }, 1)), - (Identifier.unsafeBindable("trace"), - ({ - input => Env.onLast2(input.head, input.tail.head) { (msg, i) => - Code.Call(Code.Apply(Code.Ident("print"), msg :: i :: Nil)) - .withValue(i) - } - }, 2)), - (Identifier.unsafeBindable("partition_String"), - ({ - input => - Env.newAssignableVar - .flatMap { res => - Env.onLast2(input.head, input.tail.head) { (str, sep) => - // if sep == "": None - // else: - // (a, s1, b) = str.partition(sep) - // if s1: (1, (a, b)) - // else: (0, ) - val a = res.get(0) - val s1 = res.get(1) - val b = res.get(2) - val success = Code.MakeTuple(Code.fromInt(1) :: - Code.MakeTuple(a :: b :: Nil) :: - Nil - ) - val fail = Code.MakeTuple(Code.fromInt(0) :: Nil) - val nonEmpty = - (res := str.dot(Code.Ident("partition"))(sep)) - .withValue(Code.Ternary(success, s1, fail)) - - Code.IfElse(NonEmptyList.one((sep, nonEmpty)), fail) + }, + 1 + ) + ), + ( + Identifier.unsafeBindable("int_to_String"), + ( + { input => + Env.onLast(input.head) { + case Code.PyInt(i) => Code.PyString(i.toString) + case i => + Code.Apply(Code.DotSelect(i, Code.Ident("__str__")), Nil) } - } - }, 2)), - (Identifier.unsafeBindable("rpartition_String"), - ({ - input => - Env.newAssignableVar - .flatMap { res => - Env.onLast2(input.head, input.tail.head) { (str, sep) => - // (a, s1, b) = str.partition(sep) - // if s1: (1, (a, b)) - // else: (0, ) - val a = res.get(0) - val s1 = res.get(1) - val b = res.get(2) - val success = Code.MakeTuple(Code.fromInt(1) :: - Code.MakeTuple(a :: b :: Nil) :: - Nil - ) - val fail = Code.MakeTuple(Code.fromInt(0) :: Nil) - val nonEmpty = - (res := str.dot(Code.Ident("rpartition"))(sep)) - .withValue(Code.Ternary(success, s1, fail)) - - Code.IfElse(NonEmptyList.one((sep, nonEmpty)), fail) + }, + 1 + ) + ), + ( + Identifier.unsafeBindable("char_to_String"), + // we encode chars as strings so this is just identity + ({ input => Env.envMonad.pure(input.head) }, 1) + ), + ( + Identifier.unsafeBindable("trace"), + ( + { input => + Env.onLast2(input.head, input.tail.head) { (msg, i) => + Code + .Call(Code.Apply(Code.Ident("print"), msg :: i :: Nil)) + .withValue(i) } - } - }, 2)), - (Identifier.unsafeBindable("string_Order_fn"), (cmpFn, 2)) - ) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("partition_String"), + ( + { input => + Env.newAssignableVar + .flatMap { res => + Env.onLast2(input.head, input.tail.head) { (str, sep) => + // if sep == "": None + // else: + // (a, s1, b) = str.partition(sep) + // if s1: (1, (a, b)) + // else: (0, ) + val a = res.get(0) + val s1 = res.get(1) + val b = res.get(2) + val success = Code.MakeTuple( + Code.fromInt(1) :: + Code.MakeTuple(a :: b :: Nil) :: + Nil + ) + val fail = Code.MakeTuple(Code.fromInt(0) :: Nil) + val nonEmpty = + (res := str.dot(Code.Ident("partition"))(sep)) + .withValue(Code.Ternary(success, s1, fail)) + + Code.IfElse(NonEmptyList.one((sep, nonEmpty)), fail) + } + } + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("rpartition_String"), + ( + { input => + Env.newAssignableVar + .flatMap { res => + Env.onLast2(input.head, input.tail.head) { (str, sep) => + // (a, s1, b) = str.partition(sep) + // if s1: (1, (a, b)) + // else: (0, ) + val a = res.get(0) + val s1 = res.get(1) + val b = res.get(2) + val success = Code.MakeTuple( + Code.fromInt(1) :: + Code.MakeTuple(a :: b :: Nil) :: + Nil + ) + val fail = Code.MakeTuple(Code.fromInt(0) :: Nil) + val nonEmpty = + (res := str.dot(Code.Ident("rpartition"))(sep)) + .withValue(Code.Ternary(success, s1, fail)) - def bosatsuListToPython(pyList: Code.Ident, bList: Expression): Env[Statement] = + Code.IfElse(NonEmptyList.one((sep, nonEmpty)), fail) + } + } + }, + 2 + ) + ), + (Identifier.unsafeBindable("string_Order_fn"), (cmpFn, 2)) + ) + + def bosatsuListToPython( + pyList: Code.Ident, + bList: Expression + ): Env[Statement] = Env.newAssignableVar .map { tmp => // tmp = bList @@ -1027,7 +1237,8 @@ object PythonGen { // tmp = tmp[2] Code.block( tmp := bList, - Code.While(isNonEmpty(tmp), + Code.While( + isNonEmpty(tmp), Code.block( Code.Call(pyList.dot(Code.Ident("append"))(headList(tmp))), tmp := tailList(tmp) @@ -1036,13 +1247,17 @@ object PythonGen { ) } - def unapply(expr: Expr): Option[(List[ValueLike] => Env[ValueLike], Int)] = + def unapply( + expr: Expr + ): Option[(List[ValueLike] => Env[ValueLike], Int)] = expr match { case Global(PackageName.PredefName, name) => results.get(name) - case _ => None + case _ => None } - def makeLambda(arity: Int)(fn: List[ValueLike] => Env[ValueLike]): Env[ValueLike] = + def makeLambda( + arity: Int + )(fn: List[ValueLike] => Env[ValueLike]): Env[ValueLike] = for { vars <- (1 to arity).toList.traverse(_ => Env.newAssignableVar) body <- fn(vars) @@ -1053,7 +1268,10 @@ object PythonGen { } yield res } - class Ops(packName: PackageName, remap: (PackageName, Bindable) => Env[Option[ValueLike]]) { + class Ops( + packName: PackageName, + remap: (PackageName, Bindable) => Env[Option[ValueLike]] + ) { /* * enums with no fields are integers * enums and structs are tuples @@ -1077,9 +1295,9 @@ object PythonGen { Env.onLasts(vExpr :: args)(Code.MakeTuple(_)) } case MakeStruct(arity) => - if (arity == 0) Env.pure(Code.Const.Unit) - else if (arity == 1) Env.pure(args.head) - else Env.onLasts(args)(Code.MakeTuple(_)) + if (arity == 0) Env.pure(Code.Const.Unit) + else if (arity == 1) Env.pure(args.head) + else Env.onLasts(args)(Code.MakeTuple(_)) case ZeroNat => Env.pure(Code.Const.Zero) case SuccNat => @@ -1094,8 +1312,7 @@ object PythonGen { // $COVERAGE-OFF$ throw new IllegalStateException(s"invalid arity $sz for $ce") // $COVERAGE-ON$ - } - else { + } else { // this is the case where we are using the constructor like a function assert(args.isEmpty) for { @@ -1138,58 +1355,68 @@ object PythonGen { loop(enumV, slotName).flatMap { tup => Env.onLast(tup) { t => (if (useInts) { - // this is represented as an integer - t =:= idx - } - else - t.get(0) =:= idx).simplify + // this is represented as an integer + t =:= idx + } else + t.get(0) =:= idx).simplify } } case SetMut(LocalAnonMut(mut), expr) => - (Env.nameForAnon(mut), loop(expr, slotName)) - .mapN { (ident, result) => + (Env.nameForAnon(mut), loop(expr, slotName)).mapN { + (ident, result) => Env.onLast(result) { resx => (ident := resx).withValue(Code.Const.True) } - } - .flatten + }.flatten case MatchString(str, pat, binds) => - (loop(str, slotName), binds.traverse { case LocalAnonMut(m) => Env.nameForAnon(m) }) - .mapN { (strVL, binds) => - Env.onLastM(strVL)(matchString(_, pat, binds)) - } - .flatten + ( + loop(str, slotName), + binds.traverse { case LocalAnonMut(m) => Env.nameForAnon(m) } + ).mapN { (strVL, binds) => + Env.onLastM(strVL)(matchString(_, pat, binds)) + }.flatten case SearchList(locMut, init, check, optLeft) => // check to see if we can find a non-empty // list that matches check - (loop(init, slotName), boolExpr(check, slotName)) - .mapN { (initVL, checkVL) => + (loop(init, slotName), boolExpr(check, slotName)).mapN { + (initVL, checkVL) => searchList(locMut, initVL, checkVL, optLeft) - } - .flatten + }.flatten } - def matchString(strEx: Expression, pat: List[StrPart], binds: List[Code.Ident]): Env[ValueLike] = { + def matchString( + strEx: Expression, + pat: List[StrPart], + binds: List[Code.Ident] + ): Env[ValueLike] = { import StrPart.{LitStr, Glob, CharPart} val bindArray = binds.toArray // return a value like expression that contains the boolean result // and assigns all the bindings along the way - def loop(offsetIdent: Code.Ident, pat: List[StrPart], next: Int): Env[ValueLike] = + def loop( + offsetIdent: Code.Ident, + pat: List[StrPart], + next: Int + ): Env[ValueLike] = pat match { case Nil => - //offset == str.length + // offset == str.length Env.pure(offsetIdent =:= strEx.len()) case LitStr(expect) :: tail => - //val len = expect.length - //str.regionMatches(offset, expect, 0, len) && loop(offset + len, tail, next) + // val len = expect.length + // str.regionMatches(offset, expect, 0, len) && loop(offset + len, tail, next) // // strEx.startswith(expect, offsetIdent) loop(offsetIdent, tail, next) .flatMap { loopRes => - val regionMatches = strEx.dot(Code.Ident("startswith"))(expect, offsetIdent) + val regionMatches = + strEx.dot(Code.Ident("startswith"))(expect, offsetIdent) val rest = ( - offsetIdent := offsetIdent + expect.codePointCount(0, expect.length) + offsetIdent := offsetIdent + expect.codePointCount( + 0, + expect.length + ) ).withValue(loopRes) Env.andCode(regionMatches, rest) @@ -1200,13 +1427,13 @@ object PythonGen { val stmt = if (c.capture) { // b = str[offset] - Code.block( - bindArray(next) := Code.SelectItem(strEx, offsetIdent), - offsetIdent := offsetIdent + 1 - ) - .withValue(true) - } - else (offsetIdent := offsetIdent + 1).withValue(true) + Code + .block( + bindArray(next) := Code.SelectItem(strEx, offsetIdent), + offsetIdent := offsetIdent + 1 + ) + .withValue(true) + } else (offsetIdent := offsetIdent + 1).withValue(true) for { tailRes <- loop(offsetIdent, tail, n1) and2 <- Env.andCode(stmt, tailRes) @@ -1219,10 +1446,10 @@ object PythonGen { Env.pure( if (h.capture) { // b = str[offset:] - (bindArray(next) := Code.SelectRange(strEx, Some(offsetIdent), None)) + (bindArray(next) := Code + .SelectRange(strEx, Some(offsetIdent), None)) .withValue(true) - } - else Code.Const.True + } else Code.Const.True ) case LitStr(expect) :: tail2 => // here we have to make a loop @@ -1262,88 +1489,111 @@ object PythonGen { } } result - */ - (Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar) - .mapN { (start, result, candidate, candOffset) => - val searchEnv = loop(candOffset, tail2, next1) - - def onSearch(search: ValueLike): Env[Statement] = - Env.ifElseS(search, - { - // we have matched - val capture = if (h.capture) (bindArray(next) := Code.SelectRange(strEx, Some(offsetIdent), Some(candidate))) else Code.Pass - Code.block( - capture, - result := true, - start := -1 - ) - }, - { - // we couldn't match at start, advance just after the - // candidate - start := candidate + 1 - }) - - def findBranch(search: ValueLike): Env[Statement] = - onSearch(search) - .flatMap { onS => - Env.ifElseS( - candidate :> -1, - // update candidate and search - Code.block( - candOffset := candidate + expect.codePointCount(0, expect.length), - onS) - , - // else no more candidates - start := -1 - ) - } + */ + ( + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar + ).mapN { (start, result, candidate, candOffset) => + val searchEnv = loop(candOffset, tail2, next1) + + def onSearch(search: ValueLike): Env[Statement] = + Env.ifElseS( + search, { + // we have matched + val capture = + if (h.capture) + (bindArray(next) := Code.SelectRange( + strEx, + Some(offsetIdent), + Some(candidate) + )) + else Code.Pass + Code.block( + capture, + result := true, + start := -1 + ) + }, { + // we couldn't match at start, advance just after the + // candidate + start := candidate + 1 + } + ) - for { - search <- searchEnv - find <- findBranch(search) - } yield - (Code.block( - start := offsetIdent, - result := false, - Code.While((start :> -1), + def findBranch(search: ValueLike): Env[Statement] = + onSearch(search) + .flatMap { onS => + Env.ifElseS( + candidate :> -1, + // update candidate and search Code.block( - candidate := strEx.dot(Code.Ident("find"))(expect, start), - find - ) + candOffset := candidate + expect.codePointCount( + 0, + expect.length + ), + onS + ), + // else no more candidates + start := -1 + ) + } + + for { + search <- searchEnv + find <- findBranch(search) + } yield (Code + .block( + start := offsetIdent, + result := false, + Code.While( + (start :> -1), + Code.block( + candidate := strEx + .dot(Code.Ident("find"))(expect, start), + find ) ) - .withValue(result)) - } - .flatten + ) + .withValue(result)) + }.flatten case (_: CharPart) :: _ => val next1 = if (h.capture) (next + 1) else next for { matched <- Env.newAssignableVar off1 <- Env.newAssignableVar - tailMatched <- loop(off1, tail, next1) - - matchStmt = Code.block( - matched := false, - off1 := offsetIdent, - Code.While((!matched).evalAnd(off1 :< strEx.len()), - matched := tailMatched // the tail match increments the + tailMatched <- loop(off1, tail, next1) + + matchStmt = Code + .block( + matched := false, + off1 := offsetIdent, + Code.While( + (!matched).evalAnd(off1 :< strEx.len()), + matched := tailMatched // the tail match increments the + ) ) - ).withValue(matched) + .withValue(matched) fullMatch <- if (!h.capture) Env.pure(matchStmt) else { - val capture = Code.block( - bindArray(next) := Code.SelectRange(strEx, Some(offsetIdent), Some(off1)) - ).withValue(true) + val capture = Code + .block( + bindArray(next) := Code + .SelectRange(strEx, Some(offsetIdent), Some(off1)) + ) + .withValue(true) Env.andCode(matchStmt, capture) } } yield fullMatch // $COVERAGE-OFF$ case (_: Glob) :: _ => - throw new IllegalArgumentException(s"pattern: $pat should have been prevented: adjacent globs are not permitted (one is always empty)") + throw new IllegalArgumentException( + s"pattern: $pat should have been prevented: adjacent globs are not permitted (one is always empty)" + ) // $COVERAGE-ON$ } } @@ -1354,7 +1604,12 @@ object PythonGen { } yield (offsetIdent := 0).withValue(res) } - def searchList(locMut: LocalAnonMut, initVL: ValueLike, checkVL: ValueLike, optLeft: Option[LocalAnonMut]): Env[ValueLike] = { + def searchList( + locMut: LocalAnonMut, + initVL: ValueLike, + checkVL: ValueLike, + optLeft: Option[LocalAnonMut] + ): Env[ValueLike] = { /* * here is the implementation from MatchlessToValue * @@ -1380,38 +1635,49 @@ object PythonGen { } res } - */ - (Env.nameForAnon(locMut.ident), optLeft.traverse { lm => Env.nameForAnon(lm.ident) }, Env.newAssignableVar, Env.newAssignableVar) - .mapN { (currentList, optLeft, res, tmpList) => + */ + ( + Env.nameForAnon(locMut.ident), + optLeft.traverse { lm => Env.nameForAnon(lm.ident) }, + Env.newAssignableVar, + Env.newAssignableVar + ) + .mapN { (currentList, optLeft, res, tmpList) => Code .block( res := Code.Const.False, tmpList := initVL, optLeft.fold(Code.pass)(_ := emptyList), // we don't match empty lists, so if currentList reaches Empty we are done - Code.While(isNonEmpty(tmpList), + Code.While( + isNonEmpty(tmpList), Code.block( - currentList := tmpList, - res := checkVL, - Code.ifElseS( - res, - tmpList := emptyList, - { - Code.block( - tmpList := tailList(tmpList), - optLeft.fold(Code.pass) { left => - left := consList(headList(currentList), left) - } - ) - }) + currentList := tmpList, + res := checkVL, + Code.ifElseS( + res, + tmpList := emptyList, { + Code.block( + tmpList := tailList(tmpList), + optLeft.fold(Code.pass) { left => + left := consList(headList(currentList), left) + } + ) + } ) ) - ).withValue(res) - } + ) + ) + .withValue(res) + } } // if expr is a LoopFn or Lambda handle it - def topFn(name: Code.Ident, expr: FnExpr, slotName: Option[Code.Ident]): Env[Statement] = + def topFn( + name: Code.Ident, + expr: FnExpr, + slotName: Option[Code.Ident] + ): Env[Statement] = expr match { case LoopFn(captures, _, args, b) => // note, name is already bound @@ -1438,19 +1704,22 @@ object PythonGen { // we can ignore name because python already allows recursion // we can use topLevelName on makeDefs since they are already // shadowing in the same rules as bosatsu - (args.traverse(Env.topLevelName(_)), makeSlots(captures, slotName)(loop(body, _))) - .mapN { - case (as, (slots, body)) => - Code.blockFromList( - slots.toList ::: + ( + args.traverse(Env.topLevelName(_)), + makeSlots(captures, slotName)(loop(body, _)) + ) + .mapN { case (as, (slots, body)) => + Code.blockFromList( + slots.toList ::: Env.makeDef(name, as, body) :: Nil - ) + ) } } - def makeSlots[A](captures: List[Expr], - slotName: Option[Code.Ident])(fn: Option[Code.Ident] => Env[A]): Env[(Option[Statement], A)] = + def makeSlots[A](captures: List[Expr], slotName: Option[Code.Ident])( + fn: Option[Code.Ident] => Env[A] + ): Env[(Option[Statement], A)] = if (captures.isEmpty) fn(None).map((None, _)) else { for { @@ -1467,7 +1736,10 @@ object PythonGen { // we ignore name because python already supports recursion // we can use topLevelName on makeDefs since they are already // shadowing in the same rules as bosatsu - (args.traverse(Env.topLevelName(_)), makeSlots(captures, slotName)(loop(res, _))) + ( + args.traverse(Env.topLevelName(_)), + makeSlots(captures, slotName)(loop(res, _)) + ) .flatMapN { case (args, (None, x: Expression)) => Env.pure(Code.Lambda(args.toList, x)) @@ -1500,7 +1772,9 @@ object PythonGen { loopRes <- Env.buildLoop(nameI, subs1, body) // we have bound the args twice: once as args, once as interal muts _ <- subs.traverse_ { case (a, _) => Env.unbind(a) } - } yield Code.blockFromList(prefix.toList :+ loopRes).withValue(nameI) + } yield Code + .blockFromList(prefix.toList :+ loopRes) + .withValue(nameI) case PredefExternal((fn, arity)) => // make a lambda @@ -1513,42 +1787,46 @@ object PythonGen { if (p == packName) { // This is just a name in the local package Env.topLevelName(n) - } - else { - (Env.importPackage(p), Env.topLevelName(n)).mapN(Code.DotSelect(_, _)) + } else { + (Env.importPackage(p), Env.topLevelName(n)) + .mapN(Code.DotSelect(_, _)) } } - case Local(b) => Env.deref(b) - case LocalAnon(a) => Env.nameForAnon(a) + case Local(b) => Env.deref(b) + case LocalAnon(a) => Env.nameForAnon(a) case LocalAnonMut(m) => Env.nameForAnon(m) case ClosureSlot(idx) => slotName match { case Some(ident) => Env.pure(ident.get(idx)) - case None => + case None => // $COVERAGE-OFF$ // this should be impossible for well formed Matchless AST - throw new IllegalStateException(s"saw $expr when there is no defined slot") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"saw $expr when there is no defined slot" + ) + // $COVERAGE-ON$ } case App(PredefExternal((fn, _)), args) => - args - .toList + args.toList .traverse(loop(_, slotName)) .flatMap(fn) case App(cons: ConsExpr, args) => - args.traverse(loop(_, slotName)).flatMap { pxs => makeCons(cons, pxs.toList) } + args.traverse(loop(_, slotName)).flatMap { pxs => + makeCons(cons, pxs.toList) + } case App(expr, args) => - (loop(expr, slotName), args.traverse(loop(_, slotName))) - .mapN { (fn, args) => + (loop(expr, slotName), args.traverse(loop(_, slotName))).mapN { + (fn, args) => Env.onLasts(fn :: args.toList) { case fn :: args => Code.Apply(fn, args) - case other => + case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"got $other, expected to match $expr") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"got $other, expected to match $expr" + ) + // $COVERAGE-ON$ } - } - .flatten + }.flatten case Let(localOrBind, fn: FnExpr, in) => val inF = loop(in, slotName) @@ -1565,7 +1843,8 @@ object PythonGen { } yield tl.withValue(ine) case Left(LocalAnon(l)) => // anonymous names never shadow - Env.nameForAnon(l) + Env + .nameForAnon(l) .flatMap { bi => val v = topFn(bi, fn, slotName) (v, inF).mapN(_.withValue(_)) @@ -1585,8 +1864,7 @@ object PythonGen { ine <- inF _ <- Env.unbind(b) } yield ((bi := v).withValue(ine)) - } - else { + } else { // value b is in scope after ve for { ve <- loop(notFn, slotName) @@ -1624,11 +1902,9 @@ object PythonGen { (boolExpr(c, slotName), loop(t, slotName)).tupled } - (ifsV, loop(last, slotName)) - .mapN { (ifs, elseV) => - Env.ifElse(ifs, elseV) - } - .flatten + (ifsV, loop(last, slotName)).mapN { (ifs, elseV) => + Env.ifElse(ifs, elseV) + }.flatten case Always(cond, expr) => (boolExpr(cond, slotName).map(Code.always), loop(expr, slotName)) @@ -1646,8 +1922,7 @@ object PythonGen { if (sz == 1) { // we don't bother to wrap single item structs exprR - } - else { + } else { // structs are just tuples exprR.flatMap { tup => Env.onLast(tup)(_.get(idx)) diff --git a/core/src/main/scala/org/bykn/bosatsu/graph/Dag.scala b/core/src/main/scala/org/bykn/bosatsu/graph/Dag.scala index 1eecae403..9afd19c83 100644 --- a/core/src/main/scala/org/bykn/bosatsu/graph/Dag.scala +++ b/core/src/main/scala/org/bykn/bosatsu/graph/Dag.scala @@ -29,8 +29,7 @@ sealed trait Dag[A] { def toToposorted: Toposort.Success[A] = { val layerMap: Map[Int, SortedSet[A]] = nodes.groupBy(layerOf(_)) - val ls = (0 until layers) - .iterator + val ls = (0 until layers).iterator .map { idx => // by construction all layers have at least 1 item NonEmptyList.fromListUnsafe(layerMap(idx).toList) diff --git a/core/src/main/scala/org/bykn/bosatsu/graph/Memoize.scala b/core/src/main/scala/org/bykn/bosatsu/graph/Memoize.scala index e4cb03aaa..f221e6382 100644 --- a/core/src/main/scala/org/bykn/bosatsu/graph/Memoize.scala +++ b/core/src/main/scala/org/bykn/bosatsu/graph/Memoize.scala @@ -6,11 +6,13 @@ import scala.collection.immutable.SortedMap object Memoize { - /** - * This memoizes using a sorted map (not a hashMap) in a non-threadsafe manner - * returning None, means we cannot compute this function because it loops forever - */ - def memoizeSorted[A: Ordering, B](fn: (A, A => Option[B]) => Option[B]): A => Option[B] = { + /** This memoizes using a sorted map (not a hashMap) in a non-threadsafe + * manner returning None, means we cannot compute this function because it + * loops forever + */ + def memoizeSorted[A: Ordering, B]( + fn: (A, A => Option[B]) => Option[B] + ): A => Option[B] = { var cache = SortedMap.empty[A, Option[B]] new Function[A, Option[B]] { self => @@ -29,10 +31,9 @@ object Memoize { } } - /** - * This memoizes using a hash map in a non-threadsafe manner - * this throws if you don't have a dag - */ + /** This memoizes using a hash map in a non-threadsafe manner this throws if + * you don't have a dag + */ def memoizeDagHashed[A, B](fn: (A, A => B) => B): A => B = { var cache = Map.empty[A, Option[B]] @@ -48,15 +49,14 @@ object Memoize { cache = cache.updated(a, Some(b)) b case Some(Some(b)) => b - case Some(None) => sys.error(s"loop found evaluating $a") + case Some(None) => sys.error(s"loop found evaluating $a") } } } - /** - * This memoizes using a hash map in a threadsafe manner - * it may loop forever and stack overflow if you don't have a DAG - */ + /** This memoizes using a hash map in a threadsafe manner it may loop forever + * and stack overflow if you don't have a DAG + */ def memoizeDagHashedConcurrent[A, B](fn: (A, A => B) => B): A => B = { val cache: ConcurrentHashMap[A, B] = new ConcurrentHashMap[A, B]() @@ -77,12 +77,14 @@ object Memoize { } } - /** - * This memoizes using a hash map in a threadsafe manner - * if the dependencies do not form a dag, you will deadlock - */ - def memoizeDagFuture[A, B](fn: (A, A => Par.F[B]) => Par.F[B]): A => Par.F[B] = { - val cache: ConcurrentHashMap[A, Par.P[B]] = new ConcurrentHashMap[A, Par.P[B]]() + /** This memoizes using a hash map in a threadsafe manner if the dependencies + * do not form a dag, you will deadlock + */ + def memoizeDagFuture[A, B]( + fn: (A, A => Par.F[B]) => Par.F[B] + ): A => Par.F[B] = { + val cache: ConcurrentHashMap[A, Par.P[B]] = + new ConcurrentHashMap[A, Par.P[B]]() new Function[A, Par.F[B]] { self => def apply(a: A) = { @@ -93,8 +95,7 @@ object Memoize { val resFut = fn(a, self) Par.complete(prom, resFut) resFut - } - else { + } else { // someone else is already working: Par.toF(prevProm) } diff --git a/core/src/main/scala/org/bykn/bosatsu/graph/Paths.scala b/core/src/main/scala/org/bykn/bosatsu/graph/Paths.scala index 13ef8bedf..5fdebfadf 100644 --- a/core/src/main/scala/org/bykn/bosatsu/graph/Paths.scala +++ b/core/src/main/scala/org/bykn/bosatsu/graph/Paths.scala @@ -3,29 +3,32 @@ package org.bykn.bosatsu.graph import cats.data.NonEmptyList object Paths { - /** - * A list of cycles all terminating at node - * E is intended to carry state about the edge in the graph - */ - def allCycles[A, E](node: A)(nfn: A => List[(E, A)]): List[NonEmptyList[(E, A)]] = + + /** A list of cycles all terminating at node E is intended to carry state + * about the edge in the graph + */ + def allCycles[A, E](node: A)( + nfn: A => List[(E, A)] + ): List[NonEmptyList[(E, A)]] = allPaths(node, node)(nfn) - /** - * A list of paths all terminating at to, but omitting from. - * E is intended to carry state about the edge in the graph - */ - def allPaths[A, E](from: A, to: A)(nfn: A => List[(E, A)]): List[NonEmptyList[(E, A)]] = { + /** A list of paths all terminating at to, but omitting from. E is intended to + * carry state about the edge in the graph + */ + def allPaths[A, E](from: A, to: A)( + nfn: A => List[(E, A)] + ): List[NonEmptyList[(E, A)]] = { def loop(from: A, to: A, avoid: Set[A]): List[NonEmptyList[(E, A)]] = { val newPaths = nfn(from).filterNot { case (_, a) => avoid(a) } val (ends, notEnds) = newPaths.partition { case (_, a) => a == to } - val rest = notEnds.flatMap { case edge@(_, a) => + val rest = notEnds.flatMap { case edge @ (_, a) => // don't loop back on a, loops to a are handled by ends loop(a, to, avoid + a).map(edge :: _) } NonEmptyList.fromList(ends) match { - case None => rest + case None => rest case Some(endsNE) => endsNE :: rest } } @@ -33,15 +36,13 @@ object Paths { loop(from, to, Set.empty) } - /** - * Same as allPaths but without the edge annotation type - */ + /** Same as allPaths but without the edge annotation type + */ def allPaths0[A](start: A, end: A)(nfn: A => List[A]): List[NonEmptyList[A]] = allPaths(start, end)(nfn.andThen(_.map(((), _)))).map(_.map(_._2)) - /** - * Same as allCycles but without the edge annotation type - */ + /** Same as allCycles but without the edge annotation type + */ def allCycle0[A](start: A)(nfn: A => List[A]): List[NonEmptyList[A]] = allPaths0(start, start)(nfn) } diff --git a/core/src/main/scala/org/bykn/bosatsu/graph/Toposort.scala b/core/src/main/scala/org/bykn/bosatsu/graph/Toposort.scala index 6f32343e6..981b797ea 100644 --- a/core/src/main/scala/org/bykn/bosatsu/graph/Toposort.scala +++ b/core/src/main/scala/org/bykn/bosatsu/graph/Toposort.scala @@ -5,10 +5,9 @@ import cats.syntax.all._ object Toposort { - /** - * A result is the subdag in layers, - * as well as a set of loopNodes (a sorted list of nodes that don't form a dag) - */ + /** A result is the subdag in layers, as well as a set of loopNodes (a sorted + * list of nodes that don't form a dag) + */ sealed abstract class Result[A] { // these are the nodes which depend on a cyclic subgraph def loopNodes: List[A] @@ -18,7 +17,7 @@ object Toposort { def toSuccess: Option[Vector[NonEmptyList[A]]] = this match { case Success(res, _) => Some(res) - case Failure(_, _) => None + case Failure(_, _) => None } // true if each layer has exactly one item in it @@ -33,18 +32,23 @@ object Toposort { def isFailure: Boolean = !isSuccess } - final case class Success[A](layers: Vector[NonEmptyList[A]], nfn: A => List[A]) extends Result[A] { + final case class Success[A]( + layers: Vector[NonEmptyList[A]], + nfn: A => List[A] + ) extends Result[A] { def loopNodes = List.empty[A] } - final case class Failure[A](loopNodes: List[A], layers: Vector[NonEmptyList[A]]) extends Result[A] + final case class Failure[A]( + loopNodes: List[A], + layers: Vector[NonEmptyList[A]] + ) extends Result[A] - /** - * Build a deterministic topological sort - * of a graph. The items in the position i depend only - * on things at position i-1 or less. - * - * return a result which tells us the layers of the dag, and the non-dag nodes - */ + /** Build a deterministic topological sort of a graph. The items in the + * position i depend only on things at position i-1 or less. + * + * return a result which tells us the layers of the dag, and the non-dag + * nodes + */ def sort[A: Ordering](n: Iterable[A])(fn: A => List[A]): Result[A] = if (n.isEmpty) Success(Vector.empty, fn) else { @@ -57,13 +61,12 @@ object Toposort { nonEmpty.traverse(rec).map(_.max + 1) } } - val res = n - .toList + val res = n.toList // go through in a deterministic order .sorted .map { n => depth(n) match { - case None => Left(n) + case None => Left(n) case Some(d) => Right((d, n)) } } @@ -77,13 +80,12 @@ object Toposort { // we have to be bad if we aren't good bad = true Vector.empty - } - else { + } else { val len = goodIt.max + 1 val ary = Array.fill(len)(List.newBuilder[A]) res.foreach { case Right((idx, a)) => ary(idx) += a - case Left(_) => bad = true + case Left(_) => bad = true } // the items are already sorted since we added them in sorted order diff --git a/core/src/main/scala/org/bykn/bosatsu/graph/Tree.scala b/core/src/main/scala/org/bykn/bosatsu/graph/Tree.scala index de318341f..30e2cab6f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/graph/Tree.scala +++ b/core/src/main/scala/org/bykn/bosatsu/graph/Tree.scala @@ -13,25 +13,31 @@ object Tree { val mapToTree: Map[A, Tree[A]] = toMap(t) - { (a: A) => mapToTree.get(a).fold(List.empty[A])(_.children.map(_.item)) } } - /** - * either return a tree representation of this dag or all cycles - * - * Note, this could run in a monadic context if we needed that: - * nfn: A => F[List[A]] for some monad F[_] - */ - def dagToTree[A](node: A)(nfn: A => List[A]): ValidatedNel[NonEmptyList[A], Tree[A]] = { - def treeOf(path: NonEmptyList[A], visited: Set[A]): ValidatedNel[NonEmptyList[A], Tree[A]] = { + /** either return a tree representation of this dag or all cycles + * + * Note, this could run in a monadic context if we needed that: nfn: A => + * F[List[A]] for some monad F[_] + */ + def dagToTree[A]( + node: A + )(nfn: A => List[A]): ValidatedNel[NonEmptyList[A], Tree[A]] = { + def treeOf( + path: NonEmptyList[A], + visited: Set[A] + ): ValidatedNel[NonEmptyList[A], Tree[A]] = { val children = nfn(path.head) - def assumeValid(children: List[A]): ValidatedNel[NonEmptyList[A], Tree[A]] = - children.traverse { a => - // we grow the path out here - treeOf(a :: path, visited + a) - } - .map(Tree(path.head, _)) + def assumeValid( + children: List[A] + ): ValidatedNel[NonEmptyList[A], Tree[A]] = + children + .traverse { a => + // we grow the path out here + treeOf(a :: path, visited + a) + } + .map(Tree(path.head, _)) NonEmptyList.fromList(children.filter(visited)) match { case Some(loops) => @@ -66,8 +72,7 @@ object Tree { def distinctBy[A, B](nel: List[A])(fn: A => B): List[A] = NonEmptyList.fromList(nel) match { - case None => Nil + case None => Nil case Some(nel) => distinctBy(nel)(fn).toList } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala index 3ee40af2a..5286cac16 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala @@ -13,7 +13,8 @@ trait Matcher[-P, -S, +R] { self => } object Matcher { - implicit class InvariantMatcher[P, S, R](val self: Matcher[P, S, R]) extends AnyVal { + implicit class InvariantMatcher[P, S, R](val self: Matcher[P, S, R]) + extends AnyVal { def mapWithInput[R1](fn: (S, R) => R1): Matcher[P, S, R1] = new Matcher[P, S, R1] { def apply(p: P): S => Option[R1] = { @@ -21,7 +22,7 @@ object Matcher { { (s: S) => next(s) match { - case None => None + case None => None case Some(r) => Some(fn(s, r)) } } @@ -33,15 +34,17 @@ object Matcher { def eqMatcher[A](implicit eqA: Eq[A]): Matcher[A, A, Unit] = new Matcher[A, A, Unit] { - def apply(a: A): A => Option[Unit] = - { (s: A) => if (eqA.eqv(a, s)) someUnit else None } + def apply(a: A): A => Option[Unit] = { (s: A) => + if (eqA.eqv(a, s)) someUnit else None + } } - val charMatcher: Matcher[Char, Char, Unit] = eqMatcher(Eq.fromUniversalEquals[Char]) + val charMatcher: Matcher[Char, Char, Unit] = eqMatcher( + Eq.fromUniversalEquals[Char] + ) def fnMatch[A]: Matcher[A => Boolean, A, Unit] = new Matcher[A => Boolean, A, Unit] { - def apply(p: A => Boolean) = - { (a: A) => if (p(a)) someUnit else None } + def apply(p: A => Boolean) = { (a: A) => if (p(a)) someUnit else None } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/NamedSeqPattern.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/NamedSeqPattern.scala index c58448fde..8ca9f4d64 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/NamedSeqPattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/NamedSeqPattern.scala @@ -10,7 +10,7 @@ sealed trait NamedSeqPattern[+A] { def loop(n: NamedSeqPattern[A], right: List[SeqPart[A]]): List[SeqPart[A]] = n match { case Bind(_, n) => loop(n, right) - case NEmpty => right + case NEmpty => right case NCat(first, second) => val r2 = loop(second, right) loop(first, r2) @@ -29,10 +29,10 @@ sealed trait NamedSeqPattern[+A] { // we are renderable if all Wild/AnyElem are named def isRenderable: Boolean = this match { - case NEmpty => true - case Bind(_, _) => true + case NEmpty => true + case Bind(_, _) => true case NSeqPart(Lit(_)) => true - case NSeqPart(_) => false + case NSeqPart(_) => false case NCat(l, r) => l.isRenderable && r.isRenderable } @@ -42,14 +42,15 @@ sealed trait NamedSeqPattern[+A] { def loop(n: NamedSeqPattern[A], right: S): Option[S] = n match { - case NEmpty => Some(right) + case NEmpty => Some(right) case Bind(nm, r) => // since we have this name, we don't need to recurse - names.get(nm) + names + .get(nm) .map { seq => ms.combine(seq, right) } .orElse(loop(r, right)) case NSeqPart(SeqPart.Lit(c)) => Some(ms.combine(fn(c), right)) - case NSeqPart(_) => None + case NSeqPart(_) => None case NCat(l, r) => loop(r, right) .flatMap { right => @@ -62,9 +63,9 @@ sealed trait NamedSeqPattern[+A] { def names: List[String] = this match { - case Bind(name, nsp) => name :: nsp.names + case Bind(name, nsp) => name :: nsp.names case NEmpty | NSeqPart(_) => Nil - case NCat(h, t) => h.names ::: t.names + case NCat(h, t) => h.names ::: t.names } } @@ -77,15 +78,19 @@ object NamedSeqPattern { val Wild: NamedSeqPattern[Nothing] = NSeqPart(SeqPart.Wildcard) val Any: NamedSeqPattern[Nothing] = NSeqPart(SeqPart.AnyElem) - case class Bind[A](name: String, p: NamedSeqPattern[A]) extends NamedSeqPattern[A] + case class Bind[A](name: String, p: NamedSeqPattern[A]) + extends NamedSeqPattern[A] case object NEmpty extends NamedSeqPattern[Nothing] case class NSeqPart[A](part: SeqPart[A]) extends NamedSeqPattern[A] - case class NCat[A](first: NamedSeqPattern[A], second: NamedSeqPattern[A]) extends NamedSeqPattern[A] + case class NCat[A](first: NamedSeqPattern[A], second: NamedSeqPattern[A]) + extends NamedSeqPattern[A] def fromLit[A](a: A): NamedSeqPattern[A] = NSeqPart(SeqPart.Lit(a)) - def matcher[E, I, S, R](split: Splitter[E, I, S, R]): Matcher[NamedSeqPattern[E], S, (R, Map[String, S])] = + def matcher[E, I, S, R]( + split: Splitter[E, I, S, R] + ): Matcher[NamedSeqPattern[E], S, (R, Map[String, S])] = new Matcher[NamedSeqPattern[E], S, (R, Map[String, S])] { def apply(nsp: NamedSeqPattern[E]): S => Option[(R, Map[String, S])] = { val machine = Impl.toMachine(nsp, Nil) @@ -95,7 +100,10 @@ object NamedSeqPattern { } private[this] object Impl { - def toMachine[A](n: NamedSeqPattern[A], right: List[Machine[A]]): List[Machine[A]] = + def toMachine[A]( + n: NamedSeqPattern[A], + right: List[Machine[A]] + ): List[Machine[A]] = n match { case NEmpty => right case Bind(name, n) => @@ -112,31 +120,32 @@ object NamedSeqPattern { def hasWildLeft(m: List[Machine[Any]]): Boolean = m match { - case Nil => false + case Nil => false case MSeqPart(SeqPart.Wildcard) :: _ => true - case MSeqPart(_) :: _ => false - case _ :: tail => hasWildLeft(tail) + case MSeqPart(_) :: _ => false + case _ :: tail => hasWildLeft(tail) } import SeqPart.{AnyElem, Lit, SeqPart1, Wildcard} - def capture[S](empty: S, capturing: List[String], res: Map[String, S])(fn: S => S): Map[String, S] = + def capture[S](empty: S, capturing: List[String], res: Map[String, S])( + fn: S => S + ): Map[String, S] = capturing.foldLeft(res) { (mapB, n) => val right = mapB.get(n) match { - case None => empty + case None => empty case Some(bv) => bv } mapB.updated(n, fn(right)) } def matches[E, I, S, R]( - split: Splitter[E, I, S, R], - m: List[Machine[E]], - capturing: List[String]): S => Option[(R, Map[String, S])] = - + split: Splitter[E, I, S, R], + m: List[Machine[E]], + capturing: List[String] + ): S => Option[(R, Map[String, S])] = m match { case Nil => - val res = Some((split.monoidResult.empty, Map.empty[String, S])) { (str: S) => @@ -150,7 +159,7 @@ object NamedSeqPattern { case Nil => // $COVERAGE-OFF$ sys.error("illegal End with no capturing") - // $COVERAGE-ON$ + // $COVERAGE-ON$ case n :: cap => // if n captured nothing, we need // to add an empty list @@ -169,13 +178,11 @@ object NamedSeqPattern { if (hasWildLeft(tail)) { // two adjacent wilds means this one matches nothing matches(split, tail, capturing) - } - else { + } else { val me = matchEnd(split, tail, capturing) me.andThen { stream => - stream - .headOption + stream.headOption .map { case (prefix, (rightR, rightBind)) => // now merge the prefix result val resMatched = capturing.foldLeft(rightBind) { (st, n) => @@ -195,13 +202,11 @@ object NamedSeqPattern { val headm: I => Option[R] = p1 match { case AnyElem => { (_: I) => someEmpty } - case Lit(c) => split.matcher(c) + case Lit(c) => split.matcher(c) } val tailm: S => Option[(R, Map[String, S])] = - matches(split, - tail, - capturing) + matches(split, tail, capturing) { (str: S) => for { @@ -210,14 +215,18 @@ object NamedSeqPattern { rh <- headm(h) rt <- tailm(t) (tailr, tailm) = rt - } yield (split.monoidResult.combine(rh, tailr), capture(split.emptySeq, capturing, tailm)(split.cons(h, _))) + } yield ( + split.monoidResult.combine(rh, tailr), + capture(split.emptySeq, capturing, tailm)(split.cons(h, _)) + ) } } def matchEnd[E, I, S, R]( - split: Splitter[E, I, S, R], - m: List[Machine[E]], - capturing: List[String]): S => LazyList[(S, (R, Map[String, S]))] = + split: Splitter[E, I, S, R], + m: List[Machine[E]], + capturing: List[String] + ): S => LazyList[(S, (R, Map[String, S]))] = m match { case Nil => // we always match the end @@ -233,7 +242,7 @@ object NamedSeqPattern { case Nil => // $COVERAGE-OFF$ sys.error("illegal End with no capturing") - // $COVERAGE-ON$ + // $COVERAGE-ON$ case n :: cap => // if n captured nothing, we need // to add an empty list @@ -254,24 +263,24 @@ object NamedSeqPattern { val mtail = matches(split, tail, capturing) val splits = p1 match { - case Lit(c) => split.positions(c) + case Lit(c) => split.positions(c) case AnyElem => split.anySplits(_: S) } - { (s: S) => - splits(s).map { case (pre, i, r, post) => - mtail(post) - .map { case (rp, mapRes) => - val res1 = split.monoidResult.combine(r, rp) - val res2 = capture(split.emptySeq, capturing, mapRes)(split.cons(i, _)) - (pre, (res1, res2)) - } - } - .collect { case Some(res) => res } + splits(s) + .map { case (pre, i, r, post) => + mtail(post) + .map { case (rp, mapRes) => + val res1 = split.monoidResult.combine(r, rp) + val res2 = capture(split.emptySeq, capturing, mapRes)( + split.cons(i, _) + ) + (pre, (res1, res2)) + } + } + .collect { case Some(res) => res } } } } } - - diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPart.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPart.scala index 4d402771b..a248c7f3e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPart.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPart.scala @@ -11,17 +11,19 @@ object SeqPart { override def notWild: Boolean = true } - implicit def partOrdering[E](implicit elemOrdering: Ordering[E]): Ordering[SeqPart[E]] = + implicit def partOrdering[E](implicit + elemOrdering: Ordering[E] + ): Ordering[SeqPart[E]] = new Ordering[SeqPart[E]] { def compare(a: SeqPart[E], b: SeqPart[E]) = (a, b) match { case (Lit(i1), Lit(i2)) => elemOrdering.compare(i1, i2) - case (Lit(_), _) => -1 - case (_, Lit(_)) => 1 - case (AnyElem, AnyElem) => 0 - case (AnyElem, Wildcard) => -1 - case (Wildcard, AnyElem) => 1 + case (Lit(_), _) => -1 + case (_, Lit(_)) => 1 + case (AnyElem, AnyElem) => 0 + case (AnyElem, Wildcard) => -1 + case (Wildcard, AnyElem) => 1 case (Wildcard, Wildcard) => 0 } } @@ -31,7 +33,9 @@ object SeqPart { // 0 or more characters case object Wildcard extends SeqPart[Nothing] - implicit def part1SetOps[A](implicit setOpsA: SetOps[A]): SetOps[SeqPart1[A]] = + implicit def part1SetOps[A](implicit + setOpsA: SetOps[A] + ): SetOps[SeqPart1[A]] = new SetOps[SeqPart1[A]] { private val anyList = AnyElem :: Nil @@ -42,7 +46,7 @@ object SeqPart { def anyDiff(a: A) = setOpsA.top match { - case None => anyList + case None => anyList case Some(topA) => setOpsA.difference(topA, a).map(toPart1) } @@ -50,13 +54,14 @@ object SeqPart { def isTop(c: SeqPart1[A]) = c match { case AnyElem => true - case Lit(a) => setOpsA.isTop(a) + case Lit(a) => setOpsA.isTop(a) } def intersection(p1: SeqPart1[A], p2: SeqPart1[A]): List[SeqPart1[A]] = (p1, p2) match { case (Lit(c1), Lit(c2)) => - setOpsA.intersection(c1, c2) + setOpsA + .intersection(c1, c2) .map(toPart1(_)) case (AnyElem, _) => if (isTop(p2)) AnyElem :: Nil @@ -112,18 +117,15 @@ object SeqPart { def litOpt(u: List[SeqPart1[A]], acc: List[A]): Option[List[Lit[A]]] = u match { case Nil => Some(setOpsA.unifyUnion(acc.reverse).map(Lit(_))) - case AnyElem :: _ => None + case AnyElem :: _ => None case Lit(a) :: _ if setOpsA.isTop(a) => None - case Lit(a) :: tail => litOpt(tail, a :: acc) + case Lit(a) :: tail => litOpt(tail, a :: acc) } - litOpt(u, Nil) match { - case None => AnyElem :: Nil + case None => AnyElem :: Nil case Some(u) => u } } } } - - diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala index 2f86cc4ab..195178e99 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala @@ -10,23 +10,22 @@ sealed trait SeqPattern[+A] { def matchesAny: Boolean = this match { - case Empty => false + case Empty => false case Cat(Wildcard, t) => t.matchesEmpty - case Cat(_, _) => false + case Cat(_, _) => false } def matchesEmpty: Boolean = this match { - case Empty => true + case Empty => true case Cat(Wildcard, t) => t.matchesEmpty - case Cat(_, _) => false + case Cat(_, _) => false } def isEmpty: Boolean = this == Empty - /** - * Concat that SeqPattern on the right - */ + /** Concat that SeqPattern on the right + */ def +[A1 >: A](that: SeqPattern[A1]): SeqPattern[A1] = SeqPattern.fromList(toList ::: that.toList) @@ -35,14 +34,14 @@ sealed trait SeqPattern[+A] { def prependWild: SeqPattern[A] = this match { - case Cat(AnyElem, t) => Cat(AnyElem, Cat(Wildcard, t)) + case Cat(AnyElem, t) => Cat(AnyElem, Cat(Wildcard, t)) case Cat(Wildcard, _) => this - case notAlreadyWild => Cat(Wildcard, notAlreadyWild) + case notAlreadyWild => Cat(Wildcard, notAlreadyWild) } def toList: List[SeqPart[A]] = this match { - case Empty => Nil + case Empty => Nil case Cat(h, t) => h :: t.toList } @@ -50,7 +49,7 @@ sealed trait SeqPattern[+A] { @annotation.tailrec def loop(sp: SeqPattern[A], acc: Int): Int = sp match { - case Empty => acc + case Empty => acc case Cat(_, tail) => loop(tail, acc + 1) } @@ -66,20 +65,19 @@ sealed trait SeqPattern[+A] { case Cat(_, _) => None } - /** - * If two wilds are adjacent, the left one will always match empty string - * this normalize just removes the left wild - * - * combine adjacent strings - */ + /** If two wilds are adjacent, the left one will always match empty string + * this normalize just removes the left wild + * + * combine adjacent strings + */ def normalize: SeqPattern[A] = this match { - case Empty => Empty + case Empty => Empty case Cat(Wildcard, Cat(AnyElem, t)) => // move AnyElem out val wtn = Cat(Wildcard, t) Cat(AnyElem, wtn.normalize) - case Cat(Wildcard, tail@Cat(Wildcard, _)) => + case Cat(Wildcard, tail @ Cat(Wildcard, _)) => // remove duplicate Wildcard tail.normalize case Cat(h, tail) => @@ -87,29 +85,27 @@ sealed trait SeqPattern[+A] { } def show: String = - toList - .iterator - .map { - case Lit('.') => "\\." - case Lit('*') => "\\*" - case Lit(c) => c.toString - case AnyElem => "." - case Wildcard => "*" - } - .mkString + toList.iterator.map { + case Lit('.') => "\\." + case Lit('*') => "\\*" + case Lit(c) => c.toString + case AnyElem => "." + case Wildcard => "*" + }.mkString } object SeqPattern { case object Empty extends SeqPattern[Nothing] - case class Cat[A](head: SeqPart[A], tail: SeqPattern[A]) extends SeqPattern[A] { + case class Cat[A](head: SeqPart[A], tail: SeqPattern[A]) + extends SeqPattern[A] { // return the last non-empty @annotation.tailrec final def rightMost: SeqPart[A] = tail match { - case Empty => head - case Cat(h, Empty) => h - case Cat(_, r@Cat(_, _)) => r.rightMost + case Empty => head + case Cat(h, Empty) => h + case Cat(_, r @ Cat(_, _)) => r.rightMost } def reverseCat: Cat[A] = { @@ -120,7 +116,7 @@ object SeqPattern { def fromList[A](ps: List[SeqPart[A]]): SeqPattern[A] = ps.foldRight(Empty: SeqPattern[A]) { (h, tail) => - Cat(h, tail) + Cat(h, tail) } val Wild: SeqPattern[Nothing] = Cat(SeqPart.Wildcard, Empty) @@ -131,7 +127,7 @@ object SeqPattern { val ordSeqPart: Ordering[SeqPart[A]] = implicitly[Ordering[SeqPart[A]]] def compare(a: SeqPattern[A], b: SeqPattern[A]) = (a, b) match { - case (Empty, Empty) => 0 + case (Empty, Empty) => 0 case (Empty, Cat(_, _)) => -1 case (Cat(_, _), Empty) => 1 case (Cat(h1, t1), Cat(h2, t2)) => @@ -141,7 +137,10 @@ object SeqPattern { } } - implicit def seqPatternSetOps[A](implicit part1SetOps: SetOps[SeqPart.SeqPart1[A]], ordA: Ordering[A]): SetOps[SeqPattern[A]] = + implicit def seqPatternSetOps[A](implicit + part1SetOps: SetOps[SeqPart.SeqPart1[A]], + ordA: Ordering[A] + ): SetOps[SeqPattern[A]] = new SetOps[SeqPattern[A]] { self => import SeqPart.{SeqPart1, AnyElem, Wildcard} @@ -161,14 +160,14 @@ object SeqPattern { .map(normalize(_)) .distinct .map(_.toList) - } - .map { s => normalize(SeqPattern.fromList(s)) } - .sorted + }.map { s => normalize(SeqPattern.fromList(s)) }.sorted private[this] val someWild = Some(Wildcard :: Nil) private[this] val someNil = Some(Nil) - private def unifyUnionList(union: List[List[SeqPart[A]]]): List[List[SeqPart[A]]] = { + private def unifyUnionList( + union: List[List[SeqPart[A]]] + ): List[List[SeqPart[A]]] = { // if a part of Sequences are the same except this part, can we merge by appending // something? @@ -176,36 +175,35 @@ object SeqPattern { list match { case (a: SeqPart1[A]) :: Wildcard :: Nil if isAny(a) => someWild case Wildcard :: (a: SeqPart1[A]) :: Nil if isAny(a) => someWild - case Wildcard :: Wildcard :: Nil => someWild - case Wildcard :: Nil => someWild - case Nil => someNil - case _ => None + case Wildcard :: Wildcard :: Nil => someWild + case Wildcard :: Nil => someWild + case Nil => someNil + case _ => None } - def unifyPair(left: List[SeqPart[A]], right: List[SeqPart[A]]): Option[List[SeqPart[A]]] = { + def unifyPair( + left: List[SeqPart[A]], + right: List[SeqPart[A]] + ): Option[List[SeqPart[A]]] = { def o1 = if (left.startsWith(right)) { unifySeqPart(left.drop(right.size)).map(right ::: _) - } - else None + } else None def o2 = if (right.startsWith(left)) { unifySeqPart(right.drop(left.size)).map(left ::: _) - } - else None + } else None def o3 = if (left.endsWith(right)) { unifySeqPart(left.dropRight(right.size)).map(_ ::: right) - } - else None + } else None def o4 = if (right.endsWith(left)) { unifySeqPart(right.dropRight(left.size)).map(_ ::: left) - } - else None + } else None def o5 = if (subsetList(left, right)) Some(right) else None def o6 = if (subsetList(right, left)) Some(left) else None @@ -234,13 +232,11 @@ object SeqPattern { val rest = items.iterator.filterNot(_ == null).toList // let's look again unifyUnionList(pair :: rest) - } - else union + } else union } - /** - * return true if p1 <= p2, can give false negatives - */ + /** return true if p1 <= p2, can give false negatives + */ override def subset(p1: SeqPattern[A], p2: SeqPattern[A]): Boolean = p2.matchesAny || { // if p2 doesn't matchEmpty but p1 does, we are done @@ -256,12 +252,11 @@ object SeqPattern { final def isAny(p: SeqPart1[A]): Boolean = part1SetOps.isTop(p) - /** - * If two wilds are adjacent, the left one will always match empty string - * this normalize just removes the left wild - * - * combine adjacent strings - */ + /** If two wilds are adjacent, the left one will always match empty string + * this normalize just removes the left wild + * + * combine adjacent strings + */ private def normalize(sp: SeqPattern[A]): SeqPattern[A] = sp match { case Empty => Empty @@ -278,20 +273,21 @@ object SeqPattern { // move AnyElem out val wtn = Cat(Wildcard, t) Cat(AnyElem, normalize(wtn)) - } - else { + } else { Cat(Wildcard, Cat(a, normalize(t))) } - case tail@Cat(Wildcard, _) => + case tail @ Cat(Wildcard, _) => // remove duplicate Wildcard normalize(tail) } } - - private def subsetList(p1: List[SeqPart[A]], p2: List[SeqPart[A]]): Boolean = + private def subsetList( + p1: List[SeqPart[A]], + p2: List[SeqPart[A]] + ): Boolean = (p1, p2) match { - case (Nil, Nil) => true + case (Nil, Nil) => true case (Nil, (_: SeqPart1[A]) :: _) => false case (Nil, Wildcard :: t) => subsetList(Nil, t) @@ -299,19 +295,19 @@ object SeqPattern { case ((h1: SeqPart1[A]) :: t1, (h2: SeqPart1[A]) :: t2) => part1SetOps.subset(h1, h2) && subsetList(t1, t2) case (Wildcard :: Wildcard :: t1, _) => - // normalize the left: - subsetList(Wildcard :: t1, p2) + // normalize the left: + subsetList(Wildcard :: t1, p2) case (_, Wildcard :: Wildcard :: t2) => - // normalize the right: - subsetList(p1, Wildcard :: t2) + // normalize the right: + subsetList(p1, Wildcard :: t2) case (_, Wildcard :: (a2: SeqPart1[A]) :: t2) if isAny(a2) => - // we know that right can't match empty, - // let's see if that helps us rule out matches on the left - subsetList(p1, AnyElem :: Wildcard :: t2) + // we know that right can't match empty, + // let's see if that helps us rule out matches on the left + subsetList(p1, AnyElem :: Wildcard :: t2) case (Wildcard :: (a1: SeqPart1[A]) :: t1, _) if isAny(a1) => - // we know that left can't match empty, - // let's see if that helps us rule out matches on the left - subsetList(AnyElem :: Wildcard :: t1, p2) + // we know that left can't match empty, + // let's see if that helps us rule out matches on the left + subsetList(AnyElem :: Wildcard :: t1, p2) // either t1 or t2 also ends with Wildcard case (_ :: _, Wildcard :: _) if p2.last.notWild => // wild on the right but not at the end @@ -327,8 +323,8 @@ object SeqPattern { // p1 = *t1 = t1 + _:p1 // _:p1 <= h2:t2 => (_ <= h2) && (p1 <= t2) isAny(h2) && - subsetList(t1, p2) && - subsetList(p1, t2) + subsetList(t1, p2) && + subsetList(p1, t2) case ((_: SeqPart1[A]) :: t1, Wildcard :: t2) => // we could pop off one wildcard to match head // or we could match with nothing but the rest @@ -353,34 +349,36 @@ object SeqPattern { private def relateList(p1: List[SeqPart[A]], p2: List[SeqPart[A]]): Rel = (p1, p2) match { - case (Nil, Nil) => Rel.Same + case (Nil, Nil) => Rel.Same case (Nil, (_: SeqPart1[A]) :: _) => // [] is h :: t are disjoint when h matches at least 1 Rel.Disjoint case (Nil, Wildcard :: t) => // [] <:> * :: t, if t matchesEmpty, this is subset, - if (t.exists { - case _: SeqPart1[A] => true - case Wildcard => false - }) Rel.Disjoint + if ( + t.exists { + case _: SeqPart1[A] => true + case Wildcard => false + } + ) Rel.Disjoint else Rel.Sub case (_ :: _, Nil) => relateList(p2, p1).invert case ((h1: SeqPart1[A]) :: t1, (h2: SeqPart1[A]) :: t2) => part1SetOps.relate(h1, h2).lazyCombine(relateList(t1, t2)) case (Wildcard :: Wildcard :: t1, _) => - // normalize the left: - relateList(Wildcard :: t1, p2) + // normalize the left: + relateList(Wildcard :: t1, p2) case (_, Wildcard :: Wildcard :: t2) => - // normalize the right: - relateList(p1, Wildcard :: t2) + // normalize the right: + relateList(p1, Wildcard :: t2) case (_, Wildcard :: (a2: SeqPart1[A]) :: t2) if isAny(a2) => - // we know that right can't match empty, - // let's see if that helps us rule out matches on the left - relateList(p1, AnyElem :: Wildcard :: t2) + // we know that right can't match empty, + // let's see if that helps us rule out matches on the left + relateList(p1, AnyElem :: Wildcard :: t2) case (Wildcard :: (a1: SeqPart1[A]) :: t1, _) if isAny(a1) => - // we know that left can't match empty, - // let's see if that helps us rule out matches on the left - relateList(AnyElem :: Wildcard :: t1, p2) + // we know that left can't match empty, + // let's see if that helps us rule out matches on the left + relateList(AnyElem :: Wildcard :: t1, p2) // either t1 or t2 also ends with Wildcard case (_ :: _, Wildcard :: _) if p2.last.notWild => // wild on the right but not at the end @@ -393,7 +391,8 @@ object SeqPattern { case _ => viaIntersection.relate( SeqPattern.fromList(p1), - SeqPattern.fromList(p2)) + SeqPattern.fromList(p2) + ) } private def min(p1: SeqPattern[A], p2: SeqPattern[A]): SeqPattern[A] = { @@ -406,7 +405,7 @@ object SeqPattern { import SeqPart.{Lit, AnyElem, Wildcard} def loop(p1: SeqPattern[A], p2: SeqPattern[A]): SeqPattern[A] = (p1, p2) match { - case (Empty, Empty) => Empty + case (Empty, Empty) => Empty case (Cat(AnyElem, _), Cat(_, _)) => p1 case (Cat(h1 @ Lit(_), t1), Cat(h2 @ Lit(_), t2)) => if (part1SetOps.equiv(h1, h2)) Cat(h1, loop(t1, t2)) @@ -417,21 +416,23 @@ object SeqPattern { sys.error(s"invariant violation equiv($h1, $h2) == false") // $COVERAGE-ON$ } - case (Cat(_, _), Cat(AnyElem, _)) => p2 + case (Cat(_, _), Cat(AnyElem, _)) => p2 case (Cat(_, _), Cat(Wildcard, _)) => p1 - case (Cat(Wildcard, _), _) => p2 - case (Cat(_, _), Empty) => Empty - case (Empty, Cat(_, _)) => Empty + case (Cat(Wildcard, _), _) => p2 + case (Cat(_, _), Empty) => Empty + case (Empty, Cat(_, _)) => Empty } loop(p1, p2) } } - /** - * Compute a list of patterns that matches both patterns exactly - */ - def intersection(p1: SeqPattern[A], p2: SeqPattern[A]): List[SeqPattern[A]] = + /** Compute a list of patterns that matches both patterns exactly + */ + def intersection( + p1: SeqPattern[A], + p2: SeqPattern[A] + ): List[SeqPattern[A]] = (p1, p2) match { case (Empty, _) => if (p2.matchesEmpty) p1 :: Nil else Nil @@ -443,17 +444,16 @@ object SeqPattern { if (p1.matchesAny) { // both match any, return a normalized value Wild - } - else p1 + } else p1 res :: Nil case (Cat(Wildcard, _), p2) if p1.matchesAny => // p1 matches anything, but p2 doesn't p2 :: Nil - case (Cat(Wildcard, t1@Cat(Wildcard, _)), _) => + case (Cat(Wildcard, t1 @ Cat(Wildcard, _)), _) => // unnormalized intersection(t1, p2) - case (_, Cat(Wildcard, t2@Cat(Wildcard, _))) => + case (_, Cat(Wildcard, t2 @ Cat(Wildcard, _))) => // unnormalized intersection(p1, t2) case (Cat(Wildcard, Cat(a1: SeqPart1[A], t1)), _) if isAny(a1) => @@ -469,7 +469,8 @@ object SeqPattern { } yield Cat(h, t) unifyUnion(intr) - case (c1@Cat(Wildcard, _), c2@Cat(Wildcard, _)) if c1.rightMost.notWild || c2.rightMost.notWild => + case (c1 @ Cat(Wildcard, _), c2 @ Cat(Wildcard, _)) + if c1.rightMost.notWild || c2.rightMost.notWild => // let's avoid the most complex case of both having // wild on the front if possible intersection(c1.reverse, c2.reverse).map(_.reverse) @@ -509,188 +510,184 @@ object SeqPattern { intersection(p2, p1) } - /** - * return the patterns that match p1 but not p2 - * - * For fixed sets A, B if we have (A1 x B1) - (A2 x B2) = - * A1 = (A1 n A2) u (A1 - A2) - * A2 = (A1 n A2) u (A2 - A1) - * so we can decompose: - * - * A1 x B1 = (A1 n A2)xB1 u (A1 - A2)xB1 - * A2 x B2 = (A1 n A2)xB2 u (A2 - A1)xB2 - * - * the difference is: - * (A1 n A2)x(B1 - B2) u (A1 - A2)xB1 - * - * A - (B1 u B2) = (A - B1) n (A - B2) - * A - (B1 u B2) <= ((A - B1) u (A - B2)) - (B1 n B2) - * A - (B1 u B2) >= (A - B1) u (A - B2) - * - * so if (B1 n B2) = 0, then: - * A - (B1 u B2) = (A - B1) u (A - B2) - * - * (A1 u A2) - B = (A1 - B) u (A2 - B) - * - * The last challenge is we need to operate on - * s ingle characters, so we need to expand - * wild into [*] = [] | [_, *], since our pattern - * language doesn't have a symbol for - * a single character match we have to be a bit more careful - * - * also, we can't exactly represent Wildcard - Lit - * so this is actually an upperbound on the difference - * which is to say, all the returned patterns match p1, - * but some of them also match p2 - */ - def difference(p1: SeqPattern[A], p2: SeqPattern[A]): List[SeqPattern[A]] = + /** return the patterns that match p1 but not p2 + * + * For fixed sets A, B if we have (A1 x B1) - (A2 x B2) = A1 = (A1 n A2) + * u (A1 - A2) A2 = (A1 n A2) u (A2 - A1) so we can decompose: + * + * A1 x B1 = (A1 n A2)xB1 u (A1 - A2)xB1 A2 x B2 = (A1 n A2)xB2 u (A2 - + * A1)xB2 + * + * the difference is: (A1 n A2)x(B1 - B2) u (A1 - A2)xB1 + * + * A - (B1 u B2) = (A - B1) n (A - B2) A - (B1 u B2) <= ((A - B1) u (A - + * B2)) - (B1 n B2) A - (B1 u B2) >= (A - B1) u (A - B2) + * + * so if (B1 n B2) = 0, then: A - (B1 u B2) = (A - B1) u (A - B2) + * + * (A1 u A2) - B = (A1 - B) u (A2 - B) + * + * The last challenge is we need to operate on s ingle characters, so we + * need to expand wild into [*] = [] | [_, *], since our pattern language + * doesn't have a symbol for a single character match we have to be a bit + * more careful + * + * also, we can't exactly represent Wildcard - Lit so this is actually an + * upperbound on the difference which is to say, all the returned + * patterns match p1, but some of them also match p2 + */ + def difference( + p1: SeqPattern[A], + p2: SeqPattern[A] + ): List[SeqPattern[A]] = relate(p1, p2) match { case Rel.Sub | Rel.Same => Nil - case Rel.Disjoint => p1 :: Nil - case _ => + case Rel.Disjoint => p1 :: Nil + case _ => // We know p1 is a strict super set of p2 or it // intersects. We can never return Nil or p1 :: Nil - (p1, p2) match { - case (Cat(Wildcard, t1@Cat(Wildcard, _)), _) => - // unnormalized - difference(t1, p2) - case (_, Cat(Wildcard, t2@Cat(Wildcard, _))) => - // unnormalized - difference(p1, t2) - case (Cat(Wildcard, Cat(a1: SeqPart1[A], t1)), _) if isAny(a1) => - // *. == .*, push Wildcards to the end - difference(Cat(a1, Cat(Wildcard, t1)), p2) - case (_, Cat(Wildcard, Cat(a2: SeqPart1[A], t2))) if isAny(a2) => - // *. == .*, push Wildcards to the end - difference(p1, Cat(AnyElem, Cat(Wildcard, t2))) - case (Cat(Wildcard, t1), Empty) => - // we know that t1 matches empty or these wouldn't intersect - // use (A + B) - C = (A - C) + (B - C) - // *:t1 = t1 + _:p1 - // _:p1 - [] = _:p1 - unifyUnion(Cat(AnyElem, p1) :: difference(t1, Empty)) - case (Cat(h1: SeqPart1[A], t1), Cat(h2: SeqPart1[A], t2)) => - // h1:t1 - h2:t2 = (h1 n h2):(t1 - t2) + (h1 - h2):t1 - // = (t1 n t2):(h1 - h2) + (t1 - t2):h1 - // if t1 n t2 = 0 then t1 - t2 == t1 - val intH = part1SetOps.intersection(h1, h2) - val d1 = - for { - h <- intH - t <- difference(t1, t2) - } yield Cat(h, t) - - val d2 = part1SetOps.difference(h1, h2).map(Cat(_, t1)) - unifyUnion(d1 ::: d2) - case (Cat(h1: SeqPart1[A], t1), Cat(Wildcard, t2)) => - // h1:t1 - (*:t2) = ((h1:t1 - _:p2) - t2) - // - // h1:t1 - _:p2 = (h1 n _) : (t1 - p2) + (h1 - _):t1 - // = h1 : (t1 - p2) - val dtail = difference(t1, p2).map(Cat(h1, _)) - val u = differenceAll(dtail, t2 :: Nil) - - unifyUnion(u) - case (Cat(Wildcard, t1), Cat(h2: SeqPart1[A], t2)) => - // *:t1 - (h2:t2) = t1 + _:p1 - h2:t2 - // = (t1 - p2) + (_:p1 - h2:t2) - val d12 = { - //(_:p1 - h2:t2) = - //(_ n h2):(p1 - t2) + (_ - h2):p1 - //h2:(p1 - t2) + (_ - h2):p1 - // - //or: - //(_ - h2):(p1 n t2) + _:(p1 - t2) - if (disjoint(p1, t2)) { - Cat(AnyElem, p1) :: Nil - } - else { - val dtail = difference(p1, t2) - val d1 = dtail.map(Cat(h2, _)) - val d2 = part1SetOps.difference(AnyElem, h2).map(Cat(_, p1)) - d1 ::: d2 - } - } - val d3 = difference(t1, p2) - unifyUnion(d12 ::: d3) - case (c1@Cat(Wildcard, t1), c2@Cat(Wildcard, t2)) => - if (c1.rightMost.notWild || c2.rightMost.notWild) { - // let's avoid the most complex case of both having - // wild on the front if possible - difference(c1.reverse, c2.reverse).map(_.reverse) - } - else { - // both start and end with wildcard - // - // p1 - (t2 + _:p2) = - // if (t2 n (_:*:t2)).isEmpty, then - // then we can use a simpler formula - // but that is very uncommon (maybe - // we can find a proof it can't - // happen if t2 ends with Wildcard - // which is the case we are in. - // - // otherwise - // this branch is approximate: - // (p1 - t2) n (p1 - _:p2) = - // (p1 - t2) n ((t1 + _:p1) - _:p2) - // (p1 - t2) n ((t1 - _:p2) + _:(p1 - p2)) - // - // x = a n (b + _:x) - // = (a n b) + a n (_:x) - // <= (a n b) + (_:x) - // = *:(a n b) - // - // if t1 = [], then the above gives - // either empty set or *. - // this is a common case when we are - // searching for missing branches, we - // start at * - x - // - // (* - t2) n (([] - _:p2) + _:(p1 - p2)) - // a = * - t2 - // = (* - t2) n ([] + _:(p1 - p2)) - // p1 - p2 = a n ([] + _:(a n ([] + _:(a n ([] + _: ... - // <= a n ([] _ :(a n []) + _ _ :(a n []) +++ - // = a n (*:(a n [])) - // - // since a <= *, in the right side we have - // a n * = a - // so p1 - p2 <= a - // - // note, a is always an upper bound due - // to formula x = a n (...) - val as = difference(p1, t2) - if (t1.isEmpty) { - as - } - else { - // if x <= *:(a n b) and a then it is <= a n (*:(a n b)) - val bs = difference(t1, Cat(AnyElem, p2)) - // (a1 + a2) n (b1 + b2) = - val intr = + (p1, p2) match { + case (Cat(Wildcard, t1 @ Cat(Wildcard, _)), _) => + // unnormalized + difference(t1, p2) + case (_, Cat(Wildcard, t2 @ Cat(Wildcard, _))) => + // unnormalized + difference(p1, t2) + case (Cat(Wildcard, Cat(a1: SeqPart1[A], t1)), _) if isAny(a1) => + // *. == .*, push Wildcards to the end + difference(Cat(a1, Cat(Wildcard, t1)), p2) + case (_, Cat(Wildcard, Cat(a2: SeqPart1[A], t2))) if isAny(a2) => + // *. == .*, push Wildcards to the end + difference(p1, Cat(AnyElem, Cat(Wildcard, t2))) + case (Cat(Wildcard, t1), Empty) => + // we know that t1 matches empty or these wouldn't intersect + // use (A + B) - C = (A - C) + (B - C) + // *:t1 = t1 + _:p1 + // _:p1 - [] = _:p1 + unifyUnion(Cat(AnyElem, p1) :: difference(t1, Empty)) + case (Cat(h1: SeqPart1[A], t1), Cat(h2: SeqPart1[A], t2)) => + // h1:t1 - h2:t2 = (h1 n h2):(t1 - t2) + (h1 - h2):t1 + // = (t1 n t2):(h1 - h2) + (t1 - t2):h1 + // if t1 n t2 = 0 then t1 - t2 == t1 + val intH = part1SetOps.intersection(h1, h2) + val d1 = for { - ai <- as - bi <- bs - c <- intersection(ai, bi) - // we know that everything - // in the result must be in a - a2 <- as - ca <- intersection(c.prependWild, a2) - } yield ca - - unifyUnion(intr) - } + h <- intH + t <- difference(t1, t2) + } yield Cat(h, t) + + val d2 = part1SetOps.difference(h1, h2).map(Cat(_, t1)) + unifyUnion(d1 ::: d2) + case (Cat(h1: SeqPart1[A], t1), Cat(Wildcard, t2)) => + // h1:t1 - (*:t2) = ((h1:t1 - _:p2) - t2) + // + // h1:t1 - _:p2 = (h1 n _) : (t1 - p2) + (h1 - _):t1 + // = h1 : (t1 - p2) + val dtail = difference(t1, p2).map(Cat(h1, _)) + val u = differenceAll(dtail, t2 :: Nil) + + unifyUnion(u) + case (Cat(Wildcard, t1), Cat(h2: SeqPart1[A], t2)) => + // *:t1 - (h2:t2) = t1 + _:p1 - h2:t2 + // = (t1 - p2) + (_:p1 - h2:t2) + val d12 = { + // (_:p1 - h2:t2) = + // (_ n h2):(p1 - t2) + (_ - h2):p1 + // h2:(p1 - t2) + (_ - h2):p1 + // + // or: + // (_ - h2):(p1 n t2) + _:(p1 - t2) + if (disjoint(p1, t2)) { + Cat(AnyElem, p1) :: Nil + } else { + val dtail = difference(p1, t2) + val d1 = dtail.map(Cat(h2, _)) + val d2 = part1SetOps.difference(AnyElem, h2).map(Cat(_, p1)) + d1 ::: d2 + } + } + val d3 = difference(t1, p2) + unifyUnion(d12 ::: d3) + case (c1 @ Cat(Wildcard, t1), c2 @ Cat(Wildcard, t2)) => + if (c1.rightMost.notWild || c2.rightMost.notWild) { + // let's avoid the most complex case of both having + // wild on the front if possible + difference(c1.reverse, c2.reverse).map(_.reverse) + } else { + // both start and end with wildcard + // + // p1 - (t2 + _:p2) = + // if (t2 n (_:*:t2)).isEmpty, then + // then we can use a simpler formula + // but that is very uncommon (maybe + // we can find a proof it can't + // happen if t2 ends with Wildcard + // which is the case we are in. + // + // otherwise + // this branch is approximate: + // (p1 - t2) n (p1 - _:p2) = + // (p1 - t2) n ((t1 + _:p1) - _:p2) + // (p1 - t2) n ((t1 - _:p2) + _:(p1 - p2)) + // + // x = a n (b + _:x) + // = (a n b) + a n (_:x) + // <= (a n b) + (_:x) + // = *:(a n b) + // + // if t1 = [], then the above gives + // either empty set or *. + // this is a common case when we are + // searching for missing branches, we + // start at * - x + // + // (* - t2) n (([] - _:p2) + _:(p1 - p2)) + // a = * - t2 + // = (* - t2) n ([] + _:(p1 - p2)) + // p1 - p2 = a n ([] + _:(a n ([] + _:(a n ([] + _: ... + // <= a n ([] _ :(a n []) + _ _ :(a n []) +++ + // = a n (*:(a n [])) + // + // since a <= *, in the right side we have + // a n * = a + // so p1 - p2 <= a + // + // note, a is always an upper bound due + // to formula x = a n (...) + val as = difference(p1, t2) + if (t1.isEmpty) { + as + } else { + // if x <= *:(a n b) and a then it is <= a n (*:(a n b)) + val bs = difference(t1, Cat(AnyElem, p2)) + // (a1 + a2) n (b1 + b2) = + val intr = + for { + ai <- as + bi <- bs + c <- intersection(ai, bi) + // we know that everything + // in the result must be in a + a2 <- as + ca <- intersection(c.prependWild, a2) + } yield ca + + unifyUnion(intr) + } + } + // $COVERAGE-OFF$ + case pair => + sys.error( + s"unreachable shouldn't be Super or Intersects: $pair" + ) + // $COVERAGE-ON$ } - // $COVERAGE-OFF$ - case pair => - sys.error(s"unreachable shouldn't be Super or Intersects: $pair") - // $COVERAGE-ON$ } - } } - def matcher[A, I, S, R](split: Splitter[A, I, S, R]): Matcher[SeqPattern[A], S, R] = + def matcher[A, I, S, R]( + split: Splitter[A, I, S, R] + ): Matcher[SeqPattern[A], S, R] = new Matcher[SeqPattern[A], S, R] { import SeqPart.{AnyElem, Lit, SeqPart1, Wildcard} @@ -708,7 +705,8 @@ object SeqPattern { (h, t) <- split.uncons(s) rh <- mh(h) rt <- mt(t) - } yield split.monoidResult.combine(rh, rt) } + } yield split.monoidResult.combine(rh, rt) + } case Cat(AnyElem, t) => val mt = apply(t) @@ -716,26 +714,31 @@ object SeqPattern { for { (_, t) <- split.uncons(s) rt <- mt(t) - } yield rt } + } yield rt + } case Cat(Wildcard, t) => matchEnd(t).andThen(_.headOption.map(_._2)) } def matchEnd(p: SeqPattern[A]): S => LazyList[(S, R)] = p match { - case Empty => { (s: S) => (s, split.monoidResult.empty) #:: LazyList.empty } + case Empty => { (s: S) => + (s, split.monoidResult.empty) #:: LazyList.empty + } case Cat(p: SeqPart1[A], t) => val splitFn: S => LazyList[(S, I, R, S)] = p match { - case Lit(c) => split.positions(c) + case Lit(c) => split.positions(c) case AnyElem => split.anySplits(_: S) } val tailMatch = apply(t) { (s: S) => splitFn(s) - .map { case (pre, _, r, post) => + .map { case (pre, _, r, post) => tailMatch(post) - .map { rtail => (pre, split.monoidResult.combine(r, rtail)) } + .map { rtail => + (pre, split.monoidResult.combine(r, rtail)) + } } .collect { case Some(res) => res } } diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala index a988b7cff..45ac0cda7 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala @@ -24,22 +24,29 @@ abstract class Splitter[-Elem, Item, Sequence, R] { } object Splitter { - def stringSplitter[R](fn: Char => R)(implicit m: Monoid[R]): Splitter[Char, Char, String, R] = + def stringSplitter[R]( + fn: Char => R + )(implicit m: Monoid[R]): Splitter[Char, Char, String, R] = new Splitter[Char, Char, String, R] { val matcher = Matcher.charMatcher .mapWithInput { (s, _) => fn(s) } val monoidResult = m - def positions(c: Char): String => LazyList[(String, Char, R, String)] = { str => - def loop(init: Int): LazyList[(String, Char, R, String)] = - if (init >= str.length) LazyList.empty - else if (str.charAt(init) == c) { - (str.substring(0, init), c, fn(c), str.substring(init + 1)) #:: loop(init + 1) - } - else loop(init + 1) - - loop(0) + def positions(c: Char): String => LazyList[(String, Char, R, String)] = { + str => + def loop(init: Int): LazyList[(String, Char, R, String)] = + if (init >= str.length) LazyList.empty + else if (str.charAt(init) == c) { + ( + str.substring(0, init), + c, + fn(c), + str.substring(init + 1) + ) #:: loop(init + 1) + } else loop(init + 1) + + loop(0) } def anySplits(str: String): LazyList[(String, Char, R, String)] = @@ -74,12 +81,15 @@ object Splitter { val matchFn = matcher(c) { (str: List[V]) => - def loop(tail: List[V], acc: List[V]): LazyList[(List[V], V, R, List[V])] = + def loop( + tail: List[V], + acc: List[V] + ): LazyList[(List[V], V, R, List[V])] = tail match { case Nil => LazyList.empty case h :: t => matchFn(h) match { - case None => loop(t, h :: acc) + case None => loop(t, h :: acc) case Some(r) => (acc.reverse, h, r, t) #:: loop(t, h :: acc) } } @@ -92,7 +102,8 @@ object Splitter { def loop(str: List[V], acc: List[V]): LazyList[(List[V], V, R, List[V])] = str match { case Nil => LazyList.empty - case h :: t => (acc.reverse, h, monoidResult.empty, t) #:: loop(t, h :: acc) + case h :: t => + (acc.reverse, h, monoidResult.empty, t) #:: loop(t, h :: acc) } loop(str, Nil) } @@ -111,7 +122,9 @@ object Splitter { final override def fromList(cs: List[V]) = cs } - def listSplitter[P, V, R](m: Matcher[P, V, R])(implicit mon: Monoid[R]): Splitter[P, V, List[V], R] = + def listSplitter[P, V, R]( + m: Matcher[P, V, R] + )(implicit mon: Monoid[R]): Splitter[P, V, List[V], R] = new ListSplitter[P, V, R] { val matcher = m val monoidResult = mon diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/DataRepr.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/DataRepr.scala index db7f7c593..c17f2ac9d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/DataRepr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/DataRepr.scala @@ -1,9 +1,7 @@ package org.bykn.bosatsu.rankn -/** - * How is a non-external data type - * represented - */ +/** How is a non-external data type represented + */ sealed abstract class DataRepr @@ -13,7 +11,8 @@ object DataRepr { case object ZeroNat extends Nat(true) case object SuccNat extends Nat(false) - case class Enum(variant: Int, arity: Int, familyArities: List[Int]) extends DataRepr + case class Enum(variant: Int, arity: Int, familyArities: List[Int]) + extends DataRepr // a struct with arity 1 can be elided, and is called a new-type case class Struct(arity: Int) extends DataRepr { require(arity != 1) diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala index 818f316f4..2fa6f946e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala @@ -10,21 +10,20 @@ import cats.implicits._ import cats.data.NonEmptyList final case class DefinedType[+A]( - packageName: PackageName, - name: TypeName, - annotatedTypeParams: List[(Type.Var.Bound, A)], - constructors: List[ConstructorFn]) { + packageName: PackageName, + name: TypeName, + annotatedTypeParams: List[(Type.Var.Bound, A)], + constructors: List[ConstructorFn] +) { val typeParams: List[Type.Var.Bound] = annotatedTypeParams.map(_._1) require(typeParams.distinct == typeParams, typeParams.toString) - /** - * This is not the full type, since the full type - * has a ForAll(typeParams, ... in front if the - * typeParams is nonEmpty - */ + /** This is not the full type, since the full type has a ForAll(typeParams, + * ... in front if the typeParams is nonEmpty + */ val toTypeConst: Type.Const.Defined = DefinedType.toTypeConst(packageName, name) @@ -38,9 +37,9 @@ final case class DefinedType[+A]( (_, t) <- cfn.args } yield t ) - /** - * A type with exactly one constructor is a struct - */ + + /** A type with exactly one constructor is a struct + */ def isStruct: Boolean = dataFamily == DataFamily.Struct val dataRepr: Constructor => DataRepr = @@ -54,13 +53,11 @@ final case class DefinedType[+A]( val zero = c0.name { cons => if (cons == zero) DataRepr.ZeroNat else DataRepr.SuccNat } - } - else if (c1.isZeroArg && c0.hasSingleArgType(toTypeTyConst)) { + } else if (c1.isZeroArg && c0.hasSingleArgType(toTypeTyConst)) { val zero = c1.name { cons => if (cons == zero) DataRepr.ZeroNat else DataRepr.SuccNat } - } - else { + } else { val famArities = c0.arity :: c1.arity :: Nil val zero = c0.name val zrep = DataRepr.Enum(0, c0.arity, famArities) @@ -70,7 +67,9 @@ final case class DefinedType[+A]( } case cons => val famArities = cons.map(_.arity) - val mapping = cons.zipWithIndex.map { case (c, idx) => c.name -> DataRepr.Enum(idx, c.arity, famArities) }.toMap + val mapping = cons.zipWithIndex.map { case (c, idx) => + c.name -> DataRepr.Enum(idx, c.arity, famArities) + }.toMap mapping } @@ -83,12 +82,15 @@ final case class DefinedType[+A]( case c0 :: c1 :: Nil => // exactly two constructor functions if (c0.isZeroArg && c1.hasSingleArgType(toTypeTyConst)) DataFamily.Nat - else if (c1.isZeroArg && c0.hasSingleArgType(toTypeTyConst)) DataFamily.Nat + else if (c1.isZeroArg && c0.hasSingleArgType(toTypeTyConst)) + DataFamily.Nat else DataFamily.Enum case _ => DataFamily.Enum - } + } - private def toAnnotatedKinds(implicit ev: A <:< Kind.Arg): List[(Type.Var.Bound, Kind.Arg)] = { + private def toAnnotatedKinds(implicit + ev: A <:< Kind.Arg + ): List[(Type.Var.Bound, Kind.Arg)] = { type L[+X] = List[(Type.Var.Bound, X)] ev.substituteCo[L](annotatedTypeParams) } @@ -98,11 +100,11 @@ final case class DefinedType[+A]( val tc: Type.Rho = Type.const(packageName, name) val res = typeParams.foldLeft(tc) { (res, v) => - Type.TyApply(res, Type.TyVar(v)) - } + Type.TyApply(res, Type.TyVar(v)) + } val resT = NonEmptyList.fromList(cf.args.map(_._2)) match { case Some(nel) => Type.Fun(nel, res) - case None => res + case None => res } val typeArgs = toAnnotatedKinds.map { case (b, ka) => (b, ka.kind) } Type.forAll(typeArgs, resT) @@ -116,24 +118,33 @@ object DefinedType { def toTypeConst(pn: PackageName, nm: TypeName): Type.Const.Defined = Type.Const.Defined(pn, nm) - def listToMap[A](dts: List[DefinedType[A]]): SortedMap[(PackageName, TypeName), DefinedType[A]] = + def listToMap[A]( + dts: List[DefinedType[A]] + ): SortedMap[(PackageName, TypeName), DefinedType[A]] = SortedMap(dts.map { dt => (dt.packageName, dt.name) -> dt }: _*) - def toKindMap[F[_]: Foldable](dts: F[DefinedType[Kind.Arg]]): Map[Type.Const.Defined, Kind] = - dts.foldLeft( - Map.newBuilder[Type.Const.Defined, Kind] - ) { (b, dt) => b += ((dt.toTypeConst.toDefined, dt.kindOf)) } - .result() + def toKindMap[F[_]: Foldable]( + dts: F[DefinedType[Kind.Arg]] + ): Map[Type.Const.Defined, Kind] = + dts + .foldLeft( + Map.newBuilder[Type.Const.Defined, Kind] + ) { (b, dt) => b += ((dt.toTypeConst.toDefined, dt.kindOf)) } + .result() implicit val definedTypeTraverse: Traverse[DefinedType] = new Traverse[DefinedType] { val listTup = Traverse[List].compose[(Type.Var.Bound, *)] - def traverse[F[_]: Applicative, A, B](da: DefinedType[A])(fn: A => F[B]): F[DefinedType[B]] = + def traverse[F[_]: Applicative, A, B]( + da: DefinedType[A] + )(fn: A => F[B]): F[DefinedType[B]] = listTup.traverse(da.annotatedTypeParams)(fn).map { ap => da.copy(annotatedTypeParams = ap) } - def foldRight[A, B](fa: DefinedType[A], b: Eval[B])(fn: (A, Eval[B]) => Eval[B]): Eval[B] = + def foldRight[A, B](fa: DefinedType[A], b: Eval[B])( + fn: (A, Eval[B]) => Eval[B] + ): Eval[B] = listTup.foldRight(fa.annotatedTypeParams, b)(fn) def foldLeft[A, B](fa: DefinedType[A], b: B)(fn: (B, A) => B): B = diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index fd569c319..3002eef04 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -18,7 +18,8 @@ import org.bykn.bosatsu.{ Region, RecursionKind, TypedExpr, - Variance} + Variance +} import scala.collection.immutable.SortedSet @@ -43,15 +44,17 @@ sealed abstract class Infer[+A] { Infer.Impl.MapEither(this, fn) final def runVar( - v: Map[Infer.Name, Type], - tpes: Map[(PackageName, Constructor), Infer.Cons], - kinds: Map[Type.Const.Defined, Kind]): RefSpace[Either[Error, A]] = + v: Map[Infer.Name, Type], + tpes: Map[(PackageName, Constructor), Infer.Cons], + kinds: Map[Type.Const.Defined, Kind] + ): RefSpace[Either[Error, A]] = Infer.Env.init(v, tpes, kinds).flatMap(run(_)) final def runFully( - v: Map[Infer.Name, Type], - tpes: Map[(PackageName, Constructor), Infer.Cons], - kinds: Map[Type.Const.Defined, Kind]): Either[Error, A] = + v: Map[Infer.Name, Type], + tpes: Map[(PackageName, Constructor), Infer.Cons], + kinds: Map[Type.Const.Defined, Kind] + ): Either[Error, A] = runVar(v, tpes, kinds).run.value } @@ -71,35 +74,35 @@ object Infer { TailRecM(a, fn) } - implicit val inferParallel: cats.Parallel[Infer] = - new ParallelViaProduct[Infer] { - def monad = inferMonad - def parallelProduct[A, B](fa: Infer[A], fb: Infer[B]): Infer[(A, B)] = - ParallelProduct(fa, fb) - } - + implicit val inferParallel: cats.Parallel[Infer] = + new ParallelViaProduct[Infer] { + def monad = inferMonad + def parallelProduct[A, B](fa: Infer[A], fb: Infer[B]): Infer[(A, B)] = + ParallelProduct(fa, fb) + } - /** - * The first element of the tuple are the the bound type - * vars for this type. - * the next are the types of the args of the constructor - * the final is the defined type this creates - */ + /** The first element of the tuple are the the bound type vars for this type. + * the next are the types of the args of the constructor the final is the + * defined type this creates + */ type Cons = (List[(Type.Var.Bound, Kind.Arg)], List[Type], Type.Const.Defined) type Name = (Option[PackageName], Identifier) class Env( - val uniq: Ref[Long], - val vars: Map[Name, Type], - val typeCons: Map[(PackageName, Constructor), Cons], - val variances: Map[Type.Const.Defined, Kind]) { + val uniq: Ref[Long], + val vars: Map[Name, Type], + val typeCons: Map[(PackageName, Constructor), Cons], + val variances: Map[Type.Const.Defined, Kind] + ) { def addVars(vt: NonEmptyList[(Name, Type)]): Env = new Env(uniq, vars = (vars + vt.head) ++ vt.tail, typeCons, variances) private[this] val kindCache: Type => Either[Region => Error, Kind] = Type.kindOf[Region => Error]( - b => { region => Error.UnknownKindOfVar(Type.TyVar(b), region, s"unbound var: $b") }, + b => { region => + Error.UnknownKindOfVar(Type.TyVar(b), region, s"unbound var: $b") + }, ap => { region => Error.KindCannotTyApply(ap, region) }, @@ -107,17 +110,16 @@ object Infer { Error.KindInvalidApply(ap, cons, rhs, region) }, { case Type.TyConst(const) => - val d = const.toDefined - // some tests rely on syntax without importing - // TODO remove this - variances.get(d).orElse(Type.builtInKinds.get(d)) match { - case Some(ks) => Right(ks) - case None => Left({ region => Error.UnknownDefined(d, region) }) - } + val d = const.toDefined + // some tests rely on syntax without importing + // TODO remove this + variances.get(d).orElse(Type.builtInKinds.get(d)) match { + case Some(ks) => Right(ks) + case None => Left({ region => Error.UnknownDefined(d, region) }) } + } ) - def getKindOpt(t: Type): Option[Kind] = kindCache(t).toOption @@ -129,9 +131,10 @@ object Infer { object Env { def init( - vars: Map[Name, Type], - tpes: Map[(PackageName, Constructor), Cons], - kinds: Map[Type.Const.Defined, Kind]): RefSpace[Env] = + vars: Map[Name, Type], + tpes: Map[(PackageName, Constructor), Cons], + kinds: Map[Type.Const.Defined, Kind] + ): RefSpace[Env] = RefSpace.newRef(0L).map(new Env(_, vars, tpes, kinds)) } @@ -152,7 +155,7 @@ object Infer { def lookupVarType(v: Name, reg: Region): Infer[Type] = getEnv.flatMap { env => env.get(v) match { - case None => fail(Error.VarNotInScope(v, env, reg)) + case None => fail(Error.VarNotInScope(v, env, reg)) case Some(t) => pure(t) } } @@ -161,59 +164,119 @@ object Infer { object Error { sealed abstract class Single extends Error - /** - * These are errors in the ability to type the code - * Generally these cannot be caught by other phases - */ + + /** These are errors in the ability to type the code Generally these cannot + * be caught by other phases + */ sealed abstract class TypeError extends Single - case class NotUnifiable(left: Type, right: Type, leftRegion: Region, rightRegion: Region) extends TypeError - case class KindInvalidApply(typeApply: Type.TyApply, leftK: Kind.Cons, rightK: Kind, region: Region) extends TypeError - case class KindMismatch(target: Type, targetKind: Kind, source: Type, sourceKind: Kind, targetRegion: Region, sourceRegion: Region) extends TypeError - case class KindCannotTyApply(ap: Type.TyApply, region: Region) extends TypeError - case class UnknownDefined(tpe: Type.Const.Defined, region: Region) extends TypeError - case class NotPolymorphicEnough(tpe: Type, in: Expr[_], badTvs: NonEmptyList[Type], reg: Region) extends TypeError - case class SubsumptionCheckFailure(inferred: Type, declared: Type, infRegion: Region, decRegion: Region, badTvs: NonEmptyList[Type]) extends TypeError + case class NotUnifiable( + left: Type, + right: Type, + leftRegion: Region, + rightRegion: Region + ) extends TypeError + case class KindInvalidApply( + typeApply: Type.TyApply, + leftK: Kind.Cons, + rightK: Kind, + region: Region + ) extends TypeError + case class KindMismatch( + target: Type, + targetKind: Kind, + source: Type, + sourceKind: Kind, + targetRegion: Region, + sourceRegion: Region + ) extends TypeError + case class KindCannotTyApply(ap: Type.TyApply, region: Region) + extends TypeError + case class UnknownDefined(tpe: Type.Const.Defined, region: Region) + extends TypeError + case class NotPolymorphicEnough( + tpe: Type, + in: Expr[_], + badTvs: NonEmptyList[Type], + reg: Region + ) extends TypeError + case class SubsumptionCheckFailure( + inferred: Type, + declared: Type, + infRegion: Region, + decRegion: Region, + badTvs: NonEmptyList[Type] + ) extends TypeError // this sounds internal but can be due to an infinite type attempted to be defined - case class UnexpectedMeta(m: Type.Meta, in: Type, left: Region, right: Region) extends TypeError - case class ArityMismatch(leftArity: Int, leftRegion: Region, rightArity: Int, rightRegion: Region) extends TypeError - case class ArityTooLarge(arity: Int, maxArity: Int, region: Region) extends TypeError - - /** - * These are errors that prevent typing due to unknown names, - * They could be caught in a phase that collects all the naming errors - */ + case class UnexpectedMeta( + m: Type.Meta, + in: Type, + left: Region, + right: Region + ) extends TypeError + case class ArityMismatch( + leftArity: Int, + leftRegion: Region, + rightArity: Int, + rightRegion: Region + ) extends TypeError + case class ArityTooLarge(arity: Int, maxArity: Int, region: Region) + extends TypeError + + /** These are errors that prevent typing due to unknown names, They could be + * caught in a phase that collects all the naming errors + */ sealed abstract class NameError extends Single // This could be a user error if we don't check scoping before typing - case class VarNotInScope(varName: Name, vars: Map[Name, Type], region: Region) extends NameError + case class VarNotInScope( + varName: Name, + vars: Map[Name, Type], + region: Region + ) extends NameError // This could be a user error if we don't check scoping before typing - case class UnexpectedBound(v: Type.Var.Bound, in: Type, rb: Region, rt: Region) extends NameError - case class UnknownConstructor(name: (PackageName, Constructor), region: Region, env: Env) extends NameError { - def knownConstructors: List[(PackageName, Constructor)] = env.typeCons.keys.toList.sorted + case class UnexpectedBound( + v: Type.Var.Bound, + in: Type, + rb: Region, + rt: Region + ) extends NameError + case class UnknownConstructor( + name: (PackageName, Constructor), + region: Region, + env: Env + ) extends NameError { + def knownConstructors: List[(PackageName, Constructor)] = + env.typeCons.keys.toList.sorted } - case class UnionPatternBindMismatch(pattern: Pattern, names: NonEmptyList[List[Identifier.Bindable]], region: Region) extends NameError + case class UnionPatternBindMismatch( + pattern: Pattern, + names: NonEmptyList[List[Identifier.Bindable]], + region: Region + ) extends NameError - /** - * These can only happen if the compiler has bugs at some point - */ + /** These can only happen if the compiler has bugs at some point + */ sealed abstract class InternalError extends Single { def message: String def region: Region } // This is a logic error which should never happen - case class InferIncomplete(term: Expr[_], region: Region) extends InternalError { + case class InferIncomplete(term: Expr[_], region: Region) + extends InternalError { // $COVERAGE-OFF$ we don't test these messages, maybe they should be removed def message = s"inferRho not complete for $term" // $COVERAGE-ON$ we don't test these messages, maybe they should be removed } - case class ExpectedRho(tpe: Type, context: String, region: Region) extends InternalError { + case class ExpectedRho(tpe: Type, context: String, region: Region) + extends InternalError { // $COVERAGE-OFF$ we don't test these messages, maybe they should be removed def message = s"expected $tpe to be a Type.Rho, at $context" // $COVERAGE-ON$ we don't test these messages, maybe they should be removed } - case class UnknownKindOfVar(tpe: Type, region: Region, mess: String) extends InternalError { + case class UnknownKindOfVar(tpe: Type, region: Region, mess: String) + extends InternalError { // $COVERAGE-OFF$ we don't test these messages, maybe they should be removed def message = s"unknown var in $tpe: $mess at $region" // $COVERAGE-ON$ we don't test these messages, maybe they should be removed @@ -221,7 +284,11 @@ object Infer { // here is when we have more than one error case class Combine(left: Error, right: Error) extends Error { - private def flatten(errs: NonEmptyList[Error], inAcc: Set[Single], acc: Chain[Single]): NonEmptyChain[Single] = + private def flatten( + errs: NonEmptyList[Error], + inAcc: Set[Single], + acc: Chain[Single] + ): NonEmptyChain[Single] = errs match { case NonEmptyList(s: Single, tail) => tail match { @@ -229,15 +296,13 @@ object Infer { if (inAcc(s)) { // we know s is in acc, so Chain must not be empty NonEmptyChain.fromChainUnsafe(acc) - } - else { + } else { NonEmptyChain.fromChainAppend(acc, s) } case h :: t => if (inAcc(s)) { flatten(NonEmptyList(h, t), inAcc, acc) - } - else { + } else { flatten(NonEmptyList(h, t), inAcc + s, acc :+ s) } } @@ -250,16 +315,14 @@ object Infer { } } - - /** - * This is where the internal implementation goes. - * It is here to make it easy to make one block private - * and not do so on every little helper function - */ + /** This is where the internal implementation goes. It is here to make it easy + * to make one block private and not do so on every little helper function + */ private object Impl { sealed abstract class Expected[A] object Expected { - case class Inf[A](ref: Ref[Either[Error.InferIncomplete, A]]) extends Expected[A] { + case class Inf[A](ref: Ref[Either[Error.InferIncomplete, A]]) + extends Expected[A] { def set(a: A): Infer[Unit] = Infer.lift(ref.set(Right(a))) } @@ -269,23 +332,25 @@ object Infer { case class FlatMap[A, B](fa: Infer[A], fn: A => Infer[B]) extends Infer[B] { def run(env: Env) = fa.run(env).flatMap { - case Right(a) => fn(a).run(env) + case Right(a) => fn(a).run(env) case left @ Left(_) => RefSpace.pure(left.rightCast) } } - case class ParallelProduct[A, B](fa: Infer[A], fb: Infer[B]) extends Infer[(A, B)] { + case class ParallelProduct[A, B](fa: Infer[A], fb: Infer[B]) + extends Infer[(A, B)] { def run(env: Env) = fa.run(env).flatMap { - case Right(a) => fb.run(env).map { - case Right(b) => Right((a, b)) - case left @ Left(_) => left.rightCast - } + case Right(a) => + fb.run(env).map { + case Right(b) => Right((a, b)) + case left @ Left(_) => left.rightCast + } case left @ Left(errA) => fb.run(env).map { - case Right(_) => left.rightCast + case Right(_) => left.rightCast case Left(errB) => Left(Error.Combine(errA, errB)) - } + } } } @@ -298,22 +363,24 @@ object Infer { // $COVERAGE-ON$ this should be unreachable } } - case class MapEither[A, B](fa: Infer[A], fn: A => Either[Error, B]) extends Infer[B] { + case class MapEither[A, B](fa: Infer[A], fn: A => Either[Error, B]) + extends Infer[B] { def run(env: Env) = fa.run(env).flatMap { - case Right(a) => RefSpace.pure(fn(a)) + case Right(a) => RefSpace.pure(fn(a)) case left @ Left(_) => RefSpace.pure(left.rightCast) } } // $COVERAGE-OFF$ needed for Monad, but not actually used - case class TailRecM[A, B](init: A, fn: A => Infer[Either[A, B]]) extends Infer[B] { + case class TailRecM[A, B](init: A, fn: A => Infer[Either[A, B]]) + extends Infer[B] { def run(env: Env) = { // RefSpace uses Eval so this is fine, if not maybe the fastest thing ever def loop(a: A): RefSpace[Either[Error, B]] = fn(a).run(env).flatMap { - case Left(err) => RefSpace.pure(Left(err)) - case Right(Left(a)) => loop(a) + case Left(err) => RefSpace.pure(Left(err)) + case Right(Left(a)) => loop(a) case Right(Right(b)) => RefSpace.pure(Right(b)) } loop(init) @@ -322,7 +389,8 @@ object Infer { // $COVERAGE-ON$ case object GetEnv extends Infer[Env] { - def run(env: Env): RefSpace[Either[Error, Env]] = RefSpace.pure(Right(env)) + def run(env: Env): RefSpace[Either[Error, Env]] = + RefSpace.pure(Right(env)) } def GetDataCons(fqn: (PackageName, Constructor), reg: Region): Infer[Cons] = @@ -334,7 +402,8 @@ object Infer { } } - case class ExtendEnvs[A](vt: NonEmptyList[(Name, Type)], in: Infer[A]) extends Infer[A] { + case class ExtendEnvs[A](vt: NonEmptyList[(Name, Type)], in: Infer[A]) + extends Infer[A] { def run(env: Env) = in.run(env.addVars(vt)) } @@ -360,8 +429,7 @@ object Infer { private val checkedKinds: Infer[Type => Option[Kind]] = { val emptyRegion = Region(0, 0) GetEnv.map { env => - - { tpe => env.getKind(tpe, emptyRegion).toOption } + { tpe => env.getKind(tpe, emptyRegion).toOption } } } @@ -371,41 +439,46 @@ object Infer { kindOf(ta.on, region) .flatMap(varianceOfConsKind(ta, _, region)) - def varianceOfConsKind(ta: Type.TyApply, k: Kind, region: Region): Infer[Variance] = + def varianceOfConsKind( + ta: Type.TyApply, + k: Kind, + region: Region + ): Infer[Variance] = k match { case Kind.Cons(Kind.Arg(v, _), _) => pure(v) case Kind.Type => fail(Error.KindCannotTyApply(ta, region)) } - /** - * Skolemize on a function just recurses on the result type. - * - * Skolemize replaces ForAll parameters with skolem variables - * and then skolemizes recurses on the substituted value - * - * otherwise we return the type. - * - * The returned type is in weak-prenex form: all ForAlls have - * been floated up over covariant parameters - * - * see: https://www.csd.uwo.ca/~lkari/prenex.pdf - * It seems that if C[x] is covariant, then - * C[forall x. D[x]] == forall x. C[D[x]] - * - * this is always true for existential quantification I think, but - * for universal, we need that C is covariant which roughtly - * means C[x] either has x in a return position of a function, or - * not at all, which then gives us that - * (forall x. (A(x) u B(x))) == (forall x A(x)) u (forall x B(x)) - * where A(x) and B(x) represent the union branches of the type C - */ + /** Skolemize on a function just recurses on the result type. + * + * Skolemize replaces ForAll parameters with skolem variables and then + * skolemizes recurses on the substituted value + * + * otherwise we return the type. + * + * The returned type is in weak-prenex form: all ForAlls have been floated + * up over covariant parameters + * + * see: https://www.csd.uwo.ca/~lkari/prenex.pdf It seems that if C[x] is + * covariant, then C[forall x. D[x]] == forall x. C[D[x]] + * + * this is always true for existential quantification I think, but for + * universal, we need that C is covariant which roughtly means C[x] either + * has x in a return position of a function, or not at all, which then + * gives us that (forall x. (A(x) u B(x))) == (forall x A(x)) u (forall x + * B(x)) where A(x) and B(x) represent the union branches of the type C + */ private def skolemize( - t: Type, - region: Region): Infer[(List[Type.Var.Skolem], List[Type.TyMeta], Type.Rho)] = { - + t: Type, + region: Region + ): Infer[(List[Type.Var.Skolem], List[Type.TyMeta], Type.Rho)] = { + // Invariant: if t is Rho, then result._3 is Rho - def loop(t: Type, path: Variance): Infer[(List[Type.Var.Skolem], List[Type.TyMeta], Type)] = + def loop( + t: Type, + path: Variance + ): Infer[(List[Type.Var.Skolem], List[Type.TyMeta], Type)] = t match { case q: Type.Quantified => if (path == Variance.co) { @@ -414,33 +487,33 @@ object Infer { val ty = q.in // Rule PRPOLY for { - sks1 <- univ.traverse { case (b, k) => newSkolemTyVar(b, k, existential = false) } - ms <- exists.traverse { case (_, k) => newExistential(k) } + sks1 <- univ.traverse { case (b, k) => + newSkolemTyVar(b, k, existential = false) + } + ms <- exists.traverse { case (_, k) => newExistential(k) } sksT = sks1.map(Type.TyVar(_)) ty1 = Type.substituteRhoVar( ty, (exists.map(_._1).iterator.zip(ms) ++ - univ.map(_._1).iterator.zip(sksT.iterator)) - .toMap) + univ.map(_._1).iterator.zip(sksT.iterator)).toMap + ) (sks2, ms2, ty) <- loop(ty1, path) } yield (sks1 ::: sks2, ms ::: ms2, ty) - } - else pure((Nil, Nil, t)) + } else pure((Nil, Nil, t)) - case ta@Type.TyApply(left, right) => + case ta @ Type.TyApply(left, right) => // Rule PRFUN // we know the kind of left is k -> x, and right has kind k // since left: Rho, we know loop(left, path)._3 is Rho (varianceOfCons(ta, region), loop(left, path)) - .flatMapN { - case (consVar, (sksl, el, ltpe0)) => - // due to loop invariant - val ltpe: Type.Rho = ltpe0.asInstanceOf[Type.Rho] - val rightPath = consVar * path - loop(right, rightPath) - .map { case (sksr, er, rtpe) => - (sksl ::: sksr, el ::: er, Type.TyApply(ltpe, rtpe)) - } + .flatMapN { case (consVar, (sksl, el, ltpe0)) => + // due to loop invariant + val ltpe: Type.Rho = ltpe0.asInstanceOf[Type.Rho] + val rightPath = consVar * path + loop(right, rightPath) + .map { case (sksr, er, rtpe) => + (sksl ::: sksr, el ::: er, Type.TyApply(ltpe, rtpe)) + } } case other: Type.Rho => // Rule PRMONO @@ -452,7 +525,8 @@ object Infer { (skols, metas, rho) // $COVERAGE-OFF$ this should be unreachable // because we only return ForAll on paths nested inside noncovariant path in TyApply - case (sks, metas, notRho) => sys.error(s"type = $t, sks = $sks, metas = $metas notRho = $notRho") + case (sks, metas, notRho) => + sys.error(s"type = $t, sks = $sks, metas = $metas notRho = $notRho") // $COVERAGE-ON$ this should be unreachable } } @@ -466,7 +540,7 @@ object Infer { def existentialsOf(tm: Type.Meta): Infer[Set[Type.Meta]] = { val parents = readMeta(tm).flatMap { case Some(Type.TyMeta(m2)) => existentialsOf(m2) - case _ => pureEmpty + case _ => pureEmpty } if (tm.existential) parents.map(_ + tm) @@ -483,33 +557,44 @@ object Infer { val zonk: Type.Meta => Infer[Option[Type.Rho]] = Type.zonk[Infer](SortedSet.empty, readMeta _, writeMeta _) - /** - * This fills in any meta vars that have been - * quantified and replaces them with what they point to - */ + /** This fills in any meta vars that have been quantified and replaces them + * with what they point to + */ def zonkType(t: Type): Infer[Type] = Type.zonkMeta(t)(zonk(_)) def zonkTypedExpr[A](e: TypedExpr[A]): Infer[TypedExpr[A]] = TypedExpr.zonkMeta(e)(zonk(_)) - val zonkTypeExprK: FunctionK[TypedExpr.Rho, Lambda[x => Infer[TypedExpr[x]]]] = + val zonkTypeExprK + : FunctionK[TypedExpr.Rho, Lambda[x => Infer[TypedExpr[x]]]] = new FunctionK[TypedExpr.Rho, Lambda[x => Infer[TypedExpr[x]]]] { def apply[A](fa: TypedExpr[A]): Infer[TypedExpr[A]] = zonkTypedExpr(fa) } - def initRef[E: HasRegion, A](t: Expr[E]): Infer[Ref[Either[Error.InferIncomplete, A]]] = - lift(RefSpace.newRef[Either[Error.InferIncomplete, A]]( - Left(Error.InferIncomplete(t, region(t)))) + def initRef[E: HasRegion, A]( + t: Expr[E] + ): Infer[Ref[Either[Error.InferIncomplete, A]]] = + lift( + RefSpace.newRef[Either[Error.InferIncomplete, A]]( + Left(Error.InferIncomplete(t, region(t))) + ) ) - def substTyRho(keys: NonEmptyList[Type.Var], vals: NonEmptyList[Type.Rho]): Type.Rho => Type.Rho = { + def substTyRho( + keys: NonEmptyList[Type.Var], + vals: NonEmptyList[Type.Rho] + ): Type.Rho => Type.Rho = { val env = keys.toList.iterator.zip(vals.toList.iterator).toMap { t => Type.substituteRhoVar(t, env) } } - def substTyExpr[A](keys: NonEmptyList[Type.Var], vals: NonEmptyList[Type.Rho], expr: TypedExpr[A]): TypedExpr[A] = { + def substTyExpr[A]( + keys: NonEmptyList[Type.Var], + vals: NonEmptyList[Type.Rho], + expr: TypedExpr[A] + ): TypedExpr[A] = { val fn = Type.substTy(keys, vals) expr.traverseType[cats.Id](fn) } @@ -523,7 +608,11 @@ object Infer { * new meta variables for each bound variable in ForAll or skolemize * which replaces the ForAll variables with skolem variables */ - def assertRho(t: Type, context: => String, region: Region): Infer[Type.Rho] = + def assertRho( + t: Type, + context: => String, + region: Region + ): Infer[Type.Rho] = t match { case r: Type.Rho => pure(r) // $COVERAGE-OFF$ this should be unreachable @@ -546,36 +635,47 @@ object Infer { val univRho = NonEmptyList.fromList(univs) match { case Some(vars) => - vars.traverse { case (_, k) => newMetaType(k) } + vars + .traverse { case (_, k) => newMetaType(k) } .map { vars1T => substTyRho(vars.map(_._1), vars1T)(rho) } case None => pure(rho) } - univRho.flatMap { rho => - val exists = q.existList - for { - skols <- exists.traverse { case (b, k) => newSkolemTyVar(b, k, existential = true) } - env = exists - .iterator - .map(_._1) - .zip(skols.iterator.map(Type.TyVar)) - .toMap[Type.Var, Type.TyVar] - rho1 = Type.substituteRhoVar(rho, env) - } yield (skols, rho1) - } + univRho.flatMap { rho => + val exists = q.existList + for { + skols <- exists.traverse { case (b, k) => + newSkolemTyVar(b, k, existential = true) + } + env = exists.iterator + .map(_._1) + .zip(skols.iterator.map(Type.TyVar)) + .toMap[Type.Var, Type.TyVar] + rho1 = Type.substituteRhoVar(rho, env) + } yield (skols, rho1) + } case rho: Type.Rho => pure((Nil, rho)) } /* * Invariant: r2 needs to be in weak prenex form */ - def subsCheckFn(a1s: NonEmptyList[Type], r1: Type, a2s: NonEmptyList[Type], r2: Type.Rho, left: Region, right: Region): Infer[TypedExpr.Coerce] = + def subsCheckFn( + a1s: NonEmptyList[Type], + r1: Type, + a2s: NonEmptyList[Type], + r2: Type.Rho, + left: Region, + right: Region + ): Infer[TypedExpr.Coerce] = // note due to contravariance in input, we reverse the order there for { // we know that they have the same length because we have already called unifyFnRho - coarg <- a2s.zip(a1s).parTraverse { case (a2, a1) => subsCheck(a2, a1, right, left) } + coarg <- a2s.zip(a1s).parTraverse { case (a2, a1) => + subsCheck(a2, a1, right, left) + } // r2 is already in weak-prenex form cores <- subsCheckRho(r1, r2, left, right) ks <- checkedKinds @@ -589,12 +689,17 @@ object Infer { * was rewritten to: * forall a, b. a -> (b -> b) */ - def subsCheckRho(t: Type, rho: Type.Rho, left: Region, right: Region): Infer[TypedExpr.Coerce] = + def subsCheckRho( + t: Type, + rho: Type.Rho, + left: Region, + right: Region + ): Infer[TypedExpr.Coerce] = (t, rho) match { case (fa: Type.Quantified, rho) => subsInstantiate(fa, rho, left, right) match { case Some(inf) => inf - case None => + case None => // Rule SPEC for { (exSkols, faRho) <- instantiate(fa) @@ -611,86 +716,107 @@ object Infer { private val idCoerce = pure(FunctionK.id[TypedExpr]) // if t <:< rho, then coerce to rho - def subsCheckRho2(t: Type.Rho, rho: Type.Rho, left: Region, right: Region): Infer[TypedExpr.Coerce] = + def subsCheckRho2( + t: Type.Rho, + rho: Type.Rho, + left: Region, + right: Region + ): Infer[TypedExpr.Coerce] = if (t == rho) idCoerce - else (t, rho) match { - case (rho1, Type.Fun(a2, r2)) => - // Rule FUN - for { - (a1, r1) <- unifyFnRho(a2.length, rho1, left, right) - // since rho is in weak prenex form, and Fun is covariant on r2, we know - // r2 is in weak-prenex form and a rho type - rhor2 <- assertRho(r2, s"subsCheckRho2($t, $rho, $left, $right), line 619", right) - coerce <- subsCheckFn(a1, r1, a2, rhor2, left, right) - } yield coerce - case (Type.Fun(a1, r1), rho2) => - // Rule FUN - for { - (a2, r2) <- unifyFnRho(a1.length, rho2, right, left) - // since rho is in weak prenex form, and Fun is covariant on r2, we know - // r2 is in weak-prenex form - rhor2 <- assertRho(r2, s"subsCheckRho2($t, $rho, $left, $right), line 628", right) - coerce <- subsCheckFn(a1, r1, a2, rhor2, left, right) - } yield coerce - case (rho1, ta@Type.TyApply(l2, r2)) => - for { - (kl, kr) <- validateKinds(ta, right) - (l1, r1) <- unifyTyApp(rho1, kl, kr, left, right) - // Check from right to left - _ <- subsCheckRho2(l1, l2, left, right) - _ <- varianceOfConsKind(ta, kl, right).flatMap { - case Variance.Covariant => - subsCheck(r1, r2, left, right) - case Variance.Contravariant => - subsCheck(r2, r1, right, left) - case Variance.Phantom => - // this doesn't matter - unit - case Variance.Invariant => - unifyType(r1, r2, left, right) - } - ks <- checkedKinds - } yield TypedExpr.coerceRho(ta, ks) - case (ta@Type.TyApply(l1, r1), rho2) => - // here we know that rho2 != TyApply - for { - (kl, kr) <- validateKinds(ta, left) - // here we set the kinds of l2: kl and r2: k2 - // so the kinds definitely match - (l2, r2) <- unifyTyApp(rho2, kl, kr, right, left) - // Check from right to left - _ <- subsCheckRho2(l1, l2, left, right) - // we know that l2 has kind kl - _ <- varianceOfConsKind(Type.TyApply(l2, r2), kl, right).flatMap { - case Variance.Covariant => - subsCheck(r1, r2, left, right) - case Variance.Contravariant => - subsCheck(r2, r1, right, left) - case Variance.Phantom => - // this doesn't matter - unit - case Variance.Invariant => - unifyType(r1, r2, left, right) - } - ks <- checkedKinds - } yield TypedExpr.coerceRho(rho2, ks) - case (t1, t2) => - // rule: MONO - for { - _ <- unify(t1, t2, left, right) - ck <- checkedKinds - } yield TypedExpr.coerceRho(t1, ck) // TODO this coerce seems right, since we have unified - } + else + (t, rho) match { + case (rho1, Type.Fun(a2, r2)) => + // Rule FUN + for { + (a1, r1) <- unifyFnRho(a2.length, rho1, left, right) + // since rho is in weak prenex form, and Fun is covariant on r2, we know + // r2 is in weak-prenex form and a rho type + rhor2 <- assertRho( + r2, + s"subsCheckRho2($t, $rho, $left, $right), line 619", + right + ) + coerce <- subsCheckFn(a1, r1, a2, rhor2, left, right) + } yield coerce + case (Type.Fun(a1, r1), rho2) => + // Rule FUN + for { + (a2, r2) <- unifyFnRho(a1.length, rho2, right, left) + // since rho is in weak prenex form, and Fun is covariant on r2, we know + // r2 is in weak-prenex form + rhor2 <- assertRho( + r2, + s"subsCheckRho2($t, $rho, $left, $right), line 628", + right + ) + coerce <- subsCheckFn(a1, r1, a2, rhor2, left, right) + } yield coerce + case (rho1, ta @ Type.TyApply(l2, r2)) => + for { + (kl, kr) <- validateKinds(ta, right) + (l1, r1) <- unifyTyApp(rho1, kl, kr, left, right) + // Check from right to left + _ <- subsCheckRho2(l1, l2, left, right) + _ <- varianceOfConsKind(ta, kl, right).flatMap { + case Variance.Covariant => + subsCheck(r1, r2, left, right) + case Variance.Contravariant => + subsCheck(r2, r1, right, left) + case Variance.Phantom => + // this doesn't matter + unit + case Variance.Invariant => + unifyType(r1, r2, left, right) + } + ks <- checkedKinds + } yield TypedExpr.coerceRho(ta, ks) + case (ta @ Type.TyApply(l1, r1), rho2) => + // here we know that rho2 != TyApply + for { + (kl, kr) <- validateKinds(ta, left) + // here we set the kinds of l2: kl and r2: k2 + // so the kinds definitely match + (l2, r2) <- unifyTyApp(rho2, kl, kr, right, left) + // Check from right to left + _ <- subsCheckRho2(l1, l2, left, right) + // we know that l2 has kind kl + _ <- varianceOfConsKind(Type.TyApply(l2, r2), kl, right).flatMap { + case Variance.Covariant => + subsCheck(r1, r2, left, right) + case Variance.Contravariant => + subsCheck(r2, r1, right, left) + case Variance.Phantom => + // this doesn't matter + unit + case Variance.Invariant => + unifyType(r1, r2, left, right) + } + ks <- checkedKinds + } yield TypedExpr.coerceRho(rho2, ks) + case (t1, t2) => + // rule: MONO + for { + _ <- unify(t1, t2, left, right) + ck <- checkedKinds + } yield TypedExpr.coerceRho( + t1, + ck + ) // TODO this coerce seems right, since we have unified + } /* * Invariant: if the second argument is (Check rho) then rho is in weak prenex form */ - def instSigma(sigma: Type, expect: Expected[(Type.Rho, Region)], r: Region): Infer[TypedExpr.Coerce] = + def instSigma( + sigma: Type, + expect: Expected[(Type.Rho, Region)], + r: Region + ): Infer[TypedExpr.Coerce] = expect match { case Expected.Check((t, tr)) => // note t is in weak-prenex form subsCheckRho(sigma, t, r, tr) - case infer@Expected.Inf(_) => + case infer @ Expected.Inf(_) => for { (exSkols, rho) <- instantiate(sigma) _ <- infer.set((rho, r)) @@ -699,12 +825,18 @@ object Infer { } yield coerce.andThen(unskolemizeExists(exSkols)) } - def unifyFnRho(arity: Int, fnType: Type.Rho, fnRegion: Region, evidenceRegion: Region): Infer[(NonEmptyList[Type], Type)] = + def unifyFnRho( + arity: Int, + fnType: Type.Rho, + fnRegion: Region, + evidenceRegion: Region + ): Infer[(NonEmptyList[Type], Type)] = fnType match { case Type.Fun(arg, res) => val fnArity = arg.length if (fnArity == arity) pure((arg, res)) - else fail(Error.ArityMismatch(fnArity, fnRegion, arity, evidenceRegion)) + else + fail(Error.ArityMismatch(fnArity, fnRegion, arity, evidenceRegion)) case tau => if (Type.FnType.ValidArity.unapply(arity)) { val sized = NonEmptyList.fromListUnsafe((1 to arity).toList) @@ -713,9 +845,10 @@ object Infer { resT <- newMeta _ <- unify(tau, Type.Fun(argT, resT), fnRegion, evidenceRegion) } yield (argT, resT) - } - else { - fail(Error.ArityTooLarge(arity, Type.FnType.MaxSize, evidenceRegion)) + } else { + fail( + Error.ArityTooLarge(arity, Type.FnType.MaxSize, evidenceRegion) + ) } } @@ -723,18 +856,27 @@ object Infer { kindOf(ta.on, region) .parProduct(kindOf(ta.arg, region)) .flatMap { case tup @ (lKind, rKind) => - Kind.validApply[Error](lKind, rKind, - Error.KindCannotTyApply(ta, region)) { cons => - Error.KindInvalidApply(ta, cons, rKind, region) - } match { - case Right(_) => pure(tup) - case Left(err) => fail(err) - } + Kind.validApply[Error]( + lKind, + rKind, + Error.KindCannotTyApply(ta, region) + ) { cons => + Error.KindInvalidApply(ta, cons, rKind, region) + } match { + case Right(_) => pure(tup) + case Left(err) => fail(err) + } } // destructure apType in left[right] // invariant apType is being checked against some rho with validated kind: lKind[rKind] - def unifyTyApp(apType: Type.Rho, lKind: Kind, rKind: Kind, apRegion: Region, evidenceRegion: Region): Infer[(Type.Rho, Type)] = + def unifyTyApp( + apType: Type.Rho, + lKind: Kind, + rKind: Kind, + apRegion: Region, + evidenceRegion: Region + ): Infer[(Type.Rho, Type)] = apType match { case ta @ Type.TyApply(left, right) => // this branch only happens when checking ta <:< (rho: lKind[rKind]) or >:> (rho) @@ -752,32 +894,36 @@ object Infer { } // invariant the flexible type variable ty1 is not bound - def unifyUnboundVar(ty1: Type.TyMeta, ty2: Type.Tau, left: Region, right: Region): Infer[Unit] = + def unifyUnboundVar( + ty1: Type.TyMeta, + ty2: Type.Tau, + left: Region, + right: Region + ): Infer[Unit] = ty2 match { - case meta2@Type.TyMeta(m2) => + case meta2 @ Type.TyMeta(m2) => val m = ty1.toMeta // we have to check that the kind matches before writing to a meta if (m.kind == m2.kind) { val cmp = Ordering[Type.Meta].compare(m, m2) if (cmp == 0) unit - else (readMeta(m2).flatMap { - case Some(ty2) => - // we know that m2 is set, but m is not because ty1 is unbound - if (m.existential == m2.existential) { - // we unify here because ty2 could possibly be ty1 - unify(ty1, ty2, left, right) - } - else if (m.existential) { - // m2.existential == false - // we need to point m2 at m - writeMeta(m, ty2) *> writeMeta(m2, ty1) - } - else { - // m.existential == false && m2.existential == true - // we need to point m at m2 - writeMeta(m, meta2) - } - case None => + else + (readMeta(m2).flatMap { + case Some(ty2) => + // we know that m2 is set, but m is not because ty1 is unbound + if (m.existential == m2.existential) { + // we unify here because ty2 could possibly be ty1 + unify(ty1, ty2, left, right) + } else if (m.existential) { + // m2.existential == false + // we need to point m2 at m + writeMeta(m, ty2) *> writeMeta(m2, ty1) + } else { + // m.existential == false && m2.existential == true + // we need to point m at m2 + writeMeta(m, meta2) + } + case None => // Both m and m2 are not set. We just point one at the other // by convention point to the smaller item which // definitely prevents cycles. @@ -788,11 +934,19 @@ object Infer { // creating a self-loop here writeMeta(m2, ty1) } - }) - } - else { - fail(Error.KindMismatch(ty1, ty1.toMeta.kind, meta2, m2.kind, left, right)) - } + }) + } else { + fail( + Error.KindMismatch( + ty1, + ty1.toMeta.kind, + meta2, + m2.kind, + left, + right + ) + ) + } case nonMeta => // we have a non-meta, but inside of it (TyApply) we may have // metas. Let's go ahead and zonk them now to minimize nesting @@ -813,48 +967,65 @@ object Infer { // I can't seem to find a way to exploit this to produce // a forall a. a value. writeMeta(m, nonMeta) - } - else { - fail(Error.KindMismatch(ty1, m.kind, nonMeta, nmk, left, right)) + } else { + fail( + Error.KindMismatch( + ty1, + m.kind, + nonMeta, + nmk, + left, + right + ) + ) } } - } + } } } - def unifyVar(tv: Type.TyMeta, t: Type.Tau, left: Region, right: Region): Infer[Unit] = + def unifyVar( + tv: Type.TyMeta, + t: Type.Tau, + left: Region, + right: Region + ): Infer[Unit] = readMeta(tv.toMeta).flatMap { - case None => unifyUnboundVar(tv, t, left, right) + case None => unifyUnboundVar(tv, t, left, right) case Some(ty1) => unify(ty1, t, left, right) } - def show(t: Type): String = Type.fullyResolvedDocument.document(t).render(80) + def show(t: Type): String = + Type.fullyResolvedDocument.document(t).render(80) - def unify(t1: Type.Tau, t2: Type.Tau, r1: Region, r2: Region): Infer[Unit] = { + def unify( + t1: Type.Tau, + t2: Type.Tau, + r1: Region, + r2: Region + ): Infer[Unit] = { (t1, t2) match { case (Type.TyMeta(m1), Type.TyMeta(m2)) if m1.id == m2.id => unit - case (meta@Type.TyMeta(_), tpe) => unifyVar(meta, tpe, r1, r2) - case (tpe, meta@Type.TyMeta(_)) => unifyVar(meta, tpe, r2, r1) + case (meta @ Type.TyMeta(_), tpe) => unifyVar(meta, tpe, r1, r2) + case (tpe, meta @ Type.TyMeta(_)) => unifyVar(meta, tpe, r2, r1) case (t1 @ Type.TyApply(a1, b1), t2 @ Type.TyApply(a2, b2)) => - validateKinds(t1, r1) &> + validateKinds(t1, r1) &> validateKinds(t2, r2) &> unify(a1, a2, r1, r2) &> unifyType(b1, b2, r1, r2) case (Type.TyConst(c1), Type.TyConst(c2)) if c1 == c2 => unit - case (Type.TyVar(v1), Type.TyVar(v2)) if v1 == v2 => unit - case (Type.TyVar(b@Type.Var.Bound(_)), _) => + case (Type.TyVar(v1), Type.TyVar(v2)) if v1 == v2 => unit + case (Type.TyVar(b @ Type.Var.Bound(_)), _) => fail(Error.UnexpectedBound(b, t2, r1, r2)) - case (_, Type.TyVar(b@Type.Var.Bound(_))) => + case (_, Type.TyVar(b @ Type.Var.Bound(_))) => fail(Error.UnexpectedBound(b, t1, r2, r1)) case (left, right) => fail(Error.NotUnifiable(left, right, r1, r2)) } } - /** - * for a type to be unified, we mean we can substitute in either - * direction - */ + /** for a type to be unified, we mean we can substitute in either direction + */ def unifyType(t1: Type, t2: Type, r1: Region, r2: Region): Infer[Unit] = (t1, t2) match { case (rho1: Type.Rho, rho2: Type.Rho) => @@ -864,10 +1035,10 @@ object Infer { } private val emptyRef = lift(RefSpace.newRef[Option[Type.Tau]](None)) - /** - * Allocate a new Meta variable which - * will point to a Tau (no forall anywhere) type - */ + + /** Allocate a new Meta variable which will point to a Tau (no forall + * anywhere) type + */ def newMetaType0(kind: Kind, existential: Boolean): Infer[Type.TyMeta] = for { id <- nextId @@ -885,33 +1056,40 @@ object Infer { newMetaType0(kind, existential = true) // TODO: it would be nice to support kind inference on skolem variables - def newSkolemTyVar(tv: Type.Var.Bound, kind: Kind, existential: Boolean): Infer[Type.Var.Skolem] = + def newSkolemTyVar( + tv: Type.Var.Bound, + kind: Kind, + existential: Boolean + ): Infer[Type.Var.Skolem] = nextId.map(Type.Var.Skolem(tv.name, kind, existential, _)) - /** - * See if the meta variable has been set with a Tau - * type - */ + /** See if the meta variable has been set with a Tau type + */ def readMeta(m: Type.Meta): Infer[Option[Type.Tau]] = lift(m.ref.get) - /** - * Set the meta variable to point to a Tau type - */ + /** Set the meta variable to point to a Tau type + */ private def writeMeta(m: Type.Meta, v: Type.Tau): Infer[Unit] = lift(m.ref.set(Some(v))) private def clearMeta(m: Type.Meta): Infer[Unit] = lift(m.ref.set(None)) - implicit class AndThenMap[F[_], G[_], J[_]](private val fk: FunctionK[F, Lambda[x => G[J[x]]]]) extends AnyVal { - def andThenMap[H[_]](fn2: FunctionK[J, H])(implicit G: Functor[G]): FunctionK[F, Lambda[x => G[H[x]]]] = + implicit class AndThenMap[F[_], G[_], J[_]]( + private val fk: FunctionK[F, Lambda[x => G[J[x]]]] + ) extends AnyVal { + def andThenMap[H[_]]( + fn2: FunctionK[J, H] + )(implicit G: Functor[G]): FunctionK[F, Lambda[x => G[H[x]]]] = new FunctionK[F, Lambda[x => G[H[x]]]] { def apply[A](fa: F[A]): G[H[A]] = fk(fa).map(fn2(_)) } - def andThenFlatMap[H[_]](fn2: FunctionK[J, Lambda[x => G[H[x]]]])(implicit G: Monad[G]): FunctionK[F, Lambda[x => G[H[x]]]] = + def andThenFlatMap[H[_]]( + fn2: FunctionK[J, Lambda[x => G[H[x]]]] + )(implicit G: Monad[G]): FunctionK[F, Lambda[x => G[H[x]]]] = new FunctionK[F, Lambda[x => G[H[x]]]] { def apply[A](fa: F[A]): G[H[A]] = fk(fa).flatMap(fn2(_)) @@ -919,12 +1097,13 @@ object Infer { } def checkEscapeSkols[A]( - skols: List[Type.Var.Skolem], - declared: Type, - envTpes: Infer[List[Type]], - a: A, - onErr: NonEmptyList[Type] => Error)( - fn: (A, NonEmptyList[Type.Var.Skolem]) => A + skols: List[Type.Var.Skolem], + declared: Type, + envTpes: Infer[List[Type]], + a: A, + onErr: NonEmptyList[Type] => Error + )( + fn: (A, NonEmptyList[Type.Var.Skolem]) => A ): Infer[A] = skols match { case Nil => pure(a) @@ -941,15 +1120,16 @@ object Infer { case Some(badTvs) => fail(onErr(badTvs.map(Type.TyVar(_)))) } } - } + } def checkEscapeMetas[A]( - metas: List[Type.TyMeta], - declared: Type, - envTpes: Infer[List[Type]], - a: A, - onErr: NonEmptyList[Type] => Error)( - fn: (A, NonEmptyList[Type.TyMeta]) => Infer[A] + metas: List[Type.TyMeta], + declared: Type, + envTpes: Infer[List[Type]], + a: A, + onErr: NonEmptyList[Type] => Error + )( + fn: (A, NonEmptyList[Type.TyMeta]) => Infer[A] ): Infer[A] = metas match { case Nil => pure(a) @@ -968,59 +1148,73 @@ object Infer { case Some(badTvs) => fail(onErr(badTvs)) } } - } + } def subsUpper[F[_], G[_]: Functor]( - declared: Type, - region: Region, - envTpes: Infer[List[Type]])( - fn: (List[Type.TyMeta], Type.Rho) => Infer[FunctionK[F, Lambda[x => G[TypedExpr[x]]]]])( - onErr: NonEmptyList[Type] => Error): Infer[FunctionK[F, Lambda[x => G[TypedExpr[x]]]]] = + declared: Type, + region: Region, + envTpes: Infer[List[Type]] + )( + fn: ( + List[Type.TyMeta], + Type.Rho + ) => Infer[FunctionK[F, Lambda[x => G[TypedExpr[x]]]]] + )( + onErr: NonEmptyList[Type] => Error + ): Infer[FunctionK[F, Lambda[x => G[TypedExpr[x]]]]] = for { (skols, metas, rho) <- skolemize(declared, region) coerce <- fn(metas, rho) // if there are no skolem variables, we can shortcut here, because empty.filter(fn) == empty - resSkols <- checkEscapeSkols( - skols, - declared, - envTpes, - coerce, - onErr) { (coerce, nel) => coerce.andThenMap(unskolemize(nel)) } - res <- checkEscapeMetas( - metas, - declared, - envTpes, - resSkols, - onErr) { (coerce, _) => + resSkols <- checkEscapeSkols(skols, declared, envTpes, coerce, onErr) { + (coerce, nel) => coerce.andThenMap(unskolemize(nel)) + } + res <- checkEscapeMetas(metas, declared, envTpes, resSkols, onErr) { + (coerce, _) => // TODO maybe this function should go ahead and quantify pure(coerce) - } + } } yield res - def subsInstantiate(inferred: Type, declared: Type, left: Region, right: Region): Option[Infer[TypedExpr.Coerce]] = + def subsInstantiate( + inferred: Type, + declared: Type, + left: Region, + right: Region + ): Option[Infer[TypedExpr.Coerce]] = inferred match { case Type.ForAll(vars, inT) => - Type.instantiate(vars.iterator.toMap, inT, declared).map { case (_, subs) => - validateSubs(subs.toList, left, right) - .as { - new FunctionK[TypedExpr, TypedExpr] { - def apply[A](te: TypedExpr[A]): TypedExpr[A] = - // we apply the annotation here and let Normalization - // instantiate. We could explicitly have - // instantiation TypedExpr where you pass the variables to set - TypedExpr.Annotation(te, declared) + Type.instantiate(vars.iterator.toMap, inT, declared).map { + case (_, subs) => + validateSubs(subs.toList, left, right) + .as { + new FunctionK[TypedExpr, TypedExpr] { + def apply[A](te: TypedExpr[A]): TypedExpr[A] = + // we apply the annotation here and let Normalization + // instantiate. We could explicitly have + // instantiation TypedExpr where you pass the variables to set + TypedExpr.Annotation(te, declared) + } } - } } case _ => None } // note, this is identical to subsCheckRho when declared is a Rho type - def subsCheck(inferred: Type, declared: Type, left: Region, right: Region): Infer[TypedExpr.Coerce] = { + def subsCheck( + inferred: Type, + declared: Type, + left: Region, + right: Region + ): Infer[TypedExpr.Coerce] = { subsInstantiate(inferred, declared, left, right) match { case Some(inf) => inf - case None => + case None => // DEEP-SKOL rule - subsUpper[TypedExpr, cats.Id](declared, right, pure(inferred :: Nil)) { (_, rho) => + subsUpper[TypedExpr, cats.Id]( + declared, + right, + pure(inferred :: Nil) + ) { (_, rho) => // TODO: we are ignoring the metas, but we can't easily write them // with the current design since Coerce can't do any Meta writing subsCheckRho(inferred, rho, left, right) @@ -1030,99 +1224,121 @@ object Infer { } } - def inferForAll[A: HasRegion](tpes: NonEmptyList[(Type.Var.Bound, Kind)], expr: Expr[A]): Infer[TypedExpr[A]] = + def inferForAll[A: HasRegion]( + tpes: NonEmptyList[(Type.Var.Bound, Kind)], + expr: Expr[A] + ): Infer[TypedExpr[A]] = for { - (skols, t1) <- Expr.skolemizeVars(tpes, expr)(newSkolemTyVar(_, _, existential = false)) + (skols, t1) <- Expr.skolemizeVars(tpes, expr)( + newSkolemTyVar(_, _, existential = false) + ) sigmaT <- inferSigma(t1) z <- zonkTypedExpr(sigmaT) } yield unskolemize(skols)(z) def unsolvedExistentials(ts: List[Type]): Infer[List[Type.Meta]] = - Type.metaTvs(ts) + Type + .metaTvs(ts) .iterator .filter(_.existential) .toList .traverseFilter { m => readMeta(m).map { - case None => Some(m) + case None => Some(m) case Some(_) => None } } private[this] val pureNone: Infer[None.type] = pure(None) - def solvedExistentitals(lst: List[Type.Meta]): Infer[SortedMap[Type.Meta, Type.Rho]] = - lst.traverseFilter { m => - readMeta(m).flatMap { - case Some(tau) => - // reset this meta - clearMeta(m).as(Some((m, tau))) - case None => pureNone + def solvedExistentitals( + lst: List[Type.Meta] + ): Infer[SortedMap[Type.Meta, Type.Rho]] = + lst + .traverseFilter { m => + readMeta(m).flatMap { + case Some(tau) => + // reset this meta + clearMeta(m).as(Some((m, tau))) + case None => pureNone + } } - } - .map(_.to(SortedMap)) + .map(_.to(SortedMap)) - def unifyExistential(m: Type.Meta, values: List[(Type.Rho, Region)]): Infer[Unit] = { + def unifyExistential( + m: Type.Meta, + values: List[(Type.Rho, Region)] + ): Infer[Unit] = { def loop(values: List[(Type.Rho, Region)]): Infer[Unit] = values match { case (Type.TyMeta(m1), r1) :: tail => readMeta(m1).flatMap { case Some(rho1) => loop((rho1, r1) :: tail) - case None => loop(tail) + case None => loop(tail) } case (h1, _) :: (t1 @ ((h2, _) :: _)) => if (h1 == h2) loop(t1) else { // there are at least two distinct values, make a new meta skolem nextId.flatMap { id => - val skol = Type.Var.Skolem(s"meta${m.id}", m.kind, existential = true, id) + val skol = Type.Var.Skolem( + s"meta${m.id}", + m.kind, + existential = true, + id + ) val tpe = Type.TyVar(skol) writeMeta(m, tpe) } } case (h, _) :: Nil => writeMeta(m, h) - case Nil => unit + case Nil => unit } loop(values) } - /** - * This idea here is that each branch may solve a different value of a given - * existential type. If that happens, we can assign an existential skolem - * variable to the type and move on. - * - * In this way, existentials are a kind of union type, and the skolem represents - * a value later that will be exists x. ... + /** This idea here is that each branch may solve a different value of a + * given existential type. If that happens, we can assign an existential + * skolem variable to the type and move on. + * + * In this way, existentials are a kind of union type, and the skolem + * represents a value later that will be exists x. ... */ def unifyBranchExistentials( - lst: List[Type.Meta], - branches: NonEmptyList[(SortedMap[Type.Meta, Type.Rho], Region)]): Infer[Unit] = + lst: List[Type.Meta], + branches: NonEmptyList[(SortedMap[Type.Meta, Type.Rho], Region)] + ): Infer[Unit] = lst.traverse_ { m => val toUnify = branches.toList.mapFilter { case (s, region) => - s.get(m).map((_, region)) + s.get(m).map((_, region)) } - + // despite the name, this can't fail unifyExistential(m, toUnify) } - - def maybeSimple[A: HasRegion](term: Expr[A]): Option[Infer[TypedExpr[A]]] = { + def maybeSimple[A: HasRegion]( + term: Expr[A] + ): Option[Infer[TypedExpr[A]]] = { import Expr._ term match { case Literal(lit, t) => val tpe = Type.getTypeOf(lit) Some(pure(TypedExpr.Literal(lit, tpe, t))) case Local(name, tag) => - Some(lookupVarType((None, name), region(term)) - .map { vSigma => - TypedExpr.Local(name, vSigma, tag) - }) + Some( + lookupVarType((None, name), region(term)) + .map { vSigma => + TypedExpr.Local(name, vSigma, tag) + } + ) case Global(pack, name, tag) => - Some(lookupVarType((Some(pack), name), region(term)) - .map { vSigma => - TypedExpr.Global(pack, name, vSigma, tag) - }) + Some( + lookupVarType((Some(pack), name), region(term)) + .map { vSigma => + TypedExpr.Global(pack, name, vSigma, tag) + } + ) case Annotation(term, tpe, _) => Some(checkSigma(term, tpe)) case _ => @@ -1130,37 +1346,56 @@ object Infer { } } - def validateSubs(list: List[(Type.Var.Bound, (Kind, Type))], left: Region, right: Region): Infer[Unit] = + def validateSubs( + list: List[(Type.Var.Bound, (Kind, Type))], + left: Region, + right: Region + ): Infer[Unit] = list.parTraverse_ { case (boundVar, (kind, tpe)) => kindOf(tpe, right).flatMap { k => if (Kind.leftSubsumesRight(kind, k)) { unit - } - else { - fail(Error.KindMismatch(Type.TyVar(boundVar), kind, tpe, k, left, right)) + } else { + fail( + Error.KindMismatch( + Type.TyVar(boundVar), + kind, + tpe, + k, + left, + right + ) + ) } } } - def checkApply[A: HasRegion](fn: Expr[A], args: NonEmptyList[Expr[A]], tag: A, tpe: Type, tpeRegion: Region): Infer[TypedExpr[A]] = { + def checkApply[A: HasRegion]( + fn: Expr[A], + args: NonEmptyList[Expr[A]], + tag: A, + tpe: Type, + tpeRegion: Region + ): Infer[TypedExpr[A]] = { val infOpt = maybeSimple(fn).flatTraverse { inferFnExpr => inferFnExpr.map { fnTe => fnTe.getType match { - case Type.Fun.SimpleUniversal(univ, inT, outT) if inT.length == args.length => + case Type.Fun.SimpleUniversal(univ, inT, outT) + if inT.length == args.length => // see if we can instantiate the result type // if we can, we use that to fix the known parameters and continue - Type.instantiate(univ.iterator.toMap, outT, tpe).flatMap { case (frees, inst) => - // if instantiate works, we know outT => tpe - if (inst.nonEmpty && frees.isEmpty) { - // we made some progress and there are no frees - // TODO: we could support frees it seems but - // it triggers failures in tests now - Some((fnTe, inT, frees, inst)) - } - else { - // We learned nothing - None - } + Type.instantiate(univ.iterator.toMap, outT, tpe).flatMap { + case (frees, inst) => + // if instantiate works, we know outT => tpe + if (inst.nonEmpty && frees.isEmpty) { + // we made some progress and there are no frees + // TODO: we could support frees it seems but + // it triggers failures in tests now + Some((fnTe, inT, frees, inst)) + } else { + // We learned nothing + None + } } case _ => None @@ -1171,7 +1406,8 @@ object Infer { infOpt.flatMap { case Some((fnTe, inT, frees, inst)) => val regTe = region(tag) - val validKinds: Infer[Unit] = validateSubs(inst.toList, region(fn), regTe) + val validKinds: Infer[Unit] = + validateSubs(inst.toList, region(fn), regTe) val instNoKind = inst.iterator .map { case (k, (_, t)) => (k, t) } .toMap[Type.Var, Type] @@ -1181,29 +1417,30 @@ object Infer { validKinds.parProductR { val remainingFree = NonEmptyList.fromList( - frees.iterator.map { case (_, (k, b)) => (b, k) } - .toList + frees.iterator.map { case (_, (k, b)) => (b, k) }.toList ) - + remainingFree match { case None => // we can fully instantiate - args.zip(subIn).parTraverse { case (e, t) => - checkSigma(e, t) - } - .map { argsTE => - TypedExpr.App(fnTe, argsTE, tpe, tag) - } - + args + .zip(subIn) + .parTraverse { case (e, t) => + checkSigma(e, t) + } + .map { argsTE => + TypedExpr.App(fnTe, argsTE, tpe, tag) + } + // $COVERAGE-OFF$ - //case Some(remainingFree) => + // case Some(remainingFree) => case Some(_) => - // Currently we are only returning infOpt as Some when - // there are no remaining free variables due to unit - // tests not passing - sys.error("unreachable") + // Currently we are only returning infOpt as Some when + // there are no remaining free variables due to unit + // tests not passing + sys.error("unreachable") // $COVERAGE-ON$ - /* + /* // some items are still free // TODO we could use the args to try to fix these val freeSub = frees.iterator @@ -1218,8 +1455,8 @@ object Infer { val fn1 = Expr.Annotation(fn, tpe1, fn.tag) val inner = Expr.App(fn1, args, tag) checkSigma(inner, tpe) - */ - } + */ + } } case None => tpe match { @@ -1232,50 +1469,69 @@ object Infer { } } - def applyRhoExpect[A: HasRegion](fn: Expr[A], args: NonEmptyList[Expr[A]], tag: A, expect: Expected[(Type.Rho, Region)]): Infer[TypedExpr.Rho[A]] = + def applyRhoExpect[A: HasRegion]( + fn: Expr[A], + args: NonEmptyList[Expr[A]], + tag: A, + expect: Expected[(Type.Rho, Region)] + ): Infer[TypedExpr.Rho[A]] = for { (typedFn, fnTRho) <- inferRho(fn) argsRegion = args.reduceMap(region[Expr[A]](_)) (argT, resT) <- unifyFnRho(args.length, fnTRho, region(fn), argsRegion) - typedArg <- args.zip(argT).parTraverse { case (arg, argT) => checkSigma(arg, argT) } + typedArg <- args.zip(argT).parTraverse { case (arg, argT) => + checkSigma(arg, argT) + } coerce <- instSigma(resT, expect, region(tag)) res <- zonkTypedExpr(TypedExpr.App(typedFn, typedArg, resT, tag)) } yield coerce(res) - def checkAnnotated[A: HasRegion](inner: Expr[A], tpe: Type, tpeRegion: Region, expect: Expected[(Type.Rho, Region)]): Infer[TypedExpr.Rho[A]] = + def checkAnnotated[A: HasRegion]( + inner: Expr[A], + tpe: Type, + tpeRegion: Region, + expect: Expected[(Type.Rho, Region)] + ): Infer[TypedExpr.Rho[A]] = (checkSigma(inner, tpe), instSigma(tpe, expect, tpeRegion)) .parFlatMapN { (typedTerm, coerce) => zonkTypedExpr(typedTerm).map(coerce(_)) } - /** - * Invariant: if the second argument is (Check rho) then rho is in weak prenex form - */ - def typeCheckRho[A: HasRegion](term: Expr[A], expect: Expected[(Type.Rho, Region)]): Infer[TypedExpr.Rho[A]] = { + + /** Invariant: if the second argument is (Check rho) then rho is in weak + * prenex form + */ + def typeCheckRho[A: HasRegion]( + term: Expr[A], + expect: Expected[(Type.Rho, Region)] + ): Infer[TypedExpr.Rho[A]] = { import Expr._ term match { case Literal(lit, t) => val tpe = Type.getTypeOf(lit) - instSigma(tpe, expect, region(term)).map(_(TypedExpr.Literal(lit, tpe, t))) + instSigma(tpe, expect, region(term)).map( + _(TypedExpr.Literal(lit, tpe, t)) + ) case Local(name, tag) => for { vSigma <- lookupVarType((None, name), region(term)) coerce <- instSigma(vSigma, expect, region(term)) res0 = TypedExpr.Local(name, vSigma, tag) res <- zonkTypedExpr(res0) - } yield coerce(res) + } yield coerce(res) case Global(pack, name, tag) => for { vSigma <- lookupVarType((Some(pack), name), region(term)) coerce <- instSigma(vSigma, expect, region(term)) res <- zonkTypedExpr(TypedExpr.Global(pack, name, vSigma, tag)) - } yield coerce(res) + } yield coerce(res) case Annotation(App(fn, args, tag), resT, annTag) => - (checkApply(fn, args, tag, resT, region(annTag)), + ( + checkApply(fn, args, tag, resT, region(annTag)), instSigma(resT, expect, region(annTag)) ) - .parFlatMapN { (typedTerm, coerce) => - zonkTypedExpr(typedTerm).map(coerce(_)) - } + .parFlatMapN { (typedTerm, coerce) => + zonkTypedExpr(typedTerm).map(coerce(_)) + } case App(fn, args, tag) => expect match { case Expected.Check((rho, reg)) => @@ -1284,38 +1540,47 @@ object Infer { applyRhoExpect(fn, args, tag, inf) } case Generic(tpes, in) => - for { - unSkol <- inferForAll(tpes, in) - // unSkol is not a Rho type, we need instantiate it - coerce <- instSigma(unSkol.getType, expect, region(term)) - } yield coerce(unSkol) + for { + unSkol <- inferForAll(tpes, in) + // unSkol is not a Rho type, we need instantiate it + coerce <- instSigma(unSkol.getType, expect, region(term)) + } yield coerce(unSkol) case Lambda(args, result, tag) => expect match { case Expected.Check((expTy, rr)) => for { // we know expTy is in weak-prenex form, and since Fn is covariant, bodyT must be // in weak prenex form - (varsT, bodyT) <- unifyFnRho(args.length, expTy, rr, region(term)) - bodyTRho <- assertRho(bodyT, s"expected ${show(expTy)} at $rr to be in weak-prenex form.", region(result)) + (varsT, bodyT) <- unifyFnRho( + args.length, + expTy, + rr, + region(term) + ) + bodyTRho <- assertRho( + bodyT, + s"expected ${show(expTy)} at $rr to be in weak-prenex form.", + region(result) + ) // the length of args and varsT must be the same because of unifyFnRho zipped = args.zip(varsT) namesVarsT = zipped.map { case ((n, _), t) => (n, t) } typedBody <- extendEnvList(namesVarsT.toList) { - // TODO we are ignoring the result of subsCheck here - // should we be coercing a var? - // - // this comes from page 54 of the paper, but I can't seem to find examples - // where this will fail if we reverse (as we had for a long time), which - // indicates the testing coverage is incomplete - zipped.parTraverse_ { - case ((_, Some(tpe)), varT) => - subsCheck(varT, tpe, region(term), rr) - case ((_, None), _) => unit - } &> + // TODO we are ignoring the result of subsCheck here + // should we be coercing a var? + // + // this comes from page 54 of the paper, but I can't seem to find examples + // where this will fail if we reverse (as we had for a long time), which + // indicates the testing coverage is incomplete + zipped.parTraverse_ { + case ((_, Some(tpe)), varT) => + subsCheck(varT, tpe, region(term), rr) + case ((_, None), _) => unit + } &> checkRho(result, bodyTRho) - } + } } yield TypedExpr.AnnotatedLambda(namesVarsT, typedBody, tag) - case infer@Expected.Inf(_) => + case infer @ Expected.Inf(_) => for { nameVarsT <- args.parTraverse { case (n, Some(tpe)) => @@ -1325,10 +1590,14 @@ object Infer { // all functions args of kind type newMeta.map((n, _)) } - (typedBody, bodyT) <- extendEnvNonEmptyList(nameVarsT)(inferRho(result)) - _ <- infer.set((Type.Fun(nameVarsT.map(_._2), bodyT), region(term))) + (typedBody, bodyT) <- extendEnvNonEmptyList(nameVarsT)( + inferRho(result) + ) + _ <- infer.set( + (Type.Fun(nameVarsT.map(_._2), bodyT), region(term)) + ) } yield TypedExpr.AnnotatedLambda(nameVarsT, typedBody, tag) - } + } case Let(name, rhs, body, isRecursive, tag) => if (isRecursive.isRecursive) { // all defs are marked at potentially recursive. @@ -1342,9 +1611,9 @@ object Infer { // cases differently val rhsBody = rhs match { case Expr.Annotated(tpe) => - extendEnv(name, tpe) { - checkSigma(rhs, tpe).parProduct(typeCheckRho(body, expect)) - } + extendEnv(name, tpe) { + checkSigma(rhs, tpe).parProduct(typeCheckRho(body, expect)) + } case notAnnotated => newMeta // the kind of a let value is a Type .flatMap { rhsTpe => @@ -1352,15 +1621,20 @@ object Infer { for { // the type variable needs to be unified with varT // note, varT could be a sigma type, it is not a Tau or Rho - typedRhs <- inferSigmaMeta(notAnnotated, Some((name, rhsTpe, region(notAnnotated)))) + typedRhs <- inferSigmaMeta( + notAnnotated, + Some((name, rhsTpe, region(notAnnotated))) + ) varT = typedRhs.getType // we need to overwrite the metavariable now with the full type - typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect)) + typedBody <- extendEnv(name, varT)( + typeCheckRho(body, expect) + ) } yield (typedRhs, typedBody) } } - } - + } + rhsBody.map { case (rhs, body) => // TODO: a more efficient algorithm would do this top down // for each top level TypedExpr and build it bottom up. @@ -1369,43 +1643,53 @@ object Infer { val isRecursive = RecursionKind.recursive(frees.contains(name)) TypedExpr.Let(name, rhs, body, isRecursive, tag) } - } - else { + } else { // In this branch, we typecheck the rhs *without* name in the environment // so any recursion in this case won't typecheck, and shadowing rules are // in place val rhsBody = rhs match { - case Expr.Annotated(tpe) => - // check in parallel so we collect more errors - checkSigma(rhs, tpe) - .parProduct( - extendEnv(name, tpe) { typeCheckRho(body, expect) } - ) + case Expr.Annotated(tpe) => + // check in parallel so we collect more errors + checkSigma(rhs, tpe) + .parProduct( + extendEnv(name, tpe) { typeCheckRho(body, expect) } + ) case _ => // we don't know the type of rhs, so we have to infer then check the body for { typedRhs <- inferSigma(rhs) - typedBody <- extendEnv(name, typedRhs.getType)(typeCheckRho(body, expect)) + typedBody <- extendEnv(name, typedRhs.getType)( + typeCheckRho(body, expect) + ) } yield (typedRhs, typedBody) - } + } rhsBody.map { case (rhs, body) => // Note: in this branch, we know isRecursive.isRecursive == false - TypedExpr.Let(name, rhs, body, recursive = RecursionKind.NonRecursive, tag) + TypedExpr.Let( + name, + rhs, + body, + recursive = RecursionKind.NonRecursive, + tag + ) } } case Annotation(term, tpe, tag) => val inner = term match { - case Match(arg, branches, mtag) => + case Match(arg, branches, mtag) => // We push the Annotation down to help with // existential type checking where each branch // has a different type - Match(arg, branches.map { case (p, r) => - // we have to put the tag to be r.tag - // because that's where the regions come from - (p, Annotation(r, tpe, r.tag)) - }, - mtag) + Match( + arg, + branches.map { case (p, r) => + // we have to put the tag to be r.tag + // because that's where the regions come from + (p, Annotation(r, tpe, r.tag)) + }, + mtag + ) case notMatch => notMatch } @@ -1428,13 +1712,24 @@ object Infer { // see if substitute rho with subs <:< expected // else set inferred value val validKinds: Infer[Unit] = - validateSubs(subs.toList, region(term), region(tag)) + validateSubs( + subs.toList, + region(term), + region(tag) + ) validKinds.parProductR(expect match { case Expected.Check((r1, reg1)) => for { - co <- subsCheckRho2(rho, r1, region(term), reg1) - z <- zonkTypedExpr(TypedExpr.Annotation(te, rho)) + co <- subsCheckRho2( + rho, + r1, + region(term), + reg1 + ) + z <- zonkTypedExpr( + TypedExpr.Annotation(te, rho) + ) } yield co(z) case inf @ Expected.Inf(_) => for { @@ -1444,12 +1739,12 @@ object Infer { }) case _ => default - } + } case _ => default } } - + case None => default } case _ => @@ -1487,19 +1782,25 @@ object Infer { // note, resT is in weak-prenex form, so this call is permitted checkBranch(p, check, r, resT) } - } - else { + } else { for { tbranches <- branches.parTraverse { case (p, r) => // note, resT is in weak-prenex form, so this call is permitted checkBranch(p, check, r, resT) - .product(solvedExistentitals(unknownExs).map((_, region(r)))) + .product( + solvedExistentitals(unknownExs).map( + (_, region(r)) + ) + ) } - _ <- unifyBranchExistentials(unknownExs, tbranches.map(_._2)) + _ <- unifyBranchExistentials( + unknownExs, + tbranches.map(_._2) + ) } yield tbranches.map(_._1) } } yield TypedExpr.Match(tsigma, tbranches, tag) - case infer@Expected.Inf(_) => + case infer @ Expected.Inf(_) => for { tbranches <- branches.parTraverse { case (p, r) => inferBranch(p, check, r) @@ -1512,9 +1813,13 @@ object Infer { } } - def widenBranches[A: HasRegion](branches: NonEmptyList[(Pattern, (TypedExpr.Rho[A], Type.Rho))]): Infer[(Type.Rho, Region, NonEmptyList[(Pattern, TypedExpr.Rho[A])])] = { + def widenBranches[A: HasRegion]( + branches: NonEmptyList[(Pattern, (TypedExpr.Rho[A], Type.Rho))] + ): Infer[(Type.Rho, Region, NonEmptyList[(Pattern, TypedExpr.Rho[A])])] = { - def maxBy[M[_]: Monad, B](head: B, tail: List[B])(gteq: (B, B) => M[Boolean]): M[B] = + def maxBy[M[_]: Monad, B](head: B, tail: List[B])( + gteq: (B, B) => M[Boolean] + ): M[B] = tail match { case Nil => Monad[M].pure(head) case h :: tail => @@ -1523,9 +1828,12 @@ object Infer { val next = if (keep) head else h maxBy(next, tail)(gteq) } - } + } - def gtEq[K](left: (TypedExpr[A], K), right: (TypedExpr[A], K)): Infer[Boolean] = { + def gtEq[K]( + left: (TypedExpr[A], K), + right: (TypedExpr[A], K) + ): Infer[Boolean] = { val leftTE = left._1 val rightTE = right._1 val lt = leftTE.getType @@ -1533,14 +1841,12 @@ object Infer { val rt = rightTE.getType val rr = region(rightTE) // left >= right if right subsumes left - subsCheck(rt, lt, rr, lr) - .peek + subsCheck(rt, lt, rr, lr).peek .flatMap { case Right(_) => pure(true) - case Left(_) => + case Left(_) => // maybe the other way around - subsCheck(lt, rt, lr, rr) - .peek + subsCheck(lt, rt, lr, rr).peek .flatMap { case Right(_) => // okay, we see right > left @@ -1552,37 +1858,50 @@ object Infer { } } - val withIdx = branches.zipWithIndex.map { case ((p, (te, tpe)), idx) => (te, (p, tpe, idx)) } + val withIdx = branches.zipWithIndex.map { case ((p, (te, tpe)), idx) => + (te, (p, tpe, idx)) + } for { - (maxRes, (maxPat, resTRho, maxIdx)) <- maxBy(withIdx.head, withIdx.tail)((a, b) => gtEq(a, b)) + (maxRes, (maxPat, resTRho, maxIdx)) <- maxBy( + withIdx.head, + withIdx.tail + )((a, b) => gtEq(a, b)) resRegion = region(maxRes) resBranches <- withIdx.parTraverse { case (te, (p, tpe, idx)) => if (idx != maxIdx) { // unfortunately we have to check each branch again to get the correct coerce subsCheckRho2(tpe, resTRho, region(te), resRegion) .map { coerce => - (p, coerce(te)) + (p, coerce(te)) } - } - else pure((p, te)) + } else pure((p, te)) } } yield (resTRho, resRegion, resBranches) } - + /* * we require resT in weak prenex form because we call checkRho with it * TODO: if sigma is an existential, then maybe we should reset any * existentials after and leave them opaque? Or maybe if we could * avoid allocating the metas until we get into the branch */ - def checkBranch[A: HasRegion](p: Pattern, sigma: Expected.Check[(Type, Region)], res: Expr[A], resT: Type.Rho): Infer[(Pattern, TypedExpr.Rho[A])] = + def checkBranch[A: HasRegion]( + p: Pattern, + sigma: Expected.Check[(Type, Region)], + res: Expr[A], + resT: Type.Rho + ): Infer[(Pattern, TypedExpr.Rho[A])] = for { (pattern, bindings) <- typeCheckPattern(p, sigma, region(res)) tres <- extendEnvList(bindings)(checkRho(res, resT)) } yield (pattern, tres) - def inferBranch[A: HasRegion](p: Pattern, sigma: Expected.Check[(Type, Region)], res: Expr[A]): Infer[(Pattern, (TypedExpr.Rho[A], Type.Rho))] = + def inferBranch[A: HasRegion]( + p: Pattern, + sigma: Expected.Check[(Type, Region)], + res: Expr[A] + ): Infer[(Pattern, (TypedExpr.Rho[A], Type.Rho))] = for { patBind <- typeCheckPattern(p, sigma, region(res)) (pattern, bindings) = patBind @@ -1590,13 +1909,16 @@ object Infer { res <- extendEnvList(bindings)(inferRho(res)) } yield (pattern, res) - /** - * patterns can be a sigma type, not neccesarily a rho/tau - * return a list of bound names and their (sigma) types - * - * TODO: Pattern needs to have a region for each part - */ - def typeCheckPattern(pat: Pattern, sigma: Expected.Check[(Type, Region)], reg: Region): Infer[(Pattern, List[(Bindable, Type)])] = { + /** patterns can be a sigma type, not neccesarily a rho/tau return a list of + * bound names and their (sigma) types + * + * TODO: Pattern needs to have a region for each part + */ + def typeCheckPattern( + pat: Pattern, + sigma: Expected.Check[(Type, Region)], + reg: Region + ): Infer[(Pattern, List[(Bindable, Type)])] = { pat match { case GenPattern.WildCard => Infer.pure((pat, Nil)) case GenPattern.Literal(lit) => @@ -1616,7 +1938,8 @@ object Infer { def inner(pat: Pattern) = sigma match { case Expected.Check((t, _)) => - val res = (GenPattern.Annotation(GenPattern.Named(n, pat), t), t) + val res = + (GenPattern.Annotation(GenPattern.Named(n, pat), t), t) Infer.pure(res) } // We always return an annotation here, which is the only @@ -1633,7 +1956,7 @@ object Infer { case Expected.Check((t, tr)) => subsCheck(tpe, t, reg, tr) } val names = items.collect { - case GenPattern.StrPart.NamedStr(n) => (n, tpe) + case GenPattern.StrPart.NamedStr(n) => (n, tpe) case GenPattern.StrPart.NamedChar(n) => (n, Type.CharType) } // we need to apply the type so the names are well typed @@ -1647,25 +1970,34 @@ object Infer { * of them have type A. */ def checkItem( - inner: Type, - lst: Type, - e: ListPart[Pattern]): Infer[(ListPart[Pattern], List[(Bindable, Type)])] = - e match { - case l@ListPart.WildList => - // this is *a pattern that has list type, and binds that type to the name - Infer.pure((l, Nil)) - case l@ListPart.NamedList(splice) => - // this is *a pattern that has list type, and binds that type to the name - Infer.pure((l, (splice, lst) :: Nil)) - case ListPart.Item(p) => - // This is a non-splice - checkPat(p, inner, reg).map { case (p, l) => (ListPart.Item(p), l) } - } + inner: Type, + lst: Type, + e: ListPart[Pattern] + ): Infer[(ListPart[Pattern], List[(Bindable, Type)])] = + e match { + case l @ ListPart.WildList => + // this is *a pattern that has list type, and binds that type to the name + Infer.pure((l, Nil)) + case l @ ListPart.NamedList(splice) => + // this is *a pattern that has list type, and binds that type to the name + Infer.pure((l, (splice, lst) :: Nil)) + case ListPart.Item(p) => + // This is a non-splice + checkPat(p, inner, reg).map { case (p, l) => + (ListPart.Item(p), l) + } + } val tpeOfList: Infer[Type] = sigma.value match { case (Type.TyApply(Type.ListType, item), _) => pure(item) - case (Type.ForAll(b@NonEmptyList(_, Nil), Type.TyApply(Type.ListType, item)), _) => + case ( + Type.ForAll( + b @ NonEmptyList(_, Nil), + Type.TyApply(Type.ListType, item) + ), + _ + ) => // list is covariant so we can push down pure(Type.forAll(b, item)) case (_, reg) => @@ -1682,7 +2014,10 @@ object Infer { inners <- items.parTraverse(checkItem(tpeA, listA, _)) innerPat = inners.map(_._1) innerBinds = inners.flatMap(_._2) - } yield (GenPattern.Annotation(GenPattern.ListPat(innerPat), listA), innerBinds) + } yield ( + GenPattern.Annotation(GenPattern.ListPat(innerPat), listA), + innerBinds + ) case GenPattern.Annotation(p, tpe) => // like in the case of an annotation, we check the type, then @@ -1707,19 +2042,24 @@ object Infer { pats = envs.map(_._1) bindings = envs.map(_._2) } yield (GenPattern.PositionalStruct(nm, pats), bindings.flatten) - case u@GenPattern.Union(h, t) => - (typeCheckPattern(h, sigma, reg), t.parTraverse(typeCheckPattern(_, sigma, reg))) - .parMapN { case ((h, binds), neList) => - val pat = GenPattern.Union(h, neList.map(_._1)) - val allBinds = NonEmptyList(binds, (neList.map(_._2).toList)) - identicalBinds(u, allBinds, reg).as((pat, binds)) - } - .flatten + case u @ GenPattern.Union(h, t) => + ( + typeCheckPattern(h, sigma, reg), + t.parTraverse(typeCheckPattern(_, sigma, reg)) + ).parMapN { case ((h, binds), neList) => + val pat = GenPattern.Union(h, neList.map(_._1)) + val allBinds = NonEmptyList(binds, (neList.map(_._2).toList)) + identicalBinds(u, allBinds, reg).as((pat, binds)) + }.flatten } } // Unions have to have identical bindings in all branches - def identicalBinds(u: Pattern, binds: NonEmptyList[List[(Bindable, Type)]], reg: Region): Infer[Unit] = + def identicalBinds( + u: Pattern, + binds: NonEmptyList[List[(Bindable, Type)]], + reg: Region + ): Infer[Unit] = binds.map(_.map(_._1)) match { case nel @ NonEmptyList(h, t) => val bs = h.toSet @@ -1735,21 +2075,28 @@ object Infer { unifyType(tpe, tpe2, reg, reg) } } - } - else fail(Error.UnionPatternBindMismatch(u, nel, reg)) + } else fail(Error.UnionPatternBindMismatch(u, nel, reg)) } // TODO: we should be able to derive a region for any pattern - def checkPat(pat: Pattern, sigma: Type, reg: Region): Infer[(Pattern, List[(Bindable, Type)])] = + def checkPat( + pat: Pattern, + sigma: Type, + reg: Region + ): Infer[(Pattern, List[(Bindable, Type)])] = typeCheckPattern(pat, Expected.Check((sigma, reg)), reg) - /** - * To do this, Infer will need to know the names of the type - * constructors in scope. - * - * Instantiation fills in all - */ - def instDataCon(consName: (PackageName, Constructor), sigma: Type, reg: Region, sigmaRegion: Region): Infer[List[Type]] = + /** To do this, Infer will need to know the names of the type constructors + * in scope. + * + * Instantiation fills in all + */ + def instDataCon( + consName: (PackageName, Constructor), + sigma: Type, + reg: Region, + sigmaRegion: Region + ): Infer[List[Type]] = GetDataCons(consName, reg).flatMap { case (args, consParams, tpeName) => val thisTpe = Type.TyConst(tpeName) @@ -1758,7 +2105,11 @@ object Infer { // for a pattern match, where we have already type checked the scrutinee // and the type constructor is well-kinded by the checks done at kind // inference time. - def loop(revArgs: List[(Type.Var.Bound, Kind.Arg)], leftKind: Kind, sigma: Type): Infer[Map[Type.Var, Type]] = + def loop( + revArgs: List[(Type.Var.Bound, Kind.Arg)], + leftKind: Kind, + sigma: Type + ): Infer[Map[Type.Var, Type]] = (revArgs, sigma) match { case (Nil, tpe) => for { @@ -1771,16 +2122,21 @@ object Infer { case (_, fa: Type.Quantified) => // we have to instantiate a rho type instantiate(fa) - .flatMap { case (_, faRho) => - // TODO: it seems like we shouldn't ignore the existential skolems - loop(revArgs, leftKind, faRho) - } + .flatMap { case (_, faRho) => + // TODO: it seems like we shouldn't ignore the existential skolems + loop(revArgs, leftKind, faRho) + } case ((v0, k) :: rest, _) => // (k -> leftKind)(k) for { left <- newMetaType(Kind.Cons(k, leftKind)) right <- newMetaType(k.kind) - _ <- unifyType(Type.TyApply(left, right), sigma, reg, sigmaRegion) + _ <- unifyType( + Type.TyApply(left, right), + sigma, + reg, + sigmaRegion + ) nextKind = Kind.Cons(k, leftKind) rest <- loop(rest, nextKind, left) } yield rest.updated(v0, right) @@ -1789,47 +2145,62 @@ object Infer { // so we push the forall down to avoid allocating a metaVar which can only // hold a monotype def pushDownCovariant( - revArgs: List[(Type.Var.Bound, Kind.Arg)], - revForAlls: List[(Type.Var.Bound, Kind)], - sigma: Type): Type = { - (revArgs, sigma) match { + revArgs: List[(Type.Var.Bound, Kind.Arg)], + revForAlls: List[(Type.Var.Bound, Kind)], + sigma: Type + ): Type = { + (revArgs, sigma) match { case (_, Type.ForAll(params, over)) => - pushDownCovariant(revArgs, params.toList reverse_::: revForAlls, over) - case ((_, Kind.Arg(Variance.Covariant, _)) :: rest, Type.TyApply(left, right)) => - // TODO Phantom variance has some special rules too. I guess we - // can push into phantom as well (though that's rare) - val leftFree = Type.freeBoundTyVars(left :: Nil).toSet - val rightFree = Type.freeBoundTyVars(right :: Nil).toSet - - val (nextRFA, nextRight) = - revForAlls.filter { case (leftA, _) => rightFree(leftA) && !leftFree(leftA) } match { - case Nil => (revForAlls, right) - case pushed => - // it is safe to push it down - val pushedSet = pushed.iterator.map(_._1).toSet - val revFA1 = revForAlls.toList.filterNot { case (b, _) => pushedSet(b) } - val pushedRight = Type.forAll(pushed.reverse, right) - (revFA1, pushedRight) - } - pushDownCovariant(rest, nextRFA, left) match { - case Type.ForAll(bs, l) => - // TODO: I think we can push down existentials too - Type.forAll(bs, Type.apply1(l, nextRight)) - case rho /*: Type.Rho */ => - Type.apply1(rho, nextRight) + pushDownCovariant( + revArgs, + params.toList reverse_::: revForAlls, + over + ) + case ( + (_, Kind.Arg(Variance.Covariant, _)) :: rest, + Type.TyApply(left, right) + ) => + // TODO Phantom variance has some special rules too. I guess we + // can push into phantom as well (though that's rare) + val leftFree = Type.freeBoundTyVars(left :: Nil).toSet + val rightFree = Type.freeBoundTyVars(right :: Nil).toSet + + val (nextRFA, nextRight) = + revForAlls.filter { case (leftA, _) => + rightFree(leftA) && !leftFree(leftA) + } match { + case Nil => (revForAlls, right) + case pushed => + // it is safe to push it down + val pushedSet = pushed.iterator.map(_._1).toSet + val revFA1 = revForAlls.toList.filterNot { case (b, _) => + pushedSet(b) + } + val pushedRight = Type.forAll(pushed.reverse, right) + (revFA1, pushedRight) } + pushDownCovariant(rest, nextRFA, left) match { + case Type.ForAll(bs, l) => + // TODO: I think we can push down existentials too + Type.forAll(bs, Type.apply1(l, nextRight)) + case rho /*: Type.Rho */ => + Type.apply1(rho, nextRight) + } case (_ :: rest, Type.TyApply(left, right)) => - val rightFree = Type.freeBoundTyVars(right :: Nil).toSet - val (keptRight, lefts) = - revForAlls.partition { case (leftA, _) => rightFree(leftA) } - - Type.forAll(keptRight.reverse, pushDownCovariant(rest, lefts, left)) match { - case Type.ForAll(bs, l) => - // TODO: we could possibly have an existential here? - Type.forAll(bs, Type.apply1(l, right)) - case rho /*: Type.Rho */=> - Type.apply1(rho, right) - } + val rightFree = Type.freeBoundTyVars(right :: Nil).toSet + val (keptRight, lefts) = + revForAlls.partition { case (leftA, _) => rightFree(leftA) } + + Type.forAll( + keptRight.reverse, + pushDownCovariant(rest, lefts, left) + ) match { + case Type.ForAll(bs, l) => + // TODO: we could possibly have an existential here? + Type.forAll(bs, Type.apply1(l, right)) + case rho /*: Type.Rho */ => + Type.apply1(rho, right) + } case _ => Type.forAll(revForAlls.reverse, sigma) } @@ -1846,7 +2217,10 @@ object Infer { inferSigmaMeta(e, None) // invariant: if meta.isDefined then e is not Expr.Annotated - def inferSigmaMeta[A: HasRegion](e: Expr[A], meta: Option[(Identifier, Type.TyMeta, Region)]): Infer[TypedExpr[A]] = { + def inferSigmaMeta[A: HasRegion]( + e: Expr[A], + meta: Option[(Identifier, Type.TyMeta, Region)] + ): Infer[TypedExpr[A]] = { def unifySelf(tpe: Type.Rho): Infer[Map[Name, Type]] = meta match { case None => getEnv @@ -1857,25 +2231,29 @@ object Infer { } } - /** - * if meta is Some, it is because it recursive, but those are almost - * always functions, so we can at least fix the arity of the function. - */ - val init: Infer[Unit] = - meta match { - case Some((_, tpe, rtpe)) => - def maybeUnified(e: Expr[A]): Infer[Unit] = - e match { - case Expr.Lambda(args, res, _) => - unifyFnRho(args.length, tpe, rtpe, region(e) - region(res)).void - case _ => - // we just have to wait to infer - unit - } + /** if meta is Some, it is because it recursive, but those are almost + * always functions, so we can at least fix the arity of the function. + */ + val init: Infer[Unit] = + meta match { + case Some((_, tpe, rtpe)) => + def maybeUnified(e: Expr[A]): Infer[Unit] = + e match { + case Expr.Lambda(args, res, _) => + unifyFnRho( + args.length, + tpe, + rtpe, + region(e) - region(res) + ).void + case _ => + // we just have to wait to infer + unit + } - maybeUnified(e) - case None => unit - } + maybeUnified(e) + case None => unit + } for { _ <- init @@ -1885,7 +2263,10 @@ object Infer { } yield q } - def quantify[A](env: Infer[Map[Name, Type]], rho: TypedExpr.Rho[A]): Infer[TypedExpr[A]] = + def quantify[A]( + env: Infer[Map[Name, Type]], + rho: TypedExpr.Rho[A] + ): Infer[TypedExpr[A]] = for { e <- env zrho <- zonkTypedExpr(rho) @@ -1894,8 +2275,10 @@ object Infer { // allocate this once and reuse private val envTypes = getEnv.map(_.values.toList) - - def quantifyMetas(metas: List[Type.TyMeta]): FunctionK[TypedExpr, Lambda[x => Infer[TypedExpr[x]]]] = + + def quantifyMetas( + metas: List[Type.TyMeta] + ): FunctionK[TypedExpr, Lambda[x => Infer[TypedExpr[x]]]] = NonEmptyList.fromList(metas) match { case None => new FunctionK[TypedExpr, Lambda[x => Infer[TypedExpr[x]]]] { @@ -1909,7 +2292,8 @@ object Infer { val aligned = Type.alignBinders(nel, used) val bound = aligned.toList.traverseFilter { case (m, n) => val meta = m.toMeta - if (meta.existential) writeMeta(m.toMeta, Type.TyVar(n)).as(Some((n, meta.kind))) + if (meta.existential) + writeMeta(m.toMeta, Type.TyVar(n)).as(Some((n, meta.kind))) else pure(None) } // we only need to zonk after doing a write: @@ -1917,9 +2301,13 @@ object Infer { // here have been realized to Type.Var now, and and meta pointing at them should // become visible (no longer hidden) val zFn = Type.zonk( - metas.iterator.map(_.toMeta).filter(_.existential).to(SortedSet), + metas.iterator + .map(_.toMeta) + .filter(_.existential) + .to(SortedSet), readMeta, - writeMeta) + writeMeta + ) (bound, TypedExpr.zonkMeta(fa)(zFn)) .mapN { (typeArgs, r) => TypedExpr.quantVars(forallList = Nil, existList = typeArgs, r) @@ -1931,17 +2319,21 @@ object Infer { def checkSigma[A: HasRegion](t: Expr[A], tpe: Type): Infer[TypedExpr[A]] = { val regionT = region(t) for { - check <- subsUpper[Lambda[x => (Expr[x], HasRegion[x])], Infer](tpe, regionT, envTypes) { (metas, rho) => + check <- subsUpper[Lambda[x => (Expr[x], HasRegion[x])], Infer]( + tpe, + regionT, + envTypes + ) { (metas, rho) => val cRho = checkRhoK(rho) if (tpe === rho) { // we don't need to zonk here pure(cRho) - } - else { + } else { // we need to zonk before we unskolemize because some of the metas could be skolems - pure(cRho - .andThenFlatMap[TypedExpr](zonkTypeExprK) - .andThenFlatMap[TypedExpr](quantifyMetas(metas)) + pure( + cRho + .andThenFlatMap[TypedExpr](zonkTypeExprK) + .andThenFlatMap[TypedExpr](quantifyMetas(metas)) ) } } { badTvs => @@ -1951,22 +2343,30 @@ object Infer { } yield te } - /** - * invariant: rho needs to be in weak-prenex form - */ - def checkRho[A: HasRegion](t: Expr[A], rho: Type.Rho): Infer[TypedExpr.Rho[A]] = + /** invariant: rho needs to be in weak-prenex form + */ + def checkRho[A: HasRegion]( + t: Expr[A], + rho: Type.Rho + ): Infer[TypedExpr.Rho[A]] = typeCheckRho(t, Expected.Check((rho, region(t)))) // same as checkRho but as a FunctionK - def checkRhoK(rho: Type.Rho): FunctionK[Lambda[x => (Expr[x], HasRegion[x])], Lambda[x => Infer[TypedExpr.Rho[x]]]] = - new FunctionK[Lambda[x => (Expr[x], HasRegion[x])], Lambda[x => Infer[TypedExpr.Rho[x]]]] { - def apply[A](fa: (Expr[A], HasRegion[A])): Infer[TypedExpr[A]] = + def checkRhoK(rho: Type.Rho): FunctionK[Lambda[ + x => (Expr[x], HasRegion[x]) + ], Lambda[x => Infer[TypedExpr.Rho[x]]]] = + new FunctionK[Lambda[x => (Expr[x], HasRegion[x])], Lambda[ + x => Infer[TypedExpr.Rho[x]] + ]] { + def apply[A](fa: (Expr[A], HasRegion[A])): Infer[TypedExpr[A]] = checkRho(fa._1, rho)(fa._2) } - /** - * recall a rho type never has a top level Forall - */ - def inferRho[A: HasRegion](t: Expr[A]): Infer[(TypedExpr.Rho[A], Type.Rho)] = + + /** recall a rho type never has a top level Forall + */ + def inferRho[A: HasRegion]( + t: Expr[A] + ): Infer[(TypedExpr.Rho[A], Type.Rho)] = for { ref <- initRef[A, (Type.Rho, Region)](t) expr <- typeCheckRho(t, Expected.Inf(ref)) @@ -1974,38 +2374,47 @@ object Infer { eitherTpe <- lift(ref.get <* ref.reset) tpe <- eitherTpe match { case Right(rho) => pure(rho._1) - case Left(err) => fail(err) + case Left(err) => fail(err) } } yield (expr, tpe) } - private def recursiveTypeCheck[A: HasRegion](name: Bindable, expr: Expr[A]): Infer[TypedExpr[A]] = + private def recursiveTypeCheck[A: HasRegion]( + name: Bindable, + expr: Expr[A] + ): Infer[TypedExpr[A]] = // values are of kind Type expr match { case Expr.Annotated(tpe) => extendEnv(name, tpe)(checkSigma(expr, tpe)) case notAnnotated => newMeta.flatMap { tpe => - extendEnv(name, tpe)(typeCheckMeta(notAnnotated, Some((name, tpe, region(notAnnotated))))) + extendEnv(name, tpe)( + typeCheckMeta(notAnnotated, Some((name, tpe, region(notAnnotated)))) + ) } } - def typeCheck[A: HasRegion](t: Expr[A]): Infer[TypedExpr[A]] = typeCheckMeta(t, None) - private def unskolemize(skols: NonEmptyList[Type.Var.Skolem]): TypedExpr.Coerce = + private def unskolemize( + skols: NonEmptyList[Type.Var.Skolem] + ): TypedExpr.Coerce = new FunctionK[TypedExpr, TypedExpr] { def apply[A](te: TypedExpr[A]) = { // now replace the skols with generics val used = te.allBound val aligned = Type.alignBinders(skols, used) - val te2 = substTyExpr(skols, aligned.map { case (_, b) => Type.TyVar(b) }, te) + val te2 = + substTyExpr(skols, aligned.map { case (_, b) => Type.TyVar(b) }, te) TypedExpr.forAll(aligned.map { case (s, b) => (b, s.kind) }, te2) } } - private def unskolemizeExists(skols: List[Type.Var.Skolem]): TypedExpr.Coerce = + private def unskolemizeExists( + skols: List[Type.Var.Skolem] + ): TypedExpr.Coerce = NonEmptyList.fromList(skols) match { case None => FunctionK.id[TypedExpr] case Some(skols) => @@ -2014,22 +2423,32 @@ object Infer { // now replace the skols with generics val used = te.allBound val aligned = Type.alignBinders(skols, used) - val te2 = substTyExpr(skols, aligned.map { case (_, b) => Type.TyVar(b) }, te) + val te2 = substTyExpr( + skols, + aligned.map { case (_, b) => Type.TyVar(b) }, + te + ) TypedExpr.quantVars( forallList = Nil, existList = aligned.toList.map { case (s, b) => (b, s.kind) }, - te2) + te2 + ) } - } + } } // Invariant: if optMeta.isDefined then t is not Expr.Annotated - private def typeCheckMeta[A: HasRegion](t: Expr[A], optMeta: Option[(Identifier, Type.TyMeta, Region)]): Infer[TypedExpr[A]] = { + private def typeCheckMeta[A: HasRegion]( + t: Expr[A], + optMeta: Option[(Identifier, Type.TyMeta, Region)] + ): Infer[TypedExpr[A]] = { def run(t: Expr[A]) = inferSigmaMeta(t, optMeta).flatMap(zonkTypedExpr _) val optSkols = t match { case Expr.Generic(vs, e) => - Some(Expr.skolemizeVars(vs, e)(newSkolemTyVar(_, _, existential = false))) + Some( + Expr.skolemizeVars(vs, e)(newSkolemTyVar(_, _, existential = false)) + ) case _ => None } @@ -2047,64 +2466,81 @@ object Infer { def extendEnv[A](varName: Bindable, tpe: Type)(of: Infer[A]): Infer[A] = extendEnvNonEmptyList(NonEmptyList.one((varName, tpe)))(of) - def extendEnvList[A](bindings: List[(Bindable, Type)])(of: Infer[A]): Infer[A] = + def extendEnvList[A]( + bindings: List[(Bindable, Type)] + )(of: Infer[A]): Infer[A] = NonEmptyList.fromList(bindings) match { case Some(nel) => extendEnvNonEmptyList(nel)(of) - case None => of + case None => of } - def extendEnvNonEmptyList[A](bindings: NonEmptyList[(Bindable, Type)])(of: Infer[A]): Infer[A] = + def extendEnvNonEmptyList[A](bindings: NonEmptyList[(Bindable, Type)])( + of: Infer[A] + ): Infer[A] = Infer.Impl.ExtendEnvs(bindings.map { case (n, t) => ((None, n), t) }, of) - private def extendEnvListPack[A](pack: PackageName, nameTpe: NonEmptyList[(Bindable, Type)])(of: Infer[A]): Infer[A] = - Infer.Impl.ExtendEnvs(nameTpe.map { case (name, tpe) => ((Some(pack), name), tpe) }, of) - - /** - * Packages are generally just lists of lets, this allows you to infer - * the scheme for each in the context of the list - */ - def typeCheckLets[A: HasRegion](pack: PackageName, ls: List[(Bindable, RecursionKind, Expr[A])]): Infer[List[(Bindable, RecursionKind, TypedExpr[A])]] = { + private def extendEnvListPack[A]( + pack: PackageName, + nameTpe: NonEmptyList[(Bindable, Type)] + )(of: Infer[A]): Infer[A] = + Infer.Impl.ExtendEnvs( + nameTpe.map { case (name, tpe) => ((Some(pack), name), tpe) }, + of + ) + + /** Packages are generally just lists of lets, this allows you to infer the + * scheme for each in the context of the list + */ + def typeCheckLets[A: HasRegion]( + pack: PackageName, + ls: List[(Bindable, RecursionKind, Expr[A])] + ): Infer[List[(Bindable, RecursionKind, TypedExpr[A])]] = { // Group together lets that don't include each other to get more type errors // if we can type G = NonEmptyChain[(Bindable, RecursionKind, Expr[A])] - def run(groups: List[G]): Infer[List[(Bindable, RecursionKind, TypedExpr[A])]] = + def run( + groups: List[G] + ): Infer[List[(Bindable, RecursionKind, TypedExpr[A])]] = groups match { case Nil => Infer.pure(Nil) case group :: tail => for { groupChain <- group.parTraverse { case (name, rec, expr) => - (if (rec.isRecursive) recursiveTypeCheck(name, expr) else typeCheck(expr)) + (if (rec.isRecursive) recursiveTypeCheck(name, expr) + else typeCheck(expr)) .map { te => (name, rec, te) } } glist = groupChain.toNonEmptyList - tailRes <- extendEnvListPack(pack, glist.map { case (b, _, te) => + tailRes <- extendEnvListPack( + pack, + glist.map { case (b, _, te) => (b, te.getType) - }) { - run(tail) - } + } + ) { + run(tail) + } } yield glist.head :: glist.tail ::: tailRes - } + } val groups: List[G] = - ListUtil.greedyGroup(ls)({ item => NonEmptyChain.one(item) }) { case (bs, item @ (_, _, expr)) => - val dependsOnGroup = - expr.globals.iterator.exists { - case Expr.Global(p, n1, _) => (p == pack) && bs.exists(_._1 == n1) - } - if (dependsOnGroup) None // we can't run in parallel - else Some(bs :+ item) + ListUtil.greedyGroup(ls)({ item => NonEmptyChain.one(item) }) { + case (bs, item @ (_, _, expr)) => + val dependsOnGroup = + expr.globals.iterator.exists { case Expr.Global(p, n1, _) => + (p == pack) && bs.exists(_._1 == n1) + } + if (dependsOnGroup) None // we can't run in parallel + else Some(bs :+ item) } run(groups) } - /** - * This is useful to testing purposes. - * - * Given types a and b, can we substitute - * a for for b - */ + /** This is useful to testing purposes. + * + * Given types a and b, can we substitute a for for b + */ def substitutionCheck(a: Type, b: Type, ra: Region, rb: Region): Infer[Unit] = subsCheck(a, b, ra, rb).void } diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/ParsedTypeEnv.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/ParsedTypeEnv.scala index fc3c98747..954f9d202 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/ParsedTypeEnv.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/ParsedTypeEnv.scala @@ -4,11 +4,18 @@ import org.bykn.bosatsu.{PackageName, Identifier} import Identifier.Bindable -case class ParsedTypeEnv[+A](allDefinedTypes: List[DefinedType[A]], externalDefs: List[(PackageName, Bindable, Type)]) { +case class ParsedTypeEnv[+A]( + allDefinedTypes: List[DefinedType[A]], + externalDefs: List[(PackageName, Bindable, Type)] +) { def addDefinedType[A1 >: A](dt: DefinedType[A1]): ParsedTypeEnv[A1] = copy(allDefinedTypes = dt :: allDefinedTypes) - def addExternalValue(pn: PackageName, name: Bindable, tpe: Type): ParsedTypeEnv[A] = + def addExternalValue( + pn: PackageName, + name: Bindable, + tpe: Type + ): ParsedTypeEnv[A] = copy(externalDefs = (pn, name, tpe) :: externalDefs) } diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Ref.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Ref.scala index fb22f9cbd..f071dd24b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Ref.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Ref.scala @@ -5,9 +5,8 @@ import cats.{StackSafeMonad, Eval} import scala.collection.mutable.{LongMap => MutableMap} import java.util.concurrent.atomic.AtomicLong -/** - * This gives a mutable reference in a monadic context - */ +/** This gives a mutable reference in a monadic context + */ sealed trait Ref[A] { def get: RefSpace[A] def set(a: A): RefSpace[Unit] @@ -41,7 +40,9 @@ object RefSpace { protected def runState(al: AtomicLong, state: State) = value } - private case class AllocRef[A](handle: Long, init: A) extends RefSpace[A] with Ref[A] { + private case class AllocRef[A](handle: Long, init: A) + extends RefSpace[A] + with Ref[A] { def get: RefSpace[A] = this def set(a: A): RefSpace[Unit] = SetRef(handle, a) val reset: RefSpace[Unit] = Reset(handle) @@ -57,26 +58,31 @@ object RefSpace { } } private case class SetRef(handle: Long, value: Any) extends RefSpace[Unit] { - protected def runState(al: AtomicLong, state: State): Eval[Unit] = - { state.put(handle, value); Eval.Unit } + protected def runState(al: AtomicLong, state: State): Eval[Unit] = { + state.put(handle, value); Eval.Unit + } } private case class Reset(handle: Long) extends RefSpace[Unit] { - protected def runState(al: AtomicLong, state: State): Eval[Unit] = - { state.remove(handle); Eval.Unit } + protected def runState(al: AtomicLong, state: State): Eval[Unit] = { + state.remove(handle); Eval.Unit + } } private case class Alloc[A](init: A) extends RefSpace[Ref[A]] { protected def runState(al: AtomicLong, state: State): Eval[Ref[A]] = Eval.now(AllocRef(al.getAndIncrement, init)) } - private case class Map[A, B](init: RefSpace[A], fn: A => B) extends RefSpace[B] { + private case class Map[A, B](init: RefSpace[A], fn: A => B) + extends RefSpace[B] { protected def runState(al: AtomicLong, state: State) = Eval.defer(init.runState(al, state)).map(fn) } - private case class FlatMap[A, B](init: RefSpace[A], fn: A => RefSpace[B]) extends RefSpace[B] { + private case class FlatMap[A, B](init: RefSpace[A], fn: A => RefSpace[B]) + extends RefSpace[B] { protected def runState(al: AtomicLong, state: State): Eval[B] = - Eval.defer(init.runState(al, state)) + Eval + .defer(init.runState(al, state)) .flatMap { a => fn(a).runState(al, state) } @@ -95,23 +101,24 @@ object RefSpace { def put(key: Long, value: Any): Unit = discard(map.put(key, value)) def get(key: Long): Option[Any] = map.get(key) - def remove(key: Long): Unit = + def remove(key: Long): Unit = discard(map.remove(key)) } - private[RefSpace] class Fork(under: State, over: MutableMap[Option[Any]]) extends State { + private[RefSpace] class Fork(under: State, over: MutableMap[Option[Any]]) + extends State { def put(key: Long, value: Any): Unit = discard(over.put(key, Some(value))) def get(key: Long): Option[Any] = over.get(key) match { case Some(s) => s - case None => under.get(key) + case None => under.get(key) } def remove(key: Long): Unit = discard(over.put(key, None)) def flush(): Unit = { over.foreach { case (k, Some(v)) => under.put(k, v) - case (k, None) => under.remove(k) + case (k, None) => under.remove(k) } } @@ -122,16 +129,20 @@ object RefSpace { def fork(state: State): Fork = new Fork(state, MutableMap.empty) } - private case class ResetOnLeft[A, B, C](init: RefSpace[A], fn: A => Either[B, C]) extends RefSpace[Either[B, C]] { + private case class ResetOnLeft[A, B, C]( + init: RefSpace[A], + fn: A => Either[B, C] + ) extends RefSpace[Either[B, C]] { protected def runState(al: AtomicLong, state: State): Eval[Either[B, C]] = { val forked = State.fork(state) - init.runState(al, forked) + init + .runState(al, forked) .map { a => fn(a) match { - case r@Right(_) => + case r @ Right(_) => forked.flush() r - case l@Left(_) => + case l @ Left(_) => // just let the forked state disappear l } @@ -159,7 +170,8 @@ object RefSpace { // a counter that starts at 0 val allocCounter: RefSpace[RefSpace[Long]] = - RefSpace.newRef(0L) + RefSpace + .newRef(0L) .map { ref => for { a <- ref.get diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala index 7bff119f8..c978c772b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala @@ -4,7 +4,15 @@ import cats.data.NonEmptyList import cats.parse.{Parser => P, Numbers} import cats.{Applicative, Monad, Order} import org.typelevel.paiges.{Doc, Document} -import org.bykn.bosatsu.{Kind, PackageName, Lit, TypeName, Identifier, Parser, TypeParser} +import org.bykn.bosatsu.{ + Kind, + PackageName, + Lit, + TypeName, + Identifier, + Parser, + TypeParser +} import org.bykn.bosatsu.graph.Memoize.memoizeDagHashedConcurrent import scala.collection.immutable.{SortedSet, SortedMap} @@ -17,13 +25,13 @@ sealed abstract class Type { } object Type { - /** - * A type with no top level quantification - */ + + /** A type with no top level quantification + */ sealed abstract class Rho extends Type { override def normalize: Rho } - + object Rho { implicit val orderRho: Order[Rho] = new Order[Rho] { @@ -35,16 +43,16 @@ object Type { case (TyVar(v0), TyVar(v1)) => Ordering[Var].compare(v0, v1) case (TyVar(_), TyConst(_)) => 1 - case (TyVar(_), _) => -1 + case (TyVar(_), _) => -1 case (TyMeta(m0), TyMeta(m1)) => Meta.orderingMeta.compare(m0, m1) case (TyMeta(_), TyApply(_, _)) => -1 - case (TyMeta(_), _) => 1 + case (TyMeta(_), _) => 1 case (TyApply(a0, b0), TyApply(a1, b1)) => val c = Type.typeOrder.compare(a0, a1) if (c == 0) Type.typeOrder.compare(b0, b1) else c case (TyApply(_, _), _) => 1 - } + } } implicit val orderingRho: Ordering[Rho] = orderRho.toOrdering @@ -64,41 +72,45 @@ object Type { def filter(fn: Var.Bound => Boolean): Option[Quantification] = Quantification.fromLists( forallList.filter { case (b, _) => fn(b) }, - existList.filter { case (b, _) => fn(b) }) + existList.filter { case (b, _) => fn(b) } + ) def existsQuant(fn: ((Var.Bound, Kind)) => Boolean): Boolean } object Quantification { - case class ForAll(vars: NonEmptyList[(Var.Bound, Kind)]) extends Quantification { + case class ForAll(vars: NonEmptyList[(Var.Bound, Kind)]) + extends Quantification { def existList: List[(Var.Bound, Kind)] = Nil def forallList: List[(Var.Bound, Kind)] = vars.toList def concat(that: Quantification): Quantification = that match { case ForAll(vars1) => ForAll(vars ::: vars1) case Exists(evars) => Dual(vars, evars) - case Dual(f, e) => Dual(vars ::: f, e) + case Dual(f, e) => Dual(vars ::: f, e) } def existsQuant(fn: ((Var.Bound, Kind)) => Boolean): Boolean = vars.exists(fn) } - case class Exists(vars: NonEmptyList[(Var.Bound, Kind)]) extends Quantification { + case class Exists(vars: NonEmptyList[(Var.Bound, Kind)]) + extends Quantification { def existList: List[(Var.Bound, Kind)] = vars.toList def forallList: List[(Var.Bound, Kind)] = Nil def concat(that: Quantification): Quantification = that match { case ForAll(vars1) => Dual(vars1, vars) case Exists(evars) => Exists(vars ::: evars) - case Dual(f, e) => Dual(f, vars ::: e) + case Dual(f, e) => Dual(f, vars ::: e) } def existsQuant(fn: ((Var.Bound, Kind)) => Boolean): Boolean = vars.exists(fn) } case class Dual( - foralls: NonEmptyList[(Var.Bound, Kind)], - exists: NonEmptyList[(Var.Bound, Kind)]) extends Quantification { + foralls: NonEmptyList[(Var.Bound, Kind)], + exists: NonEmptyList[(Var.Bound, Kind)] + ) extends Quantification { lazy val vars = foralls ::: exists def existList: List[(Var.Bound, Kind)] = exists.toList @@ -107,7 +119,7 @@ object Type { that match { case ForAll(vars1) => Dual(foralls ::: vars1, exists) case Exists(evars) => Dual(foralls, exists ::: evars) - case Dual(f, e) => Dual(foralls ::: f, exists ::: e) + case Dual(f, e) => Dual(foralls ::: f, exists ::: e) } def existsQuant(fn: ((Var.Bound, Kind)) => Boolean): Boolean = @@ -121,26 +133,30 @@ object Type { def compare(a: Quantification, b: Quantification): Int = (a, b) match { case (ForAll(v0), ForAll(v1)) => nelist.compare(v0, v1) - case (ForAll(_), _) => -1 - case (Exists(_), ForAll(_)) => 1 + case (ForAll(_), _) => -1 + case (Exists(_), ForAll(_)) => 1 case (Exists(v0), Exists(v1)) => nelist.compare(v0, v1) - case (Exists(_), _) => -1 + case (Exists(_), _) => -1 case (Dual(fa0, ex0), Dual(fa1, ex1)) => val c1 = nelist.compare(fa0, fa1) if (c1 != 0) c1 else nelist.compare(ex0, ex1) case (Dual(_, _), _) => 1 } - } + } - def fromLists(forallList: List[(Var.Bound, Kind)], existList: List[(Var.Bound, Kind)]): Option[Quantification] = + def fromLists( + forallList: List[(Var.Bound, Kind)], + existList: List[(Var.Bound, Kind)] + ): Option[Quantification] = forallList match { case Nil => NonEmptyList.fromList(existList).map(Exists(_)) - case head :: tail => + case head :: tail => Some(existList match { case Nil => ForAll(NonEmptyList(head, tail)) - case eh :: et => Dual(NonEmptyList(head, tail), NonEmptyList(eh, et)) + case eh :: et => + Dual(NonEmptyList(head, tail), NonEmptyList(eh, et)) }) } } @@ -177,11 +193,11 @@ object Type { // a Leaf is never equal to TyApply right match { case rightLeaf: Leaf => leftLeaf == rightLeaf - case _: TyApply => false - case q: Quantified => leftLeaf == normalize(q) + case _: TyApply => false + case q: Quantified => leftLeaf == normalize(q) } case _: TyApply if right.isInstanceOf[Leaf] => false - case _ => + case _ => // at least one is quantified or both are TyApply normalize(left) == normalize(right) } @@ -204,14 +220,14 @@ object Type { @annotation.tailrec def applyAllRho(rho: Rho, args: List[Type]): Rho = args match { - case Nil => rho + case Nil => rho case a :: as => applyAllRho(TyApply(rho, a), as) } def apply1(fn: Type, arg: Type): Type = fn match { case rho: Rho => TyApply(rho, arg) - case q => applyAll(q, arg :: Nil) + case q => applyAll(q, arg :: Nil) } def applyAll(fn: Type, args: List[Type]): Type = @@ -221,43 +237,43 @@ object Type { val freeBound = freeBoundTyVars(fn :: args) if (freeBound.isEmpty) { Quantified(q, applyAllRho(rho, args)) - } - else { + } else { val freeBoundSet: Set[Var.Bound] = freeBound.toSet val collisions = q.existsQuant { case (b, _) => freeBoundSet(b) } if (!collisions) { // we don't need to rename the vars Quantified(q, applyAllRho(rho, args)) - } - else { + } else { // we have to to rename the collisions so the free set // is unchanged val fa1 = alignBinders(q.forallList, freeBoundSet) val ex1 = alignBinders(q.existList, freeBoundSet ++ fa1.map(_._2)) - val subMap = (fa1.iterator ++ ex1.iterator).map { - case ((b0, _), b1) => (b0, TyVar(b1)) - } - .toMap[Var, Rho] + val subMap = (fa1.iterator ++ ex1.iterator) + .map { case ((b0, _), b1) => + (b0, TyVar(b1)) + } + .toMap[Var, Rho] val rho1 = substituteRhoVar(rho, subMap) - val q1 = Quantification.fromLists( - forallList = fa1.map { case ((_, k), b) => (b, k)}, - existList = ex1.map { case ((_, k), b) => (b, k)} - ) - .get // this Option must be defined because we started with a defined q + val q1 = Quantification + .fromLists( + forallList = fa1.map { case ((_, k), b) => (b, k) }, + existList = ex1.map { case ((_, k), b) => (b, k) } + ) + .get // this Option must be defined because we started with a defined q Quantified(q1, applyAllRho(rho1, args)) } } - } + } def unapplyAll(fn: Type): (Type, List[Type]) = { @annotation.tailrec def loop(fn: Type, acc: List[Type]): (Type, List[Type]) = fn match { case TyApply(fn, a) => loop(fn, a :: acc) - case notApply => (notApply, acc) + case notApply => (notApply, acc) } loop(fn, Nil) @@ -270,7 +286,8 @@ object Type { case q: Quantified => q.quant match { case Quantification.ForAll(vars) => Some((vars, q.in)) - case Quantification.Dual(foralls, existsNel) => Some((foralls, exists(existsNel, q.in))) + case Quantification.Dual(foralls, existsNel) => + Some((foralls, exists(existsNel, q.in))) case _ => None } } @@ -283,7 +300,8 @@ object Type { case q: Quantified => q.quant match { case Quantification.Exists(vars) => Some((vars, q.in)) - case Quantification.Dual(foralls, existsNel) => Some((existsNel, forAll(foralls, q.in))) + case Quantification.Dual(foralls, existsNel) => + Some((existsNel, forAll(foralls, q.in))) case _ => None } } @@ -291,26 +309,26 @@ object Type { def constantsOf(t: Type): List[Const] = t match { - case Quantified(_, t) => constantsOf(t) - case TyApply(on, arg) => constantsOf(on) ::: constantsOf(arg) - case TyConst(c) => c :: Nil + case Quantified(_, t) => constantsOf(t) + case TyApply(on, arg) => constantsOf(on) ::: constantsOf(arg) + case TyConst(c) => c :: Nil case TyVar(_) | TyMeta(_) => Nil } def hasNoVars(t: Type): Boolean = t match { - case TyConst(_) => true + case TyConst(_) => true case TyVar(_) | TyMeta(_) => false - case TyApply(on, arg) => hasNoVars(on) && hasNoVars(arg) - case q: Quantified => freeTyVars(q :: Nil).isEmpty + case TyApply(on, arg) => hasNoVars(on) && hasNoVars(arg) + case q: Quantified => freeTyVars(q :: Nil).isEmpty } def hasNoUnboundVars(t: Type): Boolean = { def loop(t: Type, bound: Set[Var.Bound]): Boolean = t match { case TyVar(b: Var.Bound) => bound(b) - case _: Leaf => true - case TyApply(on, arg) => loop(on, bound) && loop(arg, bound) + case _: Leaf => true + case TyApply(on, arg) => loop(on, bound) && loop(arg, bound) case q: Quantified => loop(q.in, bound ++ q.vars.iterator.map(_._1)) } @@ -320,11 +338,14 @@ object Type { final def forAll(vars: List[(Var.Bound, Kind)], in: Type): Type = NonEmptyList.fromList(vars) match { - case None => in + case None => in case Some(ne) => forAll(ne, in) } - final def forAll(vars: NonEmptyList[(Var.Bound, Kind)], in: Type): Type.Quantified = + final def forAll( + vars: NonEmptyList[(Var.Bound, Kind)], + in: Type + ): Type.Quantified = in match { case rho: Rho => Quantified(Quantification.ForAll(vars), rho) @@ -335,11 +356,17 @@ object Type { case Quantification.Exists(ne1) => Quantified(Quantification.Dual(foralls = vars, exists = ne1), q.in) case Quantification.Dual(fa0, e) => - Quantified(Quantification.Dual(foralls = vars ::: fa0, exists = e), q.in) + Quantified( + Quantification.Dual(foralls = vars ::: fa0, exists = e), + q.in + ) } } - final def exists(vars: NonEmptyList[(Var.Bound, Kind)], in: Type): Type.Quantified = + final def exists( + vars: NonEmptyList[(Var.Bound, Kind)], + in: Type + ): Type.Quantified = in match { case rho: Rho => Quantified(Quantification.Exists(vars), rho) @@ -350,23 +377,36 @@ object Type { case Quantification.ForAll(ne1) => Quantified(Quantification.Dual(foralls = ne1, exists = vars), q.in) case Quantification.Dual(fa0, e) => - Quantified(Quantification.Dual(foralls = fa0, exists = vars ::: e), q.in) + Quantified( + Quantification.Dual(foralls = fa0, exists = vars ::: e), + q.in + ) } } final def exists(vars: List[(Var.Bound, Kind)], in: Type): Type = vars match { case h :: t => exists(NonEmptyList(h, t), in) - case Nil => in + case Nil => in } - final def quantify(forallList: List[(Var.Bound, Kind)], existList: List[(Var.Bound, Kind)], in: Type): Type = - Quantification.fromLists(forallList = forallList, existList = existList) match { + final def quantify( + forallList: List[(Var.Bound, Kind)], + existList: List[(Var.Bound, Kind)], + in: Type + ): Type = + Quantification.fromLists( + forallList = forallList, + existList = existList + ) match { case Some(q) => quantify(q, in) - case None => in + case None => in } - final def quantify(quantification: Quantification, tpe: Type): Type.Quantified = + final def quantify( + quantification: Quantification, + tpe: Type + ): Type.Quantified = quantification match { case Quantification.ForAll(vars) => forAll(vars, tpe) case Quantification.Exists(vars) => exists(vars, tpe) @@ -376,29 +416,27 @@ object Type { def getTypeOf(lit: Lit): Type = lit match { case Lit.Integer(_) => Type.IntType - case Lit.Str(_) => Type.StrType - case Lit.Chr(_) => Type.CharType + case Lit.Str(_) => Type.StrType + case Lit.Chr(_) => Type.CharType } - /** - * types are var, meta, or const, or applied or forall on one of - * those. This returns the Type.TyConst found - * by recursing - */ + /** types are var, meta, or const, or applied or forall on one of those. This + * returns the Type.TyConst found by recursing + */ @annotation.tailrec final def rootConst(t: Type): Option[Type.TyConst] = t match { - case tyc@TyConst(_) => Some(tyc) + case tyc @ TyConst(_) => Some(tyc) case TyVar(_) | TyMeta(_) => None - case TyApply(left, _) => rootConst(left) - case q: Quantified => rootConst(q.in) + case TyApply(left, _) => rootConst(left) + case q: Quantified => rootConst(q.in) } def allConsts(ts: List[Type]): List[TyConst] = { @annotation.tailrec def loop(ts: List[Type], acc: List[TyConst]): List[TyConst] = ts match { - case (tyc@TyConst(_)) :: tail => + case (tyc @ TyConst(_)) :: tail => loop(tail, tyc :: acc) case (TyVar(_) | TyMeta(_)) :: tail => loop(tail, acc) @@ -418,10 +456,12 @@ object Type { rootConst(t) } - /** - * This form is often useful in Infer - */ - def substTy(keys: NonEmptyList[Var], vals: NonEmptyList[Type]): Type => Type = { + /** This form is often useful in Infer + */ + def substTy( + keys: NonEmptyList[Var], + vals: NonEmptyList[Type] + ): Type => Type = { val env = keys.iterator.zip(vals.iterator).toMap { t => substituteVar(t, env) } @@ -429,44 +469,46 @@ object Type { def substituteVar(t: Type, env: Map[Type.Var, Type]): Type = if (env.isEmpty) t - else (t match { - case TyApply(on, arg) => - apply1(substituteVar(on, env), substituteVar(arg, env)) - case v@TyVar(n) => - env.get(n) match { - case Some(rho) => rho - case None => v - } - case m@TyMeta(_) => m - case c@TyConst(_) => c - case q: Quantified => - val boundSet = q.vars.iterator.map(_._1).toSet[Type.Var] - val env1 = env.iterator.filter { case (v, _) => !boundSet(v) }.toMap - val subin = substituteVar(q.in, env1) - quantify(q.quant, subin) - }) + else + (t match { + case TyApply(on, arg) => + apply1(substituteVar(on, env), substituteVar(arg, env)) + case v @ TyVar(n) => + env.get(n) match { + case Some(rho) => rho + case None => v + } + case m @ TyMeta(_) => m + case c @ TyConst(_) => c + case q: Quantified => + val boundSet = q.vars.iterator.map(_._1).toSet[Type.Var] + val env1 = env.iterator.filter { case (v, _) => !boundSet(v) }.toMap + val subin = substituteVar(q.in, env1) + quantify(q.quant, subin) + }) def substituteRhoVar(t: Type.Rho, env: Map[Type.Var, Type.Rho]): Type.Rho = t match { case TyApply(on, arg) => TyApply(substituteRhoVar(on, env), substituteVar(arg, env)) - case v@TyVar(n) => + case v @ TyVar(n) => env.get(n) match { case Some(rho) => rho - case None => v + case None => v } - case m@TyMeta(_) => m - case c@TyConst(_) => c + case m @ TyMeta(_) => m + case c @ TyConst(_) => c } - /** - * Kind of the opposite of substitute: given a Map of vars, can - * we set those vars to some Type and get from to match to exactly - */ - def instantiate( - vars: Map[Var.Bound, Kind], - from: Type, - to: Type): Option[(SortedMap[Var.Bound, (Kind, Var.Bound)], SortedMap[Var.Bound, (Kind, Type)])] = { + /** Kind of the opposite of substitute: given a Map of vars, can we set those + * vars to some Type and get from to match to exactly + */ + def instantiate(vars: Map[Var.Bound, Kind], from: Type, to: Type): Option[ + ( + SortedMap[Var.Bound, (Kind, Var.Bound)], + SortedMap[Var.Bound, (Kind, Type)] + ) + ] = { sealed abstract class BoundState case object Unknown extends BoundState @@ -474,8 +516,9 @@ object Type { case class Free(rightName: Var.Bound) extends BoundState case class State( - fixed: Map[Var.Bound, (Kind, BoundState)], - rightFrees: Map[Var.Bound, Kind]) { + fixed: Map[Var.Bound, (Kind, BoundState)], + rightFrees: Map[Var.Bound, Kind] + ) { def get(b: Var.Bound): Option[(Kind, BoundState)] = fixed.get(b) @@ -502,11 +545,10 @@ object Type { case Some(toBKind) => if (Kind.leftSubsumesRight(kind, toBKind)) { Some(state.updated(b, (toBKind, Free(toB)))) - } - else None + } else None case None => None - // don't set to vars to non-free bound variables - // this shouldn't happen in real inference + // don't set to vars to non-free bound variables + // this shouldn't happen in real inference } case _ if hasNoUnboundVars(to) => Some(state.updated(b, (kind, Fixed(to)))) @@ -518,7 +560,7 @@ object Type { case Free(rightName) => to match { case TyVar(toB) if rightName == toB => Some(state) - case _ => None + case _ => None } } case None => @@ -529,19 +571,26 @@ object Type { to match { case TyApply(ta, tb) => loop(a, ta, state).flatMap { s1 => - loop(b, tb, s1) - } + loop(b, tb, s1) + } case ForAll(rightFrees, rightT) => // TODO handle shadowing - if (rightFrees.exists { case (b, _) => state.rightFrees.contains(b) }) { + if ( + rightFrees.exists { case (b, _) => + state.rightFrees.contains(b) + } + ) { None - } - else { - loop(from, + } else { + loop( + from, rightT, - state.copy(rightFrees = state.rightFrees ++ rightFrees.iterator)) + state.copy(rightFrees = + state.rightFrees ++ rightFrees.iterator + ) + ) .map { s1 => - s1.copy(rightFrees = state.rightFrees) + s1.copy(rightFrees = state.rightFrees) } } case _ => None @@ -549,7 +598,9 @@ object Type { case ForAll(shadows, from1) => val noShadow = state -- shadows.iterator.map(_._1) loop(from1, to, noShadow).map { s1 => - s1 ++ shadows.iterator.flatMap { case (v, _) => state.get(v).map(v -> _) } + s1 ++ shadows.iterator.flatMap { case (v, _) => + state.get(v).map(v -> _) + } } case _ => // We can't use sameAt to compare Var.Bound since we know the variances @@ -559,67 +610,74 @@ object Type { } val initState = State( - vars.iterator.map { case (v, a) => (v, (a, Unknown))}.toMap, + vars.iterator.map { case (v, a) => (v, (a, Unknown)) }.toMap, Map.empty ) loop(from, to, initState) .map { state => ( - state.fixed.iterator.collect { - case (t, (k, Free(t1))) => (t, (k, t1)) - case (t, (k, Unknown)) => (t, (k, t)) - } - .to(SortedMap), - state.fixed.iterator.collect { - case (t, (k, Fixed(f))) => (t, (k, f)) - } - .to(SortedMap)) + state.fixed.iterator + .collect { + case (t, (k, Free(t1))) => (t, (k, t1)) + case (t, (k, Unknown)) => (t, (k, t)) + } + .to(SortedMap), + state.fixed.iterator + .collect { case (t, (k, Fixed(f))) => + (t, (k, f)) + } + .to(SortedMap) + ) } } - /** - * Return the Bound and Skolem variables that - * are free in the given list of types - */ + /** Return the Bound and Skolem variables that are free in the given list of + * types + */ def freeTyVars(ts: List[Type]): List[Type.Var] = { // usually we can recurse in a loop, but sometimes not - def cheat(ts: List[Type], bound: Set[Type.Var.Bound], acc: List[Type.Var]): List[Type.Var] = + def cheat( + ts: List[Type], + bound: Set[Type.Var.Bound], + acc: List[Type.Var] + ): List[Type.Var] = go(ts, bound, acc) @annotation.tailrec - def go(ts: List[Type], bound: Set[Type.Var.Bound], acc: List[Type.Var]): List[Type.Var] = + def go( + ts: List[Type], + bound: Set[Type.Var.Bound], + acc: List[Type.Var] + ): List[Type.Var] = ts match { - case Nil => acc + case Nil => acc case Type.TyVar(tv) :: rest => // we only check here, we don't add val isBound = tv match { - case b@Type.Var.Bound(_) => bound(b) - case _: Type.Var.Skolem => false + case b @ Type.Var.Bound(_) => bound(b) + case _: Type.Var.Skolem => false } if (isBound) go(rest, bound, acc) else go(rest, bound, tv :: acc) case Type.TyApply(a, b) :: rest => go(a :: b :: rest, bound, acc) case (Type.TyMeta(_) | Type.TyConst(_)) :: rest => go(rest, bound, acc) case (q: Quantified) :: rest => - val acc1 = cheat(q.in :: Nil, bound ++ q.vars.toList.iterator.map(_._1), acc) + val acc1 = + cheat(q.in :: Nil, bound ++ q.vars.toList.iterator.map(_._1), acc) // note, q.vars ARE NOT bound in rest go(rest, bound, acc1) } - go(ts, Set.empty, Nil) - .reverse - .distinct + go(ts, Set.empty, Nil).reverse.distinct } - /** - * Return the Bound variables that - * are free in the given list of types - */ + /** Return the Bound variables that are free in the given list of types + */ def freeBoundTyVars(ts: List[Type]): List[Type.Var.Bound] = - freeTyVars(ts).collect { case b@Type.Var.Bound(_) => b } + freeTyVars(ts).collect { case b @ Type.Var.Bound(_) => b } @inline final def normalize(tpe: Type): Type = tpe.normalize @@ -646,7 +704,7 @@ object Type { val foralls = removeDups(q.forallList) val exists = removeDups(q.existList) val in = q.in - + val inFree = freeBoundTyVars(in :: Nil) // sort the quantification by the order of appearance val order = inFree.iterator.zipWithIndex.toMap @@ -665,8 +723,7 @@ object Type { if (bs.nonEmpty) { val subMap = - bs - .iterator + bs.iterator .map { case ((bold, _), bnew) => bold -> TyVar(bnew) } @@ -678,37 +735,39 @@ object Type { val forAllSize = fa1.size val (normfas, normexs) = newVars.splitAt(forAllSize) quantify(forallList = normfas, existList = normexs, normin) - } - else { + } else { // there is nothing to substitute, so we have nothing // to quantify in.normalize } case ta @ TyApply(_, _) => ta.normalize - case _ => tpe + case _ => tpe } - + def kindOfOption( - cons: TyConst => Option[Kind] + cons: TyConst => Option[Kind] ): Type => Option[Kind] = { val unknown: Either[Unit, Kind] = Left(()) val consE = (tc: TyConst) => cons(tc).fold(unknown)(Right(_)) val fn = kindOf[Unit](_ => (), _ => (), (_, _, _) => (), consE) - + fn.andThen { case Right(kind) => Some(kind) - case Left(_) => None + case Left(_) => None } } def kindOf[A]( - unknownVar: Var.Bound => A, - invalidApply: TyApply => A, - kindSubsumeError: (TyApply, Kind.Cons, Kind) => A, - cons: TyConst => Either[A, Kind], + unknownVar: Var.Bound => A, + invalidApply: TyApply => A, + kindSubsumeError: (TyApply, Kind.Cons, Kind) => A, + cons: TyConst => Either[A, Kind] ): Type => Either[A, Kind] = { - val fn = memoizeDagHashedConcurrent[(Type, Map[Var.Bound, Kind]), Either[A, Kind]] { case ((tpe, locals), rec) => + val fn = memoizeDagHashedConcurrent[(Type, Map[Var.Bound, Kind]), Either[ + A, + Kind + ]] { case ((tpe, locals), rec) => tpe match { case Type.TyVar(b @ Type.Var.Bound(_)) => locals.get(b) match { @@ -718,14 +777,15 @@ object Type { // $COVERAGE-ON$ this should be unreachable } case Type.TyVar(Type.Var.Skolem(_, kind, _, _)) => Right(kind) - case Type.TyMeta(Type.Meta(kind, _, _, _)) => Right(kind) - case tc@Type.TyConst(_) => cons(tc) - case ap@Type.TyApply(left, right) => + case Type.TyMeta(Type.Meta(kind, _, _, _)) => Right(kind) + case tc @ Type.TyConst(_) => cons(tc) + case ap @ Type.TyApply(left, right) => rec((left, locals)) .product(rec((right, locals))) - .flatMap { - case (leftKind, rhs) => - Kind.validApply[A](leftKind, rhs, invalidApply(ap))(kindSubsumeError(ap, _, rhs)) + .flatMap { case (leftKind, rhs) => + Kind.validApply[A](leftKind, rhs, invalidApply(ap))( + kindSubsumeError(ap, _, rhs) + ) } case q: Quantified => val varList = q.vars.toList @@ -735,24 +795,23 @@ object Type { { t => fn((t, Map.empty)) } } - /** - * These are upper-case to leverage scala's pattern - * matching on upper-cased vals - */ + + /** These are upper-case to leverage scala's pattern matching on upper-cased + * vals + */ val BoolType: Type.TyConst = TyConst(Const.predef("Bool")) val DictType: Type.TyConst = TyConst(Const.predef("Dict")) object FnType { final val MaxSize = 32 - private def predefFn(n: Int) = TyConst(Const.predef(s"Fn$n")) + private def predefFn(n: Int) = TyConst(Const.predef(s"Fn$n")) private val tpes = (1 to MaxSize).map(predefFn) private val fnMap: Map[TyConst, (TyConst, Int)] = (1 to MaxSize).iterator.map { idx => val tyconst = tpes(idx - 1) - tyconst -> (tyconst, idx) - } - .toMap + tyconst -> (tyconst, idx) + }.toMap object ValidArity { def unapply(n: Int): Boolean = @@ -762,7 +821,9 @@ object Type { def apply(n: Int): Type.TyConst = if (ValidArity.unapply(n)) tpes(n - 1) else { - throw new IllegalArgumentException(s"invalid FnType arity = $n, must be 0 < n <= $MaxSize") + throw new IllegalArgumentException( + s"invalid FnType arity = $n, must be 0 < n <= $MaxSize" + ) } def maybeFakeName(n: Int): Type.TyConst = @@ -775,7 +836,7 @@ object Type { def unapply(tpe: Type): Option[(Type.TyConst, Int)] = tpe match { case tyConst @ Type.TyConst(_) => fnMap.get(tyConst) - case _ => None + case _ => None } // FnType -> Kind(Kind.Type.contra, Kind.Type.co), @@ -784,11 +845,9 @@ object Type { def kindSize(n: Int): Kind = Kind((Vector.fill(n)(Kind.Type.contra) :+ Kind.Type.co): _*) - tpes - .iterator - .zipWithIndex - .map { case (t, n1) => (t, kindSize(n1 + 1)) } - .toList + tpes.iterator.zipWithIndex.map { case (t, n1) => + (t, kindSize(n1 + 1)) + }.toList } } val IntType: Type.TyConst = TyConst(Const.predef("Int")) @@ -811,7 +870,12 @@ object Type { } def unapply(t: Type): Option[(NonEmptyList[Type], Type)] = { - def check(n: Int, t: Type, applied: List[Type], last: Type): Option[(NonEmptyList[Type], Type)] = + def check( + n: Int, + t: Type, + applied: List[Type], + last: Type + ): Option[(NonEmptyList[Type], Type)] = t match { case TyApply(inner, arg) => check(n + 1, inner, arg :: applied, last) @@ -830,11 +894,9 @@ object Type { } } - /** - * Match if a type is a simple universal function, - * which is to say forall a, b. C -> D - * where the result type is a Rho type. - */ + /** Match if a type is a simple universal function, which is to say forall + * a, b. C -> D where the result type is a Rho type. + */ object SimpleUniversal { def unapply(t: Type): Option[ (NonEmptyList[(Type.Var.Bound, Kind)], NonEmptyList[Type], Type) @@ -851,16 +913,21 @@ object Type { case None => Some((univ ::: univR, args, res)) case Some(interNel) => - val good = univR.collect { case pair@(b, _) if !firstSet(b) => pair } + val good = univR.collect { + case pair @ (b, _) if !firstSet(b) => pair + } val avoid = firstSet ++ good.iterator.map(_._1) val rename = alignBinders(interNel, avoid) val subMap = - rename.iterator.map { case ((oldB, _), newB) => - (oldB, Type.TyVar(newB)) - } - .toMap[Type.Var, Type.Rho] - - val bounds = univ.concat(good) ::: rename.map { case ((_, k), b) => (b, k) } + rename.iterator + .map { case ((oldB, _), newB) => + (oldB, Type.TyVar(newB)) + } + .toMap[Type.Var, Type.Rho] + + val bounds = univ.concat(good) ::: rename.map { + case ((_, k), b) => (b, k) + } val newRes = Type.substituteVar(res, subMap) Some((bounds, args, newRes)) } @@ -881,8 +948,8 @@ object Type { def arity(t: Type): Int = t match { case Quantified(_, t) => arity(t) - case Fun(args, _) => args.length - case _ => 0 + case Fun(args, _) => args.length + case _ => 0 } } @@ -901,7 +968,8 @@ object Type { def unapply(t: Type): Option[Int] = t match { case Type.UnitType => Some(0) - case Type.TyConst(Const.Predef(cons)) if cons.asString.startsWith("Tuple") => + case Type.TyConst(Const.Predef(cons)) + if cons.asString.startsWith("Tuple") => Some(cons.asString.drop(5).toInt) case _ => None } @@ -920,9 +988,9 @@ object Type { } def apply(ts: List[Type]): Type = { - val sz = ts.size + val sz = ts.size val root: Type.Rho = Arity(sz) - ts.foldLeft(root) { (acc, t) => TyApply(acc, t)} + ts.foldLeft(root) { (acc, t) => TyApply(acc, t) } } val Kinds: List[(Type.TyConst, Kind)] = { @@ -930,10 +998,7 @@ object Type { def kindSize(n: Int): Kind = Kind(Vector.fill(n)(Kind.Type.co): _*) - (1 to 32) - .iterator - .map { n => (Arity(n), kindSize(n)) } - .toList + (1 to 32).iterator.map { n => (Arity(n), kindSize(n)) }.toList } } @@ -941,7 +1006,7 @@ object Type { def unapply(t: Type): Option[Type] = t match { case TyApply(OptionType, t) => Some(t) - case _ => None + case _ => None } } @@ -949,7 +1014,7 @@ object Type { def unapply(t: Type): Option[(Type, Type)] = t match { case TyApply(TyApply(DictType, kt), vt) => Some((kt, vt)) - case _ => None + case _ => None } } @@ -957,7 +1022,7 @@ object Type { def unapply(t: Type): Option[Type] = t match { case TyApply(ListType, t) => Some(t) - case _ => None + case _ => None } } @@ -976,7 +1041,7 @@ object Type { def unapply(c: Const): Option[Identifier.Constructor] = c match { case Defined(PackageName.PredefName, TypeName(cons)) => Some(cons) - case _ => None + case _ => None } } @@ -996,7 +1061,8 @@ object Type { } object Var { case class Bound(name: String) extends Var - case class Skolem(name: String, kind: Kind, existential: Boolean, id: Long) extends Var + case class Skolem(name: String, kind: Kind, existential: Boolean, id: Long) + extends Var object Bound { private[this] val cache: Array[Bound] = @@ -1007,10 +1073,8 @@ object Type { val c = str.charAt(0) if ('a' <= c && c <= 'z') { cache(c - 'a') - } - else new Bound(str) - } - else new Bound(str) + } else new Bound(str) + } else new Bound(str) implicit val orderBound: Order[Bound] = Order.by[Bound, String](_.name) @@ -1021,7 +1085,7 @@ object Type { def compare(a: Var, b: Var): Int = (a, b) match { case (Bound(a), Bound(b)) => a.compareTo(b) - case (Bound(_), _) => -1 + case (Bound(_), _) => -1 case (Skolem(n0, k0, ex0, i0), Skolem(n1, k1, ex1, i1)) => val c = java.lang.Long.compare(i0, i1) if (c != 0) c @@ -1051,20 +1115,33 @@ object Type { letters.map { c => Var.Bound(c.toString) } #::: lettersWithNumber } - def alignBinders[A](items: NonEmptyList[A], avoid: Var.Bound => Boolean): NonEmptyList[(A, Var.Bound)] = { + def alignBinders[A]( + items: NonEmptyList[A], + avoid: Var.Bound => Boolean + ): NonEmptyList[(A, Var.Bound)] = { val sz = items.size // for some reason on 2.11 we need to do .iterator or this will be an infinite loop - val bs = NonEmptyList.fromListUnsafe(allBinders.iterator.filterNot(avoid).take(sz).toList) + val bs = NonEmptyList.fromListUnsafe( + allBinders.iterator.filterNot(avoid).take(sz).toList + ) NonEmptyList((items.head, bs.head), items.tail.zip(bs.tail)) } - def alignBinders[A](items: List[A], avoid: Var.Bound => Boolean): List[(A, Var.Bound)] = + def alignBinders[A]( + items: List[A], + avoid: Var.Bound => Boolean + ): List[(A, Var.Bound)] = NonEmptyList.fromList(items) match { case Some(nel) => alignBinders(nel, avoid).toList - case None => Nil + case None => Nil } - case class Meta(kind: Kind, id: Long, existential: Boolean, ref: Ref[Option[Type.Tau]]) + case class Meta( + kind: Kind, + id: Long, + existential: Boolean, + ref: Ref[Option[Type.Tau]] + ) object Meta { implicit val orderingMeta: Ordering[Meta] = @@ -1073,34 +1150,31 @@ object Type { if (x.existential) { if (y.existential) java.lang.Long.compare(x.id, y.id) else -1 - } - else { + } else { if (!y.existential) java.lang.Long.compare(x.id, y.id) else 1 } } } - /** - * Final the set of all of Metas inside the list of given types - */ + /** Final the set of all of Metas inside the list of given types + */ def metaTvs(s: List[Type]): SortedSet[Meta] = { @annotation.tailrec def go(check: List[Type], acc: SortedSet[Meta]): SortedSet[Meta] = check match { - case Nil => acc + case Nil => acc case Quantified(_, r) :: tail => go(r :: tail, acc) - case TyApply(a, r) :: tail => go(a :: r :: tail, acc) - case TyMeta(m) :: tail => go(tail, acc + m) - case _ :: tail => go(tail, acc) + case TyApply(a, r) :: tail => go(a :: r :: tail, acc) + case TyMeta(m) :: tail => go(tail, acc + m) + case _ :: tail => go(tail, acc) } go(s, SortedSet.empty) } - /** - * Report bound variables which are used in quantify. When we - * infer a sigma type - */ + /** Report bound variables which are used in quantify. When we infer a sigma + * type + */ def tyVarBinders(tpes: List[Type]): Set[Type.Var.Bound] = { @annotation.tailrec def loop(tpes: List[Type], acc: Set[Type.Var.Bound]): Set[Type.Var.Bound] = @@ -1115,44 +1189,48 @@ object Type { loop(tpes, Set.empty) } - /** - * strange name, but the idea is to replace a Meta with a resolved Rho - * value. I think the name resolve might be better, but the paper I started - * from used zonk + /** strange name, but the idea is to replace a Meta with a resolved Rho value. + * I think the name resolve might be better, but the paper I started from + * used zonk */ def zonk[F[_]: Monad]( - transparent: SortedSet[Meta], - readMeta: Meta => F[Option[Rho]], - writeMeta: (Meta, Type.Rho) => F[Unit]): Meta => F[Option[Rho]] = { + transparent: SortedSet[Meta], + readMeta: Meta => F[Option[Rho]], + writeMeta: (Meta, Type.Rho) => F[Unit] + ): Meta => F[Option[Rho]] = { val pureNone = Monad[F].pure(Option.empty[Rho]) lazy val fn: Meta => F[Option[Rho]] = { (m: Meta) => if (m.existential && !transparent(m)) pureNone - else readMeta(m).flatMap { - case None => pureNone - case (sm @ Some(tm: Type.TyMeta)) if tm.toMeta.existential && !transparent(tm.toMeta) => - // don't zonk from non-existential past existential or we forget - // that this variable is existential and can see through it - Monad[F].pure(sm) - case sty @ Some(ty) => - zonkRhoMeta(ty)(fn).flatMap { ty1 => - if ((ty1: Type) === ty) Monad[F].pure(sty) - else { - // we were able to resolve more of the inner metas - // inside ty, so update the state - writeMeta(m, ty1).as(Some(ty1)) + else + readMeta(m).flatMap { + case None => pureNone + case (sm @ Some(tm: Type.TyMeta)) + if tm.toMeta.existential && !transparent(tm.toMeta) => + // don't zonk from non-existential past existential or we forget + // that this variable is existential and can see through it + Monad[F].pure(sm) + case sty @ Some(ty) => + zonkRhoMeta(ty)(fn).flatMap { ty1 => + if ((ty1: Type) === ty) Monad[F].pure(sty) + else { + // we were able to resolve more of the inner metas + // inside ty, so update the state + writeMeta(m, ty1).as(Some(ty1)) + } } - } - } + } } fn } - /** - * Resolve known meta variables nested inside t - */ - def zonkMeta[F[_]: Applicative](t: Type)(m: Meta => F[Option[Type.Rho]]): F[Type] = + + /** Resolve known meta variables nested inside t + */ + def zonkMeta[F[_]: Applicative]( + t: Type + )(m: Meta => F[Option[Type.Rho]]): F[Type] = t match { case rho: Rho => zonkRhoMeta(rho)(m).widen case q: Quantified => @@ -1161,16 +1239,17 @@ object Type { } } - /** - * Resolve known meta variables nested inside t - */ - def zonkRhoMeta[F[_]: Applicative](t: Type.Rho)(mfn: Meta => F[Option[Type.Rho]]): F[Type.Rho] = + /** Resolve known meta variables nested inside t + */ + def zonkRhoMeta[F[_]: Applicative]( + t: Type.Rho + )(mfn: Meta => F[Option[Type.Rho]]): F[Type.Rho] = t match { case Type.TyApply(on, arg) => (zonkRhoMeta(on)(mfn), zonkMeta(arg)(mfn)).mapN(Type.TyApply(_, _)) - case t@Type.TyMeta(m) => + case t @ Type.TyMeta(m) => mfn(m).map { - case None => t + case None => t case Some(rho) => rho } case (Type.TyConst(_) | Type.TyVar(_)) => Applicative[F].pure(t) @@ -1179,8 +1258,11 @@ object Type { private object FullResolved extends TypeParser[Type] { lazy val parseRoot: P[Type] = { val tvar = Parser.lowerIdent.map { s => Type.TyVar(Type.Var.Bound(s)) } - val name = ((PackageName.parser <* P.string("::")) ~ Identifier.consParser) - .map { case (p, n) => Type.TyConst(Type.Const.Defined(p, TypeName(n))) } + val name = + ((PackageName.parser <* P.string("::")) ~ Identifier.consParser) + .map { case (p, n) => + Type.TyConst(Type.Const.Defined(p, TypeName(n))) + } val longParser: P[Long] = Numbers.signedIntString.mapFilter { str => try Some(str.toLong) catch { @@ -1188,9 +1270,14 @@ object Type { } } val existential = P.char('e').? - val skolem = (P.char('$') *> Parser.lowerIdent, P.char('$') *> (longParser ~ existential)) + val skolem = ( + P.char('$') *> Parser.lowerIdent, + P.char('$') *> (longParser ~ existential) + ) // TODO Kind/existential - .mapN { case (n, (id, ex)) => Var.Skolem(n, Kind.Type, ex.isDefined, id) } + .mapN { case (n, (id, ex)) => + Var.Skolem(n, Kind.Type, ex.isDefined, id) + } .map(TyVar(_)) // this null is bad, but we have no way to reallocate this @@ -1199,7 +1286,9 @@ object Type { // to have fully inferred types with no skolems or metas // TODO Kind val meta = (P.char('?') *> (existential ~ longParser)) - .map { case (opt, l) => TyMeta(Meta(Kind.Type, l, opt.isDefined, null)) } + .map { case (opt, l) => + TyMeta(Meta(Kind.Type, l, opt.isDefined, null)) + } tvar.orElse(name).orElse(skolem).orElse(meta) } @@ -1208,19 +1297,26 @@ object Type { // this may be an invalid function, but typechecking verifies that. Type.Fun(in, out) - def applyTypes(left: Type, args: NonEmptyList[Type]) = applyAll(left, args.toList) + def applyTypes(left: Type, args: NonEmptyList[Type]) = + applyAll(left, args.toList) def universal(vs: NonEmptyList[(String, Option[Kind])], on: Type): Type = - Type.forAll(vs.map { - case (s, None) => (Type.Var.Bound(s), Kind.Type) - case (s, Some(k)) => (Type.Var.Bound(s), k) - }, on) + Type.forAll( + vs.map { + case (s, None) => (Type.Var.Bound(s), Kind.Type) + case (s, Some(k)) => (Type.Var.Bound(s), k) + }, + on + ) def existential(vs: NonEmptyList[(String, Option[Kind])], on: Type): Type = - Type.exists(vs.map { - case (s, None) => (Type.Var.Bound(s), Kind.Type) - case (s, Some(k)) => (Type.Var.Bound(s), k) - }, on) + Type.exists( + vs.map { + case (s, None) => (Type.Var.Bound(s), Kind.Type) + case (s, Some(k)) => (Type.Var.Bound(s), k) + }, + on + ) def makeTuple(lst: List[Type]) = Type.Tuple(lst) @@ -1229,8 +1325,11 @@ object Type { def unapplyRoot(a: Type): Option[Doc] = a match { case TyConst(Const.Defined(p, n)) => - Some(Document[PackageName].document(p) + coloncolon + Document[Identifier].document(n.ident)) - case TyVar(Var.Bound(s)) => Some(Doc.text(s)) + Some( + Document[PackageName] + .document(p) + coloncolon + Document[Identifier].document(n.ident) + ) + case TyVar(Var.Bound(s)) => Some(Doc.text(s)) case TyVar(Var.Skolem(n, _, e, i)) => // TODO Kind val dol = "$" @@ -1246,58 +1345,82 @@ object Type { def unapplyFn(a: Type): Option[(NonEmptyList[Type], Type)] = a match { case Fun(as, b) => Some((as, b)) - case _ => None + case _ => None } - def unapplyUniversal(a: Type): Option[(List[(String, Option[Kind])], Type)] = + def unapplyUniversal( + a: Type + ): Option[(List[(String, Option[Kind])], Type)] = a match { case _: Rho => None case q: Quantified => q.quant match { case Quantification.ForAll(vs) => - Some((vs.map { - case (v, k) => (v.name, Some(k)) - }.toList, q.in)) + Some( + ( + vs.map { case (v, k) => + (v.name, Some(k)) + }.toList, + q.in + ) + ) case Quantification.Dual(forall, ex) => - Some((forall.map { - case (v, k) => (v.name, Some(k)) - }.toList, exists(ex, q.in))) + Some( + ( + forall.map { case (v, k) => + (v.name, Some(k)) + }.toList, + exists(ex, q.in) + ) + ) case _ => None } } - def unapplyExistential(a: Type): Option[(List[(String, Option[Kind])], Type)] = + def unapplyExistential( + a: Type + ): Option[(List[(String, Option[Kind])], Type)] = a match { case _: Rho => None case q: Quantified => q.quant match { case Quantification.Exists(vs) => - Some((vs.map { - case (v, k) => (v.name, Some(k)) - }.toList, q.in)) + Some( + ( + vs.map { case (v, k) => + (v.name, Some(k)) + }.toList, + q.in + ) + ) case Quantification.Dual(forall, exists) => - Some((exists.map { - case (v, k) => (v.name, Some(k)) - }.toList, forAll(forall, q.in))) + Some( + ( + exists.map { case (v, k) => + (v.name, Some(k)) + }.toList, + forAll(forall, q.in) + ) + ) case _ => None } } def unapplyTypeApply(a: Type): Option[(Type, List[Type])] = a match { - case ta@TyApply(_, _) => Some(unapplyAll(ta)) - case _ => None + case ta @ TyApply(_, _) => Some(unapplyAll(ta)) + case _ => None } def unapplyTuple(a: Type): Option[List[Type]] = a match { case Tuple(as) => Some(as) - case _ => None + case _ => None } } - /** - * Parse fully resolved types: package::type - */ + + /** Parse fully resolved types: package::type + */ def fullyResolvedParser: P[Type] = FullResolved.parser def fullyResolvedDocument: Document[Type] = FullResolved.document def typeParser: TypeParser[Type] = FullResolved @@ -1307,21 +1430,19 @@ object Type { val tpes = alignBinders((1 to n).toList, Set.empty) - val tpesDecl = tpes - .iterator + val tpesDecl = tpes.iterator .map { case (_, tpe) => s"${tpe.name}: +*" } .mkString("[", ", ", "]") - tpes - .iterator + tpes.iterator .map { case (i, v) => s"item$i: ${v.name}" } .mkString(s"struct Tuple$n$tpesDecl(", ", ", ")") } def allTupleCode: String = { - (1 to 32).iterator.map { i => s"Tuple$i()"}.mkString("", ",\n", ",\n") + + (1 to 32).iterator.map { i => s"Tuple$i()" }.mkString("", ",\n", ",\n") + (1 to 32).iterator.map(tupleCodeGen).mkString("\n") } @@ -1334,8 +1455,6 @@ object Type { StrType -> Kind.Type, CharType -> Kind.Type, UnitType -> Kind.Type - )) - .map { case (t, k) => (t.tpe.toDefined, k) } - .toMap + )).map { case (t, k) => (t.tpe.toDefined, k) }.toMap } diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala index d61586d2a..38b4c2b0d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala @@ -5,9 +5,13 @@ import org.bykn.bosatsu.Identifier.{Bindable, Constructor} import scala.collection.immutable.SortedMap class TypeEnv[+A] private ( - protected val values: SortedMap[(PackageName, Identifier), Type], - protected val constructors: SortedMap[(PackageName, Constructor), (DefinedType[A], ConstructorFn)], - val definedTypes: SortedMap[(PackageName, TypeName), DefinedType[A]]) { + protected val values: SortedMap[(PackageName, Identifier), Type], + protected val constructors: SortedMap[ + (PackageName, Constructor), + (DefinedType[A], ConstructorFn) + ], + val definedTypes: SortedMap[(PackageName, TypeName), DefinedType[A]] +) { override def equals(that: Any): Boolean = that match { @@ -35,10 +39,16 @@ class TypeEnv[+A] private ( def allDefinedTypes: List[DefinedType[A]] = definedTypes.values.toList.sortBy { dt => (dt.packageName, dt.name) } - def getConstructor(p: PackageName, c: Constructor): Option[(DefinedType[A], ConstructorFn)] = + def getConstructor( + p: PackageName, + c: Constructor + ): Option[(DefinedType[A], ConstructorFn)] = constructors.get((p, c)) - def getConstructorParams(p: PackageName, c: Constructor): Option[List[(Bindable, Type)]] = + def getConstructorParams( + p: PackageName, + c: Constructor + ): Option[List[(Bindable, Type)]] = constructors.get((p, c)).map(_._2.args) def getType(p: PackageName, t: TypeName): Option[DefinedType[A]] = @@ -53,7 +63,9 @@ class TypeEnv[+A] private ( values.get((p, n)) // when we have resolved, we can get the types of constructors out - def getValue(p: PackageName, n: Identifier)(implicit ev: A <:< Kind.Arg): Option[Type] = + def getValue(p: PackageName, n: Identifier)(implicit + ev: A <:< Kind.Arg + ): Option[Type] = n match { case c @ Constructor(_) => // constructors are never external defs @@ -62,65 +74,82 @@ class TypeEnv[+A] private ( } // when we have resolved, we can get the types of constructors out - def localValuesOf(p: PackageName)(implicit ev: A <:< Kind.Arg): SortedMap[Identifier, Type] = { + def localValuesOf( + p: PackageName + )(implicit ev: A <:< Kind.Arg): SortedMap[Identifier, Type] = { val bldr = SortedMap.newBuilder[Identifier, Type] // add externals bldr ++= values.iterator.collect { case ((pn, n), v) if pn == p => (n, v) } // add constructors - bldr ++= constructors.iterator.collect { case ((pn, n), (dt, cf)) if pn == p => (n, dt.fnTypeOf(cf)) } + bldr ++= constructors.iterator.collect { + case ((pn, n), (dt, cf)) if pn == p => (n, dt.fnTypeOf(cf)) + } bldr.result() } - def addConstructor[A1 >: A](pack: PackageName, - dt: DefinedType[A1], - cf: ConstructorFn): TypeEnv[A1] = { - val nec = constructors.updated((pack, cf.name), (dt, cf)) - val dt1 = definedTypes.updated((dt.packageName, dt.name), dt) - new TypeEnv(values = values, constructors = nec, definedTypes = dt1) - } + def addConstructor[A1 >: A]( + pack: PackageName, + dt: DefinedType[A1], + cf: ConstructorFn + ): TypeEnv[A1] = { + val nec = constructors.updated((pack, cf.name), (dt, cf)) + val dt1 = definedTypes.updated((dt.packageName, dt.name), dt) + new TypeEnv(values = values, constructors = nec, definedTypes = dt1) + } - /** - * only add the type, do not add any of the constructors - * used when importing values - */ + /** only add the type, do not add any of the constructors used when importing + * values + */ def addDefinedType[A1 >: A](dt: DefinedType[A1]): TypeEnv[A1] = { val dt1 = definedTypes.updated((dt.packageName, dt.name), dt) - new TypeEnv(constructors = constructors, definedTypes = dt1, values = values) + new TypeEnv( + constructors = constructors, + definedTypes = dt1, + values = values + ) } - /** - * add a DefinedType and all of its constructors. This is done locally for - * a package - */ - def addDefinedTypeAndConstructors[A1 >: A](dt: DefinedType[A1]): TypeEnv[A1] = { + /** add a DefinedType and all of its constructors. This is done locally for a + * package + */ + def addDefinedTypeAndConstructors[A1 >: A]( + dt: DefinedType[A1] + ): TypeEnv[A1] = { val dt1 = definedTypes.updated((dt.packageName, dt.name), dt) val cons1 = dt.constructors - .foldLeft(constructors: SortedMap[(PackageName, Constructor), (DefinedType[A1], ConstructorFn)]) { - case (cons0, cf) => - cons0.updated((dt.packageName, cf.name), (dt, cf)) + .foldLeft( + constructors: SortedMap[ + (PackageName, Constructor), + (DefinedType[A1], ConstructorFn) + ] + ) { case (cons0, cf) => + cons0.updated((dt.packageName, cf.name), (dt, cf)) } new TypeEnv(constructors = cons1, definedTypes = dt1, values = values) } - /** - * External values cannot be inferred and have to be fully - * annotated - */ - def addExternalValue(pack: PackageName, name: Identifier, t: Type): TypeEnv[A] = + /** External values cannot be inferred and have to be fully annotated + */ + def addExternalValue( + pack: PackageName, + name: Identifier, + t: Type + ): TypeEnv[A] = new TypeEnv( constructors = constructors, definedTypes = definedTypes, - values = values.updated((pack, name), t)) + values = values.updated((pack, name), t) + ) - lazy val typeConstructors: SortedMap[(PackageName, Constructor), (List[(Type.Var.Bound, A)], List[Type], Type.Const.Defined)] = + lazy val typeConstructors: SortedMap[ + (PackageName, Constructor), + (List[(Type.Var.Bound, A)], List[Type], Type.Const.Defined) + ] = constructors.map { case (pc, (dt, cf)) => - (pc, - (dt.annotatedTypeParams, - cf.args.map(_._2), - dt.toTypeConst)) + (pc, (dt.annotatedTypeParams, cf.args.map(_._2), dt.toTypeConst)) } def definedTypeFor(c: (PackageName, Constructor)): Option[DefinedType[A]] = @@ -132,9 +161,11 @@ class TypeEnv[+A] private ( } def ++[A1 >: A](that: TypeEnv[A1]): TypeEnv[A1] = - new TypeEnv(values ++ that.values, + new TypeEnv( + values ++ that.values, constructors ++ that.constructors, - definedTypes ++ that.definedTypes) + definedTypes ++ that.definedTypes + ) def toKindMap(implicit ev: A <:< Kind.Arg): Map[Type.Const.Defined, Kind] = { type F[+Z] = List[DefinedType[Z]] @@ -147,17 +178,24 @@ object TypeEnv { val empty: TypeEnv[Nothing] = new TypeEnv( SortedMap.empty[(PackageName, Identifier), Type], - SortedMap.empty[(PackageName, Constructor), (DefinedType[Nothing], ConstructorFn)], - SortedMap.empty[(PackageName, TypeName), DefinedType[Nothing]]) - - /** - * Adds all the types and all the constructors from the given types - */ + SortedMap.empty[ + (PackageName, Constructor), + (DefinedType[Nothing], ConstructorFn) + ], + SortedMap.empty[(PackageName, TypeName), DefinedType[Nothing]] + ) + + /** Adds all the types and all the constructors from the given types + */ def fromDefinitions[A](defs: List[DefinedType[A]]): TypeEnv[A] = defs.foldLeft(empty: TypeEnv[A])(_.addDefinedTypeAndConstructors(_)) def fromParsed[A](p: ParsedTypeEnv[A]): TypeEnv[A] = { - val t1 = p.allDefinedTypes.foldLeft(empty: TypeEnv[A])(_.addDefinedTypeAndConstructors(_)) - p.externalDefs.foldLeft(t1) { case (t1, (p, n, t)) => t1.addExternalValue(p, n, t) } + val t1 = p.allDefinedTypes.foldLeft(empty: TypeEnv[A])( + _.addDefinedTypeAndConstructors(_) + ) + p.externalDefs.foldLeft(t1) { case (t1, (p, n, t)) => + t1.addExternalValue(p, n, t) + } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/set/Rel.scala b/core/src/main/scala/org/bykn/bosatsu/set/Rel.scala index b740a6186..fcadec20b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/set/Rel.scala +++ b/core/src/main/scala/org/bykn/bosatsu/set/Rel.scala @@ -10,7 +10,7 @@ sealed abstract class Rel { lhs => def isSubtype: Boolean = this match { case Sub | Same => true - case _ => false + case _ => false } def isStrictSupertype: Boolean = @@ -19,14 +19,14 @@ sealed abstract class Rel { lhs => def isSupertype: Boolean = this match { case Super | Same => true - case _ => false + case _ => false } def invert: Rel = this match { - case Sub => Super + case Sub => Super case Super => Sub - case _ => this + case _ => this } // implements transitivity of comparisons @@ -38,11 +38,11 @@ sealed abstract class Rel { lhs => // e.g. (a < b) and (b < c) imply (a < c) def |+|(rhs: Rel): Rel = (lhs, rhs) match { - case (x, y) if x == y => x + case (x, y) if x == y => x case (Disjoint, _) | (_, Disjoint) => Disjoint - case (Same, _) => rhs - case (_, Same) => lhs - case _ => Intersects + case (Same, _) => rhs + case (_, Same) => lhs + case _ => Intersects } def lazyCombine(rhs: => Rel): Rel = @@ -59,4 +59,3 @@ object Rel { case object Intersects extends Rel case object Disjoint extends Rel } - diff --git a/core/src/main/scala/org/bykn/bosatsu/set/Relatable.scala b/core/src/main/scala/org/bykn/bosatsu/set/Relatable.scala index 384e90368..1e611ca13 100644 --- a/core/src/main/scala/org/bykn/bosatsu/set/Relatable.scala +++ b/core/src/main/scala/org/bykn/bosatsu/set/Relatable.scala @@ -8,10 +8,10 @@ object Relatable { def apply[A](implicit r: Relatable[A]): Relatable[A] = r def fromSubsetIntersects[A]( - // True when all elements of left are in right - subset: (A, A) => Boolean, - // true when there exists 1 or more element in both - intersects: (A, A) => Boolean + // True when all elements of left are in right + subset: (A, A) => Boolean, + // true when there exists 1 or more element in both + intersects: (A, A) => Boolean ): Relatable[A] = new Relatable[A] { def relate(left: A, right: A): Rel = { @@ -20,34 +20,36 @@ object Relatable { if (leftSub) { if (rightSub) Rel.Same else Rel.Sub - } - else if (rightSub) Rel.Super + } else if (rightSub) Rel.Super else if (intersects(left, right)) Rel.Intersects else Rel.Disjoint } } def setRelatable[A]: Relatable[Set[A]] = - fromSubsetIntersects(_.subsetOf(_), (s1, s2) => { - if (s1.size <= s2.size) s1.exists(s2) - else s2.exists(s1) - }) - + fromSubsetIntersects( + _.subsetOf(_), + (s1, s2) => { + if (s1.size <= s2.size) s1.exists(s2) + else s2.exists(s1) + } + ) + def fromUniversalEquals[A]: Relatable[A] = new Relatable[A] { def relate(i: A, j: A) = if (i == j) Rel.Same else Rel.Disjoint } - /** - * Make a relatable where unions are represented by Lists. - * we need three functions: - * 1. is a value A empty - * 2. compute the intersection of two values - * 3. given A, either split it in two, or give a function to see if a == union(as) (given a >= union(as)) - */ + /** Make a relatable where unions are represented by Lists. we need three + * functions: + * 1. is a value A empty 2. compute the intersection of two values 3. given + * A, either split it in two, or give a function to see if a == + * union(as) (given a >= union(as)) + */ def listUnion[A: Relatable]( - isEmptyFn: A => Boolean, - intersectFn: (A, A) => List[A], - solveOne: A => Either[List[A] => Boolean, (A, A)]): Relatable[List[A]] = + isEmptyFn: A => Boolean, + intersectFn: (A, A) => List[A], + solveOne: A => Either[List[A] => Boolean, (A, A)] + ): Relatable[List[A]] = new Relatable[List[A]] { self => val unionRelMod: UnionRelModule[List[A]] = new UnionRelModule[List[A]] { @@ -55,9 +57,9 @@ object Relatable { def isEmpty(ls: List[A]) = ls.forall(isEmptyFn) def deunion(ls: List[A]) = ls.size match { - // $COVERAGE-OFF$ + // $COVERAGE-OFF$ case 0 => sys.error("invariant violation: deunion(Nil)") - // $COVERAGE-ON$ + // $COVERAGE-ON$ case 1 => solveOne(ls.head) match { case Left(equ) => @@ -73,15 +75,17 @@ object Relatable { Right((ls.splitAt(sz / 2))) } - def cheapUnion(head: List[A], tail: List[List[A]]) = (head :: tail).flatten.distinct + def cheapUnion(head: List[A], tail: List[List[A]]) = + (head :: tail).flatten.distinct def intersect(a: List[A], b: List[A]): List[A] = if (a.isEmpty || b.isEmpty) Nil - else for { - ai <- a - bi <- b - i <- intersectFn(ai, bi) - } yield i + else + for { + ai <- a + bi <- b + i <- intersectFn(ai, bi) + } yield i } def relate(left: List[A], right: List[A]): Rel = @@ -106,17 +110,18 @@ object Relatable { } } - /** - * unionCompare compares a <:> (b | c) - * - * It can give a Rel or a PartialRel as a result - * Note: we always evaluate a <:> b so if you - * can choose, that should be the simpler value to check - * - * Important: b and c cannot be bottom values. They cannot - * be empty. - */ - private def unionRelCompare1[A: Relatable](a: A, b: A, c: A): Either[PartialRel, Rel] = { + /** unionCompare compares a <:> (b | c) + * + * It can give a Rel or a PartialRel as a result Note: we always evaluate a + * <:> b so if you can choose, that should be the simpler value to check + * + * Important: b and c cannot be bottom values. They cannot be empty. + */ + private def unionRelCompare1[A: Relatable]( + a: A, + b: A, + c: A + ): Either[PartialRel, Rel] = { import Rel._ import PartialRel._ @@ -127,40 +132,46 @@ object Relatable { case Same => ac match { // a=b, so b|c = a|c, so a <= b|c case Sub => Right(Sub) // (a=b) then b|c = a|c, which == a, if c == a. - case Same => Right(Same) // a = b = c + case Same => Right(Same) // a = b = c case Super => Right(Same) // (a=b), a > c. So, b|c = a|c = a - case Intersects => Right(Sub) // a=b, a n c. b|c = a|c which is bigger than a + case Intersects => + Right(Sub) // a=b, a n c. b|c = a|c which is bigger than a case Disjoint => Right(Sub) // a=b, so b|c = a|c } case Super => ac match { - case Sub => Right(Sub) // a < c, so a < (b|c) + case Sub => Right(Sub) // a < c, so a < (b|c) case Same => Right(Same) // a = c, a > b. b|c = b|a, and a > b - case Super => Left(SuperSame) // if a > b, a > c, then a > (b|c) or a = (b|c) + case Super => + Left(SuperSame) // if a > b, a > c, then a > (b|c) or a = (b|c) case Intersects => // a > b, c has some outside a, but b could cover all not in c so // a < b|c or a n b|c, Sub or Intersect. Left(SubIntersects) - case Disjoint => Right(Intersects) // a > b, a ! c, all of c is outside a, but all b inside + case Disjoint => + Right( + Intersects + ) // a > b, a ! c, all of c is outside a, but all b inside } case Intersects => ac match { - case Sub => Right(Sub) // a < c so a < (b|c) - case Same => Right(Sub) // a n b, a = c. b|c = b|a, so a < b|a + case Sub => Right(Sub) // a < c so a < (b|c) + case Same => Right(Sub) // a n b, a = c. b|c = b|a, so a < b|a case Super => // a > c, b has some outside a, but c could cover all not in b so // a < b|c or a n b|c, Sub or Intersect. Left(SubIntersects) - case Intersects => Left(SubIntersects) // a n b, a n c, so a < (b|c) or a n (b|c). + case Intersects => + Left(SubIntersects) // a n b, a n c, so a < (b|c) or a n (b|c). case Disjoint => Right(Intersects) // a n b, a ! c, b|c } case Disjoint => ac match { - case Sub => Right(Sub) - case Same => Right(Sub) - case Super => Right(Intersects) + case Sub => Right(Sub) + case Same => Right(Sub) + case Super => Right(Intersects) case Intersects => Right(Intersects) - case Disjoint => Right(Disjoint) + case Disjoint => Right(Disjoint) } } } @@ -170,18 +181,16 @@ object Relatable { import PartialRel._ def relatable: Relatable[A] - /** - * Either deunion *non-empty* a into two non-empty values - * or return a function that solves the problem - * of a <:> (b | c) where we know for sure - * that the answer is either Super or Same - * which is to say, we know that a >= (b | c) - */ + + /** Either deunion *non-empty* a into two non-empty values or return a + * function that solves the problem of a <:> (b | c) where we know for sure + * that the answer is either Super or Same which is to say, we know that a + * >= (b | c) + */ def deunion(a: A): Either[(A, A) => Rel.SuperOrSame, (A, A)] - /** - * This can be a cheap union, not a totally - * normalizing union. - */ + + /** This can be a cheap union, not a totally normalizing union. + */ def cheapUnion(head: A, tail: List[A]): A def intersect(a: A, b: A): A @@ -191,62 +200,68 @@ object Relatable { private def subIntersectsCase(ab: A, a1: A, a2: A): Rel = unionRelCompare1(ab, a1, a2)(relatable) match { - case Right(Sub) => Intersects + case Right(Sub) => Intersects case Right(Same) => Sub - case Left(SubIntersects) => Intersects // we know a <:> b is < or n, so a&b <:> a is < implies this + case Left(SubIntersects) => + Intersects // we know a <:> b is < or n, so a&b <:> a is < implies this case Left(SuperSame) => Sub - // $COVERAGE-OFF$ + // $COVERAGE-OFF$ case Right(rel) => // this should never happen because we know that ab is sub or intersect sys.error(s"unexpected rel: $rel, ab = $ab, a1 = $a1, a2 = $a2") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - /** - * compare a to (b1|b2) - */ + /** compare a to (b1|b2) + */ final def unionRelCompare(a: A, b1: A, b2: A): Rel = if (isEmpty(b1)) relatable.relate(a, b2) else if (isEmpty(b2)) relatable.relate(a, b1) - else unionRelCompare1(a, b1, b2)(relatable) match { - case Right(rel) => - rel - case Left(p) => - // Note, a is never empty here because if it is, unionRelCompare1 is Sub - (deunion(a), p) match { - case (Right((a1, a2)), SubIntersects) => - val head = intersect(b1, a1) - val tail = intersect(b2, a1) :: intersect(b1, a2) :: intersect(b2, a2) :: Nil - val ab = cheapUnion(head, tail) - subIntersectsCase(ab, a1, a2) - case (Right((a1, a2)), SuperSame) => - // if we have SuperSame and invert(p1) what is the result - @inline def andInvert(p1: PartialRel): Rel = - p1 match { - case SuperSame => Same - case SubIntersects => Super - } + else + unionRelCompare1(a, b1, b2)(relatable) match { + case Right(rel) => + rel + case Left(p) => + // Note, a is never empty here because if it is, unionRelCompare1 is Sub + (deunion(a), p) match { + case (Right((a1, a2)), SubIntersects) => + val head = intersect(b1, a1) + val tail = intersect(b2, a1) :: intersect(b1, a2) :: intersect( + b2, + a2 + ) :: Nil + val ab = cheapUnion(head, tail) + subIntersectsCase(ab, a1, a2) + case (Right((a1, a2)), SuperSame) => + // if we have SuperSame and invert(p1) what is the result + @inline def andInvert(p1: PartialRel): Rel = + p1 match { + case SuperSame => Same + case SubIntersects => Super + } - // we know that a1 and a2 are not empty because they are the result - // of a deunion - unionRelCompare1(cheapUnion(b1, b2 :: Nil), a1, a2)(relatable) match { - case Left(r) => andInvert(r) - case Right(r) => r.invert - } - case (Left(f), SubIntersects) => - // we know a < (b1| b2) or it intersects - // so, if we try again with ((b1 | b2) & a) - // if a < (b1 | b2), then a <:> ((b1 | b2) & a) == Same - // else if a intersects (b1 | b2), then a <:> ((b1 | b2) & a) == Super - val ab1 = intersect(a, b1) - val ab2 = intersect(a, b2) - f(ab1, ab2) match { - case Same => Sub - case Super => Intersects - } - case (Left(f), SuperSame) => - f(b1, b2) - } - } - } -} \ No newline at end of file + // we know that a1 and a2 are not empty because they are the result + // of a deunion + unionRelCompare1(cheapUnion(b1, b2 :: Nil), a1, a2)( + relatable + ) match { + case Left(r) => andInvert(r) + case Right(r) => r.invert + } + case (Left(f), SubIntersects) => + // we know a < (b1| b2) or it intersects + // so, if we try again with ((b1 | b2) & a) + // if a < (b1 | b2), then a <:> ((b1 | b2) & a) == Same + // else if a intersects (b1 | b2), then a <:> ((b1 | b2) & a) == Super + val ab1 = intersect(a, b1) + val ab2 = intersect(a, b2) + f(ab1, ab2) match { + case Same => Sub + case Super => Intersects + } + case (Left(f), SuperSame) => + f(b1, b2) + } + } + } +} diff --git a/core/src/main/scala/org/bykn/bosatsu/set/SetOps.scala b/core/src/main/scala/org/bykn/bosatsu/set/SetOps.scala index bea2740e5..0c4670e62 100644 --- a/core/src/main/scala/org/bykn/bosatsu/set/SetOps.scala +++ b/core/src/main/scala/org/bykn/bosatsu/set/SetOps.scala @@ -1,86 +1,75 @@ package org.bykn.bosatsu.set -/** - * These are set operations we can do on patterns - */ +/** These are set operations we can do on patterns + */ trait SetOps[A] extends Relatable[A] { - /** - * a representation of the set with everything in it - * not all sets have upper bounds we can represent - */ + /** a representation of the set with everything in it not all sets have upper + * bounds we can represent + */ def top: Option[A] - /** - * if everything is <= A, maybe more than one representation of top - */ + /** if everything is <= A, maybe more than one representation of top + */ def isTop(a: A): Boolean - /** - * intersect two values and return a union represented as a list - */ + /** intersect two values and return a union represented as a list + */ def intersection(a1: A, a2: A): List[A] def relate(a1: A, a2: A): Rel - /** - * Return true if a1 and a2 are disjoint - */ + /** Return true if a1 and a2 are disjoint + */ def disjoint(a1: A, a2: A): Boolean = relate(a1, a2) == Rel.Disjoint - /** - * remove a2 from a1 return a union represented as a list - * - * this should be the tightest upperbound we can find - */ + /** remove a2 from a1 return a union represented as a list + * + * this should be the tightest upperbound we can find + */ def difference(a1: A, a2: A): List[A] - /** - * This should unify the union into the fewest number - * of patterns without changing the meaning of the union - */ + /** This should unify the union into the fewest number of patterns without + * changing the meaning of the union + */ def unifyUnion(u: List[A]): List[A] - /** - * if true, all elements in a are in b, - * if false, there is no promise - * - * this should be a reasonable cheap operation - * that is allowed to say no in order - * to avoid very expensive work - */ + /** if true, all elements in a are in b, if false, there is no promise + * + * this should be a reasonable cheap operation that is allowed to say no in + * order to avoid very expensive work + */ def subset(a: A, b: A): Boolean = relate(a, b).isSubtype def equiv(a: A, b: A): Boolean = relate(a, b) == Rel.Same - /** - * Remove all items in p2 from all items in p1 - * and unify the remaining union - */ + /** Remove all items in p2 from all items in p1 and unify the remaining union + */ def differenceAll(p1: List[A], p2: List[A]): List[A] = p2.foldLeft(p1) { (p1s, p) => // remove p from all of p1s p1s.flatMap(difference(_, p)) } - /** - * if top is defined - * a list of matches that would make the current set of matches total - * - * Note, a law here is that: - * missingBranches(te, t, branches).flatMap { ms => - * assert(missingBranches(te, t, branches ::: ms).isEmpty) - * } - */ + /** if top is defined a list of matches that would make the current set of + * matches total + * + * Note, a law here is that: missingBranches(te, t, branches).flatMap { ms => + * assert(missingBranches(te, t, branches ::: ms).isEmpty) } + */ def missingBranches(top: List[A], branches: List[A]): List[A] = { def clearSubs(branches: List[A], front: List[A]): List[A] = branches match { case Nil => Nil case h :: tail => - if (tail.exists(relate(h, _).isSubtype) || front.exists(relate(h, _).isSubtype)) clearSubs(tail, front) + if ( + tail.exists(relate(h, _).isSubtype) || front.exists( + relate(h, _).isSubtype + ) + ) clearSubs(tail, front) else h :: clearSubs(tail, h :: front) } // we can subtract in any order @@ -103,7 +92,9 @@ trait SetOps[A] extends Relatable[A] { } } val normB = clearSubs(branches, Nil) - val missing = SetOps.greedySearch(lookahead, top, unifyUnion(normB))(differenceAll(_, _))(superSetIsSmaller) + val missing = SetOps.greedySearch(lookahead, top, unifyUnion(normB))( + differenceAll(_, _) + )(superSetIsSmaller) // filter any unreachable, which can happen when earlier items shadow later // ones @@ -111,10 +102,9 @@ trait SetOps[A] extends Relatable[A] { missing.filterNot(unreach.toSet) } - /** - * if we match these branches in order, which of them - * are completely covered by previous matches - */ + /** if we match these branches in order, which of them are completely covered + * by previous matches + */ def unreachableBranches(branches: List[A]): List[A] = unreachableBranches(init = Nil, branches = branches) @@ -127,9 +117,10 @@ trait SetOps[A] extends Relatable[A] { } withPrev(branches, init) - .collect { case (p, prev) if differenceAll(p :: Nil, prev).isEmpty => - // if there is nothing, this is unreachable - p + .collect { + case (p, prev) if differenceAll(p :: Nil, prev).isEmpty => + // if there is nothing, this is unreachable + p } } } @@ -153,7 +144,11 @@ object SetOps { // we search for the best order to apply the diffs that minimizes the score @annotation.tailrec - final def greedySearch[A: Ordering, B](lookahead: Int, union: A, diffs: List[B])(fn: (A, List[B]) => A): A = + final def greedySearch[A: Ordering, B]( + lookahead: Int, + union: A, + diffs: List[B] + )(fn: (A, List[B]) => A): A = diffs match { case Nil => union case _ => @@ -166,7 +161,9 @@ object SetOps { // choose a diff that starts the most // number of results that are the smallest val best = trials - .collect { case (u1, p) if implicitly[Ordering[A]].equiv(u1, smallest._1) => p } + .collect { + case (u1, p) if implicitly[Ordering[A]].equiv(u1, smallest._1) => p + } .groupBy(identity) .map { case (k, v) => (k, v.size) } .maxBy(_._2) @@ -176,7 +173,6 @@ object SetOps { greedySearch(lookahead, u1, diffs.filterNot(_ == best))(fn) } - def distinct[A](implicit ordA: Ordering[A]): SetOps[A] = new SetOps[A] { def top: Option[A] = None @@ -194,7 +190,7 @@ object SetOps { def nub(u: List[A]): List[A] = u match { case Nil | _ :: Nil => u - case h1 :: (t1@(h2 :: _)) => + case h1 :: (t1 @ (h2 :: _)) => if (ordA.equiv(h1, h2)) nub(t1) else h1 :: nub(t1) } @@ -239,7 +235,7 @@ object SetOps { else b.exists(a) } ) - + def relate(a: Set[A], b: Set[A]) = rel.relate(a, b) } @@ -285,7 +281,7 @@ object SetOps { def top: Option[(A, B)] = (sa.top, sb.top) match { case (Some(a), Some(b)) => Some((a, b)) - case _ => None + case _ => None } def isTop(a: (A, B)): Boolean = @@ -349,15 +345,15 @@ object SetOps { def unifyUnion(u: List[(A, B)]): List[(A, B)] = { def step[X, Y](u: List[(X, Y)], sy: SetOps[Y]): Option[List[(X, Y)]] = { var change = false - val u1 = u.groupBy(_._1) + val u1 = u + .groupBy(_._1) .iterator .flatMap { case (x, xys) => val uy = sy.unifyUnion(xys.map(_._2)) if (uy.size < xys.size) { change = true uy.map((x, _)) - } - else xys + } else xys } .toList @@ -369,7 +365,7 @@ object SetOps { step(u, sb) match { case None => step(u.map(_.swap), sa) match { - case None => u + case None => u case Some(u2) => // we got a change unifying a loop(u2.map(_.swap)) diff --git a/core/src/test/scala/org/bykn/bosatsu/CollectionUtilsTest.scala b/core/src/test/scala/org/bykn/bosatsu/CollectionUtilsTest.scala index f812b4e62..6998905bb 100644 --- a/core/src/test/scala/org/bykn/bosatsu/CollectionUtilsTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/CollectionUtilsTest.scala @@ -1,13 +1,16 @@ package org.bykn.bosatsu -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class CollectionUtilsTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) - //PropertyCheckConfiguration(minSuccessful = 5) + // PropertyCheckConfiguration(minSuccessful = 5) test("listToUnique works for maps converted to lists") { forAll { (m: Map[Int, Int]) => diff --git a/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala b/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala index f162145d5..b57f0893c 100644 --- a/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import Identifier.Bindable @@ -14,27 +17,32 @@ class DeclarationTest extends AnyFunSuite { import Generators.shrinkDecl implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 200 else 20) - //PropertyCheckConfiguration(minSuccessful = 50) + // PropertyCheckConfiguration(minSuccessful = 5000) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 200 else 20 + ) + // PropertyCheckConfiguration(minSuccessful = 50) implicit val emptyRegion: Region = Region(0, 0) val genDecl = Generators.genDeclaration(depth = 4) lazy val genNonFree: Gen[Declaration.NonBinding] = - genDecl.flatMap { - case decl: Declaration.NonBinding if decl.freeVars.isEmpty => Gen.const(decl) - case _ => genNonFree - } - + genDecl.flatMap { + case decl: Declaration.NonBinding if decl.freeVars.isEmpty => + Gen.const(decl) + case _ => genNonFree + } test("freeVarsSet is a subset of allVars") { forAll(genDecl) { decl => val frees = decl.freeVars val av = decl.allNames val missing = frees -- av - assert(missing.isEmpty, s"expression:\n\n${decl}\n\nallVars: $av\n\nfrees: $frees") + assert( + missing.isEmpty, + s"expression:\n\n${decl}\n\nallVars: $av\n\nfrees: $frees" + ) } } @@ -53,8 +61,10 @@ class DeclarationTest extends AnyFunSuite { val d1Str = d1.toDoc.render(80) val dSubStr = d0sub.toDoc.render(80) - assert(!d0sub.freeVars.contains(b), - s"subs:\n\n$d0Str\n\n===============\n\n$d1Str===============\n\n$dSubStr") + assert( + !d0sub.freeVars.contains(b), + s"subs:\n\n$d0Str\n\n===============\n\n$d1Str===============\n\n$dSubStr" + ) } } } @@ -67,11 +77,11 @@ class DeclarationTest extends AnyFunSuite { lazy val notFree: Gen[Bindable] = Generators.bindIdentGen.flatMap { case b if frees(b) => notFree - case b => Gen.const(b) + case b => Gen.const(b) } notFree.map((decl, _)) - } + } def law(b: Bindable, d1: Declaration.NonBinding, d0: Declaration) = { val frees = d0.freeVars @@ -90,8 +100,7 @@ class DeclarationTest extends AnyFunSuite { // there must be some diff val diffPos = - left - .iterator + left.iterator .zip(right.iterator) .zipWithIndex .dropWhile { case ((a, b), _) => a == b } @@ -102,7 +111,8 @@ class DeclarationTest extends AnyFunSuite { val leftAt = left.drop(diffPos).take(50) val rightAt = right.drop(diffPos).take(50) val diff = s"offset: $diffPos$line$leftAt\n\n$line$rightAt" - val msg = s"left$line${left}\n\nright$line$right\n\ndiff$line$diff" + val msg = + s"left$line${left}\n\nright$line$right\n\ndiff$line$diff" assert(false, msg) } } @@ -123,10 +133,74 @@ class DeclarationTest extends AnyFunSuite { val b = Identifier.Backticked("") val d1 = Literal(Lit.fromInt(0)) val d0 = DefFn( - DefStatement(Name("mfLjwok"),None, NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("foo")))),None, - (NotSameLine(Padding(10,Indented(10,Var(Backticked(""))))), - Padding(10,Binding(BindingStatement( - Pattern.Var(Backticked("")),Var(Constructor("Rgt")),Padding(1,DefFn(DefStatement(Backticked(""),None,NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("bar")))),None,(NotSameLine(Padding(2,Indented(4,Literal(Lit.fromInt(42))))),Padding(2,DefFn(DefStatement(Name("gkxAckqpatu"),None, NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("quux")))),Some(TypeRef.TypeName(TypeName(Constructor("Y")))),(NotSameLine(Padding(6,Indented(8,Literal(Lit("oimsu"))))),Padding(2,Var(Name("j"))))))))))))))))) + DefStatement( + Name("mfLjwok"), + None, + NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("foo")))), + None, + ( + NotSameLine(Padding(10, Indented(10, Var(Backticked(""))))), + Padding( + 10, + Binding( + BindingStatement( + Pattern.Var(Backticked("")), + Var(Constructor("Rgt")), + Padding( + 1, + DefFn( + DefStatement( + Backticked(""), + None, + NonEmptyList.one( + NonEmptyList.one(Pattern.Var(Name("bar"))) + ), + None, + ( + NotSameLine( + Padding( + 2, + Indented(4, Literal(Lit.fromInt(42))) + ) + ), + Padding( + 2, + DefFn( + DefStatement( + Name("gkxAckqpatu"), + None, + NonEmptyList.one( + NonEmptyList.one( + Pattern.Var(Name("quux")) + ) + ), + Some( + TypeRef.TypeName( + TypeName(Constructor("Y")) + ) + ), + ( + NotSameLine( + Padding( + 6, + Indented(8, Literal(Lit("oimsu"))) + ) + ), + Padding(2, Var(Name("j"))) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) (b, d1, d0) } @@ -141,7 +215,7 @@ class DeclarationTest extends AnyFunSuite { genDecl.flatMap { decl => val frees = decl.freeVars.toList frees match { - case Nil => genFrees + case Nil => genFrees case nonEmpty => Gen.oneOf(nonEmpty).map((decl, _)) } } @@ -178,13 +252,17 @@ class DeclarationTest extends AnyFunSuite { val resD = res.map(unsafeParse(Declaration.parser(""), _)) val b = unsafeParse(Identifier.bindableParser, bStr) - assert(Declaration.substitute(b, d1.toNonBinding, d0) == resD) } - law("b", "12", """x = b -x""", Some("""x = 12 -x""")) + law( + "b", + "12", + """x = b +x""", + Some("""x = 12 +x""") + ) law("b", "12", """[x for b in y]""", Some("""[x for b in y]""")) law("b", "12", """[b for z in y]""", Some("""[12 for z in y]""")) @@ -209,8 +287,16 @@ x""")) law("[a for b in c if b]", List("a", "c"), List("a", "b", "c")) law("[b for b in c if d]", List("c", "d"), List("b", "c", "d")) law("[b for b in c if b]", List("c"), List("b", "c")) - law("{ k: a for b in c if d}", List("k", "a", "c", "d"), List("k", "a", "b", "c", "d")) - law("{ k: a for b in c if b}", List("k", "a", "c"), List("k", "a", "b", "c")) + law( + "{ k: a for b in c if d}", + List("k", "a", "c", "d"), + List("k", "a", "b", "c", "d") + ) + law( + "{ k: a for b in c if b}", + List("k", "a", "c"), + List("k", "a", "b", "c") + ) law("Foo { a }", List("a"), List("a")) law("Foo { a: b }", List("b"), List("b")) } diff --git a/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala b/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala index e6c4e6cba..840514408 100644 --- a/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala @@ -19,7 +19,7 @@ class DefRecursionCheckTest extends AnyFunSuite { def disallowed(teStr: String) = { val stmt = TestUtils.statementsOf(teStr) stmt.traverse_(DefRecursionCheck.checkStatement(_)) match { - case Validated.Valid(_) => fail("expected failure") + case Validated.Valid(_) => fail("expected failure") case Validated.Invalid(_) => succeed } } @@ -376,7 +376,7 @@ def nest(lst): } test("we can't use an outer def recursively") { -disallowed("""# + disallowed("""# def foo(x): def bar(y): foo(y) @@ -385,7 +385,7 @@ def foo(x): } test("we can make a recursive def in another recursive def") { -allowed("""# + allowed("""# def len(lst): # this is doing nothing, but is a nested recursion def len0(lst): @@ -400,7 +400,7 @@ def len(lst): } test("we can call a non-outer function in a recur branch") { -allowed("""# + allowed("""# def id(x): x def len(lst): diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index 04059694a..2b7316c19 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -15,7 +15,10 @@ class EvaluationTest extends AnyFunSuite with ParTest { package Foo x = 1 -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) evalTest(List("x = 1"), "Package0", VInt(1)) @@ -23,7 +26,10 @@ x = 1 List(""" # test shadowing x = match 1: case x: x -"""), "Package0", VInt(1)) +"""), + "Package0", + VInt(1) + ) evalTest( List(""" @@ -31,7 +37,10 @@ package Foo # exercise calling directly a lambda x = (y -> y)("hello") -"""), "Foo", Str("hello")) +"""), + "Foo", + Str("hello") + ) runBosatsuTest( List(""" @@ -45,7 +54,10 @@ def eq_String(a, b): case _: False test = Assertion(eq_String("hello", foo), "checking equality") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) runBosatsuTest( List(""" @@ -59,7 +71,10 @@ foo = ( ) test = Assertion(foo matches 4, "checking equality") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) runBosatsuTest( List(""" @@ -69,7 +84,10 @@ test = TestSuite("three trivial tests", [ Assertion(True, "t0"), Assertion(True, "t1"), Assertion(True, "t2"), ]) -"""), "Foo", 3) +"""), + "Foo", + 3 + ) } test("test if/else") { @@ -84,7 +102,10 @@ z = match x.cmp_Int(1): "foo" case _: "bar" -"""), "Foo", Str("foo")) +"""), + "Foo", + Str("foo") + ) evalTest( List(""" @@ -94,7 +115,10 @@ x = 1 # here if the single expression python style z = "foo" if x.eq_Int(2) else "bar" -"""), "Foo", Str("bar")) +"""), + "Foo", + Str("bar") + ) } test("exercise option from predef") { @@ -107,7 +131,10 @@ x = Some(1) z = match x: case Some(v): add(v, 10) case None: 0 -"""), "Foo", VInt(11)) +"""), + "Foo", + VInt(11) + ) // Use a local name collision and see it not have a problem evalTest( @@ -121,7 +148,10 @@ x = Some(1) z = match x: case Some(v): add(v, 10) case None: 0 -"""), "Foo", VInt(11)) +"""), + "Foo", + VInt(11) + ) evalTest( List(""" @@ -135,7 +165,10 @@ x = Some(1) z = match x: case None: 0 case Some(v): add(v, 10) -"""), "Foo", VInt(11)) +"""), + "Foo", + VInt(11) + ) } test("test matching unions") { @@ -150,7 +183,10 @@ x = Pair(Pair(1, "1"), "2") main = match x: Pair(_, "2" | "3"): "good" _: "bad" -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) evalTest( List("""package Foo @@ -164,7 +200,10 @@ def run(z): y main = run(x) -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) evalTest( List(""" @@ -179,10 +218,12 @@ def run(z): y main = run(x) -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) - evalFail( - List(""" + evalFail(List(""" package Err enum IntOrString: IntCase(i: Int), StringCase(i: Int, s: String) @@ -207,10 +248,15 @@ def go(x): main = go(IntCase(42)) """ - val packs = Map((PackageName.parts("Err"), (LocationMap(errPack), "Err.bosatsu"))) - evalFail(List(errPack)) { case te@PackageError.TypeErrorIn(_, _) => + val packs = + Map((PackageName.parts("Err"), (LocationMap(errPack), "Err.bosatsu"))) + evalFail(List(errPack)) { case te @ PackageError.TypeErrorIn(_, _) => val msg = te.message(packs, Colorize.None) - assert(msg.contains("type error: expected type Bosatsu/Predef::Int to be the same as type Bosatsu/Predef::String")) + assert( + msg.contains( + "type error: expected type Bosatsu/Predef::Int to be the same as type Bosatsu/Predef::String" + ) + ) () } @@ -227,7 +273,10 @@ def go(x): 42 main = go(IntCase(42)) -"""), "Union", VInt(42)) +"""), + "Union", + VInt(42) + ) } test("test matching literals") { @@ -240,7 +289,10 @@ x = 1 main = match x: case 1: "good" case _: "bad" -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) evalTest( List(""" @@ -252,7 +304,10 @@ x = [1] main = match x: EmptyList: "empty" NonEmptyList(...): "notempty" -"""), "Foo", Str("notempty")) +"""), + "Foo", + Str("notempty") + ) evalTest( List(""" @@ -263,7 +318,10 @@ x = "1" main = match x: case "1": "good" case _: "bad" -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) evalTest( List(""" @@ -276,7 +334,10 @@ x = Pair(1, "1") main = match x: case Pair(_, "1"): "good" case _: "bad" -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) } test("test tuples") { @@ -289,7 +350,10 @@ x = (1, "1") main = match x: case (_, "1"): "good" case _: "bad" -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) evalTest( List(""" @@ -306,7 +370,10 @@ def go(u): case _: "bad" main = go(()) -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) } test("do a fold") { @@ -323,7 +390,10 @@ sum0 = sum(three) sum1 = three.foldLeft(0, (x, y) -> add(x, y)) same = sum0.eq_Int(sum1) -"""), "Foo", True) +"""), + "Foo", + True + ) evalTest( List(""" @@ -335,7 +405,10 @@ sum0 = three.foldLeft(0, add) sum1 = three.foldLeft(0, \x, y -> add(x, y)) same = sum0.eq_Int(sum1) -"""), "Foo", True) +"""), + "Foo", + True + ) } @@ -345,7 +418,10 @@ same = sum0.eq_Int(sum1) package Foo main = 6.mod_Int(4) -"""), "Foo", VInt(2)) +"""), + "Foo", + VInt(2) + ) evalTest( List(""" @@ -355,14 +431,20 @@ main = match 6.div(4): case 0: 42 case 1: 100 case x: x -"""), "Foo", VInt(100)) +"""), + "Foo", + VInt(100) + ) evalTest( List(""" package Foo main = 6.gcd_Int(3) -"""), "Foo", VInt(3)) +"""), + "Foo", + VInt(3) + ) } test("use range") { @@ -396,10 +478,13 @@ def eq_list(a, b, fn): same_items(zip(a, b), fn) same = eq_list(three, threer, eq_Int) -"""), "Foo", True) +"""), + "Foo", + True + ) -evalTest( - List(""" + evalTest( + List(""" package Foo def zip(as: List[a], bs: List[b]) -> List[(a, b)]: @@ -411,31 +496,43 @@ def zip(as: List[a], bs: List[b]) -> List[(a, b)]: case [bh, *btail]: [(ah, bh), *zip(atail, btail)] main = 1 -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) } test("test range_fold") { -evalTest( - List(""" + evalTest( + List(""" package Foo main = range_fold(0, 10, 0, add) -"""), "Foo", VInt(45)) +"""), + "Foo", + VInt(45) + ) -evalTest( - List(""" + evalTest( + List(""" package Foo main = range_fold(0, 10, 0, (_, y) -> y) -"""), "Foo", VInt(9)) +"""), + "Foo", + VInt(9) + ) -evalTest( - List(""" + evalTest( + List(""" package Foo main = range_fold(0, 10, 100, (x, _) -> x) -"""), "Foo", VInt(100)) +"""), + "Foo", + VInt(100) + ) } test("test some list matches") { @@ -449,7 +546,10 @@ def headOption(as): case [a, *_]: Some(a) main = headOption([1]) -"""), "Foo", SumValue(1, ProductValue.single(VInt(1)))) +"""), + "Foo", + SumValue(1, ProductValue.single(VInt(1))) + ) runBosatsuTest( List(""" @@ -467,7 +567,10 @@ test = TestSuite("exists", [ Assertion(not(exists([])), "![]"), Assertion(not(exists([False])), "![False]"), ]) -"""), "Foo", 5) +"""), + "Foo", + 5 + ) } test("test generics in defs") { @@ -479,7 +582,10 @@ def id(x: a) -> a: x main = id(1) -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) } test("exercise struct creation") { @@ -490,8 +596,10 @@ package Foo struct Bar(a: Int) main = Bar(1) -"""), "Foo", - VInt(1)) +"""), + "Foo", + VInt(1) + ) evalTest( List(""" @@ -501,7 +609,10 @@ struct Bar(a: Int) # destructuring top-level let Bar(main) = Bar(1) -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) evalTest( List(""" @@ -511,7 +622,10 @@ struct Bar(a: Int) # destructuring top-level let Bar(main: Int) = Bar(1) -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) evalTest( List(""" @@ -522,7 +636,10 @@ struct Bar(a: Int) y = Bar(1) # destructuring top-level let Bar(main: Int) = y -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) evalTestJson( List(""" @@ -531,12 +648,16 @@ package Foo struct Bar(a: Int, s: String) main = Bar(1, "foo") -"""), "Foo", Json.JObject(List("a" -> Json.JNumberStr("1"), "s" -> Json.JString("foo")))) +"""), + "Foo", + Json.JObject( + List("a" -> Json.JNumberStr("1"), "s" -> Json.JString("foo")) + ) + ) } test("test some type errors") { - evalFail( - List(""" + evalFail(List(""" package Foo main = if True: @@ -546,7 +667,9 @@ else: """)) { case PackageError.TypeErrorIn(_, _) => () } } - test("test the list literals work even when we have conflicting local names") { + test( + "test the list literals work even when we have conflicting local names" + ) { evalTest( List(""" package Foo @@ -554,8 +677,10 @@ package Foo struct EmptyList main = [1, 2] -"""), "Foo", - VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil))) +"""), + "Foo", + VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil)) + ) evalTest( List(""" @@ -564,8 +689,10 @@ package Foo struct NonEmptyList main = [1, 2] -"""), "Foo", - VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil))) +"""), + "Foo", + VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil)) + ) evalTest( List(""" @@ -574,13 +701,14 @@ package Foo def concat(a): a main = [1, 2] -"""), "Foo", - VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil))) +"""), + "Foo", + VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil)) + ) } test("forbid the y-combinator") { - evalFail( - List(""" + evalFail(List(""" package Y struct W(fn: W[a, b] -> a -> b) @@ -599,29 +727,27 @@ def ltEqZero(i): fac = trace("made fac", y(\f, i -> 1 if ltEqZero(i) else f(i).times(i))) main = fac(6) -""")) { - case PackageError.KindInferenceError(_, _, _) => () - } +""")) { case PackageError.KindInferenceError(_, _, _) => + () + } - evalFail( - List(""" + evalFail(List(""" package Y struct W(wf: f[a, b] -> a -> b) def apply(w): W(fn) = w fn(w) -""")) { - case err@PackageError.TypeErrorIn(_, _) => +""")) { case err @ PackageError.TypeErrorIn(_, _) => val message = err.message(Map.empty, Colorize.None) assert(message.contains("illegal recursive type or function")) () - } + } } test("check type aligned enum") { - evalTest( - List(""" + evalTest( + List(""" package A enum GoodOrBad: @@ -633,22 +759,27 @@ def unbox(gb: GoodOrBad[a]): case Bad(b): b (main: Int) = unbox(Good(42)) -"""), "A", VInt(42)) +"""), + "A", + VInt(42) + ) - evalTest( - List(""" + evalTest( + List(""" package A enum GoodOrBad: Bad(a: a), Good(a: a) Bad(main) | Good(main) = Good(42) -"""), "A", VInt(42)) +"""), + "A", + VInt(42) + ) } test("nontotal matches fail even if not at runtime") { - evalFail( - List(""" + evalFail(List(""" package Total enum Opt: Nope, Yep(get) @@ -664,8 +795,7 @@ main = one } test("unreachable patterns are an error") { - evalFail( - List(""" + evalFail(List(""" package Total enum Opt: Nope, Yep(get) @@ -683,8 +813,8 @@ main = one } test("Leibniz type equality example") { - evalTest( - List(""" + evalTest( + List(""" package A struct Leib(subst: forall f: * -> *. f[a] -> f[b]) @@ -713,11 +843,13 @@ def getValue(v: StringOrInt[a]) -> a: case IsInt(i, leib): coerce(i, leib) main = getValue(int) -"""), "A", VInt(42)) +"""), + "A", + VInt(42) + ) - // If we leave out the coerce it fails - evalFail( - List(""" + // If we leave out the coerce it fails + evalFail(List(""" package A struct Leib(subst: forall f. f[a] -> f[b]) @@ -739,13 +871,12 @@ def getValue(v): case IsInt(i, _): i main = getValue(int) -""")){ case PackageError.TypeErrorIn(_, _) => () } +""")) { case PackageError.TypeErrorIn(_, _) => () } } test("overly generic methods fail compilation") { - evalFail( - List(""" + evalFail(List(""" package A # this shouldn't compile, a is too generic @@ -753,12 +884,11 @@ def plus(x: a, y): x.add(y) main = plus(1, 2) -""")){ case PackageError.TypeErrorIn(_, _) => () } +""")) { case PackageError.TypeErrorIn(_, _) => () } } test("unused let fails compilation") { - evalFail( - List(""" + evalFail(List(""" package A # this shouldn't compile, z is unused @@ -767,7 +897,7 @@ def plus(x, y): x.add(y) main = plus(1, 2) -""")){ case le@PackageError.UnusedLetError(_, _) => +""")) { case le @ PackageError.UnusedLetError(_, _) => val msg = le.message(Map.empty, Colorize.None) assert(!msg.contains("Name(")) assert(msg.contains("unused let binding: z\n Region(68,73)")) @@ -776,8 +906,8 @@ main = plus(1, 2) } test("structual recursion is allowed") { - evalTest( - List(""" + evalTest( + List(""" package A def len(lst, acc): @@ -786,10 +916,13 @@ def len(lst, acc): [_, *tail]: len(tail, acc.add(1)) main = len([1, 2, 3], 0) -"""), "A", VInt(3)) +"""), + "A", + VInt(3) + ) - evalTest( - List(""" + evalTest( + List(""" package A enum PNat: One, Even(of: PNat), Odd(of: PNat) @@ -801,10 +934,12 @@ def toInt(pnat): Odd(of): toInt(of).times(2).add(1) main = toInt(Even(Even(One))) -"""), "A", VInt(4)) +"""), + "A", + VInt(4) + ) - evalFail( - List(""" + evalFail(List(""" package A enum Foo: Bar, Baz @@ -815,20 +950,23 @@ def bad(foo): baz: bad(baz) main = bad(Bar) -""")){ case PackageError.RecursionError(_, _) => () } +""")) { case PackageError.RecursionError(_, _) => () } - evalTest( - List(""" + evalTest( + List(""" package A big_list = range(3_000) main = big_list.foldLeft(0, add) -"""), "A", VInt((0 until 3000).sum)) +"""), + "A", + VInt((0 until 3000).sum) + ) - def sumFn(n: Int): Int = if (n <= 0) 0 else { sumFn(n-1) + n } - evalTest( - List(""" + def sumFn(n: Int): Int = if (n <= 0) 0 else { sumFn(n - 1) + n } + evalTest( + List(""" package A enum Nat: Zero, Succ(of: Nat) @@ -844,11 +982,14 @@ def sum(nat): Succ(n): sum(n).add(toInt(nat)) main = sum(Succ(Succ(Succ(Zero)))) -"""), "A", VInt(sumFn(3))) +"""), + "A", + VInt(sumFn(3)) + ) - // try with Succ first in the Nat - evalTest( - List(""" + // try with Succ first in the Nat + evalTest( + List(""" package A enum Nat: Zero, Succ(of: Nat) @@ -864,12 +1005,15 @@ def sum(nat): Zero: 0 main = sum(Succ(Succ(Succ(Zero)))) -"""), "A", VInt(sumFn(3))) +"""), + "A", + VInt(sumFn(3)) + ) } test("we can mix literal and enum forms of List") { - evalTest( - List(""" + evalTest( + List(""" package A def len(lst, acc): @@ -878,9 +1022,12 @@ def len(lst, acc): [_, *tail]: len(tail, acc.add(1)) main = len([1, 2, 3], 0) -"""), "A", VInt(3)) - evalTest( - List(""" +"""), + "A", + VInt(3) + ) + evalTest( + List(""" package A def len(lst, acc): @@ -889,47 +1036,65 @@ def len(lst, acc): NonEmptyList(_, tail): len(tail, acc.add(1)) main = len([1, 2, 3], 0) -"""), "A", VInt(3)) +"""), + "A", + VInt(3) + ) } test("list comphension test") { - evalTest( - List(""" + evalTest( + List(""" package A main = [x for x in range(4)].foldLeft(0, add) -"""), "A", VInt(6)) - evalTest( - List(""" +"""), + "A", + VInt(6) + ) + evalTest( + List(""" package A main = [*[x] for x in range(4)].foldLeft(0, add) -"""), "A", VInt(6)) +"""), + "A", + VInt(6) + ) - evalTest( - List(""" + evalTest( + List(""" package A doub = [(x, x) for x in range(4)] main = [x.times(y) for (x, y) in doub].foldLeft(0, add) -"""), "A", VInt(1 + 4 + 9)) - evalTest( - List(""" +"""), + "A", + VInt(1 + 4 + 9) + ) + evalTest( + List(""" package A main = [x for x in range(4) if x.eq_Int(2)].foldLeft(0, add) -"""), "A", VInt(2)) +"""), + "A", + VInt(2) + ) - evalTest( - List(""" + evalTest( + List(""" package A main = [*[x, x] for x in range(4) if x.eq_Int(2)].foldLeft(0, add) -"""), "A", VInt(4)) +"""), + "A", + VInt(4) + ) - evalTest( - List(""" + evalTest( + List(""" package A def eq_List(lst1, lst2): @@ -951,12 +1116,15 @@ lst3 = [*[y, y] for (_, y) in [(x, x) for x in range(4)]] main = match (eq_List(lst1, lst2), eq_List(lst1, lst3)): case (True, True): 1 case _ : 0 -"""), "A", VInt(1)) +"""), + "A", + VInt(1) + ) } test("test fib using recursion") { - evalTest( - List(""" + evalTest( + List(""" package A enum Nat: Z, S(p: Nat) @@ -969,10 +1137,13 @@ def fib(n): # fib(5) = 1, 1, 2, 3, 5, 8 main = fib(S(S(S(S(S(Z)))))) -"""), "A", VInt(8)) +"""), + "A", + VInt(8) + ) - evalTest( - List(""" + evalTest( + List(""" package A enum Nat[a]: Z, S(p: Nat[a]) @@ -985,10 +1156,13 @@ def fib(n): # fib(5) = 1, 1, 2, 3, 5, 8 main = fib(S(S(S(S(S(Z)))))) -"""), "A", VInt(8)) +"""), + "A", + VInt(8) + ) - evalTest( - List(""" + evalTest( + List(""" package A enum Nat: S(p: Nat), Z @@ -1001,11 +1175,15 @@ def fib(n): # fib(5) = 1, 1, 2, 3, 5, 8 main = fib(S(S(S(S(S(Z)))))) -"""), "A", VInt(8)) +"""), + "A", + VInt(8) + ) } test("test matching the front of a list") { - evalTest(List(""" + evalTest( + List(""" package A def bad_len(list): @@ -1014,9 +1192,13 @@ def bad_len(list): case [*init, _]: bad_len(init).add(1) main = bad_len([1, 2, 3, 5]) -"""), "A", VInt(4)) +"""), + "A", + VInt(4) + ) - evalTest(List(""" + evalTest( + List(""" package A def last(list): @@ -1025,10 +1207,14 @@ def last(list): case [*_, s]: s main = last([1, 2, 3, 5]) -"""), "A", VInt(5)) +"""), + "A", + VInt(5) + ) } test("test a named pattern that doesn't match") { - evalTest(List(""" + evalTest( + List(""" package A def bad_len(list): @@ -1046,10 +1232,14 @@ def bad_len(list): bad_len(init).add(1) main = bad_len([1, 2, 3, 5]) -"""), "A", VInt(4)) +"""), + "A", + VInt(4) + ) } test("uncurry2") { - evalTest(List(""" + evalTest( + List(""" package A struct TwoVar(one, two) @@ -1059,10 +1249,14 @@ constructed = uncurry2(x -> y -> TwoVar(x, y))(1, "two") main = match constructed: case TwoVar(1, "two"): "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) } test("uncurry3") { - evalTest(List(""" + evalTest( + List(""" package A struct ThreeVar(one, two, three) @@ -1072,11 +1266,15 @@ constructed = uncurry3(x -> y -> z -> ThreeVar(x, y, z))(1, "two", 3) main = match constructed: case ThreeVar(1, "two", 3): "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) } test("Dict methods") { - evalTest(List(""" + evalTest( + List(""" package A e = empty_Dict(string_Order) @@ -1084,9 +1282,13 @@ e = empty_Dict(string_Order) e1 = e.add_key("hello", "world") main = e1.get_key("hello") -"""), "A", VOption.some(Str("world"))) +"""), + "A", + VOption.some(Str("world")) + ) - evalTest(List(""" + evalTest( + List(""" package A e = empty_Dict(string_Order) @@ -1094,9 +1296,13 @@ e = empty_Dict(string_Order) e1 = e.clear_Dict().add_key("hello2", "world2") main = e1.get_key("hello") -"""), "A", VOption.none) +"""), + "A", + VOption.none + ) - evalTest(List(""" + evalTest( + List(""" package A e = empty_Dict(string_Order) @@ -1105,9 +1311,13 @@ e1 = e.add_key("hello", "world") e2 = e1.remove_key("hello") main = e2.get_key("hello") -"""), "A", VOption.none) +"""), + "A", + VOption.none + ) - evalTest(List(""" + evalTest( + List(""" package A e1 = empty_Dict(string_Order) @@ -1117,9 +1327,13 @@ lst = e2.items() main = match lst: case [("hello", "world"), ("hello1", "world1")]: "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) - evalTest(List(""" + evalTest( + List(""" package A e1 = {} @@ -1129,9 +1343,13 @@ lst = e2.items() main = match lst: case [("hello", "world"), ("hello1", "world1")]: "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) - evalTest(List(""" + evalTest( + List(""" package A e = { @@ -1143,9 +1361,13 @@ lst = e.items() main = match lst: case [("hello", "world"), ("hello1", "world1")]: "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) - evalTest(List(""" + evalTest( + List(""" package A pairs = [("hello", "world"), ("hello1", "world1")] @@ -1156,9 +1378,13 @@ lst = e.items() main = match lst: case [("hello", "world"), ("hello1", "world1")]: "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) - evalTest(List(""" + evalTest( + List(""" package A pairs = [("hello", 42), ("hello1", 24)] @@ -1174,7 +1400,10 @@ lst = e.items() main = match lst: case [("hello", res)]: res case _: -1 -"""), "A", VInt(42)) +"""), + "A", + VInt(42) + ) evalTestJson( List(""" @@ -1183,7 +1412,10 @@ package Foo bar = {'a': '1', 's': 'foo' } main = bar -"""), "Foo", Json.JObject(List("a" -> Json.JString("1"), "s" -> Json.JString("foo")))) +"""), + "Foo", + Json.JObject(List("a" -> Json.JString("1"), "s" -> Json.JString("foo"))) + ) evalTestJson( List(""" @@ -1193,7 +1425,10 @@ package Foo bar: Dict[String, Option[Int]] = {'a': None, 's': None } main = bar -"""), "Foo", Json.JObject(List("a" -> Json.JNull, "s" -> Json.JNull))) +"""), + "Foo", + Json.JObject(List("a" -> Json.JNull, "s" -> Json.JNull)) + ) evalTestJson( List(""" @@ -1202,7 +1437,10 @@ package Foo bar = {'a': None, 's': Some(1) } main = bar -"""), "Foo", Json.JObject(List("a" -> Json.JNull, "s" -> Json.JNumberStr("1")))) +"""), + "Foo", + Json.JObject(List("a" -> Json.JNull, "s" -> Json.JNumberStr("1"))) + ) evalTestJson( List(""" @@ -1211,9 +1449,15 @@ package Foo bar = {'a': [], 's': [1] } main = bar -"""), "Foo", Json.JObject( - List("a" -> Json.JArray(Vector.empty), - "s" -> Json.JArray(Vector(Json.JNumberStr("1")))))) +"""), + "Foo", + Json.JObject( + List( + "a" -> Json.JArray(Vector.empty), + "s" -> Json.JArray(Vector(Json.JNumberStr("1"))) + ) + ) + ) evalTestJson( List(""" @@ -1222,32 +1466,36 @@ package Foo bar = {'a': True, 's': False } main = bar -"""), "Foo", Json.JObject( - List("a" -> Json.JBool(true), - "s" -> Json.JBool(false)))) +"""), + "Foo", + Json.JObject(List("a" -> Json.JBool(true), "s" -> Json.JBool(false))) + ) evalTestJson( List(""" package Foo main = (1, "1", ()) -"""), "Foo", Json.JArray( - Vector(Json.JNumberStr("1"), - Json.JString("1"), - Json.JNull))) +"""), + "Foo", + Json.JArray(Vector(Json.JNumberStr("1"), Json.JString("1"), Json.JNull)) + ) evalTestJson( List(""" package Foo main = [Some(Some(1)), Some(None), None] -"""), "Foo", - Json.JArray( - Vector( - Json.JArray(Vector(Json.JNumberStr("1"))), - Json.JArray(Vector(Json.JNull)), - Json.JArray(Vector.empty) - ))) +"""), + "Foo", + Json.JArray( + Vector( + Json.JArray(Vector(Json.JNumberStr("1"))), + Json.JArray(Vector(Json.JNull)), + Json.JArray(Vector.empty) + ) + ) + ) evalTestJson( List(""" @@ -1256,13 +1504,15 @@ package Foo enum FooBar: Foo(foo), Bar(bar) main = [Foo(1), Bar("1")] -"""), "Foo", - Json.JArray( - Vector( - Json.JObject( - List("foo" -> Json.JNumberStr("1"))), - Json.JObject( - List("bar" -> Json.JString("1")))))) +"""), + "Foo", + Json.JArray( + Vector( + Json.JObject(List("foo" -> Json.JNumberStr("1"))), + Json.JObject(List("bar" -> Json.JString("1"))) + ) + ) + ) } test("json handling of Nat special case") { @@ -1273,12 +1523,12 @@ package Foo enum Nat: Z, S(n: Nat) main = [Z, S(Z), S(S(Z))] -"""), "Foo", - Json.JArray( - Vector( - Json.JNumberStr("0"), - Json.JNumberStr("1"), - Json.JNumberStr("2")))) +"""), + "Foo", + Json.JArray( + Vector(Json.JNumberStr("0"), Json.JNumberStr("1"), Json.JNumberStr("2")) + ) + ) } test("json with backticks") { @@ -1291,29 +1541,37 @@ struct Foo(`struct`, `second key`, `enum`, `def`) `package` = 2 main = Foo(1, `package`, 3, 4) -"""), "Foo", - Json.JObject( - List( - ("struct" -> Json.JNumberStr("1")), - ("second key" -> Json.JNumberStr("2")), - ("enum" -> Json.JNumberStr("3")), - ("def" -> Json.JNumberStr("4"))) - )) +"""), + "Foo", + Json.JObject( + List( + ("struct" -> Json.JNumberStr("1")), + ("second key" -> Json.JNumberStr("2")), + ("enum" -> Json.JNumberStr("3")), + ("def" -> Json.JNumberStr("4")) + ) + ) + ) } test("test operators") { - evalTest(List(""" + evalTest( + List(""" package A operator + = add operator * = times main = 1 + 2 * 3 -"""), "A", VInt(7)) +"""), + "A", + VInt(7) + ) } test("patterns in lambdas") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A # you can't write \x: Int -> x.add(1) @@ -1323,17 +1581,25 @@ inc: Int -> Int = x -> x.add(1) inc2: Int -> Int = (x: Int) -> x.add(1) test = Assertion(inc(1).eq_Int(inc2(1)), "inc(1) == 2") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A def inc(x: Int): x.add(1) test = Assertion(inc(1).eq_Int(2), "inc(1) == 2") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A struct Foo(v) @@ -1360,9 +1626,13 @@ test5 = Assertion(inc4(Pair(F(1), Foo(1))).eq_Int(2), "inc4(Pair(F(1), Foo(1))) test6 = Assertion(inc4(Pair(B(1), Foo(1))).eq_Int(2), "inc4(Pair(B(1), Foo(1))) == 2") suite = TestSuite("match tests", [test0, test1, test2, test3, test4, test5, test6]) -"""), "A", 7) +"""), + "A", + 7 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A struct Foo(v) @@ -1389,96 +1659,111 @@ test5 = Assertion(inc4(Pair(F(1), Foo(1))).eq_Int(2), "inc4(Pair(F(1), Foo(1))) test6 = Assertion(inc4(Pair(B(1), Foo(1))).eq_Int(2), "inc4(Pair(B(1), Foo(1))) == 2") suite = TestSuite("match tests", [test0, test1, test2, test3, test4, test5, test6]) -"""), "A", 7) +"""), + "A", + 7 + ) } test("test some error messages") { evalFail( - List(""" + List( + """ package A a = 1 -""", """ +""", + """ package B from A import a -main = a""")) { case PackageError.UnknownImportName(_, _, _, _, _) => () } +main = a""" + ) + ) { case PackageError.UnknownImportName(_, _, _, _, _) => () } - evalFail( - List(""" + evalFail(List(""" package B from A import a main = a""")) { case PackageError.UnknownImportPackage(_, _) => () } - evalFail( - List(""" + evalFail(List(""" package B -main = a""")) { case te@PackageError.TypeErrorIn(_, _) => - val msg = te.message(Map.empty, Colorize.None) - assert(!msg.contains("Name(")) - assert(msg.contains("package B\nname \"a\" unknown")) - () - } +main = a""")) { case te @ PackageError.TypeErrorIn(_, _) => + val msg = te.message(Map.empty, Colorize.None) + assert(!msg.contains("Name(")) + assert(msg.contains("package B\nname \"a\" unknown")) + () + } - evalFail( - List(""" + evalFail(List(""" package B x = 1 main = match x: case Foo: 2 -""")) { case te@PackageError.SourceConverterErrorIn(_, _) => - val msg = te.message(Map.empty, Colorize.None) - assert(!msg.contains("Name(")) - assert(msg.contains("package B\nunknown constructor Foo")) - () - } +""")) { case te @ PackageError.SourceConverterErrorIn(_, _) => + val msg = te.message(Map.empty, Colorize.None) + assert(!msg.contains("Name(")) + assert(msg.contains("package B\nunknown constructor Foo")) + () + } - evalFail( - List(""" + evalFail(List(""" package B struct X main = match 1: case X1: 0 -""")) { case te@PackageError.SourceConverterErrorIn(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package B\nunknown constructor X1\nRegion(49,50)") +""")) { case te @ PackageError.SourceConverterErrorIn(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package B\nunknown constructor X1\nRegion(49,50)" + ) () } - evalFail( - List(""" + evalFail(List(""" package A main = match [1, 2, 3]: case []: 0 case [*a, *b, _]: 2 -""")) { case te@PackageError.TotalityCheckError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nRegion(19,70)\nmultiple splices in pattern, only one per match allowed") +""")) { case te @ PackageError.TotalityCheckError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nRegion(19,70)\nmultiple splices in pattern, only one per match allowed" + ) () } - evalFail( - List(""" + evalFail(List(""" package A enum Foo: Bar(a), Baz(b) main = match Bar(a): case Baz(b): b -""")) { case te@PackageError.TotalityCheckError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nRegion(45,75)\nnon-total match, missing: Bar(_)") +""")) { case te @ PackageError.TotalityCheckError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nRegion(45,75)\nnon-total match, missing: Bar(_)" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x): @@ -1486,13 +1771,17 @@ def fn(x): y: 0 main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nrecur but no recursive call to fn\nRegion(25,42)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nrecur but no recursive call to fn\nRegion(25,42)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x): @@ -1500,13 +1789,17 @@ def fn(x): y: 0 main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nrecur not on an argument to the def of fn, args: (x)\nRegion(25,43)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nrecur not on an argument to the def of fn, args: (x)\nRegion(25,43)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x): @@ -1514,13 +1807,17 @@ def fn(x): y: 0 main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nrecur not on an argument to the def of fn, args: (x)\nRegion(25,42)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nrecur not on an argument to the def of fn, args: (x)\nRegion(25,42)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x): @@ -1530,13 +1827,17 @@ def fn(x): z: 100 main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nunexpected recur: may only appear unnested inside a def\nRegion(47,70)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nunexpected recur: may only appear unnested inside a def\nRegion(47,70)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x): @@ -1547,13 +1848,17 @@ def fn(x): z: 100 main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nillegal shadowing on: fn. Recursive shadowing of def names disallowed\nRegion(25,81)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nillegal shadowing on: fn. Recursive shadowing of def names disallowed\nRegion(25,81)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x, y): @@ -1562,13 +1867,17 @@ def fn(x, y): case x: fn(x - 1, y + 1) main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\ninvalid recursion on fn\nRegion(63,79)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\ninvalid recursion on fn\nRegion(63,79)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x, y): @@ -1577,13 +1886,16 @@ def fn(x, y): case x: x main = fn(0, 1, 2) -""")) { case te@PackageError.TypeErrorIn(_, _) => - assert(te.message(Map.empty, Colorize.None).contains("does not match function with 3 arguments at:")) +""")) { case te @ PackageError.TypeErrorIn(_, _) => + assert( + te.message(Map.empty, Colorize.None) + .contains("does not match function with 3 arguments at:") + ) () } // we should have the region set inside - val code1571 = """ + val code1571 = """ package A def fn(x): @@ -1593,48 +1905,72 @@ def fn(x): main = fn([1, 2]) """ - evalFail(code1571 :: Nil) { case te@PackageError.TypeErrorIn(_, _) => + evalFail(code1571 :: Nil) { case te @ PackageError.TypeErrorIn(_, _) => // Make sure we point at the function directly assert(code1571.substring(67, 69) == "fn") - assert(te.message(Map.empty, Colorize.None) - .contains("the first type is a function with one argument and the second is a function with 2 arguments")) - assert(te.message(Map.empty, Colorize.None) - .contains("Region(67,69)")) + assert( + te.message(Map.empty, Colorize.None) + .contains( + "the first type is a function with one argument and the second is a function with 2 arguments" + ) + ) + assert( + te.message(Map.empty, Colorize.None) + .contains("Region(67,69)") + ) () } evalFail( - List(""" + List( + """ package A export foo foo = 3 -""", """ +""", + """ package B from A import fooz baz = fooz -""")) { case te@PackageError.UnknownImportName(_, _, _, _, _) => - assert(te.message(Map.empty, Colorize.None) == "in package: A does not have name fooz. Nearest: foo") +""" + ) + ) { case te @ PackageError.UnknownImportName(_, _, _, _, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in package: A does not have name fooz. Nearest: foo" + ) () } evalFail( - List(""" + List( + """ package A export foo foo = 3 bar = 3 -""", """ +""", + """ package B from A import bar baz = bar -""")) { case te@PackageError.UnknownImportName(_, _, _, _, _) => - assert(te.message(Map.empty, Colorize.None) == "in package: A has bar but it is not exported. Add to exports") +""" + ) + ) { case te @ PackageError.UnknownImportName(_, _, _, _, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in package: A has bar but it is not exported. Add to exports" + ) () } } @@ -1658,7 +1994,10 @@ tests = TestSuite("test triple", [ Assertion(a.eq_Int(3), "a == 3"), Assertion(bgood, b), Assertion(c.eq_Int(1), "c == 1") ]) -"""), "A", 3) +"""), + "A", + 3 + ) } test("regression from a map_List/list comprehension example from snoble") { @@ -1805,7 +2144,10 @@ tests = TestSuite("reordering", Assertion(equal_rows.equal_List(rs0.list_of_rows(), [[REBool(RecordValue(False)), REInt(RecordValue(1)), REString(RecordValue("a")), REInt(RecordValue(3))]]), "swap") ] ) -"""), "RecordSet/Library", 1) +"""), + "RecordSet/Library", + 1 + ) } test("record patterns") { @@ -1822,7 +2164,10 @@ tests = TestSuite("test record", [ Assertion(f2.eq_Int(1), "f2 == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1839,7 +2184,10 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1855,7 +2203,10 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1871,7 +2222,10 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1887,7 +2241,10 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1903,7 +2260,10 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1921,10 +2281,12 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1934,10 +2296,11 @@ get = Pair(first, ...) -> first # missing second first = 1 res = get(Pair { first }) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1948,10 +2311,11 @@ get = Pair(first, ...) -> first first = 1 second = 3 res = get(Pair { first, second, third }) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1959,10 +2323,11 @@ struct Pair(first, second) get = Pair { first } -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1971,10 +2336,11 @@ struct Pair(first, second) get = Pair(first) -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1983,10 +2349,11 @@ struct Pair(first, second) get = \Pair { first, sec: _ } -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1995,10 +2362,11 @@ struct Pair(first, second) get = Pair { first, sec: _, ... } -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -2007,10 +2375,11 @@ struct Pair(first, second) get = Pair(first, _, _) -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -2019,11 +2388,14 @@ struct Pair(first, second) get = Pair(first, _, _, ...) -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } } test("exercise total matching inside of a struct with a list") { - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct ListWrapper(items: List[a], b: Bool) @@ -2032,9 +2404,13 @@ w = ListWrapper([], True) ListWrapper([*_], r) = w tests = Assertion(r, "match with total list pattern") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct ListWrapper2(items: List[a], others: List[b], b: Bool) @@ -2043,9 +2419,13 @@ w = ListWrapper2([], [], True) ListWrapper2(_, _, r) = w tests = Assertion(r, "match with total list pattern") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct ListWrapper(items: List[(a, b)], b: Bool) @@ -2054,12 +2434,16 @@ w = ListWrapper([], True) ListWrapper(_, r) = w tests = Assertion(r, "match with total list pattern") -"""), "A", 1) +"""), + "A", + 1 + ) } test("test scoping bug (issue #311)") { - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct Foo(x, y) @@ -2069,9 +2453,13 @@ tests = TestSuite("test record", [ Assertion(x.eq_Int(42), "x == 42"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct Foo(x, y) @@ -2083,9 +2471,13 @@ tests = TestSuite("test record", [ Assertion(x.eq_Int(42), "x == 42"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct Foo(x, y) @@ -2106,12 +2498,16 @@ tests = TestSuite("test record", [ Assertion(y.eq_Int(43), "y == 43"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) } test("test ordered shadowing issue #328") { - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A one = 1 @@ -2127,10 +2523,14 @@ tests = TestSuite("test", [ Assertion(good, ""), ]) -"""), "A", 1) +"""), + "A", + 1 + ) // test record syntax - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct Foo(one) @@ -2150,10 +2550,14 @@ tests = TestSuite("test", [ Assertion(good, ""), ]) -"""), "A", 1) +"""), + "A", + 1 + ) // test local shadowing of a duplicate - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A one = 1 @@ -2172,10 +2576,14 @@ tests = TestSuite("test", [ Assertion(good, ""), ]) -"""), "A", 1) +"""), + "A", + 1 + ) // test an example using a predef function, like add - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A # this should be add from predef two = add(1, 1) @@ -2191,51 +2599,67 @@ tests = TestSuite("test", [ Assertion(good, ""), ]) -"""), "A", 1) +"""), + "A", + 1 + ) } test("shadowing of external def isn't allowed") { - evalFail( - List(""" + evalFail(List(""" package A external def foo(x: String) -> List[String] def foo(x): x -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => - assert(s.message(Map.empty, Colorize.None) == "in file: , package A\nbind names foo shadow external def\nRegion(57,71)") +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + assert( + s.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nbind names foo shadow external def\nRegion(57,71)" + ) () } - evalFail( - List(""" + evalFail(List(""" package A external def foo(x: String) -> List[String] foo = 1 -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => - assert(s.message(Map.empty, Colorize.None) == "in file: , package A\nbind names foo shadow external def\nRegion(57,65)") +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + assert( + s.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nbind names foo shadow external def\nRegion(57,65)" + ) () } - evalFail( - List(""" + evalFail(List(""" package A external def foo(x: String) -> List[String] external def foo(x: String) -> List[String] -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => - assert(s.message(Map.empty, Colorize.None) == "in file: , package A\nexternal def: foo defined multiple times\nRegion(21,55)") +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + assert( + s.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nexternal def: foo defined multiple times\nRegion(21,55)" + ) () } } test("test meta escape bug") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A struct Build[f] @@ -2249,69 +2673,92 @@ def useList(args: List[Build[File]]): check = useList([]) tests = Assertion(check, "none") -"""), "A", 1) +"""), + "A", + 1 + ) } test("type parameters must be supersets for structs and enums fails") { -evalFail( - List(""" + evalFail(List(""" package Err struct Foo[a](a) main = Foo(1, "2") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nFoo found declared: [a], not a superset of [b]\nRegion(14,30)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nFoo found declared: [a], not a superset of [b]\nRegion(14,30)" + ) () } -evalFail( - List(""" + evalFail(List(""" package Err struct Foo[a](a: a, b: b) main = Foo(1, "2") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nFoo found declared: [a], not a superset of [a, b]\nRegion(14,39)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nFoo found declared: [a], not a superset of [a, b]\nRegion(14,39)" + ) () } -evalFail( - List(""" + evalFail(List(""" package Err enum Enum[a]: Foo(a) main = Foo(1, "2") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nEnum found declared: [a], not a superset of [b]\nRegion(14,34)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nEnum found declared: [a], not a superset of [b]\nRegion(14,34)" + ) () } -evalFail( - List(""" + evalFail(List(""" package Err enum Enum[a]: Foo(a: a), Bar(a: b) main = Foo(1, "2") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nEnum found declared: [a], not a superset of [a, b]\nRegion(14,48)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nEnum found declared: [a], not a superset of [a, b]\nRegion(14,48)" + ) () } } test("test duplicate import message") { - evalFail( - List(""" + evalFail(List(""" package Err from Bosatsu/Predef import foldLeft main = 1 -""")) { case sce@PackageError.DuplicatedImport(_) => - assert(sce.message(Map.empty, Colorize.None) == "duplicate import in package Bosatsu/Predef imports foldLeft as foldLeft") +""")) { case sce @ PackageError.DuplicatedImport(_) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "duplicate import in package Bosatsu/Predef imports foldLeft as foldLeft" + ) () } } @@ -2326,15 +2773,20 @@ main = 1 |main = 1 |""".stripMargin - evalFail(List(pack, pack)) { case sce@PackageError.DuplicatedPackageError(_) => - assert(sce.message(Map.empty, Colorize.None) == "package Err duplicated in 0, 1") - () + evalFail(List(pack, pack)) { + case sce @ PackageError.DuplicatedPackageError(_) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "package Err duplicated in 0, 1" + ) + () } } test("test bad list pattern message") { - evalFail( - List(""" + evalFail(List(""" package Err x = [1, 2, 3] @@ -2343,16 +2795,20 @@ main = match x: case [*_, *_]: "bad" case _: "still bad" -""")) { case sce@PackageError.TotalityCheckError(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nRegion(36,89)\nmultiple splices in pattern, only one per match allowed") +""")) { case sce @ PackageError.TotalityCheckError(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nRegion(36,89)\nmultiple splices in pattern, only one per match allowed" + ) () } } test("test bad string pattern message") { val dollar = '$' - evalFail( - List(s""" + evalFail(List(s""" package Err x = "foo bar" @@ -2361,16 +2817,19 @@ main = match x: case "$dollar{_}$dollar{_}": "bad" case _: "still bad" -""")) { case sce@PackageError.TotalityCheckError(_, _) => +""")) { case sce @ PackageError.TotalityCheckError(_, _) => val dollar = '$' - assert(sce.message(Map.empty, Colorize.None) == - s"in file: , package Err\nRegion(36,91)\ninvalid string pattern: '$dollar{_}$dollar{_}' (adjacent string bindings aren't allowed)") + assert( + sce.message(Map.empty, Colorize.None) == + s"in file: , package Err\nRegion(36,91)\ninvalid string pattern: '$dollar{_}$dollar{_}' (adjacent string bindings aren't allowed)" + ) () } } test("test parsing type annotations") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A x: Int = 1 @@ -2381,9 +2840,13 @@ y = ( ) tests = Assertion(y.eq_Int(x), "none") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A x: Int = 1 @@ -2394,11 +2857,15 @@ y = ( ) tests = Assertion(y.eq_Int(x), "none") -"""), "A", 1) +"""), + "A", + 1 + ) } test("improve coverage of typedexpr normalization") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A enum MyBool: T, F @@ -2407,9 +2874,13 @@ main = match T: case F: False tests = Assertion(main, "t1") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A f = _ -> True @@ -2421,9 +2892,13 @@ tests = Assertion(fn((y = 1 # ignore y _ = y 2)), "t1") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A def inc(x): @@ -2434,9 +2909,13 @@ def inc(x): z.add(y) tests = Assertion(inc(1).eq_Int(2), "t1") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A w = 1 @@ -2451,9 +2930,13 @@ def inc(x): case x: x tests = Assertion(inc(1).eq_Int(2), "t1") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package QueueTest struct Queue[a](front: List[a], back: List[a]) @@ -2463,18 +2946,26 @@ def fold_Queue(Queue(f, b): Queue[a], binit: b, fold_fn: (b, a) -> b) -> b: b.reverse().foldLeft(front, fold_fn) test = Assertion(Queue([1], [2]).fold_Queue(0, add).eq_Int(3), "foldQueue") -"""), "QueueTest", 1) +"""), + "QueueTest", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A three = (x = 1 y -> x.add(y))(2) test = Assertion(three.eq_Int(3), "let inside apply") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A substitute = ( @@ -2484,22 +2975,28 @@ substitute = ( ) test = Assertion(substitute.eq_Int(42), "basis substitution") -"""), "A", 1) +"""), + "A", + 1 + ) } test("we can use .( ) to get |> like syntax for lambdas") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A three = 2.(x -> add(x, 1))() test = Assertion(three.eq_Int(3), "let inside apply") -"""), "A", 1) +"""), + "A", + 1 + ) } test("colliding type names cause errors") { - evalFail( - List(s""" + evalFail(List(s""" package Err struct Foo @@ -2507,15 +3004,19 @@ struct Foo struct Foo(x) main = Foo(1) -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\ntype name: Foo defined multiple times\nRegion(14,24)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\ntype name: Foo defined multiple times\nRegion(14,24)" + ) () } } test("colliding constructor names cause errors") { - evalFail( - List(s""" + evalFail(List(s""" package Err enum Bar: Foo @@ -2523,23 +3024,33 @@ enum Bar: Foo struct Foo(x) main = Foo(1) -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nconstructor: Foo defined multiple times\nRegion(14,27)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nconstructor: Foo defined multiple times\nRegion(14,27)" + ) () } } test("non binding top levels work") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A # this is basically a typecheck only _ = add(1, 2) test = Assertion(True, "") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A # this is basically a typecheck only @@ -2547,9 +3058,13 @@ x = (1, "1") (_, _) = x test = Assertion(True, "") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A struct Foo(x, y) @@ -2558,11 +3073,15 @@ x = Foo(1, "1") Foo(_, _) = x test = Assertion(True, "") -"""), "A", 1) +"""), + "A", + 1 + ) } test("recursion check with _ pattern: issue 573") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package VarSet/Recursion enum Thing: @@ -2574,7 +3093,10 @@ def bar(y, _: String, x): Thing2(i, t): bar(i, "boom", t) test = Assertion(True, "") -"""), "VarSet/Recursion", 1) +"""), + "VarSet/Recursion", + 1 + ) } test("recursion check with shadowing") { @@ -2591,8 +3113,13 @@ def bar(y, _: String, x): Thing2(i, t): bar(i, "boom", t) test = Assertion(True, "") -""")) { case re@PackageError.RecursionError(_, _) => - assert(re.message(Map.empty, Colorize.None) == "in file: , package S\nrecur not on an argument to the def of bar, args: (y, _: String, x)\nRegion(107,165)\n") +""")) { case re @ PackageError.RecursionError(_, _) => + assert( + re.message( + Map.empty, + Colorize.None + ) == "in file: , package S\nrecur not on an argument to the def of bar, args: (y, _: String, x)\nRegion(107,165)\n" + ) () } } @@ -2605,8 +3132,13 @@ out = match (1,2): case (a, a): a test = Assertion(True, "") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Foo\nrepeated bindings in pattern: a\nRegion(48,49)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Foo\nrepeated bindings in pattern: a\nRegion(48,49)" + ) () } evalFail(List(""" @@ -2617,11 +3149,17 @@ out = match [(1,2), (1, 0)]: case _: 0 test = Assertion(True, "") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Foo\nrepeated bindings in pattern: a\nRegion(68,69)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Foo\nrepeated bindings in pattern: a\nRegion(68,69)" + ) () } - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo out = match [(1,2), (1, 0)]: @@ -2629,11 +3167,15 @@ out = match [(1,2), (1, 0)]: case _: 0 test = Assertion(out.eq_Int(1), "") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) } test("test some complex list patterns, issue 574") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo out = match [(True, 2), (True, 0)]: @@ -2642,7 +3184,10 @@ out = match [(True, 2), (True, 0)]: case _: -1 test = Assertion(out.eq_Int(0), "") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) } test("unknown type constructor message is good. issue 653") { @@ -2653,8 +3198,13 @@ struct Bar(baz: Either[Int, String]) test = Assertion(True, "") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Foo\nunknown type: Either\nRegion(14,50)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Foo\nunknown type: Either\nRegion(14,50)" + ) () } } @@ -2668,7 +3218,7 @@ export FooE() enum FooE: Foo1, Foo2 """ :: -""" + """ package Bar from Foo import Foo1, Foo2 @@ -2679,7 +3229,10 @@ m = match x: case Foo2: False test = Assertion(m, "x matches Foo1") -""" :: Nil, "Bar", 1) +""" :: Nil, + "Bar", + 1 + ) } test("its an error to export a value and not its type. issue 782") { @@ -2698,13 +3251,19 @@ from Foo import bar x = bar """ :: Nil) { case sce => - assert(sce.message(Map.empty, Colorize.None) == "in export bar of type Foo::Bar has an unexported (private) type.") + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in export bar of type Foo::Bar has an unexported (private) type." + ) () } } test("test def with type params") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo def foo[a](a: a) -> a: @@ -2714,7 +3273,10 @@ def foo[a](a: a) -> a: and_again(again(x)) test = Assertion(foo(True), "") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) evalFail(List(""" package Foo @@ -2725,8 +3287,13 @@ def foo[a](a: a) -> a: def and_again[b](x: a): x and_again(again(x)) -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Foo\nand_again found declared types: [b], not a subset of [a]\nRegion(71,118)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Foo\nand_again found declared types: [b], not a subset of [a]\nRegion(71,118)" + ) () } } @@ -2744,12 +3311,12 @@ struct RecordGetter[shape, t]( def get[shape](sh: shape[RecordValue], RecordGetter(getter): RecordGetter[shape, t]) -> t: RecordValue(result) = sh.getter() result -""")) { case PackageError.TypeErrorIn(_, _) => () - } +""")) { case PackageError.TypeErrorIn(_, _) => () } } test("test quicklook example") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo def f(fn: forall a. List[a] -> List[a]) -> Int: @@ -2784,7 +3351,10 @@ pair = Pair1(single_id1, single_id2) comp = x -> f(g(x)) test = Assertion(True, "") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) } test("ill-kinded structs point to the right region") { @@ -2793,12 +3363,14 @@ test = Assertion(True, "") package Foo struct Foo(a: f[a], b: f) -""")) { case kie@PackageError.KindInferenceError(_, _, _) => - assert(kie.message(Map.empty, Colorize.None) == - """in file: , package Foo +""")) { case kie @ PackageError.KindInferenceError(_, _, _) => + assert( + kie.message(Map.empty, Colorize.None) == + """in file: , package Foo shape error: expected kind(f) and * to match in the constructor Foo -Region(14,39)""") +Region(14,39)""" + ) () } @@ -2806,18 +3378,21 @@ Region(14,39)""") package Foo struct Foo[a: *](a: a[Int]) -""")) { case kie@PackageError.KindInferenceError(_, _, _) => - assert(kie.message(Map.empty, Colorize.None) == - """in file: , package Foo +""")) { case kie @ PackageError.KindInferenceError(_, _, _) => + assert( + kie.message(Map.empty, Colorize.None) == + """in file: , package Foo shape error: expected * -> ? but found * in the constructor Foo inside type a[Bosatsu/Predef::Int] -Region(14,41)""") +Region(14,41)""" + ) () } } test("example from issue #264") { - runBosatsuTest(""" + runBosatsuTest( + """ package SubsumeTest def lengths(l1: List[Int], l2: List[String], maybeFn: Option[forall tt. List[tt] -> Int]): @@ -2836,7 +3411,10 @@ x = match []: case [h, *_]: (h: forall a. a) test = Assertion(lengths([], [], None) matches 0, "test") - """ :: Nil, "SubsumeTest", 1) + """ :: Nil, + "SubsumeTest", + 1 + ) } test("ill kinded code examples") { @@ -2850,9 +3428,10 @@ struct Id(a) # this code could run if we ignored kinds def makeFoo(v: Int): Foo(Id(v)) -""")) { case kie@PackageError.TypeErrorIn(_, _) => - assert(kie.message(Map.empty, Colorize.None) == - """in file: , package Foo +""")) { case kie @ PackageError.TypeErrorIn(_, _) => + assert( + kie.message(Map.empty, Colorize.None) == + """in file: , package Foo kind error: the type: ?0 of kind: (* -> *) -> * at: Region(183,188) @@ -2872,18 +3451,19 @@ struct Id(a) # this code could run if we ignored kinds def makeFoo(v: Int) -> Foo[Id, Int]: Foo(Id(v)) -""")) { case kie@PackageError.TypeErrorIn(_, _) => - assert(kie.message(Map.empty, Colorize.None) == - """in file: , package Foo +""")) { case kie @ PackageError.TypeErrorIn(_, _) => + assert( + kie.message(Map.empty, Colorize.None) == + """in file: , package Foo kind error: the type: Foo::Foo[Foo::Id] is invalid because the left Foo::Foo has kind ((* -> *) -> *) -> (* -> *) -> * and the right Foo::Id has kind +* -> * but left cannot accept the kind of the right: Region(195,205)""" ) () } - + } test("print a decent message when arguments are omitted") { - evalFail(List(""" + evalFail(List(""" package QS def quick_sort0(cmp, left, right): @@ -2898,7 +3478,7 @@ def quick_sort0(cmp, left, right): # we accidentally omit bigger below bigs = quick_sort0(cmp, tail) [*smalls, *bigs] -""")) { case kie@PackageError.TypeErrorIn(_, _) => +""")) { case kie @ PackageError.TypeErrorIn(_, _) => assert(kie.message(Map.empty, Colorize.None) == """in file: , package QS type error: expected type Bosatsu/Predef::Fn3[(?13, ?9) -> Bosatsu/Predef::Comparison] Region(403,414) @@ -2907,7 +3487,7 @@ hint: the first type is a function with 3 arguments and the second is a function Region(415,424)""") () } - + } test("error early on a bad type in a recursive function") { @@ -2922,7 +3502,7 @@ def toInt(n: N, acc: Int) -> Int: case S(n): toInt(n, "foo") """ - evalFail(List(testCode)) { case kie@PackageError.TypeErrorIn(_, _) => + evalFail(List(testCode)) { case kie @ PackageError.TypeErrorIn(_, _) => val message = kie.message(Map.empty, Colorize.None) assert(message.contains("Region(122,127)")) val badRegion = testCode.substring(122, 127) @@ -2932,7 +3512,8 @@ def toInt(n: N, acc: Int) -> Int: } test("declaring a generic parameter works fine") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Generic enum NEList[a: +*]: @@ -2944,9 +3525,13 @@ def head(nel: NEList[a]) -> a: case One(a) | Many(a, _): a test = Assertion(head(One(True)), "") -"""), "Generic", 1) +"""), + "Generic", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Generic enum NEList[a: +*]: @@ -2958,10 +3543,14 @@ def head[a](nel: NEList[a]) -> a: case One(a) | Many(a, _): a test = Assertion(head(One(True)), "") -"""), "Generic", 1) +"""), + "Generic", + 1 + ) // With recursion - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Generic enum NEList[a: +*]: @@ -2974,10 +3563,14 @@ def last(nel: NEList[a]) -> a: case Many(_, tail): last(tail) test = Assertion(last(One(True)), "") -"""), "Generic", 1) +"""), + "Generic", + 1 + ) // With recursion - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Generic enum NEList[a: +*]: @@ -2990,7 +3583,10 @@ def last[a](nel: NEList[a]) -> a: case Many(_, tail): last(tail) test = Assertion(last(One(True)), "") -"""), "Generic", 1) +"""), + "Generic", + 1 + ) } test("support polymorphic recursion") { @@ -3009,8 +3605,10 @@ def poly_rec(count: Nat, a: a) -> a: b test = Assertion(True, "") -""") - , "PolyRec", 1) +"""), + "PolyRec", + 1 + ) runBosatsuTest( List(""" @@ -3032,8 +3630,10 @@ def call(a): poly_rec(NZero, a) test = Assertion(True, "") -""") - , "PolyRec", 1) +"""), + "PolyRec", + 1 + ) } test("recursion on continuations") { @@ -3057,7 +3657,10 @@ def loop(box: Cont) -> Int: v = loop(b) main = v -"""), "A", VInt(2)) +"""), + "A", + VInt(2) + ) // Generic version evalTest( @@ -3078,10 +3681,13 @@ def loop[a](box: Cont[a]) -> a: loopgen: forall a. Cont[a] -> a = loop b: Cont[Int] = Item(1).map(x -> x.add(1)) main: Int = loop(b) -"""), "A", VInt(2)) +"""), + "A", + VInt(2) + ) - // this example also exercises polymorphic recursion - evalTest( + // this example also exercises polymorphic recursion + evalTest( List(""" package A enum Box[a: +*]: @@ -3100,7 +3706,10 @@ def loop[a](box: Box[a]) -> a: v = loop(b) main = v -"""), "A", VInt(1)) +"""), + "A", + VInt(1) + ) } test("we get error messages from multiple type errors top level") { @@ -3111,7 +3720,7 @@ x: Int = "1" y: String = 1 """ - evalFail(List(testCode)) { case kie@PackageError.TypeErrorIn(_, _) => + evalFail(List(testCode)) { case kie @ PackageError.TypeErrorIn(_, _) => val message = kie.message(Map.empty, Colorize.None) assert(message.contains("Region(30,33)")) assert(testCode.substring(30, 33) == "\"1\"") @@ -3132,7 +3741,7 @@ z = ( ) """ - evalFail(List(testCode)) { case kie@PackageError.TypeErrorIn(_, _) => + evalFail(List(testCode)) { case kie @ PackageError.TypeErrorIn(_, _) => val message = kie.message(Map.empty, Colorize.None) assert(message.contains("Region(38,41)")) assert(testCode.substring(38, 41) == "\"1\"") @@ -3168,11 +3777,15 @@ def last(str) -> Option[Char]: test3 = Assertion(last("foo") matches Some(.'o'), "last test") all = TestSuite("chars", [test1, test2, test3]) -"""), "Foo", 3) +"""), + "Foo", + 3 + ) } test("test universal quantified list match") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo empty: (forall a. List[a]) = [] @@ -3182,12 +3795,16 @@ res = match empty: case [_, *_]: 1 test = Assertion(res matches 0, "one") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) } test("existential quantification in a match") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo enum FreeF[a]: @@ -3212,9 +3829,13 @@ def run[a](fa: FreeF[a]) -> a: res = run(pure(0).map(x -> x.add(1))) test = Assertion(res matches 1, "one") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo enum ListE[a]: @@ -3232,7 +3853,10 @@ def uncons[a](l: ListE[a]) -> exists b. Option[((a, b), ListE[b])]: res = cons((1, 0), Empty).uncons() test = Assertion(res matches Some(((1, _), Empty)), "one") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) } test("tuples bigger than 32 fail") { @@ -3245,11 +3869,16 @@ z = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33) """ - evalFail(List(testCode)) { case kie@PackageError.SourceConverterErrorIn(_, _) => - val message = kie.message(Map.empty, Colorize.None) - assert(message.contains("invalid tuple size. Found 33, but maximum allowed 32")) - assert(message.contains("Region(25,154)")) - () + evalFail(List(testCode)) { + case kie @ PackageError.SourceConverterErrorIn(_, _) => + val message = kie.message(Map.empty, Colorize.None) + assert( + message.contains( + "invalid tuple size. Found 33, but maximum allowed 32" + ) + ) + assert(message.contains("Region(25,154)")) + () } val testCode1 = """ @@ -3266,11 +3895,16 @@ res = z matches (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 31, 32, 33) """ - evalFail(List(testCode1)) { case kie@PackageError.SourceConverterErrorIn(_, _) => - val message = kie.message(Map.empty, Colorize.None) - assert(message.contains("invalid tuple size. Found 33, but maximum allowed 32")) - assert(message.contains("Region(158,297)")) - () + evalFail(List(testCode1)) { + case kie @ PackageError.SourceConverterErrorIn(_, _) => + val message = kie.message(Map.empty, Colorize.None) + assert( + message.contains( + "invalid tuple size. Found 33, but maximum allowed 32" + ) + ) + assert(message.contains("Region(158,297)")) + () } } @@ -3281,16 +3915,18 @@ package ErrorCheck struct Foo[a: -*](get: a) """ - evalFail(List(testCode)) { case kie@PackageError.KindInferenceError(_, _, _) => - val message = kie.message(Map.empty, Colorize.None) - assert(message.contains("Region(21,46)")) - assert(testCode.substring(21, 46) == "struct Foo[a: -*](get: a)") - () + evalFail(List(testCode)) { + case kie @ PackageError.KindInferenceError(_, _, _) => + val message = kie.message(Map.empty, Colorize.None) + assert(message.contains("Region(21,46)")) + assert(testCode.substring(21, 46) == "struct Foo[a: -*](get: a)") + () } } test("test non-base 10 literals") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo test = TestSuite("bases", @@ -3308,7 +3944,10 @@ test = TestSuite("bases", Assertion(2 matches 0b10, "11"), Assertion(2 matches 0B10, "12"), ]) -"""), "Foo", 12) +"""), + "Foo", + 12 + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/FreeVarTest.scala b/core/src/test/scala/org/bykn/bosatsu/FreeVarTest.scala index 5db08f8e6..f2209da8a 100644 --- a/core/src/test/scala/org/bykn/bosatsu/FreeVarTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/FreeVarTest.scala @@ -1,14 +1,17 @@ package org.bykn.bosatsu -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class FreeVarTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = PropertyCheckConfiguration(minSuccessful = 1000) - //PropertyCheckConfiguration(minSuccessful = 300) - //PropertyCheckConfiguration(minSuccessful = 5) + // PropertyCheckConfiguration(minSuccessful = 300) + // PropertyCheckConfiguration(minSuccessful = 5) def assertFreeVars(stmt: String, vars: List[String]) = Statement.parser.parseAll(stmt) match { @@ -25,14 +28,18 @@ class FreeVarTest extends AnyFunSuite { assertFreeVars("""y = 1""", Nil) assertFreeVars("""external foo: Int""", Nil) assertFreeVars("""def foo(x): y""", List("y")) - assertFreeVars("""def foo(x): + assertFreeVars( + """def foo(x): y = x - y""", Nil) + y""", + Nil + ) } test("freeVars is a subset of allNames") { forAll(Generators.genStatement(3)) { stmt => - Statement.valuesOf(stmt :: Nil) + Statement + .valuesOf(stmt :: Nil) .foreach { v => assert(v.freeVars.subsetOf(v.allNames)) } diff --git a/core/src/test/scala/org/bykn/bosatsu/Gen.scala b/core/src/test/scala/org/bykn/bosatsu/Gen.scala index 5a4661c84..22ecfce66 100644 --- a/core/src/test/scala/org/bykn/bosatsu/Gen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/Gen.scala @@ -43,7 +43,10 @@ object Generators { for { e <- Gen.lzy(typeRefGen) cnt <- Gen.choose(1, 3) - args <- Gen.listOfN(cnt, Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKind))) + args <- Gen.listOfN( + cnt, + Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKind)) + ) nel = NonEmptyList.fromListUnsafe(args) } yield TypeRef.TypeForAll(nel, e) @@ -51,7 +54,10 @@ object Generators { for { e <- Gen.lzy(typeRefGen) cnt <- Gen.choose(1, 3) - args <- Gen.listOfN(cnt, Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKind))) + args <- Gen.listOfN( + cnt, + Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKind)) + ) nel = NonEmptyList.fromListUnsafe(args) } yield TypeRef.TypeExists(nel, e) @@ -63,7 +69,10 @@ object Generators { multiGen = Gen.oneOf(Operators.multiToks) ms <- Gen.listOfN(c, multiGen) asStr = ms.mkString - res <- (if ((asStr != "<-") && (asStr != "->")) Gen.const(Identifier.Operator(asStr)) else multi) + res <- + (if ((asStr != "<-") && (asStr != "->")) + Gen.const(Identifier.Operator(asStr)) + else multi) } yield res Gen.frequency((4, sing), (1, multi)) @@ -71,11 +80,20 @@ object Generators { val bindIdentGen: Gen[Identifier.Bindable] = Gen.frequency( - (10, lowerIdent.filter { n => !Declaration.keywords(n) }.map { n => Identifier.Name(n) }), + ( + 10, + lowerIdent.filter { n => !Declaration.keywords(n) }.map { n => + Identifier.Name(n) + } + ), (1, opGen), - (1, Arbitrary.arbitrary[String].map { s => - Identifier.Backticked(s) - })) + ( + 1, + Arbitrary.arbitrary[String].map { s => + Identifier.Backticked(s) + } + ) + ) lazy val typeRefGen: Gen[TypeRef] = { import TypeRef._ @@ -101,11 +119,17 @@ object Generators { Gen.frequency( (5, tvar), (5, tname), - (1, Gen.zip(Gen.lzy(smallNonEmptyList(typeRefGen, 4)), Gen.lzy(typeRefGen)).map { case (a, b) => TypeArrow(a, b) }), + ( + 1, + Gen + .zip(Gen.lzy(smallNonEmptyList(typeRefGen, 4)), Gen.lzy(typeRefGen)) + .map { case (a, b) => TypeArrow(a, b) } + ), (1, tLambda), (1, typeRefExistsGen), (1, tTup), - (1, tApply)) + (1, tApply) + ) } implicit val shrinkTypeRef: Shrink[TypeRef] = @@ -119,13 +143,13 @@ object Generators { case TypeApply(of, args) => of #:: args.toList.toStream case TypeForAll(par, expr) => val rest = NonEmptyList.fromList(par.tail) match { - case None => Stream.empty + case None => Stream.empty case Some(nel) => TypeForAll(nel, expr) #:: Stream.empty } expr #:: rest case TypeExists(par, expr) => val rest = NonEmptyList.fromList(par.tail) match { - case None => Stream.empty + case None => Stream.empty case Some(nel) => TypeExists(nel, expr) #:: Stream.empty } expr #:: rest @@ -144,7 +168,9 @@ object Generators { }) def commentGen[T](dec: Gen[T]): Gen[CommentStatement[T]] = { - def cleanNewLine(s: String): String = s.map { c => if (c == '\n') ' ' else c } + def cleanNewLine(s: String): String = s.map { c => + if (c == '\n') ' ' else c + } for { cs <- nonEmpty(Arbitrary.arbitrary[String]) t <- dec @@ -159,7 +185,7 @@ object Generators { def argToPat(arg: (Identifier.Bindable, Option[TypeRef])): Pattern.Parsed = arg match { - case (bn, None) => Pattern.Var(bn) + case (bn, None) => Pattern.Var(bn) case (bn, Some(t)) => Pattern.Annotation(Pattern.Var(bn), t) } @@ -170,21 +196,43 @@ object Generators { tpes <- smallList(Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKind))) retType <- Gen.option(typeRefGen) body <- dec - } yield DefStatement(name, NonEmptyList.fromList(tpes), args.map(_.map(argToPat)), retType, body) + } yield DefStatement( + name, + NonEmptyList.fromList(tpes), + args.map(_.map(argToPat)), + retType, + body + ) - def genSpliceOrItem[A](spliceGen: Gen[A], itemGen: Gen[A]): Gen[ListLang.SpliceOrItem[A]] = - Gen.oneOf(spliceGen.map(ListLang.SpliceOrItem.Splice(_)), - itemGen.map(ListLang.SpliceOrItem.Item(_))) + def genSpliceOrItem[A]( + spliceGen: Gen[A], + itemGen: Gen[A] + ): Gen[ListLang.SpliceOrItem[A]] = + Gen.oneOf( + spliceGen.map(ListLang.SpliceOrItem.Splice(_)), + itemGen.map(ListLang.SpliceOrItem.Item(_)) + ) - def genListLangCons[A](spliceGen: Gen[A], itemGen: Gen[A]): Gen[ListLang.Cons[ListLang.SpliceOrItem, A]] = { - Gen.choose(0, 5) + def genListLangCons[A]( + spliceGen: Gen[A], + itemGen: Gen[A] + ): Gen[ListLang.Cons[ListLang.SpliceOrItem, A]] = { + Gen + .choose(0, 5) .flatMap(Gen.listOfN(_, genSpliceOrItem(spliceGen, itemGen))) .map(ListLang.Cons(_)) } - def genListLangDictCons[A](itemGen: Gen[A]): Gen[ListLang.Cons[ListLang.KVPair, A]] = { - Gen.choose(0, 5) - .flatMap(Gen.listOfN(_, - Gen.zip(itemGen, itemGen).map { case (k, v) => ListLang.KVPair(k, v) })) + def genListLangDictCons[A]( + itemGen: Gen[A] + ): Gen[ListLang.Cons[ListLang.KVPair, A]] = { + Gen + .choose(0, 5) + .flatMap( + Gen.listOfN( + _, + Gen.zip(itemGen, itemGen).map { case (k, v) => ListLang.KVPair(k, v) } + ) + ) .map(ListLang.Cons(_)) } @@ -193,15 +241,21 @@ object Generators { val item = Gen.oneOf( - Arbitrary.arbitrary[String].filter(_.length > 1).map { s => StringDecl.Literal(emptyRegion, s) }, + Arbitrary.arbitrary[String].filter(_.length > 1).map { s => + StringDecl.Literal(emptyRegion, s) + }, dec0.map(StringDecl.StrExpr(_)), - dec0.map(StringDecl.CharExpr(_)), - ) + dec0.map(StringDecl.CharExpr(_)) + ) - def removeAdj[A](nea: NonEmptyList[A])(fn: (A, A) => Boolean): NonEmptyList[A] = + def removeAdj[A]( + nea: NonEmptyList[A] + )(fn: (A, A) => Boolean): NonEmptyList[A] = nea match { - case NonEmptyList(a1, a2 :: tail) if fn(a1, a2) => removeAdj(NonEmptyList(a2, tail))(fn) - case NonEmptyList(a1, a2 :: tail) => NonEmptyList(a1, removeAdj(NonEmptyList(a2, tail))(fn).toList) + case NonEmptyList(a1, a2 :: tail) if fn(a1, a2) => + removeAdj(NonEmptyList(a2, tail))(fn) + case NonEmptyList(a1, a2 :: tail) => + NonEmptyList(a1, removeAdj(NonEmptyList(a2, tail))(fn).toList) case ne1 => ne1 } @@ -211,13 +265,15 @@ object Generators { nel = NonEmptyList.fromListUnsafe(lst) // make sure we don't have two adjacent strings nel1 = removeAdj(nel) { - case (StringDecl.Literal(_, _), StringDecl.Literal(_, _)) => true - case _ => false + case (StringDecl.Literal(_, _), StringDecl.Literal(_, _)) => true + case _ => false } } yield Declaration.StringDecl(nel1)(emptyRegion) res.filter { - case Declaration.StringDecl(NonEmptyList(StringDecl.Literal(_, _), Nil)) => + case Declaration.StringDecl( + NonEmptyList(StringDecl.Literal(_, _), Nil) + ) => false case _ => true } @@ -225,8 +281,8 @@ object Generators { def listGen(dec0: Gen[NonBinding]): Gen[Declaration.ListDecl] = { lazy val filterFn: NonBinding => Boolean = { - case Declaration.IfElse(_, _) => false - case Declaration.Match(_, _, _) => false + case Declaration.IfElse(_, _) => false + case Declaration.Match(_, _, _) => false case Declaration.Lambda(_, body: NonBinding) => filterFn(body) case Declaration.Apply(f, args, _) => filterFn(f) && args.forall(filterFn) @@ -238,10 +294,11 @@ object Generators { // TODO we can't parse if since we get confused about it being a ternary expression val pat = genPattern(1, useUnion = true) - val comp = Gen.zip(genSpliceOrItem(dec, dec), pat, dec, Gen.option(dec)) + val comp = Gen + .zip(genSpliceOrItem(dec, dec), pat, dec, Gen.option(dec)) .map { case (a, b, c0, _) => val c = c0 match { - case tern@Declaration.Ternary(_, _, _) => + case tern @ Declaration.Ternary(_, _, _) => Declaration.Parens(tern)(emptyRegion) case not => not } @@ -253,10 +310,10 @@ object Generators { def dictGen(dec0: Gen[NonBinding]): Gen[Declaration.DictDecl] = { lazy val filterFn: NonBinding => Boolean = { - case Declaration.Annotation(_, _) => false - case Declaration.IfElse(_, _) => false - case Declaration.Match(_, _, _) => false - case Declaration.ApplyOp(_, _, _) => false + case Declaration.Annotation(_, _) => false + case Declaration.IfElse(_, _) => false + case Declaration.Match(_, _, _) => false + case Declaration.ApplyOp(_, _, _) => false case Declaration.Lambda(_, body: NonBinding) => filterFn(body) case Declaration.Apply(f, args, _) => filterFn(f) && args.forall(filterFn) @@ -268,10 +325,11 @@ object Generators { // TODO we can't parse if since we get confused about it being a ternary expression val pat = genPattern(1, useUnion = true) - val comp = Gen.zip(dec, dec, pat, dec, Gen.option(dec)) + val comp = Gen + .zip(dec, dec, pat, dec, Gen.option(dec)) .map { case (k, v, b, c0, _) => val c = c0 match { - case tern@Declaration.Ternary(_, _, _) => + case tern @ Declaration.Ternary(_, _, _) => Declaration.Parens(tern)(emptyRegion) case not => not } @@ -285,7 +343,9 @@ object Generators { decl.map { case n: Declaration.NonBinding => n match { - case v@(Declaration.Var(_) | Declaration.Parens(_) | Declaration.Apply(_, _, _)) => v + case v @ (Declaration.Var(_) | Declaration.Parens(_) | + Declaration.Apply(_, _, _)) => + v case notVar => Declaration.Parens(notVar)(emptyRegion) } case notVar => Declaration.Parens(notVar)(emptyRegion) @@ -297,18 +357,26 @@ object Generators { def isVar(d: Declaration): Boolean = d match { case Declaration.Var(_) => true - case _ => false + case _ => false } - def applyGen(fnGen: Gen[NonBinding], arg: Gen[NonBinding], dotApplyGen: Gen[Boolean]): Gen[Declaration.Apply] = { + def applyGen( + fnGen: Gen[NonBinding], + arg: Gen[NonBinding], + dotApplyGen: Gen[Boolean] + ): Gen[Declaration.Apply] = { import Declaration._ - Gen.lzy(for { - fn <- fnGen - dotApply <- dotApplyGen - useDot = dotApply && isVar(fn) // f.bar needs the fn to be a var - argsGen = if (useDot) arg.map(NonEmptyList.one(_)) else smallNonEmptyList(arg, 8) - args <- argsGen - } yield Apply(fn, args, ApplyKind.Parens)(emptyRegion)) // TODO this should pass if we use `foo.bar(a, b)` syntax + Gen.lzy( + for { + fn <- fnGen + dotApply <- dotApplyGen + useDot = dotApply && isVar(fn) // f.bar needs the fn to be a var + argsGen = + if (useDot) arg.map(NonEmptyList.one(_)) + else smallNonEmptyList(arg, 8) + args <- argsGen + } yield Apply(fn, args, ApplyKind.Parens)(emptyRegion) + ) // TODO this should pass if we use `foo.bar(a, b)` syntax } def applyOpGen(arg: Gen[NonBinding]): Gen[Declaration.ApplyOp] = @@ -326,20 +394,31 @@ object Generators { ApplyOp(protect(l), op, protect(r)) } - def bindGen[A, T](patGen: Gen[A], dec: Gen[NonBinding], tgen: Gen[T]): Gen[BindingStatement[A, NonBinding, T]] = - Gen.zip(patGen, dec, tgen) + def bindGen[A, T]( + patGen: Gen[A], + dec: Gen[NonBinding], + tgen: Gen[T] + ): Gen[BindingStatement[A, NonBinding, T]] = + Gen + .zip(patGen, dec, tgen) .map { case (b, value, in) => BindingStatement(b, value, in) } - def leftApplyGen(patGen: Gen[Pattern.Parsed], dec: Gen[NonBinding], bodyGen: Gen[Declaration]): Gen[Declaration.LeftApply] = - Gen.zip(patGen, dec, padding(bodyGen)) + def leftApplyGen( + patGen: Gen[Pattern.Parsed], + dec: Gen[NonBinding], + bodyGen: Gen[Declaration] + ): Gen[Declaration.LeftApply] = + Gen + .zip(patGen, dec, padding(bodyGen)) .map { case (p, value, in) => Declaration.LeftApply(p, emptyRegion, value, in) } def padding[T](tgen: Gen[T], min: Int = 0): Gen[Padding[T]] = - Gen.zip(Gen.choose(min, 10), tgen) + Gen + .zip(Gen.choose(min, 10), tgen) .map { case (e, t) => Padding(e, t) } def indented[T](tgen: Gen[T]): Gen[Indented[T]] = @@ -352,34 +431,37 @@ object Generators { for { args <- nonEmpty(bindIdentGen) body <- bodyGen - } yield Declaration.Lambda(args.map(Pattern.Var(_)), body)(emptyRegion) + } yield Declaration.Lambda(args.map(Pattern.Var(_)), body)(emptyRegion) def optIndent[A](genA: Gen[A]): Gen[OptIndent[A]] = { val indentation = Gen.choose(1, 10) indentation.flatMap { i => - - // TODO support parsing if foo: bar - //Gen.oneOf( - padding(genA.map(Indented(i, _)), min = 1).map(OptIndent.notSame(_)) - //, - //bodyGen.map(Left(_): OptIndent[Declaration])) + // TODO support parsing if foo: bar + // Gen.oneOf( + padding(genA.map(Indented(i, _)), min = 1).map(OptIndent.notSame(_)) + // , + // bodyGen.map(Left(_): OptIndent[Declaration])) } } - def ifElseGen(argGen0: Gen[NonBinding], bodyGen: Gen[Declaration]): Gen[Declaration.IfElse] = { + def ifElseGen( + argGen0: Gen[NonBinding], + bodyGen: Gen[Declaration] + ): Gen[Declaration.IfElse] = { import Declaration._ // args can't have raw annotations: val argGen = argGen0.map { - case ann@Annotation(_, _) => Parens(ann)(emptyRegion) - case notAnn => notAnn + case ann @ Annotation(_, _) => Parens(ann)(emptyRegion) + case notAnn => notAnn } val padBody = optIndent(bodyGen) val genIf: Gen[(NonBinding, OptIndent[Declaration])] = Gen.zip(argGen, padBody) - Gen.zip(nonEmptyN(genIf, 2), padBody) + Gen + .zip(nonEmptyN(genIf, 2), padBody) .map { case (ifs, elsec) => IfElse(ifs, elsec)(emptyRegion) } } @@ -387,14 +469,15 @@ object Generators { import Declaration._ val argGen = argGen0.map { - case lam@Lambda(_, _) => Parens(lam)(emptyRegion) - case ife@IfElse(_, _) => Parens(ife)(emptyRegion) - case tern@Ternary(_, _, _) => Parens(tern)(emptyRegion) - case matches@Matches(_, _) => Parens(matches)(emptyRegion) - case m@Match(_, _, _) => Parens(m)(emptyRegion) - case not => not + case lam @ Lambda(_, _) => Parens(lam)(emptyRegion) + case ife @ IfElse(_, _) => Parens(ife)(emptyRegion) + case tern @ Ternary(_, _, _) => Parens(tern)(emptyRegion) + case matches @ Matches(_, _) => Parens(matches)(emptyRegion) + case m @ Match(_, _, _) => Parens(m)(emptyRegion) + case not => not } - Gen.zip(argGen, argGen, argGen) + Gen + .zip(argGen, argGen, argGen) .map { case (t, c, f) => Ternary(t, c, f) } } @@ -406,25 +489,33 @@ object Generators { def toArg(p: Pattern.Parsed): Gen[Pattern.StructKind.Style.FieldKind] = p match { case Pattern.Var(b: Identifier.Bindable) => - Gen.oneOf(Gen.const(Pattern.StructKind.Style.FieldKind.Implicit(b)), - Gen.oneOf(bindIdentGen, Gen.const(b)) + Gen.oneOf( + Gen.const(Pattern.StructKind.Style.FieldKind.Implicit(b)), + Gen + .oneOf(bindIdentGen, Gen.const(b)) .map(Pattern.StructKind.Style.FieldKind.Explicit(_)) - ) + ) case Pattern.Annotation(p, _) => toArg(p) - case _ => + case _ => // if we don't have a var, we can't omit the key bindIdentGen.map(Pattern.StructKind.Style.FieldKind.Explicit(_)) } - lazy val args = tail.foldLeft(toArg(h) - .map(NonEmptyList.one)) { case (args, a) => - Gen.zip(args, toArg(a)).map { case (args, a) => NonEmptyList(a, args.toList) } + lazy val args = tail + .foldLeft( + toArg(h) + .map(NonEmptyList.one) + ) { case (args, a) => + Gen.zip(args, toArg(a)).map { case (args, a) => + NonEmptyList(a, args.toList) + } } .map(_.reverse) Gen.oneOf( Gen.const(Pattern.StructKind.Style.TupleLike), - Gen.lzy(args.map(Pattern.StructKind.Style.RecordLike(_)))) + Gen.lzy(args.map(Pattern.StructKind.Style.RecordLike(_))) + ) } def genStructKind(args: List[Pattern.Parsed]): Gen[Pattern.StructKind] = @@ -435,7 +526,8 @@ object Generators { }, Gen.zip(consIdentGen, genStyle(args)).map { case (n, s) => Pattern.StructKind.NamedPartial(n, s) - }) + } + ) def genPattern(depth: Int, useUnion: Boolean = true): Gen[Pattern.Parsed] = genPatternGen( @@ -443,60 +535,76 @@ object Generators { typeRefGen, depth, useUnion, - useAnnotation = false) + useAnnotation = false + ) - lazy val genStrPat: Gen[Pattern.StrPat] = { - val recurse = Gen.lzy(genStrPat) + lazy val genStrPat: Gen[Pattern.StrPat] = { + val recurse = Gen.lzy(genStrPat) - val genPart: Gen[Pattern.StrPart] = - Gen.oneOf( - lowerIdent.map(Pattern.StrPart.LitStr(_)), - bindIdentGen.map(Pattern.StrPart.NamedStr(_)), - bindIdentGen.map(Pattern.StrPart.NamedChar(_)), - Gen.const(Pattern.StrPart.WildStr), - Gen.const(Pattern.StrPart.WildChar)) - - def isWild(p: Pattern.StrPart): Boolean = - p match { - case Pattern.StrPart.LitStr(_) | - Pattern.StrPart.NamedChar(_) | - Pattern.StrPart.WildChar => false - case _ => true - } + val genPart: Gen[Pattern.StrPart] = + Gen.oneOf( + lowerIdent.map(Pattern.StrPart.LitStr(_)), + bindIdentGen.map(Pattern.StrPart.NamedStr(_)), + bindIdentGen.map(Pattern.StrPart.NamedChar(_)), + Gen.const(Pattern.StrPart.WildStr), + Gen.const(Pattern.StrPart.WildChar) + ) - def makeValid(nel: NonEmptyList[Pattern.StrPart]): NonEmptyList[Pattern.StrPart] = - nel match { - case NonEmptyList(_, Nil) => nel - case NonEmptyList(h1, h2 :: t) if isWild(h1) && isWild(h2) => - makeValid(NonEmptyList(h2, t)) - case NonEmptyList(Pattern.StrPart.LitStr(h1), Pattern.StrPart.LitStr(h2) :: t) => - makeValid(NonEmptyList(Pattern.StrPart.LitStr(h1 + h2), t)) - case NonEmptyList(h1, h2 :: t) => - NonEmptyList(h1, makeValid(NonEmptyList(h2, t)).toList) - } + def isWild(p: Pattern.StrPart): Boolean = + p match { + case Pattern.StrPart.LitStr(_) | Pattern.StrPart.NamedChar(_) | + Pattern.StrPart.WildChar => + false + case _ => true + } - for { - sz <- Gen.choose(1, 4) // don't get too giant, intersections blow up - inner <- nonEmptyN(genPart, sz) - p0 = Pattern.StrPat(makeValid(inner)) - notStr <- p0.toLiteralString.fold(Gen.const(p0))(_ => recurse) - } yield notStr - } + def makeValid( + nel: NonEmptyList[Pattern.StrPart] + ): NonEmptyList[Pattern.StrPart] = + nel match { + case NonEmptyList(_, Nil) => nel + case NonEmptyList(h1, h2 :: t) if isWild(h1) && isWild(h2) => + makeValid(NonEmptyList(h2, t)) + case NonEmptyList( + Pattern.StrPart.LitStr(h1), + Pattern.StrPart.LitStr(h2) :: t + ) => + makeValid(NonEmptyList(Pattern.StrPart.LitStr(h1 + h2), t)) + case NonEmptyList(h1, h2 :: t) => + NonEmptyList(h1, makeValid(NonEmptyList(h2, t)).toList) + } + + for { + sz <- Gen.choose(1, 4) // don't get too giant, intersections blow up + inner <- nonEmptyN(genPart, sz) + p0 = Pattern.StrPat(makeValid(inner)) + notStr <- p0.toLiteralString.fold(Gen.const(p0))(_ => recurse) + } yield notStr + } - def genPatternGen[N, T](genName: List[Pattern[N, T]] => Gen[N], genT: Gen[T], depth: Int, useUnion: Boolean, useAnnotation: Boolean): Gen[Pattern[N, T]] = { - val recurse = Gen.lzy(genPatternGen(genName, genT, depth - 1, useUnion, useAnnotation)) + def genPatternGen[N, T]( + genName: List[Pattern[N, T]] => Gen[N], + genT: Gen[T], + depth: Int, + useUnion: Boolean, + useAnnotation: Boolean + ): Gen[Pattern[N, T]] = { + val recurse = + Gen.lzy(genPatternGen(genName, genT, depth - 1, useUnion, useAnnotation)) val genVar = bindIdentGen.map(Pattern.Var(_)) val genWild = Gen.const(Pattern.WildCard) val genLitPat = genLit.map(Pattern.Literal(_)) if (depth <= 0) Gen.oneOf(genVar, genWild, genLitPat) else { - val genNamed = Gen.zip(bindIdentGen, recurse).map { case (n, p) => Pattern.Named(n, p) } - val genTyped = Gen.zip(recurse, genT) + val genNamed = Gen.zip(bindIdentGen, recurse).map { case (n, p) => + Pattern.Named(n, p) + } + val genTyped = Gen + .zip(recurse, genT) .map { case (p, t) => Pattern.Annotation(p, t) } - - val genStruct = for { + val genStruct = for { cnt <- Gen.choose(0, 6) args <- Gen.listOfN(cnt, recurse) nm <- genName(args) @@ -505,38 +613,50 @@ object Generators { def makeOneSplice(ps: List[Pattern.ListPart[Pattern[N, T]]]) = { val sz = ps.size if (sz == 0) Gen.const(ps) - else Gen.choose(0, sz - 1).flatMap { idx => - val splice = Gen.oneOf( - Gen.const(Pattern.ListPart.WildList), - bindIdentGen.map { v => Pattern.ListPart.NamedList(v) }) - - splice.map { v => ps.updated(idx, v) } - } + else + Gen.choose(0, sz - 1).flatMap { idx => + val splice = Gen.oneOf( + Gen.const(Pattern.ListPart.WildList), + bindIdentGen.map { v => Pattern.ListPart.NamedList(v) } + ) + + splice.map { v => ps.updated(idx, v) } + } } val genListItem: Gen[Pattern.ListPart[Pattern[N, T]]] = recurse.map(Pattern.ListPart.Item(_)) - val genList = Gen.choose(0, 5) + val genList = Gen + .choose(0, 5) .flatMap(Gen.listOfN(_, genListItem)) .flatMap { ls => - Gen.oneOf(true, false) + Gen + .oneOf(true, false) .flatMap { - case true => Gen.const(ls) + case true => Gen.const(ls) case false => makeOneSplice(ls) } } .map(Pattern.ListPat(_)) - val genUnion = Gen.choose(0, 2) + val genUnion = Gen + .choose(0, 2) .flatMap { sz => Gen.zip(recurse, recurse, Gen.listOfN(sz, recurse)) } - .map { - case (h0, h1, tail) => - Pattern.union(h0, h1 :: tail) + .map { case (h0, h1, tail) => + Pattern.union(h0, h1 :: tail) } val tailGens: List[Gen[Pattern[N, T]]] = - List(genVar, genWild, genNamed, genStrPat, genLitPat, genStruct, genList) + List( + genVar, + genWild, + genNamed, + genStrPat, + genLitPat, + genStruct, + genList + ) val withU = if (useUnion) genUnion :: tailGens else tailGens val withT = (if (useAnnotation) genTyped :: withU else withU).toArray @@ -545,20 +665,33 @@ object Generators { } } - def genCompiledPattern(depth: Int, useUnion: Boolean = true, useAnnotation: Boolean = true): Gen[Pattern[(PackageName, Identifier.Constructor), rankn.Type]] = + def genCompiledPattern( + depth: Int, + useUnion: Boolean = true, + useAnnotation: Boolean = true + ): Gen[Pattern[(PackageName, Identifier.Constructor), rankn.Type]] = genPatternGen( - { (_: List[Pattern[(PackageName, Identifier.Constructor), rankn.Type]]) => Gen.zip(packageNameGen, consIdentGen) }, - NTypeGen.genDepth03, depth, useUnion = useUnion, useAnnotation = useAnnotation) + { (_: List[Pattern[(PackageName, Identifier.Constructor), rankn.Type]]) => + Gen.zip(packageNameGen, consIdentGen) + }, + NTypeGen.genDepth03, + depth, + useUnion = useUnion, + useAnnotation = useAnnotation + ) - def matchGen(argGen0: Gen[NonBinding], bodyGen: Gen[Declaration]): Gen[Declaration.Match] = { + def matchGen( + argGen0: Gen[NonBinding], + bodyGen: Gen[Declaration] + ): Gen[Declaration.Match] = { import Declaration._ val padBody = optIndent(bodyGen) // args can't have raw annotations: val argGen = argGen0.map { - case ann@Annotation(_, _) => Parens(ann)(emptyRegion) - case notAnn => notAnn + case ann @ Annotation(_, _) => Parens(ann)(emptyRegion) + case notAnn => notAnn } val genCase: Gen[(Pattern.Parsed, OptIndent[Declaration])] = @@ -566,7 +699,10 @@ object Generators { for { cnt <- Gen.choose(1, 2) - kind <- Gen.frequency((10, Gen.const(RecursionKind.NonRecursive)), (1, Gen.const(RecursionKind.Recursive))) + kind <- Gen.frequency( + (10, Gen.const(RecursionKind.NonRecursive)), + (1, Gen.const(RecursionKind.Recursive)) + ) expr <- argGen cases <- optIndent(nonEmptyN(genCase, cnt)) } yield Match(kind, expr, cases)(emptyRegion) @@ -578,7 +714,9 @@ object Generators { val fixa = a match { // matches binds tighter than all these - case Lambda(_, _) | IfElse(_, _) | ApplyOp(_, _, _) | Match(_, _, _) | Ternary(_, _, _) => Parens(a)(emptyRegion) + case Lambda(_, _) | IfElse(_, _) | ApplyOp(_, _, _) | Match(_, _, _) | + Ternary(_, _, _) => + Parens(a)(emptyRegion) case _ => a } Matches(fixa, p)(emptyRegion) @@ -586,14 +724,15 @@ object Generators { val genLit: Gen[Lit] = { val str = for { - //q <- Gen.oneOf('\'', '"') - //str <- Arbitrary.arbitrary[String] + // q <- Gen.oneOf('\'', '"') + // str <- Arbitrary.arbitrary[String] str <- lowerIdent // TODO } yield Lit.Str(str) val char = Gen.choose(0, 0xd7ff).map { i => Lit.Chr.fromCodePoint(i) } - val bi = Arbitrary.arbitrary[BigInt].map { bi => Lit.Integer(bi.bigInteger) } + val bi = + Arbitrary.arbitrary[BigInt].map { bi => Lit.Integer(bi.bigInteger) } Gen.oneOf(str, bi, char) } @@ -610,19 +749,20 @@ object Generators { Gen.frequency( (1, consDeclGen), (2, varGen), - (1, genLit.map(Declaration.Literal(_)(emptyRegion)))) + (1, genLit.map(Declaration.Literal(_)(emptyRegion))) + ) def annGen(g: Gen[NonBinding]): Gen[Declaration.Annotation] = { import Declaration._ Gen.zip(typeRefGen, g).map { - case (t, r@(Var(_) | Apply(_, _, _) | Parens(_))) => Annotation(r, t)(emptyRegion) + case (t, r @ (Var(_) | Apply(_, _, _) | Parens(_))) => + Annotation(r, t)(emptyRegion) case (t, wrap) => Annotation(Parens(wrap)(emptyRegion), t)(emptyRegion) } } - /** - * Generate a Declaration that can be parsed as a pattern - */ + /** Generate a Declaration that can be parsed as a pattern + */ def patternDecl(depth: Int): Gen[NonBinding] = { import Declaration._ val recur = Gen.lzy(patternDecl(depth - 1)) @@ -630,12 +770,14 @@ object Generators { val applyCons = applyGen(consDeclGen, recur, Gen.const(false)) if (depth <= 0) unnestedDeclGen - else Gen.frequency( - (12, unnestedDeclGen), - (2, applyCons), - (1, recur.map(Parens(_)(emptyRegion))), - (1, annGen(recur)), - (1, genListLangCons(varGen, recur).map(ListDecl(_)(emptyRegion)))) + else + Gen.frequency( + (12, unnestedDeclGen), + (2, applyCons), + (1, recur.map(Parens(_)(emptyRegion))), + (1, annGen(recur)), + (1, genListLangCons(varGen, recur).map(ListDecl(_)(emptyRegion))) + ) } def simpleDecl(depth: Int): Gen[NonBinding] = { @@ -645,33 +787,45 @@ object Generators { val recur = Gen.lzy(simpleDecl(depth - 1)) if (depth <= 0) unnested - else Gen.frequency( - (13, unnested), - (2, lambdaGen(recur)), - (2, applyGen(recur)), - (1, applyOpGen(recur)), - (1, genStringDecl(recur)), - (1, listGen(recur)), - (1, dictGen(recur)), - (1, annGen(recur)), - (1, Gen.choose(0, 4).flatMap(Gen.listOfN(_, recur)).map(TupleCons(_)(emptyRegion))) - ) + else + Gen.frequency( + (13, unnested), + (2, lambdaGen(recur)), + (2, applyGen(recur)), + (1, applyOpGen(recur)), + (1, genStringDecl(recur)), + (1, listGen(recur)), + (1, dictGen(recur)), + (1, annGen(recur)), + ( + 1, + Gen + .choose(0, 4) + .flatMap(Gen.listOfN(_, recur)) + .map(TupleCons(_)(emptyRegion)) + ) + ) } def genRecordArg(dgen: Gen[NonBinding]): Gen[Declaration.RecordArg] = - Gen.zip(bindIdentGen, Gen.option(dgen)) + Gen + .zip(bindIdentGen, Gen.option(dgen)) .map { - case (b, None) => Declaration.RecordArg.Simple(b) + case (b, None) => Declaration.RecordArg.Simple(b) case (b, Some(decl)) => Declaration.RecordArg.Pair(b, decl) } - def genRecordDeclaration(dgen: Gen[NonBinding]): Gen[Declaration.RecordConstructor] = { + def genRecordDeclaration( + dgen: Gen[NonBinding] + ): Gen[Declaration.RecordConstructor] = { val args = for { tailSize <- Gen.choose(0, 4) args <- nonEmptyN(genRecordArg(dgen), tailSize) } yield args - Gen.zip(consIdentGen, args).map { case (c, a) => Declaration.RecordConstructor(c, a)(emptyRegion) } + Gen.zip(consIdentGen, args).map { case (c, a) => + Declaration.RecordConstructor(c, a)(emptyRegion) + } } def genNonBinding(depth: Int): Gen[NonBinding] = { @@ -682,28 +836,37 @@ object Generators { val recur = Gen.lzy(genDeclaration(depth - 1)) val recNon = Gen.lzy(genNonBinding(depth - 1)) if (depth <= 0) unnested - else Gen.frequency( - (14, unnested), - (2, lambdaGen(recNon)), - (2, applyGen(recNon)), - (1, applyOpGen(simpleDecl(depth - 1))), - (1, ifElseGen(recNon, recur)), - (1, ternaryGen(recNon)), - (1, genStringDecl(recNon)), - (1, listGen(recNon)), - (1, dictGen(recNon)), - (1, matchGen(recNon, recur)), - (1, matchesGen(recNon)), - (1, Gen.choose(0, 4).flatMap(Gen.listOfN(_, recNon)).map(TupleCons(_)(emptyRegion))), - (1, genRecordDeclaration(recNon)) - ) + else + Gen.frequency( + (14, unnested), + (2, lambdaGen(recNon)), + (2, applyGen(recNon)), + (1, applyOpGen(simpleDecl(depth - 1))), + (1, ifElseGen(recNon, recur)), + (1, ternaryGen(recNon)), + (1, genStringDecl(recNon)), + (1, listGen(recNon)), + (1, dictGen(recNon)), + (1, matchGen(recNon, recur)), + (1, matchesGen(recNon)), + ( + 1, + Gen + .choose(0, 4) + .flatMap(Gen.listOfN(_, recNon)) + .map(TupleCons(_)(emptyRegion)) + ), + (1, genRecordDeclaration(recNon)) + ) } def makeComment(c: CommentStatement[Padding[Declaration]]): Declaration = { import Declaration._ c.on.padded match { case nb: NonBinding => - CommentNB(CommentStatement(c.message, Padding(c.on.lines, nb)))(emptyRegion) + CommentNB(CommentStatement(c.message, Padding(c.on.lines, nb)))( + emptyRegion + ) case _ => Comment(c)(emptyRegion) } @@ -715,18 +878,29 @@ object Generators { val unnested = unnestedDeclGen val pat: Gen[Pattern.Parsed] = bindIdentGen.map(Pattern.Var(_)) - //val pat = genPattern(0) + // val pat = genPattern(0) val recur = Gen.lzy(genDeclaration(depth - 1)) val recNon = Gen.lzy(genNonBinding(depth - 1)) if (depth <= 0) unnested - else Gen.frequency( - (3, genNonBinding(depth)), - (1, commentGen(padding(recur, 1)).map(makeComment)), // make sure we have 1 space to prevent comments following each other - (1, defGen(Gen.zip(optIndent(recur), padding(recur, 1))).map(DefFn(_)(emptyRegion))), - (1, bindGen(pat, recNon, padding(recur, 1)).map(Binding(_)(emptyRegion))), - (1, leftApplyGen(pat, recNon, recur)) - ) + else + Gen.frequency( + (3, genNonBinding(depth)), + ( + 1, + commentGen(padding(recur, 1)).map(makeComment) + ), // make sure we have 1 space to prevent comments following each other + ( + 1, + defGen(Gen.zip(optIndent(recur), padding(recur, 1))) + .map(DefFn(_)(emptyRegion)) + ), + ( + 1, + bindGen(pat, recNon, padding(recur, 1)).map(Binding(_)(emptyRegion)) + ), + (1, leftApplyGen(pat, recNon, recur)) + ) } implicit val shrinkDecl: Shrink[Declaration] = @@ -739,7 +913,7 @@ object Generators { case Apply(fn, args, _) => val next = fn #:: args.toList.toStream next.flatMap(apply _) - case ao@ApplyOp(left, _, right) => + case ao @ ApplyOp(left, _, right) => left #:: ao.opVar #:: right #:: Stream.empty case Binding(b) => val next = b.value #:: b.in.padded #:: Stream.empty @@ -757,27 +931,30 @@ object Generators { // todo, we should really interleave shrinking r and b r #:: b.padded #:: Stream.empty case Match(_, _, args) => - args.get.toList.toStream.flatMap { - case (_, decl) => decl.get #:: apply(decl.get) + args.get.toList.toStream.flatMap { case (_, decl) => + decl.get #:: apply(decl.get) } case Matches(a, _) => a #:: apply(a) // the rest can't be shrunk - case Comment(c) => c.on.padded #:: Stream.empty - case CommentNB(c) => c.on.padded #:: Stream.empty + case Comment(c) => c.on.padded #:: Stream.empty + case CommentNB(c) => c.on.padded #:: Stream.empty case Lambda(_, body) => body #:: Stream.empty - case Literal(_) => Stream.empty - case Parens(_) => + case Literal(_) => Stream.empty + case Parens(_) => // by removing parens we can make invalid // expressions Stream.empty case TupleCons(Nil) => Stream.empty - case TupleCons(h :: tail) => h #:: TupleCons(tail)(emptyRegion) #:: apply(TupleCons(tail)(emptyRegion)) + case TupleCons(h :: tail) => + h #:: TupleCons(tail)(emptyRegion) #:: apply( + TupleCons(tail)(emptyRegion) + ) case Var(_) => Stream.empty case StringDecl(parts) => parts.toList.toStream.map { - case StringDecl.StrExpr(nb) => nb - case StringDecl.CharExpr(nb) => nb + case StringDecl.StrExpr(nb) => nb + case StringDecl.CharExpr(nb) => nb case StringDecl.Literal(r, str) => Literal(Lit.Str(str))(r) } case ListDecl(ListLang.Cons(items)) => @@ -792,30 +969,38 @@ object Generators { def head: Stream[Declaration] = args.head match { case RecordArg.Pair(n, d) => Stream(Var(n)(emptyRegion), d) - case RecordArg.Simple(n) => Stream(Var(n)(emptyRegion)) + case RecordArg.Simple(n) => Stream(Var(n)(emptyRegion)) } - def tailStream(of: NonEmptyList[RecordArg]): Stream[NonEmptyList[RecordArg]] = + def tailStream( + of: NonEmptyList[RecordArg] + ): Stream[NonEmptyList[RecordArg]] = NonEmptyList.fromList(of.tail) match { case None => Stream.empty case Some(tailArgs) => - tailArgs #:: tailStream(tailArgs) #::: tailStream(NonEmptyList(of.head, tailArgs.tail)) + tailArgs #:: tailStream(tailArgs) #::: tailStream( + NonEmptyList(of.head, tailArgs.tail) + ) } Var(n)(emptyRegion) #:: head #::: - tailStream(args).map(RecordConstructor(n, _)(emptyRegion): Declaration) // type annotation for scala 2.11 + tailStream(args).map( + RecordConstructor(n, _)(emptyRegion): Declaration + ) // type annotation for scala 2.11 } }) def interleave[A](s1: Stream[A], s2: Stream[A]): Stream[A] = - if (s1.isEmpty) s2 else if (s2.isEmpty) s1 else { + if (s1.isEmpty) s2 + else if (s2.isEmpty) s1 + else { s1.head #:: interleave(s2, s1.tail) } def interleaveAll[A](ss: List[Stream[A]]): Stream[A] = ss match { - case Nil => Stream.empty + case Nil => Stream.empty case one :: Nil => one case twoOrMore => val (l, r) = twoOrMore.splitAt(twoOrMore.size / 2) @@ -825,7 +1010,8 @@ object Generators { def shrinkOne[A: Shrink](list: List[A]): Stream[List[A]] = interleaveAll((0 until list.size).toList.map { idx => val aIdx = list(idx) - implicitly[Shrink[A]].shrink(aIdx) + implicitly[Shrink[A]] + .shrink(aIdx) .map { a => list.updated(idx, a) } @@ -836,49 +1022,55 @@ object Generators { case Nil | _ :: Nil => Stream.empty case twoOrMore => (0 until twoOrMore.size).toStream.map { idx => - list.take(idx) ::: list.drop(idx + 1) + list.take(idx) ::: list.drop(idx + 1) } } implicit def shrinkPattern[N, T]: Shrink[Pattern[N, T]] = { lazy val res: Shrink[Pattern[N, T]] = Shrink(new Function1[Pattern[N, T], Stream[Pattern[N, T]]] { - def apply(p: Pattern[N, T]) = - p match { - case Pattern.WildCard => Stream.empty - case Pattern.Var(_) => Pattern.WildCard #:: Stream.empty - case Pattern.Annotation(pattern, _) => pattern #:: Stream.empty - case Pattern.Named(_, pat) => pat #:: Stream.empty - case Pattern.PositionalStruct(n, params) => - // shrink all the params - shrinkOne(params)(res).map(Pattern.PositionalStruct(n, _)) - case Pattern.Literal(Lit.Str(s)) => - implicitly[Shrink[String]].shrink(s).map(s => Pattern.Literal(Lit(s))) - case Pattern.Literal(Lit.Integer(s)) => - implicitly[Shrink[BigInt]].shrink(BigInt(s)) - .map(s => Pattern.Literal(Lit.Integer(s.bigInteger))) - case Pattern.Literal(_) => Stream.empty - case Pattern.ListPat(ls) => - if (ls.isEmpty) Stream.empty - else (Pattern.ListPat(ls.tail) #:: Stream.empty) - case Pattern.StrPat(ls) => - if (ls.tail.isEmpty) Stream.empty - else (Pattern.StrPat(NonEmptyList.fromListUnsafe(ls.tail)) #:: Stream.empty) - case u@Pattern.Union(_, _) => - val flat = Pattern.flatten(u).toList - val sameLen = - shrinkOne[Pattern[N, T]](flat)(res) - .map { us => - Pattern.union(us.head, us.tail) - } - // unions have 2 or more, so this won't throw - val oneLess = dropItemList(flat).map { smaller => - Pattern.union(smaller.head, smaller.tail) - } + def apply(p: Pattern[N, T]) = + p match { + case Pattern.WildCard => Stream.empty + case Pattern.Var(_) => Pattern.WildCard #:: Stream.empty + case Pattern.Annotation(pattern, _) => pattern #:: Stream.empty + case Pattern.Named(_, pat) => pat #:: Stream.empty + case Pattern.PositionalStruct(n, params) => + // shrink all the params + shrinkOne(params)(res).map(Pattern.PositionalStruct(n, _)) + case Pattern.Literal(Lit.Str(s)) => + implicitly[Shrink[String]] + .shrink(s) + .map(s => Pattern.Literal(Lit(s))) + case Pattern.Literal(Lit.Integer(s)) => + implicitly[Shrink[BigInt]] + .shrink(BigInt(s)) + .map(s => Pattern.Literal(Lit.Integer(s.bigInteger))) + case Pattern.Literal(_) => Stream.empty + case Pattern.ListPat(ls) => + if (ls.isEmpty) Stream.empty + else (Pattern.ListPat(ls.tail) #:: Stream.empty) + case Pattern.StrPat(ls) => + if (ls.tail.isEmpty) Stream.empty + else + (Pattern.StrPat( + NonEmptyList.fromListUnsafe(ls.tail) + ) #:: Stream.empty) + case u @ Pattern.Union(_, _) => + val flat = Pattern.flatten(u).toList + val sameLen = + shrinkOne[Pattern[N, T]](flat)(res) + .map { us => + Pattern.union(us.head, us.tail) + } + // unions have 2 or more, so this won't throw + val oneLess = dropItemList(flat).map { smaller => + Pattern.union(smaller.head, smaller.tail) + } - interleave(oneLess, sameLen) - } - }) + interleave(oneLess, sameLen) + } + }) res } @@ -888,13 +1080,14 @@ object Generators { import Statement._ def apply(s: Statement): Stream[Statement] = s match { - case Bind(bs@BindingStatement(_, d, _)) => + case Bind(bs @ BindingStatement(_, d, _)) => shrinkDecl.shrink(d).collect { case sd: NonBinding => Bind(bs.copy(value = sd))(emptyRegion) } case Def(ds) => val body = ds.result - body.traverse(shrinkDecl.shrink(_)) + body + .traverse(shrinkDecl.shrink(_)) .map { bod => Def(ds.copy(result = bod))(emptyRegion) } @@ -902,17 +1095,21 @@ object Generators { } }) - val constructorGen: Gen[(Identifier.Constructor, List[(Identifier.Bindable, Option[TypeRef])])] = + val constructorGen: Gen[ + (Identifier.Constructor, List[(Identifier.Bindable, Option[TypeRef])]) + ] = for { name <- consIdentGen args <- smallList(argGen) } yield (name, args) val genTypeArgs: Gen[List[(TypeRef.TypeVar, Option[Kind.Arg])]] = - smallList(Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKindArg))).map(_.distinctBy(_._1)) + smallList(Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKindArg))) + .map(_.distinctBy(_._1)) val genStruct: Gen[Statement] = - Gen.zip(constructorGen, genTypeArgs) + Gen + .zip(constructorGen, genTypeArgs) .map { case ((name, args), ta) => Statement.Struct(name, NonEmptyList.fromList(ta), args)(emptyRegion) } @@ -946,14 +1143,22 @@ object Generators { // TODO make more powerful val pat: Gen[Pattern.Parsed] = genPattern(1) Gen.frequency( - (1, bindGen(pat, nonB, Gen.const(())).map(Statement.Bind(_)(emptyRegion))), + ( + 1, + bindGen(pat, nonB, Gen.const(())).map(Statement.Bind(_)(emptyRegion)) + ), (1, commentGen(Gen.const(())).map(Statement.Comment(_)(emptyRegion))), (1, defGen(optIndent(decl)).map(Statement.Def(_)(emptyRegion))), (1, genStruct), (1, genExternalStruct), (1, genExternalDef), (1, genEnum), - (1, padding(Gen.const(()), 1).map(Statement.PaddingStatement(_)(emptyRegion)))) + ( + 1, + padding(Gen.const(()), 1) + .map(Statement.PaddingStatement(_)(emptyRegion)) + ) + ) } def genStatements(depth: Int, maxLength: Int): Gen[List[Statement]] = { @@ -965,12 +1170,22 @@ object Generators { */ def combineDuplicates(stmts: List[Statement]): List[Statement] = stmts match { - case Nil => Nil + case Nil => Nil case h :: Nil => h :: Nil - case PaddingStatement(Padding(a, _)) :: PaddingStatement(Padding(b, _)) :: rest => - combineDuplicates(PaddingStatement(Padding(a + b, ()))(emptyRegion) :: rest) - case Comment(CommentStatement(lines1, _)) :: Comment(CommentStatement(lines2, _)) :: rest => - combineDuplicates(Comment(CommentStatement(lines1 ::: lines2, ()))(emptyRegion) :: rest) + case PaddingStatement(Padding(a, _)) :: PaddingStatement( + Padding(b, _) + ) :: rest => + combineDuplicates( + PaddingStatement(Padding(a + b, ()))(emptyRegion) :: rest + ) + case Comment(CommentStatement(lines1, _)) :: Comment( + CommentStatement(lines2, _) + ) :: rest => + combineDuplicates( + Comment(CommentStatement(lines1 ::: lines2, ()))( + emptyRegion + ) :: rest + ) case h1 :: rest => h1 :: combineDuplicates(rest) } @@ -1007,24 +1222,27 @@ object Generators { Gen.oneOf( bindIdentGen.map(ExportedName.Binding(_, ())), consIdentGen.map(ExportedName.TypeName(_, ())), - consIdentGen.map(ExportedName.Constructor(_, ()))) + consIdentGen.map(ExportedName.Constructor(_, ())) + ) def smallList[A](g: Gen[A]): Gen[List[A]] = - Gen.choose(0, 8).flatMap(Gen.listOfN(_, g)) def smallNonEmptyList[A](g: Gen[A], maxLen: Int): Gen[NonEmptyList[A]] = // bias to small numbers - Gen.geometric(2.0) + Gen + .geometric(2.0) .flatMap { case n if n <= 0 => g.map(NonEmptyList.one) case n => - Gen.zip(g, Gen.listOfN((n - 1) min (maxLen - 1), g)) + Gen + .zip(g, Gen.listOfN((n - 1) min (maxLen - 1), g)) .map { case (h, t) => NonEmptyList(h, t) } } def smallDistinctByList[A, B](g: Gen[A])(fn: A => B): Gen[List[A]] = - Gen.choose(0, 8) + Gen + .choose(0, 8) .flatMap(Gen.listOfN(_, g)) .map(graph.Tree.distinctBy(_)(fn)) @@ -1037,8 +1255,11 @@ object Generators { body <- genStatements(depth, 10) } yield Package(p, imports, exports, body) - - def genDefinedType[A](p: PackageName, inner: Gen[A], genType: Gen[rankn.Type]): Gen[rankn.DefinedType[A]] = + def genDefinedType[A]( + p: PackageName, + inner: Gen[A], + genType: Gen[rankn.Type] + ): Gen[rankn.DefinedType[A]] = for { t <- typeNameGen paramKeys <- smallList(NTypeGen.genBound).map(_.distinct) @@ -1063,13 +1284,16 @@ object Generators { } yield ExportedName.Binding(n, Referant.Value(t)) te.allDefinedTypes match { - case Nil => bind(NTypeGen.genDepth03) + case Nil => bind(NTypeGen.genDepth03) case dts0 => // only make one of each type val dts = dts0.map { dt => (dt.name.ident, dt) }.toMap.values.toList - val b = bind(Gen.oneOf(NTypeGen.genDepth03, Gen.oneOf(dts).map(_.toTypeTyConst))) - val genExpT = Gen.oneOf(dts) + val b = bind( + Gen.oneOf(NTypeGen.genDepth03, Gen.oneOf(dts).map(_.toTypeTyConst)) + ) + val genExpT = Gen + .oneOf(dts) .map { dt => ExportedName.TypeName(dt.name.ident, Referant.DefinedT(dt)) } @@ -1080,7 +1304,10 @@ object Generators { val c = for { dt <- Gen.oneOf(nonEmpty) cf <- Gen.oneOf(dt.constructors) - } yield ExportedName.Constructor(cf.name, Referant.Constructor(dt, cf)) + } yield ExportedName.Constructor( + cf.name, + Referant.Constructor(dt, cf) + ) Gen.oneOf(b, genExpT, c) } } @@ -1089,55 +1316,118 @@ object Generators { val interfaceGen: Gen[Package.Interface] = for { p <- packageNameGen - te <- typeEnvGen(p, Gen.oneOf(Kind.Type.co, Kind.Type.phantom, Kind.Type.contra, Kind.Type.in)) + te <- typeEnvGen( + p, + Gen.oneOf( + Kind.Type.co, + Kind.Type.phantom, + Kind.Type.contra, + Kind.Type.in + ) + ) exs0 <- smallList(exportGen(te)) - exs = exs0.map { ex => (ex.name, ex) }.toMap.values.toList // don't duplicate exported names + exs = exs0 + .map { ex => (ex.name, ex) } + .toMap + .values + .toList // don't duplicate exported names } yield Package(p, Nil, exs, ()) - /** - * This is a totally random, and not well typed expression. - * It is suitable for some tests, but it is not a valid output - * of a typechecking process - */ - def genTypedExpr[A](genTag: Gen[A], depth: Int, typeGen: Gen[rankn.Type]): Gen[TypedExpr[A]] = { + /** This is a totally random, and not well typed expression. It is suitable + * for some tests, but it is not a valid output of a typechecking process + */ + def genTypedExpr[A]( + genTag: Gen[A], + depth: Int, + typeGen: Gen[rankn.Type] + ): Gen[TypedExpr[A]] = { val recurse = Gen.lzy(genTypedExpr(genTag, depth - 1, typeGen)) - val lit = Gen.zip(genLit, NTypeGen.genDepth03, genTag).map { case (l, tpe, tag) => TypedExpr.Literal(l, tpe, tag) } + val lit = Gen.zip(genLit, NTypeGen.genDepth03, genTag).map { + case (l, tpe, tag) => TypedExpr.Literal(l, tpe, tag) + } // only literal doesn't recurse if (depth <= 0) lit else { val genGeneric = - Gen.zip(Generators.nonEmpty(Gen.zip(NTypeGen.genBound, NTypeGen.genKind)), recurse) + Gen + .zip( + Generators.nonEmpty(Gen.zip(NTypeGen.genBound, NTypeGen.genKind)), + recurse + ) .map { case (vs, t) => TypedExpr.forAll(vs, t) } val ann = - Gen.zip(recurse, typeGen) + Gen + .zip(recurse, typeGen) .map { case (te, tpe) => TypedExpr.Annotation(te, tpe) } val lam = - Gen.zip(smallNonEmptyList(Gen.zip(bindIdentGen, typeGen), 8), recurse, genTag) - .map { case (args, res, tag) => TypedExpr.AnnotatedLambda(args, res, tag) } + Gen + .zip( + smallNonEmptyList(Gen.zip(bindIdentGen, typeGen), 8), + recurse, + genTag + ) + .map { case (args, res, tag) => + TypedExpr.AnnotatedLambda(args, res, tag) + } val localGen = - Gen.zip(bindIdentGen, typeGen, genTag) + Gen + .zip(bindIdentGen, typeGen, genTag) .map { case (n, t, tag) => TypedExpr.Local(n, t, tag) } val globalGen = - Gen.zip(packageNameGen, identifierGen, typeGen, genTag) + Gen + .zip(packageNameGen, identifierGen, typeGen, genTag) .map { case (p, n, t, tag) => TypedExpr.Global(p, n, t, tag) } val app = - Gen.zip(recurse, smallNonEmptyList(recurse, 8), typeGen, genTag) - .map { case (fn, args, tpe, tag) => TypedExpr.App(fn, args, tpe, tag) } + Gen + .zip(recurse, smallNonEmptyList(recurse, 8), typeGen, genTag) + .map { case (fn, args, tpe, tag) => + TypedExpr.App(fn, args, tpe, tag) + } val let = - Gen.zip(bindIdentGen, recurse, recurse, Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive), genTag) - .map { case (n, ex, in, rec, tag) => TypedExpr.Let(n, ex, in, rec, tag) } + Gen + .zip( + bindIdentGen, + recurse, + recurse, + Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive), + genTag + ) + .map { case (n, ex, in, rec, tag) => + TypedExpr.Let(n, ex, in, rec, tag) + } val matchGen = - Gen.zip(recurse, Gen.choose(1, 4).flatMap(nonEmptyN(Gen.zip(genCompiledPattern(depth), recurse), _)), genTag) - .map { case (arg, branches, tag) => TypedExpr.Match(arg, branches, tag) } + Gen + .zip( + recurse, + Gen + .choose(1, 4) + .flatMap( + nonEmptyN(Gen.zip(genCompiledPattern(depth), recurse), _) + ), + genTag + ) + .map { case (arg, branches, tag) => + TypedExpr.Match(arg, branches, tag) + } - Gen.oneOf(genGeneric, ann, lam, localGen, globalGen, app, let, lit, matchGen) + Gen.oneOf( + genGeneric, + ann, + lam, + localGen, + globalGen, + app, + let, + lit, + matchGen + ) } } @@ -1154,7 +1444,8 @@ object Generators { def loop(idx: Int): Gen[List[Int]] = if (idx >= size) Gen.const(Nil) else - Gen.zip(Gen.choose(idx, size - 1), loop(idx + 1)) + Gen + .zip(Gen.choose(idx, size - 1), loop(idx + 1)) .map { case (h, tail) => h :: tail } loop(0).map { swaps => @@ -1166,21 +1457,32 @@ object Generators { } ary.toList } - } + } - def genOnePackage[A](genA: Gen[A], existing: Map[PackageName, Package.Typed[A]]): Gen[Package.Typed[A]] = { + def genOnePackage[A]( + genA: Gen[A], + existing: Map[PackageName, Package.Typed[A]] + ): Gen[Package.Typed[A]] = { val genDeps: Gen[Map[PackageName, Package.Typed[A]]] = Gen.frequency( - (5, Gen.const(Map.empty)), // usually have no deps, otherwise the graph gets enormous + ( + 5, + Gen.const(Map.empty) + ), // usually have no deps, otherwise the graph gets enormous (1, shuffle(existing.toList).map(_.take(2).toMap)) ) - def impFromExp(exp: List[(Package.Interface, ExportedName[Referant[Kind.Arg]])]): Gen[List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]]] = - exp.groupBy(_._1) + def impFromExp( + exp: List[(Package.Interface, ExportedName[Referant[Kind.Arg]])] + ): Gen[List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]]] = + exp + .groupBy(_._1) .toList .traverse { case (p, exps) => - val genImps: Gen[List[ImportedName[NonEmptyList[Referant[Kind.Arg]]]]] = - exps.groupBy(_._2.name) + val genImps + : Gen[List[ImportedName[NonEmptyList[Referant[Kind.Arg]]]]] = + exps + .groupBy(_._2.name) .iterator .toList .traverse { case (ident, exps) => @@ -1210,7 +1512,9 @@ object Generators { } } - val genImports: Gen[List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]]] = + val genImports: Gen[ + List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]] + ] = genDeps.flatMap { packs => val exps: List[(Package.Interface, ExportedName[Referant[Kind.Arg]])] = (for { @@ -1225,90 +1529,135 @@ object Generators { } yield imp } - def definedTypesFromImp(i: Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]): List[rankn.Type.Const] = + def definedTypesFromImp( + i: Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]] + ): List[rankn.Type.Const] = i.items.toList.flatMap { in => in.tag.toList.flatMap { - case Referant.DefinedT(dt) => dt.toTypeConst :: Nil + case Referant.DefinedT(dt) => dt.toTypeConst :: Nil case Referant.Constructor(dt, _) => dt.toTypeConst :: Nil - case Referant.Value(_) => Nil + case Referant.Value(_) => Nil } } - def genTypeEnv(pn: PackageName, - imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]]): StateT[Gen, (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable]), Unit] = - StateT.get[Gen, (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable])] - .flatMap { case (te, extDefs) => - StateT.liftF(Gen.choose(0, 9)) - .flatMap { - case 0 => - // 1 in 10 chance of stopping - StateT.pure[Gen, (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable]), Unit](()) - case _ => - // add something: - val tyconsts = - te.allDefinedTypes.map(_.toTypeConst) ++ - imps.flatMap(definedTypesFromImp) - val theseTypes = NTypeGen.genDepth(4, if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts))) - val genV: Gen[Kind.Arg] = - Gen.oneOf(Kind.Type.co, Kind.Type.contra, Kind.Type.in, Kind.Type.phantom) - val genDT = genDefinedType(pn, genV, theseTypes) - val genEx: Gen[(Identifier.Bindable, rankn.Type)] = - Gen.zip(bindIdentGen, theseTypes) - - // we can do one of the following: - // 1: add an external value - // 2: add a defined type - StateT.liftF(Gen.frequency( - (5, genDT.map { dt => (te.addDefinedTypeAndConstructors(dt), extDefs) }), - (1, genEx.map { case (b, t) => (te.addExternalValue(pn, b, t), extDefs + b) }))) - .flatMap(StateT.set(_)) - } - } + def genTypeEnv( + pn: PackageName, + imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]] + ): StateT[Gen, (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable]), Unit] = + StateT + .get[Gen, (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable])] + .flatMap { case (te, extDefs) => + StateT + .liftF(Gen.choose(0, 9)) + .flatMap { + case 0 => + // 1 in 10 chance of stopping + StateT.pure[ + Gen, + (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable]), + Unit + ](()) + case _ => + // add something: + val tyconsts = + te.allDefinedTypes.map(_.toTypeConst) ++ + imps.flatMap(definedTypesFromImp) + val theseTypes = NTypeGen.genDepth( + 4, + if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts)) + ) + val genV: Gen[Kind.Arg] = + Gen.oneOf( + Kind.Type.co, + Kind.Type.contra, + Kind.Type.in, + Kind.Type.phantom + ) + val genDT = genDefinedType(pn, genV, theseTypes) + val genEx: Gen[(Identifier.Bindable, rankn.Type)] = + Gen.zip(bindIdentGen, theseTypes) + + // we can do one of the following: + // 1: add an external value + // 2: add a defined type + StateT + .liftF( + Gen.frequency( + ( + 5, + genDT.map { dt => + (te.addDefinedTypeAndConstructors(dt), extDefs) + } + ), + ( + 1, + genEx.map { case (b, t) => + (te.addExternalValue(pn, b, t), extDefs + b) + } + ) + ) + ) + .flatMap(StateT.set(_)) + } + } - def genLets(te: rankn.TypeEnv[Kind.Arg], - exts: Set[Identifier.Bindable]): Gen[List[(Identifier.Bindable, RecursionKind, TypedExpr[A])]] = { - val allTC = te.allDefinedTypes.map(_.toTypeConst) - val theseTypes = NTypeGen.genDepth(4, if (allTC.isEmpty) None else Some(Gen.oneOf(allTC))) - val oneLet = Gen.zip(bindIdentGen.filter { b => !exts(b) }, - Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive), - genTypedExpr(genA, 4, theseTypes)) + def genLets( + te: rankn.TypeEnv[Kind.Arg], + exts: Set[Identifier.Bindable] + ): Gen[List[(Identifier.Bindable, RecursionKind, TypedExpr[A])]] = { + val allTC = te.allDefinedTypes.map(_.toTypeConst) + val theseTypes = NTypeGen.genDepth( + 4, + if (allTC.isEmpty) None else Some(Gen.oneOf(allTC)) + ) + val oneLet = Gen.zip( + bindIdentGen.filter { b => !exts(b) }, + Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive), + genTypedExpr(genA, 4, theseTypes) + ) - Gen.choose(0, 6).flatMap(Gen.listOfN(_, oneLet)) - } + Gen.choose(0, 6).flatMap(Gen.listOfN(_, oneLet)) + } def genProg( - pn: PackageName, - imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]]): Gen[Program[rankn.TypeEnv[Kind.Arg], TypedExpr[A], Any]] = - genTypeEnv(pn, imps) - .runS((rankn.TypeEnv.empty, Set.empty)) - .flatMap { case (te, b) => - genLets(te, b).map(Program(te, _, b.toList.sorted, ())) - } + pn: PackageName, + imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]] + ): Gen[Program[rankn.TypeEnv[Kind.Arg], TypedExpr[A], Any]] = + genTypeEnv(pn, imps) + .runS((rankn.TypeEnv.empty, Set.empty)) + .flatMap { case (te, b) => + genLets(te, b).map(Program(te, _, b.toList.sorted, ())) + } /* * Exports are types, constructors, or values */ - def genExports(pn: PackageName, p: Program[rankn.TypeEnv[Kind.Arg], TypedExpr[A], Any]): Gen[List[ExportedName[Referant[Kind.Arg]]]] = { + def genExports( + pn: PackageName, + p: Program[rankn.TypeEnv[Kind.Arg], TypedExpr[A], Any] + ): Gen[List[ExportedName[Referant[Kind.Arg]]]] = { def expnames: List[ExportedName[Referant[Kind.Arg]]] = p.lets.map { case (n, _, te) => ExportedName.Binding(n, Referant.Value(te.getType)) } def exts: List[ExportedName[Referant[Kind.Arg]]] = p.externalDefs.flatMap { n => - p.types.getValue(pn, n).map { t => ExportedName.Binding(n, Referant.Value(t)) } + p.types.getValue(pn, n).map { t => + ExportedName.Binding(n, Referant.Value(t)) + } } def cons: List[ExportedName[Referant[Kind.Arg]]] = p.types.allDefinedTypes.flatMap { dt => if (dt.packageName == pn) { - val dtex = ExportedName.TypeName(dt.name.ident, Referant.DefinedT(dt)) + val dtex = + ExportedName.TypeName(dt.name.ident, Referant.DefinedT(dt)) val cons = dt.constructors.map { cf => ExportedName.Constructor(cf.name, Referant.Constructor(dt, cf)) } dtex :: cons - } - else Nil + } else Nil } for { @@ -1327,21 +1676,33 @@ object Generators { } yield Package(pn, imps, exps, prog) } - def genPackagesSt[A](genA: Gen[A], maxSize: Int): StateT[Gen, Map[PackageName, Package.Typed[A]], Unit] = - StateT.get[Gen, Map[PackageName, Package.Typed[A]]] + def genPackagesSt[A]( + genA: Gen[A], + maxSize: Int + ): StateT[Gen, Map[PackageName, Package.Typed[A]], Unit] = + StateT + .get[Gen, Map[PackageName, Package.Typed[A]]] .flatMap { m => if (m.size >= maxSize) StateT.pure(()) else { // make one more and try again for { - p <- StateT.liftF[Gen, Map[PackageName, Package.Typed[A]], Package.Typed[A]](genOnePackage(genA, m)) - _ <- StateT.set[Gen, Map[PackageName, Package.Typed[A]]](m.updated(p.name, p)) + p <- StateT + .liftF[Gen, Map[PackageName, Package.Typed[A]], Package.Typed[A]]( + genOnePackage(genA, m) + ) + _ <- StateT.set[Gen, Map[PackageName, Package.Typed[A]]]( + m.updated(p.name, p) + ) _ <- genPackagesSt(genA, maxSize) } yield () } } - def genPackage[A](genA: Gen[A], maxSize: Int): Gen[Map[PackageName, Package.Typed[A]]] = + def genPackage[A]( + genA: Gen[A], + maxSize: Int + ): Gen[Map[PackageName, Package.Typed[A]]] = genPackagesSt(genA, maxSize).runS(Map.empty) object Exprs { @@ -1352,7 +1713,12 @@ object Generators { Gen.frequency( (1, Gen.zip(genLit, genA).map { case (l, t) => Literal(l, t) }), (1, Gen.zip(bindIdentGen, genA).map { case (b, t) => Local(b, t) }), - (1, Gen.zip(NTypeGen.packageNameGen, identifierGen, genA).map { case (p, i, t) => Global(p, i, t) }) + ( + 1, + Gen.zip(NTypeGen.packageNameGen, identifierGen, genA).map { + case (p, i, t) => Global(p, i, t) + } + ) ) if (depth <= 0) roots @@ -1360,19 +1726,69 @@ object Generators { val recur = Gen.lzy(gen(genA, depth - 1)) Gen.frequency( (1, roots), - (1, Gen.zip(recur, NTypeGen.genDepth03, genA).map { case (e, t, tag) => Annotation(e, t, tag) } ), - (1, Gen.zip(smallNonEmptyList(Gen.zip(NTypeGen.genBound, NTypeGen.genKind), 4), recur).map { case (ts, in) => - Generic(ts, in) - }), - (2, Gen.zip(recur, smallNonEmptyList(recur, 5), genA).map { case (fn, as, t) => App(fn, as, t) }), - (2, Gen.zip(smallNonEmptyList(Gen.zip(bindIdentGen, Gen.option(NTypeGen.genDepth03)), 4), recur, genA).map { case (as, e, t) => - Lambda(as, e, t) - }), - (4, Gen.zip(bindIdentGen, recur, recur, Gen.oneOf(RecursionKind.Recursive, RecursionKind.NonRecursive), genA) - .map { case (a, e, in, r, t) => Let(a, e, in, r, t) }), - (1, Gen.zip(recur, - smallNonEmptyList(Gen.zip(genCompiledPattern(4), recur), 3), - genA).map { case (a, bs, t) => Match(a, bs, t)}) + ( + 1, + Gen.zip(recur, NTypeGen.genDepth03, genA).map { case (e, t, tag) => + Annotation(e, t, tag) + } + ), + ( + 1, + Gen + .zip( + smallNonEmptyList( + Gen.zip(NTypeGen.genBound, NTypeGen.genKind), + 4 + ), + recur + ) + .map { case (ts, in) => + Generic(ts, in) + } + ), + ( + 2, + Gen.zip(recur, smallNonEmptyList(recur, 5), genA).map { + case (fn, as, t) => App(fn, as, t) + } + ), + ( + 2, + Gen + .zip( + smallNonEmptyList( + Gen.zip(bindIdentGen, Gen.option(NTypeGen.genDepth03)), + 4 + ), + recur, + genA + ) + .map { case (as, e, t) => + Lambda(as, e, t) + } + ), + ( + 4, + Gen + .zip( + bindIdentGen, + recur, + recur, + Gen.oneOf(RecursionKind.Recursive, RecursionKind.NonRecursive), + genA + ) + .map { case (a, e, in, r, t) => Let(a, e, in, r, t) } + ), + ( + 1, + Gen + .zip( + recur, + smallNonEmptyList(Gen.zip(genCompiledPattern(4), recur), 3), + genA + ) + .map { case (a, bs, t) => Match(a, bs, t) } + ) ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/GenJson.scala b/core/src/test/scala/org/bykn/bosatsu/GenJson.scala index 1b3a24825..a527929fe 100644 --- a/core/src/test/scala/org/bykn/bosatsu/GenJson.scala +++ b/core/src/test/scala/org/bykn/bosatsu/GenJson.scala @@ -7,33 +7,38 @@ object GenJson { val genJsonNumber: Gen[Json.JNumberStr] = { def cat(gs: List[Gen[String]]): Gen[String] = gs match { - case Nil => Gen.const("") + case Nil => Gen.const("") case h :: tail => Gen.zip(h, cat(tail)).map { case (a, b) => a + b } } val digit09 = Gen.oneOf('0' to '9').map(_.toString) val digit19 = Gen.oneOf('1' to '9').map(_.toString) val digits = Gen.listOf(digit09).map(_.mkString) - val digits1 = Gen.zip(digit09, Gen.listOf(digit09)).map { case (h, t) => (h :: t).mkString } + val digits1 = Gen.zip(digit09, Gen.listOf(digit09)).map { case (h, t) => + (h :: t).mkString + } val int = Gen.frequency( (1, Gen.const("0")), - (20, Gen.zip(digit19, digits).map { case (h, t) => h + t })) + (20, Gen.zip(digit19, digits).map { case (h, t) => h + t }) + ) val frac = digits1.map("." + _) def opt(g: Gen[String]): Gen[String] = Gen.oneOf(true, false).flatMap { - case true => g + case true => g case false => Gen.const("") } val exp = cat(List(Gen.oneOf("e", "E"), opt(Gen.oneOf("+", "-")), digits1)) - cat(List(opt(Gen.const("-")), int, opt(frac), opt(exp))).map(Json.JNumberStr(_)) + cat(List(opt(Gen.const("-")), int, opt(frac), opt(exp))) + .map(Json.JNumberStr(_)) } def genJson(depth: Int): Gen[Json] = { val genString = Gen.listOf(Gen.choose(1.toChar, 127.toChar)).map(_.mkString) val str = genString.map(Json.JString(_)) val nd1 = Arbitrary.arbitrary[Long].map { i => Json.JNumberStr(i.toString) } - val nd2 = Arbitrary.arbitrary[Double].map { d => Json.JNumberStr(d.toString) } + val nd2 = + Arbitrary.arbitrary[Double].map { d => Json.JNumberStr(d.toString) } val nd3 = Arbitrary.arbitrary[Int].map { i => Json.JNumberStr(i.toString) } val b = Gen.oneOf(Json.JBool(true), Json.JBool(false)) @@ -42,9 +47,12 @@ object GenJson { else { val recurse = Gen.lzy(genJson(depth - 1)) val collectionSize = Gen.choose(0, depth * depth) - val ary = collectionSize.flatMap(Gen.listOfN(_, recurse).map { l => Json.JArray(l.toVector) }) + val ary = collectionSize.flatMap( + Gen.listOfN(_, recurse).map { l => Json.JArray(l.toVector) } + ) val map = collectionSize.flatMap { sz => - Gen.listOfN(sz, Gen.zip(genString, recurse)) + Gen + .listOfN(sz, Gen.zip(genString, recurse)) .map { m => Json.JObject(m).normalize } } Gen.frequency((10, d0), (1, ary), (1, map)) @@ -54,16 +62,16 @@ object GenJson { implicit val arbJson: Arbitrary[Json] = Arbitrary(Gen.choose(0, 4).flatMap(genJson(_))) - implicit def shrinkJson( - implicit ss: Shrink[String] + implicit def shrinkJson(implicit + ss: Shrink[String] ): Shrink[Json] = Shrink[Json](new Function1[Json, Stream[Json]] { def apply(j: Json): Stream[Json] = { import Json._ j match { - case JString(str) => ss.shrink(str).map(JString(_)) - case JNumberStr(_) => Stream.empty - case JNull => Stream.empty + case JString(str) => ss.shrink(str).map(JString(_)) + case JNumberStr(_) => Stream.empty + case JNull => Stream.empty case JBool.True | JBool.False => Stream.empty case JArray(js) => (0 until js.size).toStream.map { sz => diff --git a/core/src/test/scala/org/bykn/bosatsu/GenValue.scala b/core/src/test/scala/org/bykn/bosatsu/GenValue.scala index 125636d16..818020d7d 100644 --- a/core/src/test/scala/org/bykn/bosatsu/GenValue.scala +++ b/core/src/test/scala/org/bykn/bosatsu/GenValue.scala @@ -12,7 +12,7 @@ object GenValue { lazy val genProd: Gen[ProductValue] = for { len <- Gen.exponential(0.5) - vs <- Gen.listOfN(len.toInt, genValue) + vs <- Gen.listOfN(len.toInt, genValue) } yield ProductValue.fromList(vs) lazy val genValue: Gen[Value] = { @@ -28,7 +28,8 @@ object GenValue { val genExt: Gen[Value] = Gen.oneOf( Gen.choose(Int.MinValue, Int.MaxValue).map(VInt(_)), - Arbitrary.arbitrary[String].map(Str(_))) + Arbitrary.arbitrary[String].map(Str(_)) + ) val genFn: Gen[FnValue] = { val fn: Gen[NonEmptyList[Value] => Value] = Gen.function1(recur)( diff --git a/core/src/test/scala/org/bykn/bosatsu/IntLaws.scala b/core/src/test/scala/org/bykn/bosatsu/IntLaws.scala index 64995f770..3e44c5f7d 100644 --- a/core/src/test/scala/org/bykn/bosatsu/IntLaws.scala +++ b/core/src/test/scala/org/bykn/bosatsu/IntLaws.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu import java.math.BigInteger import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite object IntLaws { @@ -20,26 +23,56 @@ class IntLaws extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = PropertyCheckConfiguration(minSuccessful = 50000) - //PropertyCheckConfiguration(minSuccessful = 5000) - //PropertyCheckConfiguration(minSuccessful = 500) + // PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 500) val genBI: Gen[BigInteger] = - Gen.choose(-128L, 128L) + Gen + .choose(-128L, 128L) .map(BigInteger.valueOf(_)) test("match python on some examples") { - assert(BigInteger.valueOf(4L) % BigInteger.valueOf(-3L) == BigInteger.valueOf(-2L)) - - assert(BigInteger.valueOf(-8L) % BigInteger.valueOf(-2L) == BigInteger.valueOf(0L)) - assert(BigInteger.valueOf(-8L) / BigInteger.valueOf(-2L) == BigInteger.valueOf(4L)) - - assert(BigInteger.valueOf(-4L) % BigInteger.valueOf(-3L) == BigInteger.valueOf(-1L)) - assert(BigInteger.valueOf(13L) % BigInteger.valueOf(3L) == BigInteger.valueOf(1L)) - assert(BigInteger.valueOf(-113L) / BigInteger.valueOf(16L) == BigInteger.valueOf(-8L)) - - - assert(BigInteger.valueOf(54L) % BigInteger.valueOf(-3L) == BigInteger.valueOf(0L)) - assert(BigInteger.valueOf(54L) / BigInteger.valueOf(-3L) == BigInteger.valueOf(-18L)) + assert( + BigInteger.valueOf(4L) % BigInteger.valueOf(-3L) == BigInteger.valueOf( + -2L + ) + ) + + assert( + BigInteger.valueOf(-8L) % BigInteger.valueOf(-2L) == BigInteger.valueOf( + 0L + ) + ) + assert( + BigInteger.valueOf(-8L) / BigInteger.valueOf(-2L) == BigInteger.valueOf( + 4L + ) + ) + + assert( + BigInteger.valueOf(-4L) % BigInteger.valueOf(-3L) == BigInteger.valueOf( + -1L + ) + ) + assert( + BigInteger.valueOf(13L) % BigInteger.valueOf(3L) == BigInteger.valueOf(1L) + ) + assert( + BigInteger.valueOf(-113L) / BigInteger.valueOf(16L) == BigInteger.valueOf( + -8L + ) + ) + + assert( + BigInteger.valueOf(54L) % BigInteger.valueOf(-3L) == BigInteger.valueOf( + 0L + ) + ) + assert( + BigInteger.valueOf(54L) / BigInteger.valueOf(-3L) == BigInteger.valueOf( + -18L + ) + ) } test("a = (a / b) * b + (a % b)") { @@ -107,7 +140,9 @@ class IntLaws extends AnyFunSuite { test("a / b <= a if b >= 0 and a >= 0") { forAll(genBI, genBI) { (a, b) => - if (b.compareTo(BigInteger.ZERO) >= 0 && a.compareTo(BigInteger.ZERO) >= 0) { + if ( + b.compareTo(BigInteger.ZERO) >= 0 && a.compareTo(BigInteger.ZERO) >= 0 + ) { val div = a / b assert(div.compareTo(a) <= 0, div) } @@ -133,7 +168,7 @@ class IntLaws extends AnyFunSuite { forAll(genBI, genBI) { (a, b) => val mod = a % b if (mod == BigInteger.ZERO) { - assert((a/b)*b == a) + assert((a / b) * b == a) } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/JsonTest.scala b/core/src/test/scala/org/bykn/bosatsu/JsonTest.scala index 2d55427bc..1183e48a6 100644 --- a/core/src/test/scala/org/bykn/bosatsu/JsonTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/JsonTest.scala @@ -3,7 +3,10 @@ package org.bykn.bosatsu import cats.Eq import cats.implicits._ import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import TestUtils.typeEnvOf import rankn.{NTypeGen, Type, TypeEnv} @@ -14,7 +17,9 @@ import org.scalatest.funsuite.AnyFunSuite class JsonTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 1000 else 20) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 1000 else 20 + ) def law(j: Json) = assert(Parser.unsafeParse(Json.parser, j.render) == j) @@ -25,7 +30,10 @@ class JsonTest extends AnyFunSuite { .flatMap { te => val tyconsts = te.allDefinedTypes.map(_.toTypeConst) - val theseTypes = NTypeGen.genDepth(4, if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts))) + val theseTypes = NTypeGen.genDepth( + 4, + if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts)) + ) theseTypes.map((te, _)) } @@ -37,10 +45,19 @@ class JsonTest extends AnyFunSuite { optTE = if (none) None else Some(te) } yield (optTE, tpe) - test("test some example escapes") { - assert(Parser.unsafeParse(JsonStringUtil.escapedToken.string, "\\u0000") == "\\u0000") - assert(Parser.unsafeParse(JsonStringUtil.escapedString('\''), "'\\u0000'") == 0.toChar.toString) + assert( + Parser.unsafeParse( + JsonStringUtil.escapedToken.string, + "\\u0000" + ) == "\\u0000" + ) + assert( + Parser.unsafeParse( + JsonStringUtil.escapedString('\''), + "'\\u0000'" + ) == 0.toChar.toString + ) } test("we can parse all the json we generate") { @@ -51,13 +68,12 @@ class JsonTest extends AnyFunSuite { forAll(genJsonNumber)(law(_)) forAll(genJsonNumber) { num => - val parts = Parser.unsafeParse(Parser.JsonNumber.partsParser, num.asString) + val parts = + Parser.unsafeParse(Parser.JsonNumber.partsParser, num.asString) assert(parts.asString == num.asString) } - val regressions = List( - Json.JNumberStr("2E9"), - Json.JNumberStr("-9E+19")) + val regressions = List(Json.JNumberStr("2E9"), Json.JNumberStr("-9E+19")) regressions.foreach { n => law(n) @@ -68,7 +84,7 @@ class JsonTest extends AnyFunSuite { def law(te: Option[TypeEnv[Any]], t: Type, j: Json) = { val jsonCodec = te match { - case None => ValueToJson(_ => None) + case None => ValueToJson(_ => None) case Some(te) => ValueToJson(te.toDefinedType(_)) } val toJson = jsonCodec.toJson(t) @@ -84,13 +100,17 @@ class JsonTest extends AnyFunSuite { ej1 match { case Right(j1) => assert(Eq[Json].eqv(j1, j), s"$j1 != $j") - case Left(_) => () + case Left(_) => () } } - forAll(optTE, GenJson.arbJson.arbitrary) { case ((ote, tpe), json) => law(ote, tpe, json) } + forAll(optTE, GenJson.arbJson.arbitrary) { case ((ote, tpe), json) => + law(ote, tpe, json) + } - val regressions = List((None, Type.TyApply(Type.OptionType, Type.BoolType), Json.JBool.False)) + val regressions = List( + (None, Type.TyApply(Type.OptionType, Type.BoolType), Json.JBool.False) + ) regressions.foreach { case (te, t, j) => law(te, t, j) } } @@ -114,7 +134,7 @@ class JsonTest extends AnyFunSuite { def law(ote: Option[TypeEnv[Unit]], t: Type, v: Value) = { val jsonCodec = ote match { - case None => ValueToJson(_ => None) + case None => ValueToJson(_ => None) case Some(te) => ValueToJson(te.toDefinedType(_)) } val toJson = jsonCodec.toJson(t) @@ -130,7 +150,7 @@ class JsonTest extends AnyFunSuite { ej1 match { case Right(v1) => assert(v1 == v, s"$v1 != $v") - case Left(_) => () + case Left(_) => () } } @@ -143,7 +163,9 @@ class JsonTest extends AnyFunSuite { } test("some hand written cases round trip") { - val te = typeEnvOf(PackageName.parts("Test"), """ + val te = typeEnvOf( + PackageName.parts("Test"), + """ struct MyUnit # wrappers are removed @@ -153,19 +175,21 @@ struct MyPair(fst, snd) enum MyEither: L(left), R(right) enum MyNat: Z, S(prev: MyNat) -""") +""" + ) val jsonConv = ValueToJson(te.toDefinedType(_)) def stringToType(t: String): Type = { val tr = Parser.unsafeParse(TypeRef.parser, t) TypeRefConverter[cats.Id](tr) { cons => - te.referencedPackages.toList.flatMap { pack => - val const = Type.Const.Defined(pack, TypeName(cons)) - te.toDefinedType(const).map(_ => const) - } - .headOption - .getOrElse(Type.Const.predef(cons.asString)) + te.referencedPackages.toList + .flatMap { pack => + val const = Type.Const.Defined(pack, TypeName(cons)) + te.toDefinedType(const).map(_ => const) + } + .headOption + .getOrElse(Type.Const.predef(cons.asString)) } } @@ -186,9 +210,11 @@ enum MyNat: Z, S(prev: MyNat) case Right(j1) => assert(Eq[Json].eqv(j1, j), s"$j1 != $j") case Left(err) => fail(err.toString) } - case Left(err) => fail(s"could not handle to Json: $tpe, $t, $toV, $err") + case Left(err) => + fail(s"could not handle to Json: $tpe, $t, $toV, $err") } - case Left(err) => fail(s"could not handle to Value: $tpe, $t, $toJ, $err") + case Left(err) => + fail(s"could not handle to Value: $tpe, $t, $toJ, $err") } } @@ -196,7 +222,7 @@ enum MyNat: Z, S(prev: MyNat) val t = stringToType(tpe) jsonConv.supported(t) match { case Right(_) => fail(s"expected $tpe to be unsupported") - case Left(_) => succeed + case Left(_) => succeed } } @@ -210,7 +236,7 @@ enum MyNat: Z, S(prev: MyNat) assert(toJ.isRight) val j = stringToJson(json) toV(j) match { - case Left(_) => succeed + case Left(_) => succeed case Right(v) => fail(s"expected $json to be ill-typed: $v") } case Left(err) => fail(s"could not handle to Value: $tpe, $t, $err") diff --git a/core/src/test/scala/org/bykn/bosatsu/KindFormulaTest.scala b/core/src/test/scala/org/bykn/bosatsu/KindFormulaTest.scala index 214351c87..2db505647 100644 --- a/core/src/test/scala/org/bykn/bosatsu/KindFormulaTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/KindFormulaTest.scala @@ -32,7 +32,10 @@ class KindFormulaTest extends AnyFunSuite { def testKind(teStr: String, shapes: Map[String, String]) = testKindEither(makeTE(teStr), shapes) - def testKindEither(te: Either[Any, TypeEnv[Kind.Arg]], shapes: Map[String, String]) = + def testKindEither( + te: Either[Any, TypeEnv[Kind.Arg]], + shapes: Map[String, String] + ) = te match { case Right(te) => shapes.foreach { case (n, vs) => @@ -206,9 +209,10 @@ struct Leib[a, b](cast: forall f. f[a] -> f[b]) ) ) } - + test("test Applicative example") { - testKind("""# + testKind( + """# # Represents the Applicative typeclass struct Fn[a: -*, b: +*] struct Unit @@ -221,7 +225,9 @@ struct Applicative( map2: forall a, b, c. f[a] -> f[b] -> (a -> b -> c) -> f[c], product: forall a, b. f[a] -> f[b] -> f[(a, b)]) -""", Map("Applicative" -> "(* -> *) -> *")) +""", + Map("Applicative" -> "(* -> *) -> *") + ) } test("linked list is allowed") { diff --git a/core/src/test/scala/org/bykn/bosatsu/KindParseTest.scala b/core/src/test/scala/org/bykn/bosatsu/KindParseTest.scala index c1193b248..ea14fbc14 100644 --- a/core/src/test/scala/org/bykn/bosatsu/KindParseTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/KindParseTest.scala @@ -391,7 +391,7 @@ class KindParseTest extends ParserTestBase { case Some(idx) => assert(Kind.longToKind(idx) == Some(k)) case None => () - } + } } assert(Kind.kindToLong(Kind.Type) == Some(0L)) @@ -403,28 +403,34 @@ class KindParseTest extends ParserTestBase { // small kinds have small codes Kind.allKinds.take(21).foreach { k => // these can all be encoded in 1 byte in proto - assert(Kind.kindToLong(k).get < 0x7fL) + assert(Kind.kindToLong(k).get < 0x7fL) } Kind.allKinds.take(217).foreach { k => // these can all be encoded in 2 byte in proto - assert(Kind.kindToLong(k).get < 0x7fffL) + assert(Kind.kindToLong(k).get < 0x7fffL) } } test("interleave and uninterleave -> inverses") { forAll { (l: Long) => - val res = Kind.uninterleave(l) + val res = Kind.uninterleave(l) val high = (res >>> 32).toInt val low = (res & 0xffffffffL).toInt - assert(Kind.interleave(high, low) == l, s"res = $res low = $low, high = $high") + assert( + Kind.interleave(high, low) == l, + s"res = $res low = $low, high = $high" + ) } forAll { (low: Int, high: Int) => val long = Kind.interleave(high, low) - val res = Kind.uninterleave(long) + val res = Kind.uninterleave(long) val high1 = (res >>> 32).toInt val low1 = (res & 0xffffffffL).toInt - assert((high, low) == (high1, low1), s"interleave($low, $high) = $long uninterleave($long) = $res") + assert( + (high, low) == (high1, low1), + s"interleave($low, $high) = $long uninterleave($long) = $res" + ) } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/ListUtilTest.scala b/core/src/test/scala/org/bykn/bosatsu/ListUtilTest.scala index d9091f177..2483438b2 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ListUtilTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ListUtilTest.scala @@ -1,7 +1,10 @@ package org.bykn.bosatsu import cats.data.NonEmptyList -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite import org.scalacheck.{Arbitrary, Gen} @@ -12,25 +15,30 @@ class ListUtilTest extends AnyFunSuite { def genNEL[A](ga: Gen[A]): Gen[NonEmptyList[A]] = Gen.sized { sz => - if (sz <= 1) ga.map(NonEmptyList.one) - else Gen.zip(ga, Gen.listOfN(sz - 1, ga)).map { case (h, t) => NonEmptyList(h, t) } - } + if (sz <= 1) ga.map(NonEmptyList.one) + else + Gen.zip(ga, Gen.listOfN(sz - 1, ga)).map { case (h, t) => + NonEmptyList(h, t) + } + } implicit def arbNEL[A: Arbitrary]: Arbitrary[NonEmptyList[A]] = Arbitrary(genNEL(Arbitrary.arbitrary[A])) test("unit group has 1 item") { forAll { (nel: NonEmptyList[Int]) => - val unit = ListUtil.greedyGroup(nel)(_ => ())((_, _) => Some(())) + val unit = ListUtil.greedyGroup(nel)(_ => ())((_, _) => Some(())) assert(unit == NonEmptyList.one(())) } } test("groups satisfy edge property") { forAll { (nel: NonEmptyList[Int], accept: (Int, Int) => Boolean) => - val groups = ListUtil.greedyGroup(nel)(a => Set(a))((s, i) => if (s.forall(accept(_, i))) Some(s + i) else None) + val groups = ListUtil.greedyGroup(nel)(a => Set(a))((s, i) => + if (s.forall(accept(_, i))) Some(s + i) else None + ) groups.toList.foreach { g => - val items = g.toList.zipWithIndex + val items = g.toList.zipWithIndex for { (i1, idx1) <- items (i2, idx2) <- items @@ -40,9 +48,14 @@ class ListUtilTest extends AnyFunSuite { } test("there are as most as many groups as inputs") { - forAll { (nel: NonEmptyList[Int], one: Int => Int, accept: (Int, Int) => Option[Int]) => - val groups = ListUtil.greedyGroup(nel)(one)(accept) - assert(groups.length <= nel.length) + forAll { + ( + nel: NonEmptyList[Int], + one: Int => Int, + accept: (Int, Int) => Option[Int] + ) => + val groups = ListUtil.greedyGroup(nel)(one)(accept) + assert(groups.length <= nel.length) } } @@ -62,21 +75,23 @@ class ListUtilTest extends AnyFunSuite { test("groups direct property") { forAll { (nel: NonEmptyList[Int], accept: (List[Int], Int) => Boolean) => - val groups = ListUtil.greedyGroup(nel)(a => a :: Nil)((s, i) => if (accept(s, i)) Some(i :: s) else None) + val groups = ListUtil.greedyGroup(nel)(a => a :: Nil)((s, i) => + if (accept(s, i)) Some(i :: s) else None + ) groups.toList.foreach { g => def check(g: List[Int]): Unit = g match { - case Nil => fail("expected at least one item") + case Nil => fail("expected at least one item") case _ :: Nil => - // this can always happen - () + // this can always happen + () case head :: tail => - assert(accept(tail, head)) - check(tail) + assert(accept(tail, head)) + check(tail) } check(g) } } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/bykn/bosatsu/LitTest.scala b/core/src/test/scala/org/bykn/bosatsu/LitTest.scala index 9ff3239c9..527d1498c 100644 --- a/core/src/test/scala/org/bykn/bosatsu/LitTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/LitTest.scala @@ -2,21 +2,29 @@ package org.bykn.bosatsu import org.scalacheck.Gen import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalacheck.Arbitrary import org.typelevel.paiges.Document class LitTest extends AnyFunSuite { def config: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 1000 else 100) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 1000 else 100 + ) val genLit: Gen[Lit] = Gen.oneOf( Gen.choose(-10000, 10000).map(Lit.fromInt), Arbitrary.arbitrary[String].map(Lit(_)), - Gen.frequency( - (10, Gen.choose(0, 0xd800)), - (1, Gen.choose(0xe000, 1000000))).map(Lit.fromCodePoint) + Gen + .frequency( + (10, Gen.choose(0, 0xd800)), + (1, Gen.choose(0xe000, 1000000)) + ) + .map(Lit.fromCodePoint) ) test("we can convert from Char to Lit") { @@ -24,8 +32,7 @@ class LitTest extends AnyFunSuite { try { val chr = Lit.fromChar(c) assert(chr.asInstanceOf[Lit.Chr].asStr == c.toString) - } - catch { + } catch { case _: IllegalArgumentException => // there are at least 1million valid codepoints val cp = c.toInt @@ -39,8 +46,7 @@ class LitTest extends AnyFunSuite { try { val chr = Lit.fromCodePoint(cp) assert(chr.asInstanceOf[Lit.Chr].toCodePoint == cp) - } - catch { + } catch { case _: IllegalArgumentException => // there are at least 1million valid codepoints assert(cp < 0 || (0xd800 <= cp && cp < 0xe000) || (cp > 1000000)) @@ -50,13 +56,15 @@ class LitTest extends AnyFunSuite { test("Lit ordering is correct") { forAll(genLit, genLit, genLit) { (a, b, c) => - OrderingLaws.law(a, b, c) + OrderingLaws.law(a, b, c) } } test("we can parse from document") { forAll(genLit) { l => - assert(Lit.parser.parseAll(Document[Lit].document(l).render(80)) == Right(l)) + assert( + Lit.parser.parseAll(Document[Lit].document(l).render(80)) == Right(l) + ) } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/bykn/bosatsu/LocationMapTest.scala b/core/src/test/scala/org/bykn/bosatsu/LocationMapTest.scala index 7fc3f32a0..1bc3ed6db 100644 --- a/core/src/test/scala/org/bykn/bosatsu/LocationMapTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/LocationMapTest.scala @@ -1,12 +1,17 @@ package org.bykn.bosatsu import org.scalacheck.{Arbitrary, Gen} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class LocationMapTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 50000 else 100) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 50000 else 100 + ) test("single line locations") { val singleLine: Gen[String] = @@ -30,7 +35,8 @@ class LocationMapTest extends AnyFunSuite { forAll { (str: String) => val lm = LocationMap(str) - val reconstruct = Iterator.iterate(0)(_ + 1) + val reconstruct = Iterator + .iterate(0)(_ + 1) .map(lm.getLine _) .takeWhile(_.isDefined) .collect { case Some(l) => l } @@ -39,7 +45,9 @@ class LocationMapTest extends AnyFunSuite { assert(reconstruct === str) } } - test("toLineCol is defined for all valid offsets, and getLine isDefined consistently") { + test( + "toLineCol is defined for all valid offsets, and getLine isDefined consistently" + ) { forAll { (s: String, offset: Int) => val lm = LocationMap(s) @@ -53,7 +61,8 @@ class LocationMapTest extends AnyFunSuite { case None => assert(offset == s.length) case Some(line) => assert(line.length >= col) - if (line.length == col) assert(offset == s.length || s(offset) == '\n') + if (line.length == col) + assert(offset == s.length || s(offset) == '\n') else assert(line(col) == s(offset)) } } @@ -67,7 +76,7 @@ class LocationMapTest extends AnyFunSuite { forAll { (s: String) => LocationMap(s).toLineCol(0) match { case Some(r) => assert(r == ((0, 0))) - case None => assert(s.isEmpty) + case None => assert(s.isEmpty) } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala index 1a4258401..a902d71c6 100644 --- a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala +++ b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import org.scalacheck.{Arbitrary, Gen} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration} +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import Identifier.{Bindable, Constructor} import rankn.DataRepr @@ -12,26 +15,28 @@ import org.scalatest.funsuite.AnyFunSuite class MatchlessTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 1000 else 20) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 1000 else 20 + ) type Fn = (PackageName, Constructor) => Option[DataRepr] - def fnFromTypeEnv[A](te: rankn.TypeEnv[A]): Fn = - { - // the list constructors *have* to be in scope or matching will generate - // bad code - case (PackageName.PredefName, Constructor("EmptyList")) => - Some(DataRepr.Enum(0, 0, List(0, 1))) - case (PackageName.PredefName, Constructor("NonEmptyList")) => - Some(DataRepr.Enum(1, 2, List(0, 1))) - case (pn, cons) => - te.getConstructor(pn, cons) - .map(_._1.dataRepr(cons)) - .orElse(Some(DataRepr.Struct(0))) - } + def fnFromTypeEnv[A](te: rankn.TypeEnv[A]): Fn = { + // the list constructors *have* to be in scope or matching will generate + // bad code + case (PackageName.PredefName, Constructor("EmptyList")) => + Some(DataRepr.Enum(0, 0, List(0, 1))) + case (PackageName.PredefName, Constructor("NonEmptyList")) => + Some(DataRepr.Enum(1, 2, List(0, 1))) + case (pn, cons) => + te.getConstructor(pn, cons) + .map(_._1.dataRepr(cons)) + .orElse(Some(DataRepr.Struct(0))) + } lazy val genInputs: Gen[(Bindable, RecursionKind, TypedExpr[Unit], Fn)] = - Generators.genPackage(Gen.const(()), 5) + Generators + .genPackage(Gen.const(()), 5) .flatMap { (m: Map[PackageName, Package.Typed[Unit]]) => val candidates = m.filter { case (_, t) => t.program.lets.nonEmpty } @@ -59,7 +64,9 @@ class MatchlessTest extends AnyFunSuite { val name = Identifier.Name("foo") val te = TypedExpr.Local(name, rankn.Type.IntType, ()) // this should not throw - val me = Matchless.fromLet(name, RecursionKind.Recursive, te)(fnFromTypeEnv(rankn.TypeEnv.empty)) + val me = Matchless.fromLet(name, RecursionKind.Recursive, te)( + fnFromTypeEnv(rankn.TypeEnv.empty) + ) assert(me != null) } @@ -83,14 +90,16 @@ class MatchlessTest extends AnyFunSuite { } test("Matchless.stopAt works") { - forAll(genNE(100, Gen.choose(-100, 100)), Arbitrary.arbitrary[Int => Boolean]) { (nel, fn) => + forAll( + genNE(100, Gen.choose(-100, 100)), + Arbitrary.arbitrary[Int => Boolean] + ) { (nel, fn) => val stopped = Matchless.stopAt(nel)(fn) if (fn(stopped.last)) { // none of the items before the last are true: assert(stopped.init.exists(fn) == false) - } - else { + } else { // none of them were true assert(stopped == nel) assert(nel.exists(fn) == false) @@ -105,31 +114,39 @@ class MatchlessTest extends AnyFunSuite { for { s <- size left <- Gen.listOfN(s, bytes) - sright <- Gen.choose(0, 2*s) - pat <- Gen.listOfN(sright, Arbitrary.arbitrary[Option[Byte => Option[Int]]]) + sright <- Gen.choose(0, 2 * s) + pat <- Gen.listOfN( + sright, + Arbitrary.arbitrary[Option[Byte => Option[Int]]] + ) } yield (left, pat) } import pattern.{SeqPattern, SeqPart, Splitter, Matcher} - def toSeqPat[A, B](pat: List[Option[A => Option[B]]]): SeqPattern[A => Option[B]] = + def toSeqPat[A, B]( + pat: List[Option[A => Option[B]]] + ): SeqPattern[A => Option[B]] = SeqPattern.fromList(pat.map { - case None => SeqPart.Wildcard - case Some(fn) =>SeqPart.Lit(fn) + case None => SeqPart.Wildcard + case Some(fn) => SeqPart.Lit(fn) }) val matcher = SeqPattern.matcher( Splitter.listSplitter(new Matcher[Byte => Option[Int], Byte, Int] { def apply(fn: Byte => Option[Int]) = fn - })) + }) + ) forAll(genArgs) { case (targ, pat) => val seqPat = toSeqPat(pat) val matchRes = matcher(seqPat)(targ) - val matchlessRes = Matchless.matchList(targ, + val matchlessRes = Matchless.matchList( + targ, pat.map { - case None => Left { (_: List[Byte]) => 0 } + case None => Left { (_: List[Byte]) => 0 } case Some(fn) => Right(fn) - }) + } + ) assert(matchlessRes == matchRes) } diff --git a/core/src/test/scala/org/bykn/bosatsu/MonadGen.scala b/core/src/test/scala/org/bykn/bosatsu/MonadGen.scala index c655acd3d..812e27e58 100644 --- a/core/src/test/scala/org/bykn/bosatsu/MonadGen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/MonadGen.scala @@ -14,4 +14,4 @@ object MonadGen { def tailRecM[A, B](a: A)(fn: A => Gen[Either[A, B]]): Gen[B] = Gen.tailRecM(a)(fn) } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/bykn/bosatsu/OperatorTest.scala b/core/src/test/scala/org/bykn/bosatsu/OperatorTest.scala index 817695225..34f6aa2b2 100644 --- a/core/src/test/scala/org/bykn/bosatsu/OperatorTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/OperatorTest.scala @@ -1,6 +1,6 @@ package org.bykn.bosatsu -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import cats.parse.{Parser => P} @@ -14,7 +14,7 @@ class OperatorTest extends ParserTestBase { sealed abstract class F { def toFormula: Formula[String] = this match { - case F.Num(s) => Formula.Sym(s) + case F.Num(s) => Formula.Sym(s) case F.Form(Formula.Sym(n)) => n.toFormula case F.Form(Formula.Op(left, op, right)) => Formula.Op(F.Form(left).toFormula, op, F.Form(right).toFormula) @@ -28,8 +28,7 @@ class OperatorTest extends ParserTestBase { lazy val formP: P[F] = Operators.Formula .parser( - Parser - .integerString + Parser.integerString .map(F.Num(_)) .orElse(P.defer(formP.parensCut)) ) @@ -43,7 +42,11 @@ class OperatorTest extends ParserTestBase { } def parseSame(left: String, right: String) = - assert(Parser.unsafeParse(formP, left).toFormula == Parser.unsafeParse(formP, right).toFormula) + assert( + Parser.unsafeParse(formP, left).toFormula == Parser + .unsafeParse(formP, right) + .toFormula + ) test("we can parse integer formulas") { parseSame("1+2", "1 + 2") @@ -69,7 +72,8 @@ class OperatorTest extends ParserTestBase { } test("test operator precedence in real programs") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Test operator + = add @@ -83,9 +87,13 @@ test = TestSuite("precedence", Assertion(1 * 2 * 3 == (1 * 2) * 3, "p1"), Assertion(1 + 2 % 3 == 1 + (2 % 3), "p1") ]) -"""), "Test", 3) +"""), + "Test", + 3 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Test # this is non-associative so we can test order @@ -102,9 +110,14 @@ test = TestSuite("precedence", [ Assertion(1 *> 2 *> 3 == (1 *> 2) *> 3, "p1"), ]) -"""), "Test", 1) - - runBosatsuTest(List(""" +"""), + "Test", + 1 + ) + + runBosatsuTest( + List( + """ package T1 export operator +, operator *, operator == @@ -113,7 +126,7 @@ operator + = add operator * = times operator == = eq_Int """, - """ + """ package T2 from T1 import operator + as operator ++, `*`, `==` @@ -124,11 +137,16 @@ from T1 import operator + as operator ++, `*`, `==` test = TestSuite("import export", [ Assertion(1 +. (2 * 3) == 1 .+ (2 * 3), "p1"), Assertion(1 .+ 2 * 3 == (1 .+ 2) * 3, "p1") ]) -"""), "T2", 2) +""" + ), + "T2", + 2 + ) } test("test ternary operator precedence") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Test operator == = eq_Int @@ -147,6 +165,9 @@ test = TestSuite("precedence", Assertion(left1 == right1, "p1"), Assertion(left2 == right2, "p2"), ]) -"""), "Test", 2) +"""), + "Test", + 2 + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/PackageTest.scala b/core/src/test/scala/org/bykn/bosatsu/PackageTest.scala index a235d9011..3c4846b53 100644 --- a/core/src/test/scala/org/bykn/bosatsu/PackageTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/PackageTest.scala @@ -8,9 +8,12 @@ import org.scalatest.funsuite.AnyFunSuite class PackageTest extends AnyFunSuite with ParTest { - def resolveThenInfer(ps: Iterable[Package.Parsed]): ValidatedNel[PackageError, PackageMap.Inferred] = { + def resolveThenInfer( + ps: Iterable[Package.Parsed] + ): ValidatedNel[PackageError, PackageMap.Inferred] = { implicit val showInt: Show[Int] = Show.fromToString - PackageMap.resolveThenInfer(ps.toList.zipWithIndex.map(_.swap), Nil) + PackageMap + .resolveThenInfer(ps.toList.zipWithIndex.map(_.swap), Nil) .strictToValidated } @@ -22,7 +25,7 @@ class PackageTest extends AnyFunSuite with ParTest { def valid[A, B](v: Validated[A, B]) = v match { - case Validated.Valid(_) => succeed + case Validated.Valid(_) => succeed case Validated.Invalid(err) => fail(err.toString) } @@ -35,15 +38,13 @@ class PackageTest extends AnyFunSuite with ParTest { } test("simple package resolves") { - val p1 = parse( -""" + val p1 = parse(""" package Foo export main main = 1 """) - val p2 = parse( -""" + val p2 = parse(""" package Foo2 from Foo import main as mainFoo export main, @@ -51,8 +52,7 @@ export main, main = mainFoo """) - val p3 = parse( -""" + val p3 = parse(""" package Foo from Foo2 import main as mainFoo @@ -63,8 +63,7 @@ main = 1 valid(resolveThenInfer(List(p1, p2))) invalid(resolveThenInfer(List(p2, p3))) // loop here - val p4 = parse( -""" + val p4 = parse(""" package P4 from Foo2 import main as one @@ -74,8 +73,7 @@ main = add(one, 42) """) valid(resolveThenInfer(List(p1, p2, p4))) - val p5 = parse( -""" + val p5 = parse(""" package P5 export Option(), List(), head, tail @@ -101,8 +99,7 @@ def tail(list): case NonEmpty(_, t): Some(t) """) - val p6 = parse( -""" + val p6 = parse(""" package P6 from P5 import Option, List, NonEmpty, Empty, head export data @@ -113,8 +110,7 @@ main = head(data) """) valid(resolveThenInfer(List(p5, p6))) - val p7 = parse( -""" + val p7 = parse(""" package P7 from P6 import data as p6_data from P5 import Option, List, NonEmpty as Cons, Empty as Nil, head @@ -131,8 +127,7 @@ main = head(data1) assert(Package.predefPackage != null) - val p = parse( -""" + val p = parse(""" package UsePredef def maybeOne(x): @@ -148,8 +143,7 @@ main = maybeOne(42) test("test using a renamed type") { - val p1 = parse( -""" + val p1 = parse(""" package R1 export Foo(), mkFoo, takeFoo @@ -163,8 +157,7 @@ def takeFoo(foo): 0 """) - val p2 = parse( -""" + val p2 = parse(""" package R2 from R1 import Foo as Bar, mkFoo, takeFoo @@ -184,15 +177,13 @@ main2 = match baz: } test("unused imports is an error") { - val p1 = parse( -""" + val p1 = parse(""" package Foo export main main = 1 """) - val p2 = parse( -""" + val p2 = parse(""" package Foo2 from Foo import main as mainFoo export main, @@ -202,8 +193,7 @@ main = 2 invalid(resolveThenInfer(List(p1, p2))) - val p3 = parse( -""" + val p3 = parse(""" package Foo export main, Foo @@ -211,8 +201,7 @@ enum Foo: Bar, Baz main = 1 """) - val p4 = parse( -""" + val p4 = parse(""" package Foo2 from Foo import main as mainFoo, Foo export main, diff --git a/core/src/test/scala/org/bykn/bosatsu/ParTest.scala b/core/src/test/scala/org/bykn/bosatsu/ParTest.scala index e1d284513..49977f33a 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ParTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ParTest.scala @@ -1,6 +1,6 @@ package org.bykn.bosatsu -import org.scalatest.{BeforeAndAfterAll, Suite } +import org.scalatest.{BeforeAndAfterAll, Suite} trait ParTest extends BeforeAndAfterAll { self: Suite => diff --git a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala index 335af544a..0ba3a670a 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala @@ -4,7 +4,10 @@ import cats.data.NonEmptyList import Parser.Combinators import java.math.BigInteger import org.scalacheck.{Arbitrary, Gen} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.typelevel.paiges.{Doc, Document} import cats.implicits._ @@ -20,8 +23,7 @@ trait ParseFns { else if (s0.length == idx) { val s = s0 + "*" ("...(" + s.drop(idx - 20).take(20) + ")...") - } - else { + } else { val s = s0.updated(idx, '*') ("...(" + s.drop(idx - 20).take(30) + ")...") } @@ -31,7 +33,8 @@ trait ParseFns { else if (s1.isEmpty) s2 else if (s2.isEmpty) s1 else if (s1(0) == s2(0)) firstDiff(s1.tail, s2.tail) - else s"${s1(0).toInt}: ${s1.take(20)}... != ${s2(0).toInt}: ${s2.take(20)}..." + else + s"${s1(0).toInt}: ${s1.take(20)}... != ${s2(0).toInt}: ${s2.take(20)}..." } @@ -47,11 +50,16 @@ abstract class ParserTestBase extends AnyFunSuite with ParseFns { case Right((rest, t)) => val idx = if (rest == "") str.length else str.indexOf(rest) lazy val message = firstDiff(t.toString, expected.toString) - assert(t == expected, s"difference: $message, input syntax:\n\n\n$str\n\n") + assert( + t == expected, + s"difference: $message, input syntax:\n\n\n$str\n\n" + ) assert(idx == exidx) case Left(err) => val idx = err.failedAtOffset - fail(s"failed to parse: $str: at $idx in region ${region(str, idx)} with err: ${err}") + fail( + s"failed to parse: $str: at $idx in region ${region(str, idx)} with err: ${err}" + ) } def parseTestAll[T](p: P0[T], str: String, expected: T) = @@ -71,11 +79,15 @@ abstract class ParserTestBase extends AnyFunSuite with ParseFns { case Left(err) => val idx = err.failedAtOffset val diff = firstDiff(str, tstr) - fail(s"Diff: $diff.\nfailed to reparse: $tstr: $idx in region ${region(tstr, idx)} with err: ${err}") + fail( + s"Diff: $diff.\nfailed to reparse: $tstr: $idx in region ${region(tstr, idx)} with err: ${err}" + ) } case Left(err) => val idx = err.failedAtOffset - fail(s"failed to parse: $str: $idx in region ${region(str, idx)} with err: ${err}") + fail( + s"failed to parse: $str: $idx in region ${region(str, idx)} with err: ${err}" + ) } def roundTripExact[T: Document](p: P0[T], str: String) = @@ -87,7 +99,9 @@ abstract class ParserTestBase extends AnyFunSuite with ParseFns { assert(tstr == str) case Left(err) => val idx = err.failedAtOffset - fail(s"failed to parse: $str: $idx in region ${region(str, idx)} with err: ${err}") + fail( + s"failed to parse: $str: $idx in region ${region(str, idx)} with err: ${err}" + ) } def law[T: Document](p: P0[T])(t: T) = { @@ -102,12 +116,15 @@ abstract class ParserTestBase extends AnyFunSuite with ParseFns { fail(s"parsed $t to: $idx: ${region(str, idx)}") case Left(err) => val idx = err.failedAtOffset - def msg = s"failed to parse: $str: at $idx in region ${region(str, idx)} with err: ${err}" + def msg = + s"failed to parse: $str: at $idx in region ${region(str, idx)} with err: ${err}" assert(idx == atIdx, msg) } def config: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 300 else 10) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 300 else 10 + ) } class ParserTest extends ParserTestBase { @@ -129,10 +146,11 @@ class ParserTest extends ParserTestBase { def loop(b: String): Gen[String] = if (b.length <= 1) Gen.const(b) - else for { - s <- sep - tail <- loop(b.tail) - } yield s"${b.charAt(0)}$s$tail" + else + for { + s <- sep + tail <- loop(b.tail) + } yield s"${b.charAt(0)}$s$tail" loop(bstr).map(Opaque(_)) } @@ -149,11 +167,12 @@ class ParserTest extends ParserTestBase { try { Parser.unescape(str1) match { case Right(str2) => assert(str2 == str) - case Left(idx) => fail(s"failed at idx: $idx in $str: ${region(str, idx)}") + case Left(idx) => + fail(s"failed at idx: $idx in $str: ${region(str, idx)}") } - } - catch { - case t: Throwable => fail(s"failed to decode: $str1 from $str, exception: $t") + } catch { + case t: Throwable => + fail(s"failed to decode: $str1 from $str, exception: $t") } } @@ -206,7 +225,6 @@ class ParserTest extends ParserTestBase { val regressions = List(("'", '\'')) - regressions.foreach { case (s, c) => law(s, c) } } @@ -214,12 +232,19 @@ class ParserTest extends ParserTestBase { def singleq(str1: String, res: List[Either[Json, String]]) = parseTestAll( StringUtil - .interpolatedString('\'', P.string("${").as((j: Json) => j), Json.parser, P.char('}')) + .interpolatedString( + '\'', + P.string("${").as((j: Json) => j), + Json.parser, + P.char('}') + ) .map(_.map { case Right((_, str)) => Right(str) - case Left(l) => Left(l) - }) - , str1, res) + case Left(l) => Left(l) + }), + str1, + res + ) // scala complains about things that look like interpolation strings that aren't interpolated val dollar = '$'.toString @@ -229,25 +254,40 @@ class ParserTest extends ParserTestBase { singleq(s"'foo\\$dollar{bar}'", List(Right(s"foo$dollar{bar}"))) // foo$bar is okay, it is only foo${bar} that needs to be escaped singleq(s"'foo${dollar}bar'", List(Right(s"foo${dollar}bar"))) - singleq(s"'foo$dollar{42}'", List(Right("foo"), Left(Json.JNumberStr("42")))) + singleq( + s"'foo$dollar{42}'", + List(Right("foo"), Left(Json.JNumberStr("42"))) + ) singleq(s"'$dollar{42}'", List(Left(Json.JNumberStr("42")))) - singleq(s"'$dollar{42}bar'", List(Left(Json.JNumberStr("42")), Right("bar"))) + singleq( + s"'$dollar{42}bar'", + List(Left(Json.JNumberStr("42")), Right("bar")) + ) } test("we can decode any utf16") { - val p = StringUtil.utf16Codepoint.repAs(StringUtil.codePointAccumulator) | P.pure("") + val p = + StringUtil.utf16Codepoint.repAs(StringUtil.codePointAccumulator) | P.pure( + "" + ) val genCodePoints: Gen[Int] = Gen.frequency( (10, Gen.choose(0, 0xd7ff)), - (1, Gen.choose(0, 0x10ffff).filterNot { cp => - (0xD800 <= cp && cp <= 0xDFFF) - }) + ( + 1, + Gen.choose(0, 0x10ffff).filterNot { cp => + (0xd800 <= cp && cp <= 0xdfff) + } + ) ) // .codePoints isn't available in scalajs def jsCompatCodepoints(s: String): List[Int] = if (s.isEmpty) Nil - else (s.codePointAt(0) :: jsCompatCodepoints(s.substring(s.offsetByCodePoints(0, 1)))) + else + (s.codePointAt(0) :: jsCompatCodepoints( + s.substring(s.offsetByCodePoints(0, 1)) + )) forAll(Gen.listOf(genCodePoints)) { cps => val strbuilder = new java.lang.StringBuilder @@ -258,16 +298,18 @@ class ParserTest extends ParserTestBase { val parsed = p.parseAll(str) assert(parsed == Right(str)) - assert(parsed.map(jsCompatCodepoints) == Right(cps), - s"hex = $hex, str = ${jsCompatCodepoints(str)} utf16 = ${str.toCharArray().toList.map(_.toInt.toHexString)}") + assert( + parsed.map(jsCompatCodepoints) == Right(cps), + s"hex = $hex, str = ${jsCompatCodepoints(str)} utf16 = ${str.toCharArray().toList.map(_.toInt.toHexString)}" + ) } } test("Identifier round trips") { forAll(Generators.identifierGen)(law(Identifier.parser)) - val examples = List("foo", "`bar`", "`bar foo`", - "`with \\`internal`", "operator +") + val examples = + List("foo", "`bar`", "`bar foo`", "`with \\`internal`", "operator +") examples.foreach(roundTrip(Identifier.parser, _)) } @@ -287,38 +329,55 @@ class ParserTest extends ParserTestBase { val str0 = ls.toString val str = str0.flatMap { case ',' => "," + (" " * spaceCount) - case c => c.toString + case c => c.toString } val listOfStr: P[List[String]] = P.string("List(") *> - Parser.integerString.nonEmptyList.map(_.toList) + Parser.integerString.nonEmptyList + .map(_.toList) .orElse(P.pure(Nil)) <* - P.char(')') + P.char(')') - parseTestAll( - listOfStr, - str, - ls.map(_.toString)) + parseTestAll(listOfStr, str, ls.map(_.toString)) } } test("we can parse dicts") { - val strDict = Parser.dictLikeParser(Parser.escapedString('\''), Parser.escapedString('\'')) + val strDict = Parser.dictLikeParser( + Parser.escapedString('\''), + Parser.escapedString('\'') + ) parseTestAll(strDict, "{}", Nil) parseTestAll(strDict, "{'a': 'b'}", List(("a", "b"))) parseTestAll(strDict, "{ 'a' : 'b' }", List(("a", "b"))) parseTestAll(strDict, "{'a' : 'b', 'c': 'd'}", List(("a", "b"), ("c", "d"))) - parseTestAll(strDict, "{'a' : 'b',\n'c': 'd'}", List(("a", "b"), ("c", "d"))) - parseTestAll(strDict, "{'a' : 'b',\n\t'c': 'd'}", List(("a", "b"), ("c", "d"))) - parseTestAll(strDict, "{'a' : 'b',\n 'c': 'd'}", List(("a", "b"), ("c", "d"))) - - case class WildDict(stringRepNoCurlies: List[String], original: List[(String, String)]) { + parseTestAll( + strDict, + "{'a' : 'b',\n'c': 'd'}", + List(("a", "b"), ("c", "d")) + ) + parseTestAll( + strDict, + "{'a' : 'b',\n\t'c': 'd'}", + List(("a", "b"), ("c", "d")) + ) + parseTestAll( + strDict, + "{'a' : 'b',\n 'c': 'd'}", + List(("a", "b"), ("c", "d")) + ) + + case class WildDict( + stringRepNoCurlies: List[String], + original: List[(String, String)] + ) { def stringRep: String = stringRepNoCurlies.mkString("{", "", "}") def addEntry(strings: List[String], k: String, v: String): WildDict = if (stringRepNoCurlies.isEmpty) WildDict(strings, (k, v) :: original) - else WildDict(strings ::: ("," :: stringRepNoCurlies), (k, v) :: original) + else + WildDict(strings ::: ("," :: stringRepNoCurlies), (k, v) :: original) } val genString = Arbitrary.arbitrary[String] @@ -349,7 +408,14 @@ class ParserTest extends ParserTestBase { test("we can parse RecordConstructors") { def check(str: String) = - roundTrip[Declaration](Declaration.recordConstructorP("", Declaration.varP, Declaration.varP.orElse(Declaration.lits)), str) + roundTrip[Declaration]( + Declaration.recordConstructorP( + "", + Declaration.varP, + Declaration.varP.orElse(Declaration.lits) + ), + str + ) check("Foo { bar }") check("Foo{bar}") @@ -385,7 +451,7 @@ class ParserTest extends ParserTestBase { check("Foo{x:1}") // from scalacheck - //check("Ze8lujlrbo {wlqOvp: {}}") + // check("Ze8lujlrbo {wlqOvp: {}}") } test("we can parse tuples") { @@ -398,9 +464,11 @@ class ParserTest extends ParserTestBase { case _ => ls.mkString("(", "," + pad, ")") } - parseTestAll(Parser.integerString.tupleOrParens, + parseTestAll( + Parser.integerString.tupleOrParens, str, - Right(ls.map(_.toString))) + Right(ls.map(_.toString)) + ) } // a single item is parsed as parens @@ -408,45 +476,71 @@ class ParserTest extends ParserTestBase { val spaceCount = spaceCnt0 & 7 val pad = " " * spaceCount val str = s"($it$pad)" - parseTestAll(Parser.integerString.tupleOrParens, - str, - Left(it.toString)) + parseTestAll(Parser.integerString.tupleOrParens, str, Left(it.toString)) } } test("we can parse blocks") { - val indy = OptIndent.block(Indy.lift(P.string("if foo")), Indy.lift(P.string("bar"))) + val indy = + OptIndent.block(Indy.lift(P.string("if foo")), Indy.lift(P.string("bar"))) val p = indy.run("") parseTestAll(p, "if foo: bar", ((), OptIndent.same(()))) parseTestAll(p, "if foo:\n\tbar", ((), OptIndent.paddedIndented(1, 4, ()))) - parseTestAll(p, "if foo:\n bar", ((), OptIndent.paddedIndented(1, 4, ()))) + parseTestAll( + p, + "if foo:\n bar", + ((), OptIndent.paddedIndented(1, 4, ())) + ) parseTestAll(p, "if foo:\n bar", ((), OptIndent.paddedIndented(1, 2, ()))) import Indy.IndyMethods val repeated = indy.nonEmptyList(Indy.lift(Parser.toEOL1)) val single = ((), OptIndent.notSame(Padding(1, Indented(2, ())))) - parseTestAll(repeated.run(""), "if foo:\n bar\nif foo:\n bar", - NonEmptyList.of(single, single)) + parseTestAll( + repeated.run(""), + "if foo:\n bar\nif foo:\n bar", + NonEmptyList.of(single, single) + ) // we can nest blocks - parseTestAll(OptIndent.block(Indy.lift(P.string("nest")), indy)(""), "nest: if foo: bar", - ((), OptIndent.same(((), OptIndent.same(()))))) - parseTestAll(OptIndent.block(Indy.lift(P.string("nest")), indy)(""), "nest:\n if foo: bar", - ((), OptIndent.paddedIndented(1, 2, ((), OptIndent.same(()))))) - parseTestAll(OptIndent.block(Indy.lift(P.string("nest")), indy)(""), "nest:\n if foo:\n bar", - ((), OptIndent.paddedIndented(1, 2, ((), OptIndent.paddedIndented(1, 2, ()))))) - - val simpleBlock = OptIndent.block(Indy.lift(Parser.lowerIdent <* Parser.maybeSpace), Indy.lift(Parser.lowerIdent)) + parseTestAll( + OptIndent.block(Indy.lift(P.string("nest")), indy)(""), + "nest: if foo: bar", + ((), OptIndent.same(((), OptIndent.same(())))) + ) + parseTestAll( + OptIndent.block(Indy.lift(P.string("nest")), indy)(""), + "nest:\n if foo: bar", + ((), OptIndent.paddedIndented(1, 2, ((), OptIndent.same(())))) + ) + parseTestAll( + OptIndent.block(Indy.lift(P.string("nest")), indy)(""), + "nest:\n if foo:\n bar", + ( + (), + OptIndent.paddedIndented(1, 2, ((), OptIndent.paddedIndented(1, 2, ()))) + ) + ) + + val simpleBlock = OptIndent + .block( + Indy.lift(Parser.lowerIdent <* Parser.maybeSpace), + Indy.lift(Parser.lowerIdent) + ) .nonEmptyList(Indy.toEOLIndent) - val sbRes = NonEmptyList.of(("x1", OptIndent.paddedIndented(1, 2, "x2")), - ("y1", OptIndent.paddedIndented(1, 3, "y2"))) + val sbRes = NonEmptyList.of( + ("x1", OptIndent.paddedIndented(1, 2, "x2")), + ("y1", OptIndent.paddedIndented(1, 3, "y2")) + ) parseTestAll(simpleBlock(""), "x1:\n x2\ny1:\n y2", sbRes) - parseTestAll(OptIndent.block(Indy.lift(Parser.lowerIdent), simpleBlock)(""), + parseTestAll( + OptIndent.block(Indy.lift(Parser.lowerIdent), simpleBlock)(""), "block:\n x1:\n x2\n y1:\n y2", - ("block", OptIndent.paddedIndented(1, 2, sbRes))) + ("block", OptIndent.paddedIndented(1, 2, sbRes)) + ) } def trName(s: String): TypeRef.TypeName = @@ -456,23 +550,62 @@ class ParserTest extends ParserTestBase { parseTestAll(TypeRef.parser, "foo", TypeRef.TypeVar("foo")) parseTestAll(TypeRef.parser, "Foo", trName("Foo")) - parseTestAll(TypeRef.parser, "forall a. a", - TypeRef.TypeForAll(NonEmptyList.of((TypeRef.TypeVar("a"), None)), TypeRef.TypeVar("a"))) - parseTestAll(TypeRef.parser, "forall a, b. f[a] -> f[b]", - TypeRef.TypeForAll(NonEmptyList.of((TypeRef.TypeVar("a"), None), (TypeRef.TypeVar("b"), None)), + parseTestAll( + TypeRef.parser, + "forall a. a", + TypeRef.TypeForAll( + NonEmptyList.of((TypeRef.TypeVar("a"), None)), + TypeRef.TypeVar("a") + ) + ) + parseTestAll( + TypeRef.parser, + "forall a, b. f[a] -> f[b]", + TypeRef.TypeForAll( + NonEmptyList + .of((TypeRef.TypeVar("a"), None), (TypeRef.TypeVar("b"), None)), TypeRef.TypeArrow( - TypeRef.TypeApply(TypeRef.TypeVar("f"), NonEmptyList.of(TypeRef.TypeVar("a"))), - TypeRef.TypeApply(TypeRef.TypeVar("f"), NonEmptyList.of(TypeRef.TypeVar("b")))))) + TypeRef.TypeApply( + TypeRef.TypeVar("f"), + NonEmptyList.of(TypeRef.TypeVar("a")) + ), + TypeRef.TypeApply( + TypeRef.TypeVar("f"), + NonEmptyList.of(TypeRef.TypeVar("b")) + ) + ) + ) + ) roundTrip(TypeRef.parser, "forall a, b. f[a] -> f[b]") roundTrip(TypeRef.parser, "(forall a, b. f[a]) -> f[b]") roundTrip(TypeRef.parser, "(forall a, b. f[a])[Int]") // apply a type - parseTestAll(TypeRef.parser, "Foo -> Bar", TypeRef.TypeArrow(trName("Foo"), trName("Bar"))) - parseTestAll(TypeRef.parser, "Foo -> Bar -> baz", - TypeRef.TypeArrow(trName("Foo"), TypeRef.TypeArrow(trName("Bar"), TypeRef.TypeVar("baz")))) - parseTestAll(TypeRef.parser, "(Foo -> Bar) -> baz", - TypeRef.TypeArrow(TypeRef.TypeArrow(trName("Foo"), trName("Bar")), TypeRef.TypeVar("baz"))) - parseTestAll(TypeRef.parser, "Foo[Bar]", TypeRef.TypeApply(trName("Foo"), NonEmptyList.of(trName("Bar")))) + parseTestAll( + TypeRef.parser, + "Foo -> Bar", + TypeRef.TypeArrow(trName("Foo"), trName("Bar")) + ) + parseTestAll( + TypeRef.parser, + "Foo -> Bar -> baz", + TypeRef.TypeArrow( + trName("Foo"), + TypeRef.TypeArrow(trName("Bar"), TypeRef.TypeVar("baz")) + ) + ) + parseTestAll( + TypeRef.parser, + "(Foo -> Bar) -> baz", + TypeRef.TypeArrow( + TypeRef.TypeArrow(trName("Foo"), trName("Bar")), + TypeRef.TypeVar("baz") + ) + ) + parseTestAll( + TypeRef.parser, + "Foo[Bar]", + TypeRef.TypeApply(trName("Foo"), NonEmptyList.of(trName("Bar"))) + ) forAll(Generators.typeRefGen) { tref => parseTestAll(TypeRef.parser, tref.toDoc.render(80), tref) @@ -487,19 +620,37 @@ class ParserTest extends ParserTestBase { val varA = TyVar(Var.Bound("a")) val varB = TyVar(Var.Bound("b")) - val FooBarBar = TyConst(Const.Defined(PackageName.parts("Foo", "Bar"), TypeName(Identifier.Constructor("Bar")))) + val FooBarBar = TyConst( + Const.Defined( + PackageName.parts("Foo", "Bar"), + TypeName(Identifier.Constructor("Bar")) + ) + ) check("a", varA) check("Foo/Bar::Bar", FooBarBar) check("a -> Foo/Bar::Bar", Fun(varA, FooBarBar)) - check("forall a, b. Foo/Bar::Bar[a, b]", Type.forAll(List((Var.Bound("a"), Kind.Type), (Var.Bound("b"), Kind.Type)), TyApply(TyApply(FooBarBar, varA), varB))) - check("forall a. forall b. Foo/Bar::Bar[a, b]", Type.forAll(List((Var.Bound("a"), Kind.Type), (Var.Bound("b"), Kind.Type)), TyApply(TyApply(FooBarBar, varA), varB))) + check( + "forall a, b. Foo/Bar::Bar[a, b]", + Type.forAll( + List((Var.Bound("a"), Kind.Type), (Var.Bound("b"), Kind.Type)), + TyApply(TyApply(FooBarBar, varA), varB) + ) + ) + check( + "forall a. forall b. Foo/Bar::Bar[a, b]", + Type.forAll( + List((Var.Bound("a"), Kind.Type), (Var.Bound("b"), Kind.Type)), + TyApply(TyApply(FooBarBar, varA), varB) + ) + ) check("(a)", varA) check("(a, b)", Tuple(List(varA, varB))) } test("we can parse python style list expressions") { val pident = Parser.lowerIdent - implicit val stringDoc: Document[String] = Document.instance[String](Doc.text(_)) + implicit val stringDoc: Document[String] = + Document.instance[String](Doc.text(_)) val llp = ListLang.parser(pident, pident, pident) roundTrip(llp, "[a]") @@ -519,9 +670,22 @@ class ParserTest extends ParserTestBase { test("we can parse operators") { val singleToks = List( - "+", "-", "*", "!", "$", "%", - "^", "&", "*", "|", "?", "/", "<", - ">", "~") + "+", + "-", + "*", + "!", + "$", + "%", + "^", + "&", + "*", + "|", + "?", + "/", + "<", + ">", + "~" + ) val withEq = "=" :: singleToks val allLen2 = (withEq, withEq).mapN(_ + _) @@ -545,9 +709,8 @@ class ParserTest extends ParserTestBase { } } -/** - * This is a separate class since some of these are very slow - */ +/** This is a separate class since some of these are very slow + */ class SyntaxParseTest extends ParserTestBase { implicit val generatorDrivenConfig: PropertyCheckConfiguration = config @@ -556,11 +719,18 @@ class SyntaxParseTest extends ParserTestBase { Declaration.Var(Identifier.Name(n)) test("we can parse comments") { - val gen = Generators.commentGen(Generators.padding(Generators.genDeclaration(0), 1)) + val gen = + Generators.commentGen(Generators.padding(Generators.genDeclaration(0), 1)) forAll(gen) { comment => - parseTestAll(CommentStatement.parser(i => Padding.parser(Declaration.parser(i))).run(""), - Document[CommentStatement[Padding[Declaration]]].document(comment).render(80), - comment) + parseTestAll( + CommentStatement + .parser(i => Padding.parser(Declaration.parser(i))) + .run(""), + Document[CommentStatement[Padding[Declaration]]] + .document(comment) + .render(80), + comment + ) } val commentLit = """#foo @@ -571,8 +741,12 @@ class SyntaxParseTest extends ParserTestBase { Declaration.parser(""), commentLit, Declaration.CommentNB( - CommentStatement(NonEmptyList.of("foo", "bar"), - Padding(1, Declaration.Literal(Lit.fromInt(1)))))) + CommentStatement( + NonEmptyList.of("foo", "bar"), + Padding(1, Declaration.Literal(Lit.fromInt(1))) + ) + ) + ) val parensComment = """(#foo #bar @@ -581,9 +755,15 @@ class SyntaxParseTest extends ParserTestBase { parseTestAll( Declaration.parser(""), parensComment, - Declaration.Parens(Declaration.CommentNB( - CommentStatement(NonEmptyList.of("foo", "bar"), - Padding(1, Declaration.Literal(Lit.fromInt(1))))))) + Declaration.Parens( + Declaration.CommentNB( + CommentStatement( + NonEmptyList.of("foo", "bar"), + Padding(1, Declaration.Literal(Lit.fromInt(1))) + ) + ) + ) + ) } test("we can parse Lit.Integer") { @@ -593,11 +773,19 @@ class SyntaxParseTest extends ParserTestBase { } test("we can parse DefStatement") { - forAll(Generators.defGen(Generators.optIndent(Generators.genDeclaration(0)))) { defn => + forAll( + Generators.defGen(Generators.optIndent(Generators.genDeclaration(0))) + ) { defn => parseTestAll[DefStatement[Pattern.Parsed, OptIndent[Declaration]]]( - DefStatement.parser(Pattern.bindParser, Parser.maybeSpace.with1 *> OptIndent.indy(Declaration.parser).run("")), - Document[DefStatement[Pattern.Parsed, OptIndent[Declaration]]].document(defn).render(80), - defn) + DefStatement.parser( + Pattern.bindParser, + Parser.maybeSpace.with1 *> OptIndent.indy(Declaration.parser).run("") + ), + Document[DefStatement[Pattern.Parsed, OptIndent[Declaration]]] + .document(defn) + .render(80), + defn + ) } val defWithComment = """def foo(a): @@ -607,81 +795,160 @@ foo""" parseTestAll( Declaration.parser(""), defWithComment, - Declaration.DefFn(DefStatement(Identifier.Name("foo"), None, - NonEmptyList.one(NonEmptyList.one(Pattern.Var(Identifier.Name("a")))), None, - (OptIndent.paddedIndented(1, 2, Declaration.CommentNB(CommentStatement(NonEmptyList.of(" comment here"), - Padding(0, mkVar("a"))))), - Padding(0, mkVar("foo")))))) + Declaration.DefFn( + DefStatement( + Identifier.Name("foo"), + None, + NonEmptyList.one(NonEmptyList.one(Pattern.Var(Identifier.Name("a")))), + None, + ( + OptIndent.paddedIndented( + 1, + 2, + Declaration.CommentNB( + CommentStatement( + NonEmptyList.of(" comment here"), + Padding(0, mkVar("a")) + ) + ) + ), + Padding(0, mkVar("foo")) + ) + ) + ) + ) roundTrip(Declaration.parser(""), defWithComment) // Here is a pretty brutal randomly generated case - roundTrip(Declaration.parser(""), -"""def uwr(dw: h6lmZhgg) -> forall lnNR. Z5syis -> Mhgm: + roundTrip( + Declaration.parser(""), + """def uwr(dw: h6lmZhgg) -> forall lnNR. Z5syis -> Mhgm: -349743008 -foo""") +foo""" + ) } test("we can parse BindingStatement") { val dp = Declaration.parser("") - parseTestAll(dp, + parseTestAll( + dp, """foo = 5 5""", - Declaration.Binding(BindingStatement(Pattern.Var(Identifier.Name("foo")), Declaration.Literal(Lit.fromInt(5)), - Padding(1, Declaration.Literal(Lit.fromInt(5)))))) - + Declaration.Binding( + BindingStatement( + Pattern.Var(Identifier.Name("foo")), + Declaration.Literal(Lit.fromInt(5)), + Padding(1, Declaration.Literal(Lit.fromInt(5))) + ) + ) + ) - roundTrip(dp, -"""# + roundTrip( + dp, + """# Pair(_, x) = z -x""") +x""" + ) } test("we can parse any Apply") { import Declaration._ - import ApplyKind.{Dot => ADot, Parens => AParens } + import ApplyKind.{Dot => ADot, Parens => AParens} - parseTestAll(parser(""), + parseTestAll( + parser(""), "x(f)", - Apply(mkVar("x"), NonEmptyList.of(mkVar("f")), AParens)) + Apply(mkVar("x"), NonEmptyList.of(mkVar("f")), AParens) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), "f.x()", - Apply(mkVar("x"), NonEmptyList.of(mkVar("f")), ADot)) + Apply(mkVar("x"), NonEmptyList.of(mkVar("f")), ADot) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), "f(foo).x()", - Apply(mkVar("x"), NonEmptyList.of(Apply(mkVar("f"), NonEmptyList.of(mkVar("foo")), AParens)), ADot)) + Apply( + mkVar("x"), + NonEmptyList.of( + Apply(mkVar("f"), NonEmptyList.of(mkVar("foo")), AParens) + ), + ADot + ) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), "f.foo(x)", // foo(f, x) - Apply(mkVar("foo"), NonEmptyList.of(mkVar("f"), mkVar("x")), ADot)) + Apply(mkVar("foo"), NonEmptyList.of(mkVar("f"), mkVar("x")), ADot) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), "(\\x -> x)(f)", - Apply(Parens(Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x"))), NonEmptyList.of(mkVar("f")), AParens)) + Apply( + Parens( + Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x")) + ), + NonEmptyList.of(mkVar("f")), + AParens + ) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), "((\\x -> x)(f))", - Parens(Apply(Parens(Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x"))), NonEmptyList.of(mkVar("f")), AParens))) + Parens( + Apply( + Parens( + Lambda( + NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), + mkVar("x") + ) + ), + NonEmptyList.of(mkVar("f")), + AParens + ) + ) + ) // bare lambda - parseTestAll(parser(""), + parseTestAll( + parser(""), "((x -> x)(f))", - Parens(Apply(Parens(Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x"))), NonEmptyList.of(mkVar("f")), AParens))) + Parens( + Apply( + Parens( + Lambda( + NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), + mkVar("x") + ) + ), + NonEmptyList.of(mkVar("f")), + AParens + ) + ) + ) - val expected = Apply(Parens(Parens(Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x")))), NonEmptyList.of(mkVar("f")), AParens) - parseTestAll(parser(""), - "((\\x -> x))(f)", - expected) + val expected = Apply( + Parens( + Parens( + Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x")) + ) + ), + NonEmptyList.of(mkVar("f")), + AParens + ) + parseTestAll(parser(""), "((\\x -> x))(f)", expected) - parseTestAll(parser(""), - expected.toDoc.render(80), - expected) + parseTestAll(parser(""), expected.toDoc.render(80), expected) } @@ -724,7 +991,7 @@ x""") test("Declaration.toPattern works for all Pattern-like declarations") { def law1(dec: Declaration.NonBinding) = { Declaration.toPattern(dec) match { - case None => fail("expected to convert to pattern") + case None => fail("expected to convert to pattern") case Some(pat) => // if we convert to string this parses the same as a pattern: val decStr = dec.toDoc.render(80) @@ -739,8 +1006,18 @@ x""") import Identifier.{Name, Operator, Constructor} // this operator application can be a pattern List( - ApplyOp(Var(Name("q")),Operator("|"),Var(Name("npzma"))), - ApplyOp(Parens(ApplyOp(Parens(Literal(Lit.Str("igyimc"))),Operator("|"),Var(Name("ncf5Eo9")))),Operator("|"),Var(Constructor("K"))) + ApplyOp(Var(Name("q")), Operator("|"), Var(Name("npzma"))), + ApplyOp( + Parens( + ApplyOp( + Parens(Literal(Lit.Str("igyimc"))), + Operator("|"), + Var(Name("ncf5Eo9")) + ) + ), + Operator("|"), + Var(Constructor("K")) + ) ) } @@ -750,10 +1027,12 @@ x""") val decStr = dec.toDoc.render(80) val parsePat = optionParse(Pattern.matchParser, decStr) (Declaration.toPattern(dec), parsePat) match { - case (None, None) => succeed + case (None, None) => succeed case (Some(p0), Some(p1)) => assert(p0 == p1) - case (None, Some(_)) => fail(s"toPattern failed, but parsed $decStr to: $parsePat") - case (Some(p), None) => fail(s"toPattern succeeded: $p but pattern parse failed") + case (None, Some(_)) => + fail(s"toPattern failed, but parsed $decStr to: $parsePat") + case (Some(p), None) => + fail(s"toPattern succeeded: $p but pattern parse failed") } } @@ -761,13 +1040,13 @@ x""") forAll(Generators.genNonBinding(5))(law2(_)) regressions.foreach(law2(_)) - def testEqual(decl: String) = { - val dec = unsafeParse(Declaration.parser(""), decl).asInstanceOf[Declaration.NonBinding] + val dec = unsafeParse(Declaration.parser(""), decl) + .asInstanceOf[Declaration.NonBinding] val patt = unsafeParse(Pattern.matchParser, decl) Declaration.toPattern(dec) match { case Some(p2) => assert(p2 == patt) - case None => fail(s"could not convert $decl to pattern") + case None => fail(s"could not convert $decl to pattern") } } @@ -780,28 +1059,70 @@ x""") test("we can parse bind") { import Declaration._ - parseTestAll(parser(""), + parseTestAll( + parser(""), """x = 4 x""", - Binding(BindingStatement(Pattern.Var(Identifier.Name("x")), Literal(Lit.fromInt(4)), Padding(0, mkVar("x"))))) + Binding( + BindingStatement( + Pattern.Var(Identifier.Name("x")), + Literal(Lit.fromInt(4)), + Padding(0, mkVar("x")) + ) + ) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), """x = foo(4) x""", - Binding(BindingStatement(Pattern.Var(Identifier.Name("x")), Apply(mkVar("foo"), NonEmptyList.of(Literal(Lit.fromInt(4))), ApplyKind.Parens), Padding(1, mkVar("x"))))) + Binding( + BindingStatement( + Pattern.Var(Identifier.Name("x")), + Apply( + mkVar("foo"), + NonEmptyList.of(Literal(Lit.fromInt(4))), + ApplyKind.Parens + ), + Padding(1, mkVar("x")) + ) + ) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), """x = foo(4) # x is really great x""", - Binding(BindingStatement(Pattern.Var(Identifier.Name("x")),Apply(mkVar("foo"),NonEmptyList.of(Literal(Lit.fromInt(4))), ApplyKind.Parens),Padding(0,CommentNB(CommentStatement(NonEmptyList.of(" x is really great"),Padding(0,mkVar("x")))))))) + Binding( + BindingStatement( + Pattern.Var(Identifier.Name("x")), + Apply( + mkVar("foo"), + NonEmptyList.of(Literal(Lit.fromInt(4))), + ApplyKind.Parens + ), + Padding( + 0, + CommentNB( + CommentStatement( + NonEmptyList.of(" x is really great"), + Padding(0, mkVar("x")) + ) + ) + ) + ) + ) + ) // allow indentation after = - roundTrip(parser(""), + roundTrip( + parser(""), """x = | foo - |x""".stripMargin) + |x""".stripMargin + ) } test("we can parse if") { @@ -811,103 +1132,139 @@ x""", val liftVar0 = Parser.Indy.lift(varP: P[NonBinding]) val parser0 = ifElseP(liftVar0, liftVar)("") - roundTrip[Declaration](parser0, + roundTrip[Declaration]( + parser0, """if w: x else: - y""") + y""" + ) - roundTrip[Declaration](parser0, + roundTrip[Declaration]( + parser0, """if w: | x |else: - | y""".stripMargin) + | y""".stripMargin + ) - roundTrip(parser(""), + roundTrip( + parser(""), """if eq_Int(x, 3): x else: - y""") + y""" + ) - expectFail(parser0, + expectFail( + parser0, """if x: x else - y""", 18) + y""", + 18 + ) - expectFail(parser0, + expectFail( + parser0, """if x: x -else y""", 13) +else y""", + 13 + ) - expectFail(parser(""), + expectFail( + parser(""), """if x: x else - y""", 18) + y""", + 18 + ) - expectFail(parser(""), + expectFail( + parser(""), """if x: x -else y""", 13) +else y""", + 13 + ) - expectFail(parser(""), + expectFail( + parser(""), """if f: 0 -else 1""", 13) +else 1""", + 13 + ) - roundTrip(parser(""), + roundTrip( + parser(""), """if eq_Int(x, 3): x elif foo: z else: - y""") + y""" + ) - roundTrip[Declaration](parser0, + roundTrip[Declaration]( + parser0, """if w: x -else: y""") - roundTrip(parser(""), +else: y""" + ) + roundTrip( + parser(""), """if eq_Int(x, 3): x -else: y""") +else: y""" + ) - roundTrip(parser(""), + roundTrip( + parser(""), """if eq_Int(x, 3): x elif foo: z -else: y""") +else: y""" + ) } test("we can parse a match") { val liftVar = Parser.Indy.lift(Declaration.varP: P[Declaration]) val liftVar0 = Parser.Indy.lift(Declaration.varP: P[Declaration.NonBinding]) - roundTrip[Declaration](Declaration.matchP(liftVar0, liftVar)(""), -"""match x: + roundTrip[Declaration]( + Declaration.matchP(liftVar0, liftVar)(""), + """match x: case y: z case w: - r""") - roundTrip(Declaration.parser(""), -"""match 1: + r""" + ) + roundTrip( + Declaration.parser(""), + """match 1: case Foo(a, b): a.plus(b) case Bar: - 42""") - roundTrip(Declaration.parser(""), - -"""match 1: + 42""" + ) + roundTrip( + Declaration.parser(""), + """match 1: case (a, b): a.plus(b) case (): - 42""") + 42""" + ) - roundTrip(Declaration.parser(""), - -"""match 1: + roundTrip( + Declaration.parser(""), + """match 1: case (a, (b, c)): a.plus(b).plus(e) case (1,): - 42""") + 42""" + ) - roundTrip(Declaration.parser(""), -"""match 1: + roundTrip( + Declaration.parser(""), + """match 1: case Foo(a, b): a.plus(b) case Bar: @@ -915,20 +1272,24 @@ else: y""") case True: 100 case False: - 99""") + 99""" + ) - roundTrip(Declaration.parser(""), -"""foo(1, match 2: + roundTrip( + Declaration.parser(""), + """foo(1, match 2: case Foo: foo case Bar: # this is the bar case - bar, 100)""") + bar, 100)""" + ) - roundTrip(Declaration.parser(""), -"""if (match 2: + roundTrip( + Declaration.parser(""), + """if (match 2: case Foo: foo @@ -938,93 +1299,127 @@ else: y""") bar): 1 else: - 2""") + 2""" + ) - roundTrip(Declaration.parser(""), -"""if True: + roundTrip( + Declaration.parser(""), + """if True: match 1: case Foo(f): 1 else: - 100""") + 100""" + ) - roundTrip(Declaration.parser(""), -"""match x: + roundTrip( + Declaration.parser(""), + """match x: case Bar(_, _): - 10""") + 10""" + ) - roundTrip(Declaration.parser(""), -"""match x: + roundTrip( + Declaration.parser(""), + """match x: case Bar(_, _): if True: 0 - else: 10""") + else: 10""" + ) - roundTrip(Declaration.parser(""), -"""match x: + roundTrip( + Declaration.parser(""), + """match x: case Bar(_, _): if True: 0 - else: 10""") + else: 10""" + ) - roundTrip(Declaration.parser(""), -"""match x: + roundTrip( + Declaration.parser(""), + """match x: case []: 0 case [x]: 1 - case _: 2""") + case _: 2""" + ) - roundTrip(Declaration.parser(""), -"""Foo(x) = bar -x""") + roundTrip( + Declaration.parser(""), + """Foo(x) = bar +x""" + ) - roundTrip(Declaration.parser(""), -"""Foo { x } = bar -x""") + roundTrip( + Declaration.parser(""), + """Foo { x } = bar +x""" + ) - roundTrip(Declaration.parser(""), -"""Foo { x } = Foo{x:1} -x""") + roundTrip( + Declaration.parser(""), + """Foo { x } = Foo{x:1} +x""" + ) - roundTrip(Declaration.parser(""), -"""match x: - case Some(_) | None: 1""") + roundTrip( + Declaration.parser(""), + """match x: + case Some(_) | None: 1""" + ) - roundTrip(Declaration.parser(""), -"""match x: + roundTrip( + Declaration.parser(""), + """match x: case Some(_) | None: 1 case y: y - case [x | y, _]: z""") - + case [x | y, _]: z""" + ) - roundTrip(Declaration.parser(""), -"""Foo(x) | Bar(x) = bar -x""") + roundTrip( + Declaration.parser(""), + """Foo(x) | Bar(x) = bar +x""" + ) - roundTrip(Declaration.parser(""), -"""(x: Int) = bar -x""") - roundTrip(Declaration.parser(""), -"""x: Int = bar -x""") + roundTrip( + Declaration.parser(""), + """(x: Int) = bar +x""" + ) + roundTrip( + Declaration.parser(""), + """x: Int = bar +x""" + ) } test("we allow extra indentation on elif and else for better alignment") { - roundTrip(Declaration.parser(""), + roundTrip( + Declaration.parser(""), """z = if w: | x | else: | y - |z""".stripMargin) + |z""".stripMargin + ) - roundTrip(Declaration.parser(""), + roundTrip( + Declaration.parser(""), """z = if w: x | elif y: z | else: quux - |z""".stripMargin) + |z""".stripMargin + ) } test("we can parse declaration lists") { - val ll = ListLang.parser(Declaration.parser(""), Declaration.nonBindingParserNoTern(""), Pattern.matchParser) + val ll = ListLang.parser( + Declaration.parser(""), + Declaration.nonBindingParserNoTern(""), + Pattern.matchParser + ) roundTrip(Declaration.parser(""), "[]") roundTrip(Declaration.parser(""), "[1]") @@ -1041,13 +1436,18 @@ x""") roundTrip(ll, "[x for x in range(4) if x.eq_Int(2)]") roundTrip(ListLang.SpliceOrItem.parser(Declaration.parser("")), "a") roundTrip(ListLang.SpliceOrItem.parser(Declaration.parser("")), "foo(a, b)") - roundTrip(ListLang.SpliceOrItem.parser(Declaration.parser("")), "*foo(a, b)") + roundTrip( + ListLang.SpliceOrItem.parser(Declaration.parser("")), + "*foo(a, b)" + ) roundTrip(Declaration.parser(""), "[x for y in [1, 2]]") roundTrip(Declaration.parser(""), "[x for y in [1, 2] if foo]") } test("we can parse any Declaration") { - forAll(Generators.genDeclaration(5))(law(Declaration.parser("").map(_.replaceRegions(emptyRegion)))) + forAll(Generators.genDeclaration(5))( + law(Declaration.parser("").map(_.replaceRegions(emptyRegion))) + ) def decl(s: String) = roundTrip(Declaration.parser(""), s) @@ -1095,26 +1495,35 @@ x""") } test("we can parse any Statement") { - forAll(Generators.genStatements(4, 10))(law(Statement.parser.map(_.map(_.replaceRegions(emptyRegion))))) - - roundTrip(Statement.parser, -"""# -def foo(x): x""") - - roundTrip(Statement.parser, -"""# + forAll(Generators.genStatements(4, 10))( + law(Statement.parser.map(_.map(_.replaceRegions(emptyRegion)))) + ) + + roundTrip( + Statement.parser, + """# +def foo(x): x""" + ) + + roundTrip( + Statement.parser, + """# def foo(x): - x""") + x""" + ) - roundTrip(Statement.parser, -"""# + roundTrip( + Statement.parser, + """# operator + = plus x = 1+2 -""") +""" + ) - roundTrip(Statement.parser, -"""# header + roundTrip( + Statement.parser, + """# header y = if eq_Int(x, 2): True else: @@ -1128,10 +1537,12 @@ fn = \x, y -> x.plus(y) x = ( foo ) -""") +""" + ) - roundTrip(Statement.parser, -"""# header + roundTrip( + Statement.parser, + """# header def foo(x: forall f. f[a] -> f[b], y: a) -> b: x(y) @@ -1140,11 +1551,13 @@ fn = \x, y -> x.plus(y) x = ( foo ) -""") +""" + ) // we can add spaces at the end of the file - roundTrip(Statement.parser, -"""# header + roundTrip( + Statement.parser, + """# header def foo(x: forall f. f[a] -> f[b], y: a) -> b: x(y) @@ -1152,78 +1565,99 @@ def foo(x: forall f. f[a] -> f[b], y: a) -> b: fn = \x, y -> x.plus(y) x = ( foo ) - """) + """ + ) - roundTrip(Statement.parser, -"""# + roundTrip( + Statement.parser, + """# x = Pair([], b) -""") +""" + ) - roundTrip(Statement.parser, -"""# + roundTrip( + Statement.parser, + """# Pair(x, _) = Pair([], b) -""") +""" + ) - roundTrip(Statement.parser, -"""# MONADS!!!! + roundTrip( + Statement.parser, + """# MONADS!!!! struct Monad(pure: forall a. a -> f[a], flatMap: forall a, b. f[a] -> (a -> f[b]) -> f[b]) -""") +""" + ) // we can put new-lines in structs - roundTrip(Statement.parser, -"""# MONADS!!!! + roundTrip( + Statement.parser, + """# MONADS!!!! struct Monad( pure: forall a. a -> f[a], flatMap: forall a, b. f[a] -> (a -> f[b]) -> f[b]) -""") +""" + ) // we can put type params in - roundTrip(Statement.parser, -"""# MONADS!!!! + roundTrip( + Statement.parser, + """# MONADS!!!! struct Monad[f]( pure: forall a. a -> f[a], flatMap: forall a, b. f[a] -> (a -> f[b]) -> f[b]) -""") +""" + ) // we can put new-lines in defs - roundTrip(Statement.parser, -"""# + roundTrip( + Statement.parser, + """# def foo( x, y: Int): x.add(y) -""") +""" + ) roundTrip(Statement.parser, """enum Option: None, Some(a)""") roundTrip(Statement.parser, """enum Option[a]: None, Some(a: a)""") - roundTrip(Statement.parser, -"""enum Option: + roundTrip( + Statement.parser, + """enum Option: None - Some(a)""") + Some(a)""" + ) - roundTrip(Statement.parser, -"""enum Option[a]: + roundTrip( + Statement.parser, + """enum Option[a]: None - Some(a: a)""") - - roundTrip(Statement.parser, -"""enum Option: - None, Some(a)""") - - roundTripExact(Statement.parser, -"""def run(z): + Some(a: a)""" + ) + + roundTrip( + Statement.parser, + """enum Option: + None, Some(a)""" + ) + + roundTripExact( + Statement.parser, + """def run(z): Err(y) | Good(y) = z y -""") +""" + ) } def dropTrailingPadding(s: List[Statement]): List[Statement] = s.reverse.dropWhile { case Statement.PaddingStatement(_) => true - case _ => false + case _ => false }.reverse test("Any statement may append trailing whitespace and continue to parse") { @@ -1233,55 +1667,77 @@ def foo( } } - test("Any statement ending in a newline may have it removed and continue to parse") { + test( + "Any statement ending in a newline may have it removed and continue to parse" + ) { forAll(Generators.genStatement(5)) { s => val str = Document[Statement].document(s).render(80) - roundTrip(Statement.parser.map(dropTrailingPadding(_)), str.reverse.dropWhile(_ == '\n').reverse) + roundTrip( + Statement.parser.map(dropTrailingPadding(_)), + str.reverse.dropWhile(_ == '\n').reverse + ) } } - test("Any declaration may append any whitespace and optionally a comma and parse") { - forAll(Generators.genDeclaration(4), Gen.listOf(Gen.oneOf(' ', '\t')).map(_.mkString), Gen.oneOf(true, false)) { - case (s, ws, comma) => - val str = Document[Declaration].document(s).render(80) + ws + (if (comma) "," else "") - roundTrip(Declaration.parser(""), str, lax = true) + test( + "Any declaration may append any whitespace and optionally a comma and parse" + ) { + forAll( + Generators.genDeclaration(4), + Gen.listOf(Gen.oneOf(' ', '\t')).map(_.mkString), + Gen.oneOf(true, false) + ) { case (s, ws, comma) => + val str = + Document[Declaration].document(s).render(80) + ws + (if (comma) "," + else "") + roundTrip(Declaration.parser(""), str, lax = true) } } test("parse external defs") { - roundTrip(Statement.parser, -"""# header + roundTrip( + Statement.parser, + """# header external foo: String -""") - roundTrip(Statement.parser, -"""# header +""" + ) + roundTrip( + Statement.parser, + """# header external def foo(i: Integer) -> String -""") - roundTrip(Statement.parser, -"""# header +""" + ) + roundTrip( + Statement.parser, + """# header external def foo(i: Integer, b: a) -> String external def foo2(i: Integer, b: a) -> String -""") +""" + ) } - test("we can parse any package") { - roundTrip(Package.parser(None), -""" + roundTrip( + Package.parser(None), + """ package Foo/Bar from Baz import Bippy export foo foo = 1 -""") +""" + ) - val pp = Package.parser(None).map { pack => pack.copy(program = pack.program.map(_.replaceRegions(emptyRegion))) } + val pp = Package.parser(None).map { pack => + pack.copy(program = pack.program.map(_.replaceRegions(emptyRegion))) + } forAll(Generators.packageGen(4))(law(pp)) - roundTripExact(Package.parser(None), -"""package Foo + roundTripExact( + Package.parser(None), + """package Foo enum Res[a, b]: Err(a: a), Good(a: a, b: b) @@ -1292,104 +1748,141 @@ def run(z): y main = run(x) -""") +""" + ) } test("parse errors point near where they occur") { - expectFail(Statement.parser, + expectFail( + Statement.parser, """x = 1 z = 3 z = 4 y = {'x': 'x' : 'y'} -""", 32) +""", + 32 + ) - expectFail(Statement.parser, + expectFail( + Statement.parser, """x = 1 z = ( x = 1 x x) -""", 24) +""", + 24 + ) - expectFail(Statement.parser, + expectFail( + Statement.parser, """x = 1 z = ( x = 1 y = [1, 2, 3] x x) -""", 40) +""", + 40 + ) - expectFail(Statement.parser, + expectFail( + Statement.parser, """z = ( if f: 0 else 1) -""", 23) +""", + 23 + ) - expectFail(Package.parser(None), + expectFail( + Package.parser(None), """package Foo from Baz import a, , b x = 1 -""", 31) +""", + 31 + ) - expectFail(Package.parser(None), + expectFail( + Package.parser(None), """package Foo export x, , y x = 1 -""", 22) +""", + 22 + ) - expectFail(Package.parser(None), + expectFail( + Package.parser(None), """package Foo export x, , x = 1 -""", 22) - expectFail(Package.parser(None), +""", + 22 + ) + expectFail( + Package.parser(None), """package Foo x = Foo(bar if bar) -""", 31) - +""", + 31 + ) - expectFail(Package.parser(None), + expectFail( + Package.parser(None), """package Foo z = [x for x in xs if x < y else ] -""", 41) +""", + 41 + ) } test("using parens to make blocks") { - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( y = 3 y ) -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( # some pattern matching Foo(y, _) = foo y ) -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( # an if/else block if True: 1 else: 0 ) -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( def foo(x): x @@ -1397,10 +1890,13 @@ x = ( foo(1) ) ) -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( # here is foo @@ -1410,33 +1906,44 @@ x = ( foo(1) ) ) -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( y = 3 y ) -""", lax = true) +""", + lax = true + ) } test("lambdas can have new lines") { - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = z -> z -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = z -> # we can comment here z -""", lax = true) +""", + lax = true + ) } test("Parser.integerWithBase works") { @@ -1455,11 +1962,12 @@ x = z -> prefix <- Gen.oneOf(prefix0, prefix0.toLowerCase) bi <- Gen.long.map(BigInteger.valueOf(_)) biStr = bi.toString(base) - withPrefix <- if (biStr(0) == '-') { - intersperse(biStr.tail).map { t => - s"-$prefix$t" - } - } else intersperse(biStr).map(prefix + _) + withPrefix <- + if (biStr(0) == '-') { + intersperse(biStr.tail).map { t => + s"-$prefix$t" + } + } else intersperse(biStr).map(prefix + _) } yield Args(bi, withPrefix, base) forAll(gen) { case Args(bi, inBase, base) => @@ -1468,7 +1976,7 @@ x = z -> assert(biP == bi) assert(b == base) case Left(err) => fail(err.toString) - } + } } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala b/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala index 9b6a8e4d6..64b705868 100644 --- a/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import org.scalacheck.{Gen, Arbitrary} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration} +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class PatternTest extends AnyFunSuite { @@ -23,7 +26,9 @@ class PatternTest extends AnyFunSuite { test("filtering for names not in a pattern is unbind") { forAll(patGen, Gen.listOf(Gen.identifier)) { (p, ids0) => val ids = ids0.map(Identifier.unsafe(_)) - assert(p.unbind == p.filterVars(ids.toSet.filterNot(p.names.toSet[Identifier]))) + assert( + p.unbind == p.filterVars(ids.toSet.filterNot(p.names.toSet[Identifier])) + ) } } @@ -60,27 +65,38 @@ class PatternTest extends AnyFunSuite { // we can name with the same name, and still be singly named assert(Pattern.SinglyNamed.unapply(Pattern.Named(n, p)) == Some(n)) // we can annotate and not lose singly named-ness - assert(Pattern.SinglyNamed.unapply(Pattern.Annotation(p, null)) == Some(n)) + assert( + Pattern.SinglyNamed.unapply(Pattern.Annotation(p, null)) == Some(n) + ) // we can make a union and not lose singly named-ness - assert(Pattern.SinglyNamed.unapply(Pattern.union(Pattern.Var(n), p :: Nil)) == Some(n)) + assert( + Pattern.SinglyNamed.unapply( + Pattern.union(Pattern.Var(n), p :: Nil) + ) == Some(n) + ) case _ => } forAll(patGen) { p => law(p) } - law(Pattern.Named(Identifier.Name("x"), Pattern.Named(Identifier.Name("x"), Pattern.WildCard))) + law( + Pattern.Named( + Identifier.Name("x"), + Pattern.Named(Identifier.Name("x"), Pattern.WildCard) + ) + ) } test("test some examples for singly named") { def check(str: String, nm: String) = pat(str) match { case Pattern.SinglyNamed(n) => assert(n == Identifier.unsafe(nm)) - case other => fail(s"expected singlynamed: $other") + case other => fail(s"expected singlynamed: $other") } def checkNot(str: String) = pat(str) match { case Pattern.SinglyNamed(n) => fail(s"unexpected singlynamed: $n") - case _ => succeed + case _ => succeed } check("foo", "foo") @@ -109,7 +125,12 @@ class PatternTest extends AnyFunSuite { val bar = Identifier.Name("bar") assert(Pattern.Var(foo).substructures.isEmpty) assert(Pattern.Annotation(Pattern.Var(foo), "Type").substructures.isEmpty) - assert(Pattern.Union(Pattern.Var(foo), NonEmptyList.of(Pattern.Var(bar))).substructures.isEmpty) + assert( + Pattern + .Union(Pattern.Var(foo), NonEmptyList.of(Pattern.Var(bar))) + .substructures + .isEmpty + ) } test("unions with total matches work correctly") { @@ -128,7 +149,7 @@ class PatternTest extends AnyFunSuite { assert(sp.matcher(str.take(len)).isDefined) } case _ => () - } + } } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/SelfCallKindTest.scala b/core/src/test/scala/org/bykn/bosatsu/SelfCallKindTest.scala index 5c28e1142..b2b1fbf98 100644 --- a/core/src/test/scala/org/bykn/bosatsu/SelfCallKindTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/SelfCallKindTest.scala @@ -2,14 +2,17 @@ package org.bykn.bosatsu import org.scalacheck.Gen import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.bykn.bosatsu.rankn.NTypeGen import org.bykn.bosatsu.TestUtils.checkLast import org.bykn.bosatsu.Identifier.Name class SelfCallKindTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) def gen[A](g: Gen[A]): Gen[TypedExpr[A]] = @@ -21,8 +24,7 @@ class SelfCallKindTest extends AnyFunSuite { test("test selfCallKind") { import SelfCallKind.{NoCall, NonTailCall, TailCall, apply => selfCallKind} - checkLast( - """ + checkLast(""" enum List[a]: E, NE(head: a, tail: List[a]) enum N: Z, S(prev: N) @@ -32,8 +34,7 @@ def list_len(list, acc): case NE(_, t): list_len(t, S(acc)) """) { te => assert(selfCallKind(Name("list_len"), te) == TailCall) } - checkLast( - """ + checkLast(""" enum List[a]: E, NE(head: a, tail: List[a]) enum N: Z, S(prev: N) @@ -43,8 +44,7 @@ def list_len(list): case NE(_, t): S(list_len(t)) """) { te => assert(selfCallKind(Name("list_len"), te) == NonTailCall) } - checkLast( - """ + checkLast(""" enum List[a]: E, NE(head: a, tail: List[a]) def list_len(list): @@ -56,8 +56,7 @@ def list_len(list): } test("for_all example") { - checkLast( - """ + checkLast(""" enum List[a]: E, NE(head: a, tail: List[a]) enum B: T, F @@ -68,7 +67,12 @@ def for_all(xs: List[a], fn: a -> B) -> B: match fn(head): case T: for_all(tail, fn) case F: F -""") { te => assert(SelfCallKind(Name("for_all"), te) == SelfCallKind.TailCall, s"${te.repr}") } +""") { te => + assert( + SelfCallKind(Name("for_all"), te) == SelfCallKind.TailCall, + s"${te.repr}" + ) + } } test("TypedExpr.Let.selfCallKind terminates and doesn't throw") { @@ -101,4 +105,4 @@ def for_all(xs: List[a], fn: a -> B) -> B: } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/bykn/bosatsu/SourceConverterTest.scala b/core/src/test/scala/org/bykn/bosatsu/SourceConverterTest.scala index 865c84e36..b7529e91f 100644 --- a/core/src/test/scala/org/bykn/bosatsu/SourceConverterTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/SourceConverterTest.scala @@ -1,7 +1,10 @@ package org.bykn.bosatsu import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration} +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import Identifier.Bindable @@ -10,14 +13,19 @@ import org.scalatest.funsuite.AnyFunSuite class SourceConverterTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 3000 else 20) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 3000 else 20 + ) val genRec = Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive) test("makeLetsUnique preserves let count") { val genLets = for { cnt <- Gen.choose(0, 100) - lets <- Gen.listOfN(cnt, Gen.zip(Generators.bindIdentGen, genRec, Gen.const(()))) + lets <- Gen.listOfN( + cnt, + Gen.zip(Generators.bindIdentGen, genRec, Gen.const(())) + ) } yield lets forAll(genLets) { lets => @@ -41,7 +49,8 @@ class SourceConverterTest extends AnyFunSuite { names <- Gen.listOfN(cnt, Generators.bindIdentGen) namesDistinct = names.distinct lets <- Generators.traverseGen(namesDistinct) { nm => - Gen.zip(genRec, Gen.choose(0, 10)) + Gen + .zip(genRec, Gen.choose(0, 10)) .map { case (r, d) => (nm, r, d) } } } yield lets @@ -58,7 +67,10 @@ class SourceConverterTest extends AnyFunSuite { test("makeLetsUnique applies to rhs for recursive binds") { val genLets = for { cnt <- Gen.choose(0, 100) - lets <- Gen.listOfN(cnt, Generators.bindIdentGen.map { b => (b, RecursionKind.Recursive, b) }) + lets <- Gen.listOfN( + cnt, + Generators.bindIdentGen.map { b => (b, RecursionKind.Recursive, b) } + ) } yield lets forAll(genLets) { lets => @@ -77,12 +89,33 @@ class SourceConverterTest extends AnyFunSuite { { // non recursive val l1 = List( - (Identifier.Name("b"), RecursionKind.NonRecursive, Option.empty[String]), - (Identifier.Name("a"), RecursionKind.NonRecursive, Option.empty[String]), - (Identifier.Name("c"), RecursionKind.NonRecursive, Option.empty[String]), - (Identifier.Name("a"), RecursionKind.NonRecursive, Option.empty[String]), - (Identifier.Name("d"), RecursionKind.NonRecursive, Option.empty[String]), - (Identifier.Name("a"), RecursionKind.NonRecursive, Option.empty[String])) + ( + Identifier.Name("b"), + RecursionKind.NonRecursive, + Option.empty[String] + ), + ( + Identifier.Name("a"), + RecursionKind.NonRecursive, + Option.empty[String] + ), + ( + Identifier.Name("c"), + RecursionKind.NonRecursive, + Option.empty[String] + ), + ( + Identifier.Name("a"), + RecursionKind.NonRecursive, + Option.empty[String] + ), + ( + Identifier.Name("d"), + RecursionKind.NonRecursive, + Option.empty[String] + ), + (Identifier.Name("a"), RecursionKind.NonRecursive, Option.empty[String]) + ) val up1 = SourceConverter.makeLetsUnique(l1) { case (Identifier.Name(n), idx) => @@ -99,7 +132,8 @@ class SourceConverterTest extends AnyFunSuite { (Identifier.Name("c"), RecursionKind.NonRecursive, Some("a0")), (Identifier.Name("a1"), RecursionKind.NonRecursive, Some("a0")), (Identifier.Name("d"), RecursionKind.NonRecursive, Some("a1")), - (Identifier.Name("a"), RecursionKind.NonRecursive, Some("a1"))) + (Identifier.Name("a"), RecursionKind.NonRecursive, Some("a1")) + ) assert(up1 == expectl1) } @@ -111,7 +145,8 @@ class SourceConverterTest extends AnyFunSuite { (Identifier.Name("c"), RecursionKind.Recursive, Option.empty[String]), (Identifier.Name("a"), RecursionKind.Recursive, Option.empty[String]), (Identifier.Name("d"), RecursionKind.Recursive, Option.empty[String]), - (Identifier.Name("a"), RecursionKind.Recursive, Option.empty[String])) + (Identifier.Name("a"), RecursionKind.Recursive, Option.empty[String]) + ) val up1 = SourceConverter.makeLetsUnique(l1) { case (Identifier.Name(n), idx) => @@ -128,7 +163,8 @@ class SourceConverterTest extends AnyFunSuite { (Identifier.Name("c"), RecursionKind.Recursive, Some("a0")), (Identifier.Name("a1"), RecursionKind.Recursive, Some("a1")), (Identifier.Name("d"), RecursionKind.Recursive, Some("a1")), - (Identifier.Name("a"), RecursionKind.Recursive, None)) + (Identifier.Name("a"), RecursionKind.Recursive, None) + ) assert(up1 == expectl1) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala index ed2def56e..b79d772ad 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala @@ -10,13 +10,16 @@ import IorMethods.IorExtension object TestUtils { - def parsedTypeEnvOf(pack: PackageName, str: String): ParsedTypeEnv[Option[Kind.Arg]] = { + def parsedTypeEnvOf( + pack: PackageName, + str: String + ): ParsedTypeEnv[Option[Kind.Arg]] = { val stmt = statementsOf(str) val prog = SourceConverter.toProgram(pack, Nil, stmt) match { - case Ior.Right(prog) => prog + case Ior.Right(prog) => prog case Ior.Both(_, prog) => prog - case Ior.Left(err) => sys.error(err.toString) + case Ior.Left(err) => sys.error(err.toString) } prog.types._2 } @@ -24,9 +27,9 @@ object TestUtils { val predefParsedTypeEnv: ParsedTypeEnv[Option[Kind.Arg]] = { val p = Package.predefPackage val prog = SourceConverter.toProgram(p.name, Nil, p.program) match { - case Ior.Right(prog) => prog + case Ior.Right(prog) => prog case Ior.Both(_, prog) => prog - case Ior.Left(err) => sys.error(err.toString) + case Ior.Left(err) => sys.error(err.toString) } prog.types._2 } @@ -37,46 +40,57 @@ object TestUtils { def statementsOf(str: String): List[Statement] = Parser.unsafeParse(Statement.parser, str) - /** - * Make sure no illegal final types escaped into a TypedExpr - */ + /** Make sure no illegal final types escaped into a TypedExpr + */ def assertValid[A](te: TypedExpr[A]): Unit = { def checkType(t: Type, bound: Set[Type.Var.Bound]): Type = t match { - case t@Type.TyVar(Type.Var.Skolem(_, _, _, _)) => + case t @ Type.TyVar(Type.Var.Skolem(_, _, _, _)) => sys.error(s"illegal skolem ($t) escape in ${te.repr}") case Type.TyVar(Type.Var.Bound(_)) => t - case t@Type.TyMeta(_) => + case t @ Type.TyMeta(_) => sys.error(s"illegal meta ($t) escape in ${te.repr}") case Type.TyApply(left, right) => Type.TyApply( checkType(left, bound).asInstanceOf[Type.Rho], - checkType(right, bound)) + checkType(right, bound) + ) case q: Type.Quantified => - q.copy(in = checkType(q.in, bound ++ q.vars.toList.map(_._1)).asInstanceOf[Type.Rho]) + q.copy(in = + checkType(q.in, bound ++ q.vars.toList.map(_._1)) + .asInstanceOf[Type.Rho] + ) case Type.TyConst(_) => t } te.traverseType[cats.Id](checkType(_, Set.empty)) val tp = te.getType lazy val teStr = Type.fullyResolvedDocument.document(tp).render(80) - scala.Predef.require(Type.freeTyVars(tp :: Nil).isEmpty, - s"illegal inferred type: $teStr in: ${te.repr}") - - scala.Predef.require(Type.metaTvs(tp :: Nil).isEmpty, - s"illegal inferred type: $teStr in: ${te.repr}") + scala.Predef.require( + Type.freeTyVars(tp :: Nil).isEmpty, + s"illegal inferred type: $teStr in: ${te.repr}" + ) + + scala.Predef.require( + Type.metaTvs(tp :: Nil).isEmpty, + s"illegal inferred type: $teStr in: ${te.repr}" + ) } val testPackage: PackageName = PackageName.parts("Test") - def checkLast(statement: String)(fn: TypedExpr[Declaration] => Assertion): Assertion = { + def checkLast( + statement: String + )(fn: TypedExpr[Declaration] => Assertion): Assertion = { val stmts = Parser.unsafeParse(Statement.parser, statement) Package.inferBody(testPackage, Nil, stmts).strictToValidated match { case Validated.Invalid(errs) => val lm = LocationMap(statement) val packMap = Map((testPackage, (lm, statement))) - val msg = errs.toList.map { err => - err.message(packMap, LocationMap.Colorize.None) - }.mkString("", "\n==========\n", "\n") + val msg = errs.toList + .map { err => + err.message(packMap, LocationMap.Colorize.None) + } + .mkString("", "\n==========\n", "\n") fail("inference failure: " + msg) case Validated.Valid(program) => // make sure all the TypedExpr are valid @@ -86,7 +100,9 @@ object TestUtils { } def makeInputArgs(files: List[(Int, Any)]): List[String] = - ("--package_root" :: Int.MaxValue.toString :: Nil) ::: files.flatMap { case (idx, _) => "--input" :: idx.toString :: Nil } + ("--package_root" :: Int.MaxValue.toString :: Nil) ::: files.flatMap { + case (idx, _) => "--input" :: idx.toString :: Nil + } private val module = new MemoryMain[Either[Throwable, *], Int]({ idx => if (idx == Int.MaxValue) Nil @@ -96,24 +112,37 @@ object TestUtils { def evalTest(packages: List[String], mainPackS: String, expected: Value) = { val files = packages.zipWithIndex.map(_.swap) - module.runWith(files)("eval" :: "--main" :: mainPackS :: makeInputArgs(files)) match { + module.runWith(files)( + "eval" :: "--main" :: mainPackS :: makeInputArgs(files) + ) match { case Right(module.Output.EvaluationResult(got, _, gotDoc)) => val gv = got.value - assert(gv == expected, s"${gotDoc.value.render(80)}\n\n$gv != $expected") + assert( + gv == expected, + s"${gotDoc.value.render(80)}\n\n$gv != $expected" + ) case Right(other) => fail(s"got an unexpected success: $other") case Left(err) => module.mainExceptionToString(err) match { case Some(msg) => fail(msg) - case None => fail(s"got an exception: $err") + case None => fail(s"got an exception: $err") } } } - def evalTestJson(packages: List[String], mainPackS: String, expected: Json) = { + def evalTestJson( + packages: List[String], + mainPackS: String, + expected: Json + ) = { val files = packages.zipWithIndex.map(_.swap) - module.runWith(files)("json" :: "write" :: "--main" :: mainPackS :: "--output" :: "-1" :: makeInputArgs(files)) match { + module.runWith(files)( + "json" :: "write" :: "--main" :: mainPackS :: "--output" :: "-1" :: makeInputArgs( + files + ) + ) match { case Right(module.Output.JsonOutput(got, _)) => assert(got == expected, s"$got != $expected") case Right(other) => @@ -123,15 +152,25 @@ object TestUtils { } } - def runBosatsuTest(packages: List[String], mainPackS: String, assertionCount: Int) = { + def runBosatsuTest( + packages: List[String], + mainPackS: String, + assertionCount: Int + ) = { val files = packages.zipWithIndex.map(_.swap) - module.runWith(files)("test" :: "--test_package" :: mainPackS :: makeInputArgs(files)) match { + module.runWith(files)( + "test" :: "--test_package" :: mainPackS :: makeInputArgs(files) + ) match { case Right(module.Output.TestOutput(results, _)) => results.collect { case (_, Some(t)) => t.value } match { case t :: Nil => - assert(t.assertions == assertionCount, s"${t.assertions} != $assertionCount") - val Test.Report(_, failcount, message) = Test.report(t, LocationMap.Colorize.None) + assert( + t.assertions == assertionCount, + s"${t.assertions} != $assertionCount" + ) + val Test.Report(_, failcount, message) = + Test.report(t, LocationMap.Colorize.None) assert(t.failures.map(_.assertions).getOrElse(0) == failcount) if (failcount > 0) fail(message.render(80)) else succeed @@ -151,44 +190,54 @@ object TestUtils { } } - def testInferred(packages: List[String], mainPackS: String, inferredHandler: (PackageMap.Inferred, PackageName) => Assertion)(implicit ec: Par.EC) = { + def testInferred( + packages: List[String], + mainPackS: String, + inferredHandler: (PackageMap.Inferred, PackageName) => Assertion + )(implicit ec: Par.EC) = { val mainPack = PackageName.parse(mainPackS).get val parsed = packages.zipWithIndex.traverse { case (pack, i) => Parser.parse(Package.parser(None), pack).map { case (lm, parsed) => ((i.toString, lm), parsed) } - } + } val parsedPaths = parsed match { case Validated.Valid(vs) => vs case Validated.Invalid(errs) => errs.toList.foreach { p => - System.err.println(p.showContext(LocationMap.Colorize.None).render(80)) + System.err.println( + p.showContext(LocationMap.Colorize.None).render(80) + ) } - sys.error("failed to parse") //errs.toString) + sys.error("failed to parse") // errs.toString) } val fullParsed = - PackageMap.withPredefA(("predef", LocationMap("")), parsedPaths) - .map { case ((path, _), p) => (path, p) } + PackageMap + .withPredefA(("predef", LocationMap("")), parsedPaths) + .map { case ((path, _), p) => (path, p) } PackageMap - .resolveThenInfer(fullParsed , Nil).strictToValidated match { - case Validated.Valid(packMap) => - inferredHandler(packMap, mainPack) - - case Validated.Invalid(errs) => - val tes = errs.toList.collect { - case PackageError.TypeErrorIn(te, _) => - te.toString + .resolveThenInfer(fullParsed, Nil) + .strictToValidated match { + case Validated.Valid(packMap) => + inferredHandler(packMap, mainPack) + + case Validated.Invalid(errs) => + val tes = errs.toList + .collect { case PackageError.TypeErrorIn(te, _) => + te.toString } .mkString("\n") - fail(tes + "\n" + errs.toString) - } + fail(tes + "\n" + errs.toString) + } } - def evalFail(packages: List[String])(errFn: PartialFunction[PackageError, Unit])(implicit ec: Par.EC) = { + def evalFail( + packages: List[String] + )(errFn: PartialFunction[PackageError, Unit])(implicit ec: Par.EC) = { val parsed = packages.zipWithIndex.traverse { case (pack, i) => Parser.parse(Package.parser(None), pack).map { case (lm, parsed) => @@ -217,9 +266,8 @@ object TestUtils { errs.toList.foreach(_.message(sm, LocationMap.Colorize.None)) assert(true) case Some(errs) => - fail(s"failed, but no type errors: $errs") + fail(s"failed, but no type errors: $errs") } } - } diff --git a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala index 922b610b8..5316af028 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala @@ -16,7 +16,8 @@ import org.typelevel.paiges.Document import Identifier.Constructor import org.scalacheck.Shrink -class TotalityTest extends SetOpsLaws[Pattern[(PackageName, Constructor), Type]] { +class TotalityTest + extends SetOpsLaws[Pattern[(PackageName, Constructor), Type]] { import Generators.shrinkPattern type Pat = Pattern[(PackageName, Constructor), Type] @@ -30,9 +31,8 @@ class TotalityTest extends SetOpsLaws[Pattern[(PackageName, Constructor), Type]] Generators.genCompiledPattern(5, useAnnotation = false) def showPat(pat: Pattern[(PackageName, Constructor), Type]): String = { - val pat0 = pat.mapName { - case (_, n) => - Pattern.StructKind.Named(n, Pattern.StructKind.Style.TupleLike) + val pat0 = pat.mapName { case (_, n) => + Pattern.StructKind.Named(n, Pattern.StructKind.Style.TupleLike) } implicit val tdoc = Type.fullyResolvedDocument @@ -64,8 +64,8 @@ enum Bool: False, True override def genItem: Gen[Pattern[(PackageName, Constructor), Type]] = genPattern - override val shrinkItem: Shrink[Pattern[(PackageName, Constructor),Type]] = - shrinkPattern + override val shrinkItem: Shrink[Pattern[(PackageName, Constructor), Type]] = + shrinkPattern val genPatternNoUnion: Gen[Pattern[(PackageName, Constructor), Type]] = Generators.genCompiledPattern(5, useUnion = false, useAnnotation = false) @@ -84,62 +84,76 @@ enum Bool: False, True new Eq[List[Pattern[(PackageName, Constructor), Type]]] { val e1 = PredefTotalityCheck.eqPat - def eqv(a: List[Pattern[(PackageName, Constructor), Type]], - b: List[Pattern[(PackageName, Constructor), Type]]) = - (NonEmptyList.fromList(a), NonEmptyList.fromList(b)) match { - case (oa, ob) if oa == ob => true - case (Some(a), Some(b)) => - e1.eqv(Pattern.union(a.head, a.tail), Pattern.union(b.head, b.tail)) - case _ => false - } + def eqv( + a: List[Pattern[(PackageName, Constructor), Type]], + b: List[Pattern[(PackageName, Constructor), Type]] + ) = + (NonEmptyList.fromList(a), NonEmptyList.fromList(b)) match { + case (oa, ob) if oa == ob => true + case (Some(a), Some(b)) => + e1.eqv(Pattern.union(a.head, a.tail), Pattern.union(b.head, b.tail)) + case _ => false + } } def eqUnion: Gen[Eq[List[Pattern[(PackageName, Constructor), Type]]]] = Gen.const(eqPatterns) def patterns(str: String): List[Pattern[(PackageName, Constructor), Type]] = { - val nameToCons: Constructor => (PackageName, Constructor) = - { cons => (PackageName.PredefName, cons) } - - /** - * This is sufficient for these tests, but is not - * a full features pattern compiler. - */ - def parsedToExpr(pat: Pattern.Parsed): Pattern[(PackageName, Constructor), rankn.Type] = - pat.mapStruct[(PackageName, Constructor)] { - case (Pattern.StructKind.Tuple, args) => - // this is a tuple pattern - def loop(args: List[Pattern[(PackageName, Constructor), TypeRef]]): Pattern[(PackageName, Constructor), TypeRef] = - args match { - case Nil => - // () - Pattern.PositionalStruct( - (PackageName.PredefName, Constructor("Unit")), - Nil) - case h :: tail => - val tailP = loop(tail) - Pattern.PositionalStruct( - (PackageName.PredefName, Constructor("TupleCons")), - h :: tailP :: Nil) - } - - loop(args) - case (Pattern.StructKind.Named(nm, _), args) => - Pattern.PositionalStruct(nameToCons(nm), args) - case (Pattern.StructKind.NamedPartial(nm, _), args) => - Pattern.PositionalStruct(nameToCons(nm), args) - } - .mapType { tref => - TypeRefConverter[cats.Id](tref) { tpe => - Type.Const.Defined(PackageName.PredefName, TypeName(tpe)) + val nameToCons: Constructor => (PackageName, Constructor) = { cons => + (PackageName.PredefName, cons) + } + + /** This is sufficient for these tests, but is not a full features pattern + * compiler. + */ + def parsedToExpr( + pat: Pattern.Parsed + ): Pattern[(PackageName, Constructor), rankn.Type] = + pat + .mapStruct[(PackageName, Constructor)] { + case (Pattern.StructKind.Tuple, args) => + // this is a tuple pattern + def loop( + args: List[Pattern[(PackageName, Constructor), TypeRef]] + ): Pattern[(PackageName, Constructor), TypeRef] = + args match { + case Nil => + // () + Pattern.PositionalStruct( + (PackageName.PredefName, Constructor("Unit")), + Nil + ) + case h :: tail => + val tailP = loop(tail) + Pattern.PositionalStruct( + (PackageName.PredefName, Constructor("TupleCons")), + h :: tailP :: Nil + ) + } + + loop(args) + case (Pattern.StructKind.Named(nm, _), args) => + Pattern.PositionalStruct(nameToCons(nm), args) + case (Pattern.StructKind.NamedPartial(nm, _), args) => + Pattern.PositionalStruct(nameToCons(nm), args) + } + .mapType { tref => + TypeRefConverter[cats.Id](tref) { tpe => + Type.Const.Defined(PackageName.PredefName, TypeName(tpe)) + } } - } - Parser.unsafeParse(Pattern.matchParser.listSyntax, str) + Parser + .unsafeParse(Pattern.matchParser.listSyntax, str) .map(parsedToExpr _) } - def notTotal(te: TypeEnv[Any], pats: List[Pattern[(PackageName, Constructor), Type]], testMissing: Boolean = true): Unit = { + def notTotal( + te: TypeEnv[Any], + pats: List[Pattern[(PackageName, Constructor), Type]], + testMissing: Boolean = true + ): Unit = { val res = TotalityCheck(te).missingBranches(pats).isEmpty assert(!res, pats.toString) @@ -157,7 +171,11 @@ enum Bool: False, True } } - def testTotality(te: TypeEnv[Any], pats: List[Pattern[(PackageName, Constructor), Type]], tight: Boolean = false)(implicit loc: munit.Location) = { + def testTotality( + te: TypeEnv[Any], + pats: List[Pattern[(PackageName, Constructor), Type]], + tight: Boolean = false + )(implicit loc: munit.Location) = { val res = TotalityCheck(te).missingBranches(pats) val asStr = res.map(showPat) assertEquals(asStr, Nil, showPats(pats)) @@ -165,7 +183,7 @@ enum Bool: False, True // any missing pattern shouldn't be total: def allButOne[A](head: A, tail: List[A]): List[List[A]] = tail match { - case Nil => Nil + case Nil => Nil case h :: rest => // we can either delete the head or one from the tail: val keepHead = allButOne(h, rest).map(head :: _) @@ -174,7 +192,9 @@ enum Bool: False, True pats match { case h :: tail if tight => - allButOne(h, tail).foreach(notTotal(te, _, testMissing = false)) // don't make an infinite loop here + allButOne(h, tail).foreach( + notTotal(te, _, testMissing = false) + ) // don't make an infinite loop here case _ => () } } @@ -192,7 +212,6 @@ struct Unit val pats = patterns("[Unit]") testTotality(te, pats) - val te1 = typeEnvOf("""# struct TupleCons(a, b) """) @@ -209,7 +228,11 @@ enum Option: None, Some(get) testTotality(te, patterns("[Some(_) | None]"), tight = true) testTotality(te, patterns("[Some(_), _]")) testTotality(te, patterns("[Some(1), Some(x), None]")) - testTotality(te, patterns("[Some(Some(_)), Some(None), None]"), tight = true) + testTotality( + te, + patterns("[Some(Some(_)), Some(None), None]"), + tight = true + ) testTotality(te, patterns("[Some(Some(_) | None), None]"), tight = true) notTotal(te, patterns("[Some(_)]")) @@ -224,13 +247,19 @@ enum Option: None, Some(get) enum Either: Left(l), Right(r) """) testTotality(te, patterns("[Left(_), Right(_)]")) - testTotality(te, - patterns("[Left(Right(_)), Left(Left(_)), Right(Left(_)), Right(Right(_))]"), - tight = true) + testTotality( + te, + patterns( + "[Left(Right(_)), Left(Left(_)), Right(Left(_)), Right(Right(_))]" + ), + tight = true + ) - testTotality(te, + testTotality( + te, patterns("[Left(Right(_) | Left(_)), Right(Left(_) | Right(_))]"), - tight = true) + tight = true + ) notTotal(te, patterns("[Left(_)]")) notTotal(te, patterns("[Right(_)]")) @@ -240,10 +269,22 @@ enum Either: Left(l), Right(r) test("test List matching") { testTotality(predefTE, patterns("[[], [h, *tail]]"), tight = true) - testTotality(predefTE, patterns("[[], [h, *tail], [h0, h1, *tail]]"), tight = true) + testTotality( + predefTE, + patterns("[[], [h, *tail], [h0, h1, *tail]]"), + tight = true + ) testTotality(predefTE, patterns("[[], [*tail, _]]"), tight = true) - testTotality(predefTE, patterns("[[*_, True, *_], [], [False, *_]]"), tight = true) - testTotality(predefTE, patterns("[[*_, True, *_], [] | [False, *_]]"), tight = true) + testTotality( + predefTE, + patterns("[[*_, True, *_], [], [False, *_]]"), + tight = true + ) + testTotality( + predefTE, + patterns("[[*_, True, *_], [] | [False, *_]]"), + tight = true + ) notTotal(predefTE, patterns("[[], [h, *tail, _]]")) } @@ -255,44 +296,66 @@ enum Option: None, Some(get) struct TupleCons(fst, snd) """) - testTotality(te, patterns("[None, Some(Left(_)), Some(Right(_))]"), tight = true) + testTotality( + te, + patterns("[None, Some(Left(_)), Some(Right(_))]"), + tight = true + ) testTotality(te, patterns("[None, Some(Left(_) | Right(_))]"), tight = true) - testTotality(te, patterns("[None, Some(TupleCons(Left(_), _)), Some(TupleCons(_, Right(_))), Some(TupleCons(Right(_), Left(_)))]"), tight = true) - testTotality(te, patterns("[None, Some(TupleCons(Left(_), _) | TupleCons(_, Right(_))), Some(TupleCons(Right(_), Left(_)))]"), tight = true) + testTotality( + te, + patterns( + "[None, Some(TupleCons(Left(_), _)), Some(TupleCons(_, Right(_))), Some(TupleCons(Right(_), Left(_)))]" + ), + tight = true + ) + testTotality( + te, + patterns( + "[None, Some(TupleCons(Left(_), _) | TupleCons(_, Right(_))), Some(TupleCons(Right(_), Left(_)))]" + ), + tight = true + ) } test("compose List with structs") { val te = typeEnvOf("""# enum Either: Left(l), Right(r) """) - testTotality(te, patterns("[[Left(_), *_], [Right(_), *_], [], [_, _, *_]]"), tight = true) - testTotality(te, patterns("[Left([]), Left([h, *_]), Right([]), Right([h, *_])]"), tight = true) + testTotality( + te, + patterns("[[Left(_), *_], [Right(_), *_], [], [_, _, *_]]"), + tight = true + ) + testTotality( + te, + patterns("[Left([]), Left([h, *_]), Right([]), Right([h, *_])]"), + tight = true + ) } - test("test intersection") { val p0 :: p1 :: p1norm :: Nil = patterns("[[*_], [*_, _], [_, *_]]") - PredefTotalityCheck.intersection(p0, p1) match { - case List(intr) => assert(intr == p1norm) - case other => fail(s"expected exactly one intersection: $other") - } + PredefTotalityCheck.intersection(p0, p1) match { + case List(intr) => assert(intr == p1norm) + case other => fail(s"expected exactly one intersection: $other") + } val p2 :: p3 :: Nil = patterns("[[*_], [_, _]]") - PredefTotalityCheck.intersection(p2, p3) match { - case List(intr) => assert(p3 == intr) - case other => fail(s"expected exactly one intersection: $other") - } - + PredefTotalityCheck.intersection(p2, p3) match { + case List(intr) => assert(p3 == intr) + case other => fail(s"expected exactly one intersection: $other") + } + // a regression { - val p0 :: p1 :: p2 :: Nil = patterns( - """["${_}$.{_}$.{_}", + val p0 :: p1 :: p2 :: Nil = patterns("""["${_}$.{_}$.{_}", "$.{foo}", "baz"]""") assert(PredefTotalityCheck.intersection(p0, p1).isEmpty) assert(PredefTotalityCheck.intersection(p1, p2).isEmpty) - + import pattern.SeqPattern.{stringUnitMatcher, Cat, Empty} import pattern.SeqPart.AnyElem import pattern.Splitter.stringUnit @@ -334,7 +397,7 @@ enum Either: Left(l), Right(r) val p0 :: p1 :: Nil = patterns("[[*_, _], [_, *_]]") PredefTotalityCheck.intersection(p0, p1) match { case List(res) if res == p0 || res == p1 => () - case Nil => fail("these do overlap") + case Nil => fail("these do overlap") case nonUnified => fail(s"didn't unify to one: $nonUnified") } } @@ -347,13 +410,52 @@ enum Either: Left(l), Right(r) import Identifier.Name val regressions: List[(Pat, Pat, Pat)] = - (Named(Name("hTt"), StrPat(NonEmptyList.of(NamedStr(Name("rfb")), LitStr("q"), NamedStr(Name("ngkrx"))))), + ( + Named( + Name("hTt"), + StrPat( + NonEmptyList + .of(NamedStr(Name("rfb")), LitStr("q"), NamedStr(Name("ngkrx"))) + ) + ), WildCard, - Named(Name("hjbmtklh"),StrPat(NonEmptyList.of(NamedStr(Name("qz8lcT")), WildStr, LitStr("p7"), NamedStr(Name("hqxprG")))))) :: - (WildCard, - ListPat(List(NamedList(Name("nv6")), Item(Literal(Lit.fromInt(-17))), Item(WildCard))), - ListPat(List(Item(StrPat(NonEmptyList.of(WildStr))), Item(StrPat(NonEmptyList.of(NamedStr(Name("eejhh")), LitStr("jbuzfcwsumP"), WildStr)))))) :: - Nil + Named( + Name("hjbmtklh"), + StrPat( + NonEmptyList.of( + NamedStr(Name("qz8lcT")), + WildStr, + LitStr("p7"), + NamedStr(Name("hqxprG")) + ) + ) + ) + ) :: + ( + WildCard, + ListPat( + List( + NamedList(Name("nv6")), + Item(Literal(Lit.fromInt(-17))), + Item(WildCard) + ) + ), + ListPat( + List( + Item(StrPat(NonEmptyList.of(WildStr))), + Item( + StrPat( + NonEmptyList.of( + NamedStr(Name("eejhh")), + LitStr("jbuzfcwsumP"), + WildStr + ) + ) + ) + ) + ) + ) :: + Nil regressions.foreach { case (a, b, c) => diffIntersectionLaw(a, b, c) @@ -364,16 +466,21 @@ enum Either: Left(l), Right(r) // see: https://github.com/johnynek/bosatsu/issues/475 def law(x: Pat, y: Pat)(implicit loc: munit.Location) = { if (setOps.isTop(y)) - assert(setOps.difference(x, y).isEmpty, s"x = ${showPat(x)}, y = ${showPat(y)}") + assert( + setOps.difference(x, y).isEmpty, + s"x = ${showPat(x)}, y = ${showPat(y)}" + ) } val regressions: List[(Pat, Pat)] = - List( - { - val struct = Pattern.PositionalStruct((PackageName(NonEmptyList.of("Pack")), Identifier.Constructor("Foo")), Nil) - val lst = Pattern.ListPat(List(Pattern.ListPart.WildList)) - (struct, lst) - }) + List({ + val struct = Pattern.PositionalStruct( + (PackageName(NonEmptyList.of("Pack")), Identifier.Constructor("Foo")), + Nil + ) + val lst = Pattern.ListPat(List(Pattern.ListPart.WildList)) + (struct, lst) + }) regressions.foreach { case (a, b) => law(a, b) } } @@ -383,35 +490,42 @@ enum Either: Left(l), Right(r) (pats(0), pats(1)) } test("subset consistency regressions") { - val regressions: List[(Pat, Pat)] = - { - val struct = Pattern.PositionalStruct((PackageName(NonEmptyList.of("Pack")), Identifier.Constructor("Foo")), Nil) - val lst = Pattern.ListPat(List(Pattern.ListPart.WildList)) - (struct, lst) - } :: { - import Pattern._ - import ListPart._ - - val a = WildCard - /* + val regressions: List[(Pat, Pat)] = { + val struct = Pattern.PositionalStruct( + (PackageName(NonEmptyList.of("Pack")), Identifier.Constructor("Foo")), + Nil + ) + val lst = Pattern.ListPat(List(Pattern.ListPart.WildList)) + (struct, lst) + } :: { + import Pattern._ + import ListPart._ + + val a = WildCard + /* val b = Union(ListPat(List()),NonEmptyList.of( ListPat(List(Item(Named(n("tmxb"),Union(Var(n("op")),NonEmptyList.of(Var(n("mjpqdwRbkz")), Literal(Lit.Chr("鱛")))))))), ListPat(List(Item(ListPat(List(WildList))), Item(Named(n("e7psNp0ok"),WildCard)), WildList)))) - */ - - val b = Union( - ListPat(List()), - NonEmptyList.of( - //ListPat(List(Item(Union(WildCard, NonEmptyList.of(WildCard, Literal(Lit.Chr("鱛"))))))), - ListPat(List(Item(WildCard))), - ListPat(List(Item(ListPat(List(WildList))), Item(WildCard), WildList)))) - //ListPat(List(Item(WildCard), Item(WildCard), WildList)))) - - assert(setOps.isTop(ListPat(List(WildList)))) - assertEquals(setOps.unifyUnion(Pattern.flatten(b).toList), WildCard :: Nil) - assertEquals(setOps.relate(a, b), Rel.Same) - (a, b) - } :: + */ + + val b = Union( + ListPat(List()), + NonEmptyList.of( + // ListPat(List(Item(Union(WildCard, NonEmptyList.of(WildCard, Literal(Lit.Chr("鱛"))))))), + ListPat(List(Item(WildCard))), + ListPat(List(Item(ListPat(List(WildList))), Item(WildCard), WildList)) + ) + ) + // ListPat(List(Item(WildCard), Item(WildCard), WildList)))) + + assert(setOps.isTop(ListPat(List(WildList)))) + assertEquals( + setOps.unifyUnion(Pattern.flatten(b).toList), + WildCard :: Nil + ) + assertEquals(setOps.relate(a, b), Rel.Same) + (a, b) + } :: Nil regressions.foreach { case (a, b) => @@ -419,7 +533,6 @@ enum Either: Left(l), Right(r) } } - test("difference is idempotent regressions") { import Pattern._ import ListPart._ @@ -430,13 +543,23 @@ enum Either: Left(l), Right(r) List( { val left = ListPat(List(Item(WildCard), WildList)) - val right = ListPat(List(Item(Var(Name("bey6ct"))), Item(Literal(Lit.fromInt(42))), Item(StrPat(NonEmptyList.of(WildStr))), Item(Literal(Lit("agfn"))), Item(WildCard))) + val right = ListPat( + List( + Item(Var(Name("bey6ct"))), + Item(Literal(Lit.fromInt(42))), + Item(StrPat(NonEmptyList.of(WildStr))), + Item(Literal(Lit("agfn"))), + Item(WildCard) + ) + ) (left, right) }, pair("""[[] | [_, *_], "$.{_}${_}"]""") ) - regressions.foreach { case (a, b) => differenceIsIdempotent(a, b, eqPatterns) } + regressions.foreach { case (a, b) => + differenceIsIdempotent(a, b, eqPatterns) + } } test("if a n b = 0 then a - b = a regressions") { @@ -450,17 +573,28 @@ enum Either: Left(l), Right(r) List( { val left = ListPat(List(Item(WildCard), WildList)) - val right = ListPat(List(Item(Var(Name("bey6ct"))), Item(Literal(Lit.fromInt(42))), Item(StrPat(NonEmptyList.of(WildStr))), Item(Literal(Lit("agfn"))), Item(WildCard))) + val right = ListPat( + List( + Item(Var(Name("bey6ct"))), + Item(Literal(Lit.fromInt(42))), + Item(StrPat(NonEmptyList.of(WildStr))), + Item(Literal(Lit("agfn"))), + Item(WildCard) + ) + ) (left, right) - }, - { - val left = ListPat(List(NamedList(Name("a")), Item(WildCard), Item(Var(Name("b"))))) + }, { + val left = ListPat( + List(NamedList(Name("a")), Item(WildCard), Item(Var(Name("b")))) + ) val right = ListPat(List()) (left, right) } ) - regressions.foreach { case (a, b) => emptyIntersectionMeansDiffIdent(a, b, eqPatterns) } + regressions.foreach { case (a, b) => + emptyIntersectionMeansDiffIdent(a, b, eqPatterns) + } } test("difference returns distinct regressions") { @@ -478,8 +612,7 @@ enum Either: Left(l), Right(r) val tc = PredefTotalityCheck { - val p0 :: p1 :: Nil = patterns( - """[ + val p0 :: p1 :: Nil = patterns("""[ "${_}$.{_}", "${_}$.{_}${_}", ]""") @@ -502,7 +635,7 @@ enum Either: Left(l), Right(r) override def missingBranchesIfAddedRegressions: List[List[Pat]] = { patterns("""[[*foo, "$.{_}", "$.{_}"], [[b, *_]]]""") :: - Nil + Nil } test("var pattern is super or same") { @@ -510,7 +643,7 @@ enum Either: Left(l), Right(r) val p1 :: p2 :: _ = patterns("""[foo, Bar(1)]""") val rel = tc.patternSetOps.relate(p1, p2) - assertEquals(rel, Rel.Super) + assertEquals(rel, Rel.Super) } test("union commutes with type wrappers: Some(1 | 2) == Some(1) | Some(2)") { @@ -519,40 +652,43 @@ enum Either: Left(l), Right(r) { val p1 :: p2 :: _ = patterns("""[Some(1 | 2), Some(1) | Some(2)]""") val rel = tc.patternSetOps.relate(p1, p2) - assertEquals(rel, Rel.Same) + assertEquals(rel, Rel.Same) } { val p1 :: p2 :: _ = patterns("""[Some(1 | 2 | 3), Some(1) | Some(2)]""") val rel = tc.patternSetOps.relate(p1, p2) - assertEquals(rel, Rel.Super) + assertEquals(rel, Rel.Super) } } property("unifyUnion returns no top-level unions") { forAll(Gen.listOf(genPattern)) { pats => - val unions = setOps.unifyUnion(pats).collect { case u @ Pattern.Union(_, _) => u } + val unions = + setOps.unifyUnion(pats).collect { case u @ Pattern.Union(_, _) => u } assertEquals(unions, Nil) } } property("unifyUnion(u) <:> u == Same") { def law(pat1: Pat, pat2: Pat)(implicit loc: munit.Location) = { - val unions = NonEmptyList.fromListUnsafe(setOps.unifyUnion(pat1 :: pat2 :: Nil)) + val unions = + NonEmptyList.fromListUnsafe(setOps.unifyUnion(pat1 :: pat2 :: Nil)) val u1 = Pattern.union(unions.head, unions.tail) assertEquals( setOps.relate(Pattern.union(pat1, pat2 :: Nil), u1), Rel.Same, - s"p1 = ${showPat(pat1)}\np2 = ${showPat(pat2)}\nunified = ${showPat(u1)}") + s"p1 = ${showPat(pat1)}\np2 = ${showPat(pat2)}\nunified = ${showPat(u1)}" + ) } val regressions = pair("""["$.{_}${_}$.{_}", "$.{_}${_}"]""") :: - pair("""["$.{_}", "${_}$.{_}$.{_}" as e]""") :: - pair("""["$.{a}", "${b}$.{c}$.{d}" as e]""") :: - pair("""["$.{bar}" as baz, "${_}${_}$.{c}${d}"]""") :: - pair("""["$.{bar}${_}$.{_}", "$.{_}${_}" as foo]""") :: - Nil + pair("""["$.{_}", "${_}$.{_}$.{_}" as e]""") :: + pair("""["$.{a}", "${b}$.{c}$.{d}" as e]""") :: + pair("""["$.{bar}" as baz, "${_}${_}$.{c}${d}"]""") :: + pair("""["$.{bar}${_}$.{_}", "$.{_}${_}" as foo]""") :: + Nil regressions.foreach { case (a, b) => law(a, b) } forAll(genPattern, genPattern)(law(_, _)) @@ -561,7 +697,7 @@ enum Either: Left(l), Right(r) test("x - y where isTop(y) regressions") { val regressions = (pair("""["foo", ([] | [_, *_])]"""), true) :: - Nil + Nil regressions.foreach { case ((x, y), top) => val rel = setOps.relate(x, y) @@ -570,7 +706,11 @@ enum Either: Left(l), Right(r) assert(yIsTop) } if (yIsTop) { - assertEquals(setOps.difference(x, y), Nil, s"${showPat(x)} - ${showPat(y)}, rel = $rel") + assertEquals( + setOps.difference(x, y), + Nil, + s"${showPat(x)} - ${showPat(y)}, rel = $rel" + ) } } } @@ -580,8 +720,8 @@ enum Either: Left(l), Right(r) val normp = PredefTotalityCheck.normalizePattern(p) assertEquals( setOps.relate(normp, q), - setOps.relate(p, q), - ) + setOps.relate(p, q) + ) assertEquals(setOps.relate(normp, p), Rel.Same) } diff --git a/core/src/test/scala/org/bykn/bosatsu/TypeRefTest.scala b/core/src/test/scala/org/bykn/bosatsu/TypeRefTest.scala index 9eca9bfe3..34eb06590 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TypeRefTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TypeRefTest.scala @@ -1,12 +1,15 @@ package org.bykn.bosatsu -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.bykn.bosatsu.rankn.Type import org.scalatest.funsuite.AnyFunSuite class TypeRefTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 500000) + // PropertyCheckConfiguration(minSuccessful = 500000) PropertyCheckConfiguration(minSuccessful = 5000) import Generators.{typeRefGen, shrinkTypeRef} @@ -24,14 +27,18 @@ class TypeRefTest extends AnyFunSuite { val pn = PackageName.parts("Test") forAll(typeRefGen) { tr => - val tpe = TypeRefConverter[cats.Id](tr) { c => Type.Const.Defined(pn, TypeName(c)) } - val tr1 = TypeRefConverter.fromTypeA[Option](tpe, + val tpe = TypeRefConverter[cats.Id](tr) { c => + Type.Const.Defined(pn, TypeName(c)) + } + val tr1 = TypeRefConverter.fromTypeA[Option]( + tpe, { _ => None }, { _ => None }, { case Type.Const.Defined(p, t) if p == pn => Some(TypeRef.TypeName(t)) - case _ => None - }) + case _ => None + } + ) assert(tr1 == Some(tr.normalizeForAll), s"tpe = $tpe") } diff --git a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala index ecf90c59b..a5a34e9ba 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala @@ -4,7 +4,10 @@ import cats.data.{NonEmptyList, State, Writer} import cats.implicits._ import org.scalacheck.{Arbitrary, Gen} import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import scala.collection.immutable.SortedSet import Arbitrary.arbitrary @@ -15,21 +18,21 @@ import rankn.{Type, NTypeGen} class TypedExprTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) def allVars[A](te: TypedExpr[A]): Set[Bindable] = { type W[B] = Writer[Set[Bindable], B] te.traverseUp[W] { - case v@TypedExpr.Local(ident, _, _) => Writer(Set(ident), v) - case notVar => Writer(Set.empty, notVar) - }.run._1 + case v @ TypedExpr.Local(ident, _, _) => Writer(Set(ident), v) + case notVar => Writer(Set.empty, notVar) + }.run + ._1 } - /** - * Assert two bits of code normalize to the same thing - */ + /** Assert two bits of code normalize to the same thing + */ def normSame(s1: String, s2: String) = checkLast(s1) { t1 => checkLast(s2) { t2 => @@ -42,7 +45,10 @@ class TypedExprTest extends AnyFunSuite { val frees = TypedExpr.freeVarsSet(te :: Nil).toSet val av = allVars(te) val missing = frees -- av - assert(missing.isEmpty, s"expression:\n\n${te.repr}\n\nallVars: $av\n\nfrees: $frees") + assert( + missing.isEmpty, + s"expression:\n\n${te.repr}\n\nallVars: $av\n\nfrees: $frees" + ) } forAll(genTypedExpr)(law _) @@ -97,16 +103,19 @@ y = match x: case notLit => fail(s"expected Literal got: ${notLit.repr}") } - normSame("""# + normSame( + """# struct Tup2(a, b) x = 23 x = Tup2(1, 2) y = match x: case Tup2(a, _): a -""", """# +""", + """# y = 1 -""") +""" + ) checkLast("""# struct Tup2(a, b) @@ -234,7 +243,8 @@ y = match x: } test("we can lift a match above a lambda") { - normSame("""# + normSame( + """# struct Tup2(a, b) y = Tup2(1, 2) @@ -242,12 +252,15 @@ y = Tup2(1, 2) def inner_match(x): match y: case Tup2(a, _): Tup2(a, x) -""", """# +""", + """# struct Tup2(a, b) inner_match = x -> Tup2(1, x) -""") +""" + ) - normSame("""# + normSame( + """# struct Tup2(a, b) enum Eith: L(left), R(right) @@ -258,7 +271,8 @@ def run(y): case R(b): Tup2(x, b) inner_match -""", """# +""", + """# struct Tup2(a, b) enum Eith: L(left), R(right) @@ -266,11 +280,13 @@ def run(y): match y: case L(a): x -> Tup2(a, x) case R(b): x -> Tup2(x, b) -""") +""" + ) } test("we can push lets into match") { - normSame("""# + normSame( + """# struct Tup2(a, b) enum Eith: L(left), R(right) @@ -279,7 +295,8 @@ def run(y, x): match x: L(_): z R(r): r -""", """# +""", + """# struct Tup2(a, b) enum Eith: L(left), R(right) @@ -287,35 +304,45 @@ def run(y, x): match x: L(_): y R(r): r -""") +""" + ) } test("we can evaluate constant matches") { - normSame("""# + normSame( + """# x = match 1: case (1 | 2) as x: x case _: -1 -""", """# +""", + """# x = 1 -""") +""" + ) - normSame("""# + normSame( + """# x = match 1: case _: -1 -""", """# +""", + """# x = -1 -""") +""" + ) - normSame("""# + normSame( + """# y = 21 def foo(_): match y: case 42: 0 case x: x -""", """# +""", + """# foo = _ -> 21 -""") +""" + ) /* * This does not yet work @@ -333,7 +360,7 @@ def foo(_): """, """# foo = _ -> 1 """) - */ + */ } val intTpe = Type.IntType @@ -347,53 +374,77 @@ foo = _ -> 1 def varTE(n: String, tpe: Type): TypedExpr[Unit] = TypedExpr.Local(Identifier.Name(n), tpe, ()) - def let(n: String, ex1: TypedExpr[Unit], ex2: TypedExpr[Unit]): TypedExpr[Unit] = + def let( + n: String, + ex1: TypedExpr[Unit], + ex2: TypedExpr[Unit] + ): TypedExpr[Unit] = TypedExpr.Let(Identifier.Name(n), ex1, ex2, RecursionKind.NonRecursive, ()) - def letrec(n: String, ex1: TypedExpr[Unit], ex2: TypedExpr[Unit]): TypedExpr[Unit] = + def letrec( + n: String, + ex1: TypedExpr[Unit], + ex2: TypedExpr[Unit] + ): TypedExpr[Unit] = TypedExpr.Let(Identifier.Name(n), ex1, ex2, RecursionKind.Recursive, ()) - def app(fn: TypedExpr[Unit], arg: TypedExpr[Unit], tpe: Type): TypedExpr[Unit] = + def app( + fn: TypedExpr[Unit], + arg: TypedExpr[Unit], + tpe: Type + ): TypedExpr[Unit] = TypedExpr.App(fn, NonEmptyList.one(arg), tpe, ()) - def lam(n: String, nt: Type, res: TypedExpr[Unit]): TypedExpr[Unit] = - TypedExpr.AnnotatedLambda(NonEmptyList.one((Identifier.Name(n), nt)), res, ()) + def lam(n: String, nt: Type, res: TypedExpr[Unit]): TypedExpr[Unit] = + TypedExpr.AnnotatedLambda( + NonEmptyList.one((Identifier.Name(n), nt)), + res, + () + ) test("test let substitution") { { // substitution in let val let1 = let("y", varTE("x", intTpe), varTE("y", intTpe)) - assert(TypedExpr.substitute(Identifier.Name("x"), int(2), let1) == - Some(let("y", int(2), varTE("y", intTpe)))) + assert( + TypedExpr.substitute(Identifier.Name("x"), int(2), let1) == + Some(let("y", int(2), varTE("y", intTpe))) + ) } { // substitution in let with a masking val let1 = let("y", varTE("x", intTpe), varTE("y", intTpe)) - assert(TypedExpr.substitute(Identifier.Name("x"), varTE("y", intTpe), let1) == - None) + assert( + TypedExpr.substitute(Identifier.Name("x"), varTE("y", intTpe), let1) == + None + ) } { // substitution in let with a shadowing in result val let1 = let("y", varTE("y", intTpe), varTE("y", intTpe)) - assert(TypedExpr.substitute(Identifier.Name("y"), int(42), let1) == - Some(let("y", int(42), varTE("y", intTpe)))) + assert( + TypedExpr.substitute(Identifier.Name("y"), int(42), let1) == + Some(let("y", int(42), varTE("y", intTpe))) + ) } { // substitution in letrec with a shadowing in bind and result val let1 = letrec("y", varTE("y", intTpe), varTE("y", intTpe)) - assert(TypedExpr.substitute(Identifier.Name("y"), int(42), let1) == - Some(let1)) + assert( + TypedExpr.substitute(Identifier.Name("y"), int(42), let1) == + Some(let1) + ) } } lazy val genNonFree: Gen[TypedExpr[Unit]] = - genTypedExpr.flatMap { te => - if (TypedExpr.freeVars(te :: Nil).isEmpty) Gen.const(te) - else genNonFree - } + genTypedExpr.flatMap { te => + if (TypedExpr.freeVars(te :: Nil).isEmpty) Gen.const(te) + else genNonFree + } test("after substitution, a variable is no longer free") { forAll(genTypedExpr, genNonFree) { (te0, te1) => @@ -418,7 +469,7 @@ foo = _ -> 1 lazy val nf: Gen[Bindable] = Generators.bindIdentGen.flatMap { case isfree if frees(isfree) => nf - case notfree => Gen.const(notfree) + case notfree => Gen.const(notfree) } nf @@ -430,28 +481,40 @@ foo = _ -> 1 } yield (nf, te) forAll(pair, genNonFree) { case ((b, te0), te1) => - TypedExpr.substitute(b, te1, te0) match { - case None => - // te1 has no free variables, this shouldn't fail - assert(false) + TypedExpr.substitute(b, te1, te0) match { + case None => + // te1 has no free variables, this shouldn't fail + assert(false) - case Some(te0sub) => assert(te0sub == te0) - } + case Some(te0sub) => assert(te0sub == te0) + } } } - test("let x = y in x == y") { // inline lets of vars - assert(TypedExprNormalization.normalize(let("x", varTE("y", intTpe), varTE("x", intTpe))) == - Some(varTE("y", intTpe))) + assert( + TypedExprNormalization.normalize( + let("x", varTE("y", intTpe), varTE("x", intTpe)) + ) == + Some(varTE("y", intTpe)) + ) } val normalLet = - let("x", varTE("y", intTpe), - let("y", app(varTE("z", intTpe), int(43), intTpe), - app(app(varTE("x", intTpe), varTE("y", intTpe), intTpe), - varTE("y", intTpe), intTpe))) + let( + "x", + varTE("y", intTpe), + let( + "y", + app(varTE("z", intTpe), int(43), intTpe), + app( + app(varTE("x", intTpe), varTE("y", intTpe), intTpe), + varTE("y", intTpe), + intTpe + ) + ) + ) test("we can't inline using a shadow: let x = y in let y = z in x(y, y)") { // we can't inline a shadow @@ -462,20 +525,37 @@ foo = _ -> 1 } test("if w doesn't have x free: (app (let x y z) w) == let x y (app z w)") { - assert(TypedExprNormalization.normalize(app(normalLet, varTE("w", intTpe), intTpe)) == - Some( - let("x", varTE("y", intTpe), - let("y", app(varTE("z", intTpe), int(43), intTpe), - app(app(app(varTE("x", intTpe), varTE("y", intTpe), intTpe), - varTE("y", intTpe), intTpe), - varTE("w", intTpe), intTpe))))) + assert( + TypedExprNormalization.normalize( + app(normalLet, varTE("w", intTpe), intTpe) + ) == + Some( + let( + "x", + varTE("y", intTpe), + let( + "y", + app(varTE("z", intTpe), int(43), intTpe), + app( + app( + app(varTE("x", intTpe), varTE("y", intTpe), intTpe), + varTE("y", intTpe), + intTpe + ), + varTE("w", intTpe), + intTpe + ) + ) + ) + ) + ) } test("x -> f(x) == f") { val f = varTE("f", Type.Fun(intTpe, intTpe)) val left = lam("x", intTpe, app(f, varTE("x", intTpe), intTpe)) - + assert(TypedExprNormalization.normalize(left) == Some(f)) checkLast(""" @@ -506,7 +586,8 @@ x = Foo val int2int = Type.Fun(intTpe, intTpe) val f = varTE("f", Type.Fun(intTpe, int2int)) val z = varTE("z", intTpe) - val lamf = lam("x", intTpe, app(app(f, varTE("x", intTpe), int2int), z, intTpe)) + val lamf = + lam("x", intTpe, app(app(f, varTE("x", intTpe), int2int), z, intTpe)) val y = varTE("y", intTpe) val left = app(lamf, y, intTpe) val right = app(app(f, y, int2int), z, intTpe) @@ -521,7 +602,6 @@ res = ( y -> (x -> f(x, z))(y) ) """) { te1 => - checkLast(""" res = _ -> 1 """) { te2 => @@ -534,7 +614,6 @@ f = (_, y) -> y z = 1 res = y -> (x -> f(x, z))(y) """) { te1 => - checkLast(""" res = _ -> 1 """) { te2 => @@ -555,7 +634,6 @@ res = ( Tup(x, x) ) """) { te1 => - checkLast(""" struct Tup(a, b) def f(x): x @@ -584,7 +662,7 @@ fn = ( ) ) """) { te1 => - checkLast(""" + checkLast(""" enum FooBar: Foo, Bar fn = (x: FooBar) -> x @@ -604,7 +682,7 @@ x = ( c ) """) { te1 => - checkLast(""" + checkLast(""" enum FooBar: Foo, Bar x = Foo @@ -628,7 +706,9 @@ x = Foo test("TypedExpr.substituteTypeVar of identity is identity") { forAll(genTypedExpr, Gen.listOf(NTypeGen.genBound)) { (te, bounds) => - val identMap: Map[Type.Var, Type] = bounds.map { b => (b, Type.TyVar(b)) }.toMap + val identMap: Map[Type.Var, Type] = bounds.map { b => + (b, Type.TyVar(b)) + }.toMap assert(TypedExpr.substituteTypeVar(te, identMap) == te) } } @@ -637,9 +717,12 @@ x = Foo forAll(genTypedExpr, Gen.listOf(NTypeGen.genBound)) { (te, bounds) => val tpes = te.allTypes val avoid = tpes.toSet | bounds.map(Type.TyVar(_)).toSet - val replacements = Type.allBinders.iterator.filterNot { t => avoid(Type.TyVar(t)) } + val replacements = Type.allBinders.iterator.filterNot { t => + avoid(Type.TyVar(t)) + } val identMap: Map[Type.Var, Type] = - bounds.iterator.zip(replacements) + bounds.iterator + .zip(replacements) .map { case (b, v) => (b, Type.TyVar(v)) } .toMap val te1 = TypedExpr.substituteTypeVar(te, identMap) @@ -651,26 +734,33 @@ x = Foo test("TypedExpr.substituteTypeVar is not an identity function") { // if we replace all the current types with some bound types, things won't be the same forAll(genTypedExpr) { te => - val tpes: Set[Type.Var] = te.allTypes.iterator.collect { case Type.TyVar(b) => b }.toSet + val tpes: Set[Type.Var] = te.allTypes.iterator.collect { + case Type.TyVar(b) => b + }.toSet - implicit def setM[A: Ordering]: cats.Monoid[SortedSet[A]] = new cats.Monoid[SortedSet[A]] { def empty = SortedSet.empty def combine(a: SortedSet[A], b: SortedSet[A]) = a ++ b - } + } // All the vars that are used in bounds - val bounds: Set[Type.Var] = te.traverseType { (t: Type) => - t match { - case q: Type.Quantified => Writer(SortedSet[Type.Var](q.vars.toList.map(_._1): _*), t) - case _ => Writer(SortedSet[Type.Var](), t) + val bounds: Set[Type.Var] = te + .traverseType { (t: Type) => + t match { + case q: Type.Quantified => + Writer(SortedSet[Type.Var](q.vars.toList.map(_._1): _*), t) + case _ => Writer(SortedSet[Type.Var](), t) + } } - }.run._1.toSet[Type.Var] + .run + ._1 + .toSet[Type.Var] val replacements = Type.allBinders.iterator.filterNot(tpes) val identMap: Map[Type.Var, Type] = - tpes.filterNot(bounds) + tpes + .filterNot(bounds) .iterator .zip(replacements) .map { case (b, v) => (b, Type.TyVar(v)) } @@ -689,7 +779,9 @@ x = Foo } } - def count[A](te: TypedExpr[A])(fn: PartialFunction[TypedExpr[A], Boolean]): Int = { + def count[A]( + te: TypedExpr[A] + )(fn: PartialFunction[TypedExpr[A], Boolean]): Int = { type W[B] = Writer[Int, B] val (count, _) = te.traverseUp[W] { inner => @@ -700,24 +792,25 @@ x = Foo count } - def countMatch[A](te: TypedExpr[A]) = count(te) { case TypedExpr.Match(_, _, _) => true } - def countLet[A](te: TypedExpr[A]) = count(te) { case TypedExpr.Let(_, _, _, _, _) => true } + def countMatch[A](te: TypedExpr[A]) = count(te) { + case TypedExpr.Match(_, _, _) => true + } + def countLet[A](te: TypedExpr[A]) = count(te) { + case TypedExpr.Let(_, _, _, _, _) => true + } test("test match removed from some examples") { - checkLast( - """ + checkLast(""" x = _ -> 1 """) { te => assert(countMatch(te) == 0) } - checkLast( - """ + checkLast(""" x = 10 y = match x: case z: z """) { te => assert(countMatch(te) == 0) } - checkLast( - """ + checkLast(""" x = 10 y = match x: case _: 20 @@ -726,15 +819,13 @@ y = match x: test("test let removed from some examples") { // this should turn into `y = 20` as the last expression - checkLast( - """ + checkLast(""" x = 10 y = match x: case _: 20 """) { te => assert(countLet(te) == 0) } - checkLast( - """ + checkLast(""" foo = ( x = 1 _ = x @@ -760,7 +851,7 @@ x = ( fn1(NE(1, NE(2, E))) ) """) { te1 => - checkLast(""" + checkLast(""" enum L[a]: E, NE(head: a, tail: L[a]) x = ( @@ -803,7 +894,9 @@ x = ( } def lawR[A, B](te: TypedExpr[B], a: A)(fn: (B, A) => A) = { - val viaFold = te.foldRight(cats.Eval.now(a)) { (b, r) => r.map { j => fn(b, j) } }.value + val viaFold = te + .foldRight(cats.Eval.now(a)) { (b, r) => r.map { j => fn(b, j) } } + .value val viaTraverse: State[A, Unit] = te.traverse_[State[A, *], Unit] { b => for { i <- State.get[A] @@ -815,7 +908,6 @@ x = ( assert(viaFold == viaTraverse.runS(a).value, s"${te.repr}") } - forAll(genTypedExprInt, Gen.choose(0, 1000)) { (te, init) => // make a commutative int function law(init, te) { (a, b) => (a + 1) * b } @@ -830,9 +922,11 @@ x = ( def law[A, B: Monoid](te: TypedExpr[A])(fn: A => B) = { val viaFold = te.foldMap(fn) - val viaTraverse: Const[B, Unit] = te.traverse[Const[B, *], Unit] { b => - Const[B, Unit](fn(b)) - }.void + val viaTraverse: Const[B, Unit] = te + .traverse[Const[B, *], Unit] { b => + Const[B, Unit](fn(b)) + } + .void assert(viaFold == viaTraverse.getConst, s"${te.repr}") } @@ -841,11 +935,12 @@ x = ( // non-commutative forAll(genTypedExprChar, arbitrary[Char => String])(law(_)(_)) - val lamconst: TypedExpr[String] = + val lamconst: TypedExpr[String] = TypedExpr.AnnotatedLambda( NonEmptyList.one((Identifier.Name("x"), intTpe)), int(1).as("a"), - "b") + "b" + ) assert(lamconst.foldMap(identity) == "ab") assert(lamconst.traverse { a => Const[String, Unit](a) }.getConst == "ab") @@ -854,15 +949,17 @@ x = ( test("TypedExpr.traverse.void matches traverse_") { import cats.data.Const forAll(genTypedExprInt, arbitrary[Int => String]) { (te, fn) => - assert(te.traverse { i => Const[String, Unit](fn(i)) }.void == - te.traverse_ { i => Const[String, Unit](fn(i)) }) + assert( + te.traverse { i => Const[String, Unit](fn(i)) }.void == + te.traverse_ { i => Const[String, Unit](fn(i)) } + ) } } test("TypedExpr.foldRight matches foldRight for commutative funs") { forAll(genTypedExprInt, Gen.choose(0, 1000)) { (te, init) => - - val right = te.foldRight(cats.Eval.now(init)) { (i, ej) => ej.map(_ + i) }.value + val right = + te.foldRight(cats.Eval.now(init)) { (i, ej) => ej.map(_ + i) }.value val left = te.foldLeft(init)(_ + _) assert(right == left) } @@ -870,8 +967,11 @@ x = ( test("TypedExpr.foldRight matches foldRight for non-commutative funs") { forAll(genTypedExprInt) { te => - - val right = te.foldRight(cats.Eval.now("")) { (i, ej) => ej.map { j => i.toString + j } }.value + val right = te + .foldRight(cats.Eval.now("")) { (i, ej) => + ej.map { j => i.toString + j } + } + .value val left = te.foldLeft("") { (i, j) => i + j.toString } assert(right == left) } @@ -885,9 +985,14 @@ x = ( test("freeTyVars is a superset of the frees in the outer type") { forAll(genTypedExpr) { te => - assert(Type.freeTyVars(te.getType :: Nil).toSet.subsetOf( - te.freeTyVars.toSet - )) + assert( + Type + .freeTyVars(te.getType :: Nil) + .toSet + .subsetOf( + te.freeTyVars.toSet + ) + ) } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/ValueTest.scala b/core/src/test/scala/org/bykn/bosatsu/ValueTest.scala index b022037d7..b3898c3e3 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ValueTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ValueTest.scala @@ -1,7 +1,10 @@ package org.bykn.bosatsu import org.scalacheck.{Arbitrary, Gen} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import Value._ import org.scalatest.funsuite.AnyFunSuite @@ -9,7 +12,7 @@ class ValueTest extends AnyFunSuite { import GenValue.genValue implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) test("SumValue.toString is what we expect") { @@ -29,7 +32,7 @@ class ValueTest extends AnyFunSuite { forAll(genValue) { v => VOption.some(v) match { case VOption(Some(v1)) => assert(v1 == v) - case other => fail(s"expected Some($v) got $other") + case other => fail(s"expected Some($v) got $other") } } @@ -50,7 +53,7 @@ class ValueTest extends AnyFunSuite { forAll(Gen.listOf(genValue)) { vs => VList(vs) match { case VList(vs1) => assert(vs1 == vs) - case other => fail(s"expected VList($vs) got $other") + case other => fail(s"expected VList($vs) got $other") } } diff --git a/core/src/test/scala/org/bykn/bosatsu/ValueToDocTest.scala b/core/src/test/scala/org/bykn/bosatsu/ValueToDocTest.scala index a9d25ee70..4fd7ab235 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ValueToDocTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ValueToDocTest.scala @@ -1,7 +1,10 @@ package org.bykn.bosatsu import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import rankn.{NTypeGen, Type, TypeEnv} import TestUtils.typeEnvOf @@ -10,7 +13,9 @@ import org.scalatest.funsuite.AnyFunSuite class ValueToDocTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 1000 else 20) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 1000 else 20 + ) test("never throw when converting to doc") { val tegen = Generators.typeEnvGen(PackageName.parts("Foo"), Gen.const(())) @@ -19,12 +24,14 @@ class ValueToDocTest extends AnyFunSuite { tegen.flatMap { te => val tyconsts = te.allDefinedTypes.map(_.toTypeConst) - val theseTypes = NTypeGen.genDepth(4, if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts))) + val theseTypes = NTypeGen.genDepth( + 4, + if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts)) + ) theseTypes.map((te, _)) } - forAll(withType, GenValue.genValue) { case ((te, t), v) => val vd = ValueToDoc(te.toDefinedType(_)) vd.toDoc(t)(v) @@ -33,7 +40,9 @@ class ValueToDocTest extends AnyFunSuite { } test("some hand written cases round trip") { - val te = typeEnvOf(PackageName.parts("Test"), """ + val te = typeEnvOf( + PackageName.parts("Test"), + """ struct MyUnit # wrappers are removed @@ -43,19 +52,21 @@ struct MyPair(fst, snd) enum MyEither: L(left), R(right) enum MyNat: Z, S(prev: MyNat) -""") +""" + ) val conv = ValueToDoc(te.toDefinedType(_)) def stringToType(t: String): Type = { val tr = Parser.unsafeParse(TypeRef.parser, t) TypeRefConverter[cats.Id](tr) { cons => - te.referencedPackages.toList.flatMap { pack => - val const = Type.Const.Defined(pack, TypeName(cons)) - te.toDefinedType(const).map(_ => const) - } - .headOption - .getOrElse(Type.Const.predef(cons.asString)) + te.referencedPackages.toList + .flatMap { pack => + val const = Type.Const.Defined(pack, TypeName(cons)) + te.toDefinedType(const).map(_ => const) + } + .headOption + .getOrElse(Type.Const.predef(cons.asString)) } } @@ -65,7 +76,7 @@ enum MyNat: Z, S(prev: MyNat) toDoc(v) match { case Right(doc) => assert(doc.render(80) == str) - case Left(err) => fail(s"could not handle to Value: $tpe, $v, $err") + case Left(err) => fail(s"could not handle to Value: $tpe, $v, $err") } } @@ -73,10 +84,25 @@ enum MyNat: Z, S(prev: MyNat) law("String", Value.Str("hello world"), "'hello world'") law("MyUnit", Value.UnitValue, "MyUnit") law("MyWrapper[MyUnit]", Value.UnitValue, "MyWrapper { item: MyUnit }") - law("MyWrapper[MyWrapper[MyUnit]]", Value.UnitValue, "MyWrapper { item: MyWrapper { item: MyUnit } }") - law("MyPair[MyUnit, MyUnit]", Value.ProductValue.fromList(List(Value.UnitValue, Value.UnitValue)), - "MyPair { fst: MyUnit, snd: MyUnit }") - law("MyEither[MyUnit, MyUnit]", Value.SumValue(0, Value.ProductValue.fromList(List(Value.UnitValue))), "L { left: MyUnit }") - law("MyEither[MyUnit, MyUnit]", Value.SumValue(1, Value.ProductValue.fromList(List(Value.UnitValue))), "R { right: MyUnit }") + law( + "MyWrapper[MyWrapper[MyUnit]]", + Value.UnitValue, + "MyWrapper { item: MyWrapper { item: MyUnit } }" + ) + law( + "MyPair[MyUnit, MyUnit]", + Value.ProductValue.fromList(List(Value.UnitValue, Value.UnitValue)), + "MyPair { fst: MyUnit, snd: MyUnit }" + ) + law( + "MyEither[MyUnit, MyUnit]", + Value.SumValue(0, Value.ProductValue.fromList(List(Value.UnitValue))), + "L { left: MyUnit }" + ) + law( + "MyEither[MyUnit, MyUnit]", + Value.SumValue(1, Value.ProductValue.fromList(List(Value.UnitValue))), + "R { right: MyUnit }" + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/VarianceTest.scala b/core/src/test/scala/org/bykn/bosatsu/VarianceTest.scala index df8529dc1..34afddbb8 100644 --- a/core/src/test/scala/org/bykn/bosatsu/VarianceTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/VarianceTest.scala @@ -9,7 +9,8 @@ object VarianceGen { Variance.Phantom, Variance.Contravariant, Variance.Covariant, - Variance.Invariant) + Variance.Invariant + ) implicit val arbVar: Arbitrary[Variance] = Arbitrary(gen) } @@ -27,7 +28,9 @@ class VarianceTest extends AnyFunSuite { test("variance combine is associative") { forAll { (v1: Variance, v2: Variance, v3: Variance) => - assert(V.combine(v1, V.combine(v2, v3)) == V.combine(V.combine(v1, v2), v3)) + assert( + V.combine(v1, V.combine(v2, v3)) == V.combine(V.combine(v1, v2), v3) + ) } } @@ -45,7 +48,7 @@ class VarianceTest extends AnyFunSuite { test("variance is distributive") { forAll { (v1: Variance, v2: Variance, v3: Variance) => - val left = v1 * (v2 + v3) + val left = v1 * (v2 + v3) val right = (v1 * v2) + (v1 * v3) assert(left == right, s"$left != $right") } @@ -56,7 +59,7 @@ class VarianceTest extends AnyFunSuite { val v2 = Variance.phantom val v3 = Variance.co - val left = v1 * (v2 + v3) + val left = v1 * (v2 + v3) val right = (v1 * v2) + (v1 * v3) assert(left == right, s"$left != $right") } @@ -112,7 +115,12 @@ class VarianceTest extends AnyFunSuite { } test("covariant combines to get either covariant or invariant") { - assert(V.combine(Variance.Covariant, Variance.Contravariant) == Variance.Invariant) + assert( + V.combine( + Variance.Covariant, + Variance.Contravariant + ) == Variance.Invariant + ) val results = Set(Variance.co, Variance.in) forAll { (v1: Variance) => assert(results(V.combine(v1, Variance.Covariant))) @@ -120,7 +128,12 @@ class VarianceTest extends AnyFunSuite { } test("contravariant combines to get either contravariant or invariant") { - assert(V.combine(Variance.Covariant, Variance.Contravariant) == Variance.Invariant) + assert( + V.combine( + Variance.Covariant, + Variance.Contravariant + ) == Variance.Invariant + ) val results = Set(Variance.contra, Variance.in) forAll { (v1: Variance) => assert(results(V.combine(v1, Variance.Contravariant))) diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala index 7ef56c142..a00b62edd 100644 --- a/core/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala @@ -2,14 +2,17 @@ package org.bykn.bosatsu.codegen.python import org.bykn.bosatsu.Identifier.{Bindable, unsafeBindable} import org.bykn.bosatsu.Generators.bindIdentGen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class PythonGenTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 50000) + // PropertyCheckConfiguration(minSuccessful = 50000) PropertyCheckConfiguration(minSuccessful = 5000) - //PropertyCheckConfiguration(minSuccessful = 500) + // PropertyCheckConfiguration(minSuccessful = 500) test("PythonGen.escape round trips") { @@ -17,16 +20,14 @@ class PythonGenTest extends AnyFunSuite { val ident = PythonGen.escape(b) PythonGen.unescape(ident) match { case Some(b1) => assert(b1.asString == b.asString) - case None => assert(false, s"$b => $ident could not round trip") + case None => assert(false, s"$b => $ident could not round trip") } } forAll(bindIdentGen)(law(_)) val examples: List[Bindable] = - List( - "`12 =_=`", - "`N`").map(unsafeBindable) + List("`12 =_=`", "`N`").map(unsafeBindable) examples.foreach(law(_)) @@ -38,7 +39,10 @@ class PythonGenTest extends AnyFunSuite { forAll(bindIdentGen) { b => val str = PythonGen.escape(b).name - assert(PythonName.matcher(str).matches(), s"escaped: ${b.sourceCodeRepr} to $str") + assert( + PythonName.matcher(str).matches(), + s"escaped: ${b.sourceCodeRepr} to $str" + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/graph/ToposortTest.scala b/core/src/test/scala/org/bykn/bosatsu/graph/ToposortTest.scala index b23f6bdf6..0bbcd96b9 100644 --- a/core/src/test/scala/org/bykn/bosatsu/graph/ToposortTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/graph/ToposortTest.scala @@ -3,14 +3,17 @@ package org.bykn.bosatsu.graph import cats.Order import cats.data.NonEmptyList import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import cats.implicits._ import org.scalatest.funsuite.AnyFunSuite class ToposortTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 1000) test("toposort can recover full sort") { @@ -25,7 +28,11 @@ class ToposortTest extends AnyFunSuite { assert(res.isSuccess) assert(res.isFailure == res.loopNodes.nonEmpty) assert(res.toSuccess == Some(res.layers)) - assert(res.layers == items.toVector.sorted(Order[A].toOrdering).map(NonEmptyList(_, Nil))) + assert( + res.layers == items.toVector + .sorted(Order[A].toOrdering) + .map(NonEmptyList(_, Nil)) + ) assert(res.layersAreTotalOrder) } @@ -43,7 +50,10 @@ class ToposortTest extends AnyFunSuite { val nset = fn(n).toSet if (nset.nonEmpty) { (id until layers.size).foreach { id1 => - assert(layers(id1).filter(nset).isEmpty, s"node $n in layer $id has points to later layers: $id1") + assert( + layers(id1).filter(nset).isEmpty, + s"node $n in layer $id has points to later layers: $id1" + ) } } } @@ -59,13 +69,16 @@ class ToposortTest extends AnyFunSuite { val nid = Gen.choose(0, 100) val pair = for { n <- nid - neighbor <- Gen.listOf(nid).map(_.filter(_ < n).distinct) // make sure it is a dag + neighbor <- Gen + .listOf(nid) + .map(_.filter(_ < n).distinct) // make sure it is a dag } yield (n, neighbor) val genDag = Gen.mapOf(pair).map(Dag(_)) forAll(genDag) { case Dag(graph) => val allNodes = graph.flatMap { case (h, t) => h :: t }.toSet - val Toposort.Success(sorted, _) = Toposort.sort(allNodes)(graph.getOrElse(_, Nil)) + val Toposort.Success(sorted, _) = + Toposort.sort(allNodes)(graph.getOrElse(_, Nil)) assert(sorted.flatMap(_.toList).sorted == allNodes.toList.sorted) noEdgesToLater(sorted)(n => graph.getOrElse(n, Nil)) layersAreSorted(sorted) @@ -87,7 +100,9 @@ class ToposortTest extends AnyFunSuite { layersAreSorted(layers) // all the nodes is the same set: val goodNodes = layers.flatMap(_.toList) - assert((goodNodes.toList ::: res.loopNodes).sorted == allNodes.toList.sorted) + assert( + (goodNodes.toList ::: res.loopNodes).sorted == allNodes.toList.sorted + ) // good nodes are distinct assert(goodNodes == goodNodes.distinct) // loop nodes are distinct @@ -103,7 +118,16 @@ class ToposortTest extends AnyFunSuite { } test("we return the least node with a loop") { - assert(Toposort.sort(List(1, 2))(Function.const(List(1, 2))) == Toposort.Failure(List(1, 2), Vector.empty)) - assert(Toposort.sort(List("bb", "aa"))(Function.const(List("aa", "bb"))) == Toposort.Failure(List("aa", "bb"), Vector.empty)) + assert( + Toposort.sort(List(1, 2))(Function.const(List(1, 2))) == Toposort.Failure( + List(1, 2), + Vector.empty + ) + ) + assert( + Toposort.sort(List("bb", "aa"))( + Function.const(List("aa", "bb")) + ) == Toposort.Failure(List("aa", "bb"), Vector.empty) + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/graph/TreeTest.scala b/core/src/test/scala/org/bykn/bosatsu/graph/TreeTest.scala index f077064cc..2b80c71de 100644 --- a/core/src/test/scala/org/bykn/bosatsu/graph/TreeTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/graph/TreeTest.scala @@ -10,26 +10,27 @@ import org.scalatest.funsuite.AnyFunSuite class TreeTest extends AnyFunSuite { test("explicit dags never fail") { - val dagFn: Gen[Int => List[Int]] = + val dagFn: Gen[Int => List[Int]] = Gen.choose(1L, Long.MaxValue).map { seed => - val rng = new java.util.Random(seed) val cache = scala.collection.mutable.Map[Int, List[Int]]() { (node: Int) => // the expected number of neighbors is 1.5, that means the graph is expected to be finite - cache.getOrElseUpdate(node, { - val count = rng.nextInt(3) - (node + 1 until (node + count + 1)).toList.filter(_ > node) - }) + cache.getOrElseUpdate( + node, { + val count = rng.nextInt(3) + (node + 1 until (node + count + 1)).toList.filter(_ > node) + } + ) } } forAll(Gen.choose(0, Int.MaxValue), dagFn) { (start, nfn) => Tree.dagToTree(start)(nfn) match { - case v@Validated.Valid(tree) => + case v @ Validated.Valid(tree) => // the neightbor function should give the same tree: val treeFn = Tree.neighborsFn(tree) val tree2 = Tree.dagToTree(tree.item)(treeFn) @@ -42,9 +43,10 @@ class TreeTest extends AnyFunSuite { } test("circular graphs are invalid") { - val prime = Gen.oneOf(2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89) + val prime = Gen.oneOf(2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, + 47, 53, 59, 61, 67, 71, 73, 79, 83, 89) - val dagFn: Gen[(Int, Int => List[Int])] = + val dagFn: Gen[(Int, Int => List[Int])] = for { p <- prime p1 = p - 1 @@ -54,10 +56,13 @@ class TreeTest extends AnyFunSuite { b <- nodeGen } yield { - (init, { (node: Int) => - // only 1 neighbor, but this is in a cyclic group so it can't be a dag - List((node * a + b) % p) - }) + ( + init, + { (node: Int) => + // only 1 neighbor, but this is in a cyclic group so it can't be a dag + List((node * a + b) % p) + } + ) } forAll(dagFn) { case (start, nfn) => @@ -97,7 +102,8 @@ class TreeTest extends AnyFunSuite { NonEmptyList.fromList(l1.filterNot(nel0.toList.toSet)) match { case None => succeed case Some(diffs) => - val got = Tree.distinctBy(nel0)(identity) ::: Tree.distinctBy(diffs)(identity) + val got = + Tree.distinctBy(nel0)(identity) ::: Tree.distinctBy(diffs)(identity) val expected = Tree.distinctBy(nel0 ::: diffs)(identity) assert(got == expected) } diff --git a/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala b/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala index 6acfe090e..7bd63bbeb 100644 --- a/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu.pattern import org.bykn.bosatsu.set.{Rel, SetOps} import org.scalacheck.{Arbitrary, Gen, Shrink} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import SeqPattern.{Cat, Empty} import SeqPart.{Wildcard, AnyElem, Lit} @@ -29,7 +32,8 @@ object StringSeqPatternGen { Gen.frequency( (15, Gen.oneOf(Lit('0'), Lit('1'))), (2, Gen.const(AnyElem)), - (1, Gen.const(Wildcard))) + (1, Gen.const(Wildcard)) + ) } val genPat: Gen[SeqPattern[Char]] = { @@ -39,9 +43,7 @@ object StringSeqPatternGen { t <- genPat } yield Cat(h, t) - Gen.frequency( - (1, Gen.const(Empty)), - (5, cat)) + Gen.frequency((1, Gen.const(Empty)), (5, cat)) } implicit val arbPattern: Arbitrary[SeqPattern[Char]] = Arbitrary(genPat) @@ -55,7 +57,10 @@ object StringSeqPatternGen { t #:: shrinkPat.shrink(t) } - def genNamedFn[A](gp: Gen[SeqPart[A]], nextId: Int): Gen[(Int, NamedSeqPattern[A])] = { + def genNamedFn[A]( + gp: Gen[SeqPart[A]], + nextId: Int + ): Gen[(Int, NamedSeqPattern[A])] = { lazy val recur = Gen.lzy(res) lazy val genNm: Gen[(Int, Named.Bind[A])] = @@ -69,13 +74,17 @@ object StringSeqPatternGen { // L = (4/9) / (1 - (2/3 + 1/9)) = 4 / (9 - 5) = 1 lazy val res: Gen[(Int, NamedSeqPattern[A])] = Gen.frequency( - (3, for { - (i0, n0) <- recur - (i1, n1) <- genNamedFn(gp, i0) - } yield (i1, NamedSeqPattern.NCat(n0, n1))), + ( + 3, + for { + (i0, n0) <- recur + (i1, n1) <- genNamedFn(gp, i0) + } yield (i1, NamedSeqPattern.NCat(n0, n1)) + ), (1, genNm), (1, Gen.const((nextId, NamedSeqPattern.NEmpty))), - (4, gp.map { p => (nextId, NamedSeqPattern.NSeqPart(p)) })) + (4, gp.map { p => (nextId, NamedSeqPattern.NSeqPart(p)) }) + ) res } @@ -86,7 +95,9 @@ object StringSeqPatternGen { implicit val arbNamed: Arbitrary[NamedSeqPattern[Char]] = Arbitrary(genNamed) def interleave[A](s1: Stream[A], s2: Stream[A]): Stream[A] = - if (s1.isEmpty) s2 else if (s2.isEmpty) s1 else { + if (s1.isEmpty) s2 + else if (s2.isEmpty) s1 + else { s1.head #:: interleave(s2, s1.tail) } @@ -103,32 +114,34 @@ object StringSeqPatternGen { val sp = p match { case Wildcard => AnyElem #:: tail - case AnyElem => tail - case Lit(_) => Stream.Empty + case AnyElem => tail + case Lit(_) => Stream.Empty } sp.map(NamedSeqPattern.NSeqPart(_)) case NamedSeqPattern.NCat(fst, snd) => val s1 = shrinkNamedSeqPattern.shrink(fst) val s2 = shrinkNamedSeqPattern.shrink(snd) - interleave(s1, s2).iterator.sliding(2).map { - case Seq(a, b) => NamedSeqPattern.NCat(a, b) - case _ => NamedSeqPattern.NEmpty - } - .toStream + interleave(s1, s2).iterator + .sliding(2) + .map { + case Seq(a, b) => NamedSeqPattern.NCat(a, b) + case _ => NamedSeqPattern.NEmpty + } + .toStream } def unany[A](p: SeqPattern[A]): SeqPattern[A] = p match { - case Empty => Empty + case Empty => Empty case Cat(AnyElem, t) => unany(t) - case Cat(h, t) => Cat(h, unany(t)) + case Cat(h, t) => Cat(h, unany(t)) } def unwild[A](p: SeqPattern[A]): SeqPattern[A] = p match { - case Empty => Empty + case Empty => Empty case Cat(Wildcard, t) => unwild(t) - case Cat(h, t) => Cat(h, unwild(t)) + case Cat(h, t) => Cat(h, unwild(t)) } implicit val arbString: Arbitrary[String] = Arbitrary(genBitString) @@ -150,9 +163,9 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { val Named = NamedSeqPattern implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 50000) + // PropertyCheckConfiguration(minSuccessful = 50000) PropertyCheckConfiguration(minSuccessful = 5000) - //PropertyCheckConfiguration(minSuccessful = 50) + // PropertyCheckConfiguration(minSuccessful = 50) def genPattern: Gen[Pattern] def genNamed: Gen[Named] @@ -175,8 +188,8 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { val together = intersect.exists(matches(_, x)) assert(together == sep, s"n1: $n1, n2: $n2, intersection: $intersect") - //if (together != sep) sys.error(s"n1: $n1, n2: $n2, intersection: $intersect") - //else succeed + // if (together != sep) sys.error(s"n1: $n1, n2: $n2, intersection: $intersect") + // else succeed } def namedMatchesPatternLaw(n: Named, str: S) = { @@ -209,7 +222,7 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { if (p2.matchesAny) assert(diff == Nil) // the law we wish we had: - //if (p2.matches(s) && p1.matches(s)) assert(!diffmatch) + // if (p2.matches(s) && p1.matches(s)) assert(!diffmatch) } test("reverse is idempotent") { @@ -221,8 +234,8 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { test("cat.reverse == cat.reverseCat") { forAll(genPattern) { p => p match { - case Empty => assert(p.reverse == Empty) - case c@Cat(_, _) => assert(c.reverseCat == c.reverse) + case Empty => assert(p.reverse == Empty) + case c @ Cat(_, _) => assert(c.reverseCat == c.reverse) } } } @@ -230,16 +243,19 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { test("reverse matches the reverse string") { forAll(genPattern, genSeq) { (p: Pattern, str: S) => val rstr = splitter.fromList(splitter.toList(str).reverse) - assert(matches(p, str) == matches(p.reverse, rstr), s"p.reverse = ${p.reverse}") + assert( + matches(p, str) == matches(p.reverse, rstr), + s"p.reverse = ${p.reverse}" + ) } } test("unlit patterns match everything") { def unlit(p: Pattern): Pattern = p match { - case Empty => Empty + case Empty => Empty case Cat(Lit(_) | AnyElem, t) => unlit(t) - case Cat(h, t) => Cat(h, unlit(t)) + case Cat(h, t) => Cat(h, unlit(t)) } forAll(genPattern, genSeq) { (p0, str) => @@ -259,7 +275,7 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { forAll(genPattern) { p0 => val list = p0.normalize.toList list.sliding(2).foreach { - case bad@Seq(Wildcard, Wildcard) => + case bad @ Seq(Wildcard, Wildcard) => fail(s"saw adjacent: $bad in ${p0.normalize}") case _ => () } @@ -319,13 +335,17 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { } test("if subset(a, b) then matching a implies matching b") { - forAll(genPattern, genPattern, genSeq)(subsetConsistentWithMatchLaw(_, _, _)) + forAll(genPattern, genPattern, genSeq)( + subsetConsistentWithMatchLaw(_, _, _) + ) } def diffUBRegressions: List[(Pattern, Pattern, S)] = Nil test("difference is an upper bound") { - forAll(genPattern, genPattern, genSeq) { case (p1, p2, s) => differenceUBLaw(p1, p2, s) } + forAll(genPattern, genPattern, genSeq) { case (p1, p2, s) => + differenceUBLaw(p1, p2, s) + } diffUBRegressions.foreach { case (p1, p2, s) => differenceUBLaw(p1, p2, s) @@ -333,10 +353,11 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { } test("p + q match (s + t) if p.matches(s) && q.matches(t)") { - forAll(genPattern, genPattern, genSeq, genSeq) { (p: Pattern, q: Pattern, s: S, t: S) => - if (matches(p, s) && matches(q, t)) { - assert(matches(p + q, splitter.catSeqs(s :: t :: Nil))) - } + forAll(genPattern, genPattern, genSeq, genSeq) { + (p: Pattern, q: Pattern, s: S, t: S) => + if (matches(p, s) && matches(q, t)) { + assert(matches(p + q, splitter.catSeqs(s :: t :: Nil))) + } } } @@ -346,8 +367,16 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { test("* - [] - [_, *] == empty") { val diff1 = setOps.difference(Cat(Wildcard, Empty), Empty) - assert(diff1.flatMap(setOps.difference(_, Cat(AnyElem, Cat(Wildcard, Empty)))) == Nil) - assert(diff1.flatMap(setOps.difference(_, Cat(Wildcard, Cat(AnyElem, Empty)))) == Nil) + assert( + diff1.flatMap( + setOps.difference(_, Cat(AnyElem, Cat(Wildcard, Empty))) + ) == Nil + ) + assert( + diff1.flatMap( + setOps.difference(_, Cat(Wildcard, Cat(AnyElem, Empty))) + ) == Nil + ) } test("[*, _] n [*, _, *] commutes and is [_, *]") { @@ -371,24 +400,25 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { forAll(genPattern, genPattern) { (x: Pattern, y: Pattern) => val i1 = setOps.intersection(x, y) val i2 = setOps.intersection(y, x) - assert(i1 == i2, - s"${i1.map(_.show)} != ${i2.map(_.show)}") + assert(i1 == i2, s"${i1.map(_.show)} != ${i2.map(_.show)}") } } test("if x - y is empty, (x + z) - (y + z) is empty") { - forAll(genPattern, genPattern, genPattern) { (x0: Pattern, y0: Pattern, z0: Pattern) => - val x = Pattern.fromList(x0.toList.take(3)) - val y = Pattern.fromList(y0.toList.take(3)) - val z = Pattern.fromList(z0.toList.take(3)) - if (setOps.difference(x, y).isEmpty) { - assert(setOps.differenceAll(x :: z :: Nil, y :: z :: Nil) == Nil) - } + forAll(genPattern, genPattern, genPattern) { + (x0: Pattern, y0: Pattern, z0: Pattern) => + val x = Pattern.fromList(x0.toList.take(3)) + val y = Pattern.fromList(y0.toList.take(3)) + val z = Pattern.fromList(z0.toList.take(3)) + if (setOps.difference(x, y).isEmpty) { + assert(setOps.differenceAll(x :: z :: Nil, y :: z :: Nil) == Nil) + } } } } -class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Boolean], Unit] { +class BoolSeqPatternTest + extends SeqPatternLaws[Set[Boolean], Boolean, List[Boolean], Unit] { implicit override val generatorDrivenConfig: PropertyCheckConfiguration = // these tests wind up running very long sometimes @@ -400,7 +430,10 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool case Cat(Wildcard, t) => (Cat(AnyElem, t) #:: t #:: Stream.empty).flatMap(shrinkPat.shrink) case Cat(AnyElem, t) => - (Cat(Lit(Set(false)), t) #:: Cat(Lit(Set(true)), t) #:: t #:: Stream.empty).flatMap(shrinkPat.shrink) + (Cat(Lit(Set(false)), t) #:: Cat( + Lit(Set(true)), + t + ) #:: t #:: Stream.empty).flatMap(shrinkPat.shrink) case Cat(_, t) => t #:: shrinkPat.shrink(t) } @@ -418,7 +451,8 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool Gen.frequency( (15, genSetBool.map(Lit(_))), (2, Gen.const(AnyElem)), - (1, Gen.const(Wildcard))) + (1, Gen.const(Wildcard)) + ) } val genNamed: Gen[NamedSeqPattern[Set[Boolean]]] = @@ -428,29 +462,42 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool val sp = Gen.frequency( (1, SeqPart.Wildcard), (2, SeqPart.AnyElem), - (8, genSetBool.map(SeqPart.Lit(_)))) + (8, genSetBool.map(SeqPart.Lit(_))) + ) Gen.frequency( (1, Gen.const(SeqPattern.Empty)), - (4, Gen.zip(sp, Gen.lzy(genPattern)).map { case (h, t) => SeqPattern.Cat(h, t) }) + ( + 4, + Gen.zip(sp, Gen.lzy(genPattern)).map { case (h, t) => + SeqPattern.Cat(h, t) + } + ) ) } def genSeq: Gen[List[Boolean]] = Gen.choose(0, 20).flatMap(Gen.listOfN(_, genBool)) - val splitter = Splitter.listSplitter(Matcher.fnMatch[Boolean]: Matcher[Set[Boolean], Boolean, Unit]) + val splitter = Splitter.listSplitter( + Matcher.fnMatch[Boolean]: Matcher[Set[Boolean], Boolean, Unit] + ) val pmatcher = SeqPattern.matcher(splitter) - def matches(p: SeqPattern[Set[Boolean]], s: List[Boolean]): Boolean = pmatcher(p)(s).isDefined - def namedMatches(p: NamedSeqPattern[Set[Boolean]], s: List[Boolean]): Boolean = + def matches(p: SeqPattern[Set[Boolean]], s: List[Boolean]): Boolean = + pmatcher(p)(s).isDefined + def namedMatches( + p: NamedSeqPattern[Set[Boolean]], + s: List[Boolean] + ): Boolean = NamedSeqPattern.matcher(splitter)(p)(s).isDefined - implicit val setOpsBool: SetOps[Set[Boolean]] = SetOps.fromFinite(List(true, false)) + implicit val setOpsBool: SetOps[Set[Boolean]] = + SetOps.fromFinite(List(true, false)) implicit val ordSet: Ordering[Set[Boolean]] = Ordering[List[Boolean]].on { (s: Set[Boolean]) => s.toList.sorted } - val setOps: SetOps[SeqPattern[Set[Boolean]]] = SeqPattern.seqPatternSetOps[Set[Boolean]] - + val setOps: SetOps[SeqPattern[Set[Boolean]]] = + SeqPattern.seqPatternSetOps[Set[Boolean]] // we can sometimes enumerate a finite LazyList of matches def enumerate(sp: SeqPattern[Set[Boolean]]): Option[LazyList[List[Boolean]]] = @@ -462,7 +509,7 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool val rests = loop(rest) val heads = hsp match { case Lit(s) if s.size == 1 => s.head :: Nil - case _ => + case _ => // we assume any because there // are no wilds List(false, true) @@ -479,8 +526,14 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool override def diffUBRegressions: List[(Pattern, Pattern, List[Boolean])] = List({ - val p1 = Cat(Wildcard,Empty) - val p2 = Cat(Lit(Set(true)),Cat(Wildcard,Cat(Lit(Set(true, false)),Cat(Lit(Set(true)),Cat(Wildcard, Empty))))) + val p1 = Cat(Wildcard, Empty) + val p2 = Cat( + Lit(Set(true)), + Cat( + Wildcard, + Cat(Lit(Set(true, false)), Cat(Lit(Set(true)), Cat(Wildcard, Empty))) + ) + ) val s = List(true, false, false) (p1, p2, s) @@ -493,20 +546,21 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool SeqPattern.fromList(Wildcard :: Lit(Set(true)) :: Wildcard :: Nil), SeqPattern.fromList(Lit(Set(false)) :: Wildcard :: Nil), SeqPattern.fromList(Nil) - )) + ) + ) assert(missing == Nil) } test("a n b == b n a regressions (bool)") { - //val a = Cat(AnyElem, Cat(Wildcard, Cat(Lit(Set(true, false)), Empty))) - //val b = Cat(AnyElem, Cat(Wildcard, Cat(AnyElem, Empty))) - val a = Cat(AnyElem, Cat(Wildcard, Cat(Lit(Set(true, false)), Empty))) - val b = Cat(AnyElem, Cat(Wildcard, Cat(AnyElem, Empty))) + // val a = Cat(AnyElem, Cat(Wildcard, Cat(Lit(Set(true, false)), Empty))) + // val b = Cat(AnyElem, Cat(Wildcard, Cat(AnyElem, Empty))) + val a = Cat(AnyElem, Cat(Wildcard, Cat(Lit(Set(true, false)), Empty))) + val b = Cat(AnyElem, Cat(Wildcard, Cat(AnyElem, Empty))) - val ab = setOps.intersection(a, b) - val ba = setOps.intersection(b, a) - assert(ab == ba) + val ab = setOps.intersection(a, b) + val ba = setOps.intersection(b, a) + assert(ab == ba) } test("if we can enumerate p1 and p1 - p2 == 0, then all match p2") { @@ -523,8 +577,7 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool assert(matches(p1, s), s"p1: $s") assert(matches(p2, s), s"p2: $s") } - } - else { + } else { // difference is an upper-bound // so truediff <= diff // if ms does not match diff, then it must @@ -585,26 +638,82 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool val subsets: List[(Pattern, Pattern, Boolean)] = List( { - val p0 = Cat(Wildcard,Cat(Lit(Set(false)),Cat(Lit(Set(true, false)),Cat(Lit(Set(true, false)),Empty)))) - val p1 = Cat(Wildcard,Cat(Lit(Set(true, false)),Empty)) + val p0 = Cat( + Wildcard, + Cat( + Lit(Set(false)), + Cat(Lit(Set(true, false)), Cat(Lit(Set(true, false)), Empty)) + ) + ) + val p1 = Cat(Wildcard, Cat(Lit(Set(true, false)), Empty)) (p0, p1, true) } ) - subsets.foreach { case (p1, p2, res) => assert(setOps.subset(p1, p2) == res) } + subsets.foreach { case (p1, p2, res) => + assert(setOps.subset(p1, p2) == res) + } val regressions: List[(Pattern, Pattern, List[List[Boolean]])] = List( { - val p0 = Cat(Lit(Set(false)),Cat(Lit(Set(false)),Cat(Lit(Set(false)),Cat(Lit(Set(true)),Cat(AnyElem,Cat(Lit(Set(false)),Cat(Lit(Set(false)),Cat(Lit(Set(false)),Cat(Lit(Set(true, false)),Cat(Lit(Set(true)),Empty)))))))))) - val p1 = Cat(Wildcard,Cat(Lit(Set(true, false)),Cat(Lit(Set(true)),Cat(Lit(Set(false)),Cat(Lit(Set(true, false)),Cat(Wildcard,Cat(Lit(Set(true, false)),Cat(Lit(Set(true, false)),Empty)))))))) + val p0 = Cat( + Lit(Set(false)), + Cat( + Lit(Set(false)), + Cat( + Lit(Set(false)), + Cat( + Lit(Set(true)), + Cat( + AnyElem, + Cat( + Lit(Set(false)), + Cat( + Lit(Set(false)), + Cat( + Lit(Set(false)), + Cat(Lit(Set(true, false)), Cat(Lit(Set(true)), Empty)) + ) + ) + ) + ) + ) + ) + ) + ) + val p1 = Cat( + Wildcard, + Cat( + Lit(Set(true, false)), + Cat( + Lit(Set(true)), + Cat( + Lit(Set(false)), + Cat( + Lit(Set(true, false)), + Cat( + Wildcard, + Cat( + Lit(Set(true, false)), + Cat(Lit(Set(true, false)), Empty) + ) + ) + ) + ) + ) + ) + ) val matchp0 = Nil (p0, p1, matchp0) - }, - { - val p0 = Cat(Lit(Set(true)),Cat(AnyElem,Cat(Lit(Set(false)),Empty))) - val p1 = Cat(Wildcard,Cat(Lit(Set(true)),Cat(Lit(Set(false)),Cat(Wildcard,Empty)))) + }, { + val p0 = + Cat(Lit(Set(true)), Cat(AnyElem, Cat(Lit(Set(false)), Empty))) + val p1 = Cat( + Wildcard, + Cat(Lit(Set(true)), Cat(Lit(Set(false)), Cat(Wildcard, Empty))) + ) val matchp0 = Nil (p0, p1, matchp0) @@ -616,11 +725,23 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool } test("test some missing branches") { - assert(setOps.missingBranches(Cat(Wildcard, Empty) :: Nil, Pattern.fromList(List(Wildcard, AnyElem, Wildcard)) :: Nil) == - Pattern.fromList(Nil) :: Nil) + assert( + setOps.missingBranches( + Cat(Wildcard, Empty) :: Nil, + Pattern.fromList(List(Wildcard, AnyElem, Wildcard)) :: Nil + ) == + Pattern.fromList(Nil) :: Nil + ) - assert(setOps.missingBranches(Cat(Wildcard, Empty) :: Nil, Pattern.fromList(List(Wildcard, Lit(Set(true)), Wildcard)) :: Nil) == - Pattern.fromList(Nil) :: Pattern.fromList(Lit(Set(false)) :: Wildcard :: Nil) :: Nil) + assert( + setOps.missingBranches( + Cat(Wildcard, Empty) :: Nil, + Pattern.fromList(List(Wildcard, Lit(Set(true)), Wildcard)) :: Nil + ) == + Pattern.fromList(Nil) :: Pattern.fromList( + Lit(Set(false)) :: Wildcard :: Nil + ) :: Nil + ) } } @@ -653,26 +774,33 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { Pattern.Cat(Lit(c), r) } - - override def diffUBRegressions: List[(SeqPattern[Char], SeqPattern[Char], String)] = + override def diffUBRegressions + : List[(SeqPattern[Char], SeqPattern[Char], String)] = List({ - val p1 = Cat(AnyElem,Cat(Wildcard,Empty)) - val p2 = Cat(Wildcard,Cat(AnyElem,Cat(Lit('0'),Cat(Lit('1'),Cat(Wildcard, Empty))))) + val p1 = Cat(AnyElem, Cat(Wildcard, Empty)) + val p2 = Cat( + Wildcard, + Cat(AnyElem, Cat(Lit('0'), Cat(Lit('1'), Cat(Wildcard, Empty)))) + ) (p1, p2, "11") }) test("some matching examples") { - val ms: List[(Pattern, String)] = + val ms: List[(Pattern, String)] = (Pattern.Wild + Pattern.Any + Pattern.Any + toPattern("1"), "111") :: - (toPattern("1") + Pattern.Any + toPattern("1"), "111") :: - Nil + (toPattern("1") + Pattern.Any + toPattern("1"), "111") :: + Nil ms.foreach { case (p, s) => assert(matches(p, s), s"matches($p, $s)") } } test("wildcard on either side is the same as contains") { forAll { (ps: String, s: String) => - assert(matches(Pattern.Wild + toPattern(ps) + Pattern.Wild, s) == s.contains(ps)) + assert( + matches(Pattern.Wild + toPattern(ps) + Pattern.Wild, s) == s.contains( + ps + ) + ) } } test("wildcard on front side is the same as endsWith") { @@ -686,18 +814,38 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { } } - test("intersection(p1, p2).matches(x) == p1.matches(x) && p2.matches(x) regressions") { + test( + "intersection(p1, p2).matches(x) == p1.matches(x) && p2.matches(x) regressions" + ) { val regressions: List[(Pattern, Pattern, String)] = - (toPattern("0") + Pattern.Any + Pattern.Wild, Pattern.Any + toPattern("01") + Pattern.Any, "001") :: - (Pattern.Wild + Pattern.Any + Pattern.Any + toPattern("1"), toPattern("1") + Pattern.Any + toPattern("1"), "111") :: - Nil + ( + toPattern("0") + Pattern.Any + Pattern.Wild, + Pattern.Any + toPattern("01") + Pattern.Any, + "001" + ) :: + ( + Pattern.Wild + Pattern.Any + Pattern.Any + toPattern("1"), + toPattern("1") + Pattern.Any + toPattern("1"), + "111" + ) :: + Nil regressions.foreach { case (p1, p2, s) => intersectionMatchLaw(p1, p2, s) } } test("subset is consistent with match regressions") { - assert(setOps.subset(toPattern("00") + Pattern.Wild, toPattern("0") + Pattern.Wild)) - assert(setOps.subset(toPattern("00") + Pattern.Any + Pattern.Wild, toPattern("0") + Pattern.Any + Pattern.Wild)) + assert( + setOps.subset( + toPattern("00") + Pattern.Wild, + toPattern("0") + Pattern.Wild + ) + ) + assert( + setOps.subset( + toPattern("00") + Pattern.Any + Pattern.Wild, + toPattern("0") + Pattern.Any + Pattern.Wild + ) + ) } test("if y - x is empty, (yz - xz) for all strings is empty") { @@ -714,13 +862,14 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { } test("if y - x is empty, (zy - zx) for all strings is empty") { - forAll(genPattern, genPattern, genSeq) { (x0: Pattern, y0: Pattern, str: String) => - val x = Pattern.fromList(x0.toList.take(5)) - val y = Pattern.fromList(y0.toList.take(5)) - if (setOps.difference(y, x) == Nil) { - val left = toPattern(str) + y - assert(setOps.difference(left, toPattern(str) + x) == Nil) - } + forAll(genPattern, genPattern, genSeq) { + (x0: Pattern, y0: Pattern, str: String) => + val x = Pattern.fromList(x0.toList.take(5)) + val y = Pattern.fromList(y0.toList.take(5)) + if (setOps.difference(y, x) == Nil) { + val left = toPattern(str) + y + assert(setOps.difference(left, toPattern(str) + x) == Nil) + } } } @@ -729,7 +878,7 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { namedMatch(n, str).foreach { m => n.render(m)(_.toString) match { case Some(s0) => assert(s0 == str, s"m = $m") - case None => + case None => // this can only happen if we have unnamed Wild/AnyElem assert(n.isRenderable == false, s"m = $m") } @@ -742,8 +891,11 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { import SeqPart.Wildcard val regressions: List[(Named, String)] = - (NCat(Bind("0",NSeqPart(Wildcard)),Bind("1",NSeqPart(Wildcard))), "") :: - Nil + ( + NCat(Bind("0", NSeqPart(Wildcard)), Bind("1", NSeqPart(Wildcard))), + "" + ) :: + Nil regressions.foreach { case (n, s) => law(n, s) } } @@ -755,30 +907,35 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { import Named._ val regressions: List[(Named, String)] = - (NCat(NEmpty,NCat(NSeqPart(Lit('1')),NSeqPart(Wildcard))), "1") :: - (NCat(NSeqPart(Lit('1')),NSeqPart(Wildcard)), "1") :: - (NSeqPart(Lit('1')), "1") :: - Nil + (NCat(NEmpty, NCat(NSeqPart(Lit('1')), NSeqPart(Wildcard))), "1") :: + (NCat(NSeqPart(Lit('1')), NSeqPart(Wildcard)), "1") :: + (NSeqPart(Lit('1')), "1") :: + Nil regressions.foreach { case (n, s) => namedMatchesPatternLaw(n, s) } } test("Test some examples of Named matching") { // foo@("bar" baz@(*"baz")) - val p1 = (named("bar") + (Named.Wild + named("baz")).name("baz")).name("foo") - assert(namedMatch(p1, "bar and baz") == Some(Map("foo" -> "bar and baz", "baz" -> " and baz"))) + val p1 = + (named("bar") + (Named.Wild + named("baz")).name("baz")).name("foo") + assert( + namedMatch(p1, "bar and baz") == Some( + Map("foo" -> "bar and baz", "baz" -> " and baz") + ) + ) } test("regression subset example") { - val p1 = Cat(AnyElem,Cat(Wildcard,Empty)) - val p2 = Cat(Wildcard,Cat(Lit('0'),Cat(Lit('1'),Empty))) + val p1 = Cat(AnyElem, Cat(Wildcard, Empty)) + val p2 = Cat(Wildcard, Cat(Lit('0'), Cat(Lit('1'), Empty))) assert(setOps.subset(p2, p1)) } test("intersection regression") { - val p1 = Cat(Wildcard,Cat(Lit('0'),Cat(Lit('1'),Empty))) - val p2 = Cat(Lit('0'),Cat(Lit('0'),Cat(Lit('0'),Cat(Wildcard,Empty)))) + val p1 = Cat(Wildcard, Cat(Lit('0'), Cat(Lit('1'), Empty))) + val p2 = Cat(Lit('0'), Cat(Lit('0'), Cat(Lit('0'), Cat(Wildcard, Empty)))) assert(setOps.relate(p2, p1) == Rel.Intersects) assert(setOps.intersection(p1, p2).nonEmpty) diff --git a/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala b/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala index 7ca77cf3d..e78ec2093 100644 --- a/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala +++ b/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala @@ -29,27 +29,32 @@ class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { def matches(p: Pattern, s: String): Boolean = pmatcher(p)(s).isDefined def eqUnion: Gen[Eq[List[Pattern]]] = - Gen.listOfN(5000, Gen.frequency( - 10 -> StringSeqPatternGen.genBitString, - 1 -> Gen.listOf(Gen.oneOf(List('0', '1', '2'))).map(_.mkString) - )).map { tests => - // we have to generate more than just 01 strings, - // since Any can match more than that - new Eq[List[Pattern]] { - // this can flake because if two things are different, - // but happen to have the same match results for this - // set of items, then you are hosed - def eqv(a: List[Pattern], b: List[Pattern]) = - (a, b) match { - case (ah :: Nil, bh :: Nil) => - setOps.equiv(ah, bh) - case _ => - (a.toSet == b.toSet) || tests.forall { s => - a.exists(matches(_, s)) == b.exists(matches(_, s)) - } - } + Gen + .listOfN( + 5000, + Gen.frequency( + 10 -> StringSeqPatternGen.genBitString, + 1 -> Gen.listOf(Gen.oneOf(List('0', '1', '2'))).map(_.mkString) + ) + ) + .map { tests => + // we have to generate more than just 01 strings, + // since Any can match more than that + new Eq[List[Pattern]] { + // this can flake because if two things are different, + // but happen to have the same match results for this + // set of items, then you are hosed + def eqv(a: List[Pattern], b: List[Pattern]) = + (a, b) match { + case (ah :: Nil, bh :: Nil) => + setOps.equiv(ah, bh) + case _ => + (a.toSet == b.toSet) || tests.forall { s => + a.exists(matches(_, s)) == b.exists(matches(_, s)) + } + } + } } - } implicit val setOpsChar: SetOps[Char] = SetOps.distinct[Char] val setOps: SetOps[Pattern] = Pattern.seqPatternSetOps[Char] @@ -59,22 +64,31 @@ class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { import SeqPart.Lit val regressions: List[(SeqPattern[Char], SeqPattern[Char])] = - (Cat(Lit('1'),Cat(Lit('1'),Cat(Lit('1'),Cat(Lit('1'),Empty)))), - Cat(Lit('0'),Cat(Lit('1'),Cat(Lit('1'),Empty)))) :: - (Cat(Lit('1'),Cat(Lit('0'),Cat(Lit('1'),Cat(Lit('0'),Empty)))), - Cat(Lit('0'),Cat(Lit('1'),Empty))) :: - Nil - - regressions.foreach { case (a, b) => subsetConsistencyLaw(a, b, Eq.fromUniversalEquals) } + ( + Cat(Lit('1'), Cat(Lit('1'), Cat(Lit('1'), Cat(Lit('1'), Empty)))), + Cat(Lit('0'), Cat(Lit('1'), Cat(Lit('1'), Empty))) + ) :: + ( + Cat(Lit('1'), Cat(Lit('0'), Cat(Lit('1'), Cat(Lit('0'), Empty)))), + Cat(Lit('0'), Cat(Lit('1'), Empty)) + ) :: + Nil + + regressions.foreach { case (a, b) => + subsetConsistencyLaw(a, b, Eq.fromUniversalEquals) + } } test("*x* problems") { import SeqPattern.{Cat, Empty} import SeqPart.{Lit, Wildcard} - val x = Cat(Wildcard,Cat(Lit('q'),Cat(Wildcard,Cat(Lit('p'),Cat(Wildcard,Empty))))) - val y = Cat(Wildcard,Cat(Lit('p'),Cat(Wildcard,Empty))) - val z = Cat(Wildcard,Cat(Lit('q'),Cat(Wildcard,Empty))) + val x = Cat( + Wildcard, + Cat(Lit('q'), Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty)))) + ) + val y = Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty))) + val z = Cat(Wildcard, Cat(Lit('q'), Cat(Wildcard, Empty))) // note y and z are clearly bigger than x because they are prefix/suffix that end/start with // Wildcard assert(setOps.difference(x, y).isEmpty) @@ -82,34 +96,41 @@ class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { } test("(a - b) n c = (a n c) - (b n c) regressions") { - val regressions: List[(SeqPattern[Char], SeqPattern[Char], SeqPattern[Char])] = - (Cat(Wildcard, Empty), - Cat(AnyElem,Cat(Lit('1'),Cat(AnyElem,Empty))), - Cat(AnyElem,Cat(Lit('1'),Cat(Lit('0'),Empty)))) :: - (Cat(Wildcard,Cat(Lit('0'),Empty)), - Cat(AnyElem,Cat(Lit('1'),Cat(AnyElem,Cat(Lit('0'),Empty)))), - Cat(AnyElem,Cat(Lit('1'),Cat(Lit('0'),Cat(Lit('0'),Empty))))) :: - (Cat(Wildcard, Cat(Lit('q'), Cat(Wildcard, Empty))), + val regressions + : List[(SeqPattern[Char], SeqPattern[Char], SeqPattern[Char])] = + ( Cat(Wildcard, Empty), - Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty)))) :: - /* - * This fails currently - * see: https://github.com/johnynek/bosatsu/issues/486 + Cat(AnyElem, Cat(Lit('1'), Cat(AnyElem, Empty))), + Cat(AnyElem, Cat(Lit('1'), Cat(Lit('0'), Empty))) + ) :: + ( + Cat(Wildcard, Cat(Lit('0'), Empty)), + Cat(AnyElem, Cat(Lit('1'), Cat(AnyElem, Cat(Lit('0'), Empty)))), + Cat(AnyElem, Cat(Lit('1'), Cat(Lit('0'), Cat(Lit('0'), Empty)))) + ) :: + ( + Cat(Wildcard, Cat(Lit('q'), Cat(Wildcard, Empty))), + Cat(Wildcard, Empty), + Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty))) + ) :: + /* + * This fails currently + * see: https://github.com/johnynek/bosatsu/issues/486 { val p1 = Cat(Wildcard,Cat(Lit('1'),Cat(Lit('0'),Cat(Lit('0'),Empty)))) val p2 = Cat(AnyElem,Cat(Lit('1'),Cat(Wildcard,Cat(Lit('0'),Empty)))) val p3 = Cat(Lit('1'),Cat(Lit('1'),Cat(Wildcard,Cat(Lit('0'),Empty)))) (p1, p2, p3) } :: - */ - Nil + */ + Nil regressions.foreach { case (a, b, c) => diffIntersectionLaw(a, b, c) } } test("intersection regression") { - val p1 = Cat(Wildcard,Cat(Lit('0'),Cat(Lit('1'),Empty))) - val p2 = Cat(Lit('0'),Cat(Lit('0'),Cat(Lit('0'),Cat(Wildcard,Empty)))) + val p1 = Cat(Wildcard, Cat(Lit('0'), Cat(Lit('1'), Empty))) + val p2 = Cat(Lit('0'), Cat(Lit('0'), Cat(Lit('0'), Cat(Wildcard, Empty)))) assert(setOps.relate(p1, p2) == Rel.Intersects) assert(setOps.relate(p2, p1) == Rel.Intersects) diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala index cdb685eaa..ba88fa8c8 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala @@ -24,16 +24,30 @@ object NTypeGen { consIdentGen.map(TypeName(_)) val keyWords = Set( - "if", "ffi", "match", "struct", "enum", "else", "elif", - "def", "external", "package", "import", "export", "forall", "exists", - "recur", "recursive") + "if", + "ffi", + "match", + "struct", + "enum", + "else", + "elif", + "def", + "external", + "package", + "import", + "export", + "forall", + "exists", + "recur", + "recursive" + ) val lowerIdent: Gen[String] = (for { c <- lower cnt <- Gen.choose(0, 10) rest <- Gen.listOfN(cnt, identC) - } yield (c :: rest).mkString).filter { s=> !keyWords(s) } + } yield (c :: rest).mkString).filter { s => !keyWords(s) } val packageNameGen: Gen[PackageName] = for { @@ -45,7 +59,8 @@ object NTypeGen { } yield PackageName(NonEmptyList(h, tail)) val genConst: Gen[Type.Const] = - Gen.zip(packageNameGen, typeNameGen) + Gen + .zip(packageNameGen, typeNameGen) .map { case (p, n) => Type.Const.Defined(p, n) } val genBound: Gen[Type.Var.Bound] = @@ -97,7 +112,9 @@ object NTypeGen { shrink(in).map(Type.exists(items.tail, _)) case _: Leaf => Stream.empty case TyApply(on, arg) => - on #:: arg #:: shrink(on).collect { case r: Type.Rho => TyApply(r, arg) } #::: shrink(arg).map(TyApply(on, _)) + on #:: arg #:: shrink(on).collect { case r: Type.Rho => + TyApply(r, arg) + } #::: shrink(arg).map(TyApply(on, _)) } Shrink(shrink(_)) } @@ -127,21 +144,31 @@ object NTypeGen { // either Unit, Tuple2(a, b) Gen.oneOf( Gen.const(UnitType), - Gen.zip(recurse, recTup).map { case (h, t) => Type.TyApply(Type.TyApply(Tuple.Arity(2), h), t) }) + Gen.zip(recurse, recTup).map { case (h, t) => + Type.TyApply(Type.TyApply(Tuple.Arity(2), h), t) + } + ) } Gen.frequency( (6, Gen.oneOf(t0)), - (2, for { - cons <- Gen.oneOf(t1) - param <- recurse - } yield TyApply(cons, param)), + ( + 2, + for { + cons <- Gen.oneOf(t1) + param <- recurse + } yield TyApply(cons, param) + ), (1, tupleTypes), - (1, for { - cons <- Gen.oneOf(t2) - param1 <- recurse - param2 <- recurse - } yield TyApply(TyApply(cons, param1), param2))) + ( + 1, + for { + cons <- Gen.oneOf(t2) + param1 <- recurse + param2 <- recurse + } yield TyApply(TyApply(cons, param1), param2) + ) + ) } val genQuantArgs: Gen[List[(Type.Var.Bound, Kind)]] = @@ -152,25 +179,26 @@ object NTypeGen { } yield as lazy val genQuant: Gen[Type.Quantification] = - Gen.zip(genQuantArgs, genQuantArgs) + Gen + .zip(genQuantArgs, genQuantArgs) .flatMap { case (fa, ex0) => val faSet = fa.map(_._1).toSet val ex = ex0.filterNot { case (b, _) => faSet(b) } Type.Quantification.fromLists(fa, ex) match { case Some(q) => Gen.const(q) - case None => genQuant - } + case None => genQuant + } } - def genTypeRho(d: Int, genC: Option[Gen[Type.Const]]): Gen[Type.Rho] = { val root = genRootType(genC) if (d <= 0) root else { val recurse = Gen.lzy(genTypeRho(d - 1, genC)) - val genApply = Gen.zip(recurse, genDepth(d - 1, genC)) + val genApply = Gen + .zip(recurse, genDepth(d - 1, genC)) .map { case (a, b) => Type.TyApply(a, b) } - + Gen.frequency((3, root), (1, genApply)) } } @@ -196,16 +224,20 @@ object NTypeGen { } yield Type.exists(as, in) val genQ = Gen.zip(NTypeGen.genQuant, recurse).map { case (q, t) => - Type.quantify(q, t) + Type.quantify(q, t) } - val genApply = Gen.zip(genTypeRho(d - 1, genC), recurse).map { case (a, b) => Type.TyApply(a, b) } + val genApply = Gen.zip(genTypeRho(d - 1, genC), recurse).map { + case (a, b) => Type.TyApply(a, b) + } Gen.frequency( (2, recurse), (1, genApply), - (1, Gen.oneOf(genForAll, genExists, genQ))) + (1, Gen.oneOf(genForAll, genExists, genQ)) + ) } - val genDepth03: Gen[Type] = Gen.choose(0, 3).flatMap(genDepth(_, Some(genConst))) + val genDepth03: Gen[Type] = + Gen.choose(0, 3).flatMap(genDepth(_, Some(genConst))) } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala index 231b58373..10a0b0594 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala @@ -19,17 +19,20 @@ class RankNInferTest extends AnyFunSuite { val emptyRegion: Region = Region(0, 0) - implicit val unitRegion: HasRegion[Unit] = HasRegion.instance(_ => emptyRegion) + implicit val unitRegion: HasRegion[Unit] = + HasRegion.instance(_ => emptyRegion) private def strToConst(str: Identifier.Constructor): Type.Const = str.asString match { - case "Int" => Type.Const.predef("Int") + case "Int" => Type.Const.predef("Int") case "String" => Type.Const.predef("String") - case "List" => Type.Const.predef("List") - case _ => Type.Const.Defined(testPackage, TypeName(str)) + case "List" => Type.Const.predef("List") + case _ => Type.Const.Defined(testPackage, TypeName(str)) } - def asFullyQualified(ns: Iterable[(Identifier, Type)]): Map[Infer.Name, Type] = + def asFullyQualified( + ns: Iterable[(Identifier, Type)] + ): Map[Infer.Name, Type] = ns.iterator.map { case (n, t) => ((Some(testPackage), n), t) }.toMap def typeFrom(str: String): Type = { @@ -56,11 +59,13 @@ class RankNInferTest extends AnyFunSuite { val t1 = typeFrom(left) val t2 = typeFrom(right) - val res1 = Infer.substitutionCheck(t1, t2, emptyRegion, emptyRegion) + val res1 = Infer + .substitutionCheck(t1, t2, emptyRegion, emptyRegion) .runFully(Map.empty, Map.empty, Type.builtInKinds) assert(res1.isRight, s"$left is not :<: $right\n\n$res1") - val res2 = Infer.substitutionCheck(t2, t1, emptyRegion, emptyRegion) + val res2 = Infer + .substitutionCheck(t2, t1, emptyRegion, emptyRegion) .runFully(Map.empty, Map.empty, Type.builtInKinds) assert(res2.isLeft, s"$left is unexpectedly = $right\n\n$res2") } @@ -69,9 +74,11 @@ class RankNInferTest extends AnyFunSuite { val t1 = typeFrom(left) val t2 = typeFrom(right) - val res1 = Infer.substitutionCheck(t1, t2, emptyRegion, emptyRegion) + val res1 = Infer + .substitutionCheck(t1, t2, emptyRegion, emptyRegion) .runFully(Map.empty, Map.empty, Type.builtInKinds) - val res2 = Infer.substitutionCheck(t2, t1, emptyRegion, emptyRegion) + val res2 = Infer + .substitutionCheck(t2, t1, emptyRegion, emptyRegion) .runFully(Map.empty, Map.empty, Type.builtInKinds) assert(res1.isLeft, s"$left is unexpectedly :<: $right\n\n$res1") @@ -86,31 +93,53 @@ class RankNInferTest extends AnyFunSuite { val withBools: Map[Infer.Name, Type] = Map( - (Some(PackageName.PredefName), Identifier.unsafe("True")) -> Type.BoolType, - (Some(PackageName.PredefName), Identifier.unsafe("False")) -> Type.BoolType) + ( + Some(PackageName.PredefName), + Identifier.unsafe("True") + ) -> Type.BoolType, + ( + Some(PackageName.PredefName), + Identifier.unsafe("False") + ) -> Type.BoolType + ) val boolTypes: Map[(PackageName, Constructor), Infer.Cons] = Map( - ((PackageName.PredefName, Constructor("True")), (Nil, Nil, Type.Const.predef("Bool"))), - ((PackageName.PredefName, Constructor("False")), (Nil, Nil, Type.Const.predef("Bool")))) + ( + (PackageName.PredefName, Constructor("True")), + (Nil, Nil, Type.Const.predef("Bool")) + ), + ( + (PackageName.PredefName, Constructor("False")), + (Nil, Nil, Type.Const.predef("Bool")) + ) + ) def testType[A: HasRegion](term: Expr[A], ty: Type) = Infer.typeCheck(term).runFully(withBools, boolTypes, Map.empty) match { - case Left(err) => assert(false, err) + case Left(err) => assert(false, err) case Right(tpe) => assert(tpe.getType.sameAs(ty), term.toString) } def testLetTypes[A: HasRegion](terms: List[(String, Expr[A], Type)]) = - Infer.typeCheckLets(testPackage, terms.map { case (k, v, _) => (Identifier.Name(k), RecursionKind.NonRecursive, v) }) + Infer + .typeCheckLets( + testPackage, + terms.map { case (k, v, _) => + (Identifier.Name(k), RecursionKind.NonRecursive, v) + } + ) .runFully(withBools, boolTypes, Type.builtInKinds) match { - case Left(err) => assert(false, err) - case Right(tpes) => - assert(tpes.size == terms.size) - terms.zip(tpes).foreach { case ((n, exp, expt), (n1, _, te)) => - assert(n == n1.asString, s"the name changed: $n != $n1") - assert(te.getType == expt, s"$n = $exp failed to typecheck to $expt, got ${te.getType}") - } - } - + case Left(err) => assert(false, err) + case Right(tpes) => + assert(tpes.size == terms.size) + terms.zip(tpes).foreach { case ((n, exp, expt), (n1, _, te)) => + assert(n == n1.asString, s"the name changed: $n != $n1") + assert( + te.getType == expt, + s"$n = $exp failed to typecheck to $expt, got ${te.getType}" + ) + } + } def lit(i: Int): Expr[Unit] = Literal(Lit(i.toLong), ()) def lit(b: Boolean): Expr[Unit] = @@ -122,36 +151,41 @@ class RankNInferTest extends AnyFunSuite { Lambda(NonEmptyList.one((Identifier.Name(arg), None)), result, ()) def v(name: String): Expr[Unit] = Identifier.unsafe(name) match { - case c@Identifier.Constructor(_) => Global(testPackage, c, ()) - case b: Identifier.Bindable => Local(b, ()) + case c @ Identifier.Constructor(_) => Global(testPackage, c, ()) + case b: Identifier.Bindable => Local(b, ()) } def ann(expr: Expr[Unit], t: Type): Expr[Unit] = Annotation(expr, t, ()) - def app(fn: Expr[Unit], arg: Expr[Unit]): Expr[Unit] = App(fn, NonEmptyList.one(arg), ()) + def app(fn: Expr[Unit], arg: Expr[Unit]): Expr[Unit] = + App(fn, NonEmptyList.one(arg), ()) def alam(arg: String, tpe: Type, res: Expr[Unit]): Expr[Unit] = Lambda(NonEmptyList.one((Identifier.Name(arg), Some(tpe))), res, ()) - def ife(cond: Expr[Unit], ift: Expr[Unit], iff: Expr[Unit]): Expr[Unit] = Expr.ifExpr(cond, ift, iff, ()) - def matche(arg: Expr[Unit], branches: NonEmptyList[(Pattern[String, Type], Expr[Unit])]): Expr[Unit] = - Match(arg, + def ife(cond: Expr[Unit], ift: Expr[Unit], iff: Expr[Unit]): Expr[Unit] = + Expr.ifExpr(cond, ift, iff, ()) + def matche( + arg: Expr[Unit], + branches: NonEmptyList[(Pattern[String, Type], Expr[Unit])] + ): Expr[Unit] = + Match( + arg, branches.map { case (p, e) => val p1 = p.mapName { n => (testPackage, Constructor(n)) } (p1, e) }, - ()) + () + ) - /** - * Check that a no import program has a given type - */ + /** Check that a no import program has a given type + */ def parseProgram(statement: String, tpe: String) = checkLast(statement) { te0 => - val te = te0 // TypedExprNormalization.normalize(te0).getOrElse(te0) te.traverseType[cats.Id] { - case t@Type.TyVar(Type.Var.Skolem(_, _, _, _)) => + case t @ Type.TyVar(Type.Var.Skolem(_, _, _, _)) => fail(s"illegate skolem ($t) escape in $te") t - case t@Type.TyMeta(_) => + case t @ Type.TyMeta(_) => fail(s"illegate meta ($t) escape in $te") t case good => @@ -161,10 +195,12 @@ class RankNInferTest extends AnyFunSuite { val rendered = te.repr val tp = te.getType lazy val teStr = Type.fullyResolvedDocument.document(tp).render(80) - assert(Type.freeTyVars(tp :: Nil).isEmpty, s"illegal inferred type: $teStr, in: $rendered") + assert( + Type.freeTyVars(tp :: Nil).isEmpty, + s"illegal inferred type: $teStr, in: $rendered" + ) - assert(Type.metaTvs(tp :: Nil).isEmpty, - s"illegal inferred type: $teStr") + assert(Type.metaTvs(tp :: Nil).isEmpty, s"illegal inferred type: $teStr") assert(te.getType.sameAs(typeFrom(tpe)), s"found: ${te.repr}") } @@ -172,17 +208,20 @@ class RankNInferTest extends AnyFunSuite { def checkTERepr(statement: String, repr: String) = checkLast(statement) { te => assert(te.repr == repr) } - /** - * Test that a program is ill-typed - */ + /** Test that a program is ill-typed + */ def parseProgramIllTyped(statement: String) = { val stmts = Parser.unsafeParse(Statement.parser, statement) Package.inferBody(testPackage, Nil, stmts) match { case Ior.Left(_) | Ior.Both(_, _) => assert(true) case Ior.Right(program) => - fail("expected an invalid program, but got:\n\n" + program.lets.map { case (b, r, t) => - s"$b: $r = ${t.repr}" - }.mkString("\n\n")) + fail( + "expected an invalid program, but got:\n\n" + program.lets + .map { case (b, r, t) => + s"$b: $r = ${t.repr}" + } + .mkString("\n\n") + ) } } @@ -213,19 +252,32 @@ class RankNInferTest extends AnyFunSuite { assert_:<:("forall a. a -> Int", "(forall a. a) -> Int") assertTypesUnify( "((forall a. a) -> Int) -> Int", - "forall a. (a -> Int) -> Int") + "forall a. (a -> Int) -> Int" + ) assert_:<:("List[forall a. a -> Int]", "List[(forall a. a) -> Int]") - assertTypesUnify("forall f: +* -> *. f[forall a. a]", "forall a. forall f: +* -> *. f[a]") + assertTypesUnify( + "forall f: +* -> *. f[forall a. a]", + "forall a. forall f: +* -> *. f[a]" + ) assert_:<:("forall f: * -> *. f[Int]", "forall f: +* -> *. f[Int]") assert_:<:("forall f: * -> *. f[Int]", "forall f: -* -> *. f[Int]") assert_:<:("forall f: +* -> *. f[Int]", "forall f: 👻* -> *. f[Int]") assert_:<:("forall f: -* -> *. f[Int]", "forall f: 👻* -> *. f[Int]") - assert_:<:("forall a. forall f: * -> *. f[a]", "forall f: * -> *. f[forall a. a]") - assert_:<:("forall a. forall f: -* -> *. f[a]", "forall f: -* -> *. f[forall a. a]") + assert_:<:( + "forall a. forall f: * -> *. f[a]", + "forall f: * -> *. f[forall a. a]" + ) + assert_:<:( + "forall a. forall f: -* -> *. f[a]", + "forall f: -* -> *. f[forall a. a]" + ) assertTypesUnify("(forall a. a) -> Int", "(forall a. a) -> Int") - assertTypesUnify("(forall a. a -> Int) -> Int", "(forall a. a -> Int) -> Int") + assertTypesUnify( + "(forall a. a -> Int) -> Int", + "(forall a. a -> Int) -> Int" + ) assert_:<:("forall a, b. a -> b -> b", "forall a. a -> a -> a") assert_:<:("forall a, b. a -> b", "forall b, c. b -> (c -> Int)") assert_:<:("forall a, f: * -> *. f[a]", "forall x. List[x]") @@ -236,31 +288,44 @@ class RankNInferTest extends AnyFunSuite { assertTypesDisjoint("Int -> Unit", "String") assertTypesDisjoint("Int -> Unit", "String -> a") assertTypesUnify("forall a. Int", "Int") - + // Test unbound vars assertTypesDisjoint("a", "Int") assertTypesDisjoint("Int", "a") - assert_:<:( "forall f: * -> *, a, b. (f[a], a -> f[b]) -> f[b]", - "forall f: +* -> *, a, b. (f[a], a -> f[b]) -> f[b]") + "forall f: +* -> *, a, b. (f[a], a -> f[b]) -> f[b]" + ) } test("Basic inferences") { testType(lit(100), Type.IntType) testType(let("x", lambda("y", v("y")), lit(100)), Type.IntType) - testType(lambda("y", v("y")), - forAll(NonEmptyList.of(b("a")), - Type.Fun(Type.TyVar(Bound("a")),Type.TyVar(Bound("a"))))) - testType(lambda("y", lambda("z", v("y"))), - forAll(NonEmptyList.of(b("a"), b("b")), - Type.Fun(Type.TyVar(Bound("a")), - Type.Fun(Type.TyVar(Bound("b")),Type.TyVar(Bound("a")))))) + testType( + lambda("y", v("y")), + forAll( + NonEmptyList.of(b("a")), + Type.Fun(Type.TyVar(Bound("a")), Type.TyVar(Bound("a"))) + ) + ) + testType( + lambda("y", lambda("z", v("y"))), + forAll( + NonEmptyList.of(b("a"), b("b")), + Type.Fun( + Type.TyVar(Bound("a")), + Type.Fun(Type.TyVar(Bound("b")), Type.TyVar(Bound("a"))) + ) + ) + ) testType(app(lambda("x", v("x")), lit(100)), Type.IntType) - testType(ann(app(lambda("x", v("x")), lit(100)), Type.IntType), Type.IntType) + testType( + ann(app(lambda("x", v("x")), lit(100)), Type.IntType), + Type.IntType + ) testType(app(alam("x", Type.IntType, v("x")), lit(100)), Type.IntType) // test branches @@ -268,37 +333,58 @@ class RankNInferTest extends AnyFunSuite { testType(let("x", lit(0), ife(lit(true), v("x"), lit(1))), Type.IntType) val identFnType = - forAll(NonEmptyList.of(b("a")), - Type.Fun(Type.TyVar(Bound("a")), Type.TyVar(Bound("a")))) - testType(let("x", lambda("y", v("y")), - ife(lit(true), v("x"), - ann(lambda("x", v("x")), identFnType))), identFnType) + forAll( + NonEmptyList.of(b("a")), + Type.Fun(Type.TyVar(Bound("a")), Type.TyVar(Bound("a"))) + ) + testType( + let( + "x", + lambda("y", v("y")), + ife(lit(true), v("x"), ann(lambda("x", v("x")), identFnType)) + ), + identFnType + ) // test some lets testLetTypes( List( ("x", lit(100), Type.IntType), - ("y", Expr.Global(testPackage, Identifier.Name("x"), ()), Type.IntType))) + ("y", Expr.Global(testPackage, Identifier.Name("x"), ()), Type.IntType) + ) + ) } test("match inference") { testType( - matche(lit(10), + matche( + lit(10), NonEmptyList.of( (Pattern.WildCard, lit(0)) - )), Type.IntType) + ) + ), + Type.IntType + ) testType( - matche(lit(true), + matche( + lit(true), NonEmptyList.of( (Pattern.WildCard, lit(0)) - )), Type.IntType) + ) + ), + Type.IntType + ) testType( - matche(lit(true), + matche( + lit(true), NonEmptyList.of( (Pattern.Annotation(Pattern.WildCard, Type.BoolType), lit(0)) - )), Type.IntType) + ) + ), + Type.IntType + ) } object OptionTypes { @@ -308,11 +394,23 @@ class RankNInferTest extends AnyFunSuite { val pn = testPackage val definedOption = Map( ((pn, Constructor("Some")), (Nil, List(Type.IntType), optName)), - ((pn, Constructor("None")), (Nil, Nil, optName))) + ((pn, Constructor("None")), (Nil, Nil, optName)) + ) val definedOptionGen = Map( - ((pn, Constructor("Some")), (List((Bound("a"), Kind.Type.co)), List(Type.TyVar(Bound("a"))), optName)), - ((pn, Constructor("None")), (List((Bound("a"), Kind.Type.co)), Nil, optName))) + ( + (pn, Constructor("Some")), + ( + List((Bound("a"), Kind.Type.co)), + List(Type.TyVar(Bound("a"))), + optName + ) + ), + ( + (pn, Constructor("None")), + (List((Bound("a"), Kind.Type.co)), Nil, optName) + ) + ) } test("match with custom non-generic types") { @@ -325,41 +423,65 @@ class RankNInferTest extends AnyFunSuite { val kindNotGen = Type.builtInKinds.updated(optName, Kind.Type) def testWithOpt[A: HasRegion](term: Expr[A], ty: Type) = - Infer.typeCheck(term).runFully( - withBools ++ asFullyQualified(constructors), - definedOption ++ boolTypes, - kindNotGen) match { - case Left(err) => assert(false, err) + Infer + .typeCheck(term) + .runFully( + withBools ++ asFullyQualified(constructors), + definedOption ++ boolTypes, + kindNotGen + ) match { + case Left(err) => assert(false, err) case Right(tpe) => assert(tpe.getType == ty, term.toString) } def failWithOpt[A: HasRegion](term: Expr[A]) = - Infer.typeCheck(term).runFully( - withBools ++ asFullyQualified(constructors), - definedOption ++ boolTypes, - kinds) match { + Infer + .typeCheck(term) + .runFully( + withBools ++ asFullyQualified(constructors), + definedOption ++ boolTypes, + kinds + ) match { case Left(_) => assert(true) - case Right(tpe) => assert(false, s"expected to fail, but inferred type $tpe") + case Right(tpe) => + assert(false, s"expected to fail, but inferred type $tpe") } testWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( (Pattern.WildCard, lit(0)) - )), Type.IntType) + ) + ), + Type.IntType + ) testWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( - (Pattern.PositionalStruct("Some", List(Pattern.Var(Identifier.Name("a")))), v("a")), + ( + Pattern.PositionalStruct( + "Some", + List(Pattern.Var(Identifier.Name("a"))) + ), + v("a") + ), (Pattern.PositionalStruct("None", Nil), lit(42)) - )), Type.IntType) + ) + ), + Type.IntType + ) failWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( (Pattern.PositionalStruct("Foo", List(Pattern.WildCard)), lit(0)) - ))) + ) + ) + ) } test("match with custom generic types") { @@ -370,8 +492,17 @@ class RankNInferTest extends AnyFunSuite { val kinds = Type.builtInKinds.updated(optName, Kind(Kind.Type.co)) val constructors = Map( - (Identifier.unsafe("Some"), Type.forAll(NonEmptyList.of(b("a")), Type.Fun(tv("a"), Type.TyApply(optType, tv("a"))))), - (Identifier.unsafe("None"), Type.forAll(NonEmptyList.of(b("a")), Type.TyApply(optType, tv("a")))) + ( + Identifier.unsafe("Some"), + Type.forAll( + NonEmptyList.of(b("a")), + Type.Fun(tv("a"), Type.TyApply(optType, tv("a"))) + ) + ), + ( + Identifier.unsafe("None"), + Type.forAll(NonEmptyList.of(b("a")), Type.TyApply(optType, tv("a"))) + ) ) def testWithOpt[A: HasRegion](term: Expr[A], ty: Type) = @@ -380,45 +511,77 @@ class RankNInferTest extends AnyFunSuite { .runFully( withBools ++ asFullyQualified(constructors), definedOptionGen ++ boolTypes, - kinds) match { - case Left(err) => assert(false, err) - case Right(tpe) => assert(tpe.getType == ty, term.toString) - } + kinds + ) match { + case Left(err) => assert(false, err) + case Right(tpe) => assert(tpe.getType == ty, term.toString) + } def failWithOpt[A: HasRegion](term: Expr[A]) = - Infer.typeCheck(term).runFully( - withBools ++ asFullyQualified(constructors), - definedOptionGen ++ boolTypes, - kinds) match { + Infer + .typeCheck(term) + .runFully( + withBools ++ asFullyQualified(constructors), + definedOptionGen ++ boolTypes, + kinds + ) match { case Left(_) => assert(true) - case Right(tpe) => assert(false, s"expected to fail, but inferred type $tpe") + case Right(tpe) => + assert(false, s"expected to fail, but inferred type $tpe") } testWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( (Pattern.WildCard, lit(0)) - )), Type.IntType) + ) + ), + Type.IntType + ) testWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( - (Pattern.PositionalStruct("Some", List(Pattern.Var(Identifier.Name("a")))), v("a")), + ( + Pattern.PositionalStruct( + "Some", + List(Pattern.Var(Identifier.Name("a"))) + ), + v("a") + ), (Pattern.PositionalStruct("None", Nil), lit(42)) - )), Type.IntType) + ) + ), + Type.IntType + ) // Nested Some testWithOpt( - matche(app(v("Some"), app(v("Some"), lit(1))), + matche( + app(v("Some"), app(v("Some"), lit(1))), NonEmptyList.of( - (Pattern.PositionalStruct("Some", List(Pattern.Var(Identifier.Name("a")))), v("a")) - )), Type.TyApply(optType, Type.IntType)) + ( + Pattern.PositionalStruct( + "Some", + List(Pattern.Var(Identifier.Name("a"))) + ), + v("a") + ) + ) + ), + Type.TyApply(optType, Type.IntType) + ) failWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( (Pattern.PositionalStruct("Foo", List(Pattern.WildCard)), lit(0)) - ))) + ) + ) + ) } test("Test a constructor with ForAll") { @@ -429,73 +592,130 @@ class RankNInferTest extends AnyFunSuite { val optType: Type.Tau = Type.TyConst(optName) val pn = testPackage - /** - * struct Pure(pure: forall a. a -> f[a]) - */ + + /** struct Pure(pure: forall a. a -> f[a]) + */ val defined = Map( - ((pn, Constructor("Pure")), (List((Type.Var.Bound("f"), Kind(Kind.Type.in).in)), - List(Type.forAll(NonEmptyList.of((Type.Var.Bound("a"), Kind.Type)), Type.Fun(tv("a"), Type.TyApply(tv("f"), tv("a"))))), - pureName)), - ((pn, Constructor("Some")), (List((Type.Var.Bound("a"), Kind.Type.co)), List(tv("a")), optName)), - ((pn, Constructor("None")), (List((Type.Var.Bound("a"), Kind.Type.co)), Nil, optName))) + ( + (pn, Constructor("Pure")), + ( + List((Type.Var.Bound("f"), Kind(Kind.Type.in).in)), + List( + Type.forAll( + NonEmptyList.of((Type.Var.Bound("a"), Kind.Type)), + Type.Fun(tv("a"), Type.TyApply(tv("f"), tv("a"))) + ) + ), + pureName + ) + ), + ( + (pn, Constructor("Some")), + (List((Type.Var.Bound("a"), Kind.Type.co)), List(tv("a")), optName) + ), + ( + (pn, Constructor("None")), + (List((Type.Var.Bound("a"), Kind.Type.co)), Nil, optName) + ) + ) val constructors = Map( - (Identifier.unsafe("Pure"), Type.forAll(NonEmptyList.of(b1("f")), - Type.Fun(Type.forAll(NonEmptyList.of(b("a")), Type.Fun(tv("a"), Type.TyApply(tv("f"), tv("a")))), - Type.TyApply(Type.TyConst(pureName), tv("f")) ))), - (Identifier.unsafe("Some"), Type.forAll(NonEmptyList.of(b("a")), Type.Fun(tv("a"), Type.TyApply(optType, tv("a"))))), - (Identifier.unsafe("None"), Type.forAll(NonEmptyList.of(b("a")), Type.TyApply(optType, tv("a")))) + ( + Identifier.unsafe("Pure"), + Type.forAll( + NonEmptyList.of(b1("f")), + Type.Fun( + Type.forAll( + NonEmptyList.of(b("a")), + Type.Fun(tv("a"), Type.TyApply(tv("f"), tv("a"))) + ), + Type.TyApply(Type.TyConst(pureName), tv("f")) + ) + ) + ), + ( + Identifier.unsafe("Some"), + Type.forAll( + NonEmptyList.of(b("a")), + Type.Fun(tv("a"), Type.TyApply(optType, tv("a"))) + ) + ), + ( + Identifier.unsafe("None"), + Type.forAll(NonEmptyList.of(b("a")), Type.TyApply(optType, tv("a"))) + ) ) def testWithTypes[A: HasRegion](term: Expr[A], ty: Type) = - Infer.typeCheck(term).runFully( - withBools ++ asFullyQualified(constructors), - defined ++ boolTypes, - Type.builtInKinds.updated(optName, Kind(Kind.Type.co))) match { - case Left(err) => assert(false, err) + Infer + .typeCheck(term) + .runFully( + withBools ++ asFullyQualified(constructors), + defined ++ boolTypes, + Type.builtInKinds.updated(optName, Kind(Kind.Type.co)) + ) match { + case Left(err) => assert(false, err) case Right(tpe) => assert(tpe.getType == ty, term.toString) } testWithTypes( - app(v("Pure"), v("Some")), Type.TyApply(Type.TyConst(pureName), optType)) + app(v("Pure"), v("Some")), + Type.TyApply(Type.TyConst(pureName), optType) + ) } test("test inference of basic expressions") { - parseProgram("""# + parseProgram( + """# main = (x -> x)(1) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# x = 1 y = x main = y -""", "Int") +""", + "Int" + ) } test("test inference with partial def annotation") { - parseProgram("""# + parseProgram( + """# ident: forall a. a -> a = x -> x main = ident(1) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# def ident(x: a): x main = ident(1) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# def ident(x) -> a: x main = ident(1) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# enum MyBool: T, F struct Pair(fst, snd) @@ -511,10 +731,12 @@ res = ( ) main = res -""", "Int") - +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# struct Pair(fst: a, snd: a) @@ -527,43 +749,61 @@ fst = ( ) main = fst -""", "Int") +""", + "Int" + ) } test("test inference with some defined types") { - parseProgram("""# + parseProgram( + """# struct Unit main = Unit -""", "Unit") +""", + "Unit" + ) - parseProgram("""# + parseProgram( + """# enum Option: None Some(a) main = Some(1) -""", "Option[Int]") +""", + "Option[Int]" + ) - parseProgram("""# + parseProgram( + """# enum Option: None Some(a) main = Some -""", "forall a. a -> Option[a]") +""", + "forall a. a -> Option[a]" + ) - parseProgram("""# + parseProgram( + """# id = x -> x main = id -""", "forall a. a -> a") +""", + "forall a. a -> a" + ) - parseProgram("""# + parseProgram( + """# id = x -> x main = id(1) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# enum Option: None Some(a) @@ -572,9 +812,12 @@ x = Some(1) main = match x: case None: 0 case Some(y): y -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# enum List: Empty NonEmpty(a: a, tail: b) @@ -583,9 +826,12 @@ x = NonEmpty(1, Empty) main = match x: case Empty: 0 case NonEmpty(y, _): y -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# enum Opt: None, Some(a) @@ -597,9 +843,12 @@ def optBind(opt, bindFn): case Some(a): bindFn(a) main = Monad(Some, optBind) -""", "Monad[Opt]") +""", + "Monad[Opt]" + ) - parseProgram("""# + parseProgram( + """# enum Opt: None, Some(a) @@ -622,15 +871,18 @@ def use_bind(m, a, b, c): a1.bind(_ -> b1.bind(_ -> c1)) main = use_bind(option_monad, None, None, None) -""", "forall a. Opt[a]") - - // TODO: - // The challenge here is that the naive curried form of the - // def will not see the forall until the final parameter - // we need to bubble up the forall on the whole function. - // - // same as the above with a different order in use_bind - parseProgram("""# +""", + "forall a. Opt[a]" + ) + + // TODO: + // The challenge here is that the naive curried form of the + // def will not see the forall until the final parameter + // we need to bubble up the forall on the whole function. + // + // same as the above with a different order in use_bind + parseProgram( + """# enum Opt: None, Some(a) @@ -653,20 +905,26 @@ def use_bind(a, b, c, m): a1.bind(_ -> b1.bind(_ -> c1)) main = use_bind(None, None, None, option_monad) -""", "forall a. Opt[a]") +""", + "forall a. Opt[a]" + ) } test("test zero arg defs") { - parseProgram("""# + parseProgram( + """# struct Foo fst: Foo = Foo main = fst -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# enum Foo: Bar, Baz(a) @@ -674,13 +932,15 @@ enum Foo: fst: Foo[a] = Bar main = fst -""", "forall a. Foo[a]") +""", + "forall a. Foo[a]" + ) } - test("substition works correctly") { - parseProgram("""# + parseProgram( + """# (id: forall a. a -> a) = x -> x struct Foo @@ -688,9 +948,12 @@ struct Foo def apply(fn, arg: Foo): fn(arg) main = apply(id, Foo) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# (id: forall a. a -> a) = x -> x struct Foo @@ -700,9 +963,12 @@ struct Foo def apply(fn, arg: Foo): fn(arg) main = apply(id, Foo) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# struct FnWrapper(fn: a -> a) @@ -717,9 +983,12 @@ def apply(fn, arg: Foo): f(arg) main = apply(id, Foo) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# struct Foo (id: forall a. a -> Foo) = _ -> Foo @@ -729,7 +998,9 @@ struct Foo (idGen2: (forall a. a) -> Foo) = id2 main = Foo -""", "Foo") +""", + "Foo" + ) parseProgramIllTyped("""# @@ -741,7 +1012,8 @@ struct Foo main = Foo """) - parseProgram("""# + parseProgram( + """# enum Foo: Bar, Baz (bar1: forall a. (Foo -> a) -> a) = fn -> fn(Bar) @@ -764,8 +1036,11 @@ enum Foo: Bar, Baz (producer1: Foo -> ((Foo -> Foo) -> Foo)) = producer main = Bar -""", "Foo") - parseProgram("""# +""", + "Foo" + ) + parseProgram( + """# enum Foo: Bar, Baz struct Cont[b: +*, a](cont: (b -> a) -> a) @@ -790,7 +1065,9 @@ struct Cont[b: +*, a](cont: (b -> a) -> a) (producer1: Foo -> Cont[Foo, Foo]) = producer main = Bar -""", "Foo") +""", + "Foo" + ) parseProgramIllTyped("""# enum Foo: Bar, Baz @@ -804,7 +1081,8 @@ struct Cont(cont: (b -> a) -> a) main = Bar """) - parseProgram("""# + parseProgram( + """# struct Foo enum Opt: Nope, Yep(a) @@ -816,9 +1094,11 @@ enum Opt: Nope, Yep(a) (consumer1: (forall a. Opt[a]) -> Foo) = consumer main = Foo -""", "Foo") +""", + "Foo" + ) - parseProgramIllTyped("""# + parseProgramIllTyped("""# struct Foo enum Opt: Nope, Yep(a) @@ -828,7 +1108,7 @@ enum Opt: Nope, Yep(a) main = Foo """) - parseProgramIllTyped("""# + parseProgramIllTyped("""# struct Foo enum Opt: Nope, Yep(a) @@ -839,7 +1119,8 @@ enum Opt: Nope, Yep(a) main = Foo """) - parseProgram("""# + parseProgram( + """# struct Foo enum Opt: Nope, Yep(a) @@ -853,7 +1134,9 @@ struct FnWrapper(fn: a -> b) (consumer1: FnWrapper[forall a. Opt[a], Foo]) = consumer main = Foo -""", "Foo") +""", + "Foo" + ) parseProgramIllTyped("""# struct Foo @@ -883,7 +1166,8 @@ main = Foo } test("def with type annotation and use the types inside") { - parseProgram("""# + parseProgram( + """# struct Pair(fst, snd) @@ -892,11 +1176,13 @@ def fst(p: Pair[a, b]) -> a: f main = fst(Pair(1, "1")) -""", "Int") +""", + "Int" + ) } test("test that we see some ill typed programs") { - parseProgramIllTyped("""# + parseProgramIllTyped("""# def foo(i: Int): i @@ -906,7 +1192,7 @@ main = foo("Not an Int") test("using a literal the wrong type is ill-typed") { - parseProgramIllTyped("""# + parseProgramIllTyped("""# x = "foo" @@ -915,7 +1201,7 @@ main = match x: case y: y """) - parseProgramIllTyped("""# + parseProgramIllTyped("""# x = 1 @@ -946,7 +1232,8 @@ main = 1 } test("structural recursion can be typed") { - parseProgram("""# + parseProgram( + """# enum Nat: Zero, Succ(prev: Nat) @@ -956,9 +1243,12 @@ def len(l): case Succ(p): len(p) main = len(Succ(Succ(Zero))) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# enum Nat: Zero, Succ(prev: Nat) @@ -970,12 +1260,15 @@ def len(l): len0(l) main = len(Succ(Succ(Zero))) -""", "Int") +""", + "Int" + ) } test("nested def example") { - parseProgram("""# + parseProgram( + """# struct Pair(first, second) def bar(x): @@ -985,12 +1278,15 @@ def bar(x): baz(10) main = bar(5) -""", "Pair[Int, Int]") +""", + "Pair[Int, Int]" + ) } test("test checkRho on annotated lambda") { - parseProgram("""# + parseProgram( + """# struct Foo struct Bar @@ -1002,10 +1298,13 @@ struct Bar dontCall = \(_: (forall a. a) -> Bar) -> Foo (main: Foo) = dontCall(fn) -""", "Foo") +""", + "Foo" + ) } test("ForAll as function arg") { - parseProgram("""# + parseProgram( + """# struct Wrap[bbbb](y1: bbbb) struct Foo[cccc](y2: cccc) struct Nil @@ -1021,11 +1320,14 @@ def foo(cra_fn: Wrap[(forall ssss. Foo[ssss]) -> Nil]): match cra_fn: case (_: Wrap[(forall x. Foo[x]) -> Nil]): Nil main = foo -""", "Wrap[(forall ssss. Foo[ssss]) -> Nil] -> Nil") +""", + "Wrap[(forall ssss. Foo[ssss]) -> Nil] -> Nil" + ) } test("use a type annotation inside a def") { - parseProgram("""# + parseProgram( + """# struct Foo struct Bar def ignore(_): Foo @@ -1033,9 +1335,12 @@ def add(x): (y: Foo) = x _ = ignore(y) Bar -""", "Foo -> Bar") +""", + "Foo -> Bar" + ) - parseProgram("""# + parseProgram( + """# struct Foo struct Bar(f: Foo) def ignore(_): Foo @@ -1043,7 +1348,9 @@ def add(x): ((y: Foo) as b) = x _ = ignore(y) Bar(b) -""", "Foo -> Bar") +""", + "Foo -> Bar" + ) } test("top level matches don't introduce colliding bindings") { @@ -1104,7 +1411,8 @@ struct Bar x: Bar = Foo """) - parseProgram("""# + parseProgram( + """# struct Foo struct Bar @@ -1112,9 +1420,12 @@ x = ( f = Foo f: Foo ) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# struct Foo struct Bar @@ -1122,8 +1433,11 @@ x = ( f: Foo = Foo f ) -""", "Foo") - parseProgram("""# +""", + "Foo" + ) + parseProgram( + """# struct Pair(a, b) struct Foo @@ -1134,9 +1448,12 @@ x = ( _ = ignore(g) f ) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# struct Pair(a, b) struct Foo @@ -1144,17 +1461,23 @@ x = ( Pair(f, _) = Pair(Foo: Foo, Foo) f ) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# struct Foo x: Foo = Foo -""", "Foo") +""", + "Foo" + ) } test("test inner quantification") { - parseProgram("""# + parseProgram( + """# struct Foo # this should just be: type Foo @@ -1166,11 +1489,14 @@ foo = ( ident(Foo) ) -""", "Foo") +""", + "Foo" + ) } test("widening inside a match") { - parseProgram("""# + parseProgram( + """# enum B: True, False def not(b): @@ -1184,7 +1510,9 @@ def branch(x): case False: i -> not(i) res = branch(True)(True) -""", "B") +""", + "B" + ) parseProgramIllTyped("""# enum B: True, False @@ -1204,32 +1532,48 @@ res = branch(True)(True) } test("basic existential types") { - parseProgram("""# + parseProgram( + """# x: exists b. b = 1 -""", "exists b. b") +""", + "exists b. b" + ) - parseProgram("""# + parseProgram( + """# def hide[b](x: b) -> exists a. a: x -""", "forall a. a -> (exists a. a)") +""", + "forall a. a -> (exists a. a)" + ) - parseProgram("""# + parseProgram( + """# def hide[b](x: b) -> exists a. a: x x = hide(1) -""", "exists a. a") +""", + "exists a. a" + ) - parseProgram("""# + parseProgram( + """# def hide[b](x: b) -> exists a. a: x y: exists x. x = 1 x = hide(y) -""", "exists a. a") +""", + "exists a. a" + ) - parseProgram("""# + parseProgram( + """# def hide[b](x: b) -> exists a. a: x y = hide(1) x = hide(y) -""", "exists a. a") +""", + "exists a. a" + ) - parseProgram("""# + parseProgram( + """# struct Tup(a, b) def hide[b](x: b) -> exists a. a: x @@ -1237,19 +1581,25 @@ def makeTup[a, b](x: a, y: b) -> Tup[a, b]: Tup(x, y) x = hide(1) y = hide("1") z: Tup[exists a. a, exists b. b] = makeTup(x, y) -""", "Tup[exists a. a, exists b. b]") - parseProgram("""# +""", + "Tup[exists a. a, exists b. b]" + ) + parseProgram( + """# enum B: T, F struct Inv[a: *](item: a) any: exists a. a = T x: Inv[exists a. a] = Inv(any) -""", "Inv[exists a. a]") +""", + "Inv[exists a. a]" + ) } test("we can use existentials in branches") { - parseProgram("""# + parseProgram( + """# enum MyBool: T, F def branch(b) -> exists a. a: @@ -1258,9 +1608,12 @@ def branch(b) -> exists a. a: case F: "1" x = branch(T) -""", "exists a. a") +""", + "exists a. a" + ) - parseProgram("""# + parseProgram( + """# enum Maybe: Nothing, Something(item: exists a. a) enum Opt[a]: None, Some(a: a) @@ -1272,9 +1625,12 @@ def branch(b: Maybe) -> exists a. Opt[a]: case Nothing: None x = branch(x) -""", "exists a. Opt[a]") +""", + "exists a. Opt[a]" + ) - parseProgram("""# + parseProgram( + """# struct MyTup(a, b) enum MyBool: T, F @@ -1283,7 +1639,9 @@ b = T x = MyTup((match b: case T: F case F: T), (x: MyBool) -> x): exists a. MyTup[a, a -> MyBool] -""", "exists a. MyTup[a, a -> MyBool]") +""", + "exists a. MyTup[a, a -> MyBool]" + ) parseProgramIllTyped("""# struct MyTup(a, b) @@ -1298,7 +1656,8 @@ x = MyTup((match b: } test("use existentials in ADTs") { - parseProgram("""# + parseProgram( + """# struct Tup(a, b) enum FreeF[a]: Pure(a: a) @@ -1314,12 +1673,15 @@ def branch[a](b: FreeF[a]) -> exists b. Opt[Tup[FreeF[b], b -> a]]: case Mapped(x): Some(x) case _: None -""", "forall a. FreeF[a] -> exists b. Opt[Tup[FreeF[b], b -> a]]") +""", + "forall a. FreeF[a] -> exists b. Opt[Tup[FreeF[b], b -> a]]" + ) } test("we can use existentials to delay calls") { - parseProgram("""# + parseProgram( + """# struct MyTup(a, b) def delay[a, b](fn: a -> b, a: a) -> exists c. MyTup[c -> b, c]: @@ -1330,7 +1692,9 @@ def call[a](tup: exists c. MyTup[c -> a, c]) -> a: fn(arg) x = call(delay(x -> x, 1)) -""", "Int") +""", + "Int" + ) } test("we can't see through existentials") { @@ -1404,26 +1768,33 @@ def unsound[f: * -> *](fany: f[exists a. a], get: forall a. f[a] -> a) -> forall } test("we can use existentials with invariant types") { - parseProgram("""# + parseProgram( + """# struct Box(a) x: exists a. a = 1 fn: (exists a. a) -> Box[exists a. a] = Box y = fn(x) -""", "Box[exists a. a]") +""", + "Box[exists a. a]" + ) - parseProgram("""# + parseProgram( + """# struct Box(a) x: exists a. a = 1 y: Box[exists a. a] = Box(x) -""", "Box[exists a. a]") +""", + "Box[exists a. a]" + ) } test("invariant instantiation regression") { - parseProgram("""# + parseProgram( + """# struct Box[x: *](a: x) struct One enum Opt[a]: None, Some(a: a) @@ -1446,17 +1817,22 @@ def process(o: Box[Opt[One]]) -> One: case Box(None): One z = process(y) -""", "One") +""", + "One" + ) } test("some subtyping relationships") { -parseProgram(""" + parseProgram( + """ struct Foo[a: *] f1: forall a. Foo[a] = Foo f2: Foo[forall a. a] = Foo f3: Foo[forall a. a] = f1 -""", "Foo[forall a. a]") +""", + "Foo[forall a. a]" + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala index eb3f1bd45..516beb3f9 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala @@ -3,22 +3,27 @@ package org.bykn.bosatsu.rankn import cats.data.NonEmptyList import org.bykn.bosatsu.Kind import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class TypeTest extends AnyFunSuite { import NTypeGen.shrinkType implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 1000) - //PropertyCheckConfiguration(minSuccessful = 5) + // PropertyCheckConfiguration(minSuccessful = 5) def parse(s: String): Type = Type.fullyResolvedParser.parseAll(s) match { case Right(t) => t case Left(err) => - sys.error(s"failed to parse: <$s> at ${s.drop(err.failedAtOffset)}\n\n$err") + sys.error( + s"failed to parse: <$s> at ${s.drop(err.failedAtOffset)}\n\n$err" + ) } test("free vars are not duplicated") { @@ -33,8 +38,10 @@ class TypeTest extends AnyFunSuite { val frees = Type.freeTyVars(ts :: Nil) val norm = Type.normalize(ts) val freeNorm = Type.freeTyVars(norm :: Nil) - assert(frees == freeNorm, - s"${Type.typeParser.render(ts)} => ${Type.typeParser.render(norm)}") + assert( + frees == freeNorm, + s"${Type.typeParser.render(ts)} => ${Type.typeParser.render(norm)}" + ) } } @@ -42,13 +49,15 @@ class TypeTest extends AnyFunSuite { forAll(Gen.listOf(NTypeGen.genDepth03)) { ts => Type.Tuple(ts) match { case Type.Tuple(ts1) => assert(ts1 == ts) - case notTup => fail(notTup.toString) + case notTup => fail(notTup.toString) } } assert(Type.Tuple.unapply(parse("()")) == Some(Nil)) - assert(Type.Tuple.unapply(parse("(a, b, c)")) == - Some(List("a", "b", "c").map(parse))) + assert( + Type.Tuple.unapply(parse("(a, b, c)")) == + Some(List("a", "b", "c").map(parse)) + ) } test("unapplyAll is the inverse of applyAll") { @@ -57,8 +66,10 @@ class TypeTest extends AnyFunSuite { assert(Type.applyAll(left, args) == ts) } - assert(Type.unapplyAll(parse("foo[bar]")) == - (parse("foo"), List(parse("bar")))) + assert( + Type.unapplyAll(parse("foo[bar]")) == + (parse("foo"), List(parse("bar"))) + ) } test("freeBoundVar doesn't change by applyAll") { @@ -66,10 +77,11 @@ class TypeTest extends AnyFunSuite { val applied = Type.applyAll(ts, args) val free0 = Type.freeBoundTyVars(ts :: args) val free1 = Type.freeBoundTyVars(applied :: Nil) - assert(free1.toSet == free0.toSet, - s"applied = ${Type.typeParser.render(applied)}, (${Type.typeParser.render(ts)})[${ - args.iterator.map(Type.typeParser.render(_)).mkString(", ") - }]})") + assert( + free1.toSet == free0.toSet, + s"applied = ${Type.typeParser.render(applied)}, (${Type.typeParser + .render(ts)})[${args.iterator.map(Type.typeParser.render(_)).mkString(", ")}]})" + ) } } @@ -99,24 +111,123 @@ class TypeTest extends AnyFunSuite { forAll(NTypeGen.genDepth03) { t => assert(t.sameAs(Type.normalize(t))) } - + { import Type._ import Var.Bound import org.bykn.bosatsu.Variance._ import org.bykn.bosatsu.Kind.{Arg, Cons, Type => KType} - - val qt1 = Quantified(Quantification.Dual( - NonEmptyList((Bound("qsnMgkhqY"), Cons(Arg(Covariant, Cons(Arg(Covariant, KType), KType)), Cons(Arg(Phantom, KType), KType))), List((Bound("u"), Cons(Arg(Contravariant, KType), Cons(Arg(Invariant, KType), KType))))), - NonEmptyList((Bound("nack"), Cons(Arg(Invariant, Cons(Arg(Phantom, KType), KType)), Cons(Arg(Phantom, KType), KType))), List((Bound("u"), Cons(Arg(Contravariant, Cons(Arg(Contravariant, KType), KType)), Cons(Arg(Contravariant, KType), KType))), (Bound("vHxbikOne"), Cons(Arg(Invariant, Cons(Arg(Covariant, KType), KType)), Cons(Arg(Contravariant, KType), KType))), (Bound("jofpdjgp"), Cons(Arg(Covariant, Cons(Arg(Phantom, KType), KType)), KType)), (Bound("r"), Cons(Arg(Invariant, Cons(Arg(Covariant, KType), KType)), Cons(Arg(Invariant, KType), KType)))))), - TyVar(Bound("u"))) + + val qt1 = Quantified( + Quantification.Dual( + NonEmptyList( + ( + Bound("qsnMgkhqY"), + Cons( + Arg(Covariant, Cons(Arg(Covariant, KType), KType)), + Cons(Arg(Phantom, KType), KType) + ) + ), + List( + ( + Bound("u"), + Cons( + Arg(Contravariant, KType), + Cons(Arg(Invariant, KType), KType) + ) + ) + ) + ), + NonEmptyList( + ( + Bound("nack"), + Cons( + Arg(Invariant, Cons(Arg(Phantom, KType), KType)), + Cons(Arg(Phantom, KType), KType) + ) + ), + List( + ( + Bound("u"), + Cons( + Arg(Contravariant, Cons(Arg(Contravariant, KType), KType)), + Cons(Arg(Contravariant, KType), KType) + ) + ), + ( + Bound("vHxbikOne"), + Cons( + Arg(Invariant, Cons(Arg(Covariant, KType), KType)), + Cons(Arg(Contravariant, KType), KType) + ) + ), + ( + Bound("jofpdjgp"), + Cons(Arg(Covariant, Cons(Arg(Phantom, KType), KType)), KType) + ), + ( + Bound("r"), + Cons( + Arg(Invariant, Cons(Arg(Covariant, KType), KType)), + Cons(Arg(Invariant, KType), KType) + ) + ) + ) + ) + ), + TyVar(Bound("u")) + ) val qt2 = Quantified( - Quantification.Exists(NonEmptyList((Bound("chajb"), Cons(Arg(Contravariant, Cons(Arg(Covariant, KType), KType)), Cons(Arg(Contravariant, KType), KType))), List((Bound("e"), Cons(Arg(Invariant, Cons(Arg(Phantom, KType), KType)), Cons(Arg(Phantom, KType), Cons(Arg(Phantom, KType), KType)))), (Bound("vg"), Cons(Arg(Phantom, Cons(Arg(Phantom, KType), KType)), Cons(Arg(Phantom, KType), KType))), (Bound("vvki"), Cons(Arg(Contravariant, Cons(Arg(Phantom, KType), Cons(Arg(Phantom, KType), KType))), KType)), (Bound("e"), Cons(Arg(Invariant, Cons(Arg(Invariant, KType), KType)), Cons(Arg(Phantom, KType), KType)))))), - TyVar(Bound("e"))) + Quantification.Exists( + NonEmptyList( + ( + Bound("chajb"), + Cons( + Arg(Contravariant, Cons(Arg(Covariant, KType), KType)), + Cons(Arg(Contravariant, KType), KType) + ) + ), + List( + ( + Bound("e"), + Cons( + Arg(Invariant, Cons(Arg(Phantom, KType), KType)), + Cons(Arg(Phantom, KType), Cons(Arg(Phantom, KType), KType)) + ) + ), + ( + Bound("vg"), + Cons( + Arg(Phantom, Cons(Arg(Phantom, KType), KType)), + Cons(Arg(Phantom, KType), KType) + ) + ), + ( + Bound("vvki"), + Cons( + Arg( + Contravariant, + Cons(Arg(Phantom, KType), Cons(Arg(Phantom, KType), KType)) + ), + KType + ) + ), + ( + Bound("e"), + Cons( + Arg(Invariant, Cons(Arg(Invariant, KType), KType)), + Cons(Arg(Phantom, KType), KType) + ) + ) + ) + ) + ), + TyVar(Bound("e")) + ) val regressions: List[Type] = - qt1 :: + qt1 :: qt2 :: Nil @@ -126,11 +237,10 @@ class TypeTest extends AnyFunSuite { Type.typeParser.render(t) val normt2 = Type.normalize(normt) - assert(normt == normt2, - s"${show(normt)} normalizes to ${show(normt2)}") - assert(t.sameAs(normt), s"${show(t)}.sameAs(${show(normt)}) == false") + assert(normt == normt2, s"${show(normt)} normalizes to ${show(normt2)}") + assert(t.sameAs(normt), s"${show(t)}.sameAs(${show(normt)}) == false") } - + assert(Type.freeBoundTyVars(qt1.in :: Nil) == List(Bound("u"))) } } @@ -151,16 +261,16 @@ class TypeTest extends AnyFunSuite { forAll(NTypeGen.genDepth03)(law(_)) - - forAll(NTypeGen.lowerIdent, Gen.choose(Long.MinValue, Long.MaxValue)) { (b, id) => - val str = "$" + b + "$" + id.toString - val tpe = parse(str) - law(tpe) - tpe match { - case Type.TyVar(Type.Var.Skolem(b1, k1, _, i1)) => - assert((b1, k1, i1) === (b, Kind.Type ,id)) - case other => fail(other.toString) - } + forAll(NTypeGen.lowerIdent, Gen.choose(Long.MinValue, Long.MaxValue)) { + (b, id) => + val str = "$" + b + "$" + id.toString + val tpe = parse(str) + law(tpe) + tpe match { + case Type.TyVar(Type.Var.Skolem(b1, k1, _, i1)) => + assert((b1, k1, i1) === (b, Kind.Type, id)) + case other => fail(other.toString) + } } forAll { (l: Long) => @@ -169,8 +279,10 @@ class TypeTest extends AnyFunSuite { } test("test all binders") { - assert(Type.allBinders.filter(_.name.startsWith("a")).take(100).map(_.name) == - ("a" #:: Stream.iterate(0)(_ + 1).map { i => s"a$i" }).take(100)) + assert( + Type.allBinders.filter(_.name.startsWith("a")).take(100).map(_.name) == + ("a" #:: Stream.iterate(0)(_ + 1).map { i => s"a$i" }).take(100) + ) } test("tyVarBinders is identity for Bound") { @@ -198,7 +310,7 @@ class TypeTest extends AnyFunSuite { test("hasNoVars fully recurses") { def allTypesIn(t: Type): List[Type] = t match { - case f@Type.ForAll(bounds, in) => + case f @ Type.ForAll(bounds, in) => // filter bounds out, since they are shadowed val boundSet = bounds.toList.iterator.map(_._1).toSet[Type.Var] f :: (allTypesIn(in).filterNot { it => @@ -206,8 +318,8 @@ class TypeTest extends AnyFunSuite { // if we intersect, this is not a legit type to consider (boundSet & frees).nonEmpty }) - case t@Type.TyApply(a, b) => t :: allTypesIn(a) ::: allTypesIn(b) - case other => other :: Nil + case t @ Type.TyApply(a, b) => t :: allTypesIn(a) ::: allTypesIn(b) + case other => other :: Nil } def law(t: Type) = { @@ -222,10 +334,19 @@ class TypeTest extends AnyFunSuite { val pastFails = List( - Type.forAll(NonEmptyList.of((Type.Var.Bound("x"), Kind.Type), (Type.Var.Bound("ogtumm"), Kind.Type), (Type.Var.Bound("t"), Kind.Type)), - Type.TyVar(Type.Var.Bound("x"))), - Type.forAll(NonEmptyList.of((Type.Var.Bound("a"), Kind.Type)),Type.TyVar(Type.Var.Bound("a"))) + Type.forAll( + NonEmptyList.of( + (Type.Var.Bound("x"), Kind.Type), + (Type.Var.Bound("ogtumm"), Kind.Type), + (Type.Var.Bound("t"), Kind.Type) + ), + Type.TyVar(Type.Var.Bound("x")) + ), + Type.forAll( + NonEmptyList.of((Type.Var.Bound("a"), Kind.Type)), + Type.TyVar(Type.Var.Bound("a")) ) + ) pastFails.foreach(law) } @@ -263,7 +384,8 @@ class TypeTest extends AnyFunSuite { def genSubs(depth: Int): Gen[Map[Type.Var, Type]] = { val pair = Gen.zip( NTypeGen.genBound, - NTypeGen.genDepth(depth, Some(NTypeGen.genConst))) + NTypeGen.genDepth(depth, Some(NTypeGen.genConst)) + ) Gen.mapOf(pair) } @@ -308,7 +430,9 @@ class TypeTest extends AnyFunSuite { // now subs1 has keys that can be completely removed, so // after substitution, those keys should be gone val t1 = Type.substituteVar(t, subs1) - assert((Type.freeBoundTyVars(t1 :: Nil).toSet & subs1.keySet) == Set.empty) + assert( + (Type.freeBoundTyVars(t1 :: Nil).toSet & subs1.keySet) == Set.empty + ) } forAll(NTypeGen.genDepth03, genSubs(3))(law _) @@ -320,15 +444,25 @@ class TypeTest extends AnyFunSuite { case Type.ForAll(fas, t) => Type.instantiate(fas.iterator.toMap, t, t2) match { case Some((frees, subs)) => - val t3 = Type.substituteVar(t, subs.iterator.map { case (k, (_, v)) => (k, v)}.toMap) - - val t4 = Type.substituteVar(t3, frees.iterator.map { - case (v1, (_, v2)) => (v1, Type.TyVar(v2)) - }.toMap) - - val t5 = Type.quantify(forallList = frees.iterator.map { - case (_, tup) => tup.swap - }.toList, existList = Nil, t4) + val t3 = Type.substituteVar( + t, + subs.iterator.map { case (k, (_, v)) => (k, v) }.toMap + ) + + val t4 = Type.substituteVar( + t3, + frees.iterator.map { case (v1, (_, v2)) => + (v1, Type.TyVar(v2)) + }.toMap + ) + + val t5 = Type.quantify( + forallList = frees.iterator.map { case (_, tup) => + tup.swap + }.toList, + existList = Nil, + t4 + ) assert(t5.sameAs(t2)) case None => @@ -362,23 +496,40 @@ class TypeTest extends AnyFunSuite { assert(res == None) } - check("forall a. a", "Bosatsu/Predef::Int", - List("a" -> "Bosatsu/Predef::Int")) - check("forall a. a -> a", "Bosatsu/Predef::Int -> Bosatsu/Predef::Int", - List("a" -> "Bosatsu/Predef::Int")) - check("forall a. a -> Bosatsu/Predef::Foo[a]", "Bosatsu/Predef::Int -> Bosatsu/Predef::Foo[Bosatsu/Predef::Int]", - List("a" -> "Bosatsu/Predef::Int")) - check("forall a. Bosatsu/Predef::Option[a]", "Bosatsu/Predef::Option[Bosatsu/Predef::Int]", - List("a" -> "Bosatsu/Predef::Int")) - - check("forall a. a", "forall a. a", - List("a" -> "forall a. a")) + check( + "forall a. a", + "Bosatsu/Predef::Int", + List("a" -> "Bosatsu/Predef::Int") + ) + check( + "forall a. a -> a", + "Bosatsu/Predef::Int -> Bosatsu/Predef::Int", + List("a" -> "Bosatsu/Predef::Int") + ) + check( + "forall a. a -> Bosatsu/Predef::Foo[a]", + "Bosatsu/Predef::Int -> Bosatsu/Predef::Foo[Bosatsu/Predef::Int]", + List("a" -> "Bosatsu/Predef::Int") + ) + check( + "forall a. Bosatsu/Predef::Option[a]", + "Bosatsu/Predef::Option[Bosatsu/Predef::Int]", + List("a" -> "Bosatsu/Predef::Int") + ) - check("forall a, b. a -> b", "forall c. c -> Bosatsu/Predef::Int", - List("b" -> "Bosatsu/Predef::Int")) + check("forall a. a", "forall a. a", List("a" -> "forall a. a")) - check("forall a, b. T::Cont[a, b]", "forall a. T::Cont[a, T::Foo]", - List("b" -> "T::Foo")) + check( + "forall a, b. a -> b", + "forall c. c -> Bosatsu/Predef::Int", + List("b" -> "Bosatsu/Predef::Int") + ) + + check( + "forall a, b. T::Cont[a, b]", + "forall a. T::Cont[a, T::Foo]", + List("b" -> "T::Foo") + ) noSub("forall a, b. T::Cont[a, b]", "forall a: * -> *. T::Cont[a, T::Foo]") noSub("forall a. T::Box[a]", "forall a. T::Box[T::Opt[a]]") @@ -394,7 +545,7 @@ class TypeTest extends AnyFunSuite { } yield NonEmptyList(head, tail) forAll(genArgs, NTypeGen.genDepth03) { (args, res) => - val fnType = Type.Fun(args, res) + val fnType = Type.Fun(args, res) fnType match { case Type.Fun(args1, res1) => assert(args1 == args) @@ -406,22 +557,25 @@ class TypeTest extends AnyFunSuite { } test("Quantification.concat is associative") { - forAll(NTypeGen.genQuant, NTypeGen.genQuant, NTypeGen.genQuant) { (a, b, c) => - assert(a.concat(b).concat(c) == a.concat(b.concat(c))) + forAll(NTypeGen.genQuant, NTypeGen.genQuant, NTypeGen.genQuant) { + (a, b, c) => + assert(a.concat(b).concat(c) == a.concat(b.concat(c))) } } test("Quantification.toLists/fromList identity") { forAll(NTypeGen.genQuant) { q => - assert(Type.Quantification.fromLists(q.forallList, q.existList) == Some(q)) + assert( + Type.Quantification.fromLists(q.forallList, q.existList) == Some(q) + ) } } test("unexists/exists | unforall/forall iso") { forAll(NTypeGen.genDepth03) { - case t@Type.Exists(ps, in) => + case t @ Type.Exists(ps, in) => assert(Type.exists(ps, in) == t) - case t@Type.ForAll(ps, in) => + case t @ Type.ForAll(ps, in) => assert(Type.forAll(ps, in) == t) case _ => () } @@ -475,7 +629,7 @@ class TypeTest extends AnyFunSuite { val consts = allConsts(t :: Nil) t match { - case tyc@TyConst(_) => + case tyc @ TyConst(_) => assert(consts == (tyc :: Nil)) case (TyVar(_) | TyMeta(_)) => assert(consts == Nil) @@ -496,7 +650,8 @@ class TypeTest extends AnyFunSuite { ) expect match { - case None => fail(s"$fn resulted in ${Type.typeParser.render(resTpe)}") + case None => + fail(s"$fn resulted in ${Type.typeParser.render(resTpe)}") case Some(exTpe) => val exT = parse(exTpe) assert(resTpe.sameAs(exT), s"${resTpe}.sameAs($exT) == false") @@ -504,16 +659,26 @@ class TypeTest extends AnyFunSuite { case _ => expect match { case None => succeed - case Some(exTpe) => fail(s"$fn is not SimpleUniversal but expected: $exTpe") + case Some(exTpe) => + fail(s"$fn is not SimpleUniversal but expected: $exTpe") } } } check("forall a. a -> a", Some("forall a. a -> a")) - check("forall a. a -> Foo::Option[a]", Some("forall a. a -> Foo::Option[a]")) + check( + "forall a. a -> Foo::Option[a]", + Some("forall a. a -> Foo::Option[a]") + ) check("forall a. a -> (forall b. b)", Some("forall a, b. a -> b")) - check("forall a. a -> (forall b. Foo::Option[b])", Some("forall a, b. a -> Foo::Option[b]")) + check( + "forall a. a -> (forall b. Foo::Option[b])", + Some("forall a, b. a -> Foo::Option[b]") + ) check("forall a. a -> (forall a. a)", Some("forall a, b. a -> b")) - check("forall a. a -> (forall a, c. a -> c)", Some("forall a, b, c. a -> (b -> c)")) + check( + "forall a. a -> (forall a, c. a -> c)", + Some("forall a, b, c. a -> (b -> c)") + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/set/RelLaws.scala b/core/src/test/scala/org/bykn/bosatsu/set/RelLaws.scala index ec95ae167..d05a167ee 100644 --- a/core/src/test/scala/org/bykn/bosatsu/set/RelLaws.scala +++ b/core/src/test/scala/org/bykn/bosatsu/set/RelLaws.scala @@ -173,7 +173,7 @@ abstract class GenIntersectionRelLaws extends munit.ScalaCheckSuite { laws => forAll { (x: S, a: E) => val res = x <:> lift(a) match { case Super | Same | Intersects => true - case Disjoint | Sub => false + case Disjoint | Sub => false } assertEquals(x.contains(a), res) } @@ -250,15 +250,13 @@ abstract class GenRelLaws extends GenIntersectionRelLaws { laws => val both = sa | sb if ((both <:> sa).isStrictSupertype) { assertEquals(sa <:> sb, Rel.Disjoint) - } - else { + } else { assertEquals(sa <:> sb, Rel.Same) } if ((sa <:> sb).isSupertype) { // if one is a supertype, then they are the same assertEquals(sa <:> sb, Rel.Same) - } - else { + } else { assert(!(sb <:> sa).isSupertype) } } @@ -284,13 +282,19 @@ abstract class GenRelLaws extends GenIntersectionRelLaws { laws => property("distibutive (t1 | t2) & t3 = (t1 & t3) | (t2 & t3) homomorphism") { forAll { (t1: S, t2: S, t3: S, j: E) => - assertEquals(((t1 | t2) & t3).contains(j), ((t1 & t3) | (t2 & t3)).contains(j)) + assertEquals( + ((t1 | t2) & t3).contains(j), + ((t1 & t3) | (t2 & t3)).contains(j) + ) } } property("distibutive (t1 & t2) | t3 = (t1 | t3) & (t2 | t3) homomorphism") { forAll { (t1: S, t2: S, t3: S, j: E) => - assertEquals(((t1 & t2) | t3).contains(j), ((t1 | t3) & (t2 | t3)).contains(j)) + assertEquals( + ((t1 & t2) | t3).contains(j), + ((t1 | t3) & (t2 | t3)).contains(j) + ) } } @@ -323,7 +327,10 @@ abstract class GenRelLaws extends GenIntersectionRelLaws { laws => // This checks the testing mechanisms and support // code using Set, where intersection and union are // trivial -abstract class SetGenRelLaws[A](implicit val arbSet: Arbitrary[Set[A]], val arbElement: Arbitrary[A]) extends GenRelLaws { setgenrellaws => +abstract class SetGenRelLaws[A](implicit + val arbSet: Arbitrary[Set[A]], + val arbElement: Arbitrary[A] +) extends GenRelLaws { setgenrellaws => type S = Set[A] type E = A @@ -341,7 +348,7 @@ abstract class SetGenRelLaws[A](implicit val arbSet: Arbitrary[Set[A]], val arbE def relatable = setgenrellaws.relatable def deunion(a: S): Either[(S, S) => Rel.SuperOrSame, (S, S)] = if (a.size > 1) Right((Set(a.head), a.tail)) - else Left({(s1, s2) => if (a == (s1 | s2)) Rel.Same else Rel.Super }) + else Left({ (s1, s2) => if (a == (s1 | s2)) Rel.Same else Rel.Super }) def cheapUnion(head: S, tail: List[S]): S = tail.foldLeft(head)(_ | _) @@ -352,7 +359,10 @@ abstract class SetGenRelLaws[A](implicit val arbSet: Arbitrary[Set[A]], val arbE property("test unionRelCompare") { forAll { (s1: S, s2: S, s3: S) => - assertEquals(urm.unionRelCompare(s1, s2, s3), relatable.relate(s1, s2 | s3)) + assertEquals( + urm.unionRelCompare(s1, s2, s3), + relatable.relate(s1, s2 | s3) + ) } } } @@ -381,8 +391,9 @@ class ListUnionRelatableTests extends munit.ScalaCheckSuite { val listRel: Relatable[List[Byte]] = Relatable.listUnion[Byte]( _ => false, - {(i1, i2) => if (i1 == i2) i1 :: Nil else Nil}, - { i => Left(_.distinct == (i :: Nil)) }) + { (i1, i2) => if (i1 == i2) i1 :: Nil else Nil }, + { i => Left(_.distinct == (i :: Nil)) } + ) property("listRel agrees with setRel") { forAll { (s1: Set[Byte], s2: Set[Byte]) => @@ -392,7 +403,7 @@ class ListUnionRelatableTests extends munit.ScalaCheckSuite { val listSetRel: Relatable[List[Set[Byte]]] = Relatable.listUnion[Set[Byte]]( _.isEmpty, - {(i1, i2) => + { (i1, i2) => val i = i1 & i2 if (i.isEmpty) Nil else (i :: Nil) }, @@ -404,17 +415,18 @@ class ListUnionRelatableTests extends munit.ScalaCheckSuite { if ((sz >= 2) && ((i.hashCode & 1) == 1)) { val (l, r) = i.toList.splitAt(sz / 2) Right((l.toSet, r.toSet)) - } - else { + } else { // these is a either a single value or empty // set which is >= so the fold results in // a set that is empty or has one value Left { is => is.foldLeft(Set.empty[Byte])(_ | _) == i } } - }) + } + ) def smallList[A: Arbitrary]: Gen[List[A]] = - Gen.geometric(4.0) + Gen + .geometric(4.0) .flatMap(Gen.listOfN(_, Arbitrary.arbitrary[A])) property("listUnion works with Set elements") { @@ -422,10 +434,12 @@ class ListUnionRelatableTests extends munit.ScalaCheckSuite { // this is similar to how we use unionRelMod // in code since each item it itself a set - - forAll(smallList[Set[Byte]], smallList[Set[Byte]]) { (s1: List[Set[Byte]], s2: List[Set[Byte]]) => - assertEquals(listSetRel.relate(s1, s2), - setRel.relate(s1.combineAll, s2.combineAll)) + forAll(smallList[Set[Byte]], smallList[Set[Byte]]) { + (s1: List[Set[Byte]], s2: List[Set[Byte]]) => + assertEquals( + listSetRel.relate(s1, s2), + setRel.relate(s1.combineAll, s2.combineAll) + ) } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/set/SetOpsLaws.scala b/core/src/test/scala/org/bykn/bosatsu/set/SetOpsLaws.scala index 51d5f2908..6385b699e 100644 --- a/core/src/test/scala/org/bykn/bosatsu/set/SetOpsLaws.scala +++ b/core/src/test/scala/org/bykn/bosatsu/set/SetOpsLaws.scala @@ -27,7 +27,9 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { assert(eqA.eqv(a12, a21), s"$a12 != $a21") } - def differenceIsIdempotent(a: A, b: A, eqAs: Eq[List[A]])(implicit loc: munit.Location) = { + def differenceIsIdempotent(a: A, b: A, eqAs: Eq[List[A]])(implicit + loc: munit.Location + ) = { val c = unifyUnion(difference(a, b)) val c1 = unifyUnion(differenceAll(c, b :: Nil)) assert(eqAs.eqv(c, c1), s"c = $c\n\nc1 = $c1") @@ -42,7 +44,6 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { } } - test("intersection is commutative") { forAll(genItem, genItem, eqUnion)(intersectionIsCommutative(_, _, _)) } @@ -63,15 +64,19 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { } test("intersection is associative") { - forAll(genItem, genItem, genItem, eqUnion)(intersectionIsAssociative(_, _, _, _)) + forAll(genItem, genItem, genItem, eqUnion)( + intersectionIsAssociative(_, _, _, _) + ) } test("unify union makes size <= input") { forAll(genUnion) { (ps: List[A]) => val unified = unifyUnion(ps) - assert(ps.size >= unified.size, - s"input(${ps.size}): $ps\n\nunified(${unified.size}) = $unified\n\n") + assert( + ps.size >= unified.size, + s"input(${ps.size}): $ps\n\nunified(${unified.size}) = $unified\n\n" + ) } } @@ -112,14 +117,14 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { } test("if a n b = 0 then a - b = a") { - // difference is an upper bound, so this is not true - // although we wish it were - /* + // difference is an upper bound, so this is not true + // although we wish it were + /* if (diff.map(_.normalize).distinct == p1.normalize :: Nil) { // intersection is 0 assert(inter == Nil) } - */ + */ forAll(genItem, genItem, eqUnion)(emptyIntersectionMeansDiffIdent(_, _, _)) } @@ -146,8 +151,7 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { assert(intSub) assert(diffSub) assertEquals(intSub, diffSub) - } - else { + } else { // we can have false positives of intSub // when we have a sampling equality assertEquals(diffSub, false) @@ -191,8 +195,7 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { val missing = missingBranches(wild :: Nil, patsGood) if (missing.nonEmpty) { unreachableBranches(patsGood ::: missing).isEmpty - } - else true + } else true } } } @@ -212,8 +215,7 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { // should be in that case, if (a - b) = a, then // clearly we expect (a n c) == (a n c) - (b n c) // so, b n c has to not intersect with a, but it might - } - else if (isTop(a) && intBC.isEmpty) { + } else if (isTop(a) && intBC.isEmpty) { // in patterns, we "cast" ill-typed comparisions // since we can don't care about cases that don't // type-check. But this can make this law fail: @@ -223,8 +225,7 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { // but (_ n c) = c, and b n c = 0 val leftEqC = differenceAll(unifyUnion(left), c :: Nil).isEmpty assert((left == Nil) || leftEqC) - } - else { + } else { val intAC = intersection(a, c) val right = differenceAll(intAC, intBC) @@ -233,10 +234,13 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { val leftu = unifyUnion(left) if (leftu == unifyUnion(intAC)) { () - } - else { + } else { val rightu = unifyUnion(right) - assertEquals(leftu, rightu, s"diffAB = $diffab, intAC = $intAC, intBC = $intBC") + assertEquals( + leftu, + rightu, + s"diffAB = $diffab, intAC = $intAC, intBC = $intBC" + ) } } } @@ -248,11 +252,13 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { test("(a - b) n c = (a n c) - (b n c)") { forAll(genItem, genItem, genItem)(diffIntersectionLaw(_, _, _)) } - */ + */ def missingBranchesIfAddedRegressions: List[List[A]] = Nil - test("missing branches, if added are total and none of the missing are unreachable") { + test( + "missing branches, if added are total and none of the missing are unreachable" + ) { def law(top: A, pats: List[A]) = { @@ -260,9 +266,12 @@ abstract class SetOpsLaws[A] extends munit.ScalaCheckSuite { val rest1 = missingBranches(top :: Nil, pats ::: rest) if (rest1.isEmpty) { val unreach = unreachableBranches(pats ::: rest) - assertEquals(unreach.filter(rest.toSet), Nil, s"\n\nrest = ${rest}\n\ninit: ${pats}") - } - else { + assertEquals( + unreach.filter(rest.toSet), + Nil, + s"\n\nrest = ${rest}\n\ninit: ${pats}" + ) + } else { fail(s"after adding ${rest} we still need ${rest1}") } } @@ -350,12 +359,12 @@ class DistinctSetOpsTest extends SetOpsLaws[Byte] { class FiniteSetOpsTest extends SetOpsLaws[Set[Int]] { val setOps: SetOps[Set[Int]] = SetOps.fromFinite(0 to 9) - val genItem: Gen[Set[Int]] = { + val genItem: Gen[Set[Int]] = { // don't generate empty sets, items that are empty aren't lawful // the ways the laws are written val gi = Gen.choose(0, 9) Gen.zip(gi, Gen.listOf(gi)).map { case (h, t) => - t.toSet + h + t.toSet + h } } @@ -368,9 +377,11 @@ class FiniteSetOpsTest extends SetOpsLaws[Set[Int]] { class IMapSetOpsTest extends SetOpsLaws[Byte] { val setOps: SetOps[Byte] = - SetOps.imap(SetOps.distinct[Byte], - { (b: Byte) => (b ^ 0xFF).toByte }, - { (b: Byte) => (b ^ 0xFF).toByte }) + SetOps.imap( + SetOps.distinct[Byte], + { (b: Byte) => (b ^ 0xff).toByte }, + { (b: Byte) => (b ^ 0xff).toByte } + ) val genItem: Gen[Byte] = Gen.choose(Byte.MinValue, Byte.MaxValue) @@ -381,7 +392,8 @@ class IMapSetOpsTest extends SetOpsLaws[Byte] { } class ProductSetOpsTest extends SetOpsLaws[(Boolean, Boolean)] { - val setOps: SetOps[(Boolean, Boolean)] = SetOps.product(SetOps.distinct[Boolean], SetOps.distinct[Boolean]) + val setOps: SetOps[(Boolean, Boolean)] = + SetOps.product(SetOps.distinct[Boolean], SetOps.distinct[Boolean]) val genItem: Gen[(Boolean, Boolean)] = Gen.oneOf((false, false), (false, true), (true, false), (true, true)) @@ -404,7 +416,6 @@ class UnitSetOpsTest extends SetOpsLaws[Unit] { }) } - case class Predicate[A](toFn: A => Boolean) { self => def apply(a: A): Boolean = toFn(a) def &&(that: Predicate[A]): Predicate[A] = @@ -428,7 +439,6 @@ object Predicate { Arbitrary(genPred[A]) } - class SetOpsTests extends munit.ScalaCheckSuite { override def scalaCheckTestParameters = @@ -437,35 +447,41 @@ class SetOpsTests extends munit.ScalaCheckSuite { .withMaxDiscardRatio(10) test("allPerms is correct") { - forAll(Gen.choose(0, 6).flatMap(Gen.listOfN(_, Arbitrary.arbitrary[Int]))) { is0 => - // make everything distinct - val is = is0.zipWithIndex - val perms = SetOps.allPerms(is) + forAll(Gen.choose(0, 6).flatMap(Gen.listOfN(_, Arbitrary.arbitrary[Int]))) { + is0 => + // make everything distinct + val is = is0.zipWithIndex + val perms = SetOps.allPerms(is) - def fact(i: Int, acc: Int): Int = - if (i <= 1) acc - else fact(i - 1, i * acc) + def fact(i: Int, acc: Int): Int = + if (i <= 1) acc + else fact(i - 1, i * acc) - assertEquals(perms.length, fact(is0.size, 1)) + assertEquals(perms.length, fact(is0.size, 1)) - perms.foreach { p => - assertEquals(p.sorted, is.sorted) - } - val pi = perms.zipWithIndex + perms.foreach { p => + assertEquals(p.sorted, is.sorted) + } + val pi = perms.zipWithIndex - for { - (p1, i1) <- pi - (p2, i2) <- pi - } assert((i1 >= i2 || (p1 != p2))) + for { + (p1, i1) <- pi + (p2, i2) <- pi + } assert((i1 >= i2 || (p1 != p2))) } } - test("greedySearch finds the optimal path if lookahead is greater than size") { + test( + "greedySearch finds the optimal path if lookahead is greater than size" + ) { // we need a non-commutative operation to test this // use 2x2 matrix multiplication - def mult(left: Vector[Vector[Double]], right: Vector[Vector[Double]]): Vector[Vector[Double]] = { + def mult( + left: Vector[Vector[Double]], + right: Vector[Vector[Double]] + ): Vector[Vector[Double]] = { def dot(v1: Vector[Double], v2: Vector[Double]) = - v1.iterator.zip(v2.iterator).map { case (a, b) => a*b }.sum + v1.iterator.zip(v2.iterator).map { case (a, b) => a * b }.sum def trans(v1: Vector[Vector[Double]]) = Vector(Vector(v1(0)(0), v1(1)(0)), Vector(v1(0)(1), v1(1)(1))) @@ -477,12 +493,13 @@ class SetOpsTests extends munit.ScalaCheckSuite { (c, ci) <- trans(right).zipWithIndex } yield ((ri, ci), dot(r, c)) - data.foldLeft(res) { case (v, ((r, c), d)) => v.updated(r, v(r).updated(c, d)) } + data.foldLeft(res) { case (v, ((r, c), d)) => + v.updated(r, v(r).updated(c, d)) + } } def norm(left: Vector[Vector[Double]]): Double = - left.map(_.map { x => x*x }.sum).sum - + left.map(_.map { x => x * x }.sum).sum val genMat: Gen[Vector[Vector[Double]]] = { val elem = Gen.choose(-1.0, 1.0) @@ -496,7 +513,9 @@ class SetOpsTests extends munit.ScalaCheckSuite { forAll(genMat, Gen.listOfN(5, genMat)) { (v0, prods) => val ord = Ordering.by[Vector[Vector[Double]], Double](norm) - val res = SetOps.greedySearch(5, v0, prods)({(v, ps) => ps.foldLeft(v)(mult(_, _))})(ord) + val res = SetOps.greedySearch(5, v0, prods)({ (v, ps) => + ps.foldLeft(v)(mult(_, _)) + })(ord) val normRes = norm(res) val naive = norm(prods.foldLeft(v0)(mult(_, _))) assert(normRes <= naive) @@ -513,7 +532,10 @@ class SetOpsTests extends munit.ScalaCheckSuite { val bb = pb(b) val bc = pc(b) if (!right(b)) { - assert(!left(b), s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}") + assert( + !left(b), + s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}" + ) } } } @@ -529,7 +551,10 @@ class SetOpsTests extends munit.ScalaCheckSuite { val bb = pb(b) val bc = pc(b) if (left(b)) { - assert(right(b), s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}") + assert( + right(b), + s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}" + ) } } } @@ -544,7 +569,11 @@ class SetOpsTests extends munit.ScalaCheckSuite { val ba = pa(b) val bb = pb(b) val bc = pc(b) - assertEquals(left(b), right(b), s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}") + assertEquals( + left(b), + right(b), + s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}" + ) } } } @@ -559,18 +588,28 @@ class SetOpsTests extends munit.ScalaCheckSuite { val bb = pb(b) val bc = pc(b) if (!right(b)) { - assert(!left(b), s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}") + assert( + !left(b), + s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}" + ) } } } } test("A1 x B1 - A2 x B2 = (A1 n A2)x(B1 - B2) u (A1 - A2)xB1") { - forAll { (a1: Predicate[Byte], a2: Predicate[Byte], b1: Predicate[Byte], b2: Predicate[Byte], checks: List[(Byte, Byte)]) => - val left = a1.product(b1) - a2.product(b2) - val right = (a1 && a2).product(b1 - b2) || (a1 - a2).product(b1) - checks.foreach { ab => - assertEquals(left(ab), right(ab)) - } + forAll { + ( + a1: Predicate[Byte], + a2: Predicate[Byte], + b1: Predicate[Byte], + b2: Predicate[Byte], + checks: List[(Byte, Byte)] + ) => + val left = a1.product(b1) - a2.product(b2) + val right = (a1 && a2).product(b1 - b2) || (a1 - a2).product(b1) + checks.foreach { ab => + assertEquals(left(ab), right(ab)) + } } } } diff --git a/jsapi/src/main/scala/org/bykn/bosatsu/jsapi/JsApi.scala b/jsapi/src/main/scala/org/bykn/bosatsu/jsapi/JsApi.scala index 4d8ce6020..261792d2a 100644 --- a/jsapi/src/main/scala/org/bykn/bosatsu/jsapi/JsApi.scala +++ b/jsapi/src/main/scala/org/bykn/bosatsu/jsapi/JsApi.scala @@ -25,12 +25,14 @@ object JsApi { class EvalSuccess(val result: js.Any) extends js.Object - /** - * mainPackage can be null, in which case we find the package - * in mainFile - */ + /** mainPackage can be null, in which case we find the package in mainFile + */ @JSExport - def evaluate(mainPackage: String, mainFile: String, files: js.Dictionary[String]): EvalSuccess | Error = { + def evaluate( + mainPackage: String, + mainFile: String, + files: js.Dictionary[String] + ): EvalSuccess | Error = { val baseArgs = "--package_root" :: "" :: "--color" :: "html" :: Nil val main = if (mainPackage != null) "--main" :: mainPackage :: baseArgs @@ -39,8 +41,9 @@ object JsApi { case Left(err) => new Error(s"error: ${err.getMessage}") case Right(module.Output.EvaluationResult(_, tpe, resDoc)) => - val tDoc = rankn.Type.fullyResolvedDocument.document(tpe) - val doc = resDoc.value + (Doc.lineOrEmpty + Doc.text(": ") + tDoc).nested(4) + val tDoc = rankn.Type.fullyResolvedDocument.document(tpe) + val doc = + resDoc.value + (Doc.lineOrEmpty + Doc.text(": ") + tDoc).nested(4) new EvalSuccess(doc.render(80)) case Right(other) => new Error(s"internal error. got unexpected result: $other") @@ -49,16 +52,16 @@ object JsApi { def jsonToAny(j: Json): js.Any = j match { - case Json.JString(s) => s + case Json.JString(s) => s case Json.JNumberStr(str) => // javascript only really has doubles try str.toDouble catch { case (_: NumberFormatException) => str } - case Json.JBool.True => true + case Json.JBool.True => true case Json.JBool.False => false - case Json.JNull => null + case Json.JNull => null case Json.JArray(items) => val ary = new js.Array[js.Any](items.size) items.iterator.zipWithIndex.foreach { case (j, idx) => @@ -66,23 +69,28 @@ object JsApi { } ary case Json.JObject(kvs) => - js.Dictionary[js.Any]( - kvs.map { case (k, v) => - (k, jsonToAny(v)) - } :_*) + js.Dictionary[js.Any](kvs.map { case (k, v) => + (k, jsonToAny(v)) + }: _*) } - /** - * mainPackage can be null, in which case we find the package - * in mainFile - */ + /** mainPackage can be null, in which case we find the package in mainFile + */ @JSExport - def evaluateJson(mainPackage: String, mainFile: String, files: js.Dictionary[String]): EvalSuccess | Error = { + def evaluateJson( + mainPackage: String, + mainFile: String, + files: js.Dictionary[String] + ): EvalSuccess | Error = { val baseArgs = "--package_root" :: "" :: "--color" :: "html" :: Nil val main = if (mainPackage != null) "--main" :: mainPackage :: baseArgs else "--main_file" :: mainFile :: baseArgs - module.runWith(files)("json" :: "write" :: "--output" :: "" :: main ::: makeInputArgs(files.keys)) match { + module.runWith(files)( + "json" :: "write" :: "--output" :: "" :: main ::: makeInputArgs( + files.keys + ) + ) match { case Left(err) => new Error(s"error: ${err.getMessage}") case Right(module.Output.JsonOutput(json, _)) => diff --git a/jsui/src/main/scala/org/bykn/bosatsu/jsui/Action.scala b/jsui/src/main/scala/org/bykn/bosatsu/jsui/Action.scala index f9909fdde..99af4b9a3 100644 --- a/jsui/src/main/scala/org/bykn/bosatsu/jsui/Action.scala +++ b/jsui/src/main/scala/org/bykn/bosatsu/jsui/Action.scala @@ -13,5 +13,6 @@ object Action { } case class CodeEntered(text: String) extends Action case class Run(cmd: Cmd) extends Action - case class CmdCompleted(result: String, duration: Duration, cmd: Cmd) extends Action -} \ No newline at end of file + case class CmdCompleted(result: String, duration: Duration, cmd: Cmd) + extends Action +} diff --git a/jsui/src/main/scala/org/bykn/bosatsu/jsui/App.scala b/jsui/src/main/scala/org/bykn/bosatsu/jsui/App.scala index 58cce3aff..964dd65ac 100644 --- a/jsui/src/main/scala/org/bykn/bosatsu/jsui/App.scala +++ b/jsui/src/main/scala/org/bykn/bosatsu/jsui/App.scala @@ -4,7 +4,7 @@ import cats.effect.IO import cats.effect.kernel.Resource class App extends ff4s.App[IO, State, Action] with View { - def store: Resource[IO,ff4s.Store[IO,State,Action]] = Store.value + def store: Resource[IO, ff4s.Store[IO, State, Action]] = Store.value } -object MainApp extends ff4s.IOEntryPoint(new App) \ No newline at end of file +object MainApp extends ff4s.IOEntryPoint(new App) diff --git a/jsui/src/main/scala/org/bykn/bosatsu/jsui/Store.scala b/jsui/src/main/scala/org/bykn/bosatsu/jsui/Store.scala index 115b65447..3a5b86934 100644 --- a/jsui/src/main/scala/org/bykn/bosatsu/jsui/Store.scala +++ b/jsui/src/main/scala/org/bykn/bosatsu/jsui/Store.scala @@ -8,21 +8,30 @@ import org.scalajs.dom.window.localStorage import Action.Cmd object Store { - val memoryMain = new MemoryMain[Either[Throwable, *], String](_.split("/", -1).toList) + val memoryMain = + new MemoryMain[Either[Throwable, *], String](_.split("/", -1).toList) type HandlerFn = memoryMain.Output => String def cmdHandler(cmd: Cmd): (List[String], HandlerFn) = cmd match { case Cmd.Eval => val args = List( - "eval", "--input", "root/WebDemo", "--package_root", "root", - "--main_file", "root/WebDemo", "--color", "html" + "eval", + "--input", + "root/WebDemo", + "--package_root", + "root", + "--main_file", + "root/WebDemo", + "--color", + "html" ) - + val handler: HandlerFn = { case memoryMain.Output.EvaluationResult(_, tpe, resDoc) => val tDoc = rankn.Type.fullyResolvedDocument.document(tpe) - val doc = resDoc.value + (Doc.lineOrEmpty + Doc.text(": ") + tDoc).nested(4) + val doc = + resDoc.value + (Doc.lineOrEmpty + Doc.text(": ") + tDoc).nested(4) doc.render(80) case other => s"internal error. got unexpected result: $other" @@ -30,8 +39,15 @@ object Store { (args, handler) case Cmd.Test => val args = List( - "test", "--input", "root/WebDemo", "--package_root", "root", - "--test_file", "root/WebDemo", "--color", "html" + "test", + "--input", + "root/WebDemo", + "--package_root", + "root", + "--test_file", + "root/WebDemo", + "--color", + "html" ) val handler: HandlerFn = { case memoryMain.Output.TestOutput(resMap, color) => @@ -43,7 +59,13 @@ object Store { (args, handler) case Cmd.Show => val args = List( - "show", "--input", "root/WebDemo", "--package_root", "root", "--color", "html" + "show", + "--input", + "root/WebDemo", + "--package_root", + "root", + "--color", + "html" ) val handler: HandlerFn = { case memoryMain.Output.ShowOutput(packs, ifaces, _) => @@ -74,7 +96,7 @@ object Store { case Left(err) => memoryMain.mainExceptionToString(err) match { case Some(e) => e - case None => s"unknown error: $err" + case None => s"unknown error: $err" } } @@ -91,29 +113,29 @@ object Store { def initialState: IO[State] = IO(localStorage.getItem("state")).flatMap { init => if (init == null) IO.pure(State.Init) - else (State.stringToState(init) match { - case Right(s) => IO.pure(s) - case Left(err) => - IO.println(s"could not deserialize:\n\n$init\n\n$err") - .as(State.Init) - }) + else + (State.stringToState(init) match { + case Right(s) => IO.pure(s) + case Left(err) => + IO.println(s"could not deserialize:\n\n$init\n\n$err") + .as(State.Init) + }) } val value: Resource[IO, ff4s.Store[IO, State, Action]] = for { - init <- Resource.liftK(initialState) - store <- ff4s.Store[IO, State, Action](init) { store => - { - case Action.CodeEntered(text) => - { + init <- Resource.liftK(initialState) + store <- ff4s.Store[IO, State, Action](init) { store => + { + case Action.CodeEntered(text) => { case State.Init | State.WithText(_) => (State.WithText(text), None) - case c @ State.Compiling(_) => (c, None) - case comp @ State.Compiled(_, _, _) => (comp.copy(editorText = text), None) + case c @ State.Compiling(_) => (c, None) + case comp @ State.Compiled(_, _, _) => + (comp.copy(editorText = text), None) } - case Action.Run(cmd) => - { - case State.Init => (State.Init, None) + case Action.Run(cmd) => { + case State.Init => (State.Init, None) case c @ State.Compiling(_) => (c, None) case ht: State.HasText => val action = @@ -122,13 +144,14 @@ object Store { start <- IO.monotonic output <- run(cmd, ht.editorText) end <- IO.monotonic - _ <- store.dispatch(Action.CmdCompleted(output, end - start, cmd)) + _ <- store.dispatch( + Action.CmdCompleted(output, end - start, cmd) + ) } yield () (State.Compiling(ht), Some(action)) } - case Action.CmdCompleted(result, dur, _) => - { + case Action.CmdCompleted(result, dur, _) => { case State.Compiling(ht) => val next = State.Compiled(ht.editorText, result, dur) (next, Some(stateSetter(next))) @@ -137,7 +160,7 @@ object Store { println(s"unexpected Complete: $result => $unexpected") (unexpected, None) } - } - } - } yield store -} \ No newline at end of file + } + } + } yield store +} From 9998b1eaa4f6e7c3c85e1318f64887abf8de027a Mon Sep 17 00:00:00 2001 From: Scala Steward Date: Sun, 25 Feb 2024 18:49:34 +0000 Subject: [PATCH 3/3] Add 'Reformat with scalafmt 3.8.0' to .git-blame-ignore-revs --- .git-blame-ignore-revs | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .git-blame-ignore-revs diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..9847ee5c7 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Scala Steward: Reformat with scalafmt 3.8.0 +12bd519e9e132dd6ce125b4af67f5af9e4eda59e