Skip to content

Commit

Permalink
Merge pull request #1908 from blast-hardcheese/lookupStatusCode-tracker
Browse files Browse the repository at this point in the history
Including Tracker info in lookupStatusCode
  • Loading branch information
blast-hardcheese committed Dec 31, 2023
2 parents 910781b + e3813ba commit e3c9b94
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object Responses {
instances <- responses
.traverse { case (key, resp) =>
for {
httpCode <- lookupStatusCode(key)
httpCode <- lookupStatusCode(Tracker.cloneHistory(resp, key))
(statusCode, statusCodeName) = httpCode
valueTypes <- (for {
(rawContentType, content) <- resp.downField("content", _.getContent()).indexedDistribute.value
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
package dev.guardrail.terms.framework

import dev.guardrail.languages.LA
import dev.guardrail.core.Tracker

abstract class FrameworkTerms[L <: LA, F[_]] { self =>
def getFrameworkImports(tracing: Boolean): F[List[L#Import]]
def getFrameworkImplicits(): F[Option[(L#TermName, L#ObjectDefinition)]]
def getFrameworkDefinitions(tracing: Boolean): F[List[(L#TermName, List[L#Definition])]]
def lookupStatusCode(key: String): F[(Int, L#TermName)]
def lookupStatusCode(key: Tracker[String]): F[(Int, L#TermName)]
def fileType(format: Option[String]): F[L#Type]
def objectType(format: Option[String]): F[L#Type]

def copy(
getFrameworkImports: Boolean => F[List[L#Import]] = self.getFrameworkImports _,
getFrameworkImplicits: () => F[Option[(L#TermName, L#ObjectDefinition)]] = self.getFrameworkImplicits _,
getFrameworkDefinitions: Boolean => F[List[(L#TermName, List[L#Definition])]] = self.getFrameworkDefinitions _,
lookupStatusCode: String => F[(Int, L#TermName)] = self.lookupStatusCode _,
lookupStatusCode: Tracker[String] => F[(Int, L#TermName)] = self.lookupStatusCode _,
fileType: Option[String] => F[L#Type] = self.fileType _,
objectType: Option[String] => F[L#Type] = self.objectType _
) = {
Expand All @@ -29,7 +30,7 @@ abstract class FrameworkTerms[L <: LA, F[_]] { self =>
def getFrameworkImports(tracing: Boolean) = newGetFrameworkImports(tracing)
def getFrameworkImplicits() = newGetFrameworkImplicits()
def getFrameworkDefinitions(tracing: Boolean) = newGetFrameworkDefinitions(tracing)
def lookupStatusCode(key: String) = newLookupStatusCode(key)
def lookupStatusCode(key: Tracker[String]) = newLookupStatusCode(key)
def fileType(format: Option[String]) = newFileType(format)
def objectType(format: Option[String]) = newObjectType(format)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import com.github.javaparser.ast.expr._
import scala.reflect.runtime.universe.typeTag

import dev.guardrail.Target
import dev.guardrail.core.SupportDefinition
import dev.guardrail.core.{ SupportDefinition, Tracker }
import dev.guardrail.generators.java.JavaCollectionsGenerator
import dev.guardrail.generators.java.JavaLanguage
import dev.guardrail.generators.java.JavaVavrCollectionsGenerator
Expand Down Expand Up @@ -58,10 +58,10 @@ class DropwizardGenerator private extends FrameworkTerms[JavaLanguage, Target] {
)
)

def lookupStatusCode(key: String) = {
def lookupStatusCode(key: Tracker[String]) = {
def parseStatusCode(code: Int, termName: String): Target[(Int, Name)] =
safeParseName(termName).map(name => (code, name))
key match {
key.unwrapTracker match {
case "100" => parseStatusCode(100, "Continue")
case "101" => parseStatusCode(101, "SwitchingProtocols")
case "102" => parseStatusCode(102, "Processing")
Expand Down Expand Up @@ -128,7 +128,7 @@ class DropwizardGenerator private extends FrameworkTerms[JavaLanguage, Target] {
case "511" => parseStatusCode(511, "NetworkAuthenticationRequired")
case "598" => parseStatusCode(598, "NetworkReadTimeout")
case "599" => parseStatusCode(599, "NetworkConnectTimeout")
case _ => Target.raiseUserError(s"Unknown HTTP status code: ${key}")
case code => Target.raiseUserError(s"Unknown HTTP status code: ${code} (${key.showHistory})")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.github.javaparser.ast.expr.Name
import scala.reflect.runtime.universe.typeTag

import dev.guardrail.Target
import dev.guardrail.core.Tracker
import dev.guardrail.generators.java.JavaCollectionsGenerator
import dev.guardrail.generators.java.JavaLanguage
import dev.guardrail.generators.java.JavaVavrCollectionsGenerator
Expand Down Expand Up @@ -40,11 +41,11 @@ class SpringMvcGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage,
def getFrameworkDefinitions(tracing: Boolean) =
Target.pure(List.empty)

def lookupStatusCode(key: String) = {
def lookupStatusCode(key: Tracker[String]) = {
def parseStatusCode(code: Int, termName: String): Target[(Int, Name)] =
safeParseName(termName).map(name => (code, name))

key match {
key.unwrapTracker match {
case "100" => parseStatusCode(100, "Continue")
case "101" => parseStatusCode(101, "SwitchingProtocols")
case "102" => parseStatusCode(102, "Processing")
Expand Down Expand Up @@ -111,7 +112,7 @@ class SpringMvcGenerator private (implicit Cl: CollectionsLibTerms[JavaLanguage,
case "511" => parseStatusCode(511, "NetworkAuthenticationRequired")
case "598" => parseStatusCode(598, "NetworkReadTimeout")
case "599" => parseStatusCode(599, "NetworkConnectTimeout")
case _ => Target.raiseUserError(s"Unknown HTTP status code: ${key}")
case code => Target.raiseUserError(s"Unknown HTTP status code: ${code} (${key.showHistory})")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import scala.meta._
import scala.reflect.runtime.universe.typeTag

import dev.guardrail.{ RuntimeFailure, Target }
import dev.guardrail.core.Tracker
import dev.guardrail.generators.scala.{ CirceModelGenerator, CirceRefinedModelGenerator, JacksonModelGenerator, ModelGeneratorType, ScalaLanguage }
import dev.guardrail.generators.spi.{ FrameworkGeneratorLoader, ModuleLoadResult, ProtocolGeneratorLoader }
import dev.guardrail.terms.framework._
Expand Down Expand Up @@ -383,8 +384,8 @@ class AkkaHttpGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGenerato
override def getFrameworkDefinitions(tracing: Boolean) =
Target.pure(List.empty)

override def lookupStatusCode(key: String) =
key match {
override def lookupStatusCode(key: Tracker[String]) =
key.unwrapTracker match {
case "100" => Target.pure((100, q"Continue"))
case "101" => Target.pure((101, q"SwitchingProtocols"))
case "102" => Target.pure((102, q"Processing"))
Expand Down Expand Up @@ -464,6 +465,6 @@ class AkkaHttpGenerator private (akkaHttpVersion: AkkaHttpVersion, modelGenerato
case "511" => Target.pure((511, q"NetworkAuthenticationRequired"))
case "598" => Target.pure((598, q"NetworkReadTimeout"))
case "599" => Target.pure((599, q"NetworkConnectTimeout"))
case _ => Target.raiseUserError(s"Unknown HTTP status code: ${key}")
case code => Target.raiseUserError(s"Unknown HTTP status code: ${code} (${key.showHistory})")
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.guardrail.generators.scala.dropwizard

import dev.guardrail.{ Target, UserError }
import dev.guardrail.core.Tracker
import dev.guardrail.generators.scala.ScalaLanguage
import dev.guardrail.terms.framework.FrameworkTerms
import dev.guardrail.generators.spi.{ FrameworkGeneratorLoader, ModuleLoadResult }
Expand Down Expand Up @@ -30,8 +31,8 @@ class DropwizardGenerator private extends FrameworkTerms[ScalaLanguage, Target]

// jaxrs has a Status enum, but it is missing a _lot_ of codes,
// so we'll make our own here and use ints in the generated code
override def lookupStatusCode(key: String): Target[(Int, Term.Name)] =
key match {
override def lookupStatusCode(key: Tracker[String]): Target[(Int, Term.Name)] =
key.unwrapTracker match {
case "100" => Target.pure((100, q"Continue"))
case "101" => Target.pure((101, q"SwitchingProtocols"))
case "102" => Target.pure((102, q"Processing"))
Expand Down Expand Up @@ -108,6 +109,6 @@ class DropwizardGenerator private extends FrameworkTerms[ScalaLanguage, Target]
.filter(code => code >= 100 && code <= 599)
.map((_, Term.Name(s"StatusCode$custom")))
.toOption
Target.fromOption(customCode, UserError(s"'$custom' is not a valid HTTP status code"))
Target.fromOption(customCode, UserError(s"'$custom' is not a valid HTTP status code (${key.showHistory})"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import scala.meta._
import scala.reflect.runtime.universe.typeTag

import dev.guardrail.Target
import dev.guardrail.core.Tracker
import dev.guardrail.generators.scala._
import dev.guardrail.generators.spi.{ FrameworkGeneratorLoader, ModuleLoadResult }
import dev.guardrail.terms.framework._
Expand Down Expand Up @@ -89,8 +90,8 @@ class Http4sGenerator private extends FrameworkTerms[ScalaLanguage, Target] {
def getFrameworkDefinitions(tracing: Boolean) =
Target.pure(List.empty)

def lookupStatusCode(key: String): Target[(Int, scala.meta.Term.Name)] =
key match {
def lookupStatusCode(key: Tracker[String]): Target[(Int, scala.meta.Term.Name)] =
key.unwrapTracker match {
case "100" => Target.pure((100, q"Continue"))
case "101" => Target.pure((101, q"SwitchingProtocols"))
case "102" => Target.pure((102, q"Processing"))
Expand Down Expand Up @@ -149,6 +150,6 @@ class Http4sGenerator private extends FrameworkTerms[ScalaLanguage, Target] {
case "508" => Target.pure((508, q"LoopDetected"))
case "510" => Target.pure((510, q"NotExtended"))
case "511" => Target.pure((511, q"NetworkAuthenticationRequired"))
case _ => Target.raiseUserError(s"Unknown HTTP status code: ${key}")
case code => Target.raiseUserError(s"Unknown HTTP status code: ${code} (${key.showHistory})")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class ArrayValidationTest extends AnyFreeSpec with Matchers with SwaggerSpecRunn
def getFrameworkImplicits(): dev.guardrail.Target[Option[
(dev.guardrail.generators.scala.ScalaLanguage#TermName, dev.guardrail.generators.scala.ScalaLanguage#ObjectDefinition)
]] = ???
def getFrameworkImports(tracing: Boolean): dev.guardrail.Target[List[dev.guardrail.generators.scala.ScalaLanguage#Import]] = ???
def lookupStatusCode(key: String): dev.guardrail.Target[(Int, dev.guardrail.generators.scala.ScalaLanguage#TermName)] = ???
def objectType(format: Option[String]): dev.guardrail.Target[dev.guardrail.generators.scala.ScalaLanguage#Type] = Target.pure(t"io.circe.Json")
def getFrameworkImports(tracing: Boolean): dev.guardrail.Target[List[dev.guardrail.generators.scala.ScalaLanguage#Import]] = ???
def lookupStatusCode(key: Tracker[String]): dev.guardrail.Target[(Int, dev.guardrail.generators.scala.ScalaLanguage#TermName)] = ???
def objectType(format: Option[String]): dev.guardrail.Target[dev.guardrail.generators.scala.ScalaLanguage#Type] = Target.pure(t"io.circe.Json")
}
implicit val circeProtocolGenerator: ProtocolTerms[ScalaLanguage, Target] = CirceRefinedProtocolGenerator(CirceRefinedModelGenerator.V012)
implicit val scalaGenerator = ScalaGenerator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ class BigObjectSpec extends AnyFunSuite with Matchers with SwaggerSpecRunner wit
def getFrameworkImplicits(): dev.guardrail.Target[Option[
(dev.guardrail.generators.scala.ScalaLanguage#TermName, dev.guardrail.generators.scala.ScalaLanguage#ObjectDefinition)
]] = ???
def getFrameworkImports(tracing: Boolean): dev.guardrail.Target[List[dev.guardrail.generators.scala.ScalaLanguage#Import]] = ???
def lookupStatusCode(key: String): dev.guardrail.Target[(Int, dev.guardrail.generators.scala.ScalaLanguage#TermName)] = ???
def objectType(format: Option[String]): dev.guardrail.Target[dev.guardrail.generators.scala.ScalaLanguage#Type] = Target.pure(t"io.circe.Json")
def getFrameworkImports(tracing: Boolean): dev.guardrail.Target[List[dev.guardrail.generators.scala.ScalaLanguage#Import]] = ???
def lookupStatusCode(key: Tracker[String]): dev.guardrail.Target[(Int, dev.guardrail.generators.scala.ScalaLanguage#TermName)] = ???
def objectType(format: Option[String]): dev.guardrail.Target[dev.guardrail.generators.scala.ScalaLanguage#Type] = Target.pure(t"io.circe.Json")
}
implicit val circeProtocolGenerator = CirceProtocolGenerator(CirceModelGenerator.V012)
implicit val scalaGenerator = ScalaGenerator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class ValidationTest extends AnyFreeSpec with Matchers with SwaggerSpecRunner wi
def getFrameworkImplicits(): dev.guardrail.Target[Option[
(dev.guardrail.generators.scala.ScalaLanguage#TermName, dev.guardrail.generators.scala.ScalaLanguage#ObjectDefinition)
]] = ???
def getFrameworkImports(tracing: Boolean): dev.guardrail.Target[List[dev.guardrail.generators.scala.ScalaLanguage#Import]] = ???
def lookupStatusCode(key: String): dev.guardrail.Target[(Int, dev.guardrail.generators.scala.ScalaLanguage#TermName)] = ???
def getFrameworkImports(tracing: Boolean): dev.guardrail.Target[List[dev.guardrail.generators.scala.ScalaLanguage#Import]] = ???
def lookupStatusCode(key: Tracker[String]): dev.guardrail.Target[(Int, dev.guardrail.generators.scala.ScalaLanguage#TermName)] = ???
def objectType(format: Option[String]): dev.guardrail.Target[dev.guardrail.generators.scala.ScalaLanguage#Type] = Target.pure(t"io.circe.Json")
}
implicit val circeProtocolGenerator: ProtocolTerms[ScalaLanguage, Target] = CirceRefinedProtocolGenerator(CirceRefinedModelGenerator.V012)
Expand Down

0 comments on commit e3c9b94

Please sign in to comment.