Skip to content

Commit

Permalink
improve spray directives test coverage (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
levkhomich committed Dec 26, 2014
1 parent a02fa59 commit 2f72f92
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ trait MockCollector {
def startCollector(): TServer = {
val handler = new thrift.Scribe.Iface {
override def Log(messages: util.List[LogEntry]): ResultCode = {
println(s"collector: received ${messages.size} messages")
println(s"collector: received ${messages.size} message${if (messages.size > 1) "s" else ""}")
results.addAll(messages)
thrift.ResultCode.OK
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ trait BaseTracingDirectives {
import spray.routing.directives.MiscDirectives._
import TracingDirectives._

private[this] def tracedEntity[T <: TracingSupport](service: String)(implicit um: FromRequestUnmarshaller[T]): Directive[T :: BaseTracingSupport :: HNil] =
hextract(ctx => ctx.request.as(um) :: extractSpan(ctx.request) :: ctx.request :: HNil).hflatMap[T :: BaseTracingSupport :: HNil] {
private[this] def tracedEntity[T <: TracingSupport](service: String)(implicit um: FromRequestUnmarshaller[T]): Directive[T :: HNil] =
hextract(ctx => ctx.request.as(um) :: extractSpan(ctx.request) :: ctx.request :: HNil).hflatMap[T :: HNil] {
case Right(value) :: Success(optSpan) :: request :: HNil =>
optSpan.foreach(s => value.init(s.$spanId, s.$traceId.get, s.$parentId))
if (optSpan.map(_.forceSampling).getOrElse(false))
trace.forcedSample(value, service)
else
trace.sample(value, service)
addHttpAnnotations(value, request)
hprovide(value :: optSpan.getOrElse(value) :: HNil)
hprovide(value :: HNil)
case Right(value) :: _ :: request :: HNil =>
trace.sample(value, service)
addHttpAnnotations(value, request)
hprovide(value :: value :: HNil)
hprovide(value :: HNil)
case Left(ContentExpected) :: _ => reject(RequestEntityExpectedRejection)
case Left(UnsupportedContentType(supported)) :: _ => reject(UnsupportedRequestContentTypeRejection(supported))
case Left(MalformedContent(errorMsg, cause)) :: _ => reject(MalformedRequestContentRejection(errorMsg, cause))
Expand All @@ -68,10 +68,10 @@ trait BaseTracingDirectives {
*/
def tracedHandleWith[A <: TracingSupport, B](service: String)(f: A => B)(implicit um: FromRequestUnmarshaller[A], m: ToResponseMarshaller[B]): Route =
tracedEntity(service)(um) {
case (a, span) =>
case ts =>
new StandardRoute {
def apply(ctx: RequestContext): Unit =
ctx.complete(f(a))(traceServerSend(span))
ctx.complete(f(ts))(traceServerSend(ts))
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,28 @@
package com.github.levkhomich.akka.tracing.http

import java.util.concurrent.TimeoutException

import scala.collection.JavaConversions._
import scala.concurrent.duration._
import scala.util.Random

import org.specs2.matcher.MatchResult
import spray.http.{ HttpRequest, StatusCodes, HttpResponse }
import spray.http._
import spray.httpx.unmarshalling.{ Deserialized, FromRequestUnmarshaller }
import spray.routing.HttpService
import spray.testkit.Specs2RouteTest

import com.github.levkhomich.akka.tracing._

import scala.concurrent.duration._

class TracingDirectivesSpec extends AkkaTracingSpecification with BaseTracingDirectives
with MockCollector with Specs2RouteTest with HttpService {

sequential

override implicit val system = testActorSystem()
def actorRefFactory = system
override val actorRefFactory = system
val serviceName = "testService"

final case class TestRequest(text: String) extends TracingSupport

object TestRequest {
implicit def um: FromRequestUnmarshaller[TestRequest] =
new FromRequestUnmarshaller[TestRequest] {
override def apply(request: HttpRequest): Deserialized[TestRequest] =
Right(TestRequest(request.entity.asString))
}
}
val rpcName = "testRpc"
val testPath = "/test-path"

override protected def trace: TracingExtensionImpl =
TracingExtension(system)
Expand All @@ -43,36 +34,162 @@ class TracingDirectivesSpec extends AkkaTracingSpecification with BaseTracingDir
}
}

"tracedHandleWith" should {
val tracedCompleteRoute =
get {
tracedComplete(serviceName, rpcName)(HttpResponse(StatusCodes.OK))
}

"sample requests and annotate them using HttpRequest data" in {
Get("/test-path") ~> tracedHandleWithRoute ~> check {
"tracedHandleWith directive" should {
"sample requests" in {
Get(testPath) ~> tracedHandleWithRoute ~> check {
response.status mustEqual StatusCodes.OK
Thread.sleep(3000)
val span = receiveSpan()
checkBinaryAnnotation(span, "request.path", testPath)
checkBinaryAnnotation(span, "request.uri", "http://example.com/test-path")
checkBinaryAnnotation(span, "request.method", "GET")
checkBinaryAnnotation(span, "request.proto", "HTTP/1.1")
}
}

val spans = results.map(e => decodeSpan(e.message))
spans.size mustEqual 1
"annotate sampled requests (general)" in {
Get(testPath) ~> tracedHandleWithRoute ~> check {
response.status mustEqual StatusCodes.OK
val span = receiveSpan()
checkBinaryAnnotation(span, "request.path", testPath)
checkBinaryAnnotation(span, "request.uri", "http://example.com/test-path")
checkBinaryAnnotation(span, "request.method", "GET")
checkBinaryAnnotation(span, "request.proto", "HTTP/1.1")
}
}

val span = spans.head
"annotate sampled requests (query params, headers)" in {
Get(Uri.from(path = testPath, query = Uri.Query("key" -> "value"))).withHeaders(
HttpHeaders.`Content-Type`(ContentTypes.`text/plain`) ::
Nil
) ~> tracedHandleWithRoute ~> check {
response.status mustEqual StatusCodes.OK
val span = receiveSpan()
checkBinaryAnnotation(span, "request.headers." + HttpHeaders.`Content-Type`.name, ContentTypes.`text/plain`.toString)
checkBinaryAnnotation(span, "request.query.key", "value")
}
}

def checkBinaryAnnotation(key: String, expValue: String): MatchResult[Any] = {
val ba = span.binary_annotations.find(_.get_key == key)
ba.isDefined mustEqual true
val actualValue = new String(ba.get.get_value, "UTF-8")
actualValue mustEqual expValue
val spanId = Random.nextLong
val parentId = Random.nextLong
"propagate tracing headers" in {
Get(testPath).withHeaders(
HttpHeaders.RawHeader(TracingHeaders.TraceId, Span.asString(spanId)) ::
HttpHeaders.RawHeader(TracingHeaders.ParentSpanId, Span.asString(parentId)) ::
Nil
) ~> tracedHandleWithRoute ~> check {
response.status mustEqual StatusCodes.OK
val span = receiveSpan()
checkBinaryAnnotation(span, "request.headers." + TracingHeaders.TraceId, Span.asString(spanId))
checkBinaryAnnotation(span, "request.headers." + TracingHeaders.ParentSpanId, Span.asString(parentId))
}
}
}

checkBinaryAnnotation("request.path", "/test-path")
checkBinaryAnnotation("request.uri", "http://example.com/test-path")
checkBinaryAnnotation("request.method", "GET")
checkBinaryAnnotation("request.proto", "HTTP/1.1")
"tracedComplete directive" should {
"not sample requests without tracing headers" in {
Get(testPath) ~> tracedCompleteRoute ~> check {
response.status mustEqual StatusCodes.OK
Thread.sleep(3000)
results.size mustEqual 0
}
}

"shutdown correctly" in {
system.shutdown()
collector.stop()
system.awaitTermination(FiniteDuration(5, SECONDS)) must not(throwA[TimeoutException])
"sample requests with tracing headers" in {
val spanId = Random.nextLong
Get(testPath).withHeaders(
HttpHeaders.RawHeader(TracingHeaders.TraceId, Span.asString(spanId)) ::
Nil
) ~> tracedCompleteRoute ~> check {
response.status mustEqual StatusCodes.OK
val span = receiveSpan()
span.get_trace_id mustEqual spanId
span.get_name mustEqual rpcName
span.get_annotations.head.get_host.get_service_name mustEqual serviceName
}
}

"annotate sampled requests (general)" in {
val spanId = Random.nextLong
Get(testPath).withHeaders(
HttpHeaders.RawHeader(TracingHeaders.TraceId, Span.asString(spanId)) ::
Nil
) ~> tracedCompleteRoute ~> check {
response.status mustEqual StatusCodes.OK
val span = receiveSpan()
checkBinaryAnnotation(span, "request.path", testPath)
checkBinaryAnnotation(span, "request.uri", "http://example.com/test-path")
checkBinaryAnnotation(span, "request.method", "GET")
checkBinaryAnnotation(span, "request.proto", "HTTP/1.1")
}
}

"annotate sampled requests (query params, headers)" in {
val spanId = Random.nextLong
Get(Uri.from(path = testPath, query = Uri.Query("key" -> "value"))).withHeaders(
HttpHeaders.RawHeader(TracingHeaders.TraceId, Span.asString(spanId)) ::
HttpHeaders.`Content-Type`(ContentTypes.`text/plain`) ::
Nil
) ~> tracedCompleteRoute ~> check {
response.status mustEqual StatusCodes.OK
val span = receiveSpan()
checkBinaryAnnotation(span, "request.headers." + HttpHeaders.`Content-Type`.name, ContentTypes.`text/plain`.toString)
checkBinaryAnnotation(span, "request.query.key", "value")
}
}

"propagate tracing headers" in {
val spanId = Random.nextLong
val parentId = Random.nextLong
Get(testPath).withHeaders(
HttpHeaders.RawHeader(TracingHeaders.TraceId, Span.asString(spanId)) ::
HttpHeaders.RawHeader(TracingHeaders.ParentSpanId, Span.asString(parentId)) ::
Nil
) ~> tracedCompleteRoute ~> check {
response.status mustEqual StatusCodes.OK
val span = receiveSpan()
checkBinaryAnnotation(span, "request.headers." + TracingHeaders.TraceId, Span.asString(spanId))
checkBinaryAnnotation(span, "request.headers." + TracingHeaders.ParentSpanId, Span.asString(parentId))
}
}
}

"shutdown correctly" in {
system.shutdown()
collector.stop()
system.awaitTermination(FiniteDuration(5, SECONDS)) must not(throwA[TimeoutException])
}

private[this] def checkBinaryAnnotation(span: thrift.Span, key: String, expValue: String): MatchResult[Any] = {
span.binary_annotations.find(_.get_key == key) match {
case Some(ba) =>
val actualValue = new String(ba.get_value, "UTF-8")
actualValue mustEqual expValue
case _ =>
ko(key + " = " + expValue + " not found")
}
}

final case class TestRequest(text: String) extends TracingSupport

private[this] def receiveSpan(): thrift.Span = {
Thread.sleep(3000)
val spans = results.map(e => decodeSpan(e.message))
spans.size mustEqual 1
results.clear()
spans.head
}

object TestRequest {
implicit def um: FromRequestUnmarshaller[TestRequest] =
new FromRequestUnmarshaller[TestRequest] {
override def apply(request: HttpRequest): Deserialized[TestRequest] =
Right(TestRequest(request.entity.asString))
}
}

}

0 comments on commit 2f72f92

Please sign in to comment.