Skip to content

Commit

Permalink
Merge pull request #6 from non/topic/support-type-bounds
Browse files Browse the repository at this point in the history
Type bounds proof-of-concept.
  • Loading branch information
non committed Feb 18, 2015
2 parents 87faa6d + 519caab commit 9d9a1db
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 24 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Expand Up @@ -4,7 +4,7 @@ organization := "org.spire-math"

version := "0.5.2"

scalaVersion := "2.10.3"
scalaVersion := "2.11.4"

seq(bintrayResolverSettings: _*)

Expand All @@ -25,7 +25,7 @@ scalacOptions in Test <+= (packageBin in Compile) map {
pluginJar => "-Xplugin:" + pluginJar
}

crossScalaVersions := Seq("2.9.3", "2.11.0")
crossScalaVersions := Seq("2.10.4", "2.11.4")

seq(bintrayPublishSettings: _*)

Expand Down
48 changes: 26 additions & 22 deletions src/main/scala/KindProjector.scala
Expand Up @@ -12,6 +12,8 @@ import nsc.symtab.Flags._
import nsc.ast.TreeDSL
import nsc.typechecker

import scala.reflect.NameTransformer

class KindProjector(val global: Global) extends Plugin {
val name = "kind-projector"
val description = "Expand type lambda syntax"
Expand All @@ -23,6 +25,8 @@ class KindRewriter(plugin: Plugin, val global: Global)

import global._

val sp = new StringParser[global.type](global)

val runsAfter = "parser" :: Nil
val phaseName = "kind-projector"

Expand All @@ -41,44 +45,44 @@ class KindRewriter(plugin: Plugin, val global: Global)
val Plus = newTypeName("$plus")
val Minus = newTypeName("$minus")

override def transform(tree: Tree): Tree = {
def rssi(b: String, c: String) =
Select(Select(Ident("_root_"), b), newTypeName(c))

def rssi(b: String, c: String) =
Select(Select(Ident("_root_"), b), newTypeName(c))
val NothingLower = rssi("scala", "Nothing")
val AnyUpper = rssi("scala", "Any")
val DefaultBounds = TypeBoundsTree(NothingLower, AnyUpper)

// Handy way to build the bounds that we'll frequently be using.
def bounds = TypeBoundsTree(rssi("scala", "Nothing"), rssi("scala", "Any"))
override def transform(tree: Tree): Tree = {

// Handy way to make a TypeName from a Name.
def makeTypeName(name: Name) =
newTypeName(name.toString)

// We use this to create type parameters inside our type project, e.g.
// the A in: ({type L[A] = (A, Int) => A})#L.
def makeTypeParam(name: Name) =
def makeTypeParam(name: Name, bounds: TypeBoundsTree = DefaultBounds) =
TypeDef(Modifiers(PARAM), makeTypeName(name), Nil, bounds)

// Like makeTypeParam but with covariance, e.g.
// ({type L[+A] = ... })#L.
def makeTypeParamCo(name: Name) =
def makeTypeParamCo(name: Name, bounds: TypeBoundsTree = DefaultBounds) =
TypeDef(Modifiers(PARAM | COVARIANT), makeTypeName(name), Nil, bounds)

// Like makeTypeParam but with contravariance, e.g.
// ({type L[-A] = ... })#L.
def makeTypeParamContra(name: Name) =
def makeTypeParamContra(name: Name, bounds: TypeBoundsTree = DefaultBounds) =
TypeDef(Modifiers(PARAM | CONTRAVARIANT), makeTypeName(name), Nil, bounds)

// Detects which makeTypeParam* method to call based on name.
// Names like +A are covariant, names like -A are contravariant,
// all others are invariant.
def makeTypeParamFromName(name: Name) =
if (name.startsWith("$plus")) {
makeTypeParamCo(newTypeName(name.toString.substring(5)))
} else if (name.startsWith("$minus")) {
makeTypeParamContra(newTypeName(name.toString.substring(6)))
} else {
makeTypeParam(name)
// Given a name, e.g. A or `+A` or `A <: Foo`, build a type
// parameter tree using the given name, bounds, variance, etc.
def makeTypeParamFromName(name: Name) = {
val decoded = NameTransformer.decode(name.toString)
val src = s"type _X_[$decoded] = Unit"
sp.parse(src) match {
case Some(TypeDef(_, _, List(tpe), _)) => tpe
case None => unit.error(tree.pos, s"Can't parse param: $name"); null
}
}

// Like makeTypeParam, but can be used recursively in the case of types
// that are themselves parameterized.
Expand All @@ -91,7 +95,7 @@ class KindRewriter(plugin: Plugin, val global: Global)

case ExistentialTypeTree(AppliedTypeTree(Ident(name), ps), _) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, bounds)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case x =>
unit.error(x.pos, "Can't parse %s (%s)" format (x, x.getClass.getName))
Expand Down Expand Up @@ -143,11 +147,11 @@ class KindRewriter(plugin: Plugin, val global: Global)

case AppliedTypeTree(Ident(name), ps) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, bounds)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case ExistentialTypeTree(AppliedTypeTree(Ident(name), ps), _) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, bounds)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case x =>
unit.error(x.pos, "Can't parse %s (%s)" format (x, x.getClass.getName))
Expand Down Expand Up @@ -183,7 +187,7 @@ class KindRewriter(plugin: Plugin, val global: Global)
case (Ident(name), Some(Right(ContraPlaceholder))) =>
makeTypeParamContra(name)
case (Ident(name), Some(Left(tparams))) =>
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, bounds)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)
}

val args = xyz.map(_._1)
Expand Down
21 changes: 21 additions & 0 deletions src/main/scala/StringParser.scala
@@ -0,0 +1,21 @@
package d_m

import scala.reflect.macros.ParseException
import scala.tools.nsc.Global
import scala.tools.nsc.reporters.StoreReporter

class StringParser[G <: Global](val global: G) {
import global._
def parse(code: String): Option[Tree] = {
val oldReporter = global.reporter
try {
val r = new StoreReporter()
global.reporter = r
val tree = newUnitParser(code).templateStats().headOption
if (r.infos.isEmpty) tree else None
} finally {
global.reporter = oldReporter
}
}
}

13 changes: 13 additions & 0 deletions src/test/scala/bounds.scala
@@ -0,0 +1,13 @@
package bounds

trait Leibniz[-L, +H >: L, A >: L <: H, B >: L <: H]

object Test {
trait Foo
trait Bar extends Foo

def outer[A >: Bar <: Foo] = {
def test[F[_ >: Bar <: Foo]] = 999
test[λ[`b >: Bar <: Foo` => Leibniz[Bar, Foo, A, b]]]
}
}

0 comments on commit 9d9a1db

Please sign in to comment.