Skip to content

Commit

Permalink
Support multi level lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
romainreuillon committed May 10, 2024
1 parent f6de752 commit 4e1683b
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 25 deletions.
132 changes: 107 additions & 25 deletions byte-pack/src/main/scala/bytepack/FieldIndex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ object FieldIndex:

// inline def fieldIndex[F](f: F => Any): Int = ${ fieldIndexImpl[F]('{f}) }



def fieldIndexImpl[F, T](f: Expr[F => T])(using quotes: Quotes, tpef: Type[F]): Expr[Int] =
// val term = f.show
import quotes.*
Expand All @@ -61,42 +63,122 @@ object FieldIndex:
case Block(l, _) => l.flatMap(t => recur(t, selects))
case _ => List()

def getFieldType(fromType: TypeRepr, fieldName: String): TypeRepr =
def getClassSymbol(tpe: TypeRepr): Symbol = tpe.classSymbol match
case Some(sym) => sym
case None => report.errorAndAbort(s"${tpe} is not a concrete type", f.asTerm.pos)


// We need to do this to support tuples, because even though they conform as case classes in other respects,
// for some reason their field names (_1, _2, etc) have a space at the end, ie `_1 `.
def getTrimmedFieldSymbol(fromTypeSymbol: Symbol): Symbol =
fromTypeSymbol.memberFields.find(_.name.trim == fieldName).getOrElse(Symbol.noSymbol)

object FieldType:
def unapply(fieldSymbol: Symbol): Option[TypeRepr] = fieldSymbol match
case sym if sym.isNoSymbol => None
case sym =>
sym.tree match
case ValDef(_, typeTree, _) => Some(typeTree.tpe)
case _ => None

def swapWithSuppliedType(fromType: TypeRepr, possiblyContainsTypeArgs: TypeRepr): TypeRepr =
val declared = getDeclaredTypeArgs(fromType)
val supplied = getSuppliedTypeArgs(fromType)
val swapDict = declared.view.map(_.name).zip(supplied).toMap

def swapInto(candidate: TypeRepr): TypeRepr =
candidate match
case AppliedType(typeCons, args) => swapInto(typeCons).appliedTo(args.map(swapInto))
case leafType => swapDict.getOrElse(leafType.typeSymbol.name, leafType)

swapInto(possiblyContainsTypeArgs)

def getDeclaredTypeArgs(classType: TypeRepr): List[Symbol] =
classType.classSymbol.map(_.primaryConstructor.paramSymss) match
case Some(typeParamList :: _) if typeParamList.exists(_.isTypeParam) => typeParamList
case _ => Nil

def getSuppliedTypeArgs(fromType: TypeRepr): List[TypeRepr] =
fromType match
case AppliedType(_, argTypeReprs) => argTypeReprs
case _ => Nil

val fromTypeSymbol = getClassSymbol(fromType)
getTrimmedFieldSymbol(fromTypeSymbol) match
case FieldType(possiblyTypeArg) => swapWithSuppliedType(fromType, possiblyTypeArg)
case _ => report.errorAndAbort(s"Couldn't find field type ${fromType.show} $fieldName)")



// println(f.asTerm.show(using Printer.TreeStructure))
// println(recur(f.asTerm, List()))

val sym = TypeRepr.of[F].typeSymbol

// object CaseClass:
// def unapply(term: Term): Option[Term] =
// term.tpe.classSymbol.flatMap: sym =>
// Option.when(sym.flags.is(Flags.Case))(term)
//
//
// f.asTerm match
// case Select(_,_) => "test"

//val source = f.asTerm.pos.sourceCode
def fieldIndex(tpe: TypeRepr, fieldNames: List[String], acc: List[Expr[Int]]): List[Expr[Int]] =
fieldNames match
case Nil => acc.reverse
case fieldName :: tail =>
val packProduct =
tpe.asType match
case '[t] =>
Expr.summon[PackProduct[t]]

val index =
tpe.typeSymbol.caseFields.zipWithIndex.find((f, _) => f.name == fieldName) match
case Some(f) => f._2 //if p.flags.is(Flags.HasDefault)
case None => report.errorAndAbort(s"No field named ${fieldName} found in case class ${tpe}", f.asTerm.pos)

// val field = source.head.dropWhile(_ != '.').drop(1)

val v = Expr(index)


val code =
tpe.asType match
case '[t] =>
Expr.summon[PackProduct[t]] match
case Some(pack) =>
'{
Pack.indexOf[t]($v)(using $pack)
}
case None => report.errorAndAbort(s"No PackProduct type class defined for $tpe", f.asTerm.pos)


val fieldType = getFieldType(tpe, fieldName)
fieldIndex(fieldType, tail, code :: acc)

val selects = recur(f.asTerm, List())
if selects.size != 1 then report.errorAndAbort("Only one level of case class is supported for now", f.asTerm.pos)

val field = selects.head //source.head.dropWhile(_ != '.').drop(1)
val codes = Expr.ofSeq(fieldIndex(TypeRepr.of[F], selects, List()).toVector)

val index =
sym.caseFields.zipWithIndex.find((f, _) => f.name == field) match
case Some(f) => f._2 //if p.flags.is(Flags.HasDefault)
case None => report.errorAndAbort(s"No field named ${field} found in case class ${sym}", f.asTerm.pos)
'{ ${codes}.sum }

// val field = source.head.dropWhile(_ != '.').drop(1)
// if selects.size != 1 then report.errorAndAbort("Only one level of case class is supported for now", f.asTerm.pos)
//
// val field = selects.head //source.head.dropWhile(_ != '.').drop(1)
// val sym = TypeRepr.of[F].typeSymbol
//
//
// val index =
// sym.caseFields.zipWithIndex.find((f, _) => f.name == field) match
// case Some(f) => f._2 //if p.flags.is(Flags.HasDefault)
// case None => report.errorAndAbort(s"No field named ${field} found in case class ${sym}", f.asTerm.pos)
//
// // val field = source.head.dropWhile(_ != '.').drop(1)
//
// val v = Expr(index)
// val pack =
// Expr.summon[PackProduct[F]] match
// case Some(p) => p
// case None => report.errorAndAbort(s"Not found PackProduct for type ${sym}", f.asTerm.pos)
//
// '{
// Pack.indexOf[F]($v)(using $pack)
// }

val v = Expr(index)
val pack =
Expr.summon[PackProduct[F]] match
case Some(p) => p
case None => report.errorAndAbort(s"Not found PackProduct for type ${sym}", f.asTerm.pos)

'{
Pack.indexOf[F]($v)(using $pack)
}


class MkUnpackField[From]:
Expand Down
6 changes: 6 additions & 0 deletions byte-pack/src/test/scala/bytepack/PackTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ class PackTests extends AnyFunSuite:
assert(p == Pack.unpack[UpperClass](packed))
assert(Pack.unpack[UpperClass](_.j)(packed) == 8.toByte)

def unpackMethod = Pack.unpack[UpperClass](_.j)
assert(unpackMethod(packed) == 8.toByte)

assert(Pack.unpack[UpperClass](_.testClass.i)(packed) == 9)
assert(Pack.unpack[UpperClass](_.testClass.x)(packed) == 8.0)

test("Index should be correct"):
import PackTests.*
val p = UpperClass(TestClass(9, 8.0, En.V2, None, Some(En.V1)), 8.toByte)
Expand Down

0 comments on commit 4e1683b

Please sign in to comment.