diff --git a/checkstyle.xml b/checkstyle.xml index 498b4837d..eaf475b18 100644 --- a/checkstyle.xml +++ b/checkstyle.xml @@ -34,7 +34,7 @@ - + @@ -46,7 +46,23 @@ + + + + + + + + + + + + + + + @@ -223,7 +239,7 @@ - + - - - - - diff --git a/pom.xml b/pom.xml index 3c986b59e..9d1b0e0f9 100644 --- a/pom.xml +++ b/pom.xml @@ -114,12 +114,34 @@ ${junit.jupiter.version} test + org.junit.platform junit-platform-suite 1.8.1 test + + + org.mockito + mockito-inline + 3.12.4 + test + + + + uk.org.webcompere + system-stubs-core + 2.0.1 + test + + + + uk.org.webcompere + system-stubs-jupiter + 2.0.1 + test + diff --git a/providers/flagd/README.md b/providers/flagd/README.md index 5c344a7f8..cfa90756b 100644 --- a/providers/flagd/README.md +++ b/providers/flagd/README.md @@ -23,3 +23,33 @@ The `FlagdProvider` communicates with flagd via the gRPC protocol. Instantiate a FlagdProvider provider = new FlagdProvider(Protocol.HTTP, "localhost", 8013); OpenFeatureAPI.getInstance().setProvider(provider); ``` + +Options can be defined in the constructor or as environment variables, with constructor options having the highest precedence. + +| Option name | Environment variable name | Type | Default | +| ----------- | ------------------------- | ------- | --------- | +| host | FLAGD_HOST | string | localhost | +| port | FLAGD_PORT | number | 8013 | +| tls | FLAGD_TLS | boolean | false | +| socketPath | FLAGD_SOCKET_PATH | string | - | +| certPath | FLAGD_SERVER_CERT_PATH | string | - | + +### Unix socket support + +Unix socket communication with flag is facilitated via usage of the linux-native `epoll` library on `linux-x86_64` only (ARM support is pending relase of `netty-transport-native-epoll` v5). Unix sockets are not supported on other platforms or architectures. + +### Reconnection + +Reconnection is supported by the underlying GRPCBlockingStub. If connection to flagd is lost, it will reconnect automatically. + +### Deadline (gRPC call timeout) + +The deadline for an individual flag evaluation can be configured by calling `setDeadline(< deadline in millis >)`. +If the gRPC call is not completed within this deadline, the gRPC call is terminated with the error `DEADLINE_EXCEEDED` and the evaluation will default. +The default deadline is 500ms, though evaluations typically take on the order of 10ms. + +### TLS + +Though not required in deployments where flagd runs on the same host as the workload, TLS is available. + +:warning: Note that there's a [vulnerability](https://security.snyk.io/vuln/SNYK-JAVA-IONETTY-1042268) in [netty](https://github.com/netty/netty), a transitive dependency of the underlying gRPC libraries used in the flagd-provider that fails to correctly validate certificates. This will be addressed in netty v5. diff --git a/providers/flagd/pom.xml b/providers/flagd/pom.xml index 3c9ae1a3e..7e61bc9ca 100644 --- a/providers/flagd/pom.xml +++ b/providers/flagd/pom.xml @@ -29,20 +29,31 @@ io.grpc - grpc-netty-shaded - 1.48.1 - runtime + grpc-netty + 1.51.0 + + + + io.netty + netty-transport-native-epoll + 4.1.85.Final + + linux-x86_64 + + io.grpc grpc-protobuf 1.48.2 + io.grpc grpc-stub 1.48.1 + org.apache.tomcat diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java index a31890038..fbf798e3a 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java @@ -1,10 +1,14 @@ package dev.openfeature.contrib.providers.flagd; +import java.io.File; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; +import javax.net.ssl.SSLException; + import dev.openfeature.flagd.grpc.Schema.ResolveBooleanRequest; import dev.openfeature.flagd.grpc.Schema.ResolveBooleanResponse; import dev.openfeature.flagd.grpc.Schema.ResolveFloatRequest; @@ -23,52 +27,66 @@ import dev.openfeature.sdk.MutableStructure; import dev.openfeature.sdk.ProviderEvaluation; import dev.openfeature.sdk.Value; -import io.grpc.ManagedChannelBuilder; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.epoll.EpollDomainSocketChannel; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.handler.ssl.SslContextBuilder; import lombok.extern.slf4j.Slf4j; - /** * OpenFeature provider for flagd. */ @Slf4j public class FlagdProvider implements FeatureProvider { - private ServiceBlockingStub serviceStub; static final String PROVIDER_NAME = "flagD Provider"; - static final int DEFAULT_PORT = 8013; - static final String DEFAULT_HOST = "localhost"; - + static final String DEFAULT_PORT = "8013"; + static final String DEFAULT_TLS = "false"; + static final String DEFAULT_HOST = "localhost"; + static final int DEFAULT_DEADLINE = 500; + + static final String HOST_ENV_VAR_NAME = "FLAGD_HOST"; + static final String PORT_ENV_VAR_NAME = "FLAGD_PORT"; + static final String TLS_ENV_VAR_NAME = "FLAGD_TLS"; + static final String SOCKET_PATH_ENV_VAR_NAME = "FLAGD_SOCKET_PATH"; + static final String SERVER_CERT_PATH_ENV_VAR_NAME = "FLAGD_SERVER_CERT_PATH"; + + private long deadline = DEFAULT_DEADLINE; + private ServiceBlockingStub serviceStub; + /** * Create a new FlagdProvider instance. * - * @param protocol transport protocol, "http" or "https" - * @param host flagd host, defaults to localhost - * @param port flagd port, defaults to 8013 + * @param socketPath unix socket path */ - public FlagdProvider(Protocol protocol, String host, int port) { - - this(Protocol.HTTPS == protocol - ? ServiceGrpc.newBlockingStub(ManagedChannelBuilder.forAddress(host, port) - .useTransportSecurity() - .build()) : - ServiceGrpc.newBlockingStub(ManagedChannelBuilder.forAddress(host, port) - .usePlaintext() - .build())); + public FlagdProvider(String socketPath) { + this(buildServiceStub(null, null, null, null, socketPath)); } /** * Create a new FlagdProvider instance. + * + * @param host flagd server host, defaults to "localhost" + * @param port flagd server port, defaults to 8013 + * @param tls use TLS, defaults to false + * @param certPath path for server certificate, defaults to null to, using + * system certs + * */ - public FlagdProvider() { - this(Protocol.HTTP, DEFAULT_HOST, DEFAULT_PORT); + public FlagdProvider(String host, int port, boolean tls, String certPath) { + this(buildServiceStub(host, port, tls, certPath, null)); } /** * Create a new FlagdProvider instance. - * - * @param serviceStub service stub instance to use */ - public FlagdProvider(ServiceBlockingStub serviceStub) { + public FlagdProvider() { + this(buildServiceStub(null, null, null, null, null)); + } + + FlagdProvider(ServiceBlockingStub serviceStub) { this.serviceStub = serviceStub; } @@ -84,76 +102,92 @@ public String getName() { @Override public ProviderEvaluation getBooleanEvaluation(String key, Boolean defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { ResolveBooleanRequest request = ResolveBooleanRequest.newBuilder() - .setFlagKey(key) - .setContext(this.convertContext(ctx)) - .build(); - ResolveBooleanResponse r = this.serviceStub.resolveBoolean(request); + .setFlagKey(key) + .setContext(this.convertContext(ctx)) + .build(); + ResolveBooleanResponse r = this.serviceStub.withDeadlineAfter(this.deadline, TimeUnit.MILLISECONDS) + .resolveBoolean(request); return ProviderEvaluation.builder() - .value(r.getValue()) - .variant(r.getVariant()) - .reason(r.getReason()) - .build(); + .value(r.getValue()) + .variant(r.getVariant()) + .reason(r.getReason()) + .build(); } @Override public ProviderEvaluation getStringEvaluation(String key, String defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { ResolveStringRequest request = ResolveStringRequest.newBuilder() - .setFlagKey(key) - .setContext(this.convertContext(ctx)).build(); - ResolveStringResponse r = this.serviceStub.resolveString(request); + .setFlagKey(key) + .setContext(this.convertContext(ctx)).build(); + ResolveStringResponse r = this.serviceStub.withDeadlineAfter(this.deadline, TimeUnit.MILLISECONDS) + .resolveString(request); return ProviderEvaluation.builder().value(r.getValue()) - .variant(r.getVariant()) - .reason(r.getReason()) - .build(); + .variant(r.getVariant()) + .reason(r.getReason()) + .build(); } @Override public ProviderEvaluation getDoubleEvaluation(String key, Double defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { ResolveFloatRequest request = ResolveFloatRequest.newBuilder() - .setFlagKey(key) - .setContext(this.convertContext(ctx)) - .build(); - ResolveFloatResponse r = this.serviceStub.resolveFloat(request); + .setFlagKey(key) + .setContext(this.convertContext(ctx)) + .build(); + ResolveFloatResponse r = this.serviceStub.withDeadlineAfter(this.deadline, TimeUnit.MILLISECONDS) + .resolveFloat(request); return ProviderEvaluation.builder() - .value(r.getValue()) - .variant(r.getVariant()) - .reason(r.getReason()) - .build(); + .value(r.getValue()) + .variant(r.getVariant()) + .reason(r.getReason()) + .build(); } @Override public ProviderEvaluation getIntegerEvaluation(String key, Integer defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { ResolveIntRequest request = ResolveIntRequest.newBuilder() - .setFlagKey(key) - .setContext(this.convertContext(ctx)) - .build(); - ResolveIntResponse r = this.serviceStub.resolveInt(request); + .setFlagKey(key) + .setContext(this.convertContext(ctx)) + .build(); + ResolveIntResponse r = this.serviceStub.withDeadlineAfter(this.deadline, TimeUnit.MILLISECONDS) + .resolveInt(request); return ProviderEvaluation.builder() - .value((int) r.getValue()) - .variant(r.getVariant()) - .reason(r.getReason()) - .build(); + .value((int) r.getValue()) + .variant(r.getVariant()) + .reason(r.getReason()) + .build(); } @Override public ProviderEvaluation getObjectEvaluation(String key, Value defaultValue, - EvaluationContext ctx) { + EvaluationContext ctx) { ResolveObjectRequest request = ResolveObjectRequest.newBuilder() - .setFlagKey(key) - .setContext(this.convertContext(ctx)) - .build(); - ResolveObjectResponse r = this.serviceStub.resolveObject(request); + .setFlagKey(key) + .setContext(this.convertContext(ctx)) + .build(); + ResolveObjectResponse r = this.serviceStub.withDeadlineAfter(this.deadline, TimeUnit.MILLISECONDS) + .resolveObject(request); return ProviderEvaluation.builder() - .value(this.convertObjectResponse(r.getValue())) - .variant(r.getVariant()) - .reason(r.getReason()) - .build(); + .value(this.convertObjectResponse(r.getValue())) + .variant(r.getVariant()) + .reason(r.getReason()) + .build(); + } + + /** + * Sets how long to wait for an evaluation. + * + * @param deadlineMs time to wait before gRPC call is cancelled. Defaults to 10ms. + * @return FlagdProvider + */ + FlagdProvider setDeadline(long deadlineMs) { + this.deadline = deadlineMs; + return this; } /** @@ -207,7 +241,7 @@ private com.google.protobuf.Value convertMap(Map map) { values.put(key, this.convertAny(value)); }); com.google.protobuf.Struct struct = com.google.protobuf.Struct.newBuilder() - .putAllFields(values).build(); + .putAllFields(values).build(); return com.google.protobuf.Value.newBuilder().setStructValue(struct).build(); } @@ -225,12 +259,13 @@ private Value convertProtobufMap(Map map) { } /** - * Convert openfeature list to protobuf list. + * Convert openfeature list to protobuf list. */ private com.google.protobuf.Value convertList(List values) { com.google.protobuf.ListValue list = com.google.protobuf.ListValue.newBuilder() - .addAllValues(values.stream() - .map(v -> this.convertAny(v)).collect(Collectors.toList())).build(); + .addAllValues(values.stream() + .map(v -> this.convertAny(v)).collect(Collectors.toList())) + .build(); return com.google.protobuf.Value.newBuilder().setListValue(list).build(); } @@ -275,4 +310,51 @@ private Value convertPrimitive(com.google.protobuf.Value protobuf) { } return value; } + + private static ServiceBlockingStub buildServiceStub(String host, Integer port, Boolean tls, String certPath, + String socketPath) { + + host = host != null ? host : fallBackToEnvOrDefault(HOST_ENV_VAR_NAME, DEFAULT_HOST); + port = port != null ? port : Integer.parseInt(fallBackToEnvOrDefault(PORT_ENV_VAR_NAME, DEFAULT_PORT)); + tls = tls != null ? tls : Boolean.parseBoolean(fallBackToEnvOrDefault(TLS_ENV_VAR_NAME, DEFAULT_TLS)); + certPath = certPath != null ? certPath : fallBackToEnvOrDefault(SERVER_CERT_PATH_ENV_VAR_NAME, null); + socketPath = socketPath != null ? socketPath : fallBackToEnvOrDefault(SOCKET_PATH_ENV_VAR_NAME, null); + + // we have a socket path specified, build a channel with a unix socket + if (socketPath != null) { + return ServiceGrpc.newBlockingStub(NettyChannelBuilder + .forAddress(new DomainSocketAddress(socketPath)) + .eventLoopGroup(new EpollEventLoopGroup()) + .channelType(EpollDomainSocketChannel.class) + .usePlaintext() + .build()); + } + + // build a TCP socket + try { + NettyChannelBuilder builder = NettyChannelBuilder + .forAddress(host, port); + + if (tls) { + SslContextBuilder sslContext = GrpcSslContexts.forClient(); + if (certPath != null) { + sslContext.trustManager(new File(certPath)); + } + builder.sslContext(sslContext.build()); + } else { + builder.usePlaintext(); + } + + return ServiceGrpc + .newBlockingStub(builder.build()); + } catch (SSLException ssle) { + SslConfigException sslConfigException = new SslConfigException("Error with SSL configuration."); + sslConfigException.initCause(ssle); + throw sslConfigException; + } + } + + private static String fallBackToEnvOrDefault(String key, String defaultValue) { + return System.getenv(key) != null ? System.getenv(key) : defaultValue; + } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/SslConfigException.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/SslConfigException.java new file mode 100644 index 000000000..8d7fe2b2e --- /dev/null +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/SslConfigException.java @@ -0,0 +1,7 @@ +package dev.openfeature.contrib.providers.flagd; + +class SslConfigException extends RuntimeException { + public SslConfigException(String message) { + super(message); + } +} diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java index 1aa7432c3..3d2487146 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java @@ -3,15 +3,24 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; import dev.openfeature.flagd.grpc.Schema.ResolveBooleanRequest; import dev.openfeature.flagd.grpc.Schema.ResolveBooleanResponse; @@ -19,8 +28,8 @@ import dev.openfeature.flagd.grpc.Schema.ResolveIntResponse; import dev.openfeature.flagd.grpc.Schema.ResolveObjectResponse; import dev.openfeature.flagd.grpc.Schema.ResolveStringResponse; +import dev.openfeature.flagd.grpc.ServiceGrpc; import dev.openfeature.flagd.grpc.ServiceGrpc.ServiceBlockingStub; -import dev.openfeature.sdk.EvaluationContext; import dev.openfeature.sdk.FlagEvaluationDetails; import dev.openfeature.sdk.MutableContext; import dev.openfeature.sdk.MutableStructure; @@ -28,6 +37,12 @@ import dev.openfeature.sdk.Reason; import dev.openfeature.sdk.Structure; import dev.openfeature.sdk.Value; +import io.grpc.ManagedChannel; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.unix.DomainSocketAddress; +import uk.org.webcompere.systemstubs.environment.EnvironmentVariables; class FlagdProviderTest { @@ -42,12 +57,15 @@ class FlagdProviderTest { static final Double DOUBLE_VALUE = .5d; static final String INNER_STRUCT_KEY = "inner_key"; static final String INNER_STRUCT_VALUE = "inner_value"; - static final Structure OBJECT_VALUE = new MutableStructure() {{ + static final Structure OBJECT_VALUE = new MutableStructure() { + { add(INNER_STRUCT_KEY, INNER_STRUCT_VALUE); - }}; + } + }; static final com.google.protobuf.Struct PROTOBUF_STRUCTURE_VALUE = com.google.protobuf.Struct.newBuilder() - .putFields(INNER_STRUCT_KEY, com.google.protobuf.Value.newBuilder().setStringValue(INNER_STRUCT_VALUE).build()) - .build(); + .putFields(INNER_STRUCT_KEY, + com.google.protobuf.Value.newBuilder().setStringValue(INNER_STRUCT_VALUE).build()) + .build(); static final String STRING_VALUE = "hi!"; static OpenFeatureAPI api; @@ -57,52 +75,174 @@ public static void init() { api = OpenFeatureAPI.getInstance(); } + @Test + void path_arg_should_build_domain_socket_with_correct_path() { + final String path = "/some/path"; + + ServiceBlockingStub mockStub = mock(ServiceBlockingStub.class); + NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); + + try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { + mockStaticService.when(() -> ServiceGrpc.newBlockingStub(any(ManagedChannel.class))) + .thenReturn(mockStub); + + try (MockedStatic mockStaticChannelBuilder = mockStatic(NettyChannelBuilder.class)) { + + try (MockedConstruction mockEpollEventLoopGroup = mockConstruction( + EpollEventLoopGroup.class, + (mock, context) -> { + })) { + mockStaticChannelBuilder.when(() -> NettyChannelBuilder + .forAddress(any(DomainSocketAddress.class))).thenReturn(mockChannelBuilder); + new FlagdProvider(path); + + // verify path matches + mockStaticChannelBuilder.verify(() -> NettyChannelBuilder + .forAddress(argThat((DomainSocketAddress d) -> { + assertEquals(d.path(), path); // path should match + return true; + }))); + } + } + } + } + + @Test + void no_args_socket_env_should_build_domain_socket_with_correct_path() throws Exception { + final String path = "/some/other/path"; + + new EnvironmentVariables("FLAGD_SOCKET_PATH", path).execute(() -> { + + ServiceBlockingStub mockStub = mock(ServiceBlockingStub.class); + NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); + + try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { + mockStaticService.when(() -> ServiceGrpc.newBlockingStub(any(ManagedChannel.class))) + .thenReturn(mockStub); + + try (MockedStatic mockStaticChannelBuilder = mockStatic( + NettyChannelBuilder.class)) { + + try (MockedConstruction mockEpollEventLoopGroup = mockConstruction( + EpollEventLoopGroup.class, + (mock, context) -> { + })) { + mockStaticChannelBuilder.when(() -> NettyChannelBuilder + .forAddress(any(DomainSocketAddress.class))).thenReturn(mockChannelBuilder); + new FlagdProvider(); + + //verify path matches + mockStaticChannelBuilder.verify(() -> NettyChannelBuilder + .forAddress(argThat((DomainSocketAddress d) -> { + return d.path() == path; + }))); + } + } + } + }); + } + + @Test + void host_and_port_arg_should_build_tcp_socket() { + final String host = "host.com"; + final int port = 1234; + + ServiceBlockingStub mockStub = mock(ServiceBlockingStub.class); + NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); + + try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { + mockStaticService.when(() -> ServiceGrpc.newBlockingStub(any(ManagedChannel.class))) + .thenReturn(mockStub); + + try (MockedStatic mockStaticChannelBuilder = mockStatic(NettyChannelBuilder.class)) { + + mockStaticChannelBuilder.when(() -> NettyChannelBuilder + .forAddress(anyString(), anyInt())).thenReturn(mockChannelBuilder); + new FlagdProvider(host, port, false, null); + + // verify host/port matches + mockStaticChannelBuilder.verify(() -> NettyChannelBuilder + .forAddress(host, port)); + } + } + } + + @Test + void no_args_host_and_port_env_set_should_build_tcp_socket() throws Exception { + final String host = "server.com"; + final int port = 4321; + + new EnvironmentVariables("FLAGD_HOST", host, "FLAGD_PORT", String.valueOf(port)).execute(() -> { + ServiceBlockingStub mockStub = mock(ServiceBlockingStub.class); + NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); + + try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { + mockStaticService.when(() -> ServiceGrpc.newBlockingStub(any(ManagedChannel.class))) + .thenReturn(mockStub); + + try (MockedStatic mockStaticChannelBuilder = mockStatic( + NettyChannelBuilder.class)) { + + mockStaticChannelBuilder.when(() -> NettyChannelBuilder + .forAddress(anyString(), anyInt())).thenReturn(mockChannelBuilder); + new FlagdProvider(); + + // verify host/port matches + mockStaticChannelBuilder.verify(() -> NettyChannelBuilder + .forAddress(host, port)); + } + } + }); + } + @Test void resolvers_call_grpc_service_and_return_details() { ResolveBooleanResponse booleanResponse = ResolveBooleanResponse.newBuilder() - .setValue(true) - .setVariant(BOOL_VARIANT) - .setReason(DEFAULT.toString()) - .build(); + .setValue(true) + .setVariant(BOOL_VARIANT) + .setReason(DEFAULT.toString()) + .build(); ResolveStringResponse stringResponse = ResolveStringResponse.newBuilder() - .setValue(STRING_VALUE) - .setVariant(STRING_VARIANT) - .setReason(DEFAULT.toString()) - .build(); + .setValue(STRING_VALUE) + .setVariant(STRING_VARIANT) + .setReason(DEFAULT.toString()) + .build(); ResolveIntResponse intResponse = ResolveIntResponse.newBuilder() - .setValue(INT_VALUE) - .setVariant(INT_VARIANT) - .setReason(DEFAULT.toString()) - .build(); + .setValue(INT_VALUE) + .setVariant(INT_VARIANT) + .setReason(DEFAULT.toString()) + .build(); ResolveFloatResponse floatResponse = ResolveFloatResponse.newBuilder() - .setValue(DOUBLE_VALUE) - .setVariant(DOUBLE_VARIANT) - .setReason(DEFAULT.toString()) - .build(); + .setValue(DOUBLE_VALUE) + .setVariant(DOUBLE_VARIANT) + .setReason(DEFAULT.toString()) + .build(); ResolveObjectResponse objectResponse = ResolveObjectResponse.newBuilder() - .setValue(PROTOBUF_STRUCTURE_VALUE) - .setVariant(OBJECT_VARIANT) - .setReason(DEFAULT.toString()) - .build(); + .setValue(PROTOBUF_STRUCTURE_VALUE) + .setVariant(OBJECT_VARIANT) + .setReason(DEFAULT.toString()) + .build(); ServiceBlockingStub serviceBlockingStubMock = mock(ServiceBlockingStub.class); + when(serviceBlockingStubMock.withDeadlineAfter(anyLong(), any(TimeUnit.class))) + .thenReturn(serviceBlockingStubMock); when(serviceBlockingStubMock - .resolveBoolean(argThat(x -> FLAG_KEY.equals(x.getFlagKey())))).thenReturn(booleanResponse); + .resolveBoolean(argThat(x -> FLAG_KEY.equals(x.getFlagKey())))).thenReturn(booleanResponse); when(serviceBlockingStubMock - .resolveFloat(argThat(x -> FLAG_KEY.equals(x.getFlagKey())))).thenReturn(floatResponse); + .resolveFloat(argThat(x -> FLAG_KEY.equals(x.getFlagKey())))).thenReturn(floatResponse); when(serviceBlockingStubMock - .resolveInt(argThat(x -> FLAG_KEY.equals(x.getFlagKey())))).thenReturn(intResponse); + .resolveInt(argThat(x -> FLAG_KEY.equals(x.getFlagKey())))).thenReturn(intResponse); when(serviceBlockingStubMock - .resolveString(argThat(x -> FLAG_KEY.equals(x.getFlagKey())))).thenReturn(stringResponse); + .resolveString(argThat(x -> FLAG_KEY.equals(x.getFlagKey())))).thenReturn(stringResponse); when(serviceBlockingStubMock - .resolveObject(argThat(x -> FLAG_KEY.equals(x.getFlagKey())))).thenReturn(objectResponse); + .resolveObject(argThat(x -> FLAG_KEY.equals(x.getFlagKey())))).thenReturn(objectResponse); OpenFeatureAPI.getInstance().setProvider(new FlagdProvider(serviceBlockingStubMock)); - + FlagEvaluationDetails booleanDetails = api.getClient().getBooleanDetails(FLAG_KEY, false); assertTrue(booleanDetails.getValue()); assertEquals(BOOL_VARIANT, booleanDetails.getVariant()); @@ -117,7 +257,7 @@ void resolvers_call_grpc_service_and_return_details() { assertEquals(INT_VALUE, intDetails.getValue()); assertEquals(INT_VARIANT, intDetails.getVariant()); assertEquals(DEFAULT.toString(), intDetails.getReason()); - + FlagEvaluationDetails floatDetails = api.getClient().getDoubleDetails(FLAG_KEY, 0.1); assertEquals(DOUBLE_VALUE, floatDetails.getValue()); assertEquals(DOUBLE_VARIANT, floatDetails.getVariant()); @@ -125,7 +265,7 @@ void resolvers_call_grpc_service_and_return_details() { FlagEvaluationDetails objectDetails = api.getClient().getObjectDetails(FLAG_KEY, new Value()); assertEquals(INNER_STRUCT_VALUE, objectDetails.getValue().asStructure() - .asMap().get(INNER_STRUCT_KEY).asString()); + .asMap().get(INNER_STRUCT_KEY).asString()); assertEquals(OBJECT_VARIANT, objectDetails.getVariant()); assertEquals(DEFAULT.toString(), objectDetails.getReason()); } @@ -145,33 +285,38 @@ void context_is_parsed_and_passed_to_grpc_service() { final Integer INT_ATTR_VALUE = 1; final String STRING_ATTR_VALUE = "str"; final Double DOUBLE_ATTR_VALUE = 0.5d; - final List LIST_ATTR_VALUE = new ArrayList() {{ + final List LIST_ATTR_VALUE = new ArrayList() { + { add(new Value(1)); - }}; + } + }; final String STRUCT_ATTR_INNER_VALUE = "struct-inner-value"; final Structure STRUCT_ATTR_VALUE = new MutableStructure().add(STRUCT_ATTR_INNER_KEY, STRUCT_ATTR_INNER_VALUE); final String DEFAULT_STRING = "DEFAULT"; ResolveBooleanResponse booleanResponse = ResolveBooleanResponse.newBuilder() - .setValue(true) - .setVariant(BOOL_VARIANT) - .setReason(DEFAULT_STRING.toString()) - .build(); + .setValue(true) + .setVariant(BOOL_VARIANT) + .setReason(DEFAULT_STRING.toString()) + .build(); ServiceBlockingStub serviceBlockingStubMock = mock(ServiceBlockingStub.class); - when(serviceBlockingStubMock.resolveBoolean(argThat(x -> - STRING_ATTR_VALUE.equals(x.getContext().getFieldsMap().get(STRING_ATTR_KEY).getStringValue()) - && INT_ATTR_VALUE == x.getContext().getFieldsMap().get(INT_ATTR_KEY).getNumberValue() - && DOUBLE_ATTR_VALUE == x.getContext().getFieldsMap().get(DOUBLE_ATTR_KEY).getNumberValue() - && LIST_ATTR_VALUE.get(0).asInteger() == x.getContext().getFieldsMap() - .get(LIST_ATTR_KEY).getListValue().getValuesList().get(0).getNumberValue() - && x.getContext().getFieldsMap().get(BOOLEAN_ATTR_KEY).getBoolValue() - && STRUCT_ATTR_INNER_VALUE.equals(x.getContext().getFieldsMap() - .get(STRUCT_ATTR_KEY).getStructValue().getFieldsMap().get(STRUCT_ATTR_INNER_KEY).getStringValue()) - ))).thenReturn(booleanResponse); + when(serviceBlockingStubMock.withDeadlineAfter(anyLong(), any(TimeUnit.class))) + .thenReturn(serviceBlockingStubMock); + when(serviceBlockingStubMock.resolveBoolean(argThat( + x -> STRING_ATTR_VALUE.equals(x.getContext().getFieldsMap().get(STRING_ATTR_KEY).getStringValue()) + && INT_ATTR_VALUE == x.getContext().getFieldsMap().get(INT_ATTR_KEY).getNumberValue() + && DOUBLE_ATTR_VALUE == x.getContext().getFieldsMap().get(DOUBLE_ATTR_KEY).getNumberValue() + && LIST_ATTR_VALUE.get(0).asInteger() == x.getContext().getFieldsMap() + .get(LIST_ATTR_KEY).getListValue().getValuesList().get(0).getNumberValue() + && x.getContext().getFieldsMap().get(BOOLEAN_ATTR_KEY).getBoolValue() + && STRUCT_ATTR_INNER_VALUE.equals(x.getContext().getFieldsMap() + .get(STRUCT_ATTR_KEY).getStructValue().getFieldsMap().get(STRUCT_ATTR_INNER_KEY) + .getStringValue())))) + .thenReturn(booleanResponse); OpenFeatureAPI.getInstance().setProvider(new FlagdProvider(serviceBlockingStubMock)); - + MutableContext context = new MutableContext(); context.add(BOOLEAN_ATTR_KEY, BOOLEAN_ATTR_VALUE); context.add(INT_ATTR_KEY, INT_ATTR_VALUE); @@ -186,22 +331,54 @@ void context_is_parsed_and_passed_to_grpc_service() { assertEquals(DEFAULT.toString(), booleanDetails.getReason()); } - //TODO: update this to be able unknown codes + @Test + void set_deadline_deadline_send_in_grpc() { + long deadline = 1300; + + ResolveBooleanResponse badReasonResponse = ResolveBooleanResponse.newBuilder() + .setValue(true) + .build(); + + ServiceBlockingStub serviceBlockingStubMock = mock(ServiceBlockingStub.class); + when(serviceBlockingStubMock.withDeadlineAfter(anyLong(), any(TimeUnit.class))) + .thenReturn(serviceBlockingStubMock); + when(serviceBlockingStubMock.resolveBoolean(any(ResolveBooleanRequest.class))).thenReturn(badReasonResponse); + + FlagdProvider provider = new FlagdProvider(serviceBlockingStubMock); + provider.setDeadline(deadline); + OpenFeatureAPI.getInstance().setProvider(provider); + + + api.getClient().getBooleanDetails(FLAG_KEY, false, new MutableContext()); + verify(serviceBlockingStubMock).withDeadlineAfter(deadline, TimeUnit.MILLISECONDS); + } + @Test void reason_mapped_correctly_if_unknown() { ResolveBooleanResponse badReasonResponse = ResolveBooleanResponse.newBuilder() - .setValue(true) - .setVariant(BOOL_VARIANT) - .setReason("UNKNOWN") // set an invalid reason string - .build(); + .setValue(true) + .setVariant(BOOL_VARIANT) + .setReason("UNKNOWN") // set an invalid reason string + .build(); ServiceBlockingStub serviceBlockingStubMock = mock(ServiceBlockingStub.class); + when(serviceBlockingStubMock.withDeadlineAfter(anyLong(), any(TimeUnit.class))) + .thenReturn(serviceBlockingStubMock); when(serviceBlockingStubMock.resolveBoolean(any(ResolveBooleanRequest.class))).thenReturn(badReasonResponse); OpenFeatureAPI.getInstance().setProvider(new FlagdProvider(serviceBlockingStubMock)); FlagEvaluationDetails booleanDetails = api.getClient() - .getBooleanDetails(FLAG_KEY, false, new MutableContext()); + .getBooleanDetails(FLAG_KEY, false, new MutableContext()); assertEquals(Reason.UNKNOWN.toString(), booleanDetails.getReason()); // reason should be converted to UNKNOWN } + + private NettyChannelBuilder getMockChannelBuilderSocket() { + NettyChannelBuilder mockChannelBuilder = mock(NettyChannelBuilder.class); + when(mockChannelBuilder.eventLoopGroup(any(EventLoopGroup.class))).thenReturn(mockChannelBuilder); + when(mockChannelBuilder.channelType(any(Class.class))).thenReturn(mockChannelBuilder); + when(mockChannelBuilder.usePlaintext()).thenReturn(mockChannelBuilder); + when(mockChannelBuilder.build()).thenReturn(null); + return mockChannelBuilder; + } } \ No newline at end of file