Skip to content

Commit

Permalink
Merge pull request #81 from kubukoz/more-schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
kubukoz committed Aug 13, 2022
2 parents a3ccbdd + 9913b3f commit a7e67a7
Show file tree
Hide file tree
Showing 16 changed files with 322 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ object DocumentSymbolProvider {
int = _ => Nil,
listed = list => findInList(node.copy(value = list)),
bool = _ => Nil,
nul = _ => Nil,
)

private def findInList(
Expand Down
59 changes: 29 additions & 30 deletions core/src/main/scala/playground/NodeEncoderVisitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ import smithy4s.Hints
import smithy4s.Lazy
import smithy4s.Refinement
import smithy4s.ShapeId
import smithy4s.capability.EncoderK
import smithy4s.schema.Alt
import smithy4s.schema.CollectionTag
import smithy4s.schema.CollectionTag.IndexedSeqTag
import smithy4s.schema.CollectionTag.ListTag
import smithy4s.schema.CollectionTag.SetTag
import smithy4s.schema.CollectionTag.VectorTag
import smithy4s.schema.EnumValue
import smithy4s.schema.Field
import smithy4s.schema.Primitive
Expand All @@ -42,12 +48,8 @@ import smithy4s.schema.Primitive.PUnit
import smithy4s.schema.Schema
import smithy4s.schema.SchemaField
import smithy4s.schema.SchemaVisitor
import smithy4s.schema.CollectionTag
import smithy4s.capability.EncoderK
import smithy4s.schema.CollectionTag.IndexedSeqTag
import smithy4s.schema.CollectionTag.ListTag
import smithy4s.schema.CollectionTag.SetTag
import smithy4s.schema.CollectionTag.VectorTag
import smithy4s.ByteArray
import playground.smithyql.NullLiteral

trait NodeEncoder[A] {
def toNode(a: A): InputNode[Id]
Expand Down Expand Up @@ -87,19 +89,19 @@ object NodeEncoderVisitor extends SchemaVisitor[NodeEncoder] { self =>
def primitive[P](shapeId: ShapeId, hints: Hints, tag: Primitive[P]): NodeEncoder[P] =
tag match {
case PInt => int
case PShort => unsupported("short")
case PLong => int.contramap(_.toInt) // todo: wraps
case PShort => short
case PLong => long
case PString => string
case PBigInt => unsupported("bigint")
case PBigInt => bigint
case PBoolean => boolean
case PBigDecimal => bigdecimal
case PBlob => string.contramap(_.toString) // todo this only works for UTF-8 text
case PDouble => int.contramap(_.toInt) // todo: wraps
case PBlob => string.contramap((_: ByteArray).toString)
case PDouble => double
case PDocument => document
case PFloat => unsupported("float")
case PUnit => struct(shapeId, hints, Vector.empty, _ => ())
case PFloat => float
case PUnit => _ => obj(Nil)
case PUUID => string.contramap(_.toString())
case PByte => unsupported("byte")
case PByte => byte
case PTimestamp => string.contramap(_.toString)
}

Expand Down Expand Up @@ -205,11 +207,15 @@ object NodeEncoderVisitor extends SchemaVisitor[NodeEncoder] { self =>
value => mapped.value.toNode(value)
}

def unsupported[A](tag: String): NodeEncoder[A] =
v => throw new Exception(s"Unsupported operation: $tag for value $v")

val bigdecimal: NodeEncoder[BigDecimal] = unsupported("bigdecimal")
val int: NodeEncoder[Int] = IntLiteral(_)
private val number: NodeEncoder[String] = IntLiteral(_)
val bigdecimal: NodeEncoder[BigDecimal] = number.contramap(_.toString)
val bigint: NodeEncoder[BigInt] = number.contramap(_.toString)
val long: NodeEncoder[Long] = number.contramap(_.toString)
val int: NodeEncoder[Int] = number.contramap(_.toString)
val short: NodeEncoder[Short] = number.contramap(_.toString)
val byte: NodeEncoder[Byte] = number.contramap(_.toString)
val float: NodeEncoder[Float] = number.contramap(_.toString)
val double: NodeEncoder[Double] = number.contramap(_.toString)

val string: NodeEncoder[String] = StringLiteral(_)

Expand All @@ -228,17 +234,10 @@ object NodeEncoderVisitor extends SchemaVisitor[NodeEncoder] { self =>
doc match {
case DArray(value) => document.listed.toNode(value.toList)
case DBoolean(value) => boolean.toNode(value)
case DNumber(value) =>
if (value.isValidInt)
int.toNode(value.toInt)
else
// todo other numbers
bigdecimal.toNode(value)
case DNull =>
// todo nul???
unsupported[Null]("null").toNode(null)
case DString(value) => string.toNode(value)
case DObject(value) => obj(value.toList.map(_.map(document.toNode)))
case DNumber(value) => number.toNode(value.toString())
case DNull => NullLiteral()
case DString(value) => string.toNode(value)
case DObject(value) => obj(value.toList.map(_.map(document.toNode)))
}

}
87 changes: 57 additions & 30 deletions core/src/main/scala/playground/QueryCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ import java.util.UUID
import types._
import util.chaining._
import PartialCompiler.WAST
import smithy4s.ByteArray
import java.util.Base64

trait PartialCompiler[A] {
final def emap[B](f: A => PartialCompiler.Result[B]): PartialCompiler[B] =
ast => compile(ast).flatMap(f)

// TODO: Actually use the powers of Ior. Maybe a custom monad for errors / warnings? Diagnosed[A]? Either+Writer composition?
def compile(ast: WAST): PartialCompiler.Result[A]

}
Expand Down Expand Up @@ -163,6 +164,8 @@ sealed trait CompilationErrorDetails extends Product with Serializable {
case Message(text) => text
case DeprecatedItem(info) => "Deprecated" + CompletionItem.deprecationString(info)
case InvalidUUID => "Invalid UUID"
case InvalidBlob => "Invalid blob, expected base64-encoded string"
case NumberOutOfRange(value, expectedType) => s"Number out of range for $expectedType: $value"
case EnumFallback(enumName) =>
s"""Matching enums by value is deprecated and may be removed in the future. Use $enumName instead.""".stripMargin
case DuplicateItem => "Duplicate item - some entries will be dropped to fit in a set shape."
Expand All @@ -181,8 +184,6 @@ sealed trait CompilationErrorDetails extends Product with Serializable {

case TypeMismatch(expected, actual) => s"Type mismatch: expected $expected, got $actual."

case UnsupportedNode(tag) => s"Unsupported operation: $tag"

case OperationNotFound(name, validOperations) =>
s"Operation ${name.text} not found. Available operations: ${validOperations.map(_.text).mkString_(", ")}."

Expand Down Expand Up @@ -250,6 +251,9 @@ object CompilationErrorDetails {

case object InvalidUUID extends CompilationErrorDetails

final case class NumberOutOfRange(numberValue: String, typeName: String)
extends CompilationErrorDetails

final case class InvalidTimestampFormat(expected: TimestampFormat) extends CompilationErrorDetails

final case class MissingDiscriminator(possibleValues: NonEmptyList[String])
Expand All @@ -271,36 +275,64 @@ object CompilationErrorDetails {

final case class RefinementFailure(msg: String) extends CompilationErrorDetails

final case class UnsupportedNode(tag: String) extends CompilationErrorDetails

case object DuplicateItem extends CompilationErrorDetails

case object InvalidBlob extends CompilationErrorDetails

case class DeprecatedItem(info: api.Deprecated) extends CompilationErrorDetails

final case class EnumFallback(enumName: String) extends CompilationErrorDetails
}

object QueryCompiler extends SchemaVisitor[PartialCompiler] {

private def checkRange[A, B](
pc: PartialCompiler[A]
)(
tag: String
)(
matchToRange: A => Option[B]
) = (pc, PartialCompiler.pos).tupled.emap { case (i, range) =>
matchToRange(i)
.toRightIor(
CompilationError
.error(NumberOutOfRange(i.toString, tag), range)
)
.toIorNec
}

def primitive[P](shapeId: ShapeId, hints: Hints, tag: Primitive[P]): PartialCompiler[P] =
tag match {
case PString => string
case PBoolean =>
PartialCompiler
.typeCheck(NodeKind.Bool) { case b @ BooleanLiteral(_) => b }
.map(_.value.value)
case PUnit => struct(shapeId, hints, Vector.empty, _ => ())
case PInt =>
PartialCompiler
.typeCheck(NodeKind.IntLiteral) { case i @ IntLiteral(_) => i }
.map(_.value.value)
case PDocument => document
case PShort => unsupported("short")
case PBlob => unsupported("blob")
case PByte => unsupported("byte")
case PBigDecimal => unsupported("bigDecimal")
case PDouble => unsupported("double")
case PBigInt => unsupported("bigint")
case PUnit => struct(shapeId, hints, Vector.empty, _ => ())
case PLong => checkRange(integer)("int")(_.toLongOption)
case PInt => checkRange(integer)("int")(_.toIntOption)
case PShort => checkRange(integer)("short")(_.toShortOption)
case PByte => checkRange(integer)("byte")(_.toByteOption)
case PFloat => checkRange(integer)("float")(_.toFloatOption)
case PDouble => checkRange(integer)("double")(_.toDoubleOption)
case PDocument => document
case PBlob =>
(string, PartialCompiler.pos).tupled.emap { case (s, range) =>
Either
.catchNonFatal(Base64.getDecoder().decode(s))
.map(ByteArray(_))
.leftMap(_ => CompilationError.error(CompilationErrorDetails.InvalidBlob, range))
.toIor
.toIorNec
}
case PBigDecimal =>
checkRange(integer)("bigdecimal") { s =>
Either.catchNonFatal(BigDecimal(s)).toOption
}
case PBigInt =>
checkRange(integer)("bigint") { s =>
Either.catchNonFatal(BigInt(s)).toOption
}
case PUUID =>
stringLiteral.emap { s =>
Either
Expand All @@ -310,8 +342,6 @@ object QueryCompiler extends SchemaVisitor[PartialCompiler] {
.toIorNec
}

case PLong => unsupported("long")
case PFloat => unsupported("float")
case PTimestamp =>
stringLiteral.emap { s =>
// We don't support other formats for the simple reason that it's not necessary:
Expand All @@ -331,6 +361,10 @@ object QueryCompiler extends SchemaVisitor[PartialCompiler] {
}
}

private val integer: PartialCompiler[String] = PartialCompiler
.typeCheck(NodeKind.IntLiteral) { case i @ IntLiteral(_) => i }
.map(_.value.value)

def collection[C[_], A](
shapeId: ShapeId,
hints: Hints,
Expand Down Expand Up @@ -570,23 +604,15 @@ object QueryCompiler extends SchemaVisitor[PartialCompiler] {
it.value.compile(_)
}

def unsupported[A](ctx: String): PartialCompiler[A] =
ast =>
Ior.leftNec(
CompilationError.error(
UnsupportedNode(ctx),
ast.range,
)
)

val stringLiteral =
PartialCompiler.typeCheck(NodeKind.StringLiteral) { case StringLiteral(s) => s }

val document: PartialCompiler[Document] =
_.value match {
case BooleanLiteral(value) => Document.fromBoolean(value).pure[PartialCompiler.Result]
case IntLiteral(value) => Document.fromInt(value).pure[PartialCompiler.Result]
case StringLiteral(value) => Document.fromString(value).pure[PartialCompiler.Result]
case IntLiteral(value) =>
Document.fromBigDecimal(BigDecimal(value)).pure[PartialCompiler.Result]
case StringLiteral(value) => Document.fromString(value).pure[PartialCompiler.Result]
// parTraverse in this file isn't going to work like you think it will
case Listed(values) => values.value.parTraverse(document.compile(_)).map(Document.array(_))
case Struct(fields) =>
Expand All @@ -595,6 +621,7 @@ object QueryCompiler extends SchemaVisitor[PartialCompiler] {
.value
.parTraverse { case (key, value) => document.compile(value).tupleLeft(key.value.text) }
.map(Document.obj(_: _*))
case NullLiteral() => Document.nullDoc.rightIor
}

val string = stringLiteral.map(_.value)
Expand Down
1 change: 0 additions & 1 deletion core/src/main/scala/playground/run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ object Compiler {
dsi
.allServices
.map { svc =>
// todo: deprecated services (here / in completions)
QualifiedIdentifier
.forService(svc.service) -> Compiler.fromService[svc.Alg, svc.Op](svc.service)
}
Expand Down
10 changes: 9 additions & 1 deletion core/src/main/scala/playground/smithyql/AST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ sealed trait InputNode[F[_]] extends AST[F] {
int: IntLiteral[F] => A,
listed: Listed[F] => A,
bool: BooleanLiteral[F] => A,
nul: NullLiteral[F] => A,
): A =
this match {
case s @ Struct(_) => struct(s)
case i @ IntLiteral(_) => int(i)
case b @ BooleanLiteral(_) => bool(b)
case s @ StringLiteral(_) => string(s)
case l @ Listed(_) => listed(l)
case n @ NullLiteral() => nul(n)
}

def mapK[G[_]: Functor](fk: F ~> G): InputNode[G]
Expand Down Expand Up @@ -158,7 +160,12 @@ object Struct {

}

final case class IntLiteral[F[_]](value: Int) extends InputNode[F] {
final case class NullLiteral[F[_]]() extends InputNode[F] {
def kind: NodeKind = NodeKind.NullLiteral
def mapK[G[_]: Functor](fk: F ~> G): InputNode[G] = copy()
}

final case class IntLiteral[F[_]](value: String) extends InputNode[F] {
def kind: NodeKind = NodeKind.IntLiteral
def mapK[G[_]: Functor](fk: F ~> G): InputNode[G] = copy()
}
Expand Down Expand Up @@ -188,6 +195,7 @@ sealed trait NodeKind extends Product with Serializable
object NodeKind {
case object Struct extends NodeKind
case object IntLiteral extends NodeKind
case object NullLiteral extends NodeKind
case object StringLiteral extends NodeKind
case object Query extends NodeKind
case object Listed extends NodeKind
Expand Down
7 changes: 6 additions & 1 deletion core/src/main/scala/playground/smithyql/DSL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ object DSL {
): Struct[Id] = Struct[Id](Struct.Fields.fromSeq[Id](args.map(_.leftMap(Struct.Key(_)))))

implicit def stringToAST(s: String): StringLiteral[Id] = StringLiteral[Id](s)
implicit def intToAST(i: Int): IntLiteral[Id] = IntLiteral[Id](i)
implicit def intToAST(i: Int): IntLiteral[Id] = IntLiteral[Id](i.toString)
implicit def longToAST(i: Long): IntLiteral[Id] = IntLiteral[Id](i.toString)
implicit def floatToAST(i: Float): IntLiteral[Id] = IntLiteral[Id](i.toString)
implicit def doubleToAST(i: Double): IntLiteral[Id] = IntLiteral[Id](i.toString)
implicit def bigIntToAST(i: BigInt): IntLiteral[Id] = IntLiteral[Id](i.toString)
implicit def bigDecimalToAST(i: BigDecimal): IntLiteral[Id] = IntLiteral[Id](i.toString)
implicit def boolToAST(b: Boolean): BooleanLiteral[Id] = BooleanLiteral[Id](b)

implicit def listToAST[A](
Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/playground/smithyql/Formatter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ object Formatter {
case BooleanLiteral(b) => Doc.text(b.toString())
case StringLiteral(s) => Doc.text(renderStringLiteral(s))
case l @ Listed(_) => renderSequence(l)
case NullLiteral() => Doc.text("null")
}

def renderOperationName(o: OperationName[WithSource]): Doc = Doc.text(o.text)
Expand Down
8 changes: 5 additions & 3 deletions core/src/main/scala/playground/smithyql/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ object SmithyQLParser {
(Rfc5234.alpha ~ Parser.charsWhile0(ch => ch.isLetterOrDigit || "_".contains(ch)))
.map { case (ch, s) => s.prepended(ch) }

val number: Parser[Int] = Numbers
.signedIntString
.map(_.toInt)
val number: Parser[String] = Numbers.jsonNumber

val bool: Parser[Boolean] = string("true").as(true).orElse(string("false").as(false))

Expand All @@ -110,6 +108,8 @@ object SmithyQLParser {
.with1
.surroundedBy(char('"'))

val nullLiteral: Parser[Unit] = string("null")

def punctuation(c: Char): Parser[Unit] = char(c)

val equalsSign = punctuation('=')
Expand Down Expand Up @@ -152,11 +152,13 @@ object SmithyQLParser {
val boolLiteral = tokens.bool.map(BooleanLiteral[T](_))

val stringLiteral = tokens.stringLiteral.map(StringLiteral[T](_))
val nullLiteral = tokens.nullLiteral.map(_ => NullLiteral[T]())

lazy val node: Parser[InputNode[T]] = Parser.defer {
intLiteral |
boolLiteral |
stringLiteral |
nullLiteral |
struct |
listed
}
Expand Down

0 comments on commit a7e67a7

Please sign in to comment.