Skip to content
This repository has been archived by the owner on Jul 6, 2022. It is now read-only.

Commit

Permalink
Added Macro support for Case Classes
Browse files Browse the repository at this point in the history
SCALA-168
  • Loading branch information
rozza committed Dec 15, 2016
1 parent 930d8fe commit f4a1717
Show file tree
Hide file tree
Showing 10 changed files with 1,033 additions and 2 deletions.
86 changes: 86 additions & 0 deletions bson/src/main/scala/org/mongodb/scala/bson/codecs/Macros.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2016 MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.mongodb.scala.bson.codecs

import scala.annotation.compileTimeOnly
import scala.language.experimental.macros
import scala.language.implicitConversions

import org.bson.codecs.Codec
import org.bson.codecs.configuration.{ CodecProvider, CodecRegistry }

import org.mongodb.scala.bson.codecs.macrocodecs.{ CaseClassCodec, CaseClassProvider }

/**
* Macro based Codecs
*
* Allows the compile time creation of Codecs for case classes.
*
* The recommended approach is to use the implicit [[Macros.createCodecProvider]] method to help build a codecRegistry:
* ```
* import org.mongodb.scala.bson.codecs.Macros.createCodecProvider
* import org.bson.codecs.configuration.CodecRegistries.{fromRegistries, fromProviders}
*
* case class Contact(phone: String)
* case class User(_id: Int, username: String, age: Int, hobbies: List[String], contacts: List[Contact])
*
* val codecRegistry = fromRegistries(fromProviders(classOf[User], classOf[Contact]), MongoClient.DEFAULT_CODEC_REGISTRY)
* ```
*
* @since 2.0
*/
object Macros {

/**
* Creates a CodecProvider for a case class
*
* @tparam T the case class to create a Codec from
* @return the CodecProvider for the case class
*/
@compileTimeOnly("Creating a CodecProvider utilises Macros and must be run at compile time.")
def createCodecProvider[T](): CodecProvider = macro CaseClassProvider.createCaseClassProvider[T]

/**
* Creates a CodecProvider for a case class using the given class to represent the case class
*
* @param clazz the clazz that is the case class
* @tparam T the case class to create a Codec from
* @return the CodecProvider for the case class
*/
@compileTimeOnly("Creating a CodecProvider utilises Macros and must be run at compile time.")
implicit def createCodecProvider[T](clazz: Class[T]): CodecProvider = macro CaseClassProvider.createCaseClassProviderWithClass[T]

/**
* Creates a Codec for a case class
*
* @tparam T the case class to create a Codec from
* @return the Codec for the case class
*/
@compileTimeOnly("Creating a Codec utilises Macros and must be run at compile time.")
def createCodec[T](): Codec[T] = macro CaseClassCodec.createCodecNoArgs[T]

/**
* Creates a Codec for a case class
*
* @param codecRegistry the Codec Registry to use
* @tparam T the case class to create a codec from
* @return the Codec for the case class
*/
@compileTimeOnly("Creating a Codec utilises Macros and must be run at compile time.")
def createCodec[T](codecRegistry: CodecRegistry): Codec[T] = macro CaseClassCodec.createCodec[T]

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
/*
* Copyright 2016 MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.mongodb.scala.bson.codecs.macrocodecs

import scala.collection.MapLike
import scala.reflect.macros.whitebox

import org.bson.codecs.Codec
import org.bson.codecs.configuration.CodecRegistry

private[codecs] object CaseClassCodec {

def createCodecNoArgs[T: c.WeakTypeTag](c: whitebox.Context)(): c.Expr[Codec[T]] = {
import c.universe._
createCodec[T](c)(c.Expr[CodecRegistry](
q"""
import org.mongodb.scala.bson.codecs.DEFAULT_CODEC_REGISTRY
DEFAULT_CODEC_REGISTRY
"""
)).asInstanceOf[c.Expr[Codec[T]]]
}

// scalastyle:off method.length
def createCodec[T: c.WeakTypeTag](c: whitebox.Context)(codecRegistry: c.Expr[CodecRegistry]): c.Expr[Codec[T]] = {
import c.universe._

// Declared types
val mainType = weakTypeOf[T]

val stringType = typeOf[String]
val mapTypeSymbol = typeOf[MapLike[_, _, _]].typeSymbol

// Names
val classTypeName = mainType.typeSymbol.name.toTypeName
val codecName = TypeName(s"${classTypeName}MacroCodec")

// Type checkers
def isCaseClass(t: Type): Boolean = t.typeSymbol.isClass && t.typeSymbol.asClass.isCaseClass
def isMap(t: Type): Boolean = t.baseClasses.contains(mapTypeSymbol)
def isOption(t: Type): Boolean = t.typeSymbol == definitions.OptionClass
def isTuple(t: Type): Boolean = definitions.TupleClass.seq.contains(t.typeSymbol)
def isSealed(t: Type): Boolean = t.typeSymbol.isClass && t.typeSymbol.asClass.isSealed
def isCaseClassOrSealed(t: Type): Boolean = isCaseClass(t) || isSealed(t)

// Data converters
def keyName(t: Type): Literal = Literal(Constant(t.typeSymbol.name.decodedName.toString))
def keyNameTerm(t: TermName): Literal = Literal(Constant(t.toString))

def allSubclasses(s: Symbol): Set[Symbol] = {
val directSubClasses = s.asClass.knownDirectSubclasses
directSubClasses ++ directSubClasses.flatMap({ s: Symbol => allSubclasses(s) })
}
val subClasses: List[Type] = allSubclasses(mainType.typeSymbol).map(_.asClass.toType).filter(isCaseClass).toList
if (isSealed(mainType) && subClasses.isEmpty) c.abort(c.enclosingPosition, "No known subclasses of the sealed class")
val knownTypes = (mainType +: subClasses).reverse
def fields: Map[Type, List[(TermName, Type)]] = knownTypes.map(t => (t, t.members.sorted.filter(_.isMethod).map(_.asMethod).filter(_.isGetter)
.map(m => (m.name, m.returnType.asSeenFrom(t, t.typeSymbol))))).toMap

// Primitives type map
val primitiveTypesMap: Map[Type, Type] = Map(
typeOf[Boolean] -> typeOf[java.lang.Boolean],
typeOf[Byte] -> typeOf[java.lang.Byte],
typeOf[Char] -> typeOf[java.lang.Character],
typeOf[Double] -> typeOf[java.lang.Double],
typeOf[Float] -> typeOf[java.lang.Float],
typeOf[Int] -> typeOf[java.lang.Integer],
typeOf[Long] -> typeOf[java.lang.Long],
typeOf[Short] -> typeOf[java.lang.Short]
)

/**
* Flattens the type args for any given type.
*
* Removes the key field from Maps as they have to be strings.
* Removes Option type as the Option value is wrapped automatically below.
* Throws if the case class contains a Tuple
*
* @param t the type to flatten the arguments for
* @return a list of the type arguments for the type
*/
def flattenTypeArgs(t: Type): List[c.universe.Type] = {
val typeArgs = if (isMap(t)) {
if (t.typeArgs.head != stringType) c.abort(c.enclosingPosition, "Maps must contain string types for keys")
t.typeArgs.tail
} else {
t.typeArgs
}
val types = t +: typeArgs.flatMap(x => flattenTypeArgs(x))
if (types.exists(isTuple)) c.abort(c.enclosingPosition, "Tuples currently aren't supported in case classes")
types.filter(x => !isOption(x)).map(x => primitiveTypesMap.getOrElse(x, x))
}

/**
* Maps the given field names to type args for the values in the field
*
* ```
* addresses: Seq[Address] => (addresses, List[classOf[Seq], classOf[Address]])
* nestedAddresses: Seq[Seq[Address]] => (addresses, List[classOf[Seq], classOf[Seq], classOf[Address]])
* ```
*
* @return a map of the field names with a list of the contain types
*/
def createFieldTypeArgsMap(fields: List[(TermName, Type)]) = {
val setTypeArgs = fields.map({
case (name, f) =>
val key = keyNameTerm(name)
q"""
typeArgs += ($key -> {
val tpeArgs = mutable.ListBuffer.empty[Class[_]]
..${flattenTypeArgs(f).map(t => q"tpeArgs += classOf[${t.finalResultType}]")}
tpeArgs.toList
})"""
})

q"""
val typeArgs = mutable.Map[String, List[Class[_]]]()
..$setTypeArgs
typeArgs.toMap
"""
}

/**
* For each case class sets the Map of the given field names and their field types.
*/
def createClassFieldTypeArgsMap = {
val setClassFieldTypeArgs = fields.map(field =>
q"""
classFieldTypeArgs += (${keyName(field._1)} -> ${createFieldTypeArgsMap(field._2)})
""")

q"""
val classFieldTypeArgs = mutable.Map[String, Map[String, List[Class[_]]]]()
..$setClassFieldTypeArgs
classFieldTypeArgs.toMap
"""
}

/**
* Creates a `Map[String, Class[_]]` mapping the case class name and the type.
*
* @return the case classes map
*/
def caseClassesMap = {
val setSubClasses = knownTypes.map(t => q"caseClassesMap += (${keyName(t)} -> classOf[${t.finalResultType}])")
q"""
val caseClassesMap = mutable.Map[String, Class[_]]()
..$setSubClasses
caseClassesMap.toMap
"""
}

/**
* Creates a `Map[Class[_], Boolean]` mapping field types to a boolean representing if they are a case class.
*
* @return the class to case classes map
*/
def classToCaseClassMap = {
val flattenedFieldTypes = fields.flatMap({ case (t, types) => types.map(f => f._2) :+ t })
val setclassToCaseClassMap = flattenedFieldTypes.map(t => q"""classToCaseClassMap ++= ${
flattenTypeArgs(t).map(t =>
q"(classOf[${t.finalResultType}], ${isCaseClassOrSealed(t)})")
}""")

q"""
val classToCaseClassMap = mutable.Map[Class[_], Boolean]()
..$setclassToCaseClassMap
classToCaseClassMap.toMap
"""
}

/**
* Handles the writing of case class fields.
*
* @param fields the list of fields
* @return the tree that writes the case class fields
*/
def writeClassValues(fields: List[(TermName, Type)]): List[Tree] = {
fields.map({
case (name, f) =>
val key = keyNameTerm(name)
f match {
case optional if isOption(optional) => q"""
val localVal = instanceValue.$name
writer.writeName($key)
if (localVal.isDefined) {
this.writeValue(writer, localVal.get, encoderContext)
} else {
this.writeValue(writer, this.bsonNull, encoderContext)
}"""
case _ => q"""
val localVal = instanceValue.$name
writer.writeName($key)
this.writeValue(writer, localVal, encoderContext)
"""
}
})
}

/**
* Writes the Case Class fields and values to the BsonWriter
*/
def writeValue: Tree = {
val cases: Seq[Tree] = {
fields.map(field => cq""" ${keyName(field._1)} =>
val instanceValue = value.asInstanceOf[${field._1}]
..${writeClassValues(field._2)}""").toSeq
}

q"""
writer.writeStartDocument()
this.writeClassFieldName(writer, className, encoderContext)
className match { case ..$cases }
writer.writeEndDocument()
"""
}

def fieldSetters(fields: List[(TermName, Type)]) = {
fields.map({
case (name, f) =>
val key = keyNameTerm(name)
f match {
case optional if isOption(optional) => q"$name = Option(fieldData($key)).asInstanceOf[$f]"
case _ => q"$name = fieldData($key).asInstanceOf[$f]"
}
})
}

def getInstance = {
val cases = knownTypes.map { st =>
cq"${keyName(st)} => new $st(..${fieldSetters(fields(st))})"
} :+ cq"""_ => throw new CodecConfigurationException("Unexpected class type: " + className)"""
q"className match { case ..$cases }"
}

c.Expr[Codec[T]](
q"""
import scala.collection.mutable
import org.bson.BsonWriter
import org.bson.codecs.EncoderContext
import org.bson.codecs.configuration.{CodecRegistry, CodecConfigurationException}
import org.mongodb.scala.bson.codecs.macrocodecs.MacroCodec

case class $codecName(codecRegistry: CodecRegistry) extends MacroCodec[$classTypeName] {
val caseClassesMap = $caseClassesMap
val classToCaseClassMap = $classToCaseClassMap
val classFieldTypeArgsMap = $createClassFieldTypeArgsMap
val encoderClass = classOf[$classTypeName]
def getInstance(className: String, fieldData: Map[String, Any]) = $getInstance
def writeCaseClassData(className: String, writer: BsonWriter, value: $mainType, encoderContext: EncoderContext) = $writeValue
}

${codecName.toTermName}($codecRegistry).asInstanceOf[Codec[$mainType]]
"""
)
}
// scalastyle:on method.length
}

0 comments on commit f4a1717

Please sign in to comment.