diff --git a/storage/connectors/redis/pom.xml b/storage/connectors/redis/pom.xml
index d0e127cde8..ca6e8d42ad 100644
--- a/storage/connectors/redis/pom.xml
+++ b/storage/connectors/redis/pom.xml
@@ -89,7 +89,12 @@
4.12
test
-
+
+ org.apache.beam
+ beam-sdks-java-extensions-protobuf
+ ${org.apache.beam.version}
+ test
+
org.slf4j
slf4j-simple
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/BatchDoFnWithRedis.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/BatchDoFnWithRedis.java
new file mode 100644
index 0000000000..d6c83c3a54
--- /dev/null
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/BatchDoFnWithRedis.java
@@ -0,0 +1,88 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ * Copyright 2018-2020 The Feast Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package feast.storage.connectors.redis.writer;
+
+import feast.storage.common.retry.Retriable;
+import io.lettuce.core.RedisException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.function.Function;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Base class for redis-related DoFns. Assumes that operations will be batched. Prepares redisClient
+ * on DoFn.Setup stage and close it on DoFn.Teardown stage.
+ *
+ * @param
+ * @param
+ */
+public class BatchDoFnWithRedis extends DoFn {
+ private static final Logger log = LoggerFactory.getLogger(BatchDoFnWithRedis.class);
+
+ private final RedisIngestionClient redisIngestionClient;
+
+ BatchDoFnWithRedis(RedisIngestionClient redisIngestionClient) {
+ this.redisIngestionClient = redisIngestionClient;
+ }
+
+ @Setup
+ public void setup() {
+ this.redisIngestionClient.setup();
+ }
+
+ @StartBundle
+ public void startBundle() {
+ try {
+ redisIngestionClient.connect();
+ } catch (RedisException e) {
+ log.error("Connection to redis cannot be established: %s", e);
+ }
+ }
+
+ void executeBatch(Function>> executor)
+ throws Exception {
+ this.redisIngestionClient
+ .getBackOffExecutor()
+ .execute(
+ new Retriable() {
+ @Override
+ public void execute() throws ExecutionException, InterruptedException {
+ if (!redisIngestionClient.isConnected()) {
+ redisIngestionClient.connect();
+ }
+
+ Iterable> futures = executor.apply(redisIngestionClient);
+ redisIngestionClient.sync(futures);
+ }
+
+ @Override
+ public Boolean isExceptionRetriable(Exception e) {
+ return e instanceof RedisException;
+ }
+
+ @Override
+ public void cleanUpAfterFailure() {}
+ });
+ }
+
+ @Teardown
+ public void teardown() {
+ redisIngestionClient.shutdown();
+ }
+}
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java
index 389db4be3a..f36d70563e 100644
--- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisClusterIngestionClient.java
@@ -20,7 +20,6 @@
import feast.proto.core.StoreProto;
import feast.storage.common.retry.BackOffExecutor;
import io.lettuce.core.LettuceFutures;
-import io.lettuce.core.RedisFuture;
import io.lettuce.core.RedisURI;
import io.lettuce.core.cluster.RedisClusterClient;
import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
@@ -28,6 +27,8 @@
import io.lettuce.core.codec.ByteArrayCodec;
import java.util.Arrays;
import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.joda.time.Duration;
@@ -39,7 +40,6 @@ public class RedisClusterIngestionClient implements RedisIngestionClient {
private transient RedisClusterClient clusterClient;
private StatefulRedisClusterConnection connection;
private RedisAdvancedClusterAsyncCommands commands;
- private List futures = Lists.newArrayList();
public RedisClusterIngestionClient(StoreProto.Store.RedisClusterConfig redisClusterConfig) {
this.uriList =
@@ -55,7 +55,6 @@ public RedisClusterIngestionClient(StoreProto.Store.RedisClusterConfig redisClus
redisClusterConfig.getInitialBackoffMs() > 0 ? redisClusterConfig.getInitialBackoffMs() : 1;
this.backOffExecutor =
new BackOffExecutor(redisClusterConfig.getMaxRetries(), Duration.millis(backoffMs));
- this.clusterClient = RedisClusterClient.create(uriList);
}
@Override
@@ -78,6 +77,10 @@ public void connect() {
if (!isConnected()) {
this.connection = clusterClient.connect(new ByteArrayCodec());
this.commands = connection.async();
+
+ // despite we're using async API client still flushes after each command by default
+ // which we don't want since we produce all commands in batches
+ this.commands.setAutoFlushCommands(false);
}
}
@@ -87,46 +90,20 @@ public boolean isConnected() {
}
@Override
- public void sync() {
- try {
- LettuceFutures.awaitAll(60, TimeUnit.SECONDS, futures.toArray(new RedisFuture[0]));
- } finally {
- futures.clear();
- }
- }
-
- @Override
- public void pexpire(byte[] key, Long expiryMillis) {
- futures.add(commands.pexpire(key, expiryMillis));
- }
-
- @Override
- public void append(byte[] key, byte[] value) {
- futures.add(commands.append(key, value));
- }
-
- @Override
- public void set(byte[] key, byte[] value) {
- futures.add(commands.set(key, value));
- }
+ public void sync(Iterable> futures) {
+ this.connection.flushCommands();
- @Override
- public void lpush(byte[] key, byte[] value) {
- futures.add(commands.lpush(key, value));
- }
-
- @Override
- public void rpush(byte[] key, byte[] value) {
- futures.add(commands.rpush(key, value));
+ LettuceFutures.awaitAll(
+ 60, TimeUnit.SECONDS, Lists.newArrayList(futures).toArray(new Future[0]));
}
@Override
- public void sadd(byte[] key, byte[] value) {
- futures.add(commands.sadd(key, value));
+ public CompletableFuture set(byte[] key, byte[] value) {
+ return commands.set(key, value).toCompletableFuture();
}
@Override
- public void zadd(byte[] key, Long score, byte[] value) {
- futures.add(commands.zadd(key, score, value));
+ public CompletableFuture get(byte[] key) {
+ return commands.get(key).toCompletableFuture();
}
}
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java
index f73c458d78..c42cff7bd0 100644
--- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java
@@ -17,8 +17,9 @@
package feast.storage.connectors.redis.writer;
import com.google.common.collect.Iterators;
-import com.google.common.collect.Lists;
+import com.google.common.collect.Streams;
import com.google.common.hash.Hashing;
+import com.google.protobuf.InvalidProtocolBufferException;
import feast.proto.core.FeatureSetProto.EntitySpec;
import feast.proto.core.FeatureSetProto.FeatureSetSpec;
import feast.proto.core.FeatureSetProto.FeatureSpec;
@@ -29,20 +30,20 @@
import feast.proto.types.ValueProto;
import feast.storage.api.writer.FailedElement;
import feast.storage.api.writer.WriteResult;
-import feast.storage.common.retry.Retriable;
import feast.storage.connectors.redis.retriever.FeatureRowDecoder;
-import io.lettuce.core.RedisException;
import java.nio.charset.StandardCharsets;
+import java.util.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.concurrent.ExecutionException;
+import java.util.function.BinaryOperator;
import java.util.stream.Collectors;
import org.apache.beam.sdk.transforms.*;
import org.apache.beam.sdk.transforms.windowing.*;
import org.apache.beam.sdk.values.*;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -116,63 +117,23 @@ public void process(ProcessContext c) {
redisWrite.get(failedInsertsTupleTag));
}
- public static class WriteDoFn extends DoFn, FeatureRow> {
- private PCollectionView>> featureSetSpecsView;
- private RedisIngestionClient redisIngestionClient;
+ /**
+ * Writes batch of {@link FeatureRow} to Redis. Only latest values should be written. In order
+ * to guarantee that we first fetch all existing values (first batch operation), compare with
+ * current batch by eventTimestamp, and send to redis values (second batch operation) that were
+ * confirmed to be most recent.
+ */
+ public static class WriteDoFn extends BatchDoFnWithRedis, FeatureRow> {
+ private final PCollectionView>> featureSetSpecsView;
WriteDoFn(
RedisIngestionClient redisIngestionClient,
PCollectionView>> featureSetSpecsView) {
- this.redisIngestionClient = redisIngestionClient;
+ super(redisIngestionClient);
this.featureSetSpecsView = featureSetSpecsView;
}
- @Setup
- public void setup() {
- this.redisIngestionClient.setup();
- }
-
- @StartBundle
- public void startBundle() {
- try {
- redisIngestionClient.connect();
- } catch (RedisException e) {
- log.error("Connection to redis cannot be established ", e);
- }
- }
-
- private void executeBatch(
- Iterable featureRows, Map featureSetSpecs)
- throws Exception {
- this.redisIngestionClient
- .getBackOffExecutor()
- .execute(
- new Retriable() {
- @Override
- public void execute() throws ExecutionException, InterruptedException {
- if (!redisIngestionClient.isConnected()) {
- redisIngestionClient.connect();
- }
- featureRows.forEach(
- row -> {
- redisIngestionClient.set(
- getKey(row, featureSetSpecs.get(row.getFeatureSet())),
- getValue(row, featureSetSpecs.get(row.getFeatureSet())));
- });
- redisIngestionClient.sync();
- }
-
- @Override
- public Boolean isExceptionRetriable(Exception e) {
- return e instanceof RedisException;
- }
-
- @Override
- public void cleanUpAfterFailure() {}
- });
- }
-
private FailedElement toFailedElement(
FeatureRow featureRow, Exception exception, String jobName) {
return FailedElement.newBuilder()
@@ -184,7 +145,7 @@ private FailedElement toFailedElement(
.build();
}
- private byte[] getKey(FeatureRow featureRow, FeatureSetSpec spec) {
+ private RedisKey getKey(FeatureRow featureRow, FeatureSetSpec spec) {
List entityNames =
spec.getEntitiesList().stream()
.map(EntitySpec::getName)
@@ -203,7 +164,7 @@ private byte[] getKey(FeatureRow featureRow, FeatureSetSpec spec) {
for (String entityName : entityNames) {
redisKeyBuilder.addEntities(entityFields.get(entityName));
}
- return redisKeyBuilder.build().toByteArray();
+ return redisKeyBuilder.build();
}
/**
@@ -212,7 +173,7 @@ private byte[] getKey(FeatureRow featureRow, FeatureSetSpec spec) {
* names and not unsetting the feature set reference. {@link FeatureRowDecoder} is
* rensponsible for reversing this "encoding" step.
*/
- private byte[] getValue(FeatureRow featureRow, FeatureSetSpec spec) {
+ private FeatureRow getValue(FeatureRow featureRow, FeatureSetSpec spec) {
List featureNames =
spec.getFeaturesList().stream().map(FeatureSpec::getName).collect(Collectors.toList());
@@ -250,35 +211,101 @@ private byte[] getValue(FeatureRow featureRow, FeatureSetSpec spec) {
return FeatureRow.newBuilder()
.setEventTimestamp(featureRow.getEventTimestamp())
.addAllFields(values)
- .build()
- .toByteArray();
+ .build();
}
@ProcessElement
public void processElement(ProcessContext context) {
- List featureRows = Lists.newArrayList(context.element().iterator());
-
+ List filteredFeatureRows = Collections.synchronizedList(new ArrayList<>());
Map latestSpecs =
- context.sideInput(featureSetSpecsView).entrySet().stream()
- .map(e -> ImmutablePair.of(e.getKey(), Iterators.getLast(e.getValue().iterator())))
- .collect(Collectors.toMap(ImmutablePair::getLeft, ImmutablePair::getRight));
+ getLatestSpecs(context.sideInput(featureSetSpecsView));
+
+ Map deduplicatedRows =
+ deduplicateRows(context.element(), latestSpecs);
try {
- executeBatch(featureRows, latestSpecs);
- featureRows.forEach(row -> context.output(successfulInsertsTag, row));
+ executeBatch(
+ (redisIngestionClient) ->
+ deduplicatedRows.entrySet().stream()
+ .map(
+ entry ->
+ redisIngestionClient
+ .get(entry.getKey().toByteArray())
+ .thenAccept(
+ currentValue -> {
+ FeatureRow newRow = entry.getValue();
+ if (rowShouldBeWritten(newRow, currentValue)) {
+ filteredFeatureRows.add(newRow);
+ }
+ }))
+ .collect(Collectors.toList()));
+
+ executeBatch(
+ redisIngestionClient ->
+ filteredFeatureRows.stream()
+ .map(
+ row ->
+ redisIngestionClient.set(
+ getKey(row, latestSpecs.get(row.getFeatureSet())).toByteArray(),
+ getValue(row, latestSpecs.get(row.getFeatureSet()))
+ .toByteArray()))
+ .collect(Collectors.toList()));
+
+ filteredFeatureRows.forEach(row -> context.output(successfulInsertsTag, row));
} catch (Exception e) {
- featureRows.forEach(
- failedMutation -> {
- FailedElement failedElement =
- toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName());
- context.output(failedInsertsTupleTag, failedElement);
- });
+ deduplicatedRows
+ .values()
+ .forEach(
+ failedMutation -> {
+ FailedElement failedElement =
+ toFailedElement(
+ failedMutation, e, context.getPipelineOptions().getJobName());
+ context.output(failedInsertsTupleTag, failedElement);
+ });
}
}
- @Teardown
- public void teardown() {
- redisIngestionClient.shutdown();
+ boolean rowShouldBeWritten(FeatureRow newRow, byte[] currentValue) {
+ if (currentValue == null) {
+ // nothing to compare with
+ return true;
+ }
+ FeatureRow currentRow;
+ try {
+ currentRow = FeatureRow.parseFrom(currentValue);
+ } catch (InvalidProtocolBufferException e) {
+ // definitely need to replace current value
+ return true;
+ }
+
+ // check whether new row has later eventTimestamp
+ return new DateTime(currentRow.getEventTimestamp().getSeconds() * 1000L)
+ .isBefore(new DateTime(newRow.getEventTimestamp().getSeconds() * 1000L));
+ }
+
+ /** Deduplicate rows by key within batch. Keep only latest eventTimestamp */
+ Map deduplicateRows(
+ Iterable rows, Map latestSpecs) {
+ Comparator byEventTimestamp =
+ Comparator.comparing(r -> r.getEventTimestamp().getSeconds());
+
+ FeatureRow identity =
+ FeatureRow.newBuilder()
+ .setEventTimestamp(
+ com.google.protobuf.Timestamp.newBuilder().setSeconds(-1).build())
+ .build();
+
+ return Streams.stream(rows)
+ .collect(
+ Collectors.groupingBy(
+ row -> getKey(row, latestSpecs.get(row.getFeatureSet())),
+ Collectors.reducing(identity, BinaryOperator.maxBy(byEventTimestamp))));
+ }
+
+ Map getLatestSpecs(Map> specs) {
+ return specs.entrySet().stream()
+ .map(e -> ImmutablePair.of(e.getKey(), Iterators.getLast(e.getValue().iterator())))
+ .collect(Collectors.toMap(ImmutablePair::getLeft, ImmutablePair::getRight));
}
}
}
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java
index 6616a79aac..e9b1a5dc44 100644
--- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java
@@ -18,6 +18,8 @@
import feast.storage.common.retry.BackOffExecutor;
import java.io.Serializable;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
public interface RedisIngestionClient extends Serializable {
@@ -31,19 +33,9 @@ public interface RedisIngestionClient extends Serializable {
boolean isConnected();
- void sync();
+ void sync(Iterable> futures);
- void pexpire(byte[] key, Long expiryMillis);
+ CompletableFuture set(byte[] key, byte[] value);
- void append(byte[] key, byte[] value);
-
- void set(byte[] key, byte[] value);
-
- void lpush(byte[] key, byte[] value);
-
- void rpush(byte[] key, byte[] value);
-
- void sadd(byte[] key, byte[] value);
-
- void zadd(byte[] key, Long score, byte[] value);
+ CompletableFuture get(byte[] key);
}
diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java
index 24591a1dc0..f0a2054b9b 100644
--- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java
+++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java
@@ -21,12 +21,12 @@
import feast.storage.common.retry.BackOffExecutor;
import io.lettuce.core.LettuceFutures;
import io.lettuce.core.RedisClient;
-import io.lettuce.core.RedisFuture;
import io.lettuce.core.RedisURI;
import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.async.RedisAsyncCommands;
import io.lettuce.core.codec.ByteArrayCodec;
-import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.joda.time.Duration;
@@ -38,7 +38,6 @@ public class RedisStandaloneIngestionClient implements RedisIngestionClient {
private static final int DEFAULT_TIMEOUT = 2000;
private StatefulRedisConnection connection;
private RedisAsyncCommands commands;
- private List futures = Lists.newArrayList();
public RedisStandaloneIngestionClient(StoreProto.Store.RedisConfig redisConfig) {
this.host = redisConfig.getHost();
@@ -69,6 +68,9 @@ public void connect() {
if (!isConnected()) {
this.connection = this.redisclient.connect(new ByteArrayCodec());
this.commands = connection.async();
+
+ // enable pipelining of commands
+ this.commands.setAutoFlushCommands(false);
}
}
@@ -78,48 +80,20 @@ public boolean isConnected() {
}
@Override
- public void sync() {
- // Wait for some time for futures to complete
- // TODO: should this be configurable?
- try {
- LettuceFutures.awaitAll(60, TimeUnit.SECONDS, futures.toArray(new RedisFuture[0]));
- } finally {
- futures.clear();
- }
- }
-
- @Override
- public void pexpire(byte[] key, Long expiryMillis) {
- commands.pexpire(key, expiryMillis);
- }
-
- @Override
- public void append(byte[] key, byte[] value) {
- futures.add(commands.append(key, value));
- }
-
- @Override
- public void set(byte[] key, byte[] value) {
- futures.add(commands.set(key, value));
- }
+ public void sync(Iterable> futures) {
+ this.connection.flushCommands();
- @Override
- public void lpush(byte[] key, byte[] value) {
- futures.add(commands.lpush(key, value));
- }
-
- @Override
- public void rpush(byte[] key, byte[] value) {
- futures.add(commands.rpush(key, value));
+ LettuceFutures.awaitAll(
+ 60, TimeUnit.SECONDS, Lists.newArrayList(futures).toArray(new Future[0]));
}
@Override
- public void sadd(byte[] key, byte[] value) {
- futures.add(commands.sadd(key, value));
+ public CompletableFuture set(byte[] key, byte[] value) {
+ return commands.set(key, value).toCompletableFuture();
}
@Override
- public void zadd(byte[] key, Long score, byte[] value) {
- futures.add(commands.zadd(key, score, value));
+ public CompletableFuture get(byte[] key) {
+ return commands.get(key).toCompletableFuture();
}
}
diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisClusterFeatureSinkTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisClusterFeatureSinkTest.java
deleted file mode 100644
index 62ddfff3a7..0000000000
--- a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisClusterFeatureSinkTest.java
+++ /dev/null
@@ -1,539 +0,0 @@
-/*
- * SPDX-License-Identifier: Apache-2.0
- * Copyright 2018-2019 The Feast Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package feast.storage.connectors.redis.writer;
-
-import static feast.storage.common.testing.TestUtil.field;
-import static feast.storage.common.testing.TestUtil.hash;
-import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.MatcherAssert.assertThat;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import com.google.protobuf.Timestamp;
-import feast.common.models.FeatureSetReference;
-import feast.proto.core.FeatureSetProto.EntitySpec;
-import feast.proto.core.FeatureSetProto.FeatureSetSpec;
-import feast.proto.core.FeatureSetProto.FeatureSpec;
-import feast.proto.core.StoreProto.Store.RedisClusterConfig;
-import feast.proto.storage.RedisProto.RedisKey;
-import feast.proto.types.FeatureRowProto.FeatureRow;
-import feast.proto.types.FieldProto.Field;
-import feast.proto.types.ValueProto.Value;
-import feast.proto.types.ValueProto.ValueType.Enum;
-import io.lettuce.core.RedisURI;
-import io.lettuce.core.cluster.RedisClusterClient;
-import io.lettuce.core.cluster.api.StatefulRedisClusterConnection;
-import io.lettuce.core.cluster.api.sync.RedisClusterCommands;
-import io.lettuce.core.codec.ByteArrayCodec;
-import java.io.File;
-import java.io.IOException;
-import java.nio.file.Paths;
-import java.util.*;
-import java.util.concurrent.ScheduledFuture;
-import java.util.concurrent.ScheduledThreadPoolExecutor;
-import java.util.concurrent.TimeUnit;
-import net.ishiis.redis.unit.RedisCluster;
-import org.apache.beam.sdk.testing.PAssert;
-import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Count;
-import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.View;
-import org.apache.beam.sdk.transforms.windowing.*;
-import org.apache.beam.sdk.values.PCollection;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-
-public class RedisClusterFeatureSinkTest {
- @Rule public transient TestPipeline p = TestPipeline.create();
-
- private static String REDIS_CLUSTER_HOST = "localhost";
- private static int REDIS_CLUSTER_PORT1 = 6380;
- private static int REDIS_CLUSTER_PORT2 = 6381;
- private static int REDIS_CLUSTER_PORT3 = 6382;
- private static String CONNECTION_STRING = "localhost:6380,localhost:6381,localhost:6382";
- private RedisCluster redisCluster;
- private RedisClusterClient redisClusterClient;
- private RedisClusterCommands redisClusterCommands;
-
- private RedisFeatureSink redisClusterFeatureSink;
-
- @Before
- public void setUp() throws IOException {
- redisCluster = new RedisCluster(REDIS_CLUSTER_PORT1, REDIS_CLUSTER_PORT2, REDIS_CLUSTER_PORT3);
- redisCluster.start();
- redisClusterClient =
- RedisClusterClient.create(
- Arrays.asList(
- RedisURI.create(REDIS_CLUSTER_HOST, REDIS_CLUSTER_PORT1),
- RedisURI.create(REDIS_CLUSTER_HOST, REDIS_CLUSTER_PORT2),
- RedisURI.create(REDIS_CLUSTER_HOST, REDIS_CLUSTER_PORT3)));
- StatefulRedisClusterConnection connection =
- redisClusterClient.connect(new ByteArrayCodec());
- redisClusterCommands = connection.sync();
- redisClusterCommands.setTimeout(java.time.Duration.ofMillis(600000));
-
- FeatureSetSpec spec1 =
- FeatureSetSpec.newBuilder()
- .setName("fs")
- .setProject("myproject")
- .addEntities(EntitySpec.newBuilder().setName("entity").setValueType(Enum.INT64).build())
- .addFeatures(
- FeatureSpec.newBuilder().setName("feature").setValueType(Enum.STRING).build())
- .build();
-
- FeatureSetSpec spec2 =
- FeatureSetSpec.newBuilder()
- .setName("feature_set")
- .setProject("myproject")
- .addEntities(
- EntitySpec.newBuilder()
- .setName("entity_id_primary")
- .setValueType(Enum.INT32)
- .build())
- .addEntities(
- EntitySpec.newBuilder()
- .setName("entity_id_secondary")
- .setValueType(Enum.STRING)
- .build())
- .addFeatures(
- FeatureSpec.newBuilder().setName("feature_1").setValueType(Enum.STRING).build())
- .addFeatures(
- FeatureSpec.newBuilder().setName("feature_2").setValueType(Enum.INT64).build())
- .build();
-
- Map specMap =
- ImmutableMap.of(
- FeatureSetReference.of("myproject", "fs", 1), spec1,
- FeatureSetReference.of("myproject", "feature_set", 1), spec2);
- RedisClusterConfig redisClusterConfig =
- RedisClusterConfig.newBuilder()
- .setConnectionString(CONNECTION_STRING)
- .setInitialBackoffMs(2000)
- .setMaxRetries(4)
- .build();
-
- redisClusterFeatureSink =
- RedisFeatureSink.builder().setRedisClusterConfig(redisClusterConfig).build();
- redisClusterFeatureSink.prepareWrite(p.apply("Specs-1", Create.of(specMap)));
- }
-
- static boolean deleteDirectory(File directoryToBeDeleted) {
- File[] allContents = directoryToBeDeleted.listFiles();
- if (allContents != null) {
- for (File file : allContents) {
- deleteDirectory(file);
- }
- }
- return directoryToBeDeleted.delete();
- }
-
- @After
- public void teardown() {
- redisCluster.stop();
- redisClusterClient.shutdown();
- deleteDirectory(new File(String.valueOf(Paths.get(System.getProperty("user.dir"), ".redis"))));
- }
-
- @Test
- public void shouldWriteToRedis() {
-
- HashMap kvs = new LinkedHashMap<>();
- kvs.put(
- RedisKey.newBuilder()
- .setFeatureSet("myproject/fs")
- .addEntities(field("entity", 1, Enum.INT64))
- .build(),
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.getDefaultInstance())
- .addFields(
- Field.newBuilder()
- .setName(hash("feature"))
- .setValue(Value.newBuilder().setStringVal("one")))
- .build());
- kvs.put(
- RedisKey.newBuilder()
- .setFeatureSet("myproject/fs")
- .addEntities(field("entity", 2, Enum.INT64))
- .build(),
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.getDefaultInstance())
- .addFields(
- Field.newBuilder()
- .setName(hash("feature"))
- .setValue(Value.newBuilder().setStringVal("two")))
- .build());
-
- List featureRows =
- ImmutableList.of(
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/fs")
- .addFields(field("entity", 1, Enum.INT64))
- .addFields(field("feature", "one", Enum.STRING))
- .build(),
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/fs")
- .addFields(field("entity", 2, Enum.INT64))
- .addFields(field("feature", "two", Enum.STRING))
- .build());
-
- p.apply(Create.of(featureRows)).apply(redisClusterFeatureSink.writer());
- p.run();
-
- kvs.forEach(
- (key, value) -> {
- byte[] actual = redisClusterCommands.get(key.toByteArray());
- assertThat(actual, equalTo(value.toByteArray()));
- });
- }
-
- @Test(timeout = 15000)
- public void shouldRetryFailConnection() throws InterruptedException {
- HashMap kvs = new LinkedHashMap<>();
- kvs.put(
- RedisKey.newBuilder()
- .setFeatureSet("myproject/fs")
- .addEntities(field("entity", 1, Enum.INT64))
- .build(),
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.getDefaultInstance())
- .addFields(
- Field.newBuilder()
- .setName(hash("feature"))
- .setValue(Value.newBuilder().setStringVal("one")))
- .build());
-
- List featureRows =
- ImmutableList.of(
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/fs")
- .addFields(field("entity", 1, Enum.INT64))
- .addFields(field("feature", "one", Enum.STRING))
- .build());
-
- PCollection failedElementCount =
- p.apply(Create.of(featureRows))
- .apply(redisClusterFeatureSink.writer())
- .getFailedInserts()
- .apply(Count.globally());
-
- redisCluster.stop();
- final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1);
- ScheduledFuture> scheduledRedisRestart =
- redisRestartExecutor.schedule(
- () -> {
- redisCluster.start();
- },
- 3,
- TimeUnit.SECONDS);
-
- PAssert.that(failedElementCount).containsInAnyOrder(0L);
- p.run();
- scheduledRedisRestart.cancel(true);
-
- kvs.forEach(
- (key, value) -> {
- byte[] actual = redisClusterCommands.get(key.toByteArray());
- assertThat(actual, equalTo(value.toByteArray()));
- });
- }
-
- @Test
- public void shouldProduceFailedElementIfRetryExceeded() {
- RedisClusterConfig redisClusterConfig =
- RedisClusterConfig.newBuilder()
- .setConnectionString(CONNECTION_STRING)
- .setInitialBackoffMs(2000)
- .setMaxRetries(1)
- .build();
-
- FeatureSetSpec spec1 =
- FeatureSetSpec.newBuilder()
- .setName("fs")
- .setProject("myproject")
- .addEntities(EntitySpec.newBuilder().setName("entity").setValueType(Enum.INT64).build())
- .addFeatures(
- FeatureSpec.newBuilder().setName("feature").setValueType(Enum.STRING).build())
- .build();
- Map specMap = ImmutableMap.of("myproject/fs", spec1);
- redisClusterFeatureSink =
- RedisFeatureSink.builder()
- .setRedisClusterConfig(redisClusterConfig)
- .build()
- .withSpecsView(p.apply("Specs-2", Create.of(specMap)).apply("View", View.asMultimap()));
-
- redisCluster.stop();
-
- List featureRows =
- ImmutableList.of(
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/fs")
- .addFields(field("entity", 1, Enum.INT64))
- .addFields(field("feature", "one", Enum.STRING))
- .build());
-
- PCollection failedElementCount =
- p.apply(Create.of(featureRows))
- .apply("modifiedSink", redisClusterFeatureSink.writer())
- .getFailedInserts()
- .apply(Count.globally());
-
- PAssert.that(failedElementCount).containsInAnyOrder(1L);
- p.run();
- }
-
- @Test
- public void shouldConvertRowWithDuplicateEntitiesToValidKey() {
-
- FeatureRow offendingRow =
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(2)))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .addFields(
- Field.newBuilder()
- .setName("feature_1")
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder()
- .setName("feature_2")
- .setValue(Value.newBuilder().setInt64Val(1001)))
- .build();
-
- RedisKey expectedKey =
- RedisKey.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .build();
-
- FeatureRow expectedValue =
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName(hash("feature_1"))
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder()
- .setName(hash("feature_2"))
- .setValue(Value.newBuilder().setInt64Val(1001)))
- .build();
-
- p.apply(Create.of(offendingRow)).apply(redisClusterFeatureSink.writer());
-
- p.run();
-
- byte[] actual = redisClusterCommands.get(expectedKey.toByteArray());
- assertThat(actual, equalTo(expectedValue.toByteArray()));
- }
-
- @Test
- public void shouldConvertRowWithOutOfOrderFieldsToValidKey() {
- FeatureRow offendingRow =
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addFields(
- Field.newBuilder()
- .setName("feature_2")
- .setValue(Value.newBuilder().setInt64Val(1001)))
- .addFields(
- Field.newBuilder()
- .setName("feature_1")
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .build();
-
- RedisKey expectedKey =
- RedisKey.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .build();
-
- List expectedFields =
- Arrays.asList(
- Field.newBuilder()
- .setName(hash("feature_1"))
- .setValue(Value.newBuilder().setStringVal("strValue1"))
- .build(),
- Field.newBuilder()
- .setName(hash("feature_2"))
- .setValue(Value.newBuilder().setInt64Val(1001))
- .build());
- FeatureRow expectedValue =
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addAllFields(expectedFields)
- .build();
-
- p.apply(Create.of(offendingRow)).apply(redisClusterFeatureSink.writer());
-
- p.run();
-
- byte[] actual = redisClusterCommands.get(expectedKey.toByteArray());
- assertThat(actual, equalTo(expectedValue.toByteArray()));
- }
-
- @Test
- public void shouldMergeDuplicateFeatureFields() {
- FeatureRow featureRowWithDuplicatedFeatureFields =
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .addFields(
- Field.newBuilder()
- .setName("feature_1")
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder()
- .setName("feature_1")
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder()
- .setName("feature_2")
- .setValue(Value.newBuilder().setInt64Val(1001)))
- .build();
-
- RedisKey expectedKey =
- RedisKey.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .build();
-
- FeatureRow expectedValue =
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName(hash("feature_1"))
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder()
- .setName(hash("feature_2"))
- .setValue(Value.newBuilder().setInt64Val(1001)))
- .build();
-
- p.apply(Create.of(featureRowWithDuplicatedFeatureFields))
- .apply(redisClusterFeatureSink.writer());
-
- p.run();
-
- byte[] actual = redisClusterCommands.get(expectedKey.toByteArray());
- assertThat(actual, equalTo(expectedValue.toByteArray()));
- }
-
- @Test
- public void shouldPopulateMissingFeatureValuesWithDefaultInstance() {
- FeatureRow featureRowWithDuplicatedFeatureFields =
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .addFields(
- Field.newBuilder()
- .setName("feature_1")
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .build();
-
- RedisKey expectedKey =
- RedisKey.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .build();
-
- FeatureRow expectedValue =
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName(hash("feature_1"))
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder().setName(hash("feature_2")).setValue(Value.getDefaultInstance()))
- .build();
-
- p.apply(Create.of(featureRowWithDuplicatedFeatureFields))
- .apply(redisClusterFeatureSink.writer());
-
- p.run();
-
- byte[] actual = redisClusterCommands.get(expectedKey.toByteArray());
- assertThat(actual, equalTo(expectedValue.toByteArray()));
- }
-}
diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java
index 948b8d0fda..12377fd1d1 100644
--- a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java
+++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java
@@ -20,63 +20,112 @@
import static feast.storage.common.testing.TestUtil.hash;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.collection.IsCollectionWithSize.hasSize;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.protobuf.Message;
import com.google.protobuf.Timestamp;
import feast.common.models.FeatureSetReference;
import feast.proto.core.FeatureSetProto.EntitySpec;
import feast.proto.core.FeatureSetProto.FeatureSetSpec;
import feast.proto.core.FeatureSetProto.FeatureSpec;
import feast.proto.core.StoreProto;
+import feast.proto.core.StoreProto.Store.RedisClusterConfig;
import feast.proto.core.StoreProto.Store.RedisConfig;
import feast.proto.storage.RedisProto.RedisKey;
import feast.proto.types.FeatureRowProto.FeatureRow;
import feast.proto.types.FieldProto.Field;
import feast.proto.types.ValueProto.Value;
import feast.proto.types.ValueProto.ValueType.Enum;
+import io.lettuce.core.AbstractRedisClient;
import io.lettuce.core.RedisClient;
import io.lettuce.core.RedisURI;
-import io.lettuce.core.api.StatefulRedisConnection;
import io.lettuce.core.api.sync.RedisStringCommands;
+import io.lettuce.core.cluster.RedisClusterClient;
import io.lettuce.core.codec.ByteArrayCodec;
-import java.io.IOException;
import java.util.*;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import net.ishiis.redis.unit.Redis;
+import net.ishiis.redis.unit.RedisCluster;
+import net.ishiis.redis.unit.RedisServer;
+import org.apache.beam.sdk.extensions.protobuf.ProtoCoder;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestStream;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.PCollection;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import redis.embedded.Redis;
-import redis.embedded.RedisServer;
+import org.junit.*;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+@RunWith(Parameterized.class)
public class RedisFeatureSinkTest {
@Rule public transient TestPipeline p = TestPipeline.create();
private static String REDIS_HOST = "localhost";
- private static int REDIS_PORT = 51234;
- private Redis redis;
- private RedisClient redisClient;
- private RedisStringCommands sync;
+ private static int REDIS_PORT = 51233;
+ private static Integer[] REDIS_CLUSTER_PORTS = {6380, 6381, 6382};
+ private RedisStringCommands sync;
private RedisFeatureSink redisFeatureSink;
private Map specMap;
- @Before
- public void setUp() throws IOException {
- redis = new RedisServer(REDIS_PORT);
- redis.start();
- redisClient =
+ @Parameterized.Parameters
+ public static Iterable backends() {
+ Redis redis = new RedisServer(REDIS_PORT);
+ RedisClient client =
RedisClient.create(new RedisURI(REDIS_HOST, REDIS_PORT, java.time.Duration.ofMillis(2000)));
- StatefulRedisConnection connection = redisClient.connect(new ByteArrayCodec());
- sync = connection.sync();
+
+ Redis redisCluster = new RedisCluster(REDIS_CLUSTER_PORTS);
+ RedisClusterClient clientCluster =
+ RedisClusterClient.create(
+ Lists.newArrayList(REDIS_CLUSTER_PORTS).stream()
+ .map(port -> RedisURI.create(REDIS_HOST, port))
+ .collect(Collectors.toList()));
+
+ StoreProto.Store.RedisConfig redisConfig =
+ StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build();
+
+ StoreProto.Store.RedisClusterConfig redisClusterConfig =
+ StoreProto.Store.RedisClusterConfig.newBuilder()
+ .setConnectionString(
+ Lists.newArrayList(REDIS_CLUSTER_PORTS).stream()
+ .map(port -> String.format("%s:%d", REDIS_HOST, port))
+ .collect(Collectors.joining(",")))
+ .setInitialBackoffMs(2000)
+ .setMaxRetries(4)
+ .build();
+
+ return Arrays.asList(
+ new Object[] {redis, client, redisConfig},
+ new Object[] {redisCluster, clientCluster, redisClusterConfig});
+ }
+
+ @Parameterized.Parameter(0)
+ public Redis redisServer;
+
+ @Parameterized.Parameter(1)
+ public AbstractRedisClient redisClient;
+
+ @Parameterized.Parameter(2)
+ public Message redisConfig;
+
+ @Before
+ public void setUp() {
+ redisServer.start();
+
+ if (redisClient instanceof RedisClient) {
+ sync = ((RedisClient) redisClient).connect(new ByteArrayCodec()).sync();
+ } else {
+ sync = ((RedisClusterClient) redisClient).connect(new ByteArrayCodec()).sync();
+ }
FeatureSetSpec spec1 =
FeatureSetSpec.newBuilder()
@@ -111,17 +160,42 @@ public void setUp() throws IOException {
ImmutableMap.of(
FeatureSetReference.of("myproject", "fs", 1), spec1,
FeatureSetReference.of("myproject", "feature_set", 1), spec2);
- StoreProto.Store.RedisConfig redisConfig =
- StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build();
- redisFeatureSink = RedisFeatureSink.builder().setRedisConfig(redisConfig).build();
+ RedisFeatureSink.Builder builder = RedisFeatureSink.builder();
+ if (redisConfig instanceof RedisConfig) {
+ builder = builder.setRedisConfig((RedisConfig) redisConfig);
+ } else {
+ builder = builder.setRedisClusterConfig((RedisClusterConfig) redisConfig);
+ }
+ redisFeatureSink = builder.build();
redisFeatureSink.prepareWrite(p.apply("Specs-1", Create.of(specMap)));
}
@After
- public void teardown() {
- redisClient.shutdown();
- redis.stop();
+ public void tearDown() {
+ if (redisServer.isActive()) {
+ redisServer.stop();
+ }
+ }
+
+ private RedisKey createRedisKey(String featureSetRef, Field... fields) {
+ return RedisKey.newBuilder()
+ .setFeatureSet(featureSetRef)
+ .addAllEntities(Lists.newArrayList(fields))
+ .build();
+ }
+
+ private FeatureRow createFeatureRow(String featureSetRef, Timestamp timestamp, Field... fields) {
+ FeatureRow.Builder builder = FeatureRow.newBuilder();
+ if (featureSetRef != null) {
+ builder.setFeatureSet(featureSetRef);
+ }
+
+ if (timestamp != null) {
+ builder.setEventTimestamp(timestamp);
+ }
+
+ return builder.addAllFields(Lists.newArrayList(fields)).build();
}
@Test
@@ -129,42 +203,26 @@ public void shouldWriteToRedis() {
HashMap kvs = new LinkedHashMap<>();
kvs.put(
- RedisKey.newBuilder()
- .setFeatureSet("myproject/fs")
- .addEntities(field("entity", 1, Enum.INT64))
- .build(),
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.getDefaultInstance())
- .addFields(
- Field.newBuilder()
- .setName(hash("feature"))
- .setValue(Value.newBuilder().setStringVal("one")))
- .build());
+ createRedisKey("myproject/fs", field("entity", 1, Enum.INT64)),
+ createFeatureRow(
+ null, Timestamp.getDefaultInstance(), field(hash("feature"), "one", Enum.STRING)));
kvs.put(
- RedisKey.newBuilder()
- .setFeatureSet("myproject/fs")
- .addEntities(field("entity", 2, Enum.INT64))
- .build(),
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.getDefaultInstance())
- .addFields(
- Field.newBuilder()
- .setName(hash("feature"))
- .setValue(Value.newBuilder().setStringVal("two")))
- .build());
+ createRedisKey("myproject/fs", field("entity", 2, Enum.INT64)),
+ createFeatureRow(
+ null, Timestamp.getDefaultInstance(), field(hash("feature"), "two", Enum.STRING)));
List featureRows =
ImmutableList.of(
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/fs")
- .addFields(field("entity", 1, Enum.INT64))
- .addFields(field("feature", "one", Enum.STRING))
- .build(),
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/fs")
- .addFields(field("entity", 2, Enum.INT64))
- .addFields(field("feature", "two", Enum.STRING))
- .build());
+ createFeatureRow(
+ "myproject/fs",
+ null,
+ field("entity", 1, Enum.INT64),
+ field("feature", "one", Enum.STRING)),
+ createFeatureRow(
+ "myproject/fs",
+ null,
+ field("entity", 2, Enum.INT64),
+ field("feature", "two", Enum.STRING)));
p.apply(Create.of(featureRows)).apply(redisFeatureSink.writer());
p.run();
@@ -176,7 +234,7 @@ public void shouldWriteToRedis() {
});
}
- @Test(timeout = 10000)
+ @Test(timeout = 30000)
public void shouldRetryFailConnection() throws InterruptedException {
RedisConfig redisConfig =
RedisConfig.newBuilder()
@@ -194,25 +252,17 @@ public void shouldRetryFailConnection() throws InterruptedException {
HashMap kvs = new LinkedHashMap<>();
kvs.put(
- RedisKey.newBuilder()
- .setFeatureSet("myproject/fs")
- .addEntities(field("entity", 1, Enum.INT64))
- .build(),
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.getDefaultInstance())
- .addFields(
- Field.newBuilder()
- .setName(hash("feature"))
- .setValue(Value.newBuilder().setStringVal("one")))
- .build());
+ createRedisKey("myproject/fs", field("entity", 1, Enum.INT64)),
+ createFeatureRow(
+ "", Timestamp.getDefaultInstance(), field(hash("feature"), "one", Enum.STRING)));
List featureRows =
ImmutableList.of(
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/fs")
- .addFields(field("entity", 1, Enum.INT64))
- .addFields(field("feature", "one", Enum.STRING))
- .build());
+ createFeatureRow(
+ "myproject/fs",
+ null,
+ field("entity", 1, Enum.INT64),
+ field("feature", "one", Enum.STRING)));
PCollection failedElementCount =
p.apply(Create.of(featureRows))
@@ -220,12 +270,12 @@ public void shouldRetryFailConnection() throws InterruptedException {
.getFailedInserts()
.apply(Count.globally());
- redis.stop();
+ redisServer.stop();
final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1);
ScheduledFuture> scheduledRedisRestart =
redisRestartExecutor.schedule(
() -> {
- redis.start();
+ redisServer.start();
},
3,
TimeUnit.SECONDS);
@@ -255,17 +305,9 @@ public void shouldProduceFailedElementIfRetryExceeded() {
HashMap kvs = new LinkedHashMap<>();
kvs.put(
- RedisKey.newBuilder()
- .setFeatureSet("myproject/fs")
- .addEntities(field("entity", 1, Enum.INT64))
- .build(),
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.getDefaultInstance())
- .addFields(
- Field.newBuilder()
- .setName(hash("feature"))
- .setValue(Value.newBuilder().setStringVal("one")))
- .build());
+ createRedisKey("myproject/fs", field("entity", 1, Enum.INT64)),
+ createFeatureRow(
+ "", Timestamp.getDefaultInstance(), field(hash("feature"), "one", Enum.STRING)));
List featureRows =
ImmutableList.of(
@@ -281,7 +323,7 @@ public void shouldProduceFailedElementIfRetryExceeded() {
.getFailedInserts()
.apply(Count.globally());
- redis.stop();
+ redisServer.stop();
PAssert.that(failedElementCount).containsInAnyOrder(1L);
p.run();
}
@@ -290,56 +332,27 @@ public void shouldProduceFailedElementIfRetryExceeded() {
public void shouldConvertRowWithDuplicateEntitiesToValidKey() {
FeatureRow offendingRow =
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(2)))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .addFields(
- Field.newBuilder()
- .setName("feature_1")
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder()
- .setName("feature_2")
- .setValue(Value.newBuilder().setInt64Val(1001)))
- .build();
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(10).build(),
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_primary", 2, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING),
+ field("feature_1", "strValue1", Enum.STRING),
+ field("feature_2", 1001, Enum.INT64));
RedisKey expectedKey =
- RedisKey.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .build();
+ createRedisKey(
+ "myproject/feature_set",
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING));
FeatureRow expectedValue =
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName(hash("feature_1"))
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder()
- .setName(hash("feature_2"))
- .setValue(Value.newBuilder().setInt64Val(1001)))
- .build();
+ createFeatureRow(
+ "",
+ Timestamp.newBuilder().setSeconds(10).build(),
+ field(hash("feature_1"), "strValue1", Enum.STRING),
+ field(hash("feature_2"), 1001, Enum.INT64));
p.apply(Create.of(offendingRow)).apply(redisFeatureSink.writer());
@@ -352,55 +365,26 @@ public void shouldConvertRowWithDuplicateEntitiesToValidKey() {
@Test
public void shouldConvertRowWithOutOfOrderFieldsToValidKey() {
FeatureRow offendingRow =
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addFields(
- Field.newBuilder()
- .setName("feature_2")
- .setValue(Value.newBuilder().setInt64Val(1001)))
- .addFields(
- Field.newBuilder()
- .setName("feature_1")
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .build();
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(10).build(),
+ field("entity_id_secondary", "a", Enum.STRING),
+ field("entity_id_primary", 1, Enum.INT32),
+ field("feature_2", 1001, Enum.INT64),
+ field("feature_1", "strValue1", Enum.STRING));
RedisKey expectedKey =
- RedisKey.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .build();
+ createRedisKey(
+ "myproject/feature_set",
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING));
- List expectedFields =
- Arrays.asList(
- Field.newBuilder()
- .setName(hash("feature_1"))
- .setValue(Value.newBuilder().setStringVal("strValue1"))
- .build(),
- Field.newBuilder()
- .setName(hash("feature_2"))
- .setValue(Value.newBuilder().setInt64Val(1001))
- .build());
FeatureRow expectedValue =
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addAllFields(expectedFields)
- .build();
+ createFeatureRow(
+ "",
+ Timestamp.newBuilder().setSeconds(10).build(),
+ field(hash("feature_1"), "strValue1", Enum.STRING),
+ field(hash("feature_2"), 1001, Enum.INT64));
p.apply(Create.of(offendingRow)).apply(redisFeatureSink.writer());
@@ -413,56 +397,27 @@ public void shouldConvertRowWithOutOfOrderFieldsToValidKey() {
@Test
public void shouldMergeDuplicateFeatureFields() {
FeatureRow featureRowWithDuplicatedFeatureFields =
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .addFields(
- Field.newBuilder()
- .setName("feature_1")
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder()
- .setName("feature_1")
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder()
- .setName("feature_2")
- .setValue(Value.newBuilder().setInt64Val(1001)))
- .build();
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(10).build(),
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING),
+ field("feature_2", 1001, Enum.INT64),
+ field("feature_1", "strValue1", Enum.STRING),
+ field("feature_1", "strValue1", Enum.STRING));
RedisKey expectedKey =
- RedisKey.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .build();
+ createRedisKey(
+ "myproject/feature_set",
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING));
FeatureRow expectedValue =
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName(hash("feature_1"))
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder()
- .setName(hash("feature_2"))
- .setValue(Value.newBuilder().setInt64Val(1001)))
- .build();
+ createFeatureRow(
+ "",
+ Timestamp.newBuilder().setSeconds(10).build(),
+ field(hash("feature_1"), "strValue1", Enum.STRING),
+ field(hash("feature_2"), 1001, Enum.INT64));
p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.writer());
@@ -475,46 +430,28 @@ public void shouldMergeDuplicateFeatureFields() {
@Test
public void shouldPopulateMissingFeatureValuesWithDefaultInstance() {
FeatureRow featureRowWithDuplicatedFeatureFields =
- FeatureRow.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addFields(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .addFields(
- Field.newBuilder()
- .setName("feature_1")
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .build();
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(10).build(),
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING),
+ field("feature_1", "strValue1", Enum.STRING));
RedisKey expectedKey =
- RedisKey.newBuilder()
- .setFeatureSet("myproject/feature_set")
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_primary")
- .setValue(Value.newBuilder().setInt32Val(1)))
- .addEntities(
- Field.newBuilder()
- .setName("entity_id_secondary")
- .setValue(Value.newBuilder().setStringVal("a")))
- .build();
+ createRedisKey(
+ "myproject/feature_set",
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING));
FeatureRow expectedValue =
- FeatureRow.newBuilder()
- .setEventTimestamp(Timestamp.newBuilder().setSeconds(10))
- .addFields(
- Field.newBuilder()
- .setName(hash("feature_1"))
- .setValue(Value.newBuilder().setStringVal("strValue1")))
- .addFields(
- Field.newBuilder().setName(hash("feature_2")).setValue(Value.getDefaultInstance()))
- .build();
+ createFeatureRow(
+ "",
+ Timestamp.newBuilder().setSeconds(10).build(),
+ field(hash("feature_1"), "strValue1", Enum.STRING),
+ Field.newBuilder()
+ .setName(hash("feature_2"))
+ .setValue(Value.getDefaultInstance())
+ .build());
p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.writer());
@@ -523,4 +460,206 @@ public void shouldPopulateMissingFeatureValuesWithDefaultInstance() {
byte[] actual = sync.get(expectedKey.toByteArray());
assertThat(actual, equalTo(expectedValue.toByteArray()));
}
+
+ @Test
+ public void shouldDeduplicateRowsWithinBatch() {
+ TestStream featureRowTestStream =
+ TestStream.create(ProtoCoder.of(FeatureRow.class))
+ .addElements(
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(20).build(),
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING),
+ field("feature_2", 111, Enum.INT32)))
+ .addElements(
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(10).build(),
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING),
+ field("feature_2", 222, Enum.INT32)))
+ .addElements(
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.getDefaultInstance(),
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING),
+ field("feature_2", 333, Enum.INT32)))
+ .advanceWatermarkToInfinity();
+
+ p.apply(featureRowTestStream).apply(redisFeatureSink.writer());
+ p.run();
+
+ RedisKey expectedKey =
+ createRedisKey(
+ "myproject/feature_set",
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING));
+
+ FeatureRow expectedValue =
+ createFeatureRow(
+ "",
+ Timestamp.newBuilder().setSeconds(20).build(),
+ Field.newBuilder()
+ .setName(hash("feature_1"))
+ .setValue(Value.getDefaultInstance())
+ .build(),
+ field(hash("feature_2"), 111, Enum.INT32));
+
+ byte[] actual = sync.get(expectedKey.toByteArray());
+ assertThat(actual, equalTo(expectedValue.toByteArray()));
+ }
+
+ @Test
+ public void shouldWriteWithLatterTimestamp() {
+ TestStream featureRowTestStream =
+ TestStream.create(ProtoCoder.of(FeatureRow.class))
+ .addElements(
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(20).build(),
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING),
+ field("feature_2", 111, Enum.INT32)))
+ .addElements(
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(20).build(),
+ field("entity_id_primary", 2, Enum.INT32),
+ field("entity_id_secondary", "b", Enum.STRING),
+ field("feature_2", 222, Enum.INT32)))
+ .addElements(
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(10).build(),
+ field("entity_id_primary", 3, Enum.INT32),
+ field("entity_id_secondary", "c", Enum.STRING),
+ field("feature_2", 333, Enum.INT32)))
+ .advanceWatermarkToInfinity();
+
+ RedisKey keyA =
+ createRedisKey(
+ "myproject/feature_set",
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING));
+
+ RedisKey keyB =
+ createRedisKey(
+ "myproject/feature_set",
+ field("entity_id_primary", 2, Enum.INT32),
+ field("entity_id_secondary", "b", Enum.STRING));
+
+ RedisKey keyC =
+ createRedisKey(
+ "myproject/feature_set",
+ field("entity_id_primary", 3, Enum.INT32),
+ field("entity_id_secondary", "c", Enum.STRING));
+
+ sync.set(
+ keyA.toByteArray(),
+ createFeatureRow("", Timestamp.newBuilder().setSeconds(30).build()).toByteArray());
+
+ sync.set(
+ keyB.toByteArray(),
+ createFeatureRow("", Timestamp.newBuilder().setSeconds(10).build()).toByteArray());
+
+ sync.set(
+ keyC.toByteArray(),
+ createFeatureRow("", Timestamp.newBuilder().setSeconds(10).build()).toByteArray());
+
+ p.apply(featureRowTestStream).apply(redisFeatureSink.writer());
+ p.run();
+
+ assertThat(
+ sync.get(keyA.toByteArray()),
+ equalTo(createFeatureRow("", Timestamp.newBuilder().setSeconds(30).build()).toByteArray()));
+
+ assertThat(
+ sync.get(keyB.toByteArray()),
+ equalTo(
+ createFeatureRow(
+ "",
+ Timestamp.newBuilder().setSeconds(20).build(),
+ Field.newBuilder()
+ .setName(hash("feature_1"))
+ .setValue(Value.getDefaultInstance())
+ .build(),
+ field(hash("feature_2"), 222, Enum.INT32))
+ .toByteArray()));
+
+ assertThat(
+ sync.get(keyC.toByteArray()),
+ equalTo(createFeatureRow("", Timestamp.newBuilder().setSeconds(10).build()).toByteArray()));
+ }
+
+ @Test
+ public void shouldOverwriteInvalidRows() {
+ TestStream featureRowTestStream =
+ TestStream.create(ProtoCoder.of(FeatureRow.class))
+ .addElements(
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(20).build(),
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING),
+ field("feature_1", "text", Enum.STRING),
+ field("feature_2", 111, Enum.INT32)))
+ .advanceWatermarkToInfinity();
+
+ RedisKey expectedKey =
+ createRedisKey(
+ "myproject/feature_set",
+ field("entity_id_primary", 1, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING));
+
+ sync.set(expectedKey.toByteArray(), "some-invalid-data".getBytes());
+
+ p.apply(featureRowTestStream).apply(redisFeatureSink.writer());
+ p.run();
+
+ FeatureRow expectedValue =
+ createFeatureRow(
+ "",
+ Timestamp.newBuilder().setSeconds(20).build(),
+ field(hash("feature_1"), "text", Enum.STRING),
+ field(hash("feature_2"), 111, Enum.INT32));
+
+ byte[] actual = sync.get(expectedKey.toByteArray());
+ assertThat(actual, equalTo(expectedValue.toByteArray()));
+ }
+
+ @Test
+ public void loadTest() {
+ List rows =
+ IntStream.range(0, 10000)
+ .mapToObj(
+ i ->
+ createFeatureRow(
+ "myproject/feature_set",
+ Timestamp.newBuilder().setSeconds(20).build(),
+ field("entity_id_primary", i, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING),
+ field("feature_1", "text", Enum.STRING),
+ field("feature_2", 111, Enum.INT32)))
+ .collect(Collectors.toList());
+
+ p.apply(Create.of(rows)).apply(redisFeatureSink.writer());
+ p.run();
+
+ List outcome =
+ IntStream.range(0, 10000)
+ .mapToObj(
+ i ->
+ createRedisKey(
+ "myproject/feature_set",
+ field("entity_id_primary", i, Enum.INT32),
+ field("entity_id_secondary", "a", Enum.STRING))
+ .toByteArray())
+ .map(sync::get)
+ .collect(Collectors.toList());
+
+ assertThat(outcome, hasSize(10000));
+ assertThat("All rows were stored", outcome.stream().allMatch(Objects::nonNull));
+ }
}
diff --git a/tests/e2e/redis/basic-ingest-redis-serving.py b/tests/e2e/redis/basic-ingest-redis-serving.py
index 1fcae69ed3..c1e25508d4 100644
--- a/tests/e2e/redis/basic-ingest-redis-serving.py
+++ b/tests/e2e/redis/basic-ingest-redis-serving.py
@@ -4,7 +4,7 @@
import tempfile
import time
import uuid
-from datetime import datetime
+from datetime import datetime, timedelta
import grpc
import numpy as np
@@ -821,6 +821,8 @@ def all_types_dataframe():
@pytest.mark.timeout(45)
@pytest.mark.run(order=20)
def test_all_types_register_feature_set_success(client):
+ client.set_project(PROJECT_NAME)
+
all_types_fs_expected = FeatureSet(
name="all_types",
entities=[Entity(name="user_id", dtype=ValueType.INT64)],
@@ -930,9 +932,11 @@ def try_get_features():
@pytest.mark.timeout(300)
-@pytest.mark.run(order=29)
+@pytest.mark.run(order=35)
def test_all_types_ingest_jobs(client, all_types_dataframe):
# list ingestion jobs given featureset
+ client.set_project(PROJECT_NAME)
+
all_types_fs = client.get_feature_set(name="all_types")
ingest_jobs = client.list_ingest_jobs(
feature_set_ref=FeatureSetRef.from_feature_set(all_types_fs)
@@ -990,7 +994,7 @@ def large_volume_dataframe():
@pytest.mark.timeout(45)
-@pytest.mark.run(order=30)
+@pytest.mark.run(order=40)
def test_large_volume_register_feature_set_success(client):
cust_trans_fs_expected = FeatureSet.from_yaml(
f"{DIR_PATH}/large_volume/cust_trans_large_fs.yaml"
@@ -1016,7 +1020,7 @@ def test_large_volume_register_feature_set_success(client):
@pytest.mark.timeout(300)
-@pytest.mark.run(order=31)
+@pytest.mark.run(order=41)
def test_large_volume_ingest_success(client, large_volume_dataframe):
# Get large volume feature set
cust_trans_fs = client.get_feature_set(name="customer_transactions_large")
@@ -1026,7 +1030,7 @@ def test_large_volume_ingest_success(client, large_volume_dataframe):
@pytest.mark.timeout(90)
-@pytest.mark.run(order=32)
+@pytest.mark.run(order=42)
def test_large_volume_retrieve_online_success(client, large_volume_dataframe):
# Poll serving for feature values until the correct values are returned
feature_refs = [
@@ -1112,7 +1116,7 @@ def all_types_parquet_file():
@pytest.mark.timeout(300)
-@pytest.mark.run(order=40)
+@pytest.mark.run(order=50)
def test_all_types_parquet_register_feature_set_success(client):
# Load feature set from file
all_types_parquet_expected = FeatureSet.from_yaml(
@@ -1140,7 +1144,7 @@ def test_all_types_parquet_register_feature_set_success(client):
@pytest.mark.timeout(600)
-@pytest.mark.run(order=41)
+@pytest.mark.run(order=51)
def test_all_types_infer_register_ingest_file_success(client, all_types_parquet_file):
# Get feature set
all_types_fs = client.get_feature_set(name="all_types_parquet")
@@ -1150,7 +1154,7 @@ def test_all_types_infer_register_ingest_file_success(client, all_types_parquet_
@pytest.mark.timeout(200)
-@pytest.mark.run(order=50)
+@pytest.mark.run(order=60)
def test_list_entities_and_features(client):
customer_entity = Entity("customer_id", ValueType.INT64)
driver_entity = Entity("driver_id", ValueType.INT64)
@@ -1225,7 +1229,7 @@ def test_list_entities_and_features(client):
@pytest.mark.timeout(900)
-@pytest.mark.run(order=60)
+@pytest.mark.run(order=70)
def test_sources_deduplicate_ingest_jobs(client):
source = KafkaSource("localhost:9092", "feast-features")
alt_source = KafkaSource("localhost:9092", "feast-data")
@@ -1273,6 +1277,58 @@ def get_running_jobs():
time.sleep(1)
+@pytest.mark.run(order=30)
+def test_sink_writes_only_recent_rows(client):
+ client.set_project("default")
+
+ feature_refs = ["driver:rating", "driver:cost"]
+
+ later_df = basic_dataframe(
+ entities=["driver_id"],
+ features=["rating", "cost"],
+ ingest_time=datetime.utcnow(),
+ n_size=5,
+ )
+
+ earlier_df = basic_dataframe(
+ entities=["driver_id"],
+ features=["rating", "cost"],
+ ingest_time=datetime.utcnow() - timedelta(minutes=5),
+ n_size=5,
+ )
+
+ def try_get_features():
+ response = client.get_online_features(
+ entity_rows=[
+ GetOnlineFeaturesRequest.EntityRow(
+ fields={"driver_id": Value(int64_val=later_df.iloc[0]["driver_id"])}
+ )
+ ],
+ feature_refs=feature_refs,
+ ) # type: GetOnlineFeaturesResponse
+ is_ok = all(
+ [check_online_response(ref, later_df, response) for ref in feature_refs]
+ )
+ return response, is_ok
+
+ # test compaction within batch
+ client.ingest("driver", pd.concat([earlier_df, later_df]))
+ wait_retry_backoff(
+ retry_fn=try_get_features,
+ timeout_secs=90,
+ timeout_msg="Timed out trying to get online feature values",
+ )
+
+ # test read before write
+ client.ingest("driver", earlier_df)
+ time.sleep(10)
+ wait_retry_backoff(
+ retry_fn=try_get_features,
+ timeout_secs=90,
+ timeout_msg="Timed out trying to get online feature values",
+ )
+
+
# TODO: rewrite these using python SDK once the labels are implemented there
class TestsBasedOnGrpc:
GRPC_CONNECTION_TIMEOUT = 3