diff --git a/core/src/main/scala/scalacache/memoization/Macros.scala b/core/src/main/scala/scalacache/memoization/Macros.scala index 381ec74b..b5d2a9da 100644 --- a/core/src/main/scala/scalacache/memoization/Macros.scala +++ b/core/src/main/scala/scalacache/memoization/Macros.scala @@ -5,41 +5,37 @@ import scala.reflect.macros.blackbox import scala.concurrent.duration.Duration import scalacache.{ Flags, ScalaCache } -object Macros { +class Macros(val c: blackbox.Context) { + import c.universe._ - def memoizeImpl[A: c.WeakTypeTag](c: blackbox.Context)(f: c.Expr[A])(scalaCache: c.Expr[ScalaCache], flags: c.Expr[Flags]) = { - import c.universe._ - - commonMacroImpl(c)(scalaCache, { keyName => + def memoizeImpl[A: c.WeakTypeTag](f: c.Expr[A])(scalaCache: c.Expr[ScalaCache], flags: c.Expr[Flags]): Tree = { + commonMacroImpl(scalaCache, { keyName => q"""_root_.scalacache.caching($keyName)($f)($scalaCache, $flags)""" }) } - def memoizeImplWithTTL[A: c.WeakTypeTag](c: blackbox.Context)(ttl: c.Expr[Duration])(f: c.Expr[A])(scalaCache: c.Expr[ScalaCache], flags: c.Expr[Flags]) = { - import c.universe._ - - commonMacroImpl(c)(scalaCache, { keyName => + def memoizeImplWithTTL[A: c.WeakTypeTag](ttl: c.Expr[Duration])(f: c.Expr[A])(scalaCache: c.Expr[ScalaCache], flags: c.Expr[Flags]): Tree = { + commonMacroImpl(scalaCache, { keyName => q"""_root_.scalacache.cachingWithTTL($keyName)($ttl)($f)($scalaCache, $flags)""" }) } - private def commonMacroImpl[A: c.WeakTypeTag](c: blackbox.Context)(scalaCache: c.Expr[ScalaCache], keyNameToCachingCall: c.TermName => c.Tree) = { - import c.universe._ + private def commonMacroImpl[A: c.WeakTypeTag](scalaCache: c.Expr[ScalaCache], keyNameToCachingCall: (c.TermName) => c.Tree): Tree = { - val enclosingMethodSymbol = getMethodSymbol(c) - val classSymbol = getClassSymbol(c) + val enclosingMethodSymbol = getMethodSymbol() + val classSymbol = getClassSymbol() /* * Gather all the info needed to build the cache key: * class name, method name and the method parameters lists */ - val classNameTree = getFullClassName(c)(classSymbol) - val classParamssTree = getConstructorParams(c)(classSymbol) - val methodNameTree = getMethodName(c)(enclosingMethodSymbol) + val classNameTree = getFullClassName(classSymbol) + val classParamssTree = getConstructorParams(classSymbol) + val methodNameTree = getMethodName(enclosingMethodSymbol) val methodParamssSymbols = c.internal.enclosingOwner.info.paramLists - val methodParamssTree = paramListsToTree(c)(methodParamssSymbols) + val methodParamssTree = paramListsToTree(methodParamssSymbols) - val keyName = createKeyName(c) + val keyName = createKeyName() val scalacacheCall = keyNameToCachingCall(keyName) val tree = q""" val $keyName = $scalaCache.memoization.toStringConverter.toString($classNameTree, $classParamssTree, $methodNameTree, $methodParamssTree) @@ -54,8 +50,7 @@ object Macros { * Get the symbol of the method that encloses the macro, * or abort the compilation if we can't find one. */ - private def getMethodSymbol(c: blackbox.Context): c.Symbol = { - import c.universe._ + private def getMethodSymbol(): c.Symbol = { def getMethodSymbolRecursively(sym: Symbol): Symbol = { if (sym == null || sym == NoSymbol || sym.owner == sym) @@ -74,16 +69,13 @@ object Macros { /** * Convert the given method symbol to a tree representing the method name. */ - private def getMethodName(c: blackbox.Context)(methodSymbol: c.Symbol): c.Tree = { - import c.universe._ + private def getMethodName(methodSymbol: c.Symbol): c.Tree = { val methodName = methodSymbol.asMethod.name.toString // return a Tree q"$methodName" } - private def getClassSymbol(c: blackbox.Context): c.Symbol = { - import c.universe._ - + private def getClassSymbol(): c.Symbol = { def getClassSymbolRecursively(sym: Symbol): Symbol = { if (sym == null) c.abort(c.enclosingPosition, "Encountered a null symbol while searching for enclosing class") @@ -101,29 +93,26 @@ object Macros { * * @param classSymbol should be either a ClassSymbol or a ModuleSymbol */ - private def getFullClassName(c: blackbox.Context)(classSymbol: c.Symbol): c.Tree = { - import c.universe._ + private def getFullClassName(classSymbol: c.Symbol): c.Tree = { val className = classSymbol.fullName // return a Tree q"$className" } - private def getConstructorParams(c: blackbox.Context)(classSymbol: c.Symbol): c.Tree = { - import c.universe._ + private def getConstructorParams(classSymbol: c.Symbol): c.Tree = { if (classSymbol.isClass) { val symbolss = classSymbol.asClass.primaryConstructor.asMethod.paramLists if (symbolss == List(Nil)) { q"_root_.scala.collection.immutable.Nil" } else { - paramListsToTree(c)(symbolss) + paramListsToTree(symbolss) } } else { q"_root_.scala.collection.immutable.Nil" } } - private def paramListsToTree(c: blackbox.Context)(symbolss: List[List[c.Symbol]]): c.Tree = { - import c.universe._ + private def paramListsToTree(symbolss: List[List[c.Symbol]]): c.Tree = { val cacheKeyExcludeType = c.typeOf[cacheKeyExclude] def shouldExclude(s: c.Symbol) = { s.annotations.exists(a => a.tree.tpe == cacheKeyExcludeType) @@ -131,18 +120,17 @@ object Macros { val identss: List[List[Ident]] = symbolss.map(ss => ss.collect { case s if !shouldExclude(s) => Ident(s.name) }) - listToTree(c)(identss.map(is => listToTree(c)(is))) + listToTree(identss.map(is => listToTree(is))) } /** * Convert a List[Tree] to a Tree by calling scala.collection.immutable.list.apply() */ - private def listToTree(c: blackbox.Context)(ts: List[c.Tree]): c.Tree = { - import c.universe._ + private def listToTree(ts: List[c.Tree]): c.Tree = { q"_root_.scala.collection.immutable.List(..$ts)" } - private def createKeyName(c: blackbox.Context) = { + private def createKeyName(): TermName = { // We must create a fresh name for any vals that we define, to ensure we don't clash with any user-defined terms. // See https://github.com/cb372/scalacache/issues/13 // (Note that c.freshName("key") does not work as expected.