Skip to content

Commit

Permalink
Merge branch 'new-definition-typing' into specializing
Browse files Browse the repository at this point in the history
  • Loading branch information
LPTK authored Oct 4, 2023
2 parents 7edfb6c + 2976832 commit 8018cbb
Show file tree
Hide file tree
Showing 9 changed files with 543 additions and 40 deletions.
2 changes: 1 addition & 1 deletion shared/src/main/scala/mlscript/NuTypeDefs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,7 @@ class NuTypeDefs extends ConstraintSolver { self: Typer =>

val body_ty = td.sig match {
case S(sig) =>
typeType(sig)
ctx.nextLevel { implicit ctx: Ctx => typeType(sig) }
case N =>
err(msg"Type alias definition requires a right-hand side", td.toLoc)
}
Expand Down
61 changes: 27 additions & 34 deletions shared/src/main/scala/mlscript/TyperHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -921,11 +921,32 @@ abstract class TyperHelpers { Typer: Typer =>
res.toSortedMap
}

private def childrenMem(m: NuMember): List[ST] = m match {
case NuParam(nme, ty, pub) => ty.lb.toList ::: ty.ub :: Nil
case TypedNuFun(level, fd, ty) => ty :: Nil
private def childrenMem(m: NuMember): IterableOnce[ST] = m match {
case tf: TypedNuFun =>
tf.bodyType :: Nil
case als: TypedNuAls =>
als.tparams.iterator.map(_._2) ++ S(als.body)
case mxn: TypedNuMxn =>
mxn.tparams.iterator.map(_._2) ++
mxn.members.valuesIterator.flatMap(childrenMem) ++
S(mxn.superTy) ++
S(mxn.thisTy)
case cls: TypedNuCls =>
cls.tparams.iterator.map(_._2) ++
cls.params.toList.flatMap(_.flatMap(p => p._2.lb.toList ::: p._2.ub :: Nil)) ++
cls.auxCtorParams.toList.flatMap(_.values) ++
cls.members.valuesIterator.flatMap(childrenMem) ++
S(cls.thisTy) ++
S(cls.sign)
case trt: TypedNuTrt =>
trt.tparams.iterator.map(_._2) ++
trt.members.valuesIterator.flatMap(childrenMem) ++
S(trt.thisTy) ++
S(trt.sign) ++
trt.parentTP.valuesIterator.flatMap(childrenMem)
case p: NuParam =>
p.ty.lb.toList ::: p.ty.ub :: Nil
case TypedNuDummy(d) => Nil
case _ => ??? // TODO
}
def children(includeBounds: Bool): List[SimpleType] = this match {
case tv @ AssignedVariable(ty) => if (includeBounds) ty :: Nil else Nil
Expand All @@ -949,35 +970,7 @@ abstract class TyperHelpers { Typer: Typer =>
case ConstrainedType(cs, und) => cs.flatMap(lu => lu._1 :: lu._2 :: Nil) ::: und :: Nil
case SpliceType(fs) => fs.flatMap{ case L(l) => l :: Nil case R(r) => r.lb.toList ::: r.ub :: Nil}
case OtherTypeLike(tu) =>
// tu.childrenPol(PolMap.neu).map(tp => tp._1)
val ents = tu.implementedMembers.flatMap {
case tf: TypedNuFun =>
tf.bodyType :: Nil
case als: TypedNuAls =>
als.tparams.iterator.map(_._2) ++ S(als.body)
case mxn: TypedNuMxn =>
mxn.tparams.iterator.map(_._2) ++
mxn.members.valuesIterator.flatMap(childrenMem) ++
S(mxn.superTy) ++
S(mxn.thisTy)
case cls: TypedNuCls =>
cls.tparams.iterator.map(_._2) ++
cls.params.toList.flatMap(_.flatMap(p => p._2.lb.toList ::: p._2.ub :: Nil)) ++
cls.auxCtorParams.toList.flatMap(_.values) ++
cls.members.valuesIterator.flatMap(childrenMem) ++
S(cls.thisTy) ++
S(cls.sign) /* ++
S(cls.instanceType) // Not a real child; to remove */
case trt: TypedNuTrt =>
trt.tparams.iterator.map(_._2) ++
trt.members.valuesIterator.flatMap(childrenMem) ++
S(trt.thisTy) ++
S(trt.sign) ++
trt.parentTP.valuesIterator.flatMap(childrenMem)
case p: NuParam =>
p.ty.lb.toList ::: p.ty.ub :: Nil
case TypedNuDummy(d) => Nil
}
val ents = tu.implementedMembers.flatMap(childrenMem)
ents ::: tu.result.toList
}

Expand Down Expand Up @@ -1105,7 +1098,7 @@ abstract class TyperHelpers { Typer: Typer =>
info.result match {
case S(td: TypedNuAls) =>
assert(td.tparams.size === targs.size)
substSyntax(td.body)(td.tparams.lazyZip(targs).map {
subst(td.body, td.tparams.lazyZip(targs).map {
case (tp, ta) => SkolemTag(tp._2)(noProv) -> ta
}.toMap)
case S(td: TypedNuTrt) =>
Expand Down
2 changes: 1 addition & 1 deletion shared/src/main/scala/mlscript/helpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ trait TermImpl extends StatementImpl { self: Term =>
try R(toType_!.withLocOf(this)) catch {
case e: NotAType =>
import Message._
L(ErrorReport(msg"not a recognized type" -> e.trm.toLoc::Nil, newDefs=true)) }
L(ErrorReport(msg"Not a recognized type" -> e.trm.toLoc::Nil, newDefs=true)) }
protected def toType_! : Type = (this match {
case Var(name) if name.startsWith("`") => TypeVar(R(name.tail), N)
case Var(name) if name.startsWith("'") => TypeVar(R(name), N)
Expand Down
2 changes: 1 addition & 1 deletion shared/src/test/diff/basics/Datatypes.fun
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ data type List a of
Nil
Cons (head: a) (tail: List a)
//│ Parsed: data type List(...a) of {Nil; Cons(...'(' {[head: a,]} ')')(...'(' {[tail: List(...a),]} ')')};
//│ ╔══[ERROR] not a recognized type
//│ ╔══[ERROR] Not a recognized type
//│ ║ l.116: Cons (head: a) (tail: List a)
//│ ╙── ^^^^^^
//│ Desugared: type alias List[a] = Nil[a] | Cons[a]
Expand Down
Loading

0 comments on commit 8018cbb

Please sign in to comment.