Skip to content

Commit

Permalink
Introduce Related Resource Auth Override
Browse files Browse the repository at this point in the history
- This change allows API developers to create related resources that are fetched with “internal” auths.

- This is done by adding an authOverride value to the reverse relation definition on the primary resource.

- The purpose of this is to give developers the option to do authentication in only one place, the entry/parent resource, and use internal auth for all the related/child resources. This allows the child resources to be a bit more re-usable in that they can be attached to multiple parent resources that may all have different auth schemes.
  • Loading branch information
mbarackman-coursera committed May 3, 2018
1 parent fae3a22 commit 4c2861b
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 107 deletions.
Expand Up @@ -10,6 +10,7 @@ import org.coursera.naptime.ari.graphql.SangriaGraphQlContext
import org.coursera.naptime.ari.graphql.schema.DataMapWithParent
import org.coursera.naptime.ari.graphql.schema.NaptimeResourceUtils
import org.coursera.naptime.ari.graphql.schema.ParentModel
import org.coursera.naptime.schema.AuthOverride
import play.api.libs.json.JsArray
import play.api.libs.json.JsNull
import play.api.libs.json.JsValue
Expand All @@ -25,7 +26,8 @@ case class NaptimeRequest(
resourceName: ResourceName,
arguments: Set[(String, JsValue)],
resourceSchema: RecordDataSchema,
paginationOverride: Option[ResponsePagination] = None)
paginationOverride: Option[ResponsePagination] = None,
authOverride: Option[AuthOverride] = None)

case class NaptimeResponse(
elements: List[DataMapWithParent],
Expand All @@ -44,20 +46,22 @@ case class DeferredNaptimeRequest(
resourceName: ResourceName,
arguments: Set[(String, JsValue)],
resourceSchema: RecordDataSchema,
paginationOverride: Option[ResponsePagination] = None)
paginationOverride: Option[ResponsePagination] = None,
authOverride: Option[AuthOverride] = None)
extends Deferred[Either[NaptimeError, NaptimeResponse]]
with DeferredNaptime {

def toNaptimeRequest(idx: Int): NaptimeRequest = {
NaptimeRequest(RequestId(idx), resourceName, arguments, resourceSchema, paginationOverride)
NaptimeRequest(RequestId(idx), resourceName, arguments, resourceSchema, paginationOverride, authOverride)
}
}

case class DeferredNaptimeElement(
resourceName: ResourceName,
idOpt: Option[JsValue],
arguments: Set[(String, JsValue)],
resourceSchema: RecordDataSchema)
resourceSchema: RecordDataSchema,
authOverride: Option[AuthOverride] = None)
extends Deferred[Either[NaptimeError, NaptimeResponse]]
with DeferredNaptime {

Expand All @@ -68,7 +72,9 @@ case class DeferredNaptimeElement(
arguments ++ idOpt
.map(id => List("ids" -> JsArray(List(id))))
.getOrElse(List.empty),
resourceSchema)
resourceSchema,
paginationOverride = None,
authOverride)
}
}

Expand All @@ -82,23 +88,29 @@ class NaptimeResolver extends DeferredResolver[SangriaGraphQlContext] with Stric
case (d: DeferredNaptime, idx: Int) => d.toNaptimeRequest(idx)
}

val dataByResource = naptimeRequests
.groupBy(_.resourceName)
.map {
case (resourceName, requests) =>
val (forwardRequests, reverseRequests) =
requests.partition(_.arguments.exists(_._1 == "ids"))
val dataByResource = naptimeRequests.groupBy(_.resourceName)
.map { case (resourceName, requests) =>

// Handle MultiGet and Non-Multigets differently, since multigets can be batched
val (forwardRequests, reverseRequests) =
requests.partition(_.arguments.exists(_._1 == "ids"))

// partition forward requests by auth type
val forwardRelations = Future.sequence(forwardRequests.groupBy(_.authOverride).map {
case (authOverride, selectedRequests) =>
fetchForwardRelations(selectedRequests, resourceName, ctx, authOverride)
}).map(_.flatten.toMap)

val reverseRelations =
fetchReverseRelations(reverseRequests, resourceName, ctx)

// Handle MultiGet and Non-Multigets differently, since multigets can be batched
val forwardRelations =
fetchForwardRelations(forwardRequests, resourceName, ctx)
val reverseRelations =
fetchReverseRelations(reverseRequests, resourceName, ctx)
Future
.sequence(List(forwardRelations, reverseRelations))
.map(_.flatten.toMap)
val allRelations = List(
forwardRelations,
reverseRelations)

Future.sequence(allRelations).map(_.flatten.toMap)
}

val allData = Future.sequence(dataByResource).map(_.flatten.toMap)

deferred.zipWithIndex.map {
Expand All @@ -123,46 +135,43 @@ class NaptimeResolver extends DeferredResolver[SangriaGraphQlContext] with Stric
* @return Map of request ids (indexes from the deferred request batching) to either a
* NaptimeError or NaptimeResponse
*/
def fetchForwardRelations(
private[this] def fetchForwardRelations(
requests: Vector[NaptimeRequest],
resourceName: ResourceName,
context: SangriaGraphQlContext)(implicit ec: ExecutionContext)
: Future[Map[RequestId, Either[NaptimeError, NaptimeResponse]]] = {

Future
.sequence {
mergeMultigetRequests(context.requestHeader, requests, resourceName)
.map {
case (request, sourceRequests) =>
context.fetcher.data(request, context.debugMode).map {
case Right(successfulResponse) =>
val parsedElements =
parseElements(request, successfulResponse, requests.head.resourceSchema)
val parsedElementsMap = parsedElements.map { element =>
val id = Option(element.element.get("id"))
.map(NaptimeResourceUtils.parseToJson)
.getOrElse(JsNull)
id -> element
}.toMap
sourceRequests.map { sourceRequest =>
// TODO(bryan): Clean this up
val elements = parseIds(sourceRequest)
.flatMap(parsedElementsMap.get)
.toList
val url = successfulResponse.url.getOrElse("???")
sourceRequest.idx ->
Right[NaptimeError, NaptimeResponse](
NaptimeResponse(elements, sourceRequest.paginationOverride, url, 200, None))
}.toMap
case Left(error) =>
sourceRequests.map { sourceRequest =>
sourceRequest.idx ->
Left(NaptimeError(error.url.getOrElse("???"), error.code, error.message))
}.toMap
}
context: SangriaGraphQlContext,
authOverride: Option[AuthOverride])
(implicit ec: ExecutionContext):
Future[Map[RequestId, Either[NaptimeError, NaptimeResponse]]] = {
Future.sequence {
mergeMultigetRequests(context.requestHeader, requests, resourceName, authOverride)
.map { case (request, sourceRequests) =>
context.fetcher.data(request, context.debugMode).map {
case Right(successfulResponse) =>
val parsedElements = parseElements(
request,
successfulResponse,
requests.head.resourceSchema)
val parsedElementsMap = parsedElements.map { element =>
val id = Option(element.element.get("id"))
.map(NaptimeResourceUtils.parseToJson)
.getOrElse(JsNull)
id -> element
}.toMap
sourceRequests.map { sourceRequest =>
// TODO(bryan): Clean this up
val elements = parseIds(sourceRequest).flatMap(parsedElementsMap.get).toList
val url = successfulResponse.url.getOrElse("???")
sourceRequest.idx ->
Right[NaptimeError, NaptimeResponse](
NaptimeResponse(elements, sourceRequest.paginationOverride, url, 200, None))
}.toMap
case Left(error) =>
sourceRequests.map { sourceRequest =>
sourceRequest.idx ->
Left(NaptimeError(error.url.getOrElse("???"), error.code, error.message))
}.toMap
}
}
.map(_.flatten.toMap)
}}.map(_.flatten.toMap)
}

/**
Expand All @@ -173,20 +182,21 @@ class NaptimeResolver extends DeferredResolver[SangriaGraphQlContext] with Stric
* @param requests A list of NaptimeRequests specifying the resource and arguments
* @return a map of TopLevelRequests -> list of NaptimeRequests that it fulfills
*/
def mergeMultigetRequests(
private[this] def mergeMultigetRequests(
header: RequestHeader,
requests: Vector[NaptimeRequest],
resourceName: ResourceName): Map[Request, Vector[NaptimeRequest]] = {

resourceName: ResourceName,
authOverride: Option[AuthOverride]):
Map[Request, Vector[NaptimeRequest]] = {
requests
.groupBy(_.arguments.filterNot(_._1 == "ids"))
.map {
case (nonIdArguments, innerRequests) =>
// TODO(bryan): Limit multiget requests by number of ids as well, to avoid http limits
Request(
header,
resourceName,
nonIdArguments + ("ids" -> JsArray(parseAndMergeIds(innerRequests)))) -> innerRequests
.map { case (nonIdArguments, innerRequests) =>
// TODO(bryan): Limit multiget requests by number of ids as well, to avoid http limits
Request(
header,
resourceName,
nonIdArguments + ("ids" -> JsArray(parseAndMergeIds(innerRequests))),
authOverride) -> innerRequests
}
}

Expand Down Expand Up @@ -220,31 +230,29 @@ class NaptimeResolver extends DeferredResolver[SangriaGraphQlContext] with Stric
def fetchReverseRelations(
requests: Vector[NaptimeRequest],
resourceName: ResourceName,
context: SangriaGraphQlContext)(implicit ec: ExecutionContext)
: Future[Map[RequestId, Either[NaptimeError, NaptimeResponse]]] = {
Future
.sequence {
requests.map { request =>
val fetcherRequest =
Request(context.requestHeader, resourceName, request.arguments)
context.fetcher
.data(fetcherRequest, context.debugMode)
.map {
case Right(response) =>
val elements = parseElements(fetcherRequest, response, requests.head.resourceSchema)
Right(
NaptimeResponse(
elements,
Some(response.pagination),
response.url.getOrElse("???")))
case Left(error) =>
Left(NaptimeError(error.url.getOrElse("???"), error.code, error.message))
}
.map(res => Map(request.idx -> res))
}
}
.map(_.flatten.toMap)
}
context: SangriaGraphQlContext)
(implicit ec: ExecutionContext):
Future[Map[RequestId, Either[NaptimeError, NaptimeResponse]]] = {
Future.sequence {
requests.map { request =>
val fetcherRequest = Request(
context.requestHeader,
resourceName,
request.arguments,
request.authOverride)
context.fetcher.data(fetcherRequest, context.debugMode).map {
case Right(response) =>
val elements = parseElements(fetcherRequest, response, requests.head.resourceSchema)
Right(
NaptimeResponse(
elements,
Some(response.pagination),
response.url.getOrElse("???")))
case Left(error) =>
Left(NaptimeError(error.url.getOrElse("???"), error.code, error.message))
}.map(res => Map(request.idx -> res))
}}.map(_.flatten.toMap)
}

/**
* Helper to parse the elements in a response into a map of JsValue -> DataMapWithParent
Expand Down Expand Up @@ -279,5 +287,4 @@ class NaptimeResolver extends DeferredResolver[SangriaGraphQlContext] with Stric
byResourceName.headOption.map(_._1)
}
}

}
Expand Up @@ -92,6 +92,8 @@ object NaptimePaginatedResourceField extends StrictLogging {
val providedArguments =
fieldRelationOpt.map(_.arguments.keySet).getOrElse(Set[String]())

val authOverride = fieldRelationOpt.flatMap(_.authOverride)

val arguments = NaptimeResourceUtils
.generateHandlerArguments(handler, includePagination = true)
.filterNot(_.name == "ids")
Expand Down Expand Up @@ -168,7 +170,8 @@ object NaptimePaginatedResourceField extends StrictLogging {
resourceName,
updatedArgs,
resourceMergedType,
paginationOverride))
paginationOverride,
authOverride))
.map {
case Left(error) =>
NaptimeResponse(
Expand Down
Expand Up @@ -119,9 +119,8 @@ object NaptimeResourceField extends StrictLogging {
.getOrElse {
Set("ids" -> NaptimeResourceUtils.parseToJson(context.arg("id")))
}
val args = context.args.raw
.mapValues(NaptimeResourceUtils.parseToJson)
.toSet ++ extraArguments
val args = context.args.raw.mapValues(NaptimeResourceUtils.parseToJson).toSet ++
extraArguments
val idArg = args.find(_._1 == "ids").map(_._2)
val nonIdArgs = args.filter(_._1 != "ids")

Expand All @@ -134,10 +133,18 @@ object NaptimeResourceField extends StrictLogging {
fieldRelationOpt.isEmpty) &&
idArg.forall(Utilities.jsValueIsEmpty)

val authOverride = fieldRelationOpt.flatMap(_.authOverride)

if (isForwardRelationButMissingId) {
Value(null)
} else {
DeferredValue(DeferredNaptimeElement(resourceName, idArg, nonIdArgs, resourceMergedType))
DeferredValue(
DeferredNaptimeElement(
resourceName,
idArg,
nonIdArgs,
resourceMergedType,
authOverride))
.map {
case Left(error) =>
throw NaptimeResolveException(error)
Expand Down
@@ -0,0 +1,6 @@
namespace org.coursera.naptime.schema

typeref AuthOverride = union [
record InternalAuth {}
]

Expand Up @@ -20,4 +20,12 @@ record ReverseRelationAnnotation {
*/
relationType: enum RelationType { FINDER, MULTI_GET, GET, SINGLE_ELEMENT_FINDER }

/**
* This is an optional field which will override the auth headers from the
* original request with whatever is indicated here when requesting the child resource. It was
* introduced to allow the API developer to use "internal auths" for specific related
* resources.
*/
authOverride: AuthOverride?

}
4 changes: 3 additions & 1 deletion naptime/src/main/scala/org/coursera/naptime/ari/models.scala
Expand Up @@ -21,6 +21,7 @@ import com.linkedin.data.schema.DataSchema
import com.linkedin.data.schema.RecordDataSchema
import org.coursera.naptime.ResourceName
import org.coursera.naptime.ResponsePagination
import org.coursera.naptime.schema.AuthOverride
import org.coursera.naptime.schema.Resource
import play.api.libs.json.JsValue
import play.api.mvc.RequestHeader
Expand Down Expand Up @@ -84,7 +85,8 @@ object FullSchema {
case class Request(
requestHeader: RequestHeader,
resource: ResourceName,
arguments: Set[(String, JsValue)])
arguments: Set[(String, JsValue)],
authOverride: Option[AuthOverride])

/**
* This model represents a response from a [[Request]], including elements and pagination
Expand Down

0 comments on commit 4c2861b

Please sign in to comment.