Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce boilerplate with macro bundles #58

Merged
merged 1 commit into from
Sep 17, 2015
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 24 additions & 36 deletions core/src/main/scala/scalacache/memoization/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,37 @@ import scala.reflect.macros.blackbox
import scala.concurrent.duration.Duration
import scalacache.{ Flags, ScalaCache }

object Macros {
class Macros(val c: blackbox.Context) {
import c.universe._

def memoizeImpl[A: c.WeakTypeTag](c: blackbox.Context)(f: c.Expr[A])(scalaCache: c.Expr[ScalaCache], flags: c.Expr[Flags]) = {
import c.universe._

commonMacroImpl(c)(scalaCache, { keyName =>
def memoizeImpl[A: c.WeakTypeTag](f: c.Expr[A])(scalaCache: c.Expr[ScalaCache], flags: c.Expr[Flags]): Tree = {
commonMacroImpl(scalaCache, { keyName =>
q"""_root_.scalacache.caching($keyName)($f)($scalaCache, $flags)"""
})
}

def memoizeImplWithTTL[A: c.WeakTypeTag](c: blackbox.Context)(ttl: c.Expr[Duration])(f: c.Expr[A])(scalaCache: c.Expr[ScalaCache], flags: c.Expr[Flags]) = {
import c.universe._

commonMacroImpl(c)(scalaCache, { keyName =>
def memoizeImplWithTTL[A: c.WeakTypeTag](ttl: c.Expr[Duration])(f: c.Expr[A])(scalaCache: c.Expr[ScalaCache], flags: c.Expr[Flags]): Tree = {
commonMacroImpl(scalaCache, { keyName =>
q"""_root_.scalacache.cachingWithTTL($keyName)($ttl)($f)($scalaCache, $flags)"""
})
}

private def commonMacroImpl[A: c.WeakTypeTag](c: blackbox.Context)(scalaCache: c.Expr[ScalaCache], keyNameToCachingCall: c.TermName => c.Tree) = {
import c.universe._
private def commonMacroImpl[A: c.WeakTypeTag](scalaCache: c.Expr[ScalaCache], keyNameToCachingCall: (c.TermName) => c.Tree): Tree = {

val enclosingMethodSymbol = getMethodSymbol(c)
val classSymbol = getClassSymbol(c)
val enclosingMethodSymbol = getMethodSymbol()
val classSymbol = getClassSymbol()

/*
* Gather all the info needed to build the cache key:
* class name, method name and the method parameters lists
*/
val classNameTree = getFullClassName(c)(classSymbol)
val classParamssTree = getConstructorParams(c)(classSymbol)
val methodNameTree = getMethodName(c)(enclosingMethodSymbol)
val classNameTree = getFullClassName(classSymbol)
val classParamssTree = getConstructorParams(classSymbol)
val methodNameTree = getMethodName(enclosingMethodSymbol)
val methodParamssSymbols = c.internal.enclosingOwner.info.paramLists
val methodParamssTree = paramListsToTree(c)(methodParamssSymbols)
val methodParamssTree = paramListsToTree(methodParamssSymbols)

val keyName = createKeyName(c)
val keyName = createKeyName()
val scalacacheCall = keyNameToCachingCall(keyName)
val tree = q"""
val $keyName = $scalaCache.memoization.toStringConverter.toString($classNameTree, $classParamssTree, $methodNameTree, $methodParamssTree)
Expand All @@ -54,8 +50,7 @@ object Macros {
* Get the symbol of the method that encloses the macro,
* or abort the compilation if we can't find one.
*/
private def getMethodSymbol(c: blackbox.Context): c.Symbol = {
import c.universe._
private def getMethodSymbol(): c.Symbol = {

def getMethodSymbolRecursively(sym: Symbol): Symbol = {
if (sym == null || sym == NoSymbol || sym.owner == sym)
Expand All @@ -74,16 +69,13 @@ object Macros {
/**
* Convert the given method symbol to a tree representing the method name.
*/
private def getMethodName(c: blackbox.Context)(methodSymbol: c.Symbol): c.Tree = {
import c.universe._
private def getMethodName(methodSymbol: c.Symbol): c.Tree = {
val methodName = methodSymbol.asMethod.name.toString
// return a Tree
q"$methodName"
}

private def getClassSymbol(c: blackbox.Context): c.Symbol = {
import c.universe._

private def getClassSymbol(): c.Symbol = {
def getClassSymbolRecursively(sym: Symbol): Symbol = {
if (sym == null)
c.abort(c.enclosingPosition, "Encountered a null symbol while searching for enclosing class")
Expand All @@ -101,48 +93,44 @@ object Macros {
*
* @param classSymbol should be either a ClassSymbol or a ModuleSymbol
*/
private def getFullClassName(c: blackbox.Context)(classSymbol: c.Symbol): c.Tree = {
import c.universe._
private def getFullClassName(classSymbol: c.Symbol): c.Tree = {
val className = classSymbol.fullName
// return a Tree
q"$className"
}

private def getConstructorParams(c: blackbox.Context)(classSymbol: c.Symbol): c.Tree = {
import c.universe._
private def getConstructorParams(classSymbol: c.Symbol): c.Tree = {
if (classSymbol.isClass) {
val symbolss = classSymbol.asClass.primaryConstructor.asMethod.paramLists
if (symbolss == List(Nil)) {
q"_root_.scala.collection.immutable.Nil"
} else {
paramListsToTree(c)(symbolss)
paramListsToTree(symbolss)
}
} else {
q"_root_.scala.collection.immutable.Nil"
}
}

private def paramListsToTree(c: blackbox.Context)(symbolss: List[List[c.Symbol]]): c.Tree = {
import c.universe._
private def paramListsToTree(symbolss: List[List[c.Symbol]]): c.Tree = {
val cacheKeyExcludeType = c.typeOf[cacheKeyExclude]
def shouldExclude(s: c.Symbol) = {
s.annotations.exists(a => a.tree.tpe == cacheKeyExcludeType)
}
val identss: List[List[Ident]] = symbolss.map(ss => ss.collect {
case s if !shouldExclude(s) => Ident(s.name)
})
listToTree(c)(identss.map(is => listToTree(c)(is)))
listToTree(identss.map(is => listToTree(is)))
}

/**
* Convert a List[Tree] to a Tree by calling scala.collection.immutable.list.apply()
*/
private def listToTree(c: blackbox.Context)(ts: List[c.Tree]): c.Tree = {
import c.universe._
private def listToTree(ts: List[c.Tree]): c.Tree = {
q"_root_.scala.collection.immutable.List(..$ts)"
}

private def createKeyName(c: blackbox.Context) = {
private def createKeyName(): TermName = {
// We must create a fresh name for any vals that we define, to ensure we don't clash with any user-defined terms.
// See https://github.com/cb372/scalacache/issues/13
// (Note that c.freshName("key") does not work as expected.
Expand Down