Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize compile-time linear arithmetic ops in comparer #16138

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
160 changes: 160 additions & 0 deletions compiler/src/dotty/tools/dotc/core/CompiletimeOpsComparer.scala
@@ -0,0 +1,160 @@
package dotty.tools.dotc.core

import Types.*, Contexts.*, Symbols.*, Constants.*, Definitions.*,
Denotations.*, Decorators.*, Names.*, StdNames.*, Periods.*

abstract class CompiletimeOpsComparer[N <: Matchable]:
def moduleClass(using Context): Symbol
def addType(using Context): Type
def multiplyType(using Context): Type
def zero: N
def one: N
def minusOne: N
def add(x: N, y: N): N
def multiply(x: N, y: N): N
def isN(x: Any): Boolean
def toN(x: Any): N

def equiv(a: Type, b: Type)(using Context) = (a, b) match
case (_, Op(_, _)) | (Op(_, _), _) =>
sumFromTypeNormalizedCached(a) == sumFromTypeNormalizedCached(b)
case _ => false

def isSingletonOp(tp: Type)(using Context): Boolean = tp match
case Op(_, args) => args.forall(isSingletonOp)
case tv: TypeVar if tv.isInstantiated => isSingletonOp(tv.underlying)
case tp => tp.isStable

object Op:
def unapply(tp: Type)(using Context): Option[(Name, List[Type])] = tp match
case AppliedType(tycon: TypeRef, args)
if tycon.symbol.denot != SymDenotations.NoDenotation && tycon.symbol.owner == moduleClass =>
Some((tycon.symbol.name, args))
case _ => None

val minusOneProd = Product(Nil, minusOne)

def negate(sum: Sum)(using Context) = Sum(sum.terms.map(_ * minusOneProd))

val sumFromTypeNormalizedCached =
cachedTypeOp(tp => sumFromType(tp).normalized)

def sumFromType(tp: Type)(using Context): Sum =
underlyingSingletonDeep(tp.dealias) match
case ConstantType(Constant(c)) if isN(c) =>
Sum(List(Product(Nil, toN(c))))
case Op(tpnme.Negate, List(x)) =>
negate(sumFromType(x))
case Op(tpnme.Plus, List(x, y)) =>
sumFromType(x) + sumFromType(y)
case Op(tpnme.Minus, List(x, y)) =>
sumFromType(x) + negate(sumFromType(y))
case Op(tpnme.Times, _) =>
Sum(List(productFromType(tp)))
case tp =>
Sum(List(Product(List(tp))))

def productFromType(tp: Type)(using Context): Product =
underlyingSingletonDeep(tp.dealias) match
case ConstantType(Constant(c)) if isN(c) =>
Product(Nil, toN(c))
case Op(tpnme.Times, List(x, y)) =>
productFromType(x) * productFromType(y)
case Op(tpnme.Negate | tpnme.Plus | tpnme.Minus, _) =>
Product(List(sumFromType(tp)))
case tp =>
Product(List(tp))

case class Sum(terms: List[Product] = Nil):
infix def +(that: Sum) =
Sum(terms ++ that.terms)
def normalized(using Context): Sum =
val (singletonTerms, nonSingletonTerms) =
terms
.map(_.normalized)
.partition(_.isSingleton)
val normalizedTerms =
singletonTerms
.groupMapReduce(_.facts)(_.c)(add)
.toList
.filter({
case (_, c) if c == zero => false
case _ => true
})
.map(Product.apply)
.concat(nonSingletonTerms)
.sortBy(_.hashCode())
Sum(normalizedTerms)
def isSingleton(using Context): Boolean =
terms.forall(_.isSingleton)
def show(using Context): String =
terms.map(_.show).mkString(" +! ")

case class Product(facts: List[Sum | Type] = Nil, c: N = one):
infix def *(that: Product)(using Context) =
Product(facts ++ that.facts, multiply(c, that.c))
def normalized(using Context): Product =
val normalizedFacts =
facts
.map[Sum | Type]({
case s: Sum => s.normalized
case tp: Type => tp
})
.sortBy(_.hashCode())
Product(normalizedFacts, c)
def isSingleton(using Context): Boolean =
facts
.forall({
case p: Sum => p.isSingleton
case tp: Type => tp.isStable
})
def show(using Context): String =
facts
.map({
case p: Sum => p.show
case tp: Type => tp.show
})
.mkString(" *! ") + " *! " + c

def underlyingSingletonDeep(tp: Type)(using Context): Type = tp match
case tp: SingletonType if tp.underlying.isStable =>
underlyingSingletonDeep(tp.underlying)
case tv: TypeVar if tv.isInstantiated =>
underlyingSingletonDeep(tv.underlying)
case _ => tp

object IntOpsComparer extends CompiletimeOpsComparer[Int]:
def moduleClass(using Context) = defn.CompiletimeOpsIntModuleClass
def addType(using Context) = defn.CompiletimeOpsInt_Add
def multiplyType(using Context) = defn.CompiletimeOpsInt_Multiply
def zero = 0
def one = 1
def minusOne = -1
def add(x: Int, y: Int) = x + y
def multiply(x: Int, y: Int) = x * y
def isN(x: Any) = x.isInstanceOf[Int]
def toN(x: Any) = x.asInstanceOf[Int]

object LongOpsComparer extends CompiletimeOpsComparer[Long]:
def moduleClass(using Context) = defn.CompiletimeOpsLongModuleClass
def addType(using Context) = defn.CompiletimeOpsLong_Add
def multiplyType(using Context) = defn.CompiletimeOpsLong_Multiply
def zero = 0L
def one = 1L
def minusOne = -1L
def add(x: Long, y: Long) = x + y
def multiply(x: Long, y: Long) = x * y
def isN(x: Any) = x.isInstanceOf[Long]
def toN(x: Any) = x.asInstanceOf[Long]

def cachedTypeOp[T](f: Type => Context ?=> T): Type => Context ?=> T =
val cache = collection.mutable.Map.empty[Type, (Period, T)]
def cached(tp: Type)(using Context) =
if tp.isProvisional then f(tp)
else
val res = cache.updateWith(tp) {
case Some((p, v)) if p == ctx.period => Some((p, v))
case Some(_) | None => Some((ctx.period, f(tp)))
}
res.get._2
cached
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Expand Up @@ -272,7 +272,11 @@ class Definitions {
@tu lazy val CompiletimeOpsPackage: Symbol = requiredPackage("scala.compiletime.ops")
@tu lazy val CompiletimeOpsAnyModuleClass: Symbol = requiredModule("scala.compiletime.ops.any").moduleClass
@tu lazy val CompiletimeOpsIntModuleClass: Symbol = requiredModule("scala.compiletime.ops.int").moduleClass
@tu lazy val CompiletimeOpsInt_Add: Type = CompiletimeOpsIntModuleClass.requiredTypeRef(tpnme.Plus)
@tu lazy val CompiletimeOpsInt_Multiply: Type = CompiletimeOpsIntModuleClass.requiredTypeRef(tpnme.Times)
@tu lazy val CompiletimeOpsLongModuleClass: Symbol = requiredModule("scala.compiletime.ops.long").moduleClass
@tu lazy val CompiletimeOpsLong_Add: Type = CompiletimeOpsLongModuleClass.requiredTypeRef(tpnme.Plus)
@tu lazy val CompiletimeOpsLong_Multiply: Type = CompiletimeOpsLongModuleClass.requiredTypeRef(tpnme.Times)
@tu lazy val CompiletimeOpsFloatModuleClass: Symbol = requiredModule("scala.compiletime.ops.float").moduleClass
@tu lazy val CompiletimeOpsDoubleModuleClass: Symbol = requiredModule("scala.compiletime.ops.double").moduleClass
@tu lazy val CompiletimeOpsStringModuleClass: Symbol = requiredModule("scala.compiletime.ops.string").moduleClass
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Denotations.scala
Expand Up @@ -345,6 +345,9 @@ object Denotations {
info.member(name).requiredSymbol("type", name, this)(_.isType).asType
}

def requiredTypeRef(pname: PreName)(using Context): TypeRef =
requiredType(pname).typeRef

/** The alternative of this denotation that has a type matching `targetType` when seen
* as a member of type `site` and that has a target name matching `targetName`, or
* `NoDenotation` if none exists.
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Expand Up @@ -1387,7 +1387,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
if (defn.isCompiletime_S(tp.tycon.typeSymbol)) compareS(tp, other, fromBelow)
else {
val folded = tp.tryCompiletimeConstantFold
if (fromBelow) recur(other, folded) else recur(folded, other)
def compareFolded = if (fromBelow) recur(other, folded) else recur(folded, other)
compareFolded || IntOpsComparer.equiv(tp, other) || LongOpsComparer.equiv(tp, other)
}
}

Expand Down
9 changes: 9 additions & 0 deletions tests/neg/singleton-ops-int-normalize.scala
@@ -0,0 +1,9 @@
import scala.compiletime.ops.int.*

object Test:
type Pos <: Int
type Neg <: Int

// Non-singleton types are not grouped.
summon[Pos - Pos + Neg =:= Neg] // error
summon[Pos + Pos =:= 2 * Pos] // error
71 changes: 71 additions & 0 deletions tests/pos/singleton-ops-int-normalize.scala
@@ -0,0 +1,71 @@
// Generated from singleton-ops-long-normalize.scala by singleton-ops-int-normalize_make.sh
import scala.compiletime.ops.int.*

object Test:
// Operations with constant arguments are constant-folded.
summon[3 =:= 2 + 1]
summon[1 + 2 =:= 3]

// Non-constant arguments are sorted.
val m: Int = 2
val n: Int = 3
summon[1 + m.type =:= 1 + m.type]
summon[1 + m.type =:= m.type + 1]
summon[1 + m.type + 1 =:= m.type + 2]
summon[m.type + n.type =:= n.type + m.type]
summon[m.type * n.type =:= n.type * m.type]
summon[2 * m.type * n.type =:= 2 * n.type * m.type]

// -x is normalized to -1 * x
summon[m.type - n.type =:= -1 * n.type + m.type]

// Summing x n times is normalized to x * n.
summon[2 * m.type =:= m.type + m.type]
summon[2 * m.type + 2 * m.type =:= m.type + 3 * m.type]
summon[2 * m.type * m.type =:= m.type * 2 * m.type]

// Addition is distributed over multiplication.
//summon[2 * (m.type + n.type) =:= 2 * m.type + 2 * n.type]
//summon[(m.type + n.type) * (m.type + n.type) =:= m.type * m.type + 2 * m.type * n.type + n.type * n.type]

// Works with TermRefs arguments referencing other TermRefs.
type SInt = Int & Singleton
final val x: Int = ???
final val y: x.type + 1 = ???
summon[y.type + 1 =:= x.type + 2]

// Terms are canceled
summon[1 + m.type -1 =:= m.type]
summon[1 + m.type - 1 =:= m.type]
summon[1 + m.type + Negate[1] =:= m.type]
summon[1 + m.type - m.type =:= 1]
summon[1 + m.type + Negate[m.type] =:= 1]

// Val with prefixes are correctly sorted.
object A:
val m: Int = 4
object B:
val m: Int = 5
import A.{m => am}
summon[B.m.type + am.type + m.type + A.m.type =:= m.type + A.m.type + B.m.type + am.type]

// Works with inline transparent def.
transparent inline def f[T <: Int & Singleton](t: T) = t + 5
val a: 9 = f(4)

// Arguments are normalized.
val b: 4 + (10/2) = ???
summon[b.type =:= 9]
val d: Singleton & Int = ???
summon[(10/2) + d.type =:= d.type + 5]

// Non-singleton types are also sorted.
type Pos <: Int
type Neg <: Int
summon[Pos + Neg =:= Neg + Pos]
// (But not grouped; see tests/neg/singleton-ops-int-normalize.scala)

// Non-singleton terms do not prevent singleton ones from being grouped.
def g1[T <: Int](x: T): 1 + T - 1 = x
def g2[T <: Int](x: T): x.type + T - x.type = x
def g3[T <: Int](x: T): T - 0 = x
3 changes: 3 additions & 0 deletions tests/pos/singleton-ops-int-normalize_make.sh
@@ -0,0 +1,3 @@
echo "// Generated from singleton-ops-long-normalize.scala by singleton-ops-int-normalize_make.sh" > singleton-ops-int-normalize.scala

sed -E -e "s/([0-9]+)L/\1/g" -e "s/Long/Int/g" -e "s/long/int/g" singleton-ops-long-normalize.scala >> singleton-ops-int-normalize.scala
70 changes: 70 additions & 0 deletions tests/pos/singleton-ops-long-normalize.scala
@@ -0,0 +1,70 @@
import scala.compiletime.ops.long.*

object Test:
// Operations with constant arguments are constant-folded.
summon[3L =:= 2L + 1L]
summon[1L + 2L =:= 3L]

// Non-constant arguments are sorted.
val m: Long = 2L
val n: Long = 3L
summon[1L + m.type =:= 1L + m.type]
summon[1L + m.type =:= m.type + 1L]
summon[1L + m.type + 1L =:= m.type + 2L]
summon[m.type + n.type =:= n.type + m.type]
summon[m.type * n.type =:= n.type * m.type]
summon[2L * m.type * n.type =:= 2L * n.type * m.type]

// -x is normalized to -1L * x
summon[m.type - n.type =:= -1L * n.type + m.type]

// Summing x n times is normalized to x * n.
summon[2L * m.type =:= m.type + m.type]
summon[2L * m.type + 2L * m.type =:= m.type + 3L * m.type]
summon[2L * m.type * m.type =:= m.type * 2L * m.type]

// Addition is distributed over multiplication.
//summon[2L * (m.type + n.type) =:= 2L * m.type + 2L * n.type]
//summon[(m.type + n.type) * (m.type + n.type) =:= m.type * m.type + 2L * m.type * n.type + n.type * n.type]

// Works with TermRefs arguments referencing other TermRefs.
type SLong = Long & Singleton
final val x: Long = ???
final val y: x.type + 1L = ???
summon[y.type + 1L =:= x.type + 2L]

// Terms are canceled
summon[1L + m.type -1L =:= m.type]
summon[1L + m.type - 1L =:= m.type]
summon[1L + m.type + Negate[1L] =:= m.type]
summon[1L + m.type - m.type =:= 1L]
summon[1L + m.type + Negate[m.type] =:= 1L]

// Val with prefixes are correctly sorted.
object A:
val m: Long = 4L
object B:
val m: Long = 5L
import A.{m => am}
summon[B.m.type + am.type + m.type + A.m.type =:= m.type + A.m.type + B.m.type + am.type]

// Works with inline transparent def.
transparent inline def f[T <: Long & Singleton](t: T) = t + 5L
val a: 9L = f(4L)

// Arguments are normalized.
val b: 4L + (10L/2L) = ???
summon[b.type =:= 9L]
val d: Singleton & Long = ???
summon[(10L/2L) + d.type =:= d.type + 5L]

// Non-singleton types are also sorted.
type Pos <: Long
type Neg <: Long
summon[Pos + Neg =:= Neg + Pos]
// (But not grouped; see tests/neg/singleton-ops-int-normalize.scala)

// Non-singleton terms do not prevent singleton ones from being grouped.
def g1[T <: Long](x: T): 1L + T - 1L = x
def g2[T <: Long](x: T): x.type + T - x.type = x
def g3[T <: Long](x: T): T - 0L = x