Skip to content

Commit

Permalink
spray directives: reject requests with malformed tracing headers
Browse files Browse the repository at this point in the history
  • Loading branch information
levkhomich committed Dec 27, 2014
1 parent 0410d70 commit cf7539c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class TracingSupportSpecification extends Specification with AkkaTracingSpecific
import pattern.ask
val childActor = {
val ref = TestActorRef(new Actor {
def receive = {
def receive = {
case _: TracingSupport => sender ! "ok"
}
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.github.levkhomich.akka.tracing.http

import scala.util.{ Random, Try, Success }
import scala.util.{ Failure, Random, Try, Success }

import akka.actor.Actor
import shapeless._
Expand All @@ -38,18 +38,16 @@ trait BaseTracingDirectives {

private[this] def tracedEntity[T <: TracingSupport](service: String)(implicit um: FromRequestUnmarshaller[T]): Directive[T :: HNil] =
hextract(ctx => ctx.request.as(um) :: extractSpan(ctx.request) :: ctx.request :: HNil).hflatMap[T :: HNil] {
case Right(value) :: Success(optSpan) :: request :: HNil =>
case Right(value) :: Right(optSpan) :: request :: HNil =>
optSpan.foreach(s => value.init(s.$spanId, s.$traceId.get, s.$parentId))
if (optSpan.map(_.forceSampling).getOrElse(false))
trace.forcedSample(value, service)
else
trace.sample(value, service)
addHttpAnnotations(value, request)
hprovide(value :: HNil)
case Right(value) :: _ :: request :: HNil =>
trace.sample(value, service)
addHttpAnnotations(value, request)
hprovide(value :: HNil)
case Right(value) :: Left(headerName) :: request :: HNil =>
reject(MalformedHeaderRejection(headerName, "invalid value"))
case Left(ContentExpected) :: _ => reject(RequestEntityExpectedRejection)
case Left(UnsupportedContentType(supported)) :: _ => reject(UnsupportedRequestContentTypeRejection(supported))
case Left(MalformedContent(errorMsg, cause)) :: _ => reject(MalformedRequestContentRejection(errorMsg, cause))
Expand Down Expand Up @@ -86,7 +84,7 @@ trait BaseTracingDirectives {
new StandardRoute {
def apply(ctx: RequestContext): Unit = {
extractSpan(ctx.request) match {
case Success(Some(span)) =>
case Right(Some(span)) =>
// only requests with explicit tracing headers can be traced here, because we don't have
// any clues about spanId generated for unmarshalled entity
if (span.forceSampling)
Expand Down Expand Up @@ -164,30 +162,35 @@ private[http] object TracingDirectives {

private[this] val DebugFlag = 1L

def extractSpan(message: HttpMessage): Try[Option[Span]] = {
def extractSpan(message: HttpMessage): Either[String, Option[Span]] = {
def headerStringValue(name: String): Option[String] =
message.headers.find(_.name == name).map(_.value)
def headerLongValue(name: String): Try[Option[Long]] =
Try(headerStringValue(name).map(Span.fromString))
def headerLongValue(name: String): Either[String, Option[Long]] =
Try(headerStringValue(name).map(Span.fromString)) match {
case Failure(e) =>
Left(name)
case Success(v) =>
Right(v)
}
def isFlagSet(v: String, flag: Long): Boolean =
(java.lang.Long.parseLong(v) & flag) == flag
// debug flag forces sampling (see http://git.io/hdEVug)
def forceSampling: Boolean =
headerStringValue(Flags).exists(isFlagSet(_, DebugFlag)) ||
headerStringValue(Sampled).filter(_ == "true").isDefined
def spanId: Long =
headerLongValue(SpanId).toOption.flatten.getOrElse(Random.nextLong)
headerLongValue(SpanId).right.toOption.flatten.getOrElse(Random.nextLong)

headerLongValue(TraceId).flatMap {
headerLongValue(TraceId).right.map({
case Some(traceId) =>
headerLongValue(ParentSpanId).map { parentId =>
headerLongValue(ParentSpanId).right.map { parentId =>
Some(Span(traceId, spanId, parentId, forceSampling))
}
case None if forceSampling =>
Success(Some(Span(Random.nextLong, spanId, None, true)))
Right(Some(Span(Random.nextLong, spanId, None, true)))
case _ =>
Success(None)
}
Right(None)
}).joinRight
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import scala.collection.JavaConversions._
import scala.concurrent.duration._
import scala.util.Random

import org.specs2.mutable.Specification
import spray.http._
import spray.httpx.unmarshalling.{ Deserialized, FromRequestUnmarshaller }
import spray.routing.HttpService
import spray.testkit.Specs2RouteTest

import org.specs2.matcher.MatchResult
import org.specs2.mutable.Specification

import com.github.levkhomich.akka.tracing._

class TracingDirectivesSpec extends Specification with AkkaTracingSpecification
Expand Down Expand Up @@ -73,9 +71,9 @@ class TracingDirectivesSpec extends Specification with AkkaTracingSpecification
}
}

val spanId = Random.nextLong
val parentId = Random.nextLong
"propagate tracing headers" in {
val spanId = Random.nextLong
val parentId = Random.nextLong
Get(testPath).withHeaders(
HttpHeaders.RawHeader(TracingHeaders.TraceId, Span.asString(spanId)) ::
HttpHeaders.RawHeader(TracingHeaders.ParentSpanId, Span.asString(parentId)) ::
Expand All @@ -87,6 +85,28 @@ class TracingDirectivesSpec extends Specification with AkkaTracingSpecification
checkBinaryAnnotation(span, "request.headers." + TracingHeaders.ParentSpanId, Span.asString(parentId))
}
}

val MalformedHeaderRejection = "The value of HTTP header '%s' was malformed:\ninvalid value"
"reject requests with malformed X-B3-TraceId header" in {
Get(testPath).withHeaders(
HttpHeaders.RawHeader(TracingHeaders.TraceId, "malformed") :: Nil
) ~> sealRoute(tracedHandleWithRoute) ~> check {
response.status mustEqual StatusCodes.BadRequest
responseAs[String] mustEqual (MalformedHeaderRejection format TracingHeaders.TraceId)
}
}

"reject requests with malformed X-B3-ParentTraceId header" in {
val spanId = Random.nextLong
Get(testPath).withHeaders(
HttpHeaders.RawHeader(TracingHeaders.TraceId, Span.asString(spanId)) ::
HttpHeaders.RawHeader(TracingHeaders.ParentSpanId, "malformed") ::
Nil
) ~> sealRoute(tracedHandleWithRoute) ~> check {
response.status mustEqual StatusCodes.BadRequest
responseAs[String] mustEqual (MalformedHeaderRejection format TracingHeaders.ParentSpanId)
}
}
}

"tracedComplete directive" should {
Expand Down

0 comments on commit cf7539c

Please sign in to comment.