Skip to content

Commit

Permalink
Use less specific val type, sort types by references in code rather…
Browse files Browse the repository at this point in the history
… than by traversing types.
  • Loading branch information
mrdziuban committed Feb 23, 2024
1 parent ea9ecc1 commit d5f86d3
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 213 deletions.
22 changes: 0 additions & 22 deletions src/main/scala/scalats/ReferencedTypes.scala

This file was deleted.

63 changes: 0 additions & 63 deletions src/main/scala/scalats/TsGenerator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -526,67 +526,4 @@ final class TsGenerator(
)
)), model)
}

/** Parses the types that a scala `case class` refers to */
private def referencedTypeNamesInterface(iface: TsModel.Interface, currType: TypeName): Map[TypeName, Set[TypeName]] =
iface.fields.foldMap { case TsModel.InterfaceField(_, tpe) => referencedTypeNames(tpe, currType) }

/** Parses the types that a scala `object` refers to */
private def referencedTypeNamesObject(obj: TsModel.Object, currType: TypeName): Map[TypeName, Set[TypeName]] =
obj.fields.foldMap { case TsModel.ObjectField(_, tpe, _) => referencedTypeNames(tpe, currType) }

/** Parses the types that a scala `enum`/`sealed trait`/`sealed class` refers to */
private def referencedTypeNamesUnion(union: TsModel.Union, currType: TypeName): Map[TypeName, Set[TypeName]] = {
import cats.syntax.semigroup.*
union.possibilities.foldMap(x => Map(currType -> Set(x.typeName)) |+| (x match {
case i: TsModel.Interface => referencedTypeNamesInterface(i, i.typeName)
case o: TsModel.Object => referencedTypeNamesObject(o, o.typeName)
}))
}

/** Parses the types that a given [[scalats.TsModel]] refers to */
private def referencedTypeNames(model: TsModel, currType: TypeName): Map[TypeName, Set[TypeName]] = {
import cats.syntax.semigroup.*

model match {
case (
TsModel.TypeParam(_)
| TsModel.Json(_)
| TsModel.Number(_)
| TsModel.BigNumber(_)
| TsModel.Boolean(_)
| TsModel.String(_)
| TsModel.LocalDate(_)
| TsModel.DateTime(_)
) => Map.empty

case TsModel.Literal(tpe, _) => referencedTypeNames(tpe, currType)
case TsModel.Eval(_, tpe) => referencedTypeNames(tpe, currType)
case TsModel.Array(_, tpe, _) => referencedTypeNames(tpe, currType)
case TsModel.Set(_, tpe) => referencedTypeNames(tpe, currType)
case TsModel.NonEmptyArray(_, tpe, _) => referencedTypeNames(tpe, currType)
case TsModel.Option(_, tpe) => referencedTypeNames(tpe, currType)
case TsModel.Either(_, left, right, _) => referencedTypeNames(left, currType) |+| referencedTypeNames(right, currType)
case TsModel.Ior(_, left, right, _) => referencedTypeNames(left, currType) |+| referencedTypeNames(right, currType)
case TsModel.Map(_, key, value) => referencedTypeNames(key, currType) |+| referencedTypeNames(value, currType)
case TsModel.Tuple(_, tpes) => tpes.foldMap(referencedTypeNames(_, currType))
case i: TsModel.Interface => referencedTypeNamesInterface(i, currType)
case TsModel.InterfaceRef(typeName, typeArgs) => Map(currType -> Set(typeName)) |+| typeArgs.foldMap(referencedTypeNames(_, currType))
case o: TsModel.Object => referencedTypeNamesObject(o, currType)
case TsModel.ObjectRef(typeName) => Map(currType -> Set(typeName))
case u: TsModel.Union => referencedTypeNamesUnion(u, currType)
case TsModel.UnionRef(typeName, typeArgs, _) => Map(currType -> Set(typeName)) |+| typeArgs.foldMap(referencedTypeNames(_, currType))
case TsModel.Unknown(typeName, typeArgs) => Map(currType -> Set(typeName)) |+| typeArgs.foldMap(referencedTypeNames(_, currType))
}
}

/**
* Produces a [[scalats.ReferencedTypes]] containing the list of [[scalats.TypeName]]s
* that the given [[scalats.TypeName]] refers to
*
* @param model The [[scalats.TsModel]] to parse references for
* @return The types the model references
*/
def referencedTypes(model: TsModel): ReferencedTypes =
new ReferencedTypes(referencedTypeNames(model, model.typeName))
}
91 changes: 1 addition & 90 deletions src/main/scala/scalats/TsParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -216,39 +216,7 @@ final class TsParser()(using override val ctx: Quotes) extends ReflectionUtils {
case t => t
}

sym.tree match {
case ValDef(_, _, Some(term)) => unAnd(term.tpe)
case ValDef(_, tpe, None) =>
println(s"""
**********************************************
typeRepr: $typeRepr

tpe: ${tpe.tpe}

typeRepr.show: ${typeRepr.show}

typeRepr.typeSymbol.primaryConstructor.tree: ${typeRepr.typeSymbol.primaryConstructor.tree}

typeRepr.typeSymbol.typeRef: ${typeRepr.typeSymbol.typeRef}

typeRepr.typeSymbol.termRef: ${typeRepr.typeSymbol.termRef}

typeRepr.typeSymbol.flags.is(Flags.Enum): ${typeRepr.typeSymbol.flags.is(Flags.Enum)}

typeRepr.typeSymbol.owner.termRef.show: ${typeRepr.typeSymbol.owner.termRef.show}

t1: ${typeRepr.select(sym).dealias}

t2: ${unAnd(typeRepr.memberType(sym))}
**********************************************
""")

// unAnd(typeRepr.memberType(sym))
// typeRepr.select(sym)
unAnd(tpe.tpe)

case _ => report.errorAndAbort(s"Failed to get type of ${typeRepr.show}#${sym.name}")
}
unAnd(typeRepr.memberType(sym))
}

/** Parse an `object` definiton into its [[scalats.TsModel]] representation */
Expand All @@ -272,60 +240,3 @@ t2: ${unAnd(typeRepr.memberType(sym))}
private def parseObjectRef[A: Type]: Expr[TsModel.ObjectRef] =
'{ TsModel.ObjectRef(${ mkTypeName(TypeRepr.of[A]) }) }
}

private def parseValueImpl[A: Type](value: Expr[A])(using Quotes): Expr[TsModel] =
new TsParser().parse[A](false)

inline def parseValue[A](value: A): TsModel = ${ parseValueImpl[A]('value) }

private case class PrettyPrinter(level: Int, inQuotes: Boolean, backslashed: Boolean) {
val indent = List.fill(level)(" ").mkString

def transform(char: Char): (PrettyPrinter, String) = {
val woSlash = copy(backslashed = false)
val (pp, f): (PrettyPrinter, PrettyPrinter => String) = char match {
case '"' if inQuotes && !backslashed => (woSlash.copy(inQuotes = false), (_: PrettyPrinter) => s"$char")
case '"' if !inQuotes => (woSlash.copy(inQuotes = true), (_: PrettyPrinter) => s"$char")
case '\\' if inQuotes && !backslashed => (copy(backslashed = true), (_: PrettyPrinter) => s"$char")

case ',' if !inQuotes => (woSlash, (p: PrettyPrinter) => s",\n${p.indent}")
case '(' if !inQuotes => (woSlash.copy(level = level + 1), (p: PrettyPrinter) => s"(\n${p.indent}")
case ')' if !inQuotes => (woSlash.copy(level = level - 1), (p: PrettyPrinter) => s"\n${p.indent})")
case _ => (woSlash, (_: PrettyPrinter) => s"$char")
}
(pp, f(pp))
}
}

private def prettyPrint(raw: String): String =
raw.foldLeft((PrettyPrinter(0, false, false), new StringBuilder(""))) { case ((pp, sb), char) =>
val (newPP, res) = pp.transform(char)
(newPP, sb.append(res))
}._2.toString.replaceAll("""\(\s+\)""", "()")

private def getTypeOfXImpl[A: Type](using q: Quotes): Expr[Unit] = {
import q.reflect.*

val tpe = TypeRepr.of[A]
val sym = tpe.typeSymbol
val x = sym.fieldMembers.find(_.name == "x").get
val valDef = x.tree.asInstanceOf[ValDef]
val ctor = sym.primaryConstructor.tree.asInstanceOf[DefDef]
println(s"""

tpe.memberType(x): ${tpe.memberType(x)}

valDef.tpt.tpe: ${valDef.tpt.tpe}

ctor: $ctor

ctor.paramss.head.params.head.rhs: ${ctor.paramss.head.params.head.asInstanceOf[ValDef].tpt}

sym.tree: ${prettyPrint(sym.tree.toString)}

""")

'{ () }
}

inline def getTypeOfX[A] = ${ getTypeOfXImpl[A] }
6 changes: 1 addition & 5 deletions src/main/scala/scalats/TypeName.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@ package scalats
* typeName.base // Baz
* ```
*/
class TypeName private (val raw: String) {
case class TypeName(raw: String) {
final val full: String = raw.split('[').head.stripSuffix(".type")
final val base: String = full.split('.').last
override final lazy val toString: String = raw
}

object TypeName {
def apply(raw: String): TypeName = new TypeName(raw) {}
}

object T { def apply(raw: String): TypeName = TypeName(raw) }
54 changes: 30 additions & 24 deletions src/main/scala/scalats/TypeSorter.scala
Original file line number Diff line number Diff line change
@@ -1,43 +1,49 @@
package scalats

import cats.Eval
import scala.util.matching.Regex

/**
* A helper to sort generated types, see [[scalats.TypeSorter.sort]]
*/
object TypeSorter {
private case class GeneratedWithIndex(gen: Generated, index: Int)

private val constRx = """(?:const ([^\s]+))""".r
private val functionRx = """(?:function ([^\s<\(]+))""".r
private val typeRx = """(?:type ([^\s<]+))""".r

private def extractNames(gen: GeneratedWithIndex, rx: Regex): Map[String, GeneratedWithIndex] =
rx.findAllMatchIn(gen.gen.code).toList.map(_.group(1)).map((_, gen)).toMap

/**
* Sort a set of generated types based on references between them
*
* @param types The generated types
* @params refs The mapping of type to other types it refers to
* @return A list of generated code sorted properly
*/
def sort(types: List[(Option[TypeName], Generated)], refs: ReferencedTypes): List[Generated] = {
val all = types.flatMap { case (typeName, generated) => typeName.map(n => (n.full, generated)) }.toMap
def sort(types: List[(Option[TypeName], Generated)]): List[Generated] = {
val gwis = types.zipWithIndex.map { case ((_, gen), idx) => GeneratedWithIndex(gen, idx) }
val definitions = gwis.flatMap(gen =>
extractNames(gen, constRx) ++ extractNames(gen, functionRx) ++ extractNames(gen, typeRx)
).toMap
// Names can be referenced either as `foo` or `importedN_foo` where `N` is a number
val definitionNamesRx = ("""(?:imported\d+_|\b)(""" ++
definitions.keys.map(name => Regex.quote(name) ++ "\\b").mkString("|") ++
")").r

def addType(
acc: Eval[(List[Generated], Set[String])],
typeName: TypeName,
generated: Generated,
): Eval[(List[Generated], Set[String])] =
acc.flatMap { case (accGen, accSkip) =>
if (accSkip.contains(typeName.full))
Eval.now((accGen, accSkip))
else
refs.get(typeName) match {
case Some(types) =>
types.foldLeft(Eval.now((accGen, accSkip + typeName.full)))(
(acc, ref) => all.get(ref.full).fold(acc)(addType(acc, ref, _))
).map { case (x, y) => (x :+ generated, y) }
case None =>
Eval.now((accGen :+ generated, accSkip + typeName.full))
}
def addGenerated(acc: Eval[(List[Generated], Set[Int])], gen: GeneratedWithIndex): Eval[(List[Generated], Set[Int])] =
acc.flatMap { case acc @ (accGen, accSkip) =>
if (accSkip.contains(gen.index))
Eval.now(acc)
else {
val refNames = definitionNamesRx.findAllMatchIn(gen.gen.code).map(_.group(1)).toSet
refNames.foldLeft(Eval.now((accGen, accSkip + gen.index)))(
(acc, refName) => definitions.get(refName).fold(acc)(addGenerated(acc, _))
).map { case (x, y) => (x :+ gen.gen, y) }
}
}

types.foldLeft(Eval.now(List.empty[Generated], Set.empty[String])) {
case (acc, (Some(typeName), generated)) => addType(acc, typeName, generated)
case (acc, (None, generated)) => acc.map { case (accGen, accSkip) => (accGen :+ generated, accSkip) }
}.value._1
gwis.foldLeft(Eval.now(List.empty[Generated], Set.empty[Int]))(addGenerated).value._1
}
}
18 changes: 9 additions & 9 deletions src/main/scala/scalats/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,30 @@ def generateAll(all: Map[File, List[TsModel]], debug: Boolean = false, debugFilt
using customType: TsCustomType,
customOrd: TsCustomOrd,
imports: TsImports.Available,
): List[(File, (List[(Option[TypeName], Generated)], ReferencedTypes))] = {
): List[(File, List[(Option[TypeName], Generated)])] = {
val generator = new TsGenerator(customType, customOrd, imports, debug, debugFilter)
all.toList.foldMap { case (file, models) =>
Map(file.toString -> (models.flatMap(generator.generateTopLevel), models.foldMap(generator.referencedTypes)))
}.toList.map { case (file, (types, refs)) => (new File(file), (types.distinctBy(_._1.map(_.full)), refs)) }
Map(file.toString -> models.flatMap(generator.generateTopLevel))
}.toList.map { case (file, types) => (new File(file), types.distinctBy(_._1.map(_.full))) }
}

private def withFileWriter[A](file: File)(f: PrintStream => A): A = {
new File(file.getParent).mkdirs
Using.resource(new PrintStream(file))(f)
}

private def mkAllTypesByFile(all: List[(File, (List[(Option[TypeName], Generated)], ReferencedTypes))]): Map[String, Set[TypeName]] =
all.foldMap { case (f, ts) => Map(f.toString -> ts._1.flatMap(_._1).toSet) }
private def mkAllTypesByFile(all: List[(File, List[(Option[TypeName], Generated)])]): Map[String, Set[TypeName]] =
all.foldMap { case (f, ts) => Map(f.toString -> ts.flatMap(_._1).toSet) }

/**
* Write the types generated by [[scalats.generateAll]] to the file system
*
* @param all The generated types
*/
def writeAll(all: List[(File, (List[(Option[TypeName], Generated)], ReferencedTypes))]): Unit = {
def writeAll(all: List[(File, List[(Option[TypeName], Generated)])]): Unit = {
val allTypesByFile = mkAllTypesByFile(all)
val res = all.map { case (file, (types, refs)) =>
val (allImports, allCode) = TypeSorter.sort(types, refs).foldMap { case Generated(imports, code) =>
val res = all.map { case (file, types) =>
val (allImports, allCode) = TypeSorter.sort(types).foldMap { case Generated(imports, code) =>
val (updImports, updCode) = imports.resolve(file.toString, allTypesByFile, code)
(updImports, updCode + "\n\n")
}
Expand Down Expand Up @@ -91,7 +91,7 @@ def writeAll(all: Map[File, List[TsModel]], debug: Boolean = false, debugFilter:
def resolve(
currFile: File,
generated: Generated,
allGenerated: List[(File, (List[(Option[TypeName], Generated)], ReferencedTypes))],
allGenerated: List[(File, List[(Option[TypeName], Generated)])],
): (Map[TsImport.Resolved, TsImport], String) = {
val allTypesByFile = mkAllTypesByFile(allGenerated)
val Generated(imports, code) = generated
Expand Down

0 comments on commit d5f86d3

Please sign in to comment.