diff --git a/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/ContextualKey.java b/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/ContextualKey.java new file mode 100644 index 0000000..37eb3f7 --- /dev/null +++ b/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/ContextualKey.java @@ -0,0 +1,27 @@ +package org.hypertrace.core.grpcutils.context; + +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +public interface ContextualKey { + RequestContext getContext(); + + T getData(); + + /** + * Calls the function in the key's context and providing the key's data as an argument, returning + * any result + */ + R callInContext(Function function); + + R callInContext(Supplier supplier); + + /** + * Calls the function in the key's context and providing the key's data as an argument, returning + * no result + */ + void runInContext(Consumer consumer); + + void runInContext(Runnable runnable); +} diff --git a/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/DefaultContextualKey.java b/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/DefaultContextualKey.java new file mode 100644 index 0000000..5088718 --- /dev/null +++ b/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/DefaultContextualKey.java @@ -0,0 +1,84 @@ +package org.hypertrace.core.grpcutils.context; + +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +class DefaultContextualKey implements ContextualKey { + private final RequestContext context; + private final T data; + private final Map meaningfulContextHeaders; + + DefaultContextualKey(RequestContext context, T data) { + this.context = context; + this.data = data; + this.meaningfulContextHeaders = this.extractMeaningfulHeaders(context.getRequestHeaders()); + } + + @Override + public RequestContext getContext() { + return this.context; + } + + @Override + public T getData() { + return this.data; + } + + @Override + public R callInContext(Function function) { + return this.context.call(() -> function.apply(this.getData())); + } + + @Override + public R callInContext(Supplier supplier) { + return this.context.call(supplier::get); + } + + @Override + public void runInContext(Consumer consumer) { + this.context.run(() -> consumer.accept(this.getData())); + } + + @Override + public void runInContext(Runnable runnable) { + this.context.run(runnable); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DefaultContextualKey that = (DefaultContextualKey) o; + return Objects.equals(getData(), that.getData()) + && meaningfulContextHeaders.equals(that.meaningfulContextHeaders); + } + + @Override + public int hashCode() { + return Objects.hash(getData(), meaningfulContextHeaders); + } + + @Override + public String toString() { + return "DefaultContextualKey{" + + "data=" + + data + + ", meaningfulContextHeaders=" + + meaningfulContextHeaders + + '}'; + } + + private Map extractMeaningfulHeaders(Map allHeaders) { + return allHeaders.entrySet().stream() + .filter( + entry -> + RequestContextConstants.CACHE_MEANINGFUL_HEADERS.contains( + entry.getKey().toLowerCase())) + .collect(Collectors.toUnmodifiableMap(Entry::getKey, Entry::getValue)); + } +} diff --git a/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/RequestContext.java b/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/RequestContext.java index 196fc35..6116397 100644 --- a/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/RequestContext.java +++ b/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/RequestContext.java @@ -79,4 +79,12 @@ public V call(@Nonnull Callable callable) { public void run(@Nonnull Runnable runnable) { Context.current().withValue(RequestContext.CURRENT, this).run(runnable); } + + public ContextualKey buildContextualKey(T data) { + return new DefaultContextualKey<>(this, data); + } + + public ContextualKey buildContextualKey() { + return new DefaultContextualKey<>(this, null); + } } diff --git a/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/RequestContextConstants.java b/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/RequestContextConstants.java index 4926041..4eb9bd3 100644 --- a/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/RequestContextConstants.java +++ b/grpc-context-utils/src/main/java/org/hypertrace/core/grpcutils/context/RequestContextConstants.java @@ -1,13 +1,13 @@ package org.hypertrace.core.grpcutils.context; +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; + import io.grpc.Metadata; import java.util.Set; -import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; - /** - * GRPC request context constants used to propagate the tenantId, authorization token, tracing headers etc - * in the platform services. + * GRPC request context constants used to propagate the tenantId, authorization token, tracing + * headers etc in the platform services. */ public class RequestContextConstants { public static final String TENANT_ID_HEADER_KEY = "x-tenant-id"; @@ -17,10 +17,20 @@ public class RequestContextConstants { public static final String AUTHORIZATION_HEADER = "authorization"; + /** The values in this set are looked up with case insensitivity. */ + public static final Set HEADER_PREFIXES_TO_BE_PROPAGATED = + Set.of( + TENANT_ID_HEADER_KEY, + "X-B3-", + "grpc-trace-bin", + "traceparent", + "tracestate", + AUTHORIZATION_HEADER); + /** - * The values in this set are looked up with case insensitivity. + * These headers may affect returned results and should be accounted for in any cached remote + * results */ - public static final Set HEADER_PREFIXES_TO_BE_PROPAGATED = - Set.of(TENANT_ID_HEADER_KEY, "X-B3-", "grpc-trace-bin", - "traceparent", "tracestate", AUTHORIZATION_HEADER); + static final Set CACHE_MEANINGFUL_HEADERS = + Set.of(TENANT_ID_HEADER_KEY, AUTHORIZATION_HEADER); } diff --git a/grpc-context-utils/src/test/java/org/hypertrace/core/grpcutils/context/DefaultContextualKeyTest.java b/grpc-context-utils/src/test/java/org/hypertrace/core/grpcutils/context/DefaultContextualKeyTest.java new file mode 100644 index 0000000..2820ead --- /dev/null +++ b/grpc-context-utils/src/test/java/org/hypertrace/core/grpcutils/context/DefaultContextualKeyTest.java @@ -0,0 +1,85 @@ +package org.hypertrace.core.grpcutils.context; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import org.junit.jupiter.api.Test; + +class DefaultContextualKeyTest { + + @Test + void callsProvidedMethodsInContext() { + RequestContext testContext = RequestContext.forTenantId("test-tenant"); + ContextualKey key = new DefaultContextualKey<>(testContext, "input"); + + Function testFunction = + value -> + "returned: " + + value + + " for " + + RequestContext.CURRENT.get().getTenantId().orElseThrow(); + + assertEquals("returned: input for test-tenant", key.callInContext(testFunction)); + + Supplier testSupplier = + () -> "returned for " + RequestContext.CURRENT.get().getTenantId().orElseThrow(); + + assertEquals("returned for test-tenant", key.callInContext(testSupplier)); + } + + @Test + void runsProvidedMethodInContext() { + RequestContext testContext = RequestContext.forTenantId("test-tenant"); + ContextualKey key = new DefaultContextualKey<>(testContext, "input"); + + Consumer testConsumer = mock(Consumer.class); + + doAnswer( + invocation -> { + assertSame(testContext, RequestContext.CURRENT.get()); + return null; + }) + .when(testConsumer) + .accept(any()); + key.runInContext(testConsumer); + verify(testConsumer, times(1)).accept(eq("input")); + + Runnable testRunnable = mock(Runnable.class); + key.runInContext(testRunnable); + verify(testRunnable, times(1)).run(); + } + + @Test + void matchesEquivalentKeysOnly() { + RequestContext tenant1Context = RequestContext.forTenantId("first"); + RequestContext alternateTenant1Context = RequestContext.forTenantId("first"); + alternateTenant1Context.add("other", "value"); + RequestContext tenant2Context = RequestContext.forTenantId("second"); + + assertEquals( + new DefaultContextualKey<>(tenant1Context, "input"), + new DefaultContextualKey<>(tenant1Context, "input")); + + assertEquals( + new DefaultContextualKey<>(tenant1Context, "input"), + new DefaultContextualKey<>(alternateTenant1Context, "input")); + + assertNotEquals( + new DefaultContextualKey<>(tenant1Context, "input"), + new DefaultContextualKey<>(tenant2Context, "input")); + + assertNotEquals( + new DefaultContextualKey<>(tenant1Context, "input"), + new DefaultContextualKey<>(tenant1Context, "other input")); + } +}