Skip to content

Commit

Permalink
First pass at implementing macros for typeclass operators in Spire.
Browse files Browse the repository at this point in the history
This is a gigantic commit. It has some other necessary restructuring mixed in
as well, which probably should have been pulled out into its own commit. This
was mostly just to demonstrate that the approch works, and to try to get the
2.10.0 branch a bit more up-to-date.
  • Loading branch information
non committed Jul 3, 2012
1 parent bd0fea7 commit bab8af5
Show file tree
Hide file tree
Showing 23 changed files with 290 additions and 129 deletions.
5 changes: 5 additions & 0 deletions TODO
Expand Up @@ -9,3 +9,8 @@
9. Match up concrete and generic APIs
10. Get compiler plugin for inlining implicits/ops working
11. Unify operators (e.g. /~, %, /%, **, ~^, etc).

=======

1. type class standardization plus guidelines
2.
6 changes: 3 additions & 3 deletions project/Build.scala
Expand Up @@ -27,11 +27,11 @@ object MyBuild extends Build {

lazy val spire = Project("spire", file("."))

lazy val examples = Project("examples", file("examples"))
lazy val examples = Project("examples", file("examples")).
dependsOn (spire)

lazy val benchmark: Project = Project("benchmark", file("benchmark"))
settings (benchmarkSettings: _*)
lazy val benchmark: Project = Project("benchmark", file("benchmark")).
settings (benchmarkSettings: _*).
dependsOn (spire)

def benchmarkSettings = Seq(
Expand Down
8 changes: 5 additions & 3 deletions src/main/scala/spire/algebra/EuclideanRing.scala
@@ -1,7 +1,9 @@
package spire.algebra

import spire.math._
import spire.macros._

import language.experimental.macros
import scala.{specialized => spec}
import scala.math.{abs, ceil, floor}

Expand All @@ -12,9 +14,9 @@ trait EuclideanRing[@spec(Int,Long,Float,Double) A] extends Ring[A] {
}

final class EuclideanRingOps[@spec(Int,Long,Float,Double) A](lhs:A)(implicit ev:EuclideanRing[A]) {
def /~(rhs:A) = ev.quot(lhs, rhs)
def %(rhs:A) = ev.mod(lhs, rhs)
def /%(rhs:A) = (ev.quot(lhs, rhs), ev.mod(lhs, rhs))
def /~(rhs:A) = macro Macros.quot[A]
def %(rhs:A) = macro Macros.mod[A]
def /%(rhs:A) = macro Macros.quotmod[A]
}

object EuclideanRing {
Expand Down
6 changes: 4 additions & 2 deletions src/main/scala/spire/algebra/Field.scala
@@ -1,7 +1,9 @@
package spire.algebra

import spire.math._
import spire.macros._

import language.experimental.macros
import scala.{specialized => spec}

trait Field[@spec(Int,Long,Float,Double) A] extends EuclideanRing[A] {
Expand All @@ -17,8 +19,8 @@ trait Field[@spec(Int,Long,Float,Double) A] extends EuclideanRing[A] {
}

final class FieldOps[@spec(Int,Long,Float,Double) A](lhs:A)(implicit ev:Field[A]) {
def /(rhs:A) = ev.div(lhs, rhs)
def isWhole = ev.isWhole(lhs)
def /(rhs:A) = macro Macros.div[A]
def isWhole() = macro Macros.isWhole[A]
}

object Field {
Expand Down
14 changes: 9 additions & 5 deletions src/main/scala/spire/algebra/NRoot.scala
@@ -1,7 +1,9 @@
package spire.algebra

import spire.math._
import spire.macros._

import language.experimental.macros
import scala.{specialized => spec, math => mth}
import java.math.MathContext

Expand All @@ -21,7 +23,6 @@ trait NRoot[@spec(Double,Float,Int,Long) A] {
def sqrt(a: A): A = nroot(a, 2)
}


/**
* A type class for `EuclideanRing`s with `NRoot`s as well. Since the base
* requirement is only a `EuclideanRing`, we can provide instances for `Int`,
Expand Down Expand Up @@ -71,8 +72,10 @@ object FieldWithNRoot {


final class NRootOps[@spec(Double, Float, Int, Long) A](lhs: A)(implicit n: NRoot[A]) {
def nroot(k: Int): A = n.nroot(lhs, k)
def sqrt: A = n.sqrt(lhs)
//def nroot(k: Int): A = n.nroot(lhs, k)
//def sqrt: A = n.sqrt(lhs)
def nroot(rhs: Int): A = macro Macros.nroot[A]
def sqrt(): A = macro Macros.sqrt[A]
}


Expand Down Expand Up @@ -173,6 +176,9 @@ trait BigIntIsNRoot extends NRoot[BigInt] {


object NRoot {
@inline final def apply[A](implicit ev:NRoot[A]) = ev

implicit object BigIntIsNRoot extends BigIntIsNRoot

/**
* This will return the largest integer that meets some criteria. Specifically,
Expand Down Expand Up @@ -287,5 +293,3 @@ object NRoot {
BigDecimal(unscaled, newscale, ctxt)
}
}


18 changes: 10 additions & 8 deletions src/main/scala/spire/algebra/Ring.scala
@@ -1,11 +1,13 @@
package spire.algebra

import spire.math._

import annotation.tailrec
import language.experimental.macros
import scala.{specialized => spec}
import scala.math.{abs, ceil, floor, pow => mpow}

import spire.math._
import spire.macros._

/**
* Ring represents a set (A) that is a group over addition (+) and a monoid
* over multiplication (*).
Expand Down Expand Up @@ -48,14 +50,14 @@ trait Ring[@spec(Int,Long,Float,Double) A] {
}

final class RingOps[@spec(Int,Long,Float,Double) A](lhs:A)(implicit ev:Ring[A]) {
def unary_- = ev.negate(lhs)
def unary_-() = macro Macros.negate[A]

def -(rhs:A) = ev.minus(lhs, rhs)
def +(rhs:A) = ev.plus(lhs, rhs)
def *(rhs:A) = ev.times(lhs, rhs)
def -(rhs:A) = macro Macros.minus[A]
def +(rhs:A) = macro Macros.plus[A]
def *(rhs:A) = macro Macros.times[A]

def pow(rhs:Int) = ev.pow(lhs, rhs)
def **(rhs:Int) = ev.pow(lhs, rhs)
def pow(rhs:Int) = macro Macros.pow[A]
def **(rhs:Int) = macro Macros.pow[A]
}

object Ring {
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/spire/algebra/Semigroup.scala
@@ -1,6 +1,8 @@
package spire.algebra

import spire.math.Eq
import spire.macros._
import language.experimental.macros

trait Semigroup[A] {
def op(x:A, y:A): A
Expand All @@ -11,5 +13,5 @@ object Semigroup {
}

final class SemigroupOps[A](lhs:A)(implicit ev:Semigroup[A]) {
def |+|(rhs:A) = ev.op(lhs, rhs)
def |+|(rhs:A) = macro Macros.op[A]
}
13 changes: 9 additions & 4 deletions src/main/scala/spire/algebra/Signed.scala
@@ -1,7 +1,9 @@
package spire.algebra

import spire.math.{Fractional, Trig, Order, Real, Rational, Complex}
import spire.math._
import spire.macros._

import language.experimental.macros
import java.lang.{Math => mth}
import scala.{ specialized => spec }

Expand Down Expand Up @@ -37,9 +39,12 @@ object Signed extends SignedLow {
}

final class SignedOps[@spec(Double, Float, Int, Long) A](lhs: A)(implicit s: Signed[A]) {
def abs: A = s.abs(lhs)
def sign: Sign = s.sign(lhs)
def signum: Int = s.signum(lhs)
//def abs: A = s.abs(lhs)
//def sign: Sign = s.sign(lhs)
//def signum: Int = s.signum(lhs)
def abs(): A = macro Macros.abs[A]
def sign(): Sign = macro Macros.sign[A]
def signum(): Int = macro Macros.signum[A]
}

class OrderedRingIsSigned[A](implicit o:Order[A], r:Ring[A]) extends Signed[A] {
Expand Down
109 changes: 109 additions & 0 deletions src/main/scala/spire/macros/Ops.scala
@@ -0,0 +1,109 @@
package spire.macros

import language.implicitConversions
import language.higherKinds
import language.experimental.macros
import scala.{specialized => spec}
import scala.reflect.makro.Context

import spire.math._
import spire.algebra._

object Macros {
// eq
def eqv[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "eqv")
def neqv[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "neqv")

// order
def gt[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "gt")
def gteqv[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "gteqv")
def lt[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "lt")
def lteqv[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Boolean](c)(rhs, "lteqv")
def compare[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, Int](c)(rhs, "compare")
def min[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "min")
def max[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "max")

// ring
def negate[A](c:Context)() = Ops.unop[A](c)("negate")
def plus[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "plus")
def times[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "times")
def minus[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "minus")
def pow[A](c:Context)(rhs:c.Expr[Int]) = Ops.binop[Int, A](c)(rhs, "pow")

// euclidean ring
def quot[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "quot")
def mod[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "mod")
def quotmod[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, (A, A)](c)(rhs, "quotmod")

// field
def div[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "div")
def isWhole[A](c:Context)() = Ops.unop[Boolean](c)("isWhole")

// fractional
def ceil[A](c:Context)() = Ops.unop[A](c)("ceil")
def floor[A](c:Context)() = Ops.unop[A](c)("floor")

// nroot
def nroot[A](c:Context)(rhs:c.Expr[Int]) = Ops.binop[Int, A](c)(rhs, "nroot")
def sqrt[A](c:Context)() = Ops.unop[A](c)("sqrt")

// semigroup
def op[A](c:Context)(rhs:c.Expr[A]) = Ops.binop[A, A](c)(rhs, "op")

// signed
def abs[A](c:Context)() = Ops.unop[A](c)("abs")
def sign[A](c:Context)() = Ops.unop[Sign](c)("sign")
def signum[A](c:Context)() = Ops.unop[Int](c)("signum")

// convertable
def toByte[A](c:Context)() = Ops.unop[Byte](c)("toByte")
def toShort[A](c:Context)() = Ops.unop[Short](c)("toShort")
def toInt[A](c:Context)() = Ops.unop[Int](c)("toInt")
def toLong[A](c:Context)() = Ops.unop[Long](c)("toLong")
def toFloat[A](c:Context)() = Ops.unop[Float](c)("toFloat")
def toDouble[A](c:Context)() = Ops.unop[Double](c)("toDouble")
def toBigInt[A](c:Context)() = Ops.unop[BigInt](c)("toBigInt")
def toBigDecimal[A](c:Context)() = Ops.unop[BigDecimal](c)("toBigDecimal")
def toRational[A](c:Context)() = Ops.unop[Rational](c)("toRational")
}

/**
* This trait has some nice methods for working with implicit Ops classes.
*/
object Ops {
/**
* Given context, this method pulls 'evidence' and 'lhs' values out of
* instantiations of implicit -Ops classes. For instance,
*
* Given "new FooOps(x)(ev)", this method returns (ev, x).
*/
def unpack[T[_], A](c:Context) = {
import c.universe._
c.prefix.tree match {
case Apply(Apply(TypeApply(_, _), List(x)), List(ev)) => (ev, x)
case t => sys.error("bad tree: %s" format t)
}
}

/**
* Given context and a method name, this method rewrites the tree to call the
* given method with the lhs parameter. This is useful when defining unary
* operators as macros, for instance: !x, -x.
*/
def unop[R](c:Context)(name:String):c.Expr[R] = {
import c.universe._
val (ev, x) = unpack(c)
c.Expr[R](Apply(Select(ev, name), List(x)))
}

/**
* Given context, an expression, and a method name, this method rewrites the
* tree to call the given method with the lhs and rhs parameters. This is
* useful when defining binary operators as macros, for instance: x * y.
*/
def binop[A, R](c:Context)(y:c.Expr[A], name:String):c.Expr[R] = {
import c.universe._
val (ev, x) = unpack(c)
c.Expr[R](Apply(Select(ev, name), List(x, y.tree)))
}
}
35 changes: 19 additions & 16 deletions src/main/scala/spire/math/Complex.scala
Expand Up @@ -39,23 +39,23 @@ extends ScalaNumber with ScalaNumericConversions with Serializable {
// ugh, ScalaNumericConversions ghetto
//
// maybe complex numbers are too different...
def doubleValue = real.toDouble
def floatValue = real.toFloat
def longValue = real.toLong
def intValue = real.toInt
def isWhole = (f.fromInt(real.toInt) == real) && (imag == f.zero)
def doubleValue = f.toDouble(real)
def floatValue = f.toFloat(real)
def longValue = f.toLong(real)
def intValue = f.toInt(real)
def isWhole = (f.eqv(f.fromInt(f.toInt(real)), real)) && (f.eqv(imag, f.zero))
def signum: Int = f.compare(real, f.zero)
def underlying = (real, imag)
def complexSignum = if (abs == f.zero) {
def complexSignum = if (f.eqv(abs, f.zero)) {
Complex.zero
} else {
this / Complex(abs, f.zero)
}

override def hashCode: Int = {
if (isReal && real.isWhole &&
real <= f.fromInt(Int.MaxValue) &&
real >= f.fromInt(Int.MinValue)) real.toInt.##
if (isReal && f.isWhole(real) &&
f.lteqv(real, f.fromInt(Int.MaxValue)) &&
f.gteqv(real, f.fromInt(Int.MinValue))) f.toInt(real).##
else 19 * real.## + 41 * imag.##
}

Expand All @@ -71,7 +71,7 @@ extends ScalaNumber with ScalaNumericConversions with Serializable {
// ugh, specialized lazy vals don't work very well
//lazy val abs: T = f.sqrt(real * real + imag * imag)
//lazy val arg: T = t.atan2(imag, real)
def abs: T = f.sqrt(real * real + imag * imag)
def abs: T = f.sqrt(f.plus(f.times(real, real), f.times(imag, imag)))
def arg: T = t.atan2(imag, real)

def conjugate = Complex(real, f.negate(imag))
Expand All @@ -95,8 +95,8 @@ extends ScalaNumber with ScalaNumericConversions with Serializable {
f.plus(f.times(imag, b.real), f.times(real, b.imag)))

def /(b:Complex[T]) = {
val abs_breal = b.real.abs
val abs_bimag = b.imag.abs
val abs_breal = f.abs(b.real)
val abs_bimag = f.abs(b.imag)

if (f.gteqv(abs_breal, abs_bimag)) {
if (f.eqv(abs_breal, f.zero)) throw new Exception("/ by zero")
Expand Down Expand Up @@ -143,12 +143,15 @@ extends ScalaNumber with ScalaNumericConversions with Serializable {
// TODO: is adding frac**frac reasonable? if not, we won't be able to do
// this without something hacky like the below.
// TODO: we also need log and exp on Field, Fractional, or Trig.
val len = f.fromDouble(math.pow(abs.toDouble, b.real.toDouble) / exp((arg * b.imag).toDouble))
val phase = f.fromDouble(arg.toDouble * b.real.toDouble + log(abs.toDouble) * b.imag.toDouble)
val len = f.fromDouble(
math.pow(f.toDouble(abs), f.toDouble(b.real)) / exp(f.toDouble(f.times(arg, b.imag)))
)
val phase = f.fromDouble(f.toDouble(arg) * f.toDouble(b.real) +
log(f.toDouble(abs)) * f.toDouble(b.imag))
Complex.polar(len, phase)

} else {
val len = f.fromDouble(math.pow(abs.toDouble, b.real.toDouble))
val len = f.fromDouble(math.pow(f.toDouble(abs), f.toDouble(b.real)))
val phase = f.times(arg, b.real)
Complex.polar(len, phase)
}
Expand Down Expand Up @@ -237,7 +240,7 @@ object FastComplex {
final def conjugate(d:Long):Long = encode(real(d), -imag(d))

// see if the complex number is a whole value
final def isWhole(d:Long):Boolean = real(d).isWhole && imag(d).isWhole
final def isWhole(d:Long):Boolean = real(d) % 1.0F == 0.0F && imag(d) % 1.0F == 0.0F

// get the sign of the complex number
final def signum(d:Long):Int = real(d) compare 0.0F
Expand Down

0 comments on commit bab8af5

Please sign in to comment.