Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add default support in several locations #358

Merged
merged 11 commits into from
Aug 18, 2022
7 changes: 5 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ lazy val core = projectMatrix
(ThisBuild / baseDirectory).value / "sampleSpecs" / "errors.smithy",
(ThisBuild / baseDirectory).value / "sampleSpecs" / "example.smithy",
(ThisBuild / baseDirectory).value / "sampleSpecs" / "adtMember.smithy",
(ThisBuild / baseDirectory).value / "sampleSpecs" / "enums.smithy"
(ThisBuild / baseDirectory).value / "sampleSpecs" / "enums.smithy",
(ThisBuild / baseDirectory).value / "sampleSpecs" / "defaults.smithy"
),
(Test / sourceGenerators) := Seq(genSmithyScala(Test).taskValue),
testFrameworks += new TestFramework("weaver.framework.CatsEffect"),
Expand Down Expand Up @@ -394,6 +395,7 @@ lazy val protocol = projectMatrix
.filterNot { case (file, path) =>
path.equalsIgnoreCase("META-INF/smithy/manifest")
},
resolvers += Resolver.mavenLocal,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove when changing to use official smithy 2 release

libraryDependencies += Dependencies.Smithy.model,
javacOptions ++= Seq("--release", "8")
)
Expand Down Expand Up @@ -624,7 +626,8 @@ lazy val example = projectMatrix
(ThisBuild / baseDirectory).value / "sampleSpecs" / "brandscommon.smithy",
(ThisBuild / baseDirectory).value / "sampleSpecs" / "refined.smithy",
(ThisBuild / baseDirectory).value / "sampleSpecs" / "enums.smithy",
(ThisBuild / baseDirectory).value / "sampleSpecs" / "mixins.smithy"
(ThisBuild / baseDirectory).value / "sampleSpecs" / "mixins.smithy",
(ThisBuild / baseDirectory).value / "sampleSpecs" / "defaults.smithy"
),
Compile / resourceDirectory := (ThisBuild / baseDirectory).value / "modules" / "example" / "resources",
isCE3 := true,
Expand Down
18 changes: 11 additions & 7 deletions modules/codegen/src/smithy4s/codegen/CollisionAvoidance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ package smithy4s.codegen

import cats.syntax.all._
import cats.~>
import smithy4s.codegen.Hint.Constraint
import smithy4s.codegen.Hint.Native
import smithy4s.codegen.Type.Alias
import smithy4s.codegen.Type.PrimitiveType
import smithy4s.codegen.TypedNode._
Expand Down Expand Up @@ -143,15 +141,21 @@ object CollisionAvoidance {
}

private def modRef(ref: Type.Ref): Type.Ref =
Type.Ref(ref.namespace, ref.name.capitalize)
Type.Ref(ref.namespace, protect(ref.name.capitalize))

private def modNativeHint(hint: Hint.Native): Hint.Native =
Native(smithy4s.recursion.preprocess(modTypedNode)(hint.typedNode))
Hint.Native(smithy4s.recursion.preprocess(modTypedNode)(hint.typedNode))

private def modHint(hint: Hint): Hint = hint match {
case n: Native => modNativeHint(n)
case Constraint(tr, nat) => Constraint(modRef(tr), modNativeHint(nat))
case other => other
case n: Hint.Native => modNativeHint(n)
case Hint.Constraint(tr, nat) =>
Hint.Constraint(modRef(tr), modNativeHint(nat))
case df: Hint.Default => modDefault(df)
case other => other
}

private def modDefault(hint: Hint.Default): Hint.Default = {
Hint.Default(smithy4s.recursion.preprocess(modTypedNode)(hint.typedNode))
}

private def modProduct(p: Product): Product = {
Expand Down
82 changes: 63 additions & 19 deletions modules/codegen/src/smithy4s/codegen/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package smithy4s.codegen

import cats.Functor
import cats.data.NonEmptyList
import cats.syntax.all._
import smithy4s.codegen.TypedNode.FieldTN.OptionalNoneTN
Expand All @@ -29,14 +28,17 @@ import smithy4s.codegen.TypedNode.AltValueTN.TypeAltTN
import smithy4s.codegen.UnionMember._
import smithy4s.codegen.LineSegment.{NameDef, NameRef}
import cats.kernel.Eq
import cats.Traverse
import cats.Applicative
import cats.Eval

case class CompilationUnit(namespace: String, declarations: List[Decl])

sealed trait Decl {
def name: String
def hints: List[Hint]
def nameDef: NameDef = NameDef(name)
def nameRef: NameRef = NameRef(List.empty,name)
def nameRef: NameRef = NameRef(List.empty, name)
}

case class Service(
Expand Down Expand Up @@ -229,6 +231,7 @@ object Hint {
case object ErrorMessage extends Hint
case class Constraint(tr: Type.Ref, native: Native) extends Hint
case class Protocol(traits: List[Type.Ref]) extends Hint
case class Default(typedNode: Fix[TypedNode]) extends Hint
// traits that get rendered generically
case class Native(typedNode: Fix[TypedNode]) extends Hint
case object IntEnum extends Hint
Expand Down Expand Up @@ -273,6 +276,21 @@ object TypedNode {
}
}
object FieldTN {
implicit val fieldTNTraverse: Traverse[FieldTN] = new Traverse[FieldTN] {
def traverse[G[_]: Applicative, A, B](
fa: FieldTN[A]
)(f: A => G[B]): G[FieldTN[B]] =
fa match {
case RequiredTN(value) => f(value).map(RequiredTN(_))
case OptionalSomeTN(value) => f(value).map(OptionalSomeTN(_))
case OptionalNoneTN => Applicative[G].pure(OptionalNoneTN)
}
def foldLeft[A, B](fa: FieldTN[A], b: B)(f: (B, A) => B): B = ???
def foldRight[A, B](fa: FieldTN[A], lb: Eval[B])(
f: (A, Eval[B]) => Eval[B]
): Eval[B] = ???
}

case class RequiredTN[A](value: A) extends FieldTN[A]
case class OptionalSomeTN[A](value: A) extends FieldTN[A]
case object OptionalNoneTN extends FieldTN[Nothing]
Expand All @@ -284,28 +302,54 @@ object TypedNode {
}
}
object AltValueTN {
implicit val altValueTNTraverse: Traverse[AltValueTN] =
new Traverse[AltValueTN] {
def traverse[G[_]: Applicative, A, B](
fa: AltValueTN[A]
)(f: A => G[B]): G[AltValueTN[B]] =
fa match {
case ProductAltTN(value) => f(value).map(ProductAltTN(_))
case TypeAltTN(value) => f(value).map(TypeAltTN(_))
}
def foldLeft[A, B](fa: AltValueTN[A], b: B)(f: (B, A) => B): B = ???
def foldRight[A, B](fa: AltValueTN[A], lb: Eval[B])(
f: (A, Eval[B]) => Eval[B]
): Eval[B] = ???
}

case class ProductAltTN[A](value: A) extends AltValueTN[A]
case class TypeAltTN[A](value: A) extends AltValueTN[A]
}

implicit val typedNodeFunctor: Functor[TypedNode] = new Functor[TypedNode] {
def map[A, B](fa: TypedNode[A])(f: A => B): TypedNode[B] = fa match {
case EnumerationTN(ref, value, ordinal, name) =>
EnumerationTN(ref, value, ordinal, name)
case StructureTN(ref, fields) =>
StructureTN(ref, fields.map(_.map(_.map(f))))
case NewTypeTN(ref, target) =>
NewTypeTN(ref, f(target))
case AltTN(ref, altName, alt) =>
AltTN(ref, altName, alt.map(f))
case MapTN(values) =>
MapTN(values.map(_.leftMap(f).map(f)))
case CollectionTN(collectionType, values) =>
CollectionTN(collectionType, values.map(f))
case PrimitiveTN(prim, value) =>
PrimitiveTN(prim, value)
implicit val typedNodeTraverse: Traverse[TypedNode] =
new Traverse[TypedNode] {
def traverse[G[_], A, B](
fa: TypedNode[A]
)(f: A => G[B])(implicit F: Applicative[G]): G[TypedNode[B]] = fa match {
case EnumerationTN(ref, value, ordinal, name) =>
F.pure(EnumerationTN(ref, value, ordinal, name))
case StructureTN(ref, fields) =>
fields.traverse(_.traverse(_.traverse(f))).map(StructureTN(ref, _))
case NewTypeTN(ref, target) =>
f(target).map(NewTypeTN(ref, _))
case AltTN(ref, altName, alt) =>
alt.traverse(f).map(AltTN(ref, altName, _))
case MapTN(values) =>
values
.traverse { case (k, v) =>
(f(k), f(v)).tupled
}
.map(MapTN(_))
case CollectionTN(collectionType, values) =>
values.traverse(f).map(CollectionTN(collectionType, _))
case PrimitiveTN(prim, value) =>
F.pure(PrimitiveTN(prim, value))
}
def foldLeft[A, B](fa: TypedNode[A], b: B)(f: (B, A) => B): B = ???
def foldRight[A, B](fa: TypedNode[A], lb: Eval[B])(
f: (A, Eval[B]) => Eval[B]
): Eval[B] = ???
}
}

case class EnumerationTN(
ref: Type.Ref,
Expand Down
3 changes: 1 addition & 2 deletions modules/codegen/src/smithy4s/codegen/LineSegment.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import cats.Show
import cats.implicits._
import cats.data.Chain


// LineSegment models segments of a line of code.
sealed trait LineSegment { self =>
def toLine: Line = Line(Chain.one(self))
Expand Down Expand Up @@ -70,4 +69,4 @@ object LineSegment {
implicit def chainShow[A: Show]: Show[Chain[A]] = Show.show { chain =>
chain.toList.map(_.show).mkString
}
}
}
30 changes: 26 additions & 4 deletions modules/codegen/src/smithy4s/codegen/Renderer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,14 @@ private[codegen] class Renderer(compilationUnit: CompilationUnit) { self =>

val renderedErrorUnion = errorUnion.foldMap {
case union @ Union(_, originalName, alts, recursive, hints) =>
renderUnion(union.nameRef, originalName, alts, recursive, hints, error = true)
renderUnion(
union.nameRef,
originalName,
alts,
recursive,
hints,
error = true
)
}

lines(
Expand Down Expand Up @@ -629,11 +636,18 @@ private[codegen] class Renderer(compilationUnit: CompilationUnit) { self =>
noDefault: Boolean = false
): Line = {
field match {
case Field(name, _, tpe, required, _) =>
case Field(name, _, tpe, required, hints) =>
val line = line"$tpe"
line"$name: " + (if (required) line
else Line.optional(line, !noDefault))
val tpeAndDefault = if (required) {
val maybeDefault = hints
.collectFirst { case d @ Hint.Default(_) => d }
.map(renderDefault)
Line.required(line, maybeDefault)
} else {
Line.optional(line, !noDefault)
}

line"$name: " + tpeAndDefault
}
}
private def renderArgs(fields: List[Field]): Line = fields
Expand Down Expand Up @@ -792,6 +806,14 @@ private[codegen] class Renderer(compilationUnit: CompilationUnit) { self =>
._2
)

private def renderDefault(hint: Hint.Default): Line =
Line(
smithy4s.recursion
.cata(renderTypedNode)(hint.typedNode)
.run(true)
._2
)

private def renderHint(hint: Hint): Option[Line] = hint match {
case h: Hint.Native => renderNativeHint(h).some
case Hint.IntEnum => line"${NameRef("smithy4s", "IntEnum")}()".some
Expand Down
45 changes: 42 additions & 3 deletions modules/codegen/src/smithy4s/codegen/SmithyToIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ import software.amazon.smithy.model.Model
import software.amazon.smithy.model.node.Node
import software.amazon.smithy.model.shapes._
import software.amazon.smithy.model.traits.RequiredTrait
import software.amazon.smithy.model.traits.DefaultTrait
import software.amazon.smithy.model.traits._

import scala.jdk.CollectionConverters._
import software.amazon.smithy.model.selector.PathFinder
import scala.annotation.nowarn
import smithy4s.meta.ErrorMessageTrait
import smithy4s.codegen.Type.Alias

object SmithyToIR {

Expand Down Expand Up @@ -103,6 +105,9 @@ private[codegen] class SmithyToIR(model: Model, namespace: String) {

override def listShape(x: ListShape): Option[Decl] = getDefault(x)

@annotation.nowarn("msg=class SetShape in package shapes is deprecated")
override def setShape(x: SetShape): Option[Decl] = getDefault(x)

override def mapShape(x: MapShape): Option[Decl] = getDefault(x)

override def byteShape(x: ByteShape): Option[Decl] = getDefault(x)
Expand Down Expand Up @@ -572,6 +577,32 @@ private[codegen] class SmithyToIR(model: Model, namespace: String) {
}
}

// Captures the data representing the default value of a member shape.
private def maybeDefault(shape: MemberShape): List[Hint.Default] = {
val maybeTrait = shape.getTrait(classOf[DefaultTrait])
if (maybeTrait.isPresent()) {
val tr = maybeTrait.get()
// We're short-circuiting when encountering any external type,
// as we do not have the means to instantiate them in a safe manner.
def unfoldNodeAndTypeIfNotExternal(nodeAndType: NodeAndType) = {
nodeAndType.tpe match {
case _: Type.ExternalType => None
case _ => Some(unfoldNodeAndType(nodeAndType))
}
}
val targetTpe = shape.getTarget.tpe.get
// Constructing the initial value for the refold
val nodeAndType = targetTpe match {
case Alias(_, _, tpe, true) => NodeAndType(tr.toNode(), tpe)
case _ => NodeAndType(tr.toNode(), targetTpe)
}
val maybeTree = anaM(unfoldNodeAndTypeIfNotExternal)(nodeAndType)
maybeTree.map(Hint.Default(_)).toList
} else {
List.empty
}
}

@annotation.nowarn(
"msg=class UniqueItemsTrait in package traits is deprecated"
)
Expand Down Expand Up @@ -624,8 +655,9 @@ private[codegen] class SmithyToIR(model: Model, namespace: String) {
(
member.getMemberName(),
member.tpe,
member.hasTrait(classOf[RequiredTrait]),
hints(member)
member.hasTrait(classOf[RequiredTrait]) ||
member.hasTrait(classOf[DefaultTrait]),
hints(member) ++ maybeDefault(member)
)
}
.collect { case (name, Some(tpe), required, hints) =>
Expand Down Expand Up @@ -764,7 +796,14 @@ private[codegen] class SmithyToIR(model: Model, namespace: String) {
val fieldNames = struct.fields.map(_.name)
val fields: List[TypedNode.FieldTN[NodeAndType]] = structFields.map {
case Field(_, realName, tpe, true, _) =>
val node = map(realName) // validated by smithy
val node = map.get(realName).getOrElse {
struct
.getMember(realName)
.get
.getTrait(classOf[DefaultTrait])
.get
.toNode
} // value or default must be present on required field
TypedNode.FieldTN.RequiredTN(NodeAndType(node, tpe))
case Field(_, realName, tpe, false, _) =>
map.get(realName) match {
Expand Down
7 changes: 7 additions & 0 deletions modules/codegen/src/smithy4s/codegen/ToLine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ case class Line(segments: Chain[LineSegment]) {

object Line {

def required(line: Line, default: Option[Line]): Line = {
default match {
case None => line
case Some(value) => line + Literal(" = ") + value
}
}

def optional(line: Line, default: Boolean = false): Line = {
val option =
NameRef("Option").toLine + Literal("[") + line + Literal("]")
Expand Down
Loading