Skip to content

Commit

Permalink
Avoid generating invalid Scala identifiers when printing schemas and …
Browse files Browse the repository at this point in the history
…protocols (#178)

* Avoid generating invalid Scala identifiers from OpenAPI specs

There is already a `normalize` function which makes an attempt to
normalize strings into valid Scala identifiers, but it does not handle
all cases correctly, e.g. it does not handle reserved words.

Added an `ident` function which does a better job at converting an
arbitrary string into a valid identifier, by delegating the logic to
scalameta. Rather than fiddling with the contents of the string to make
it valid, we just wrap it in backticks if required.

Unfortunately, `normalize` is used for a lot of different things, not
only generation of valid identifiers, so I then had to go through the
code and switch from `normalize` to `ident` in *just* the right places
to make the tests pass.

The codegen behaviour has changed slightly (for the better I think) in a
few cases. See the updated unit tests for the details.

* Avoid invalid Scala identifiers when printing Mu schemas/protocols

Similar to the previous commit, just wrapping strings in
`toValidIdentifier` in all the right places.

Extended the existing protobuf printing test to cover this new
behaviour.

* Use constant for scalameta version

Co-Authored-By: Juan Pedro Moreno <4879373+juanpedromoreno@users.noreply.github.com>
  • Loading branch information
cb372 and juanpedromoreno committed Dec 5, 2019
1 parent eac3f8d commit 15e408b
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 79 deletions.
6 changes: 4 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ val V = new {
val droste = "0.8.0"
val kindProjector = "0.10.3"
val macroParadise = "2.1.1"
val meta = "4.3.0"
val scala212 = "2.12.10"
val scalacheck = "1.14.2"
val specs2 = "4.8.1"
Expand Down Expand Up @@ -103,18 +104,19 @@ lazy val commonSettings = Seq(
crossScalaVersions := Seq(V.scala212),
ThisBuild / scalacOptions -= "-Xplugin-require:macroparadise",
libraryDependencies ++= Seq(
%%("cats-laws", V.cats) % Test,
%%("cats-core", V.cats),
"io.higherkindness" %% "droste-core" % V.droste,
"io.higherkindness" %% "droste-macros" % V.droste,
"org.apache.avro" % "avro" % V.avro,
"com.github.os72" % "protoc-jar" % V.protoc,
"com.google.protobuf" % "protobuf-java" % V.protobuf,
"io.circe" %% "circe-yaml" % V.circeYaml,
"io.circe" %% "circe-testing" % V.circe % Test,
%%("cats-effect", V.catsEffect),
%%("circe-core", V.circe),
%%("circe-parser", V.circe),
"org.scalameta" %% "scalameta" % V.meta,
%%("cats-laws", V.cats) % Test,
"io.circe" %% "circe-testing" % V.circe % Test,
%%("scalacheck", V.scalacheck) % Test,
%%("specs2-core", V.specs2) % Test,
"org.typelevel" %% "discipline-specs2" % V.disciplineSpecs2 % Test,
Expand Down
12 changes: 12 additions & 0 deletions src/main/scala/higherkindness/skeuomorph/Printer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ object Printer {

val string: Printer[String] = print(identity)

val identifier: Printer[String] = print(toValidIdentifier)

val unit: Printer[Unit] = print(_ => "")

object avoid {
Expand All @@ -80,6 +82,16 @@ object Printer {
case xs => xs.map(p.print).mkString(sep)
}

/*
* The logic to decide whether a given string needs to be wrapped in backticks
* to be a valid Scala identifier is really complicated (the spec is at
* https://scala-lang.org/files/archive/spec/2.13/01-lexical-syntax.html#identifiers,
* but it's quite wrong/misleading). So we let scalameta take care of it for us.
*/
def toValidIdentifier(string: String): String =
if (string.isEmpty) string
else scala.meta.Term.Name(string).syntax

implicit val divisiblePrinter: Decidable[Printer] = new Decidable[Printer] {
def unit: Printer[Unit] = Printer.unit
def product[A, B](fa: Printer[A], fb: Printer[B]): Printer[(A, B)] = new Printer[(A, B)] {
Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/higherkindness/skeuomorph/mu/print.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object print {
case TBoolean() => "Boolean"
case TString() => "String"
case TByteArray() => "Array[Byte]"
case TNamedType(name) => name
case TNamedType(name) => toValidIdentifier(name)
case TOption(value) => s"Option[$value]"
case TEither(a, b) => s"Either[$a, $b]"
case TMap(Some(key), value) => s"Map[$key, $value]"
Expand All @@ -60,8 +60,8 @@ object print {
|}
""".stripMargin
case TProduct(name, fields) =>
val printFields = fields.map(f => s"${f.name}: ${f.tpe}").mkString(", ")
s"@message final case class $name($printFields)"
val printFields = fields.map(f => s"${toValidIdentifier(f.name)}: ${f.tpe}").mkString(", ")
s"@message final case class ${toValidIdentifier(name)}($printFields)"
}

Printer.print(scheme.cata(algebra))
Expand Down Expand Up @@ -184,7 +184,7 @@ object print {
def depImport[T](implicit T: Basis[MuF, T]): Printer[DependentImport[T]] =
(
konst("import ") *< string,
konst(".") *< string,
konst(".") *< identifier,
konst(".") *< Printer.print(namedTypes[T] >>> schema.print)
).contramapN(importTuple)

Expand All @@ -198,7 +198,7 @@ object print {
konst("package ") *< optional(string) >* newLine >* newLine,
sepBy(option, lineFeed),
sepBy(depImport, lineFeed) >* newLine >* newLine,
konst("object ") *< string >* konst(" { ") >* newLine >* newLine,
konst("object ") *< identifier >* konst(" { ") >* newLine >* newLine,
sepBy(Printer.print(nestedOptionInCoproduct[T] >>> schema.print), lineFeed) >* newLine,
sepBy(service, doubleLineFeed) >* (newLine >* newLine >* konst("}"))
).contramapN(protoTuple)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ package object circe {

private def codecsTypes[T](name: String): ((String, Tpe[T]), (String, Tpe[T])) = {
val tpe = Tpe[T](name)
(name -> tpe) -> (s"Option${name}" -> tpe.copy(required = false))
(name -> tpe) -> (s"Option${normalize(name)}" -> tpe.copy(required = false))
}
protected def isIdentifier: String => Boolean =
field => field.headOption.exists(x => x.isLetter && x.isLower) && field.forall(_.isLetterOrDigit)
Expand Down Expand Up @@ -84,7 +84,7 @@ package object circe {
val (default, optionType) = codecsTypes[T](name)
(
enumPackages,
(name -> name, values.map(x => normalize(x) -> x)).some,
(name -> name, values.map(x => ident(x) -> x)).some,
name.asRight.asLeft,
(name, values).asRight.asLeft,
default,
Expand All @@ -107,10 +107,10 @@ package object circe {
}

private def decoderDef[T: Basis[JsonSchemaF, ?], B](body: Printer[B]): Printer[(String, Tpe[T], B)] =
implicitVal(body).contramap { case (x, y, z) => (x, "Decoder", y, z) }
implicitVal(body).contramap { case (x, y, z) => (normalize(x), "Decoder", y, z) }

private def encoderDef[T: Basis[JsonSchemaF, ?], B](body: Printer[B]): Printer[(String, Tpe[T], B)] =
implicitVal(body).contramap { case (x, y, z) => (x, "Encoder", y, z) }
implicitVal(body).contramap { case (x, y, z) => (normalize(x), "Encoder", y, z) }

def enumCirceEncoder[T: Basis[JsonSchemaF, ?], B]: Printer[String] =
encoderDef(κ("Encoder.encodeString.contramap(_.show)")).contramap(x => (x, Tpe[T](x), ()))
Expand All @@ -125,7 +125,7 @@ package object circe {
"\n"
) >* newLine,
κ(""" case x => s"$x is not valid """) *< string >* κ("""".asLeft""") *< newLine *< κ("}") *< newLine
).contramapN(x => flip(second(x)(_.map(x => x -> normalize(x)))))).contramap {
).contramapN(x => flip(second(x)(_.map(x => x -> ident(x)))))).contramap {
case (x, xs) => (x, Tpe[T](x), x -> xs)
}

Expand Down Expand Up @@ -184,18 +184,18 @@ package object circe {
κ("implicit def ") *< string >* κ("EntityDecoder[F[_]:Sync]: "),
κ("EntityDecoder[F, ") *< tpe[T] >* κ("] = "),
κ("jsonOf[F, ") *< tpe[T] >* κ("]"))
.contramapN(x => (x._1, x._2, x._2))
.contramapN(x => (normalize(x._1), x._2, x._2))

def entityEncoder[T: Basis[JsonSchemaF, ?]]: Printer[(String, Tpe[T])] =
(
κ("implicit def ") *< string >* κ("EntityEncoder[F[_]:Applicative]: "),
κ("EntityEncoder[F, ") *< tpe[T] >* κ("] = "),
κ("jsonEncoderOf[F, ") *< tpe[T] >* κ("]"))
.contramapN(x => (x._1, x._2, x._2))
.contramapN(x => (normalize(x._1), x._2, x._2))

private def showEnum[T: Basis[JsonSchemaF, ?]]: Printer[((String, String), List[(String, String)])] =
divBy(
implicitVal(κ("Show.show {")).contramap { case (x, y) => (x, "Show", Tpe[T](y), ()) },
implicitVal(κ("Show.show {")).contramap { case (x, y) => (normalize(x), "Show", Tpe[T](y), ()) },
newLine,
sepBy(divBy(space *< space *< κ("case ") *< string, κ(" => "), κ("\"") *< string >* κ("\"")), "\n") >* newLine >* κ(
"}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object print {
openApi)
}

val listEnconderPrinter: Printer[Unit] = κ(
val listEncoderPrinter: Printer[Unit] = κ(
"implicit def listEntityEncoder[T: Encoder]: EntityEncoder[F, List[T]] = jsonEncoderOf[F, List[T]]")
val listDecoderPrinter: Printer[Unit] = κ(
"implicit def listEntityDecoder[T: Decoder]: EntityDecoder[F, List[T]] = jsonOf[F, List[T]]")
Expand All @@ -66,7 +66,7 @@ object print {
")") *< κ(": ") *< show[TraitName] >* κ("[F]"),
κ(" = new ") *< show[TraitName] >* κ("[F] {") >* newLine,
twoSpaces *< twoSpaces *< sepBy(importDef, "\n") >* newLine *<
twoSpaces *< twoSpaces *< listEnconderPrinter *< newLine *<
twoSpaces *< twoSpaces *< listEncoderPrinter *< newLine *<
twoSpaces *< twoSpaces *< listDecoderPrinter *< newLine *<
twoSpaces *< twoSpaces *< optionListEncoderPrinter *< newLine *<
twoSpaces *< twoSpaces *< optionListDecoderPrinter *< newLine *<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ object print {
val params = varPaths ++ queries
OperationId(
decapitalize(
normalize(operation.operationId
ident(operation.operationId
.getOrElse {
val printVerb = Http.Verb.methodFrom(verb)
val printParamsIfRequired = if (params.nonEmpty) s"By${params.mkString}" else ""
Expand Down
30 changes: 17 additions & 13 deletions src/main/scala/higherkindness/skeuomorph/openapi/print.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ object print {
val parametersRegex = """#/components/parameters/(.+)""".r

def schemaWithName[T: Basis[JsonSchemaF, ?]](implicit codecs: Printer[Codecs]): Printer[(String, T)] = Printer {
case (name, t) if (isBasicType(t)) => typeAliasDef(schema[T]()).print((normalize(name), t, none))
case (name, t) if (isBasicType(t)) => typeAliasDef(schema[T]()).print((ident(name), t, none))
case (name, t) if (isArray(t)) =>
typeAliasDef(schema[T]()).print((normalize(name), t, none))
case (name, t) => schema[T](normalize(name).some).print(t)
typeAliasDef(schema[T]()).print((ident(name), t, none))
case (name, t) => schema[T](ident(name).some).print(t)
}

protected[openapi] def schema[T: Basis[JsonSchemaF, ?]](name: Option[String] = None)(
Expand Down Expand Up @@ -76,9 +76,9 @@ object print {
case (ArrayF(x), _) => listDef.print(x)
case (EnumF(fields), Some(name)) => sealedTraitDef.print(name -> fields)
case (SumF(cases), Some(name)) => sumDef.print(name -> cases)
case (ReferenceF(schemasRegex(ref)), _) => normalize(ref)
case (ReferenceF(parametersRegex(ref)), _) => normalize(ref)
case (ReferenceF(ref), _) => normalize(ref)
case (ReferenceF(schemasRegex(ref)), _) => ident(ref)
case (ReferenceF(parametersRegex(ref)), _) => ident(ref)
case (ReferenceF(ref), _) => ident(ref)
}
}
Printer.print(scheme.cata(algebra))
Expand Down Expand Up @@ -127,7 +127,7 @@ object print {
openApi.components.toList
.flatMap(_.schemas.toList)
.filter { case (_, t) => isSum(t) }
.map { case (x, _) => normalize(x) }
.map { case (x, _) => ident(x) }
}

sealed trait Codecs
Expand All @@ -148,7 +148,7 @@ object print {
x =>
tpe.nestedTypes.headOption.map(_ => s"${tpe.nestedTypes.mkString(".")}.").getOrElse("") +
Printer
.print(Optimize.namedTypes[T](normalize(tpe.description)) >>> schema(none).print)
.print(Optimize.namedTypes[T](ident(tpe.description)) >>> schema(none).print)
.print(x)
)
def option[T: Basis[JsonSchemaF, ?]](tpe: Tpe[T]): Either[String, String] =
Expand All @@ -165,7 +165,7 @@ object print {

final case class Var(name: String) extends AnyVal
object Var {
implicit val varShow: Show[Var] = Show.show(x => decapitalize(normalize(x.name)))
implicit val varShow: Show[Var] = Show.show(x => decapitalize(ident(x.name)))
}
final case class VarWithType[T](name: Var, tpe: Tpe[T])
object VarWithType {
Expand Down Expand Up @@ -199,7 +199,7 @@ object print {
).contramap(x => (x, ((x._1, none), List.empty, SumCodecs.apply _ tupled (x))))

private def caseObjectDef: Printer[(String, String)] =
(κ("final case object ") *< string >* κ(" extends "), string).contramapN { case (x, y) => (normalize(x), y) }
(κ("final case object ") *< string >* κ(" extends "), string).contramapN { case (x, y) => (ident(x), y) }

private def sealedTraitCompanionObjectDef(
implicit codecs: Printer[Codecs]): Printer[(List[(String, String)], Codecs)] =
Expand Down Expand Up @@ -245,12 +245,16 @@ object print {
body >* newLine *< κ("}")
).contramap { case (x, y, z) => (x -> y, z) }

def normalize(value: String): String =
value
def normalize(value: String): String = {
val withoutBackticks = value.stripPrefix("`").stripSuffix("`")
withoutBackticks
.dropWhile(_.isDigit)
.split("[ _-]")
.map(_.filter(x => x.isLetterOrDigit).capitalize)
.mkString ++ value.takeWhile(_.isDigit)
.mkString ++ withoutBackticks.takeWhile(_.isDigit)
}

def ident(value: String): String = Printer.toValidIdentifier(value)

def divBy[A, B](p1: Printer[A], sep: Printer[Unit], p2: Printer[B]): Printer[(A, B)] =
(p1, sep, p2).contramapN[(A, B)] { case (x, y) => (x, (), y) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ class JsonSchemaPrintSpecification extends org.specs2.mutable.Specification {
|}""".stripMargin)
}

"when object is provided whose name is a Scala reserved word" >> {
schemaWithName
.print(
"=>" ->
Fixed.`object`(List("name" -> Fixed.string()), List("name"))) must ===(
s"""|final case class `=>`(name: String)
|object `=>` {
|
|
|}""".stripMargin)
}

"when object is provided without required fields" >> {
schemaWithName
.print(
Expand Down Expand Up @@ -109,6 +121,22 @@ class JsonSchemaPrintSpecification extends org.specs2.mutable.Specification {
|}""".stripMargin)
}

"when object is provided with a field name which is a Scala reserved word" >> {
schemaWithName
.print(
"Person" ->
Fixed.`object`(
List(
"name" -> Fixed.string(),
"type" -> Fixed.string()
),
List("name", "type"))) must ===(s"""|final case class Person(name: String, `type`: String)
|object Person {
|
|
|}""".stripMargin)
}

"when enum is provided" >> {
schemaWithName
.print("Color" -> Fixed.enum(List("Blue", "Red", "Yellow"))) must
Expand Down
Loading

0 comments on commit 15e408b

Please sign in to comment.