Skip to content

Commit

Permalink
Merge pull request #1037 from square/jwilson.0605.concurrent_mock_tracer
Browse files Browse the repository at this point in the history
Fix races in ClientServerTraceTest
  • Loading branch information
swankjesse committed Jun 5, 2019
2 parents bf66d14 + 5b96ab7 commit 46abda2
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 64 deletions.
27 changes: 27 additions & 0 deletions misk-testing/src/main/kotlin/misk/testing/ConcurrentMockTracer.kt
@@ -0,0 +1,27 @@
package misk.testing

import io.opentracing.mock.MockSpan
import io.opentracing.mock.MockTracer
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.TimeUnit
import javax.inject.Inject
import javax.inject.Singleton

/**
* Extends [MockTracer] for use in concurrent environments, such as a web server and test client.
* Prefer this wherever you'd otherwise use [MockTracer].
*/
@Singleton
class ConcurrentMockTracer @Inject constructor() : MockTracer() {
private val queue = LinkedBlockingDeque<MockSpan>()

/** Awaits a span, removes it, and returns it. */
fun take(): MockSpan {
return queue.poll(500, TimeUnit.MILLISECONDS) ?: throw IllegalArgumentException("no spans!")
}

override fun onSpanFinished(mockSpan: MockSpan) {
super.onSpanFinished(mockSpan)
queue.put(mockSpan)
}
}
Expand Up @@ -6,7 +6,7 @@ import misk.inject.KAbstractModule

class MockTracingBackendModule : KAbstractModule() {
override fun configure() {
bind<MockTracer>().toInstance(MockTracer())
bind<Tracer>().to<MockTracer>()
bind<MockTracer>().to<ConcurrentMockTracer>()
bind<Tracer>().to<ConcurrentMockTracer>()
}
}
36 changes: 19 additions & 17 deletions misk/src/test/kotlin/misk/tracing/ClientServerTraceTest.kt
Expand Up @@ -15,6 +15,7 @@ import misk.client.TypedHttpClientModule
import misk.inject.KAbstractModule
import misk.inject.getInstance
import misk.inject.keyOf
import misk.testing.ConcurrentMockTracer
import misk.testing.MiskTest
import misk.testing.MiskTestModule
import misk.testing.MockTracingBackendModule
Expand Down Expand Up @@ -45,7 +46,7 @@ internal class ClientServerTraceTest {
private lateinit var jetty: JettyService

@Inject
private lateinit var tracer: Tracer
private lateinit var serverTracer: ConcurrentMockTracer

@Inject
private lateinit var serverInjector: Injector
Expand All @@ -65,14 +66,11 @@ internal class ClientServerTraceTest {
val client = clientInjector.getInstance<ReturnADinosaur>(Names.named("dinosaur"))
client.getDinosaur(dinosaurRequest).execute()

val serverTracer = tracer as MockTracer
assertThat(serverTracer.finishedSpans().size).isEqualTo(1)

val serverSpan = serverTracer.finishedSpans().first()
val serverSpan = serverTracer.take()
// Parent ID of 0 means there is no parent span
assertThat(serverSpan.parentId()).isGreaterThan(0)

val clientTracer = clientInjector.getInstance(Tracer::class.java) as MockTracer
val clientTracer = clientInjector.getInstance(MockTracer::class.java)
// Two spans here because one is created at the app level and another at the network interceptor
// level.
assertThat(clientTracer.finishedSpans().size).isEqualTo(2)
Expand All @@ -89,9 +87,8 @@ internal class ClientServerTraceTest {

client.getDinosaur(dinosaurRequest).execute()

val serverTracer = tracer as MockTracer
assertThat(serverTracer.finishedSpans().size).isEqualTo(1)
assertThat(serverTracer.finishedSpans().first().parentId()).isEqualTo(0)
val span = serverTracer.take()
assertThat(span.parentId()).isEqualTo(0)

assertThat(clientInjector.allBindings.filter { it.key == keyOf<Tracer>() }).isEmpty()
}
Expand All @@ -105,21 +102,26 @@ internal class ClientServerTraceTest {
val client = clientInjector.getInstance<RoarLikeDinosaur>(Names.named("roar"))
client.doRoar(dinosaurRequest).execute()

val serverTracer = tracer as MockTracer
assertThat(serverTracer.finishedSpans().size).isEqualTo(4)
// Expect 4 spans on the server.
val serverSpans = listOf(
serverTracer.take(),
serverTracer.take(),
serverTracer.take(),
serverTracer.take()
)

val spanIds = serverTracer.finishedSpans().map { it.context().spanId() }.toSet()
val traceId = serverTracer.finishedSpans().first().context().traceId()
val spanIds = serverSpans.map { it.context().spanId() }.toSet()
val traceId = serverSpans[0].context().traceId()

var initialServerSpan: MockSpan? = null
serverTracer.finishedSpans().forEach {
for (span in serverSpans) {
// Parent ID of 0 means there is no parent span
assertThat(it.parentId()).isGreaterThan(0)
assertThat(span.parentId()).isGreaterThan(0)

// Assert trace IDs are all the same (i.e. no new traces, new spans added as children)
assertThat(it.context().traceId()).isEqualTo(traceId)
assertThat(span.context().traceId()).isEqualTo(traceId)

if (!spanIds.contains(it.parentId())) initialServerSpan = it
if (!spanIds.contains(span.parentId())) initialServerSpan = span
}

assertThat(initialServerSpan).isNotNull()
Expand Down
45 changes: 18 additions & 27 deletions misk/src/test/kotlin/misk/tracing/TracerExtTest.kt
@@ -1,11 +1,10 @@
package misk.tracing

import io.opentracing.Tracer
import io.opentracing.mock.MockSpan
import io.opentracing.mock.MockTracer
import io.opentracing.tag.Tags
import misk.exceptions.ActionException
import misk.exceptions.StatusCode
import misk.testing.ConcurrentMockTracer
import misk.testing.MiskTest
import misk.testing.MiskTestModule
import misk.testing.MockTracingBackendModule
Expand All @@ -19,58 +18,50 @@ class TracerExtTest {
@MiskTestModule
val module = MockTracingBackendModule()

@Inject private lateinit var tracer: Tracer
@Inject private lateinit var tracer: ConcurrentMockTracer

@Test
fun traceTracedMethod() {
val mockTracer = tracer as MockTracer

assertThat(mockTracer.finishedSpans().size).isEqualTo(0)
assertThat(tracer.finishedSpans().size).isEqualTo(0)
val spanUsed = tracer.traceWithSpan("traceMe") { span -> span }
assertThat(mockTracer.finishedSpans().size).isEqualTo(1)
assertThat(spanUsed).isEqualTo(mockTracer.finishedSpans().first())
assertThat(mockTracer.finishedSpans().first().tags()).isEmpty()
val span = tracer.take()
assertThat(spanUsed).isEqualTo(span)
assertThat(span.tags()).isEmpty()
}

@Test
fun traceTracedMethodWithTags() {
val mockTracer = tracer as MockTracer

val tags = mapOf("a" to "b", "x" to "y")

assertThat(mockTracer.finishedSpans().size).isEqualTo(0)
assertThat(tracer.finishedSpans().size).isEqualTo(0)
val spanUsed = tracer.traceWithSpan("traceMe", tags) { span -> span }
assertThat(mockTracer.finishedSpans().size).isEqualTo(1)
assertThat(spanUsed).isEqualTo(mockTracer.finishedSpans().first())

assertThat(mockTracer.finishedSpans().first().tags()).isEqualTo(tags)
val span = tracer.take()
assertThat(spanUsed).isEqualTo(span)
assertThat(span.tags()).isEqualTo(tags)
}

@Test
fun tagTracingFailures() {
val mockTracer = tracer as MockTracer

assertThat(mockTracer.finishedSpans().size).isEqualTo(0)
assertThat(tracer.finishedSpans().size).isEqualTo(0)
assertFailsWith<ActionException> {
tracer.trace("failedTrace") {
throw ActionException(StatusCode.BAD_REQUEST, "sadness")
}
}
assertThat(mockTracer.finishedSpans().size).isEqualTo(1)
assertThat(mockTracer.finishedSpans().get(0).tags().get(Tags.ERROR.key)).isEqualTo(true)
val span = tracer.take()
assertThat(span.tags()[Tags.ERROR.key]).isEqualTo(true)
}

@Test
fun nestedTracing() {
val mockTracer = tracer as MockTracer

assertThat(mockTracer.finishedSpans().size).isEqualTo(0)
assertThat(tracer.finishedSpans().size).isEqualTo(0)
val (parentSpan, childSpan) = tracer.traceWithSpan("parent") { span1 ->
span1 to tracer.traceWithSpan("child") { span2 -> span2 }
}
assertThat(mockTracer.finishedSpans().size).isEqualTo(2)
assertThat(mockTracer.finishedSpans()[0]).isEqualTo(childSpan)
assertThat(mockTracer.finishedSpans()[1]).isEqualTo(parentSpan)
val span0 = tracer.take()
val span1 = tracer.take()
assertThat(span0).isEqualTo(childSpan)
assertThat(span1).isEqualTo(parentSpan)

val parentContext = parentSpan.context() as MockSpan.MockContext
val childContext = childSpan.context() as MockSpan.MockContext
Expand Down
@@ -1,13 +1,12 @@
package misk.web.interceptors

import com.google.inject.Guice
import io.opentracing.Tracer
import io.opentracing.mock.MockTracer
import io.opentracing.tag.Tags
import misk.asAction
import misk.exceptions.ActionException
import misk.exceptions.StatusCode
import misk.inject.KAbstractModule
import misk.testing.ConcurrentMockTracer
import misk.testing.MiskTest
import misk.testing.MiskTestModule
import misk.testing.MockTracingBackendModule
Expand Down Expand Up @@ -37,7 +36,7 @@ class TracingInterceptorTest {

@Inject private lateinit var tracingInterceptorFactory: TracingInterceptor.Factory
@Inject private lateinit var tracingTestAction: TracingTestAction
@Inject private lateinit var tracer: Tracer
@Inject private lateinit var tracer: ConcurrentMockTracer
@Inject private lateinit var jettyService: JettyService

@Test
Expand All @@ -53,10 +52,7 @@ class TracingInterceptorTest {

chain.proceed(chain.request)

val mockTracer = tracer as MockTracer
assertThat(mockTracer.finishedSpans().size).isEqualTo(1)
assertThat(mockTracer.finishedSpans().first().parentId()).isEqualTo(0)
val span = mockTracer.finishedSpans().first()
val span = tracer.take()
assertThat(span.parentId()).isEqualTo(0)
assertThat(span.tags()).isEqualTo(mapOf(
"http.method" to "GET",
Expand All @@ -79,19 +75,15 @@ class TracingInterceptorTest {

chain.proceed(chain.request)

val mockTracer = tracer as MockTracer
assertThat(mockTracer.finishedSpans().size).isEqualTo(1)
assertThat(mockTracer.finishedSpans().first().parentId()).isEqualTo(1)
val span = tracer.take()
assertThat(span.parentId()).isEqualTo(1)
}

@Test
fun failedTrace() {
get("/failed_trace")

val mockTracer = tracer as MockTracer
assertThat(mockTracer.finishedSpans().size).isEqualTo(1)

val span = mockTracer.finishedSpans().first()
val span = tracer.take()
assertThat(span.tags().get(Tags.ERROR.key)).isEqualTo(true)
assertThat(span.tags().get(Tags.HTTP_STATUS.key)).isEqualTo(400)
}
Expand All @@ -100,10 +92,7 @@ class TracingInterceptorTest {
fun failedTraceWithException() {
get("/exception_trace")

val mockTracer = tracer as MockTracer
assertThat(mockTracer.finishedSpans().size).isEqualTo(1)

val span = mockTracer.finishedSpans().first()
val span = tracer.take()
assertThat(span.tags().get(Tags.ERROR.key)).isEqualTo(true)
assertThat(span.tags().get(Tags.HTTP_STATUS.key)).isEqualTo(420)
}
Expand Down

0 comments on commit 46abda2

Please sign in to comment.