Skip to content

Commit

Permalink
Implement field index with compile time check
Browse files Browse the repository at this point in the history
  • Loading branch information
romainreuillon committed May 9, 2024
1 parent ae64fff commit ccf2eb0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
35 changes: 35 additions & 0 deletions byte-pack/src/main/scala/bytepack/FieldIndex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,41 @@ object FieldIndex:

'{ $namesExpr.toMap }


inline def fieldIndex[F](f: F => Any)(using p: PackProduct[F]): Int = ${ fieldIndexImpl[F, Any]('f, 'p) }

def fieldIndexImpl[F, T](f: Expr[F => T], t: Expr[PackProduct[F]])(using quotes: Quotes, tpef: Type[F], typet: Type[T]): Expr[Int] =
// val term = f.show
import quotes.*
import quotes.reflect.*

val sym = TypeRepr.of[F].typeSymbol

// object CaseClass:
// def unapply(term: Term): Option[Term] =
// term.tpe.classSymbol.flatMap: sym =>
// Option.when(sym.flags.is(Flags.Case))(term)
//
//
// f.asTerm match
// case Select(_,_) => "test"

val source = f.asTerm.pos.sourceCode

val field = source.head.dropWhile(_ != '.').drop(1)
val index =
sym.caseFields.zipWithIndex.find((f, _) => f.name == field) match
case Some(f) => f._2 //if p.flags.is(Flags.HasDefault)
case None => report.errorAndAbort(s"No field named ${source} found in case class ${sym}", f.asTerm.pos)

// val field = source.head.dropWhile(_ != '.').drop(1)

val v = Expr(index)

'{
Pack.indexOf[F]($v)(using $t)
}

//
// val body = comp.tree.asInstanceOf[ClassDef].body
// val idents: List[Term] =
Expand Down
5 changes: 5 additions & 0 deletions byte-pack/src/main/scala/bytepack/Pack.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,18 @@ object Pack:
p.unpack(0, b)

def indexOf[T: PackProduct](i: Int) = summon[PackProduct[T]].index(i)
inline def indexOf[F](f: F => Any)(using p: PackProduct[F]): Int = ${ FieldIndex.fieldIndexImpl[F, Any]('f, 'p) }


// TODO find a way to check field name at compile time and get field type
def indexOf[T: PackProduct](field: String): Int =
val pack = summon[PackProduct[T]]
val index = pack.fields.getOrElse(field, throw RuntimeException(s"Field $field not found among ${pack.fields}"))
indexOf[T](index)


//inline def fieldName[F](inline f: F => Any) = FieldIndex.fieldName(f)

def size[T: Pack] = summon[Pack[T]].size


Expand Down
3 changes: 1 addition & 2 deletions byte-pack/src/test/scala/bytepack/PackTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,4 @@ class PackTests extends AnyFunSuite:
val packed = Pack.pack(p)
assert(Pack.indexOf[UpperClass](1) == 11)
assert(packed(Pack.indexOf[UpperClass](1)) == 8.toByte)
assert(Pack.indexOf[UpperClass]("j") == 11)

assert(Pack.indexOf[UpperClass](_.j) == 11)

0 comments on commit ccf2eb0

Please sign in to comment.