Skip to content

Commit

Permalink
Implement modifers
Browse files Browse the repository at this point in the history
  • Loading branch information
romainreuillon committed May 13, 2024
1 parent 73c9e54 commit b5ed403
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 109 deletions.
215 changes: 107 additions & 108 deletions byte-pack/src/main/scala/bytepack/FieldIndex.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package bytepack

import bytepack.Pack.{Mutation, UnsetModifier}

/*
* Copyright (C) 2024 Romain Reuillon
*
Expand Down Expand Up @@ -41,14 +43,12 @@ object FieldIndex:
'{ $namesExpr.toMap }



class MkFieldIndex[From]:
transparent inline def apply[To](inline lambda: From => To): Int = ${ fieldIndexImpl[From, To]('{lambda}) }

// inline def fieldIndex[F](f: F => Any): Int = ${ fieldIndexImpl[F]('{f}) }



def fieldIndexImpl[F, T](f: Expr[F => T])(using quotes: Quotes, tpef: Type[F]): Expr[Int] =
// val term = f.show
import quotes.*
Expand All @@ -63,91 +63,9 @@ object FieldIndex:
case Block(l, _) => l.flatMap(t => recur(t, selects))
case _ => List()

def getFieldType(fromType: TypeRepr, fieldName: String): TypeRepr =
def getClassSymbol(tpe: TypeRepr): Symbol = tpe.classSymbol match
case Some(sym) => sym
case None => report.errorAndAbort(s"${tpe} is not a concrete type")


// We need to do this to support tuples, because even though they conform as case classes in other respects,
// for some reason their field names (_1, _2, etc) have a space at the end, ie `_1 `.
def getTrimmedFieldSymbol(fromTypeSymbol: Symbol): Symbol =
fromTypeSymbol.memberFields.find(_.name.trim == fieldName).getOrElse(Symbol.noSymbol)

object FieldType:
def unapply(fieldSymbol: Symbol): Option[TypeRepr] = fieldSymbol match
case sym if sym.isNoSymbol => None
case sym =>
sym.tree match
case ValDef(_, typeTree, _) => Some(typeTree.tpe)
case _ => None

def swapWithSuppliedType(fromType: TypeRepr, possiblyContainsTypeArgs: TypeRepr): TypeRepr =
val declared = getDeclaredTypeArgs(fromType)
val supplied = getSuppliedTypeArgs(fromType)
val swapDict = declared.view.map(_.name).zip(supplied).toMap

def swapInto(candidate: TypeRepr): TypeRepr =
candidate match
case AppliedType(typeCons, args) => swapInto(typeCons).appliedTo(args.map(swapInto))
case leafType => swapDict.getOrElse(leafType.typeSymbol.name, leafType)

swapInto(possiblyContainsTypeArgs)

def getDeclaredTypeArgs(classType: TypeRepr): List[Symbol] =
classType.classSymbol.map(_.primaryConstructor.paramSymss) match
case Some(typeParamList :: _) if typeParamList.exists(_.isTypeParam) => typeParamList
case _ => Nil

def getSuppliedTypeArgs(fromType: TypeRepr): List[TypeRepr] =
fromType match
case AppliedType(_, argTypeReprs) => argTypeReprs
case _ => Nil

val fromTypeSymbol = getClassSymbol(fromType)
getTrimmedFieldSymbol(fromTypeSymbol) match
case FieldType(possiblyTypeArg) => swapWithSuppliedType(fromType, possiblyTypeArg)
case _ => report.errorAndAbort(s"Couldn't find field type ${fromType.show} $fieldName)")



// println(f.asTerm.show(using Printer.TreeStructure))
// println(recur(f.asTerm, List()))



def fieldIndex(tpe: TypeRepr, fieldNames: List[String], acc: List[Expr[Int]]): List[Expr[Int]] =
fieldNames match
case Nil => acc.reverse
case fieldName :: tail =>
val packProduct =
tpe.asType match
case '[t] =>
Expr.summon[PackProduct[t]]

val index =
tpe.typeSymbol.caseFields.zipWithIndex.find((f, _) => f.name == fieldName) match
case Some(f) => f._2 //if p.flags.is(Flags.HasDefault)
case None => report.errorAndAbort(s"No field named ${fieldName} found in case class ${tpe}")

// val field = source.head.dropWhile(_ != '.').drop(1)

val v = Expr(index)


val code =
tpe.asType match
case '[t] =>
Expr.summon[PackProduct[t]] match
case Some(pack) =>
'{
Pack.indexOf[t]($v)(using $pack)
}
case None => report.errorAndAbort(s"No PackProduct type class defined for $tpe")

val fieldType = getFieldType(tpe, fieldName)
fieldIndex(fieldType, tail, code :: acc)

val selects = recur(f.asTerm, List())

val codes = fieldIndex(TypeRepr.of[F], selects, List()).toVector
Expand All @@ -160,30 +78,6 @@ object FieldIndex:
case 4 => '{ ${codes(0)} + ${codes(1)} + ${codes(2)} + ${codes(3)} }
case _ => '{ ${Expr.ofSeq(codes)}.sum }

// if selects.size != 1 then report.errorAndAbort("Only one level of case class is supported for now", f.asTerm.pos)
//
// val field = selects.head //source.head.dropWhile(_ != '.').drop(1)
// val sym = TypeRepr.of[F].typeSymbol
//
//
// val index =
// sym.caseFields.zipWithIndex.find((f, _) => f.name == field) match
// case Some(f) => f._2 //if p.flags.is(Flags.HasDefault)
// case None => report.errorAndAbort(s"No field named ${field} found in case class ${sym}", f.asTerm.pos)
//
// // val field = source.head.dropWhile(_ != '.').drop(1)
//
// val v = Expr(index)
// val pack =
// Expr.summon[PackProduct[F]] match
// case Some(p) => p
// case None => report.errorAndAbort(s"Not found PackProduct for type ${sym}", f.asTerm.pos)
//
// '{
// Pack.indexOf[F]($v)(using $pack)
// }




class MkUnpackField[From]:
Expand Down Expand Up @@ -216,3 +110,108 @@ object FieldIndex:
// Expr.ofList(idents.map(_.asExpr))
//
// '{ $namesExpr.zip($identsExpr).toMap }

class MkModifyField[From]:
transparent inline def apply[To](inline lambda: From => To): UnsetModifier[To] = ${ modifyFieldImpl[From, To]('{ lambda }) }


def modifyFieldImpl[F, T](f: Expr[F => T])(using quotes: Quotes, tpef: Type[F], tpeT: Type[T]): Expr[UnsetModifier[T]] =
import quotes.*
import quotes.reflect.*

val packT =
Expr.summon[Pack[T]] match
case Some(p) => p
case None => report.errorAndAbort(s"Not found Pack for type ${tpeT}", f.asTerm.pos)

'{
val index = Pack.indexOf[F]($f)
new UnsetModifier[T]:
def set(t: T): Mutation =
val packedT = IArray.toArray(Pack.pack[T](t)(using $packT))
(b: Array[Byte]) => System.arraycopy(packedT, 0, b, index, packedT.length)
}

def fieldIndex(using quotes: Quotes)(tpe: quotes.reflect.TypeRepr, fieldNames: List[String], acc: List[Expr[Int]]): List[Expr[Int]] =
import quotes.*
import quotes.reflect.*

fieldNames match
case Nil => acc.reverse
case fieldName :: tail =>
val packProduct =
tpe.asType match
case '[t] =>
Expr.summon[PackProduct[t]]

val index =
tpe.typeSymbol.caseFields.zipWithIndex.find((f, _) => f.name == fieldName) match
case Some(f) => f._2 //if p.flags.is(Flags.HasDefault)
case None => report.errorAndAbort(s"No field named ${fieldName} found in case class ${tpe}")

// val field = source.head.dropWhile(_ != '.').drop(1)

val v = Expr(index)


val code =
tpe.asType match
case '[t] =>
Expr.summon[PackProduct[t]] match
case Some(pack) =>
'{
Pack.indexOf[t]($v)(using $pack)
}
case None => report.errorAndAbort(s"No PackProduct type class defined for $tpe")

val fieldType = getFieldType(tpe, fieldName)
fieldIndex(fieldType, tail, code :: acc)

def getFieldType(using quotes: Quotes)(fromType: quotes.reflect.TypeRepr, fieldName: String): quotes.reflect.TypeRepr =
import quotes.*
import quotes.reflect.*
def getClassSymbol(tpe: TypeRepr): Symbol = tpe.classSymbol match
case Some(sym) => sym
case None => report.errorAndAbort(s"${tpe} is not a concrete type")


// We need to do this to support tuples, because even though they conform as case classes in other respects,
// for some reason their field names (_1, _2, etc) have a space at the end, ie `_1 `.
def getTrimmedFieldSymbol(fromTypeSymbol: Symbol): Symbol =
fromTypeSymbol.memberFields.find(_.name.trim == fieldName).getOrElse(Symbol.noSymbol)

object FieldType:
def unapply(fieldSymbol: Symbol): Option[TypeRepr] = fieldSymbol match
case sym if sym.isNoSymbol => None
case sym =>
sym.tree match
case ValDef(_, typeTree, _) => Some(typeTree.tpe)
case _ => None

def swapWithSuppliedType(fromType: TypeRepr, possiblyContainsTypeArgs: TypeRepr): TypeRepr =
val declared = getDeclaredTypeArgs(fromType)
val supplied = getSuppliedTypeArgs(fromType)
val swapDict = declared.view.map(_.name).zip(supplied).toMap

def swapInto(candidate: TypeRepr): TypeRepr =
candidate match
case AppliedType(typeCons, args) => swapInto(typeCons).appliedTo(args.map(swapInto))
case leafType => swapDict.getOrElse(leafType.typeSymbol.name, leafType)

swapInto(possiblyContainsTypeArgs)

def getDeclaredTypeArgs(classType: TypeRepr): List[Symbol] =
classType.classSymbol.map(_.primaryConstructor.paramSymss) match
case Some(typeParamList :: _) if typeParamList.exists(_.isTypeParam) => typeParamList
case _ => Nil

def getSuppliedTypeArgs(fromType: TypeRepr): List[TypeRepr] =
fromType match
case AppliedType(_, argTypeReprs) => argTypeReprs
case _ => Nil

val fromTypeSymbol = getClassSymbol(fromType)
getTrimmedFieldSymbol(fromTypeSymbol) match
case FieldType(possiblyTypeArg) => swapWithSuppliedType(fromType, possiblyTypeArg)
case _ => report.errorAndAbort(s"Couldn't find field type ${fromType.show} $fieldName)")

15 changes: 14 additions & 1 deletion byte-pack/src/main/scala/bytepack/Pack.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package bytepack
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

import bytepack.FieldIndex.{MkFieldIndex, MkUnpackField}
import bytepack.FieldIndex.{MkFieldIndex, MkModifyField, MkUnpackField}

import scala.deriving.*
import scala.compiletime.*
Expand Down Expand Up @@ -99,6 +99,19 @@ object Pack:
def size[T: Pack] = summon[Pack[T]].size


type Mutation = Array[Byte] => Unit
trait UnsetModifier[T]:
def set(v: T): Mutation

def modifier[From]: MkModifyField[From] = new MkModifyField[From]

def modify[From](p: IArray[Byte], mutation: Mutation*): IArray[Byte] =
val arr = p.toArray
mutation.foreach: m =>
m(arr)
IArray.unsafeFromArray(arr)


def packProduct[T](p: Mirror.ProductOf[T], elems: => Array[Pack[_]]): Pack[T] with PackProduct[T] =
inline def packElement(elem: Pack[_])(x: Any, b: ByteBuffer): Unit =
elem.asInstanceOf[Pack[Any]].pack(x, b)
Expand Down
13 changes: 13 additions & 0 deletions byte-pack/src/test/scala/bytepack/PackTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,16 @@ class PackTests extends AnyFunSuite:
assert(Pack.indexOf[UpperClass](1) == 11)
assert(packed(Pack.indexOf[UpperClass](1)) == 8.toByte)
assert(Pack.indexOf[UpperClass](_.j) == 11)

test("Modify should work"):
import PackTests.*
val p = UpperClass(TestClass(9, 8.0, En.V2, None, Some(En.V1)), 8.toByte)
val packed = Pack.pack(p)

val modifyX = Pack.modifier[UpperClass](_.testClass.x)
val modifyJ = Pack.modifier[UpperClass](_.j)

val newPacked = Pack.modify(packed, modifyX.set(20.0f), modifyJ.set(100.toByte))

assert(Pack.unpack[UpperClass](packed) == p)
assert(Pack.unpack[UpperClass](newPacked) == p.copy(testClass = p.testClass.copy(x = 20.0f), j = 100.toByte))

0 comments on commit b5ed403

Please sign in to comment.