-
Notifications
You must be signed in to change notification settings - Fork 1k
/
MacroAnnotation.scala
212 lines (209 loc) · 10.1 KB
/
MacroAnnotation.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
// TODO in which package should this class be located?
package scala
package annotation
import scala.quoted._
/** Base trait for macro annotation implementation.
* Macro annotations can transform definitions and add new definitions.
*
* See: `MacroAnnotation.transform`
*
* @syntax markdown
*/
@experimental
trait MacroAnnotation extends StaticAnnotation:
/** Transform the `tree` definition and add new definitions
*
* This method takes as argument the annotated definition.
* It returns a non-empty list containing the modified version of the annotated definition.
* The new tree for the definition must use the original symbol.
* New definitions can be added to the list before or after the transformed definitions, this order
* will be retained. New definitions will not be visible from outside the macro expansion.
*
* #### Restrictions
* - All definitions in the result must have the same owner. The owner can be recovered from `Symbol.spliceOwner`.
* - Special case: an annotated top-level `def`, `val`, `var`, `lazy val` can return a `class`/`object`
definition that is owned by the package or package object.
* - Can not return a `type`.
* - Annotated top-level `class`/`object` can not return top-level `def`, `val`, `var`, `lazy val`.
* - Can not see new definition in user written code.
*
* #### Good practices
* - Make your new definitions private if you can.
* - New definitions added as class members should use a fresh name (`Symbol.freshName`) to avoid collisions.
* - New top-level definitions should use a fresh name (`Symbol.freshName`) that includes the name of the annotated
* member as a prefix to avoid collisions of definitions added in other files.
*
* **IMPORTANT**: When developing and testing a macro annotation, you must enable `-Xcheck-macros` and `-Ycheck:all`.
*
* #### Example 1
* This example shows how to modify a `def` and add a `val` next to it using a macro annotation.
* ```scala
* import scala.quoted.*
* import scala.collection.mutable
*
* class memoize extends MacroAnnotation:
* def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
* import quotes.reflect._
* tree match
* case DefDef(name, TermParamClause(param :: Nil) :: Nil, tpt, Some(rhsTree)) =>
* (param.tpt.tpe.asType, tpt.tpe.asType) match
* case ('[t], '[u]) =>
* val cacheName = Symbol.freshName(name + "Cache")
* val cacheSymbol = Symbol.newVal(Symbol.spliceOwner, cacheName, TypeRepr.of[mutable.Map[t, u]], Flags.Private, Symbol.noSymbol)
* val cacheRhs =
* given Quotes = cacheSymbol.asQuotes
* '{ mutable.Map.empty[t, u] }.asTerm
* val cacheVal = ValDef(cacheSymbol, Some(cacheRhs))
* val newRhs =
* given Quotes = tree.symbol.asQuotes
* val cacheRefExpr = Ref(cacheSymbol).asExprOf[mutable.Map[t, u]]
* val paramRefExpr = Ref(param.symbol).asExprOf[t]
* val rhsExpr = rhsTree.asExprOf[u]
* '{ $cacheRefExpr.getOrElseUpdate($paramRefExpr, $rhsExpr) }.asTerm
* val newTree = DefDef.copy(tree)(name, TermParamClause(param :: Nil) :: Nil, tpt, Some(newRhs))
* List(cacheVal, newTree)
* case _ =>
* report.error("Annotation only supported on `def` with a single argument are supported")
* List(tree)
* ```
* with this macro annotation a user can write
* ```scala
* //{
* class memoize extends scala.annotation.StaticAnnotation
* //}
* @memoize
* def fib(n: Int): Int =
* println(s"compute fib of $n")
* if n <= 1 then n else fib(n - 1) + fib(n - 2)
* ```
* and the macro will modify the definition to create
* ```scala
* val fibCache$macro$1 =
* scala.collection.mutable.Map.empty[Int, Int]
* def fib(n: Int): Int =
* fibCache$macro$1.getOrElseUpdate(
* n,
* {
* println(s"compute fib of $n")
* if n <= 1 then n else fib(n - 1) + fib(n - 2)
* }
* )
* ```
*
* #### Example 2
* This example shows how to modify a `class` using a macro annotation.
* It shows how to override inherited members and add new ones.
* ```scala
* import scala.annotation.{experimental, MacroAnnotation}
* import scala.quoted.*
*
* @experimental
* class equals extends MacroAnnotation:
* def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] =
* import quotes.reflect.*
* tree match
* case ClassDef(className, ctr, parents, self, body) =>
* val cls = tree.symbol
*
* val constructorParameters = ctr.paramss.collect { case clause: TermParamClause => clause }
* if constructorParameters.size != 1 || constructorParameters.head.params.isEmpty then
* report.errorAndAbort("@equals class must have a single argument list with at least one argument", ctr.pos)
* def checkNotOverridden(sym: Symbol): Unit =
* if sym.overridingSymbol(cls).exists then
* report.error(s"Cannot override ${sym.name} in a @equals class")
*
* val fields = body.collect {
* case vdef: ValDef if vdef.symbol.flags.is(Flags.ParamAccessor) =>
* Select(This(cls), vdef.symbol).asExpr
* }
*
* val equalsSym = Symbol.requiredMethod("java.lang.Object.equals")
* checkNotOverridden(equalsSym)
* val equalsOverrideSym = Symbol.newMethod(cls, "equals", equalsSym.info, Flags.Override, Symbol.noSymbol)
* def equalsOverrideDefBody(argss: List[List[Tree]]): Option[Term] =
* given Quotes = equalsOverrideSym.asQuotes
* cls.typeRef.asType match
* case '[c] =>
* Some(equalsExpr[c](argss.head.head.asExpr, fields).asTerm)
* val equalsOverrideDef = DefDef(equalsOverrideSym, equalsOverrideDefBody)
*
* val hashSym = Symbol.newVal(cls, Symbol.freshName("hash"), TypeRepr.of[Int], Flags.Private | Flags.Lazy, Symbol.noSymbol)
* val hashVal = ValDef(hashSym, Some(hashCodeExpr(className, fields)(using hashSym.asQuotes).asTerm))
*
* val hashCodeSym = Symbol.requiredMethod("java.lang.Object.hashCode")
* checkNotOverridden(hashCodeSym)
* val hashCodeOverrideSym = Symbol.newMethod(cls, "hashCode", hashCodeSym.info, Flags.Override, Symbol.noSymbol)
* val hashCodeOverrideDef = DefDef(hashCodeOverrideSym, _ => Some(Ref(hashSym)))
*
* val newBody = equalsOverrideDef :: hashVal :: hashCodeOverrideDef :: body
* List(ClassDef.copy(tree)(className, ctr, parents, self, newBody))
* case _ =>
* report.error("Annotation only supports `class`")
* List(tree)
*
* private def equalsExpr[T: Type](that: Expr[Any], thisFields: List[Expr[Any]])(using Quotes): Expr[Boolean] =
* '{
* $that match
* case that: T @unchecked =>
* ${
* val thatFields: List[Expr[Any]] =
* import quotes.reflect.*
* thisFields.map(field => Select('{that}.asTerm, field.asTerm.symbol).asExpr)
* thisFields.zip(thatFields)
* .map { case (thisField, thatField) => '{ $thisField == $thatField } }
* .reduce { case (pred1, pred2) => '{ $pred1 && $pred2 } }
* }
* case _ => false
* }
*
* private def hashCodeExpr(className: String, thisFields: List[Expr[Any]])(using Quotes): Expr[Int] =
* '{
* var acc: Int = ${ Expr(scala.runtime.Statics.mix(-889275714, className.hashCode)) }
* ${
* Expr.block(
* thisFields.map {
* case '{ $field: Boolean } => '{ if $field then 1231 else 1237 }
* case '{ $field: Byte } => '{ $field.toInt }
* case '{ $field: Char } => '{ $field.toInt }
* case '{ $field: Short } => '{ $field.toInt }
* case '{ $field: Int } => field
* case '{ $field: Long } => '{ scala.runtime.Statics.longHash($field) }
* case '{ $field: Double } => '{ scala.runtime.Statics.doubleHash($field) }
* case '{ $field: Float } => '{ scala.runtime.Statics.floatHash($field) }
* case '{ $field: Null } => '{ 0 }
* case '{ $field: Unit } => '{ 0 }
* case field => '{ scala.runtime.Statics.anyHash($field) }
* }.map(hash => '{ acc = scala.runtime.Statics.mix(acc, $hash) }),
* '{ scala.runtime.Statics.finalizeHash(acc, ${Expr(thisFields.size)}) }
* )
* }
* }
* ```
* with this macro annotation a user can write
* ```scala
* //{
* class equals extends scala.annotation.StaticAnnotation
* //}
* @equals class User(val name: String, val id: Int)
* ```
* and the macro will modify the class definition to generate the following code
* ```scala
* class User(val name: String, val id: Int):
* override def equals(that: Any): Boolean =
* that match
* case that: User => this.name == that.name && this.id == that.id
* case _ => false
* private lazy val hash$macro$1: Int =
* var acc = 515782504 // scala.runtime.Statics.mix(-889275714, "User".hashCode)
* acc = scala.runtime.Statics.mix(acc, scala.runtime.Statics.anyHash(name))
* acc = scala.runtime.Statics.mix(acc, id)
* scala.runtime.Statics.finalizeHash(acc, 2)
* override def hashCode(): Int = hash$macro$1
* ```
*
* @param Quotes Implicit instance of Quotes used for tree reflection
* @param tree Tree that will be transformed
*
* @syntax markdown
*/
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition]