Permalink
Browse files

First pass at implementing macros for typeclass operators in Spire.

This is a gigantic commit. It has some other necessary restructuring mixed in
as well, which probably should have been pulled out into its own commit. This
was mostly just to demonstrate that the approch works, and to try to get the
2.10.0 branch a bit more up-to-date.
  • Loading branch information...
1 parent bd0fea7 commit bab8af5eac6e70b961da5f5b0138f560f7a1b782 @non committed Jul 3, 2012
View
@@ -9,3 +9,8 @@
9. Match up concrete and generic APIs
10. Get compiler plugin for inlining implicits/ops working
11. Unify operators (e.g. /~, %, /%, **, ~^, etc).
+
+=======
+
+1. type class standardization plus guidelines
+2.
View
@@ -27,11 +27,11 @@ object MyBuild extends Build {
lazy val spire = Project("spire", file("."))
- lazy val examples = Project("examples", file("examples"))
+ lazy val examples = Project("examples", file("examples")).
dependsOn (spire)
- lazy val benchmark: Project = Project("benchmark", file("benchmark"))
- settings (benchmarkSettings: _*)
+ lazy val benchmark: Project = Project("benchmark", file("benchmark")).
+ settings (benchmarkSettings: _*).
dependsOn (spire)
def benchmarkSettings = Seq(
@@ -1,7 +1,9 @@
package spire.algebra
import spire.math._
+import spire.macros._
+import language.experimental.macros
import scala.{specialized => spec}
import scala.math.{abs, ceil, floor}
@@ -12,9 +14,9 @@ trait EuclideanRing[@spec(Int,Long,Float,Double) A] extends Ring[A] {
}
final class EuclideanRingOps[@spec(Int,Long,Float,Double) A](lhs:A)(implicit ev:EuclideanRing[A]) {
- def /~(rhs:A) = ev.quot(lhs, rhs)
- def %(rhs:A) = ev.mod(lhs, rhs)
- def /%(rhs:A) = (ev.quot(lhs, rhs), ev.mod(lhs, rhs))
+ def /~(rhs:A) = macro Macros.quot[A]
+ def %(rhs:A) = macro Macros.mod[A]
+ def /%(rhs:A) = macro Macros.quotmod[A]
}
object EuclideanRing {
@@ -1,7 +1,9 @@
package spire.algebra
import spire.math._
+import spire.macros._
+import language.experimental.macros
import scala.{specialized => spec}
trait Field[@spec(Int,Long,Float,Double) A] extends EuclideanRing[A] {
@@ -17,8 +19,8 @@ trait Field[@spec(Int,Long,Float,Double) A] extends EuclideanRing[A] {
}
final class FieldOps[@spec(Int,Long,Float,Double) A](lhs:A)(implicit ev:Field[A]) {
- def /(rhs:A) = ev.div(lhs, rhs)
- def isWhole = ev.isWhole(lhs)
+ def /(rhs:A) = macro Macros.div[A]
+ def isWhole() = macro Macros.isWhole[A]
}
object Field {
@@ -1,7 +1,9 @@
package spire.algebra
import spire.math._
+import spire.macros._
+import language.experimental.macros
import scala.{specialized => spec, math => mth}
import java.math.MathContext
@@ -21,7 +23,6 @@ trait NRoot[@spec(Double,Float,Int,Long) A] {
def sqrt(a: A): A = nroot(a, 2)
}
-
/**
* A type class for `EuclideanRing`s with `NRoot`s as well. Since the base
* requirement is only a `EuclideanRing`, we can provide instances for `Int`,
@@ -71,8 +72,10 @@ object FieldWithNRoot {
final class NRootOps[@spec(Double, Float, Int, Long) A](lhs: A)(implicit n: NRoot[A]) {
- def nroot(k: Int): A = n.nroot(lhs, k)
- def sqrt: A = n.sqrt(lhs)
+ //def nroot(k: Int): A = n.nroot(lhs, k)
+ //def sqrt: A = n.sqrt(lhs)
+ def nroot(rhs: Int): A = macro Macros.nroot[A]
+ def sqrt(): A = macro Macros.sqrt[A]
}
@@ -173,6 +176,9 @@ trait BigIntIsNRoot extends NRoot[BigInt] {
object NRoot {
+ @inline final def apply[A](implicit ev:NRoot[A]) = ev
+
+ implicit object BigIntIsNRoot extends BigIntIsNRoot
/**
* This will return the largest integer that meets some criteria. Specifically,
@@ -287,5 +293,3 @@ object NRoot {
BigDecimal(unscaled, newscale, ctxt)
}
}
-
-
@@ -1,11 +1,13 @@
package spire.algebra
-import spire.math._
-
import annotation.tailrec
+import language.experimental.macros
import scala.{specialized => spec}
import scala.math.{abs, ceil, floor, pow => mpow}
+import spire.math._
+import spire.macros._
+
/**
* Ring represents a set (A) that is a group over addition (+) and a monoid
* over multiplication (*).
@@ -48,14 +50,14 @@ trait Ring[@spec(Int,Long,Float,Double) A] {
}
final class RingOps[@spec(Int,Long,Float,Double) A](lhs:A)(implicit ev:Ring[A]) {
- def unary_- = ev.negate(lhs)
+ def unary_-() = macro Macros.negate[A]
- def -(rhs:A) = ev.minus(lhs, rhs)
- def +(rhs:A) = ev.plus(lhs, rhs)
- def *(rhs:A) = ev.times(lhs, rhs)
+ def -(rhs:A) = macro Macros.minus[A]
+ def +(rhs:A) = macro Macros.plus[A]
+ def *(rhs:A) = macro Macros.times[A]
- def pow(rhs:Int) = ev.pow(lhs, rhs)
- def **(rhs:Int) = ev.pow(lhs, rhs)
+ def pow(rhs:Int) = macro Macros.pow[A]
+ def **(rhs:Int) = macro Macros.pow[A]
}
object Ring {
@@ -1,6 +1,8 @@
package spire.algebra
import spire.math.Eq
+import spire.macros._
+import language.experimental.macros
trait Semigroup[A] {
def op(x:A, y:A): A
@@ -11,5 +13,5 @@ object Semigroup {
}
final class SemigroupOps[A](lhs:A)(implicit ev:Semigroup[A]) {
- def |+|(rhs:A) = ev.op(lhs, rhs)
+ def |+|(rhs:A) = macro Macros.op[A]
}
@@ -1,7 +1,9 @@
package spire.algebra
-import spire.math.{Fractional, Trig, Order, Real, Rational, Complex}
+import spire.math._
+import spire.macros._
+import language.experimental.macros
import java.lang.{Math => mth}
import scala.{ specialized => spec }
@@ -37,9 +39,12 @@ object Signed extends SignedLow {
}
final class SignedOps[@spec(Double, Float, Int, Long) A](lhs: A)(implicit s: Signed[A]) {
- def abs: A = s.abs(lhs)
- def sign: Sign = s.sign(lhs)
- def signum: Int = s.signum(lhs)
+ //def abs: A = s.abs(lhs)
+ //def sign: Sign = s.sign(lhs)
+ //def signum: Int = s.signum(lhs)
+ def abs(): A = macro Macros.abs[A]
+ def sign(): Sign = macro Macros.sign[A]
+ def signum(): Int = macro Macros.signum[A]
}
class OrderedRingIsSigned[A](implicit o:Order[A], r:Ring[A]) extends Signed[A] {
@@ -0,0 +1,109 @@
+package spire.macros
+
+import language.implicitConversions
+import language.higherKinds
+import language.experimental.macros
+import scala.{specialized => spec}
+import scala.reflect.makro.Context
+
+import spire.math._
+import spire.algebra._
+
+object Macros {
+ // eq
+ def eqv[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "eqv")
+ def neqv[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "neqv")
+
+ // order
+ def gt[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "gt")
+ def gteqv[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "gteqv")
+ def lt[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "lt")
+ def lteqv[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "lteqv")
+ def compare[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Int](c)(rhs, "compare")
+ def min[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "min")
+ def max[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "max")
+
+ // ring
+ def negate[A](c:Context)() = Ops.unop[A](c)("negate")
+ def plus[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "plus")
+ def times[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "times")
+ def minus[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "minus")
+ def pow[A](c:Context)(rhs:c.Expr[Int]) = Ops.binop[Int, A](c)(rhs, "pow")
+
+ // euclidean ring
+ def quot[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "quot")
+ def mod[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "mod")
+ def quotmod[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, (A, A)](c)(rhs, "quotmod")
+
+ // field
+ def div[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "div")
+ def isWhole[A](c:Context)() = Ops.unop[Boolean](c)("isWhole")
+
+ // fractional
+ def ceil[A](c:Context)() = Ops.unop[A](c)("ceil")
+ def floor[A](c:Context)() = Ops.unop[A](c)("floor")
+
+ // nroot
+ def nroot[A](c:Context)(rhs:c.Expr[Int]) = Ops.binop[Int, A](c)(rhs, "nroot")
+ def sqrt[A](c:Context)() = Ops.unop[A](c)("sqrt")
+
+ // semigroup
+ def op[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "op")
+
+ // signed
+ def abs[A](c:Context)() = Ops.unop[A](c)("abs")
+ def sign[A](c:Context)() = Ops.unop[Sign](c)("sign")
+ def signum[A](c:Context)() = Ops.unop[Int](c)("signum")
+
+ // convertable
+ def toByte[A](c:Context)() = Ops.unop[Byte](c)("toByte")
+ def toShort[A](c:Context)() = Ops.unop[Short](c)("toShort")
+ def toInt[A](c:Context)() = Ops.unop[Int](c)("toInt")
+ def toLong[A](c:Context)() = Ops.unop[Long](c)("toLong")
+ def toFloat[A](c:Context)() = Ops.unop[Float](c)("toFloat")
+ def toDouble[A](c:Context)() = Ops.unop[Double](c)("toDouble")
+ def toBigInt[A](c:Context)() = Ops.unop[BigInt](c)("toBigInt")
+ def toBigDecimal[A](c:Context)() = Ops.unop[BigDecimal](c)("toBigDecimal")
+ def toRational[A](c:Context)() = Ops.unop[Rational](c)("toRational")
+}
+
+/**
+ * This trait has some nice methods for working with implicit Ops classes.
+ */
+object Ops {
+ /**
+ * Given context, this method pulls 'evidence' and 'lhs' values out of
+ * instantiations of implicit -Ops classes. For instance,
+ *
+ * Given "new FooOps(x)(ev)", this method returns (ev, x).
+ */
+ def unpack[T[_], A](c:Context) = {
+ import c.universe._
+ c.prefix.tree match {
+ case Apply(Apply(TypeApply(_, _), List(x)), List(ev)) => (ev, x)
+ case t => sys.error("bad tree: %s" format t)
+ }
+ }
+
+ /**
+ * Given context and a method name, this method rewrites the tree to call the
+ * given method with the lhs parameter. This is useful when defining unary
+ * operators as macros, for instance: !x, -x.
+ */
+ def unop[R](c:Context)(name:String):c.Expr[R] = {
+ import c.universe._
+ val (ev, x) = unpack(c)
+ c.Expr[R](Apply(Select(ev, name), List(x)))
+ }
+
+ /**
+ * Given context, an expression, and a method name, this method rewrites the
+ * tree to call the given method with the lhs and rhs parameters. This is
+ * useful when defining binary operators as macros, for instance: x * y.
+ */
+ def binop[A, R](c:Context)(y:c.Expr[A], name:String):c.Expr[R] = {
+ import c.universe._
+ val (ev, x) = unpack(c)
+ c.Expr[R](Apply(Select(ev, name), List(x, y.tree)))
+ }
+}
@@ -39,23 +39,23 @@ extends ScalaNumber with ScalaNumericConversions with Serializable {
// ugh, ScalaNumericConversions ghetto
//
// maybe complex numbers are too different...
- def doubleValue = real.toDouble
- def floatValue = real.toFloat
- def longValue = real.toLong
- def intValue = real.toInt
- def isWhole = (f.fromInt(real.toInt) == real) && (imag == f.zero)
+ def doubleValue = f.toDouble(real)
+ def floatValue = f.toFloat(real)
+ def longValue = f.toLong(real)
+ def intValue = f.toInt(real)
+ def isWhole = (f.eqv(f.fromInt(f.toInt(real)), real)) && (f.eqv(imag, f.zero))
def signum: Int = f.compare(real, f.zero)
def underlying = (real, imag)
- def complexSignum = if (abs == f.zero) {
+ def complexSignum = if (f.eqv(abs, f.zero)) {
Complex.zero
} else {
this / Complex(abs, f.zero)
}
override def hashCode: Int = {
- if (isReal && real.isWhole &&
- real <= f.fromInt(Int.MaxValue) &&
- real >= f.fromInt(Int.MinValue)) real.toInt.##
+ if (isReal && f.isWhole(real) &&
+ f.lteqv(real, f.fromInt(Int.MaxValue)) &&
+ f.gteqv(real, f.fromInt(Int.MinValue))) f.toInt(real).##
else 19 * real.## + 41 * imag.##
}
@@ -71,7 +71,7 @@ extends ScalaNumber with ScalaNumericConversions with Serializable {
// ugh, specialized lazy vals don't work very well
//lazy val abs: T = f.sqrt(real * real + imag * imag)
//lazy val arg: T = t.atan2(imag, real)
- def abs: T = f.sqrt(real * real + imag * imag)
+ def abs: T = f.sqrt(f.plus(f.times(real, real), f.times(imag, imag)))
def arg: T = t.atan2(imag, real)
def conjugate = Complex(real, f.negate(imag))
@@ -95,8 +95,8 @@ extends ScalaNumber with ScalaNumericConversions with Serializable {
f.plus(f.times(imag, b.real), f.times(real, b.imag)))
def /(b:Complex[T]) = {
- val abs_breal = b.real.abs
- val abs_bimag = b.imag.abs
+ val abs_breal = f.abs(b.real)
+ val abs_bimag = f.abs(b.imag)
if (f.gteqv(abs_breal, abs_bimag)) {
if (f.eqv(abs_breal, f.zero)) throw new Exception("/ by zero")
@@ -143,12 +143,15 @@ extends ScalaNumber with ScalaNumericConversions with Serializable {
// TODO: is adding frac**frac reasonable? if not, we won't be able to do
// this without something hacky like the below.
// TODO: we also need log and exp on Field, Fractional, or Trig.
- val len = f.fromDouble(math.pow(abs.toDouble, b.real.toDouble) / exp((arg * b.imag).toDouble))
- val phase = f.fromDouble(arg.toDouble * b.real.toDouble + log(abs.toDouble) * b.imag.toDouble)
+ val len = f.fromDouble(
+ math.pow(f.toDouble(abs), f.toDouble(b.real)) / exp(f.toDouble(f.times(arg, b.imag)))
+ )
+ val phase = f.fromDouble(f.toDouble(arg) * f.toDouble(b.real) +
+ log(f.toDouble(abs)) * f.toDouble(b.imag))
Complex.polar(len, phase)
} else {
- val len = f.fromDouble(math.pow(abs.toDouble, b.real.toDouble))
+ val len = f.fromDouble(math.pow(f.toDouble(abs), f.toDouble(b.real)))
val phase = f.times(arg, b.real)
Complex.polar(len, phase)
}
@@ -237,7 +240,7 @@ object FastComplex {
final def conjugate(d:Long):Long = encode(real(d), -imag(d))
// see if the complex number is a whole value
- final def isWhole(d:Long):Boolean = real(d).isWhole && imag(d).isWhole
+ final def isWhole(d:Long):Boolean = real(d) % 1.0F == 0.0F && imag(d) % 1.0F == 0.0F
// get the sign of the complex number
final def signum(d:Long):Int = real(d) compare 0.0F
Oops, something went wrong.

0 comments on commit bab8af5

Please sign in to comment.