This repository has been archived by the owner on Jan 12, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 16
/
AvroTypeProviderMacro.scala
74 lines (58 loc) · 3.17 KB
/
AvroTypeProviderMacro.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
package com.julianpeeters.avro.annotations
import provider._
import scala.reflect.macros.blackbox.Context
import scala.language.experimental.macros
import scala.annotation.StaticAnnotation
import collection.JavaConversions._
import java.io.File
import com.typesafe.scalalogging._
object AvroTypeProviderMacro extends LazyLogging {
def impl(c: Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
import Flag._
val result = {
annottees.map(_.tree).toList match {
case q"$mods class $name[..$tparams](..$first)(...$rest) extends ..$parents { $self => ..$body }" :: tail => {
// get the namespace from the context and passing it around instead of using schema.getNamespace
// in order to read schemas that omit namespace (e.g. nested schemas or python's avrostorage default)
val namespace = NamespaceProbe.getNamespace(c)
val fullName: String = {
if (namespace == null) name.toString
else s"$namespace.$name"
}
// currently, having a `@AvroRecord` the only thing that will trigger the writing of vars instead of vals
val isImmutable: Boolean = {
!mods.annotations.exists(mod => mod.toString == "new AvroRecord()" | mod.toString =="new AvroRecord(null)")
}
// helpful for IDE users who may not be able to easily see where their files live
logger.info(s"Current path: ${new File(".").getAbsolutePath}")
// get the schema for the record that this class represents
val avroFilePath = FilePathProbe.getPath(c)
val infile = new File(avroFilePath)
val fileSchemas = FileParser.getSchemas(infile)
val nestedSchemas = fileSchemas.flatMap(NestedSchemaExtractor.getNestedSchemas)
// first try matching schema record full name to class full name, then by the
// regular name in case we're trying to read from a non-namespaced schema
val classSchema = nestedSchemas.find(s => s.getFullName == fullName)
.getOrElse(nestedSchemas.find(s => s.getName == name.toString && s.getNamespace == null)
.getOrElse(sys.error("no record found with name " + name)))
// wraps each schema field in a quasiquote, returning immutable val defs if immutable flag is true
val newFields: List[ValDef] = ValDefGenerator.asScalaFields(classSchema, namespace, isImmutable, c)
tail match {
// if there is no preexisiting companion
case Nil => q"$mods class $name[..$tparams](..${newFields:::first})(...$rest) extends ..$parents { $self => ..$body }"
// if there is a preexisting companion, include it with the updated classDef
case moduleDef @ q"object $moduleName { ..$moduleBody }" :: Nil => {
q"""$mods class $name[..$tparams](..${newFields:::first})(...$rest) extends ..$parents { $self => ..$body };
object ${name.toTermName} { ..$moduleBody }"""
}
}
}
}
}
c.Expr[Any](result)
}
}
class AvroTypeProvider(inputPath: String) extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro AvroTypeProviderMacro.impl
}