From eac197a99528a30d8cb403e68238d194aea55ff6 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Sun, 4 Nov 2018 02:56:06 +0000 Subject: [PATCH 1/8] Add TotalityCheck and tests --- .../org/bykn/bosatsu/TotalityCheck.scala | 366 ++++++++++++++++++ .../org/bykn/bosatsu/rankn/DefinedType.scala | 6 + .../scala/org/bykn/bosatsu/rankn/Type.scala | 7 + .../org/bykn/bosatsu/rankn/TypeEnv.scala | 8 +- .../scala/org/bykn/bosatsu/ParserTest.scala | 22 +- .../scala/org/bykn/bosatsu/TotalityTest.scala | 143 +++++++ 6 files changed, 542 insertions(+), 10 deletions(-) create mode 100644 core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala create mode 100644 core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala diff --git a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala new file mode 100644 index 000000000..3d296a705 --- /dev/null +++ b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala @@ -0,0 +1,366 @@ +package org.bykn.bosatsu + +import cats.{Monad, Applicative} +import cats.data.NonEmptyList +import cats.implicits._ + +import rankn.{Type, TypeEnv} +import Pattern._ + +object TotalityCheck { + type Cons = (PackageName, ConstructorName) + type Res[+A] = Either[NonEmptyList[Error], A] + type Patterns = List[Pattern[Cons, Type]] + + sealed abstract class Error + case class ArityMismatch(cons: Cons, in: Pattern[Cons, Type], env: TypeEnv, expected: Int, found: Int) extends Error + case class UnknownConstructor(cons: Cons, in: Pattern[Cons, Type], env: TypeEnv) extends Error + case class UntypedPattern(pat: Pattern[Cons, Type], env: TypeEnv) extends Error +} + +case class TotalityCheck(inEnv: TypeEnv) { + import TotalityCheck._ + + /** + * in the given type environment, return + * a list of matches that would make the current set of matches total + * + * Note, a law here is that: + * missingBranches(te, t, branches).flatMap { ms => + * assert(missingBranches(te, t, branches ::: ms).isEmpty) + * } + */ + def missingBranches(branches: Patterns): Res[Patterns] = { + def step(patMiss: (Patterns, Patterns)): Res[Either[(Patterns, Patterns), Patterns]] = { + val (branches, missing0) = patMiss + branches match { + case Nil => + Right(Right(missing0)) + case h :: tail => + difference(missing0, h) + .map { newMissing => + Left((tail, newMissing)) + } + } + } + + Monad[Res].tailRecM((branches, List(WildCard): Patterns))(step _) + } + + def isTotal(branches: Patterns): Res[Boolean] = + missingBranches(branches).map(_.isEmpty) + + /** + * This is like a non-symmetric set difference, where we are removing the right from the left + */ + def difference(left: Patterns, right: Pattern[Cons, Type]): Res[Patterns] = + left.traverse(difference0(_, right)).map(_.flatten) + + @annotation.tailrec + private def matchesEmpty(lp: ListPat[Cons, Type]): Boolean = + lp.parts match { + case Nil => true + case Left(_) :: tail => matchesEmpty(ListPat(tail)) + case Right(_) :: _ => false + } + + def difference0(left: Pattern[Cons, Type], right: Pattern[Cons, Type]): Res[Patterns] = { + isTotal(right).flatMap { + case true => Right(Nil): Res[Patterns] + case false => + (left, right) match { + case (WildCard | Var(_), Literal(_)) => + // the left is infinite, and the right is just one value + Right(left :: Nil) + case (WildCard | Var(_), lp@ListPat(_)) => + // _ is the same as [*_] for well typed expressions + difference0(ListPat(Left(None) :: Nil), lp) + case (ListPat(lp), rightList@ListPat(rp)) => + (lp, rp) match { + case (Nil, Nil) => + // total overlap + Right(Nil) + case (Nil, Right(_) :: _) => + // a list of 1 or more, can't match less + Right(left :: Nil) + case (Nil, Left(_) :: tail) => + // we can have zero or more, 1 or more clearly can't match: + // if the tail can match 0, we anhilate, otherwise not + if (matchesEmpty(ListPat(tail))) Right(Nil) + else Right(left :: Nil) + case (Right(_) :: _, Nil) => + // left has at least one + Right(left :: Nil) + case (Right(lhead) :: ltail, Right(rhead) :: rtail) => + // we use productDifference here + productDifference((lhead, rhead) :: (ListPat(ltail), ListPat(rtail)) :: Nil) + .map { listOfList => + listOfList.map { + case h :: ListPat(tail) :: Nil => + ListPat(Right(h) :: tail) + case other => + sys.error(s"expected exactly two items: $other") + } + } + case (Left(_) :: tail, Nil) => + // if tail matches empty, then we can only match 1 or more + // else, these are disjoint + if (matchesEmpty(ListPat(tail))) + Right(ListPat(Right(WildCard) :: lp) :: Nil) + else Right(left :: Nil) + case (Left(_) :: tail, Right(_) :: _) => + // The right hand side can't match a zero length list + val zero = ListPat(tail) + val oneOrMore = ListPat(Right(WildCard) :: lp) + difference0(oneOrMore, right) + .map(zero :: _) + case (_, Left(_) :: rtail) if matchesEmpty(ListPat(rtail)) => + // this is a total match + Right(Nil) + case (_, Left(_) :: rtail) => + // In this branch, the right cannot match + // the empty list, but the left side can + // we could in principle match a finite + // list from either direction, so we reverse + // and try again + difference0(ListPat(lp.reverse), ListPat(rp.reverse)) + .map(_.map { + case ListPat(diff) => ListPat(diff.reverse) + case other => sys.error(s"unreachable: list patterns can't difference to non-list: $other") + }) + } + case (WildCard | Var(_), PositionalStruct(nm, ps)) => + inEnv.definedTypeFor(nm) match { + case None => Left(NonEmptyList.of(UnknownConstructor(nm, right, inEnv))) + case Some(dt) => + dt.constructors.traverse { + case (c, params, _) if (dt.packageName, c) == nm => + /* + * At each position we compute the difference with _ + * then make: + * Struct(d1, _, _), Struct(_, d2, _), ... + */ + def poke[M[_]: Applicative, A](items: List[A])(fn: A => M[List[A]]): M[List[List[A]]] = + items match { + case Nil => Applicative[M].pure(Nil) + case h :: tail => + val ptail = poke(tail)(fn) + val head = fn(h) + (head, ptail).mapN { (heads, tails) => + val t2 = tails.map(h :: _) + val h1 = heads.map(_ :: tail) + h1 ::: t2 + } + } + + // for this one, we need to compute the difference for each: + poke(ps) { p => difference0(WildCard, p) } + .map(_.map(PositionalStruct(nm, _))) + + case (c, params, _) => + // TODO, this could be smarter + // we need to learn how to deal with typed generics + def argToPat(t: (ParamName, Type)): Pattern[Cons, Type] = + if (Type.hasNoVars(t._2)) Annotation(WildCard, t._2) + else WildCard + + Right(List(PositionalStruct((dt.packageName, c), params.map(argToPat)))) + } + .map(_.flatten) + } + case (llit@Literal(l), Literal(r)) => + if (l == r) Right(Nil): Res[Patterns] + else Right(llit :: Nil): Res[Patterns] + case (PositionalStruct(ln, lp), PositionalStruct(rn, rp)) if ln == rn => + // we have two matching structs + val arityMatch = + checkArity(ln, lp.size, left) + .product(checkArity(rn, rp.size, right)) + .as(()) + productDifference(lp zip rp).map { pats => + pats.map(PositionalStruct(ln, _)) + } + case _ => + // There is no overlap + Right(left :: Nil): Res[Patterns] + } + } + } + + def intersection( + left: Pattern[Cons, Type], + right: Pattern[Cons, Type]): Res[List[Pattern[Cons, Type]]] = + (left, right) match { + case (WildCard | Var(_), v) => Right(List(v)) + case (v, WildCard | Var(_)) => Right(List(v)) + case (Annotation(p, _), t) => intersection(p, t) + case (t, Annotation(p, _)) => intersection(t, p) + case (Literal(a), Literal(b)) => + if (a == b) Right(List(left)) + else Right(Nil) + case (Literal(_), _) => Right(Nil) + case (_, Literal(_)) => Right(Nil) + case (ListPat(leftL), ListPat(rightL)) => + (leftL, rightL) match { + case (Nil, Nil) => Right(List(left)) + case (Nil, Right(_) :: _) => Right(Nil) + case (_, Left(_) :: tail) if matchesEmpty(ListPat(tail)) => Right(List(left)) + case (Nil, Left(_) :: _) => Right(List(left)) + case (Right(_) :: _, Nil) => Right(Nil) + case (Right(lh) :: lt, Right(rh) :: rt) => + intersection(lh, rh).flatMap { + case Nil => Right(Nil) + case nonEmpty => + intersection(ListPat(lt), ListPat(rt)) + .map(_.flatMap { + case ListPat(ts) => nonEmpty.map { h => ListPat(Right(h) :: ts) } + case other => sys.error(s"unreachable: list patterns can't intersect to non-list: $other") + }) + } + case (Right(lh) :: lt, Left(rh) :: rt) => + val zero = ListPat(rt) + val oneOrMore = ListPat(Right(WildCard) :: rightL) + // a n (b0 + b1) = (a n b0) + (a n b1) + for { + withZ <- intersection(left, zero) + with0 <- intersection(left, oneOrMore) + } yield withZ ::: with0 + case (Left(_) :: lt, Left(_) :: rt) => + intersection(ListPat(lt), ListPat(rt)) + .map(_.map { + case ListPat(tail) => ListPat(Left(None) :: tail) + case other => sys.error(s"unreachable: list patterns can't intersect to non-list: $other") + }) + case (_, _) => + // intersection is symmetric + intersection(right, left) + } + case (ListPat(_), _) => Right(Nil) + case (_, ListPat(_)) => Right(Nil) + case (PositionalStruct(ln, lps), PositionalStruct(rn, rps)) => + if (ln == rn) { + val check = for { + _ <- checkArity(ln, lps.size, left) + _ <- checkArity(rn, rps.size, right) + } yield () + + type ResList[A] = Res[List[A]] + implicit val app = Applicative[Res].compose(Applicative[List]) + val parts = check.flatMap { _ => + lps.zip(rps).traverse[ResList, Pattern[Cons, Type]] { + case (l, r) => intersection(l, r) + } + } + + parts.map(_.map(PositionalStruct(ln, _))) + } + else Right(Nil) + } + + /** + * There the list is a tuple or product pattern + * the left and right should be the same size and the result will be a list of lists + * with the inner having the same size + */ + def productDifference( + zip: List[(Pattern[Cons, Type], Pattern[Cons, Type])] + ): Res[List[List[Pattern[Cons, Type]]]] = + /* + * (Left(_), _) -- (Right(_), Right(_)) = (Left(_), _) + * (Left(_), _) -- (Left(_), Right(_)) = (Left(_), Left(_)) + * + * (Left(_), _, _) -- (Left(_), Right(_), Right(_)) = (L, L, R), (L, R, L), (L, R, R) + * + * (Left(_), _) -- (Left(Right(_)), Right(_)) = (L(L(_)), _), (L(R), L(_)) + * + * This seems to be difference of a product of sets. The formula for this + * seems to be: + * + * (a0 x a1) - (b0 x b1) = (a0 - b0) x a1 + (a0 n b0) x (a1 - b1) + */ + zip match { + case Nil => Right(Nil) // complete match + case (lh, rh) :: tail => + type Result = Res[List[List[Pattern[Cons, Type]]]] + val headDiff: Result = + difference0(lh, rh).map(_.map(_ :: tail.map(_._1))) + + val tailDiff: Result = + intersection(lh, rh).flatMap { + case Nil => + // we don't need to recurse on the rest + Right(Nil) + case nonEmpty => + productDifference(tail).map { pats => + nonEmpty.flatMap { intr => + pats.map(intr :: _) + } + } + } + + (headDiff, tailDiff).mapN(_ ::: _) + } + + /** + * Constructors must match all items to be legal + */ + private def checkArity(nm: Cons, size: Int, pat: Pattern[Cons, Type]): Res[Unit] = + inEnv.typeConstructors.get(nm) match { + case None => Left(NonEmptyList.of(UnknownConstructor(nm, pat, inEnv))) + case Some((_, params, _)) => + val cmp = params.lengthCompare(size) + if (cmp == 0) Right(()) + else Left(NonEmptyList.of(ArityMismatch(nm, pat, inEnv, size, params.size))) + } + + // def typeOf(p: Pattern[Cons, Type]): Res[Type] = { + // def err = Left(NonEmptyList.of(UntypedPattern(p, inEnv))) + // p match { + // case Annotation(_, t) => Right(t) + // case Literal(lit) => Right(Type.getTypeOf(lit)) + // case WildCard | Var(_) => err + // case ListPat(pats) => + // // we assume we are well typed, so the first right hand side tells us + // val listTypeParam = + // pats.foldLeft(err: Res[Type]) { + // case (e, Right(p)) => + // typeOf(p) match { + // case Left(_) => e + // case right => right + // } + // case (e, _) => e + // } + + // listTypeParam.map { a => Type.TyApply(Type.ListType, a) } + // case PositionalStruct(nm, _) => + // inEnv.definedTypeFor(nm) match { + // case None => + // Left(NonEmptyList.of(UnknownConstructor(nm, p, inEnv))) + // case Some(dt) => + // Right(dt.fullType) + // } + // } + // } + + /** + * Can a given pattern match everything for a the current type + */ + private def isTotal(p: Pattern[Cons, Type]): Res[Boolean] = + p match { + case Pattern.WildCard | Pattern.Var(_) => Right(true) + case Pattern.Literal(_) => Right(false) // literals are not total + case Pattern.ListPat(_) => Right(false) // empty list pattern matching *COULD* be total, if we understood Void + case Pattern.Annotation(p, _) => isTotal(p) + case Pattern.PositionalStruct(name, params) => + // This is total if the struct has a single constructor AND each of the patterns is total + inEnv.definedTypeFor(name) match { + case None => + Left(NonEmptyList.of(UnknownConstructor(name, p, inEnv))) + case Some(dt) => + if (dt.isStruct) params.forallM(isTotal) + else Right(false) + } + } + + +} diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala index bdebd5424..4c16afa60 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala @@ -27,6 +27,12 @@ case class DefinedType( def toOpaque: DefinedType = copy(constructors = Nil) + + /** + * This may be a ForAll type if there are typeParams + */ + def fullType: Type = + Type.forAll(typeParams, toTypeTyConst) } object DefinedType { diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala index 805809ec7..37f9f8dfc 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala @@ -30,6 +30,13 @@ object Type { case TyVar(_) | TyMeta(_) => Nil } + def hasNoVars(t: Type): Boolean = + t match { + case TyConst(c) => true + case TyVar(_) | TyMeta(_) | ForAll(_, _) => false + case TyApply(on, arg) => hasNoVars(on) && hasNoVars(arg) + } + @annotation.tailrec final def forAll(vars: List[Var.Bound], in: Type): Type = vars match { diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala index aca830957..000cbfcdd 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala @@ -35,7 +35,7 @@ case class TypeEnv( // TODO to support parameter named patterns we'd need to know the // parameter names - def typeConstructors: Map[(PackageName, ConstructorName), (List[Type.Var], List[Type], Type.Const.Defined)] = + lazy val typeConstructors: Map[(PackageName, ConstructorName), (List[Type.Var], List[Type], Type.Const.Defined)] = constructors.map { case (pc, (params, dt, _)) => (pc, (dt.typeParams, @@ -43,6 +43,12 @@ case class TypeEnv( dt.toTypeConst)) } + def definedTypeFor(c: (PackageName, ConstructorName)): Option[DefinedType] = + typeConstructors.get(c).flatMap { case (_, _, d) => toDefinedType(d) } + + def toDefinedType(t: Type.Const.Defined): Option[DefinedType] = + definedTypes.get((t.packageName, TypeName(t.name))) + } object TypeEnv { diff --git a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala index eb504f528..e60a2566f 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala @@ -12,15 +12,7 @@ import Parser.Indy import Generators.shrinkDecl -class ParserTest extends FunSuite { - // This is so we can make Declarations without the region - private[this] implicit val emptyRegion: Region = Region(0, 0) - - implicit val generatorDrivenConfig = - //PropertyCheckConfiguration(minSuccessful = 500) - PropertyCheckConfiguration(minSuccessful = 50) - //PropertyCheckConfiguration(minSuccessful = 5) - +object TestParseUtils { def region(s0: String, idx: Int): String = if (s0.isEmpty) s"empty string, idx = $idx" else if (s0.length == idx) { @@ -39,6 +31,18 @@ class ParserTest extends FunSuite { else if (s1(0) == s2(0)) firstDiff(s1.tail, s2.tail) else s"${s1(0).toInt}: ${s1.take(20)}... != ${s2(0).toInt}: ${s2.take(20)}..." +} + +class ParserTest extends FunSuite { + import TestParseUtils._ + // This is so we can make Declarations without the region + private[this] implicit val emptyRegion: Region = Region(0, 0) + + implicit val generatorDrivenConfig = + //PropertyCheckConfiguration(minSuccessful = 500) + PropertyCheckConfiguration(minSuccessful = 50) + //PropertyCheckConfiguration(minSuccessful = 5) + def parseTest[T](p: Parser[T], str: String, expected: T, exidx: Int) = p.parse(str) match { case Parsed.Success(t, idx) => diff --git a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala new file mode 100644 index 000000000..8484751eb --- /dev/null +++ b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala @@ -0,0 +1,143 @@ +package org.bykn.bosatsu + +import org.scalatest.FunSuite + +import rankn._ + +import Parser.Combinators +import fastparse.all.Parsed + +class TotalityTest extends FunSuite { + import TestParseUtils._ + + val pack = PackageName.parts("Test") + def const(t: String): Type = + Type.TyConst(Type.Const.Defined(pack, t)) + + def typeEnvOf(str: String): TypeEnv = + Statement.parser.parse(str) match { + case Parsed.Success(stmt, idx) => + assert(idx == str.length) + val prog = Program.fromStatement( + pack, + { tpe => Type.Const.Defined(pack, tpe) }, + { cons => (pack, ConstructorName(cons)) }, + stmt) + prog.types + case Parsed.Failure(exp, idx, extra) => + fail(s"failed to parse: $str: $exp at $idx in region ${region(str, idx)} with trace: ${extra.traced.trace}") + sys.error("could not produce TypeEnv") + } + + def patterns(str: String): List[Pattern[(PackageName, ConstructorName), Type]] = + Pattern.parser.listSyntax.parse(str) match { + case Parsed.Success(pats, idx) => + pats.map { pat => + pat + .mapName { n => (pack, ConstructorName(n)) } + .mapType(_.toType { n => Type.Const.Defined(pack, n) }) + } + case Parsed.Failure(exp, idx, extra) => + fail(s"failed to parse: $str: $exp at $idx in region ${region(str, idx)} with trace: ${extra.traced.trace}") + sys.error("could not produce TypeEnv") + } + + def notTotal(te: TypeEnv, pats: List[Pattern[(PackageName, ConstructorName), Type]]) = + TotalityCheck(te).isTotal(pats) match { + case Right(res) => assert(!res, pats.toString) + case Left(errs) => fail(errs.toString) + } + + def testTotality(te: TypeEnv, pats: List[Pattern[(PackageName, ConstructorName), Type]], tight: Boolean = false) = { + TotalityCheck(te).isTotal(pats) match { + case Right(res) => assert(res) + case Left(errs) => fail(errs.toString) + } + // any missing pattern shouldn't be total: + def allButOne[A](head: A, tail: List[A]): List[List[A]] = + tail match { + case Nil => Nil + case h :: rest => + // we can either delete the head or one from the tail: + val keepHead = allButOne(h, rest).map(head :: _) + tail :: keepHead + } + + pats match { + case h :: tail if tight => + allButOne(h, tail).foreach(notTotal(te, _)) + case _ => () + } + } + + + + test("totality test") { + val te = typeEnvOf("""# +struct Unit +""") + val pats = patterns("[Unit]") + testTotality(te, pats) + + + val te1 = typeEnvOf("""# +struct Tuple2(a, b) +""") + testTotality(te1, patterns("[Tuple2(_, _)]")) + testTotality(te1, patterns("[Tuple2(_, 0), Tuple2(_, _)]")) + notTotal(te1, patterns("[Tuple2(_, 0)]")) + } + + test("test Option types") { + val te = typeEnvOf("""# +enum Option: None, Some(get) +""") + testTotality(te, patterns("[Some(_), None]"), tight = true) + testTotality(te, patterns("[Some(_), _]")) + testTotality(te, patterns("[Some(1), Some(x), None]")) + testTotality(te, patterns("[Some(Some(_)), Some(None), None]"), tight = true) + + notTotal(te, patterns("[Some(_)]")) + notTotal(te, patterns("[Some(Some(_)), None]")) + notTotal(te, patterns("[None]")) + notTotal(te, patterns("[]")) + } + + test("test Either types") { + val te = typeEnvOf("""# +enum Either: Left(l), Right(r) +""") + testTotality(te, patterns("[Left(_), Right(_)]")) + testTotality(te, + patterns("[Left(Right(_)), Left(Left(_)), Right(Left(_)), Right(Right(_))]"), + tight = true) + + notTotal(te, patterns("[Left(_)]")) + notTotal(te, patterns("[Right(_)]")) + notTotal(te, patterns("[Left(Right(_)), Right(_)]")) + } + + test("test List matching") { + testTotality(TypeEnv.empty, patterns("[[], [h, *tail]]"), tight = true) + testTotality(TypeEnv.empty, patterns("[[], [h, *tail], [h0, h1, *tail]]"), tight = true) + + notTotal(TypeEnv.empty, patterns("[[], [h, *tail, _]]")) + notTotal(TypeEnv.empty, patterns("[[], [*tail, _]]")) + } + + test("multiple struct compose") { + val te = typeEnvOf("""# +enum Either: Left(l), Right(r) +enum Option: None, Some(get) +""") + + testTotality(te, patterns("[None, Some(Left(_)), Some(Right(_))]"), tight = true) + } + + test("compose List with structs") { + val te = typeEnvOf("""# +enum Either: Left(l), Right(r) +""") + testTotality(te, patterns("[[Left(_), *_], [Right(_), *_], [], [_, _, *_]]"), tight = true) + } +} From 6f069227f215e988ba4c1b66d7e9952473015225 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Sun, 4 Nov 2018 03:04:56 +0000 Subject: [PATCH 2/8] minor cleanup --- .../org/bykn/bosatsu/TotalityCheck.scala | 41 +++++-------------- 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala index 3d296a705..abe8a2e71 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala @@ -47,6 +47,12 @@ case class TotalityCheck(inEnv: TypeEnv) { Monad[Res].tailRecM((branches, List(WildCard): Patterns))(step _) } + /** + * Return true of this set of branches represents a total match + * + * useful for testing, but a better error message will be obtained from using + * missingBranches + */ def isTotal(branches: Patterns): Res[Boolean] = missingBranches(branches).map(_.isEmpty) @@ -313,35 +319,6 @@ case class TotalityCheck(inEnv: TypeEnv) { else Left(NonEmptyList.of(ArityMismatch(nm, pat, inEnv, size, params.size))) } - // def typeOf(p: Pattern[Cons, Type]): Res[Type] = { - // def err = Left(NonEmptyList.of(UntypedPattern(p, inEnv))) - // p match { - // case Annotation(_, t) => Right(t) - // case Literal(lit) => Right(Type.getTypeOf(lit)) - // case WildCard | Var(_) => err - // case ListPat(pats) => - // // we assume we are well typed, so the first right hand side tells us - // val listTypeParam = - // pats.foldLeft(err: Res[Type]) { - // case (e, Right(p)) => - // typeOf(p) match { - // case Left(_) => e - // case right => right - // } - // case (e, _) => e - // } - - // listTypeParam.map { a => Type.TyApply(Type.ListType, a) } - // case PositionalStruct(nm, _) => - // inEnv.definedTypeFor(nm) match { - // case None => - // Left(NonEmptyList.of(UnknownConstructor(nm, p, inEnv))) - // case Some(dt) => - // Right(dt.fullType) - // } - // } - // } - /** * Can a given pattern match everything for a the current type */ @@ -349,7 +326,11 @@ case class TotalityCheck(inEnv: TypeEnv) { p match { case Pattern.WildCard | Pattern.Var(_) => Right(true) case Pattern.Literal(_) => Right(false) // literals are not total - case Pattern.ListPat(_) => Right(false) // empty list pattern matching *COULD* be total, if we understood Void + case Pattern.ListPat(Left(_) :: rest) => + Right(matchesEmpty(ListPat(rest))) + case Pattern.ListPat(_) => + // either can't match everything on the front or back + Right(false) case Pattern.Annotation(p, _) => isTotal(p) case Pattern.PositionalStruct(name, params) => // This is total if the struct has a single constructor AND each of the patterns is total From 6de89979aae8902dfeeffac82c5794d1d6465aac Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Sun, 4 Nov 2018 03:14:12 +0000 Subject: [PATCH 3/8] use the arity checks, move list difference into a method --- .../org/bykn/bosatsu/TotalityCheck.scala | 127 +++++++++--------- 1 file changed, 67 insertions(+), 60 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala index abe8a2e71..61ac5f505 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala @@ -70,6 +70,64 @@ case class TotalityCheck(inEnv: TypeEnv) { case Right(_) :: _ => false } + private def difference0List( + lp: List[Either[Option[String], Pattern[Cons, Type]]], + rp: List[Either[Option[String], Pattern[Cons, Type]]]): Res[Patterns] = { + (lp, rp) match { + case (Nil, Nil) => + // total overlap + Right(Nil) + case (Nil, Right(_) :: _) => + // a list of 1 or more, can't match less + Right(ListPat(lp) :: Nil) + case (Nil, Left(_) :: tail) => + // we can have zero or more, 1 or more clearly can't match: + // if the tail can match 0, we anhilate, otherwise not + if (matchesEmpty(ListPat(tail))) Right(Nil) + else Right(ListPat(lp) :: Nil) + case (Right(_) :: _, Nil) => + // left has at least one + Right(ListPat(lp) :: Nil) + case (Right(lhead) :: ltail, Right(rhead) :: rtail) => + // we use productDifference here + productDifference((lhead, rhead) :: (ListPat(ltail), ListPat(rtail)) :: Nil) + .map { listOfList => + listOfList.map { + case h :: ListPat(tail) :: Nil => + ListPat(Right(h) :: tail) + case other => + sys.error(s"expected exactly two items: $other") + } + } + case (Left(_) :: tail, Nil) => + // if tail matches empty, then we can only match 1 or more + // else, these are disjoint + if (matchesEmpty(ListPat(tail))) + Right(ListPat(Right(WildCard) :: lp) :: Nil) + else Right(ListPat(lp) :: Nil) + case (Left(_) :: tail, Right(_) :: _) => + // The right hand side can't match a zero length list + val zero = ListPat(tail) + val oneOrMore = Right(WildCard) :: lp + difference0List(oneOrMore, rp) + .map(zero :: _) + case (_, Left(_) :: rtail) if matchesEmpty(ListPat(rtail)) => + // this is a total match + Right(Nil) + case (_, Left(_) :: _) => + // In this branch, the right cannot match + // the empty list, but the left side can + // we could in principle match a finite + // list from either direction, so we reverse + // and try again + difference0List(lp.reverse, rp.reverse) + .map(_.map { + case ListPat(diff) => ListPat(diff.reverse) + case other => sys.error(s"unreachable: list patterns can't difference to non-list: $other") + }) + } + } + def difference0(left: Pattern[Cons, Type], right: Pattern[Cons, Type]): Res[Patterns] = { isTotal(right).flatMap { case true => Right(Nil): Res[Patterns] @@ -82,59 +140,7 @@ case class TotalityCheck(inEnv: TypeEnv) { // _ is the same as [*_] for well typed expressions difference0(ListPat(Left(None) :: Nil), lp) case (ListPat(lp), rightList@ListPat(rp)) => - (lp, rp) match { - case (Nil, Nil) => - // total overlap - Right(Nil) - case (Nil, Right(_) :: _) => - // a list of 1 or more, can't match less - Right(left :: Nil) - case (Nil, Left(_) :: tail) => - // we can have zero or more, 1 or more clearly can't match: - // if the tail can match 0, we anhilate, otherwise not - if (matchesEmpty(ListPat(tail))) Right(Nil) - else Right(left :: Nil) - case (Right(_) :: _, Nil) => - // left has at least one - Right(left :: Nil) - case (Right(lhead) :: ltail, Right(rhead) :: rtail) => - // we use productDifference here - productDifference((lhead, rhead) :: (ListPat(ltail), ListPat(rtail)) :: Nil) - .map { listOfList => - listOfList.map { - case h :: ListPat(tail) :: Nil => - ListPat(Right(h) :: tail) - case other => - sys.error(s"expected exactly two items: $other") - } - } - case (Left(_) :: tail, Nil) => - // if tail matches empty, then we can only match 1 or more - // else, these are disjoint - if (matchesEmpty(ListPat(tail))) - Right(ListPat(Right(WildCard) :: lp) :: Nil) - else Right(left :: Nil) - case (Left(_) :: tail, Right(_) :: _) => - // The right hand side can't match a zero length list - val zero = ListPat(tail) - val oneOrMore = ListPat(Right(WildCard) :: lp) - difference0(oneOrMore, right) - .map(zero :: _) - case (_, Left(_) :: rtail) if matchesEmpty(ListPat(rtail)) => - // this is a total match - Right(Nil) - case (_, Left(_) :: rtail) => - // In this branch, the right cannot match - // the empty list, but the left side can - // we could in principle match a finite - // list from either direction, so we reverse - // and try again - difference0(ListPat(lp.reverse), ListPat(rp.reverse)) - .map(_.map { - case ListPat(diff) => ListPat(diff.reverse) - case other => sys.error(s"unreachable: list patterns can't difference to non-list: $other") - }) - } + difference0List(lp, rp) case (WildCard | Var(_), PositionalStruct(nm, ps)) => inEnv.definedTypeFor(nm) match { case None => Left(NonEmptyList.of(UnknownConstructor(nm, right, inEnv))) @@ -183,7 +189,8 @@ case class TotalityCheck(inEnv: TypeEnv) { checkArity(ln, lp.size, left) .product(checkArity(rn, rp.size, right)) .as(()) - productDifference(lp zip rp).map { pats => + + arityMatch >> productDifference(lp zip rp).map { pats => pats.map(PositionalStruct(ln, _)) } case _ => @@ -252,13 +259,13 @@ case class TotalityCheck(inEnv: TypeEnv) { type ResList[A] = Res[List[A]] implicit val app = Applicative[Res].compose(Applicative[List]) - val parts = check.flatMap { _ => - lps.zip(rps).traverse[ResList, Pattern[Cons, Type]] { - case (l, r) => intersection(l, r) + check >> + check.flatMap { _ => + lps.zip(rps).traverse[ResList, Pattern[Cons, Type]] { + case (l, r) => intersection(l, r) + } } - } - - parts.map(_.map(PositionalStruct(ln, _))) + .map(_.map(PositionalStruct(ln, _))) } else Right(Nil) } From 8995ee77890a46c0c98f122ae373000d8c7781f6 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Sun, 4 Nov 2018 03:42:19 +0000 Subject: [PATCH 4/8] add some test coverage I hope --- .../scala/org/bykn/bosatsu/TotalityTest.scala | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala index 8484751eb..164de3428 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala @@ -7,6 +7,8 @@ import rankn._ import Parser.Combinators import fastparse.all.Parsed +import org.typelevel.paiges.Document + class TotalityTest extends FunSuite { import TestParseUtils._ @@ -49,8 +51,14 @@ class TotalityTest extends FunSuite { } def testTotality(te: TypeEnv, pats: List[Pattern[(PackageName, ConstructorName), Type]], tight: Boolean = false) = { - TotalityCheck(te).isTotal(pats) match { - case Right(res) => assert(res) + TotalityCheck(te).missingBranches(pats) match { + case Right(res) => + val asStr = res.map { pat => + val pat0 = pat.mapName { case (_, ConstructorName(n)) => n } + .mapType { t => TypeRef.fromType(t).get } + Document[Pattern[String, TypeRef]].document(pat0).render(80) + } + assert(asStr == Nil) case Left(errs) => fail(errs.toString) } // any missing pattern shouldn't be total: @@ -129,9 +137,11 @@ enum Either: Left(l), Right(r) val te = typeEnvOf("""# enum Either: Left(l), Right(r) enum Option: None, Some(get) +struct Tuple2(fst, snd) """) testTotality(te, patterns("[None, Some(Left(_)), Some(Right(_))]"), tight = true) + testTotality(te, patterns("[None, Some(Tuple2(Left(_), _)), Some(Tuple2(_, Right(_))), Some(Tuple2(Right(_), Left(_)))]"), tight = true) } test("compose List with structs") { @@ -139,5 +149,6 @@ enum Option: None, Some(get) enum Either: Left(l), Right(r) """) testTotality(te, patterns("[[Left(_), *_], [Right(_), *_], [], [_, _, *_]]"), tight = true) + testTotality(te, patterns("[Left([]), Left([h, *_]), Right([]), Right([h, *_])]"), tight = true) } } From b77315b70d2abacebb0c097213f601131a311280 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Mon, 5 Nov 2018 05:33:47 +0000 Subject: [PATCH 5/8] Add scalacheck, fix a ton of bugs --- .../main/scala/org/bykn/bosatsu/Pattern.scala | 19 + .../org/bykn/bosatsu/TotalityCheck.scala | 375 ++++++++++++------ .../org/bykn/bosatsu/rankn/DefinedType.scala | 6 - .../src/test/scala/org/bykn/bosatsu/Gen.scala | 75 +++- .../scala/org/bykn/bosatsu/TotalityTest.scala | 343 +++++++++++++++- .../org/bykn/bosatsu/rankn/TypeTest.scala | 16 + 6 files changed, 670 insertions(+), 164 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala index 8792fb0a8..8907bd7e6 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala @@ -52,6 +52,24 @@ object Pattern { Pattern.PositionalStruct(name, ps) } } + + /** + * Return the pattern with all the binding names removed + */ + def unbind: Pattern[N, T] = + pat match { + case Pattern.WildCard | Pattern.Literal(_) => pat + case Pattern.Var(_) => Pattern.WildCard + case Pattern.ListPat(items) => + Pattern.ListPat(items.map { + case Left(_) => Left(None) + case Right(p) => Right(p.unbind) + }) + case Pattern.Annotation(p, tpe) => + Pattern.Annotation(p.unbind, tpe) + case Pattern.PositionalStruct(name, params) => + Pattern.PositionalStruct(name, params.map(_.unbind)) + } } case object WildCard extends Pattern[Nothing, Nothing] @@ -61,6 +79,7 @@ object Pattern { case class Annotation[N, T](pattern: Pattern[N, T], tpe: T) extends Pattern[N, T] case class PositionalStruct[N, T](name: N, params: List[Pattern[N, T]]) extends Pattern[N, T] + implicit lazy val document: Document[Pattern[String, TypeRef]] = Document.instance[Pattern[String, TypeRef]] { case WildCard => Doc.char('_') diff --git a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala index 61ac5f505..b2801d71d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala @@ -1,6 +1,6 @@ package org.bykn.bosatsu -import cats.{Monad, Applicative} +import cats.{Monad, Applicative, Eq} import cats.data.NonEmptyList import cats.implicits._ @@ -11,11 +11,13 @@ object TotalityCheck { type Cons = (PackageName, ConstructorName) type Res[+A] = Either[NonEmptyList[Error], A] type Patterns = List[Pattern[Cons, Type]] + type ListPatElem = Either[Option[String], Pattern[Cons, Type]] sealed abstract class Error case class ArityMismatch(cons: Cons, in: Pattern[Cons, Type], env: TypeEnv, expected: Int, found: Int) extends Error case class UnknownConstructor(cons: Cons, in: Pattern[Cons, Type], env: TypeEnv) extends Error case class UntypedPattern(pat: Pattern[Cons, Type], env: TypeEnv) extends Error + case class MultipleSplicesInPattern(pat: ListPat[Cons, Type], env: TypeEnv) extends Error } case class TotalityCheck(inEnv: TypeEnv) { @@ -63,10 +65,10 @@ case class TotalityCheck(inEnv: TypeEnv) { left.traverse(difference0(_, right)).map(_.flatten) @annotation.tailrec - private def matchesEmpty(lp: ListPat[Cons, Type]): Boolean = - lp.parts match { + private def matchesEmpty(lp: List[ListPatElem]): Boolean = + lp match { case Nil => true - case Left(_) :: tail => matchesEmpty(ListPat(tail)) + case Left(_) :: tail => matchesEmpty(tail) case Right(_) :: _ => false } @@ -83,7 +85,7 @@ case class TotalityCheck(inEnv: TypeEnv) { case (Nil, Left(_) :: tail) => // we can have zero or more, 1 or more clearly can't match: // if the tail can match 0, we anhilate, otherwise not - if (matchesEmpty(ListPat(tail))) Right(Nil) + if (matchesEmpty(tail)) Right(Nil) else Right(ListPat(lp) :: Nil) case (Right(_) :: _, Nil) => // left has at least one @@ -102,110 +104,131 @@ case class TotalityCheck(inEnv: TypeEnv) { case (Left(_) :: tail, Nil) => // if tail matches empty, then we can only match 1 or more // else, these are disjoint - if (matchesEmpty(ListPat(tail))) + if (matchesEmpty(tail)) Right(ListPat(Right(WildCard) :: lp) :: Nil) else Right(ListPat(lp) :: Nil) case (Left(_) :: tail, Right(_) :: _) => - // The right hand side can't match a zero length list - val zero = ListPat(tail) + val zero = tail val oneOrMore = Right(WildCard) :: lp - difference0List(oneOrMore, rp) - .map(zero :: _) - case (_, Left(_) :: rtail) if matchesEmpty(ListPat(rtail)) => + (difference0List(zero, rp), difference0List(oneOrMore, rp)).mapN(_ ::: _) + case (_, Left(_) :: rtail) if matchesEmpty(rtail) => // this is a total match Right(Nil) case (_, Left(_) :: _) => - // In this branch, the right cannot match - // the empty list, but the left side can - // we could in principle match a finite - // list from either direction, so we reverse - // and try again - difference0List(lp.reverse, rp.reverse) - .map(_.map { - case ListPat(diff) => ListPat(diff.reverse) - case other => sys.error(s"unreachable: list patterns can't difference to non-list: $other") - }) + // if this pattern ends with Left(_) we have + // a hard match problem on our hands. For now, we ban it: + val revRight = rp.reverse + revRight match { + case Left(_) :: tail if !matchesEmpty(tail) => + Left(NonEmptyList.of(MultipleSplicesInPattern(ListPat(rp), inEnv))) + case _ => + // we can make progress: + + // In this branch, the right cannot match + // the empty list, but the left side can + // we could in principle match a finite + // list from either direction, so we reverse + // and try again + difference0List(lp.reverse, revRight) + .map(_.map { + case ListPat(diff) => ListPat(diff.reverse) + case other => sys.error(s"unreachable: list patterns can't difference to non-list: $other") + }) + } } } - def difference0(left: Pattern[Cons, Type], right: Pattern[Cons, Type]): Res[Patterns] = { - isTotal(right).flatMap { - case true => Right(Nil): Res[Patterns] - case false => - (left, right) match { - case (WildCard | Var(_), Literal(_)) => - // the left is infinite, and the right is just one value - Right(left :: Nil) - case (WildCard | Var(_), lp@ListPat(_)) => - // _ is the same as [*_] for well typed expressions - difference0(ListPat(Left(None) :: Nil), lp) - case (ListPat(lp), rightList@ListPat(rp)) => - difference0List(lp, rp) - case (WildCard | Var(_), PositionalStruct(nm, ps)) => - inEnv.definedTypeFor(nm) match { - case None => Left(NonEmptyList.of(UnknownConstructor(nm, right, inEnv))) - case Some(dt) => - dt.constructors.traverse { - case (c, params, _) if (dt.packageName, c) == nm => - /* - * At each position we compute the difference with _ - * then make: - * Struct(d1, _, _), Struct(_, d2, _), ... - */ - def poke[M[_]: Applicative, A](items: List[A])(fn: A => M[List[A]]): M[List[List[A]]] = - items match { - case Nil => Applicative[M].pure(Nil) - case h :: tail => - val ptail = poke(tail)(fn) - val head = fn(h) - (head, ptail).mapN { (heads, tails) => - val t2 = tails.map(h :: _) - val h1 = heads.map(_ :: tail) - h1 ::: t2 - } + def difference0(left: Pattern[Cons, Type], right: Pattern[Cons, Type]): Res[Patterns] = + (left, right) match { + case (_, WildCard | Var(_)) => Right(Nil) + case (WildCard | Var(_), Literal(_)) => + // the left is infinite, and the right is just one value + Right(left :: Nil) + case (WildCard | Var(_), _) if isTotal(right) == Right(true) => + Right(Nil) + case (WildCard, ListPat(rp)) => + // _ is the same as [*_] for well typed expressions + difference0List(Left(None) :: Nil, rp) + case (Var(v), ListPat(rp)) => + // v is the same as [*v] for well typed expressions + difference0List(Left(Some(v)) :: Nil, rp) + case (ListPat(lp), ListPat(rp)) => + difference0List(lp, rp) + case (Literal(_), ListPat(_) | PositionalStruct(_, _)) => + Right(left :: Nil) + case (ListPat(_), Literal(_) | PositionalStruct(_, _)) => + Right(left :: Nil) + case (PositionalStruct(_, _), Literal(_) | ListPat(_)) => + Right(left :: Nil) + case (WildCard | Var(_), PositionalStruct(nm, ps)) => + inEnv.definedTypeFor(nm) match { + case None => Left(NonEmptyList.of(UnknownConstructor(nm, right, inEnv))) + case Some(dt) => + dt.constructors.traverse { + case (c, params, _) if (dt.packageName, c) == nm => + /* + * At each position we compute the difference with _ + * then make: + * Struct(d1, _, _), Struct(_, d2, _), ... + */ + def poke[M[_]: Applicative, A](items: List[A])(fn: A => M[List[A]]): M[List[List[A]]] = + items match { + case Nil => Applicative[M].pure(Nil) + case h :: tail => + val ptail = poke(tail)(fn) + val head = fn(h) + (head, ptail).mapN { (heads, tails) => + val t2 = tails.map(h :: _) + val h1 = heads.map(_ :: tail) + h1 ::: t2 } + } - // for this one, we need to compute the difference for each: - poke(ps) { p => difference0(WildCard, p) } - .map(_.map(PositionalStruct(nm, _))) + // for this one, we need to compute the difference for each: + poke(ps) { p => difference0(WildCard, p) } + .map(_.map(PositionalStruct(nm, _))) - case (c, params, _) => - // TODO, this could be smarter - // we need to learn how to deal with typed generics - def argToPat(t: (ParamName, Type)): Pattern[Cons, Type] = - if (Type.hasNoVars(t._2)) Annotation(WildCard, t._2) - else WildCard + case (c, params, _) => + // TODO, this could be smarter + // we need to learn how to deal with typed generics + def argToPat(t: (ParamName, Type)): Pattern[Cons, Type] = + if (Type.hasNoVars(t._2)) Annotation(WildCard, t._2) + else WildCard - Right(List(PositionalStruct((dt.packageName, c), params.map(argToPat)))) - } - .map(_.flatten) + Right(List(PositionalStruct((dt.packageName, c), params.map(argToPat)))) } - case (llit@Literal(l), Literal(r)) => - if (l == r) Right(Nil): Res[Patterns] - else Right(llit :: Nil): Res[Patterns] - case (PositionalStruct(ln, lp), PositionalStruct(rn, rp)) if ln == rn => - // we have two matching structs - val arityMatch = - checkArity(ln, lp.size, left) - .product(checkArity(rn, rp.size, right)) - .as(()) + .map(_.flatten) + } + case (llit@Literal(l), Literal(r)) => + if (l == r) Right(Nil): Res[Patterns] + else Right(llit :: Nil): Res[Patterns] + case (PositionalStruct(ln, lp), PositionalStruct(rn, rp)) if ln == rn => + // we have two matching structs + val arityMatch = + checkArity(ln, lp.size, left) + .product(checkArity(rn, rp.size, right)) + .as(()) - arityMatch >> productDifference(lp zip rp).map { pats => - pats.map(PositionalStruct(ln, _)) - } - case _ => - // There is no overlap - Right(left :: Nil): Res[Patterns] + arityMatch >> productDifference(lp zip rp).map { pats => + pats.map(PositionalStruct(ln, _)) } + case (PositionalStruct(_, _), PositionalStruct(_, _)) => + Right(left :: Nil) + + // case _ => + // // There is no overlap + // Right(left :: Nil): Res[Patterns] } - } def intersection( left: Pattern[Cons, Type], right: Pattern[Cons, Type]): Res[List[Pattern[Cons, Type]]] = (left, right) match { - case (WildCard | Var(_), v) => Right(List(v)) - case (v, WildCard | Var(_)) => Right(List(v)) + case (Var(va), Var(vb)) => Right(List(Var(Ordering[String].min(va, vb)))) + case (WildCard, v) => Right(List(v)) + case (v, WildCard) => Right(List(v)) + case (Var(_), v) => Right(List(v)) + case (v, Var(_)) => Right(List(v)) case (Annotation(p, _), t) => intersection(p, t) case (t, Annotation(p, _)) => intersection(t, p) case (Literal(a), Literal(b)) => @@ -214,40 +237,7 @@ case class TotalityCheck(inEnv: TypeEnv) { case (Literal(_), _) => Right(Nil) case (_, Literal(_)) => Right(Nil) case (ListPat(leftL), ListPat(rightL)) => - (leftL, rightL) match { - case (Nil, Nil) => Right(List(left)) - case (Nil, Right(_) :: _) => Right(Nil) - case (_, Left(_) :: tail) if matchesEmpty(ListPat(tail)) => Right(List(left)) - case (Nil, Left(_) :: _) => Right(List(left)) - case (Right(_) :: _, Nil) => Right(Nil) - case (Right(lh) :: lt, Right(rh) :: rt) => - intersection(lh, rh).flatMap { - case Nil => Right(Nil) - case nonEmpty => - intersection(ListPat(lt), ListPat(rt)) - .map(_.flatMap { - case ListPat(ts) => nonEmpty.map { h => ListPat(Right(h) :: ts) } - case other => sys.error(s"unreachable: list patterns can't intersect to non-list: $other") - }) - } - case (Right(lh) :: lt, Left(rh) :: rt) => - val zero = ListPat(rt) - val oneOrMore = ListPat(Right(WildCard) :: rightL) - // a n (b0 + b1) = (a n b0) + (a n b1) - for { - withZ <- intersection(left, zero) - with0 <- intersection(left, oneOrMore) - } yield withZ ::: with0 - case (Left(_) :: lt, Left(_) :: rt) => - intersection(ListPat(lt), ListPat(rt)) - .map(_.map { - case ListPat(tail) => ListPat(Left(None) :: tail) - case other => sys.error(s"unreachable: list patterns can't intersect to non-list: $other") - }) - case (_, _) => - // intersection is symmetric - intersection(right, left) - } + intersectionList(leftL, rightL) case (ListPat(_), _) => Right(Nil) case (_, ListPat(_)) => Right(Nil) case (PositionalStruct(ln, lps), PositionalStruct(rn, rps)) => @@ -260,16 +250,104 @@ case class TotalityCheck(inEnv: TypeEnv) { type ResList[A] = Res[List[A]] implicit val app = Applicative[Res].compose(Applicative[List]) check >> - check.flatMap { _ => - lps.zip(rps).traverse[ResList, Pattern[Cons, Type]] { - case (l, r) => intersection(l, r) - } + lps.zip(rps).traverse[ResList, Pattern[Cons, Type]] { + case (l, r) => intersection(l, r) } .map(_.map(PositionalStruct(ln, _))) } else Right(Nil) } + private def intersectionList(leftL: List[ListPatElem], rightL: List[ListPatElem]): Res[Patterns] = { + def left = ListPat(leftL) + (leftL, rightL) match { + case (_, Left(_) :: tail) if matchesEmpty(tail) => + // the right hand side is a top value, it can match any list, so intersection with top is + // left + Right(left :: Nil) + case (Left(_) :: tail, _) if matchesEmpty(tail) => + // the left hand side is a top value, it can match any list, so intersection with top is + // right + Right(ListPat(rightL) :: Nil) + case (Nil, Nil) => Right(List(left)) + case (Nil, Right(_) :: _) => Right(Nil) + case (Nil, Left(_) :: _) | (Left(_) :: _, Nil) => + // the non Nil patterns can't match empty due to the above: + Right(Nil) + case (Right(_) :: _, Nil) => Right(Nil) + case (Right(lh) :: lt, Right(rh) :: rt) => + intersection(lh, rh).flatMap { + case Nil => Right(Nil) + case nonEmpty => + intersectionList(lt, rt) + .map(_.flatMap { + case ListPat(ts) => + // this could create duplicates + nonEmpty + .map { h => ListPat(Right(h) :: ts) } + .distinct + case other => sys.error(s"unreachable: list patterns can't intersect to non-list: $other") + }) + } + case (Right(lh) :: lt, Left(rh) :: rt) => + // a n (b0 + b1) = (a n b0) + (a n b1) + val zero = rt + val oneOrMore = Right(WildCard) :: rightL + for { + withZ <- intersectionList(leftL, zero) + with0 <- intersectionList(leftL, oneOrMore) + } yield (withZ ::: with0).distinct + case (Left(_) :: _, Right(_) :: _) => + // intersection is symmetric + intersectionList(rightL, leftL) + case (Left(a) :: lt, Left(b) :: rt) => + /* + * the left and right can consume any number + * of items before matching the rest. + * + * if we assume rt has no additional Lefts, + * we can pad the left to be the same size + * by adding wildcards, and repeat + */ + def hasMultiple(ps: List[ListPatElem]): Boolean = + ps.exists { + case Left(_) => true + case Right(_) => false + } + + (hasMultiple(lt), hasMultiple(rt)) match { + case (true, true) => + Left(NonEmptyList.of( + MultipleSplicesInPattern(ListPat(rightL), inEnv), + MultipleSplicesInPattern(ListPat(leftL), inEnv))) + case (_, false) => + /* + * make suffix of rt that lines up with lt + */ + val rtSize = rt.size + val ltSize = lt.size + val padSize = rtSize - ltSize + val (initLt, lastLt) = + if (padSize > 0) { + (List.empty[ListPatElem], List.fill(padSize)(Right(WildCard)) reverse_::: lt) + } + else { + (lt.take(-padSize), lt.drop(-padSize)) + } + intersectionList(lastLt, rt) + .map(_.map { + case ListPat(tail) => + val m: ListPatElem = Left(if (a == b) a else None) + ListPat(m :: initLt ::: tail) + case other => sys.error(s"unreachable: list patterns can't intersect to non-list: $other") + }) + case (false, _) => + // intersection is symmetric + intersectionList(rightL, leftL) + } + } + } + /** * There the list is a tuple or product pattern * the left and right should be the same size and the result will be a list of lists @@ -290,6 +368,9 @@ case class TotalityCheck(inEnv: TypeEnv) { * seems to be: * * (a0 x a1) - (b0 x b1) = (a0 - b0) x a1 + (a0 n b0) x (a1 - b1) + * + * Note, if a1 - b1 = a1, this becomes: + * ((a0 - b0) + (a0 n b0)) x a1 = a0 x a1 */ zip match { case Nil => Right(Nil) // complete match @@ -311,7 +392,16 @@ case class TotalityCheck(inEnv: TypeEnv) { } } - (headDiff, tailDiff).mapN(_ ::: _) + tailDiff match { + case Right((intr :: td) :: Nil) + if tail.zip(td).forall { case ((a1, _), d) => eqPat.eqv(a1, d) } => + // this is the rule that if the rest has no diff, the first + // part has no diff + // not needed for correctness, but useful for normalizing + Right(zip.map(_._1) :: Nil) + case _ => + (headDiff, tailDiff).mapN(_ ::: _) + } } /** @@ -334,9 +424,9 @@ case class TotalityCheck(inEnv: TypeEnv) { case Pattern.WildCard | Pattern.Var(_) => Right(true) case Pattern.Literal(_) => Right(false) // literals are not total case Pattern.ListPat(Left(_) :: rest) => - Right(matchesEmpty(ListPat(rest))) + Right(matchesEmpty(rest)) case Pattern.ListPat(_) => - // either can't match everything on the front or back + // can't match everything on the front Right(false) case Pattern.Annotation(p, _) => isTotal(p) case Pattern.PositionalStruct(name, params) => @@ -350,5 +440,36 @@ case class TotalityCheck(inEnv: TypeEnv) { } } + def normalizePattern(p: Pattern[Cons, Type]): Pattern[Cons, Type] = + isTotal(p) match { + case Right(true) => WildCard + case _ => + p match { + case WildCard | Literal(_) => p + case Var(_) => WildCard + case ListPat(ls) => + val normLs: List[ListPatElem] = + ls.map { + case Left(_) => Left(None) + case Right(p) => Right(normalizePattern(p)) + } + normLs match { + case Left(None) :: Nil => WildCard + case rest => ListPat(rest) + } + case Annotation(p, t) => Annotation(normalizePattern(p), t) + case PositionalStruct(n, params) => + PositionalStruct(n, params.map(normalizePattern)) + } + } + /** + * This tells if two patterns for the same type + * would match the same values + */ + val eqPat: Eq[Pattern[Cons, Type]] = + new Eq[Pattern[Cons, Type]] { + def eqv(l: Pattern[Cons, Type], r: Pattern[Cons, Type]) = + normalizePattern(l) == normalizePattern(r) + } } diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala index 4c16afa60..bdebd5424 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala @@ -27,12 +27,6 @@ case class DefinedType( def toOpaque: DefinedType = copy(constructors = Nil) - - /** - * This may be a ForAll type if there are typeParams - */ - def fullType: Type = - Type.forAll(typeParams, toTypeTyConst) } object DefinedType { diff --git a/core/src/test/scala/org/bykn/bosatsu/Gen.scala b/core/src/test/scala/org/bykn/bosatsu/Gen.scala index b6b043ebe..c4cc59d83 100644 --- a/core/src/test/scala/org/bykn/bosatsu/Gen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/Gen.scala @@ -181,23 +181,61 @@ object Generators { .map { case (ifs, elsec) => IfElse(ifs, elsec)(emptyRegion) } } + def genPattern(depth: Int): Gen[Pattern[String, TypeRef]] = { + val recurse = Gen.lzy(genPattern(depth - 1)) + val genVar = lowerIdent.map(Pattern.Var(_)) + val genWild = Gen.const(Pattern.WildCard) + val genLitPat = genLit.map(Pattern.Literal(_)) + + if (depth <= 0) Gen.oneOf(genVar, genWild, genLitPat) + else { + val genTyped = Gen.zip(recurse, typeRefGen) + .map { case (p, t) => Pattern.Annotation(p, t) } + + val genStruct = for { + nm <- upperIdent + cnt <- Gen.choose(0, 6) + args <- Gen.listOfN(cnt, recurse) + } yield Pattern.PositionalStruct(nm, args) + + def makeOneSplice(ps: List[Either[Option[String], Pattern[String, TypeRef]]]) = { + val sz = ps.size + if (sz == 0) Gen.const(ps) + else Gen.choose(0, sz - 1).flatMap { idx => + val splice = Gen.oneOf( + Gen.const(Left(None)), + lowerIdent.map { v => Left(Some(v)) }) + + splice.map { v => ps.updated(idx, v) } + } + } + + val genListItem: Gen[Either[Option[String], Pattern[String, TypeRef]]] = + recurse.map(Right(_)) + + val genList = Gen.choose(0, 5) + .flatMap(Gen.listOfN(_, genListItem)) + .flatMap { ls => + Gen.oneOf(true, false) + .flatMap { + case true => Gen.const(ls) + case false => makeOneSplice(ls) + } + } + .map(Pattern.ListPat(_)) + + Gen.oneOf(genVar, genWild, genLitPat, genStruct, genList /*, genTyped */) + } + } + def matchGen(bodyGen: Gen[Declaration]): Gen[Declaration.Match] = { import Declaration._ val padBody = optIndent(bodyGen) - val genPattern = for { - nm <- upperIdent - cnt <- Gen.choose(0, 6) - args <- Gen.listOfN(cnt, Gen.option(lowerIdent)) - argPat = args.map { - case None => Pattern.WildCard - case Some(v) => Pattern.Var(v) - } - } yield Pattern.PositionalStruct(nm, argPat) val genCase: Gen[(Pattern[String, TypeRef], OptIndent[Declaration])] = - Gen.zip(genPattern, padBody) + Gen.zip(genPattern(3), padBody) for { cnt <- Gen.choose(1, 2) @@ -206,20 +244,25 @@ object Generators { } yield Match(expr, cases)(emptyRegion) } - def genDeclaration(depth: Int): Gen[Declaration] = { - import Declaration._ - + val genLit: Gen[Lit] = { val str = for { q <- Gen.oneOf('\'', '"') //str <- Arbitrary.arbitrary[String] str <- lowerIdent // TODO - } yield Literal(Lit.Str(str))(emptyRegion) + } yield Lit.Str(str) + + val bi = Arbitrary.arbitrary[BigInt].map { bi => Lit.Integer(bi.bigInteger) } + Gen.oneOf(str, bi) + } + + + def genDeclaration(depth: Int): Gen[Declaration] = { + import Declaration._ val unnested = Gen.oneOf( lowerIdent.map(Var(_)(emptyRegion)), upperIdent.map(Constructor(_)(emptyRegion)), - Arbitrary.arbitrary[BigInt].map { bi => Literal(Lit.Integer(bi.bigInteger))(emptyRegion) }, - str) + genLit.map(Literal(_)(emptyRegion))) val recur = Gen.lzy(genDeclaration(depth - 1)) if (depth <= 0) unnested diff --git a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala index 164de3428..c69552a95 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala @@ -1,6 +1,9 @@ package org.bykn.bosatsu +import cats.Eq import org.scalatest.FunSuite +import org.scalatest.prop.PropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalacheck.Gen import rankn._ @@ -12,18 +15,46 @@ import org.typelevel.paiges.Document class TotalityTest extends FunSuite { import TestParseUtils._ + implicit val generatorDrivenConfig = + PropertyCheckConfiguration(minSuccessful = 20000) + //PropertyCheckConfiguration(minSuccessful = 50) + //PropertyCheckConfiguration(minSuccessful = 5) + val pack = PackageName.parts("Test") def const(t: String): Type = Type.TyConst(Type.Const.Defined(pack, t)) + val tpeFn: String => Type.Const = + { tpe => Type.Const.Defined(pack, tpe) } + + val consFn: String => (PackageName, ConstructorName) = + { cons => (pack, ConstructorName(cons)) } + + def parsedToExpr(pat: Pattern[String, TypeRef]): Pattern[(PackageName, ConstructorName), Type] = + pat + .mapName(consFn) + .mapType(_.toType(tpeFn)) + + val genPattern: Gen[Pattern[(PackageName, ConstructorName), Type]] = + Generators.genPattern(5) + .map(parsedToExpr _) + + def showPat(pat: Pattern[(PackageName, ConstructorName), Type]): String = { + val pat0 = pat.mapName { case (_, ConstructorName(n)) => n } + .mapType { t => TypeRef.fromType(t).get } + Document[Pattern[String, TypeRef]].document(pat0).render(80) + } + def showPatU(pat: Pattern[(PackageName, ConstructorName), Type]): String = + showPat(pat.unbind) + def typeEnvOf(str: String): TypeEnv = Statement.parser.parse(str) match { case Parsed.Success(stmt, idx) => assert(idx == str.length) val prog = Program.fromStatement( pack, - { tpe => Type.Const.Defined(pack, tpe) }, - { cons => (pack, ConstructorName(cons)) }, + tpeFn, + consFn, stmt) prog.types case Parsed.Failure(exp, idx, extra) => @@ -34,30 +65,39 @@ class TotalityTest extends FunSuite { def patterns(str: String): List[Pattern[(PackageName, ConstructorName), Type]] = Pattern.parser.listSyntax.parse(str) match { case Parsed.Success(pats, idx) => - pats.map { pat => - pat - .mapName { n => (pack, ConstructorName(n)) } - .mapType(_.toType { n => Type.Const.Defined(pack, n) }) - } + pats.map(parsedToExpr _) case Parsed.Failure(exp, idx, extra) => fail(s"failed to parse: $str: $exp at $idx in region ${region(str, idx)} with trace: ${extra.traced.trace}") sys.error("could not produce TypeEnv") } - def notTotal(te: TypeEnv, pats: List[Pattern[(PackageName, ConstructorName), Type]]) = + def notTotal(te: TypeEnv, pats: List[Pattern[(PackageName, ConstructorName), Type]], testMissing: Boolean = true): Unit = { TotalityCheck(te).isTotal(pats) match { case Right(res) => assert(!res, pats.toString) case Left(errs) => fail(errs.toString) } + if (testMissing) { + // if we add the missing, it should be total + TotalityCheck(te).missingBranches(pats) match { + case Left(errs) => fail(errs.toString) + case Right(mb) => + // missing branches can't be tight because + // for instance: + // match x: + // 1: foo + // + // is not total, but can only be made total by + // adding a wildcard match, which by itself is total + testTotality(te, pats ::: mb, tight = false) + } + } + } + def testTotality(te: TypeEnv, pats: List[Pattern[(PackageName, ConstructorName), Type]], tight: Boolean = false) = { TotalityCheck(te).missingBranches(pats) match { case Right(res) => - val asStr = res.map { pat => - val pat0 = pat.mapName { case (_, ConstructorName(n)) => n } - .mapType { t => TypeRef.fromType(t).get } - Document[Pattern[String, TypeRef]].document(pat0).render(80) - } + val asStr = res.map(showPat) assert(asStr == Nil) case Left(errs) => fail(errs.toString) } @@ -73,7 +113,7 @@ class TotalityTest extends FunSuite { pats match { case h :: tail if tight => - allButOne(h, tail).foreach(notTotal(te, _)) + allButOne(h, tail).foreach(notTotal(te, _, testMissing = false)) // don't make an infinite loop here case _ => () } } @@ -128,9 +168,9 @@ enum Either: Left(l), Right(r) test("test List matching") { testTotality(TypeEnv.empty, patterns("[[], [h, *tail]]"), tight = true) testTotality(TypeEnv.empty, patterns("[[], [h, *tail], [h0, h1, *tail]]"), tight = true) + testTotality(TypeEnv.empty, patterns("[[], [*tail, _]]"), tight = true) notTotal(TypeEnv.empty, patterns("[[], [h, *tail, _]]")) - notTotal(TypeEnv.empty, patterns("[[], [*tail, _]]")) } test("multiple struct compose") { @@ -151,4 +191,277 @@ enum Either: Left(l), Right(r) testTotality(te, patterns("[[Left(_), *_], [Right(_), *_], [], [_, _, *_]]"), tight = true) testTotality(te, patterns("[Left([]), Left([h, *_]), Right([]), Right([h, *_])]"), tight = true) } + + + test("test intersection") { + val p0 :: p1 :: Nil = patterns("[[*_], [*_, _]]") + TotalityCheck(TypeEnv.empty).intersection(p0, p1) match { + case Left(err) => fail(err.toString) + case Right(intr :: Nil) => assert(p1 == intr) + case Right(many) => fail(s"expected exactly one intersection: ${many.map(showPat)}") + } + + val p2 :: p3 :: Nil = patterns("[[*_], [_, _]]") + TotalityCheck(TypeEnv.empty).intersection(p2, p3) match { + case Left(err) => fail(err.toString) + case Right(intr :: Nil) => assert(p3 == intr) + case Right(many) => fail(s"expected exactly one intersection: ${many.map(showPat)}") + } + } + + test("test some difference examples") { + val tc = TotalityCheck(TypeEnv.empty) + import tc.eqPat.eqv + { + val p0 :: p1 :: Nil = patterns("[[1], [\"foo\", _]]") + tc.difference0(p0, p1) match { + case Left(err) => fail(err.toString) + case Right(diff :: Nil) => assert(eqv(p0, diff)) + case Right(many) => fail(s"expected exactly one difference: ${many.map(showPat)}") + } + } + + { + val p0 :: p1 :: Nil = patterns("[[_, _], [[*foo]]]") + TotalityCheck(TypeEnv.empty).difference0(p1, p0) match { + case Left(err) => fail(err.toString) + case Right(diff :: Nil) => assert(eqv(diff, p1)) + case Right(many) => fail(s"expected exactly one difference: ${many.map(showPat)}") + } + TotalityCheck(TypeEnv.empty).difference0(p0, p1) match { + case Left(err) => fail(err.toString) + case Right(diff :: Nil) => assert(eqv(diff, p0)) + case Right(many) => fail(s"expected exactly one difference: ${many.map(showPat)}") + } + } + } + + test("intersection(a, a) == a") { + def law(p: Pattern[(PackageName, ConstructorName), Type]) = + // this would be better if we could get + // generate random patterns from a sane + // type Env... thats a TODO) + TotalityCheck(TypeEnv.empty) + .intersection(p, p) match { + case Left(_) => () // we can often fail now due to bad patterns + case Right(h :: Nil) => assert(h == p) + case Right(many) => fail(s"expected one intersection, found many: $many") + } + + forAll(genPattern)(law) + + def manualTest(str: String) = { + val a :: Nil = patterns(str) + law(a) + } + + manualTest("[[_, _]]") + manualTest("[[*foo]]") + manualTest("[[*_]]") + } + + test("intersection(a, b) == intersection(b, a)") { + def law( + a: Pattern[(PackageName, ConstructorName), Type], + b: Pattern[(PackageName, ConstructorName), Type]) = { + // this would be better if we could get + // generate random patterns from a sane + // type Env... thats a TODO) + val ab = TotalityCheck(TypeEnv.empty) + .intersection(a, b) + val ba = TotalityCheck(TypeEnv.empty) + .intersection(b, a) + (ab, ba) match { + case (Left(_), Left(_)) => () + case (Left(err), Right(_)) => + fail(s"a = ${showPat(a)} b = ${showPat(b)} ab fails, but ba succeeds: $err") + case (Right(_), Left(err)) => + fail(s"a = ${showPat(a)} b = ${showPat(b)} ba fails, but ab succeeds: $err") + case (Right(ab), Right(ba)) => + assert(ab.map(showPatU) == ba.map(showPatU), s"a = ${showPat(a)} b = ${showPat(b)}") + } + } + + forAll(genPattern, genPattern)(law) + + def manualTest(str: String) = { + val a :: b :: Nil = patterns(str) + law(a, b) + } + + manualTest("[[_, _], [*foo]]") + manualTest("[[*foo], [1, 2, *_]]") + } + + test("if intersection(a, b) = 0, then a - b == a") { + def law( + a: Pattern[(PackageName, ConstructorName), Type], + b: Pattern[(PackageName, ConstructorName), Type]) = { + + // this would be better if we could get + // generate random patterns from a sane + // type Env... thats a TODO) + val tc = TotalityCheck(TypeEnv.empty) + tc.intersection(a, b) match { + case Left(_) => () // we can often fail now due to bad patterns + case Right(Nil) => + tc.difference0(a, b) match { + case Left(err) => () // due to our generators, we fail a lot + case Right(h :: Nil) => + assert(tc.eqPat.eqv(h, a), s"${showPat(h)} != ${showPat(a)}") + case Right(newDiff) => + // our tests are not well typed, in well typed + // code, this shouldn't happen, but it can if the + // right side is total: + tc.isTotal(b::Nil) match { + case Left(_) => () + case Right(true) => () + case Right(false) => + fail(s"a = ${showPat(a)} b = ${showPat(b)}, a n b == 0. expected no diff, found ${newDiff.map(showPat)}") + } + } + case Right(_) => () + } + } + + def manualTest(str: String) = { + val a :: b :: Nil = patterns(str) + law(a, b) + } + + /** + * These are some harder regressions that have caught us in the past + */ + manualTest("[[_, _], [*foo]]") + /* + * the following is trick: + * [_, _] n [[*foo]] == 0 because the left matches 2 item lists and the right only 1 + * + * if we consider it as a product: + * _ x [_] - [*foo] x [] + * + * the product difference formula: + * (a0 x a1) - (b0 x b1) = (a0 - b0) x a1 + (a0 n b0) x (a1 - b1) + * + * suggests: _ - [*foo] = 0 for well typed expressions + * and (_ n [*foo]) = [*foo] or _ depending on the way to write it + * + * then [_] - [] = [_] since they don't overlap + * + * if we write `(_ n [*foo]) as _ we get the right answer + * if we write it as [*foo] we get [[*foo], _] + */ + manualTest("[[_, _], [[*foo]]]") + manualTest("[['foo'], [*foo, [*_]]]") + + //forAll(genPattern, genPattern)(law(_, _)) + } + + test("difference returns distinct values") { + forAll(genPattern, genPattern) { (a, b) => + // this would be better if we could get + // generate random patterns from a sane + // type Env... thats a TODO) + val tc = TotalityCheck(TypeEnv.empty) + tc.difference0(a, b) match { + case Left(_) => () // we can often fail now due to bad patterns + case Right(c) => assert(c == c.distinct) + } + } + } + test("intersection returns distinct values") { + forAll(genPattern, genPattern) { (a, b) => + // this would be better if we could get + // generate random patterns from a sane + // type Env... thats a TODO) + val tc = TotalityCheck(TypeEnv.empty) + tc.intersection(a, b) match { + case Left(_) => () // we can often fail now due to bad patterns + case Right(c) => assert(c == c.distinct) + } + } + } + + test("a - b = c then c - b == c, because we have already removed all of b") { + def law( + a: Pattern[(PackageName, ConstructorName), Type], + b: Pattern[(PackageName, ConstructorName), Type]) = { + // this would be better if we could get + // generate random patterns from a sane + // type Env... thats a TODO) + val tc = TotalityCheck(TypeEnv.empty) + tc.difference0(a, b) match { + case Left(_) => () // we can often fail now due to bad patterns + case Right(c) => + tc.difference(c, b) match { + case Left(err) => () // due to our generators, we fail a lot + case Right(c1) => + // this quadradic, but order independent + val eqList = new Eq[TotalityCheck.Patterns] { + def eqv(a: TotalityCheck.Patterns, b: TotalityCheck.Patterns) = { + (a, b) match { + case (ah :: taila, _) if taila.exists(tc.eqPat.eqv(ah, _)) => + // duplicate, skip it + eqv(taila, b) + case (_, bh :: tailb) if tailb.exists(tc.eqPat.eqv(bh, _)) => + // duplicate, skip it + eqv(a, tailb) + case (ah :: taila, Nil) => false + case (ah :: taila, bh :: tailb) => + b.exists(tc.eqPat.eqv(_, ah)) && + a.exists(tc.eqPat.eqv(_, bh)) && + eqv(taila, tailb) + case (Nil, Nil) => true + case (Nil, _) => false + case (_, Nil) => false + } + } + } + + assert(eqList.eqv(c1, c), + s"${showPat(a)} - (b=${showPat(b)}) = ${c.map(showPat)} - b = ${c1.map(showPat)} diff = ${ + c.map(showPatU).diff(c1.map(showPatU))}" ) + } + } + } + def manualTest(str: String) = { + val a :: b :: Nil = patterns(str) + law(a, b) + } + + manualTest("[[*foo, bar], [baz]]") + //forAll(genPattern, genPattern)(law) + } + + test("a - a = 0") { + forAll(genPattern) { a => + // this would be better if we could get + // generate random patterns from a sane + // type Env... thats a TODO) + val tc = TotalityCheck(TypeEnv.empty) + tc.difference0(a, a) match { + case Left(_) => () // we can often fail now due to bad patterns + case Right(Nil) => succeed + case Right(many) => fail(s"expected empty difference: $many") + } + } + } + + test("missing branches, if added is total") { + val pats = Gen.choose(0, 10).flatMap(Gen.listOfN(_, genPattern)) + + forAll(pats) { pats => + val tc = TotalityCheck(TypeEnv.empty) + tc.missingBranches(pats) match { + case Left(_) => () + case Right(rest) => + tc.missingBranches(pats ::: rest) match { + case Left(err) => fail(err.toString) + case Right(Nil) => succeed + case Right(rest1) => + fail(s"after adding $rest we still need $rest1") + } + } + } + } } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala index bde4be31c..d2255ae73 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala @@ -69,4 +69,20 @@ class TypeTest extends FunSuite { } } + test("if Type.hasNoVars then freeVars is empty") { + forAll(NTypeGen.genDepth03) { t => + if (Type.hasNoVars(t)) assert(Type.freeTyVars(t :: Nil).isEmpty) + else () + } + } + + test("hasNoVars fully recurses") { + forAll(NTypeGen.genDepth03) { t => + val allT = NTypeGen.allTypesIn(t) + val hnv = Type.hasNoVars(t) + + if (hnv) assert(allT.forall(Type.hasNoVars), "hasNoVars == true") + else assert(allT.exists { t => !Type.hasNoVars(t) }, "hasNoVars == false") + } + } } From 457b91b68b31fa48afba20a8a3788c20be3e174e Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Mon, 5 Nov 2018 06:58:54 +0000 Subject: [PATCH 6/8] turn on checks left off, fix another issue --- .../main/scala/org/bykn/bosatsu/TotalityCheck.scala | 10 +++++++++- .../src/test/scala/org/bykn/bosatsu/TotalityTest.scala | 5 +++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala index b2801d71d..c07caae9f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala @@ -110,7 +110,15 @@ case class TotalityCheck(inEnv: TypeEnv) { case (Left(_) :: tail, Right(_) :: _) => val zero = tail val oneOrMore = Right(WildCard) :: lp - (difference0List(zero, rp), difference0List(oneOrMore, rp)).mapN(_ ::: _) + // If the left and right are disjoint, + // this creates a different representation + // of the left + (difference0List(zero, rp), difference0List(oneOrMore, rp)) + .mapN { + case ((z :: Nil), (o :: Nil)) if eqPat.eqv(z, ListPat(zero)) && eqPat.eqv(o, ListPat(oneOrMore)) => + ListPat(lp) :: Nil + case (zz, oo) => zz ::: oo + } case (_, Left(_) :: rtail) if matchesEmpty(rtail) => // this is a total match Right(Nil) diff --git a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala index c69552a95..b8eb104fd 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala @@ -353,8 +353,9 @@ enum Either: Left(l), Right(r) */ manualTest("[[_, _], [[*foo]]]") manualTest("[['foo'], [*foo, [*_]]]") + manualTest("[[*_, [_]], [1]]") - //forAll(genPattern, genPattern)(law(_, _)) + forAll(genPattern, genPattern)(law(_, _)) } test("difference returns distinct values") { @@ -430,7 +431,7 @@ enum Either: Left(l), Right(r) } manualTest("[[*foo, bar], [baz]]") - //forAll(genPattern, genPattern)(law) + forAll(genPattern, genPattern)(law) } test("a - a = 0") { From 8cca0ca47513cc9bfd9c5c8ce582a8392a077d13 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Tue, 6 Nov 2018 09:16:05 +0000 Subject: [PATCH 7/8] fix more bugs --- .../org/bykn/bosatsu/TotalityCheck.scala | 372 +++++++++++------- .../scala/org/bykn/bosatsu/TotalityTest.scala | 46 ++- 2 files changed, 265 insertions(+), 153 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala index c07caae9f..551b7ec3b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala @@ -72,9 +72,13 @@ case class TotalityCheck(inEnv: TypeEnv) { case Right(_) :: _ => false } + /** + * By invariant, we never allow invalid patterns to enter + * this method + */ private def difference0List( lp: List[Either[Option[String], Pattern[Cons, Type]]], - rp: List[Either[Option[String], Pattern[Cons, Type]]]): Res[Patterns] = { + rp: List[Either[Option[String], Pattern[Cons, Type]]]): Res[List[ListPat[Cons, Type]]] = (lp, rp) match { case (Nil, Nil) => // total overlap @@ -108,6 +112,13 @@ case class TotalityCheck(inEnv: TypeEnv) { Right(ListPat(Right(WildCard) :: lp) :: Nil) else Right(ListPat(lp) :: Nil) case (Left(_) :: tail, Right(_) :: _) => + /** + * Note since we only allow a single splice, + * tail has no splices, and is thus finite + * + * This is using the rule: + * [*_, rest] = [rest] | [_, *_, rest] + */ val zero = tail val oneOrMore = Right(WildCard) :: lp // If the left and right are disjoint, @@ -115,7 +126,9 @@ case class TotalityCheck(inEnv: TypeEnv) { // of the left (difference0List(zero, rp), difference0List(oneOrMore, rp)) .mapN { - case ((z :: Nil), (o :: Nil)) if eqPat.eqv(z, ListPat(zero)) && eqPat.eqv(o, ListPat(oneOrMore)) => + case (zz, oo) + if leftIsSuperSet(zz, ListPat(zero)) && + leftIsSuperSet(oo, ListPat(oneOrMore)) => ListPat(lp) :: Nil case (zz, oo) => zz ::: oo } @@ -123,28 +136,15 @@ case class TotalityCheck(inEnv: TypeEnv) { // this is a total match Right(Nil) case (_, Left(_) :: _) => - // if this pattern ends with Left(_) we have - // a hard match problem on our hands. For now, we ban it: - val revRight = rp.reverse - revRight match { - case Left(_) :: tail if !matchesEmpty(tail) => - Left(NonEmptyList.of(MultipleSplicesInPattern(ListPat(rp), inEnv))) - case _ => - // we can make progress: - - // In this branch, the right cannot match - // the empty list, but the left side can - // we could in principle match a finite - // list from either direction, so we reverse - // and try again - difference0List(lp.reverse, revRight) - .map(_.map { - case ListPat(diff) => ListPat(diff.reverse) - case other => sys.error(s"unreachable: list patterns can't difference to non-list: $other") - }) - } + // we know the right can't end in Left since + // it starts with Left and the tail is not empty + // we can make progress: + + difference0List(lp.reverse, rp.reverse) + .map(_.map { + case ListPat(diff) => ListPat(diff.reverse) + }) } - } def difference0(left: Pattern[Cons, Type], right: Pattern[Cons, Type]): Res[Patterns] = (left, right) match { @@ -154,14 +154,17 @@ case class TotalityCheck(inEnv: TypeEnv) { Right(left :: Nil) case (WildCard | Var(_), _) if isTotal(right) == Right(true) => Right(Nil) - case (WildCard, ListPat(rp)) => + case (WildCard, listPat@ListPat(rp)) => // _ is the same as [*_] for well typed expressions - difference0List(Left(None) :: Nil, rp) - case (Var(v), ListPat(rp)) => + checkListPats(listPat :: Nil) *> + difference0List(Left(None) :: Nil, rp) + case (Var(v), listPat@ListPat(rp)) => // v is the same as [*v] for well typed expressions - difference0List(Left(Some(v)) :: Nil, rp) - case (ListPat(lp), ListPat(rp)) => - difference0List(lp, rp) + checkListPats(listPat :: Nil) *> + difference0List(Left(Some(v)) :: Nil, rp) + case (left@ListPat(lp), right@ListPat(rp)) => + checkListPats(left :: right :: Nil) *> + difference0List(lp, rp) case (Literal(_), ListPat(_) | PositionalStruct(_, _)) => Right(left :: Nil) case (ListPat(_), Literal(_) | PositionalStruct(_, _)) => @@ -222,32 +225,44 @@ case class TotalityCheck(inEnv: TypeEnv) { } case (PositionalStruct(_, _), PositionalStruct(_, _)) => Right(left :: Nil) + } + - // case _ => - // // There is no overlap - // Right(left :: Nil): Res[Patterns] + private def checkListPats(pats: List[ListPat[Cons, Type]]): Res[Unit] = { + def hasMultiple(ps: ListPat[Cons, Type]): Boolean = + ps.parts.count { + case Left(_) => true + case Right(_) => false + } > 1 + + pats.filter(hasMultiple) match { + case Nil => Right(()) + case h :: tail => + Left(NonEmptyList(h, tail).map(MultipleSplicesInPattern(_, inEnv))) } + } def intersection( left: Pattern[Cons, Type], - right: Pattern[Cons, Type]): Res[List[Pattern[Cons, Type]]] = + right: Pattern[Cons, Type]): Res[Option[Pattern[Cons, Type]]] = (left, right) match { - case (Var(va), Var(vb)) => Right(List(Var(Ordering[String].min(va, vb)))) - case (WildCard, v) => Right(List(v)) - case (v, WildCard) => Right(List(v)) - case (Var(_), v) => Right(List(v)) - case (v, Var(_)) => Right(List(v)) + case (Var(va), Var(vb)) => Right(Some(Var(Ordering[String].min(va, vb)))) + case (WildCard, v) => Right(Some(v)) + case (v, WildCard) => Right(Some(v)) + case (Var(_), v) => Right(Some(v)) + case (v, Var(_)) => Right(Some(v)) case (Annotation(p, _), t) => intersection(p, t) case (t, Annotation(p, _)) => intersection(t, p) case (Literal(a), Literal(b)) => - if (a == b) Right(List(left)) - else Right(Nil) - case (Literal(_), _) => Right(Nil) - case (_, Literal(_)) => Right(Nil) - case (ListPat(leftL), ListPat(rightL)) => - intersectionList(leftL, rightL) - case (ListPat(_), _) => Right(Nil) - case (_, ListPat(_)) => Right(Nil) + if (a == b) Right(Some(left)) + else Right(None) + case (Literal(_), _) => Right(None) + case (_, Literal(_)) => Right(None) + case (lp@ListPat(leftL), rp@ListPat(rightL)) => + checkListPats(lp :: rp :: Nil) *> + intersectionList(leftL, rightL) + case (ListPat(_), _) => Right(None) + case (_, ListPat(_)) => Right(None) case (PositionalStruct(ln, lps), PositionalStruct(rn, rps)) => if (ln == rn) { val check = for { @@ -255,56 +270,71 @@ case class TotalityCheck(inEnv: TypeEnv) { _ <- checkArity(rn, rps.size, right) } yield () - type ResList[A] = Res[List[A]] - implicit val app = Applicative[Res].compose(Applicative[List]) + type ResOption[A] = Res[Option[A]] + implicit val app = Applicative[Res].compose(Applicative[Option]) check >> - lps.zip(rps).traverse[ResList, Pattern[Cons, Type]] { + lps.zip(rps).traverse[ResOption, Pattern[Cons, Type]] { case (l, r) => intersection(l, r) } .map(_.map(PositionalStruct(ln, _))) } - else Right(Nil) + else Right(None) } - private def intersectionList(leftL: List[ListPatElem], rightL: List[ListPatElem]): Res[Patterns] = { + // invariant: each input has at most 1 splice pattern. This should be checked by callers. + private def intersectionList(leftL: List[ListPatElem], rightL: List[ListPatElem]): Res[Option[ListPat[Cons, Type]]] = { def left = ListPat(leftL) (leftL, rightL) match { case (_, Left(_) :: tail) if matchesEmpty(tail) => // the right hand side is a top value, it can match any list, so intersection with top is // left - Right(left :: Nil) + Right(Some(left)) case (Left(_) :: tail, _) if matchesEmpty(tail) => // the left hand side is a top value, it can match any list, so intersection with top is // right - Right(ListPat(rightL) :: Nil) - case (Nil, Nil) => Right(List(left)) - case (Nil, Right(_) :: _) => Right(Nil) + Right(Some(ListPat(rightL))) + case (Nil, Nil) => Right(Some(left)) + case (Nil, Right(_) :: _) => Right(None) case (Nil, Left(_) :: _) | (Left(_) :: _, Nil) => // the non Nil patterns can't match empty due to the above: - Right(Nil) - case (Right(_) :: _, Nil) => Right(Nil) + Right(None) + case (Right(_) :: _, Nil) => Right(None) case (Right(lh) :: lt, Right(rh) :: rt) => intersection(lh, rh).flatMap { - case Nil => Right(Nil) - case nonEmpty => + case None => Right(None) + case Some(h) => intersectionList(lt, rt) - .map(_.flatMap { - case ListPat(ts) => - // this could create duplicates - nonEmpty - .map { h => ListPat(Right(h) :: ts) } - .distinct - case other => sys.error(s"unreachable: list patterns can't intersect to non-list: $other") + .map(_.map { + case ListPat(ts) => ListPat(Right(h) :: ts) }) } case (Right(lh) :: lt, Left(rh) :: rt) => - // a n (b0 + b1) = (a n b0) + (a n b1) - val zero = rt - val oneOrMore = Right(WildCard) :: rightL - for { - withZ <- intersectionList(leftL, zero) - with0 <- intersectionList(leftL, oneOrMore) - } yield (withZ ::: with0).distinct + // we know rt is not empty, because otherwise it would + // matchEmpty above + // + // if lt does not end with a Left, + // we can intersect the ends, and repeat. + // + // This leaves only the case of + // [a, b... *c] n [*d, e, f]. + // which is [a, b, ... ,*_, e, f...] + lt.lastOption match { + case None => + // we have a singleton list matching at least one after the splice: + // only zero from the splice can match + intersectionList(leftL, rt) + case Some(Right(_)) => + // can reverse and and recurse + intersectionList(leftL.reverse, rightL.reverse) + .map(_.map { + case ListPat(res) => ListPat(res.reverse) + }) + case Some(Left(_)) => + // left side can have infinite right, right side can have infinite left: + val leftSide = leftL.init + val rightSide = rt + Right(Some(ListPat(leftSide ::: (Left(None) :: rightSide)))) + } case (Left(_) :: _, Right(_) :: _) => // intersection is symmetric intersectionList(rightL, leftL) @@ -317,51 +347,104 @@ case class TotalityCheck(inEnv: TypeEnv) { * we can pad the left to be the same size * by adding wildcards, and repeat */ - def hasMultiple(ps: List[ListPatElem]): Boolean = - ps.exists { - case Left(_) => true - case Right(_) => false + /* + * make suffix of rt that lines up with lt + */ + val rtSize = rt.size + val ltSize = lt.size + val padSize = rtSize - ltSize + val (initLt, lastLt) = + if (padSize > 0) { + (List.empty[ListPatElem], List.fill(padSize)(Right(WildCard)) reverse_::: lt) } - - (hasMultiple(lt), hasMultiple(rt)) match { - case (true, true) => - Left(NonEmptyList.of( - MultipleSplicesInPattern(ListPat(rightL), inEnv), - MultipleSplicesInPattern(ListPat(leftL), inEnv))) - case (_, false) => - /* - * make suffix of rt that lines up with lt - */ - val rtSize = rt.size - val ltSize = lt.size - val padSize = rtSize - ltSize - val (initLt, lastLt) = - if (padSize > 0) { - (List.empty[ListPatElem], List.fill(padSize)(Right(WildCard)) reverse_::: lt) - } - else { - (lt.take(-padSize), lt.drop(-padSize)) - } - intersectionList(lastLt, rt) - .map(_.map { - case ListPat(tail) => - val m: ListPatElem = Left(if (a == b) a else None) - ListPat(m :: initLt ::: tail) - case other => sys.error(s"unreachable: list patterns can't intersect to non-list: $other") - }) - case (false, _) => - // intersection is symmetric - intersectionList(rightL, leftL) - } + else { + (lt.take(-padSize), lt.drop(-padSize)) + } + intersectionList(lastLt, rt) + .map(_.map { + case ListPat(tail) => + val m: ListPatElem = Left(if (a == b) a else None) + ListPat(m :: initLt ::: tail) + }) } } + /* + * TODO, this a little weak now, would be great to make this tight and directly + * tested. I think introducing union patterns will force our hand + * + * This is private because it is currently an approximation that sometimes + * give false negatives + */ + private def leftIsSuperSet(superSet: Patterns, subSet: Pattern[Cons, Type]): Boolean = { + // This is true, but doesn't terminate + // superSet match { + // case Nil => false + // case h :: tail => + // difference(subSet, h) match { + // case Left(_) => false + // case Right(newSubs) => + // leftIsSuperSet(tail, newSubs) + // } + // } + def loop(superSet: Patterns, subSet: Pattern[Cons, Type]): Boolean = + (superSet, subSet) match { + case ((WildCard | Var(_)) :: _, _) => true + case (_, Annotation(p, _)) => loop(superSet, p) + case (_, (WildCard | Var(_))) => false // we never call this on a total superset + case ((Literal(a) :: tail), Literal(b)) if a == b => true + case ((Literal(_) :: tail), notLit) => loop(tail, notLit) + case (Annotation(p, _) :: tail, sub) => loop(p :: tail, sub) + case (_, PositionalStruct(psub, partsSub)) => + val partsSize = partsSub.size + val structs = superSet.collect { case PositionalStruct(n, parts) if n == psub => parts } + def toList(p: Patterns): ListPat[Cons, Type] = + ListPat(p.map(Right(_))) + val subListPat = toList(partsSub) + loop(structs.map(toList), subListPat) + case (PositionalStruct(_, _) :: tail, ListPat(_) | Literal(_)) => loop(tail, subSet) + case ((left@ListPat(_)) :: tail, right@ListPat(_)) => + // in case + val nonList = tail.filter { + case ListPat(_) => false + case _ => true + } + val tailLists: List[ListPat[Cons, Type]] = tail.collect { case lp@ListPat(_) => lp } + loop(nonList, right) || listSuper(left :: tailLists, right) + case (Nil, _) => false + } + + loop(superSet, subSet) + } + + + /** + * This is the complex part of this problem + * [] | [_, *_] == [*_] + * [] | [*_, _] == [*_] + * + * we could also concat onto the front or back + */ + private def listSuper(left: List[ListPat[Cons, Type]], right: ListPat[Cons, Type]): Boolean = + left.exists(eqPat.eqv(_, right)) + // case (ListPat(Right(p) :: lrest) :: tail, ListPat(Right(subp) :: subrest)) => + // (listSuper(p :: Nil, subp) && listSuper(ListPat(lrest) :: Nil, ListPat(subrest))) || + // listSuper(tail, subSet) + // case ((lp@ListPat(Left(_) :: lrest)) :: tail, ListPat(Right(_) :: subrest)) => + // // the left can absorb this right + // (listSuper(lp :: Nil, ListPat(subrest))) || + // listSuper(tail, subSet) + // case (ListPat(Right(p) :: suprest) :: tail, ListPat(Left(_) :: subrest)) => + // listSuper(superSet, ListPat(subrest)) + // case (ListPat(Nil) :: tail, ListPat(parts)) => + // parts.isEmpty || listSuper(tail, subSet) + /** * There the list is a tuple or product pattern * the left and right should be the same size and the result will be a list of lists * with the inner having the same size */ - def productDifference( + private def productDifference( zip: List[(Pattern[Cons, Type], Pattern[Cons, Type])] ): Res[List[List[Pattern[Cons, Type]]]] = /* @@ -379,36 +462,55 @@ case class TotalityCheck(inEnv: TypeEnv) { * * Note, if a1 - b1 = a1, this becomes: * ((a0 - b0) + (a0 n b0)) x a1 = a0 x a1 + * + * similarly: a0 - b0 = a0, implies a0 n b0 = 0 + * so, the difference is a0 x a1, or no difference... + * + * note that a0 - b0 <= a0, so if we have a0 - b0 >= a0, we know a0 - b0 = a0 */ zip match { case Nil => Right(Nil) // complete match case (lh, rh) :: tail => type Result = Res[List[List[Pattern[Cons, Type]]]] - val headDiff: Result = - difference0(lh, rh).map(_.map(_ :: tail.map(_._1))) - - val tailDiff: Result = - intersection(lh, rh).flatMap { - case Nil => - // we don't need to recurse on the rest - Right(Nil) - case nonEmpty => - productDifference(tail).map { pats => - nonEmpty.flatMap { intr => - pats.map(intr :: _) - } + + val headDiff = difference0(lh, rh) + + def noDiffResult: List[Patterns] = zip.map(_._1) :: Nil + + headDiff.right.flatMap { + case noDiff if leftIsSuperSet(noDiff, lh) => + Right(noDiffResult) + case hd => + val tailDiff: Result = + intersection(lh, rh).flatMap { + case None => + // we don't need to recurse on the rest + Right(Nil) + case Some(intr) => + productDifference(tail).map { pats => + pats.map(intr :: _) + } } - } - tailDiff match { - case Right((intr :: td) :: Nil) - if tail.zip(td).forall { case ((a1, _), d) => eqPat.eqv(a1, d) } => - // this is the rule that if the rest has no diff, the first - // part has no diff - // not needed for correctness, but useful for normalizing - Right(zip.map(_._1) :: Nil) - case _ => - (headDiff, tailDiff).mapN(_ ::: _) + def productAsList(prod: List[Pattern[Cons, Type]]): Pattern[Cons, Type] = + ListPat(prod.map(Right(_))) + + tailDiff.map { union => + // we have already attached on each inner list + val unionAsList = union.map { t => productAsList(t.tail) } + val tailProd = productAsList(tail.map(_._1)) + + if (leftIsSuperSet(unionAsList, tailProd)) { + // this is the rule that if the rest has no diff, the first + // part has no diff + // not needed for correctness, but useful for normalizing + noDiffResult + } + else { + val headDiffWithRest = hd.map(_ :: tail.map(_._1)) + headDiffWithRest ::: union + } + } } } @@ -448,6 +550,11 @@ case class TotalityCheck(inEnv: TypeEnv) { } } + /** + * recursively replace as much as possible with Wildcard + * This should match exactly the same set for the same type as + * the previous pattern, without any binding names + */ def normalizePattern(p: Pattern[Cons, Type]): Pattern[Cons, Type] = isTotal(p) match { case Right(true) => WildCard @@ -470,7 +577,6 @@ case class TotalityCheck(inEnv: TypeEnv) { PositionalStruct(n, params.map(normalizePattern)) } } - /** * This tells if two patterns for the same type * would match the same values diff --git a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala index b8eb104fd..fc0594f06 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala @@ -16,7 +16,7 @@ class TotalityTest extends FunSuite { import TestParseUtils._ implicit val generatorDrivenConfig = - PropertyCheckConfiguration(minSuccessful = 20000) + PropertyCheckConfiguration(minSuccessful = 50000) //PropertyCheckConfiguration(minSuccessful = 50) //PropertyCheckConfiguration(minSuccessful = 5) @@ -197,15 +197,15 @@ enum Either: Left(l), Right(r) val p0 :: p1 :: Nil = patterns("[[*_], [*_, _]]") TotalityCheck(TypeEnv.empty).intersection(p0, p1) match { case Left(err) => fail(err.toString) - case Right(intr :: Nil) => assert(p1 == intr) - case Right(many) => fail(s"expected exactly one intersection: ${many.map(showPat)}") + case Right(Some(intr)) => assert(p1 == intr) + case Right(None) => fail("expected exactly one intersection") } val p2 :: p3 :: Nil = patterns("[[*_], [_, _]]") TotalityCheck(TypeEnv.empty).intersection(p2, p3) match { case Left(err) => fail(err.toString) - case Right(intr :: Nil) => assert(p3 == intr) - case Right(many) => fail(s"expected exactly one intersection: ${many.map(showPat)}") + case Right(Some(intr)) => assert(p3 == intr) + case Right(None) => fail("expected exactly one intersection") } } @@ -244,8 +244,8 @@ enum Either: Left(l), Right(r) TotalityCheck(TypeEnv.empty) .intersection(p, p) match { case Left(_) => () // we can often fail now due to bad patterns - case Right(h :: Nil) => assert(h == p) - case Right(many) => fail(s"expected one intersection, found many: $many") + case Right(Some(h)) => assert(h == p) + case Right(None) => fail(s"expected one intersection, found none") } forAll(genPattern)(law) @@ -304,7 +304,7 @@ enum Either: Left(l), Right(r) val tc = TotalityCheck(TypeEnv.empty) tc.intersection(a, b) match { case Left(_) => () // we can often fail now due to bad patterns - case Right(Nil) => + case Right(None) => tc.difference0(a, b) match { case Left(err) => () // due to our generators, we fail a lot case Right(h :: Nil) => @@ -354,6 +354,11 @@ enum Either: Left(l), Right(r) manualTest("[[_, _], [[*foo]]]") manualTest("[['foo'], [*foo, [*_]]]") manualTest("[[*_, [_]], [1]]") + /* + * This is hard because they are orthogonal, one is list of 2, one is a list of three, + * but the first element has a difference + */ + manualTest("[[[cvspypdahs, *_], ['jnC']], [[*_, 5921457613766301145, 'j'], p, bmhvhs]]") forAll(genPattern, genPattern)(law(_, _)) } @@ -370,18 +375,6 @@ enum Either: Left(l), Right(r) } } } - test("intersection returns distinct values") { - forAll(genPattern, genPattern) { (a, b) => - // this would be better if we could get - // generate random patterns from a sane - // type Env... thats a TODO) - val tc = TotalityCheck(TypeEnv.empty) - tc.intersection(a, b) match { - case Left(_) => () // we can often fail now due to bad patterns - case Right(c) => assert(c == c.distinct) - } - } - } test("a - b = c then c - b == c, because we have already removed all of b") { def law( @@ -465,4 +458,17 @@ enum Either: Left(l), Right(r) } } } + + test("missing branches are distinct") { + val pats = Gen.choose(0, 10).flatMap(Gen.listOfN(_, genPattern)) + + forAll(pats) { pats => + val tc = TotalityCheck(TypeEnv.empty) + tc.missingBranches(pats) match { + case Left(_) => () + case Right(rest) => + assert(rest == rest.distinct) + } + } + } } From 6ad9ddc8f22766675dc07afe44c8ca8d97c99b58 Mon Sep 17 00:00:00 2001 From: Oscar Boykin Date: Tue, 6 Nov 2018 09:25:05 +0000 Subject: [PATCH 8/8] fix doc error --- core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala index 551b7ec3b..a41cfba70 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala @@ -112,7 +112,7 @@ case class TotalityCheck(inEnv: TypeEnv) { Right(ListPat(Right(WildCard) :: lp) :: Nil) else Right(ListPat(lp) :: Nil) case (Left(_) :: tail, Right(_) :: _) => - /** + /* * Note since we only allow a single splice, * tail has no splices, and is thus finite *