diff --git a/storage/connectors/redis/pom.xml b/storage/connectors/redis/pom.xml index d0e127cde8..ca6e8d42ad 100644 --- a/storage/connectors/redis/pom.xml +++ b/storage/connectors/redis/pom.xml @@ -89,7 +89,12 @@ 4.12 test - + + org.apache.beam + beam-sdks-java-extensions-protobuf + ${org.apache.beam.version} + test + org.slf4j slf4j-simple diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/BatchDoFnWithRedis.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/BatchDoFnWithRedis.java new file mode 100644 index 0000000000..d6c83c3a54 --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/BatchDoFnWithRedis.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.writer; + +import feast.storage.common.retry.Retriable; +import io.lettuce.core.RedisException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.function.Function; +import org.apache.beam.sdk.transforms.DoFn; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Base class for redis-related DoFns. Assumes that operations will be batched. Prepares redisClient + * on DoFn.Setup stage and close it on DoFn.Teardown stage. + * + * @param + * @param + */ +public class BatchDoFnWithRedis extends DoFn { + private static final Logger log = LoggerFactory.getLogger(BatchDoFnWithRedis.class); + + private final RedisIngestionClient redisIngestionClient; + + BatchDoFnWithRedis(RedisIngestionClient redisIngestionClient) { + this.redisIngestionClient = redisIngestionClient; + } + + @Setup + public void setup() { + this.redisIngestionClient.setup(); + } + + @StartBundle + public void startBundle() { + try { + redisIngestionClient.connect(); + } catch (RedisException e) { + log.error("Connection to redis cannot be established: %s", e); + } + } + + void executeBatch(Function>> executor) + throws Exception { + this.redisIngestionClient + .getBackOffExecutor() + .execute( + new Retriable() { + @Override + public void execute() throws ExecutionException, InterruptedException { + if (!redisIngestionClient.isConnected()) { + redisIngestionClient.connect(); + } + + Iterable> futures = executor.apply(redisIngestionClient); + redisIngestionClient.sync(futures); + } + + @Override + public Boolean isExceptionRetriable(Exception e) { + return e instanceof RedisException; + } + + @Override + public void cleanUpAfterFailure() {} + }); + } + + @Teardown + public void teardown() { + redisIngestionClient.shutdown(); + } +} diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java index 389db4be3a..f36d70563e 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java @@ -20,7 +20,6 @@ import feast.proto.core.StoreProto; import feast.storage.common.retry.BackOffExecutor; import io.lettuce.core.LettuceFutures; -import io.lettuce.core.RedisFuture; import io.lettuce.core.RedisURI; import io.lettuce.core.cluster.RedisClusterClient; import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; @@ -28,6 +27,8 @@ import io.lettuce.core.codec.ByteArrayCodec; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.joda.time.Duration; @@ -39,7 +40,6 @@ public class RedisClusterIngestionClient implements RedisIngestionClient { private transient RedisClusterClient clusterClient; private StatefulRedisClusterConnection connection; private RedisAdvancedClusterAsyncCommands commands; - private List futures = Lists.newArrayList(); public RedisClusterIngestionClient(StoreProto.Store.RedisClusterConfig redisClusterConfig) { this.uriList = @@ -55,7 +55,6 @@ public RedisClusterIngestionClient(StoreProto.Store.RedisClusterConfig redisClus redisClusterConfig.getInitialBackoffMs() > 0 ? redisClusterConfig.getInitialBackoffMs() : 1; this.backOffExecutor = new BackOffExecutor(redisClusterConfig.getMaxRetries(), Duration.millis(backoffMs)); - this.clusterClient = RedisClusterClient.create(uriList); } @Override @@ -78,6 +77,10 @@ public void connect() { if (!isConnected()) { this.connection = clusterClient.connect(new ByteArrayCodec()); this.commands = connection.async(); + + // despite we're using async API client still flushes after each command by default + // which we don't want since we produce all commands in batches + this.commands.setAutoFlushCommands(false); } } @@ -87,46 +90,20 @@ public boolean isConnected() { } @Override - public void sync() { - try { - LettuceFutures.awaitAll(60, TimeUnit.SECONDS, futures.toArray(new RedisFuture[0])); - } finally { - futures.clear(); - } - } - - @Override - public void pexpire(byte[] key, Long expiryMillis) { - futures.add(commands.pexpire(key, expiryMillis)); - } - - @Override - public void append(byte[] key, byte[] value) { - futures.add(commands.append(key, value)); - } - - @Override - public void set(byte[] key, byte[] value) { - futures.add(commands.set(key, value)); - } + public void sync(Iterable> futures) { + this.connection.flushCommands(); - @Override - public void lpush(byte[] key, byte[] value) { - futures.add(commands.lpush(key, value)); - } - - @Override - public void rpush(byte[] key, byte[] value) { - futures.add(commands.rpush(key, value)); + LettuceFutures.awaitAll( + 60, TimeUnit.SECONDS, Lists.newArrayList(futures).toArray(new Future[0])); } @Override - public void sadd(byte[] key, byte[] value) { - futures.add(commands.sadd(key, value)); + public CompletableFuture set(byte[] key, byte[] value) { + return commands.set(key, value).toCompletableFuture(); } @Override - public void zadd(byte[] key, Long score, byte[] value) { - futures.add(commands.zadd(key, score, value)); + public CompletableFuture get(byte[] key) { + return commands.get(key).toCompletableFuture(); } } diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java index f73c458d78..c42cff7bd0 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java @@ -17,8 +17,9 @@ package feast.storage.connectors.redis.writer; import com.google.common.collect.Iterators; -import com.google.common.collect.Lists; +import com.google.common.collect.Streams; import com.google.common.hash.Hashing; +import com.google.protobuf.InvalidProtocolBufferException; import feast.proto.core.FeatureSetProto.EntitySpec; import feast.proto.core.FeatureSetProto.FeatureSetSpec; import feast.proto.core.FeatureSetProto.FeatureSpec; @@ -29,20 +30,20 @@ import feast.proto.types.ValueProto; import feast.storage.api.writer.FailedElement; import feast.storage.api.writer.WriteResult; -import feast.storage.common.retry.Retriable; import feast.storage.connectors.redis.retriever.FeatureRowDecoder; -import io.lettuce.core.RedisException; import java.nio.charset.StandardCharsets; +import java.util.*; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutionException; +import java.util.function.BinaryOperator; import java.util.stream.Collectors; import org.apache.beam.sdk.transforms.*; import org.apache.beam.sdk.transforms.windowing.*; import org.apache.beam.sdk.values.*; import org.apache.commons.lang3.exception.ExceptionUtils; import org.apache.commons.lang3.tuple.ImmutablePair; +import org.joda.time.DateTime; import org.joda.time.Duration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -116,63 +117,23 @@ public void process(ProcessContext c) { redisWrite.get(failedInsertsTupleTag)); } - public static class WriteDoFn extends DoFn, FeatureRow> { - private PCollectionView>> featureSetSpecsView; - private RedisIngestionClient redisIngestionClient; + /** + * Writes batch of {@link FeatureRow} to Redis. Only latest values should be written. In order + * to guarantee that we first fetch all existing values (first batch operation), compare with + * current batch by eventTimestamp, and send to redis values (second batch operation) that were + * confirmed to be most recent. + */ + public static class WriteDoFn extends BatchDoFnWithRedis, FeatureRow> { + private final PCollectionView>> featureSetSpecsView; WriteDoFn( RedisIngestionClient redisIngestionClient, PCollectionView>> featureSetSpecsView) { - this.redisIngestionClient = redisIngestionClient; + super(redisIngestionClient); this.featureSetSpecsView = featureSetSpecsView; } - @Setup - public void setup() { - this.redisIngestionClient.setup(); - } - - @StartBundle - public void startBundle() { - try { - redisIngestionClient.connect(); - } catch (RedisException e) { - log.error("Connection to redis cannot be established ", e); - } - } - - private void executeBatch( - Iterable featureRows, Map featureSetSpecs) - throws Exception { - this.redisIngestionClient - .getBackOffExecutor() - .execute( - new Retriable() { - @Override - public void execute() throws ExecutionException, InterruptedException { - if (!redisIngestionClient.isConnected()) { - redisIngestionClient.connect(); - } - featureRows.forEach( - row -> { - redisIngestionClient.set( - getKey(row, featureSetSpecs.get(row.getFeatureSet())), - getValue(row, featureSetSpecs.get(row.getFeatureSet()))); - }); - redisIngestionClient.sync(); - } - - @Override - public Boolean isExceptionRetriable(Exception e) { - return e instanceof RedisException; - } - - @Override - public void cleanUpAfterFailure() {} - }); - } - private FailedElement toFailedElement( FeatureRow featureRow, Exception exception, String jobName) { return FailedElement.newBuilder() @@ -184,7 +145,7 @@ private FailedElement toFailedElement( .build(); } - private byte[] getKey(FeatureRow featureRow, FeatureSetSpec spec) { + private RedisKey getKey(FeatureRow featureRow, FeatureSetSpec spec) { List entityNames = spec.getEntitiesList().stream() .map(EntitySpec::getName) @@ -203,7 +164,7 @@ private byte[] getKey(FeatureRow featureRow, FeatureSetSpec spec) { for (String entityName : entityNames) { redisKeyBuilder.addEntities(entityFields.get(entityName)); } - return redisKeyBuilder.build().toByteArray(); + return redisKeyBuilder.build(); } /** @@ -212,7 +173,7 @@ private byte[] getKey(FeatureRow featureRow, FeatureSetSpec spec) { * names and not unsetting the feature set reference. {@link FeatureRowDecoder} is * rensponsible for reversing this "encoding" step. */ - private byte[] getValue(FeatureRow featureRow, FeatureSetSpec spec) { + private FeatureRow getValue(FeatureRow featureRow, FeatureSetSpec spec) { List featureNames = spec.getFeaturesList().stream().map(FeatureSpec::getName).collect(Collectors.toList()); @@ -250,35 +211,101 @@ private byte[] getValue(FeatureRow featureRow, FeatureSetSpec spec) { return FeatureRow.newBuilder() .setEventTimestamp(featureRow.getEventTimestamp()) .addAllFields(values) - .build() - .toByteArray(); + .build(); } @ProcessElement public void processElement(ProcessContext context) { - List featureRows = Lists.newArrayList(context.element().iterator()); - + List filteredFeatureRows = Collections.synchronizedList(new ArrayList<>()); Map latestSpecs = - context.sideInput(featureSetSpecsView).entrySet().stream() - .map(e -> ImmutablePair.of(e.getKey(), Iterators.getLast(e.getValue().iterator()))) - .collect(Collectors.toMap(ImmutablePair::getLeft, ImmutablePair::getRight)); + getLatestSpecs(context.sideInput(featureSetSpecsView)); + + Map deduplicatedRows = + deduplicateRows(context.element(), latestSpecs); try { - executeBatch(featureRows, latestSpecs); - featureRows.forEach(row -> context.output(successfulInsertsTag, row)); + executeBatch( + (redisIngestionClient) -> + deduplicatedRows.entrySet().stream() + .map( + entry -> + redisIngestionClient + .get(entry.getKey().toByteArray()) + .thenAccept( + currentValue -> { + FeatureRow newRow = entry.getValue(); + if (rowShouldBeWritten(newRow, currentValue)) { + filteredFeatureRows.add(newRow); + } + })) + .collect(Collectors.toList())); + + executeBatch( + redisIngestionClient -> + filteredFeatureRows.stream() + .map( + row -> + redisIngestionClient.set( + getKey(row, latestSpecs.get(row.getFeatureSet())).toByteArray(), + getValue(row, latestSpecs.get(row.getFeatureSet())) + .toByteArray())) + .collect(Collectors.toList())); + + filteredFeatureRows.forEach(row -> context.output(successfulInsertsTag, row)); } catch (Exception e) { - featureRows.forEach( - failedMutation -> { - FailedElement failedElement = - toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName()); - context.output(failedInsertsTupleTag, failedElement); - }); + deduplicatedRows + .values() + .forEach( + failedMutation -> { + FailedElement failedElement = + toFailedElement( + failedMutation, e, context.getPipelineOptions().getJobName()); + context.output(failedInsertsTupleTag, failedElement); + }); } } - @Teardown - public void teardown() { - redisIngestionClient.shutdown(); + boolean rowShouldBeWritten(FeatureRow newRow, byte[] currentValue) { + if (currentValue == null) { + // nothing to compare with + return true; + } + FeatureRow currentRow; + try { + currentRow = FeatureRow.parseFrom(currentValue); + } catch (InvalidProtocolBufferException e) { + // definitely need to replace current value + return true; + } + + // check whether new row has later eventTimestamp + return new DateTime(currentRow.getEventTimestamp().getSeconds() * 1000L) + .isBefore(new DateTime(newRow.getEventTimestamp().getSeconds() * 1000L)); + } + + /** Deduplicate rows by key within batch. Keep only latest eventTimestamp */ + Map deduplicateRows( + Iterable rows, Map latestSpecs) { + Comparator byEventTimestamp = + Comparator.comparing(r -> r.getEventTimestamp().getSeconds()); + + FeatureRow identity = + FeatureRow.newBuilder() + .setEventTimestamp( + com.google.protobuf.Timestamp.newBuilder().setSeconds(-1).build()) + .build(); + + return Streams.stream(rows) + .collect( + Collectors.groupingBy( + row -> getKey(row, latestSpecs.get(row.getFeatureSet())), + Collectors.reducing(identity, BinaryOperator.maxBy(byEventTimestamp)))); + } + + Map getLatestSpecs(Map> specs) { + return specs.entrySet().stream() + .map(e -> ImmutablePair.of(e.getKey(), Iterators.getLast(e.getValue().iterator()))) + .collect(Collectors.toMap(ImmutablePair::getLeft, ImmutablePair::getRight)); } } } diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java index 6616a79aac..e9b1a5dc44 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java @@ -18,6 +18,8 @@ import feast.storage.common.retry.BackOffExecutor; import java.io.Serializable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; public interface RedisIngestionClient extends Serializable { @@ -31,19 +33,9 @@ public interface RedisIngestionClient extends Serializable { boolean isConnected(); - void sync(); + void sync(Iterable> futures); - void pexpire(byte[] key, Long expiryMillis); + CompletableFuture set(byte[] key, byte[] value); - void append(byte[] key, byte[] value); - - void set(byte[] key, byte[] value); - - void lpush(byte[] key, byte[] value); - - void rpush(byte[] key, byte[] value); - - void sadd(byte[] key, byte[] value); - - void zadd(byte[] key, Long score, byte[] value); + CompletableFuture get(byte[] key); } diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java index 24591a1dc0..f0a2054b9b 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java @@ -21,12 +21,12 @@ import feast.storage.common.retry.BackOffExecutor; import io.lettuce.core.LettuceFutures; import io.lettuce.core.RedisClient; -import io.lettuce.core.RedisFuture; import io.lettuce.core.RedisURI; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.codec.ByteArrayCodec; -import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import org.joda.time.Duration; @@ -38,7 +38,6 @@ public class RedisStandaloneIngestionClient implements RedisIngestionClient { private static final int DEFAULT_TIMEOUT = 2000; private StatefulRedisConnection connection; private RedisAsyncCommands commands; - private List futures = Lists.newArrayList(); public RedisStandaloneIngestionClient(StoreProto.Store.RedisConfig redisConfig) { this.host = redisConfig.getHost(); @@ -69,6 +68,9 @@ public void connect() { if (!isConnected()) { this.connection = this.redisclient.connect(new ByteArrayCodec()); this.commands = connection.async(); + + // enable pipelining of commands + this.commands.setAutoFlushCommands(false); } } @@ -78,48 +80,20 @@ public boolean isConnected() { } @Override - public void sync() { - // Wait for some time for futures to complete - // TODO: should this be configurable? - try { - LettuceFutures.awaitAll(60, TimeUnit.SECONDS, futures.toArray(new RedisFuture[0])); - } finally { - futures.clear(); - } - } - - @Override - public void pexpire(byte[] key, Long expiryMillis) { - commands.pexpire(key, expiryMillis); - } - - @Override - public void append(byte[] key, byte[] value) { - futures.add(commands.append(key, value)); - } - - @Override - public void set(byte[] key, byte[] value) { - futures.add(commands.set(key, value)); - } + public void sync(Iterable> futures) { + this.connection.flushCommands(); - @Override - public void lpush(byte[] key, byte[] value) { - futures.add(commands.lpush(key, value)); - } - - @Override - public void rpush(byte[] key, byte[] value) { - futures.add(commands.rpush(key, value)); + LettuceFutures.awaitAll( + 60, TimeUnit.SECONDS, Lists.newArrayList(futures).toArray(new Future[0])); } @Override - public void sadd(byte[] key, byte[] value) { - futures.add(commands.sadd(key, value)); + public CompletableFuture set(byte[] key, byte[] value) { + return commands.set(key, value).toCompletableFuture(); } @Override - public void zadd(byte[] key, Long score, byte[] value) { - futures.add(commands.zadd(key, score, value)); + public CompletableFuture get(byte[] key) { + return commands.get(key).toCompletableFuture(); } } diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisClusterFeatureSinkTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisClusterFeatureSinkTest.java deleted file mode 100644 index 62ddfff3a7..0000000000 --- a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisClusterFeatureSinkTest.java +++ /dev/null @@ -1,539 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.storage.connectors.redis.writer; - -import static feast.storage.common.testing.TestUtil.field; -import static feast.storage.common.testing.TestUtil.hash; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.protobuf.Timestamp; -import feast.common.models.FeatureSetReference; -import feast.proto.core.FeatureSetProto.EntitySpec; -import feast.proto.core.FeatureSetProto.FeatureSetSpec; -import feast.proto.core.FeatureSetProto.FeatureSpec; -import feast.proto.core.StoreProto.Store.RedisClusterConfig; -import feast.proto.storage.RedisProto.RedisKey; -import feast.proto.types.FeatureRowProto.FeatureRow; -import feast.proto.types.FieldProto.Field; -import feast.proto.types.ValueProto.Value; -import feast.proto.types.ValueProto.ValueType.Enum; -import io.lettuce.core.RedisURI; -import io.lettuce.core.cluster.RedisClusterClient; -import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; -import io.lettuce.core.cluster.api.sync.RedisClusterCommands; -import io.lettuce.core.codec.ByteArrayCodec; -import java.io.File; -import java.io.IOException; -import java.nio.file.Paths; -import java.util.*; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.ScheduledThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import net.ishiis.redis.unit.RedisCluster; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.View; -import org.apache.beam.sdk.transforms.windowing.*; -import org.apache.beam.sdk.values.PCollection; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; - -public class RedisClusterFeatureSinkTest { - @Rule public transient TestPipeline p = TestPipeline.create(); - - private static String REDIS_CLUSTER_HOST = "localhost"; - private static int REDIS_CLUSTER_PORT1 = 6380; - private static int REDIS_CLUSTER_PORT2 = 6381; - private static int REDIS_CLUSTER_PORT3 = 6382; - private static String CONNECTION_STRING = "localhost:6380,localhost:6381,localhost:6382"; - private RedisCluster redisCluster; - private RedisClusterClient redisClusterClient; - private RedisClusterCommands redisClusterCommands; - - private RedisFeatureSink redisClusterFeatureSink; - - @Before - public void setUp() throws IOException { - redisCluster = new RedisCluster(REDIS_CLUSTER_PORT1, REDIS_CLUSTER_PORT2, REDIS_CLUSTER_PORT3); - redisCluster.start(); - redisClusterClient = - RedisClusterClient.create( - Arrays.asList( - RedisURI.create(REDIS_CLUSTER_HOST, REDIS_CLUSTER_PORT1), - RedisURI.create(REDIS_CLUSTER_HOST, REDIS_CLUSTER_PORT2), - RedisURI.create(REDIS_CLUSTER_HOST, REDIS_CLUSTER_PORT3))); - StatefulRedisClusterConnection connection = - redisClusterClient.connect(new ByteArrayCodec()); - redisClusterCommands = connection.sync(); - redisClusterCommands.setTimeout(java.time.Duration.ofMillis(600000)); - - FeatureSetSpec spec1 = - FeatureSetSpec.newBuilder() - .setName("fs") - .setProject("myproject") - .addEntities(EntitySpec.newBuilder().setName("entity").setValueType(Enum.INT64).build()) - .addFeatures( - FeatureSpec.newBuilder().setName("feature").setValueType(Enum.STRING).build()) - .build(); - - FeatureSetSpec spec2 = - FeatureSetSpec.newBuilder() - .setName("feature_set") - .setProject("myproject") - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_primary") - .setValueType(Enum.INT32) - .build()) - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_secondary") - .setValueType(Enum.STRING) - .build()) - .addFeatures( - FeatureSpec.newBuilder().setName("feature_1").setValueType(Enum.STRING).build()) - .addFeatures( - FeatureSpec.newBuilder().setName("feature_2").setValueType(Enum.INT64).build()) - .build(); - - Map specMap = - ImmutableMap.of( - FeatureSetReference.of("myproject", "fs", 1), spec1, - FeatureSetReference.of("myproject", "feature_set", 1), spec2); - RedisClusterConfig redisClusterConfig = - RedisClusterConfig.newBuilder() - .setConnectionString(CONNECTION_STRING) - .setInitialBackoffMs(2000) - .setMaxRetries(4) - .build(); - - redisClusterFeatureSink = - RedisFeatureSink.builder().setRedisClusterConfig(redisClusterConfig).build(); - redisClusterFeatureSink.prepareWrite(p.apply("Specs-1", Create.of(specMap))); - } - - static boolean deleteDirectory(File directoryToBeDeleted) { - File[] allContents = directoryToBeDeleted.listFiles(); - if (allContents != null) { - for (File file : allContents) { - deleteDirectory(file); - } - } - return directoryToBeDeleted.delete(); - } - - @After - public void teardown() { - redisCluster.stop(); - redisClusterClient.shutdown(); - deleteDirectory(new File(String.valueOf(Paths.get(System.getProperty("user.dir"), ".redis")))); - } - - @Test - public void shouldWriteToRedis() { - - HashMap kvs = new LinkedHashMap<>(); - kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields( - Field.newBuilder() - .setName(hash("feature")) - .setValue(Value.newBuilder().setStringVal("one"))) - .build()); - kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 2, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields( - Field.newBuilder() - .setName(hash("feature")) - .setValue(Value.newBuilder().setStringVal("two"))) - .build()); - - List featureRows = - ImmutableList.of( - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build(), - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 2, Enum.INT64)) - .addFields(field("feature", "two", Enum.STRING)) - .build()); - - p.apply(Create.of(featureRows)).apply(redisClusterFeatureSink.writer()); - p.run(); - - kvs.forEach( - (key, value) -> { - byte[] actual = redisClusterCommands.get(key.toByteArray()); - assertThat(actual, equalTo(value.toByteArray())); - }); - } - - @Test(timeout = 15000) - public void shouldRetryFailConnection() throws InterruptedException { - HashMap kvs = new LinkedHashMap<>(); - kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields( - Field.newBuilder() - .setName(hash("feature")) - .setValue(Value.newBuilder().setStringVal("one"))) - .build()); - - List featureRows = - ImmutableList.of( - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build()); - - PCollection failedElementCount = - p.apply(Create.of(featureRows)) - .apply(redisClusterFeatureSink.writer()) - .getFailedInserts() - .apply(Count.globally()); - - redisCluster.stop(); - final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1); - ScheduledFuture scheduledRedisRestart = - redisRestartExecutor.schedule( - () -> { - redisCluster.start(); - }, - 3, - TimeUnit.SECONDS); - - PAssert.that(failedElementCount).containsInAnyOrder(0L); - p.run(); - scheduledRedisRestart.cancel(true); - - kvs.forEach( - (key, value) -> { - byte[] actual = redisClusterCommands.get(key.toByteArray()); - assertThat(actual, equalTo(value.toByteArray())); - }); - } - - @Test - public void shouldProduceFailedElementIfRetryExceeded() { - RedisClusterConfig redisClusterConfig = - RedisClusterConfig.newBuilder() - .setConnectionString(CONNECTION_STRING) - .setInitialBackoffMs(2000) - .setMaxRetries(1) - .build(); - - FeatureSetSpec spec1 = - FeatureSetSpec.newBuilder() - .setName("fs") - .setProject("myproject") - .addEntities(EntitySpec.newBuilder().setName("entity").setValueType(Enum.INT64).build()) - .addFeatures( - FeatureSpec.newBuilder().setName("feature").setValueType(Enum.STRING).build()) - .build(); - Map specMap = ImmutableMap.of("myproject/fs", spec1); - redisClusterFeatureSink = - RedisFeatureSink.builder() - .setRedisClusterConfig(redisClusterConfig) - .build() - .withSpecsView(p.apply("Specs-2", Create.of(specMap)).apply("View", View.asMultimap())); - - redisCluster.stop(); - - List featureRows = - ImmutableList.of( - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build()); - - PCollection failedElementCount = - p.apply(Create.of(featureRows)) - .apply("modifiedSink", redisClusterFeatureSink.writer()) - .getFailedInserts() - .apply(Count.globally()); - - PAssert.that(failedElementCount).containsInAnyOrder(1L); - p.run(); - } - - @Test - public void shouldConvertRowWithDuplicateEntitiesToValidKey() { - - FeatureRow offendingRow = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(2))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); - - RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); - - FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName(hash("feature_1")) - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName(hash("feature_2")) - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); - - p.apply(Create.of(offendingRow)).apply(redisClusterFeatureSink.writer()); - - p.run(); - - byte[] actual = redisClusterCommands.get(expectedKey.toByteArray()); - assertThat(actual, equalTo(expectedValue.toByteArray())); - } - - @Test - public void shouldConvertRowWithOutOfOrderFieldsToValidKey() { - FeatureRow offendingRow = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .build(); - - RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); - - List expectedFields = - Arrays.asList( - Field.newBuilder() - .setName(hash("feature_1")) - .setValue(Value.newBuilder().setStringVal("strValue1")) - .build(), - Field.newBuilder() - .setName(hash("feature_2")) - .setValue(Value.newBuilder().setInt64Val(1001)) - .build()); - FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addAllFields(expectedFields) - .build(); - - p.apply(Create.of(offendingRow)).apply(redisClusterFeatureSink.writer()); - - p.run(); - - byte[] actual = redisClusterCommands.get(expectedKey.toByteArray()); - assertThat(actual, equalTo(expectedValue.toByteArray())); - } - - @Test - public void shouldMergeDuplicateFeatureFields() { - FeatureRow featureRowWithDuplicatedFeatureFields = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); - - RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); - - FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName(hash("feature_1")) - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName(hash("feature_2")) - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); - - p.apply(Create.of(featureRowWithDuplicatedFeatureFields)) - .apply(redisClusterFeatureSink.writer()); - - p.run(); - - byte[] actual = redisClusterCommands.get(expectedKey.toByteArray()); - assertThat(actual, equalTo(expectedValue.toByteArray())); - } - - @Test - public void shouldPopulateMissingFeatureValuesWithDefaultInstance() { - FeatureRow featureRowWithDuplicatedFeatureFields = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .build(); - - RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); - - FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName(hash("feature_1")) - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder().setName(hash("feature_2")).setValue(Value.getDefaultInstance())) - .build(); - - p.apply(Create.of(featureRowWithDuplicatedFeatureFields)) - .apply(redisClusterFeatureSink.writer()); - - p.run(); - - byte[] actual = redisClusterCommands.get(expectedKey.toByteArray()); - assertThat(actual, equalTo(expectedValue.toByteArray())); - } -} diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java index 948b8d0fda..12377fd1d1 100644 --- a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java +++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java @@ -20,63 +20,112 @@ import static feast.storage.common.testing.TestUtil.hash; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.collection.IsCollectionWithSize.hasSize; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.protobuf.Message; import com.google.protobuf.Timestamp; import feast.common.models.FeatureSetReference; import feast.proto.core.FeatureSetProto.EntitySpec; import feast.proto.core.FeatureSetProto.FeatureSetSpec; import feast.proto.core.FeatureSetProto.FeatureSpec; import feast.proto.core.StoreProto; +import feast.proto.core.StoreProto.Store.RedisClusterConfig; import feast.proto.core.StoreProto.Store.RedisConfig; import feast.proto.storage.RedisProto.RedisKey; import feast.proto.types.FeatureRowProto.FeatureRow; import feast.proto.types.FieldProto.Field; import feast.proto.types.ValueProto.Value; import feast.proto.types.ValueProto.ValueType.Enum; +import io.lettuce.core.AbstractRedisClient; import io.lettuce.core.RedisClient; import io.lettuce.core.RedisURI; -import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.sync.RedisStringCommands; +import io.lettuce.core.cluster.RedisClusterClient; import io.lettuce.core.codec.ByteArrayCodec; -import java.io.IOException; import java.util.*; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import net.ishiis.redis.unit.Redis; +import net.ishiis.redis.unit.RedisCluster; +import net.ishiis.redis.unit.RedisServer; +import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import redis.embedded.Redis; -import redis.embedded.RedisServer; +import org.junit.*; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +@RunWith(Parameterized.class) public class RedisFeatureSinkTest { @Rule public transient TestPipeline p = TestPipeline.create(); private static String REDIS_HOST = "localhost"; - private static int REDIS_PORT = 51234; - private Redis redis; - private RedisClient redisClient; - private RedisStringCommands sync; + private static int REDIS_PORT = 51233; + private static Integer[] REDIS_CLUSTER_PORTS = {6380, 6381, 6382}; + private RedisStringCommands sync; private RedisFeatureSink redisFeatureSink; private Map specMap; - @Before - public void setUp() throws IOException { - redis = new RedisServer(REDIS_PORT); - redis.start(); - redisClient = + @Parameterized.Parameters + public static Iterable backends() { + Redis redis = new RedisServer(REDIS_PORT); + RedisClient client = RedisClient.create(new RedisURI(REDIS_HOST, REDIS_PORT, java.time.Duration.ofMillis(2000))); - StatefulRedisConnection connection = redisClient.connect(new ByteArrayCodec()); - sync = connection.sync(); + + Redis redisCluster = new RedisCluster(REDIS_CLUSTER_PORTS); + RedisClusterClient clientCluster = + RedisClusterClient.create( + Lists.newArrayList(REDIS_CLUSTER_PORTS).stream() + .map(port -> RedisURI.create(REDIS_HOST, port)) + .collect(Collectors.toList())); + + StoreProto.Store.RedisConfig redisConfig = + StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build(); + + StoreProto.Store.RedisClusterConfig redisClusterConfig = + StoreProto.Store.RedisClusterConfig.newBuilder() + .setConnectionString( + Lists.newArrayList(REDIS_CLUSTER_PORTS).stream() + .map(port -> String.format("%s:%d", REDIS_HOST, port)) + .collect(Collectors.joining(","))) + .setInitialBackoffMs(2000) + .setMaxRetries(4) + .build(); + + return Arrays.asList( + new Object[] {redis, client, redisConfig}, + new Object[] {redisCluster, clientCluster, redisClusterConfig}); + } + + @Parameterized.Parameter(0) + public Redis redisServer; + + @Parameterized.Parameter(1) + public AbstractRedisClient redisClient; + + @Parameterized.Parameter(2) + public Message redisConfig; + + @Before + public void setUp() { + redisServer.start(); + + if (redisClient instanceof RedisClient) { + sync = ((RedisClient) redisClient).connect(new ByteArrayCodec()).sync(); + } else { + sync = ((RedisClusterClient) redisClient).connect(new ByteArrayCodec()).sync(); + } FeatureSetSpec spec1 = FeatureSetSpec.newBuilder() @@ -111,17 +160,42 @@ public void setUp() throws IOException { ImmutableMap.of( FeatureSetReference.of("myproject", "fs", 1), spec1, FeatureSetReference.of("myproject", "feature_set", 1), spec2); - StoreProto.Store.RedisConfig redisConfig = - StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build(); - redisFeatureSink = RedisFeatureSink.builder().setRedisConfig(redisConfig).build(); + RedisFeatureSink.Builder builder = RedisFeatureSink.builder(); + if (redisConfig instanceof RedisConfig) { + builder = builder.setRedisConfig((RedisConfig) redisConfig); + } else { + builder = builder.setRedisClusterConfig((RedisClusterConfig) redisConfig); + } + redisFeatureSink = builder.build(); redisFeatureSink.prepareWrite(p.apply("Specs-1", Create.of(specMap))); } @After - public void teardown() { - redisClient.shutdown(); - redis.stop(); + public void tearDown() { + if (redisServer.isActive()) { + redisServer.stop(); + } + } + + private RedisKey createRedisKey(String featureSetRef, Field... fields) { + return RedisKey.newBuilder() + .setFeatureSet(featureSetRef) + .addAllEntities(Lists.newArrayList(fields)) + .build(); + } + + private FeatureRow createFeatureRow(String featureSetRef, Timestamp timestamp, Field... fields) { + FeatureRow.Builder builder = FeatureRow.newBuilder(); + if (featureSetRef != null) { + builder.setFeatureSet(featureSetRef); + } + + if (timestamp != null) { + builder.setEventTimestamp(timestamp); + } + + return builder.addAllFields(Lists.newArrayList(fields)).build(); } @Test @@ -129,42 +203,26 @@ public void shouldWriteToRedis() { HashMap kvs = new LinkedHashMap<>(); kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields( - Field.newBuilder() - .setName(hash("feature")) - .setValue(Value.newBuilder().setStringVal("one"))) - .build()); + createRedisKey("myproject/fs", field("entity", 1, Enum.INT64)), + createFeatureRow( + null, Timestamp.getDefaultInstance(), field(hash("feature"), "one", Enum.STRING))); kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 2, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields( - Field.newBuilder() - .setName(hash("feature")) - .setValue(Value.newBuilder().setStringVal("two"))) - .build()); + createRedisKey("myproject/fs", field("entity", 2, Enum.INT64)), + createFeatureRow( + null, Timestamp.getDefaultInstance(), field(hash("feature"), "two", Enum.STRING))); List featureRows = ImmutableList.of( - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build(), - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 2, Enum.INT64)) - .addFields(field("feature", "two", Enum.STRING)) - .build()); + createFeatureRow( + "myproject/fs", + null, + field("entity", 1, Enum.INT64), + field("feature", "one", Enum.STRING)), + createFeatureRow( + "myproject/fs", + null, + field("entity", 2, Enum.INT64), + field("feature", "two", Enum.STRING))); p.apply(Create.of(featureRows)).apply(redisFeatureSink.writer()); p.run(); @@ -176,7 +234,7 @@ public void shouldWriteToRedis() { }); } - @Test(timeout = 10000) + @Test(timeout = 30000) public void shouldRetryFailConnection() throws InterruptedException { RedisConfig redisConfig = RedisConfig.newBuilder() @@ -194,25 +252,17 @@ public void shouldRetryFailConnection() throws InterruptedException { HashMap kvs = new LinkedHashMap<>(); kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields( - Field.newBuilder() - .setName(hash("feature")) - .setValue(Value.newBuilder().setStringVal("one"))) - .build()); + createRedisKey("myproject/fs", field("entity", 1, Enum.INT64)), + createFeatureRow( + "", Timestamp.getDefaultInstance(), field(hash("feature"), "one", Enum.STRING))); List featureRows = ImmutableList.of( - FeatureRow.newBuilder() - .setFeatureSet("myproject/fs") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build()); + createFeatureRow( + "myproject/fs", + null, + field("entity", 1, Enum.INT64), + field("feature", "one", Enum.STRING))); PCollection failedElementCount = p.apply(Create.of(featureRows)) @@ -220,12 +270,12 @@ public void shouldRetryFailConnection() throws InterruptedException { .getFailedInserts() .apply(Count.globally()); - redis.stop(); + redisServer.stop(); final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1); ScheduledFuture scheduledRedisRestart = redisRestartExecutor.schedule( () -> { - redis.start(); + redisServer.start(); }, 3, TimeUnit.SECONDS); @@ -255,17 +305,9 @@ public void shouldProduceFailedElementIfRetryExceeded() { HashMap kvs = new LinkedHashMap<>(); kvs.put( - RedisKey.newBuilder() - .setFeatureSet("myproject/fs") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.getDefaultInstance()) - .addFields( - Field.newBuilder() - .setName(hash("feature")) - .setValue(Value.newBuilder().setStringVal("one"))) - .build()); + createRedisKey("myproject/fs", field("entity", 1, Enum.INT64)), + createFeatureRow( + "", Timestamp.getDefaultInstance(), field(hash("feature"), "one", Enum.STRING))); List featureRows = ImmutableList.of( @@ -281,7 +323,7 @@ public void shouldProduceFailedElementIfRetryExceeded() { .getFailedInserts() .apply(Count.globally()); - redis.stop(); + redisServer.stop(); PAssert.that(failedElementCount).containsInAnyOrder(1L); p.run(); } @@ -290,56 +332,27 @@ public void shouldProduceFailedElementIfRetryExceeded() { public void shouldConvertRowWithDuplicateEntitiesToValidKey() { FeatureRow offendingRow = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(2))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_primary", 2, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_1", "strValue1", Enum.STRING), + field("feature_2", 1001, Enum.INT64)); RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName(hash("feature_1")) - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName(hash("feature_2")) - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(10).build(), + field(hash("feature_1"), "strValue1", Enum.STRING), + field(hash("feature_2"), 1001, Enum.INT64)); p.apply(Create.of(offendingRow)).apply(redisFeatureSink.writer()); @@ -352,55 +365,26 @@ public void shouldConvertRowWithDuplicateEntitiesToValidKey() { @Test public void shouldConvertRowWithOutOfOrderFieldsToValidKey() { FeatureRow offendingRow = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .build(); + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_secondary", "a", Enum.STRING), + field("entity_id_primary", 1, Enum.INT32), + field("feature_2", 1001, Enum.INT64), + field("feature_1", "strValue1", Enum.STRING)); RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); - List expectedFields = - Arrays.asList( - Field.newBuilder() - .setName(hash("feature_1")) - .setValue(Value.newBuilder().setStringVal("strValue1")) - .build(), - Field.newBuilder() - .setName(hash("feature_2")) - .setValue(Value.newBuilder().setInt64Val(1001)) - .build()); FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addAllFields(expectedFields) - .build(); + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(10).build(), + field(hash("feature_1"), "strValue1", Enum.STRING), + field(hash("feature_2"), 1001, Enum.INT64)); p.apply(Create.of(offendingRow)).apply(redisFeatureSink.writer()); @@ -413,56 +397,27 @@ public void shouldConvertRowWithOutOfOrderFieldsToValidKey() { @Test public void shouldMergeDuplicateFeatureFields() { FeatureRow featureRowWithDuplicatedFeatureFields = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_2", 1001, Enum.INT64), + field("feature_1", "strValue1", Enum.STRING), + field("feature_1", "strValue1", Enum.STRING)); RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName(hash("feature_1")) - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName(hash("feature_2")) - .setValue(Value.newBuilder().setInt64Val(1001))) - .build(); + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(10).build(), + field(hash("feature_1"), "strValue1", Enum.STRING), + field(hash("feature_2"), 1001, Enum.INT64)); p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.writer()); @@ -475,46 +430,28 @@ public void shouldMergeDuplicateFeatureFields() { @Test public void shouldPopulateMissingFeatureValuesWithDefaultInstance() { FeatureRow featureRowWithDuplicatedFeatureFields = - FeatureRow.newBuilder() - .setFeatureSet("myproject/feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .build(); + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_1", "strValue1", Enum.STRING)); RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("myproject/feature_set") - .addEntities( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addEntities( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .build(); + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); FeatureRow expectedValue = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName(hash("feature_1")) - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder().setName(hash("feature_2")).setValue(Value.getDefaultInstance())) - .build(); + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(10).build(), + field(hash("feature_1"), "strValue1", Enum.STRING), + Field.newBuilder() + .setName(hash("feature_2")) + .setValue(Value.getDefaultInstance()) + .build()); p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.writer()); @@ -523,4 +460,206 @@ public void shouldPopulateMissingFeatureValuesWithDefaultInstance() { byte[] actual = sync.get(expectedKey.toByteArray()); assertThat(actual, equalTo(expectedValue.toByteArray())); } + + @Test + public void shouldDeduplicateRowsWithinBatch() { + TestStream featureRowTestStream = + TestStream.create(ProtoCoder.of(FeatureRow.class)) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(20).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_2", 111, Enum.INT32))) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_2", 222, Enum.INT32))) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.getDefaultInstance(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_2", 333, Enum.INT32))) + .advanceWatermarkToInfinity(); + + p.apply(featureRowTestStream).apply(redisFeatureSink.writer()); + p.run(); + + RedisKey expectedKey = + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); + + FeatureRow expectedValue = + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(20).build(), + Field.newBuilder() + .setName(hash("feature_1")) + .setValue(Value.getDefaultInstance()) + .build(), + field(hash("feature_2"), 111, Enum.INT32)); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); + } + + @Test + public void shouldWriteWithLatterTimestamp() { + TestStream featureRowTestStream = + TestStream.create(ProtoCoder.of(FeatureRow.class)) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(20).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_2", 111, Enum.INT32))) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(20).build(), + field("entity_id_primary", 2, Enum.INT32), + field("entity_id_secondary", "b", Enum.STRING), + field("feature_2", 222, Enum.INT32))) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(10).build(), + field("entity_id_primary", 3, Enum.INT32), + field("entity_id_secondary", "c", Enum.STRING), + field("feature_2", 333, Enum.INT32))) + .advanceWatermarkToInfinity(); + + RedisKey keyA = + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); + + RedisKey keyB = + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 2, Enum.INT32), + field("entity_id_secondary", "b", Enum.STRING)); + + RedisKey keyC = + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 3, Enum.INT32), + field("entity_id_secondary", "c", Enum.STRING)); + + sync.set( + keyA.toByteArray(), + createFeatureRow("", Timestamp.newBuilder().setSeconds(30).build()).toByteArray()); + + sync.set( + keyB.toByteArray(), + createFeatureRow("", Timestamp.newBuilder().setSeconds(10).build()).toByteArray()); + + sync.set( + keyC.toByteArray(), + createFeatureRow("", Timestamp.newBuilder().setSeconds(10).build()).toByteArray()); + + p.apply(featureRowTestStream).apply(redisFeatureSink.writer()); + p.run(); + + assertThat( + sync.get(keyA.toByteArray()), + equalTo(createFeatureRow("", Timestamp.newBuilder().setSeconds(30).build()).toByteArray())); + + assertThat( + sync.get(keyB.toByteArray()), + equalTo( + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(20).build(), + Field.newBuilder() + .setName(hash("feature_1")) + .setValue(Value.getDefaultInstance()) + .build(), + field(hash("feature_2"), 222, Enum.INT32)) + .toByteArray())); + + assertThat( + sync.get(keyC.toByteArray()), + equalTo(createFeatureRow("", Timestamp.newBuilder().setSeconds(10).build()).toByteArray())); + } + + @Test + public void shouldOverwriteInvalidRows() { + TestStream featureRowTestStream = + TestStream.create(ProtoCoder.of(FeatureRow.class)) + .addElements( + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(20).build(), + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_1", "text", Enum.STRING), + field("feature_2", 111, Enum.INT32))) + .advanceWatermarkToInfinity(); + + RedisKey expectedKey = + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", 1, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)); + + sync.set(expectedKey.toByteArray(), "some-invalid-data".getBytes()); + + p.apply(featureRowTestStream).apply(redisFeatureSink.writer()); + p.run(); + + FeatureRow expectedValue = + createFeatureRow( + "", + Timestamp.newBuilder().setSeconds(20).build(), + field(hash("feature_1"), "text", Enum.STRING), + field(hash("feature_2"), 111, Enum.INT32)); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); + } + + @Test + public void loadTest() { + List rows = + IntStream.range(0, 10000) + .mapToObj( + i -> + createFeatureRow( + "myproject/feature_set", + Timestamp.newBuilder().setSeconds(20).build(), + field("entity_id_primary", i, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING), + field("feature_1", "text", Enum.STRING), + field("feature_2", 111, Enum.INT32))) + .collect(Collectors.toList()); + + p.apply(Create.of(rows)).apply(redisFeatureSink.writer()); + p.run(); + + List outcome = + IntStream.range(0, 10000) + .mapToObj( + i -> + createRedisKey( + "myproject/feature_set", + field("entity_id_primary", i, Enum.INT32), + field("entity_id_secondary", "a", Enum.STRING)) + .toByteArray()) + .map(sync::get) + .collect(Collectors.toList()); + + assertThat(outcome, hasSize(10000)); + assertThat("All rows were stored", outcome.stream().allMatch(Objects::nonNull)); + } } diff --git a/tests/e2e/redis/basic-ingest-redis-serving.py b/tests/e2e/redis/basic-ingest-redis-serving.py index 1fcae69ed3..c1e25508d4 100644 --- a/tests/e2e/redis/basic-ingest-redis-serving.py +++ b/tests/e2e/redis/basic-ingest-redis-serving.py @@ -4,7 +4,7 @@ import tempfile import time import uuid -from datetime import datetime +from datetime import datetime, timedelta import grpc import numpy as np @@ -821,6 +821,8 @@ def all_types_dataframe(): @pytest.mark.timeout(45) @pytest.mark.run(order=20) def test_all_types_register_feature_set_success(client): + client.set_project(PROJECT_NAME) + all_types_fs_expected = FeatureSet( name="all_types", entities=[Entity(name="user_id", dtype=ValueType.INT64)], @@ -930,9 +932,11 @@ def try_get_features(): @pytest.mark.timeout(300) -@pytest.mark.run(order=29) +@pytest.mark.run(order=35) def test_all_types_ingest_jobs(client, all_types_dataframe): # list ingestion jobs given featureset + client.set_project(PROJECT_NAME) + all_types_fs = client.get_feature_set(name="all_types") ingest_jobs = client.list_ingest_jobs( feature_set_ref=FeatureSetRef.from_feature_set(all_types_fs) @@ -990,7 +994,7 @@ def large_volume_dataframe(): @pytest.mark.timeout(45) -@pytest.mark.run(order=30) +@pytest.mark.run(order=40) def test_large_volume_register_feature_set_success(client): cust_trans_fs_expected = FeatureSet.from_yaml( f"{DIR_PATH}/large_volume/cust_trans_large_fs.yaml" @@ -1016,7 +1020,7 @@ def test_large_volume_register_feature_set_success(client): @pytest.mark.timeout(300) -@pytest.mark.run(order=31) +@pytest.mark.run(order=41) def test_large_volume_ingest_success(client, large_volume_dataframe): # Get large volume feature set cust_trans_fs = client.get_feature_set(name="customer_transactions_large") @@ -1026,7 +1030,7 @@ def test_large_volume_ingest_success(client, large_volume_dataframe): @pytest.mark.timeout(90) -@pytest.mark.run(order=32) +@pytest.mark.run(order=42) def test_large_volume_retrieve_online_success(client, large_volume_dataframe): # Poll serving for feature values until the correct values are returned feature_refs = [ @@ -1112,7 +1116,7 @@ def all_types_parquet_file(): @pytest.mark.timeout(300) -@pytest.mark.run(order=40) +@pytest.mark.run(order=50) def test_all_types_parquet_register_feature_set_success(client): # Load feature set from file all_types_parquet_expected = FeatureSet.from_yaml( @@ -1140,7 +1144,7 @@ def test_all_types_parquet_register_feature_set_success(client): @pytest.mark.timeout(600) -@pytest.mark.run(order=41) +@pytest.mark.run(order=51) def test_all_types_infer_register_ingest_file_success(client, all_types_parquet_file): # Get feature set all_types_fs = client.get_feature_set(name="all_types_parquet") @@ -1150,7 +1154,7 @@ def test_all_types_infer_register_ingest_file_success(client, all_types_parquet_ @pytest.mark.timeout(200) -@pytest.mark.run(order=50) +@pytest.mark.run(order=60) def test_list_entities_and_features(client): customer_entity = Entity("customer_id", ValueType.INT64) driver_entity = Entity("driver_id", ValueType.INT64) @@ -1225,7 +1229,7 @@ def test_list_entities_and_features(client): @pytest.mark.timeout(900) -@pytest.mark.run(order=60) +@pytest.mark.run(order=70) def test_sources_deduplicate_ingest_jobs(client): source = KafkaSource("localhost:9092", "feast-features") alt_source = KafkaSource("localhost:9092", "feast-data") @@ -1273,6 +1277,58 @@ def get_running_jobs(): time.sleep(1) +@pytest.mark.run(order=30) +def test_sink_writes_only_recent_rows(client): + client.set_project("default") + + feature_refs = ["driver:rating", "driver:cost"] + + later_df = basic_dataframe( + entities=["driver_id"], + features=["rating", "cost"], + ingest_time=datetime.utcnow(), + n_size=5, + ) + + earlier_df = basic_dataframe( + entities=["driver_id"], + features=["rating", "cost"], + ingest_time=datetime.utcnow() - timedelta(minutes=5), + n_size=5, + ) + + def try_get_features(): + response = client.get_online_features( + entity_rows=[ + GetOnlineFeaturesRequest.EntityRow( + fields={"driver_id": Value(int64_val=later_df.iloc[0]["driver_id"])} + ) + ], + feature_refs=feature_refs, + ) # type: GetOnlineFeaturesResponse + is_ok = all( + [check_online_response(ref, later_df, response) for ref in feature_refs] + ) + return response, is_ok + + # test compaction within batch + client.ingest("driver", pd.concat([earlier_df, later_df])) + wait_retry_backoff( + retry_fn=try_get_features, + timeout_secs=90, + timeout_msg="Timed out trying to get online feature values", + ) + + # test read before write + client.ingest("driver", earlier_df) + time.sleep(10) + wait_retry_backoff( + retry_fn=try_get_features, + timeout_secs=90, + timeout_msg="Timed out trying to get online feature values", + ) + + # TODO: rewrite these using python SDK once the labels are implemented there class TestsBasedOnGrpc: GRPC_CONNECTION_TIMEOUT = 3