Skip to content

Commit

Permalink
optimize redis hash value (#109)
Browse files Browse the repository at this point in the history
Signed-off-by: Khor Shu Heng <khor.heng@gojek.com>

Co-authored-by: Khor Shu Heng <khor.heng@gojek.com>
  • Loading branch information
khorshuheng and khorshuheng committed Feb 17, 2022
1 parent 73b09f6 commit 642d2c9
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 120 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,6 @@ sdk/python/docs/html
.bloop
.metals
*.code-workspace

spark/ingestion/src/test/resources/python/libs.tar.gz
spark/ingestion/src/test/resources/python/udf.pickle
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class HashTypePersistence(config: SparkRedisConfig) extends Persistence with Ser

private def encodeRow(
value: Row,
maxExpiryTimestamp: java.sql.Timestamp
expiryTimestamp: Option[java.sql.Timestamp]
): Map[Array[Byte], Array[Byte]] = {
val fields = value.schema.fields.map(_.name)
val types = value.schema.fields.map(f => (f.name, f.dataType)).toMap
Expand All @@ -60,25 +60,23 @@ class HashTypePersistence(config: SparkRedisConfig) extends Persistence with Ser

val timestampHash = Seq(
(
timestampHashKey(config.namespace).getBytes,
timestampHashKey(config.namespace),
encodeValue(value.getAs[Timestamp](config.timestampColumn), TimestampType)
)
)

val expiryUnixTimestamp = {
if (config.maxAge > 0)
value.getAs[java.sql.Timestamp](config.timestampColumn).getTime + config.maxAge * 1000
else maxExpiryTimestamp.getTime
expiryTimestamp match {
case Some(expiry) =>
val expiryTimestampHash = Seq(
(
expiryTimestampHashKey(config.namespace),
encodeValue(expiry, TimestampType)
)
)
values ++ timestampHash ++ expiryTimestampHash
case None => values ++ timestampHash
}
val expiryTimestamp = new java.sql.Timestamp(expiryUnixTimestamp)
val expiryTimestampHash = Seq(
(
expiryTimestampHashKey(config.namespace).getBytes,
encodeValue(expiryTimestamp, TimestampType)
)
)

values ++ timestampHash ++ expiryTimestampHash
}

private def encodeValue(value: Any, `type`: DataType): Array[Byte] = {
Expand All @@ -90,12 +88,14 @@ class HashTypePersistence(config: SparkRedisConfig) extends Persistence with Ser
Hashing.murmur3_32.hashString(fullFeatureReference, StandardCharsets.UTF_8).asBytes()
}

private def timestampHashKey(namespace: String): String = {
s"${config.timestampPrefix}:${namespace}"
private def timestampHashKey(namespace: String): Array[Byte] = {
Hashing.murmur3_32
.hashString(s"${config.timestampPrefix}:${namespace}", StandardCharsets.UTF_8)
.asBytes
}

private def expiryTimestampHashKey(namespace: String): String = {
s"${config.expiryPrefix}:${namespace}"
private def expiryTimestampHashKey(namespace: String): Array[Byte] = {
config.expiryPrefix.getBytes()
}

private def decodeTimestamp(encodedTimestamp: Array[Byte]): java.sql.Timestamp = {
Expand All @@ -106,15 +106,16 @@ class HashTypePersistence(config: SparkRedisConfig) extends Persistence with Ser
pipeline: PipelineBinaryCommands,
key: Array[Byte],
row: Row,
expiryTimestamp: java.sql.Timestamp,
maxExpiryTimestamp: java.sql.Timestamp
expiryTimestamp: Option[java.sql.Timestamp]
): Unit = {
val value = encodeRow(row, maxExpiryTimestamp).asJava
val value = encodeRow(row, expiryTimestamp).asJava
pipeline.hset(key, value)
if (expiryTimestamp.equals(maxExpiryTimestamp)) {
pipeline.persist(key)
} else {
pipeline.expireAt(key, expiryTimestamp.getTime / 1000)

expiryTimestamp match {
case Some(expiry) =>
pipeline.expireAt(key, expiry.getTime / 1000)
case None =>
pipeline.persist(key)
}
}

Expand All @@ -130,7 +131,7 @@ class HashTypePersistence(config: SparkRedisConfig) extends Persistence with Ser
): Option[java.sql.Timestamp] = {
value.asScala.toMap
.map { case (key, value) =>
(key.map(_.toChar).mkString, value)
(wrapByteArray(key), value)
}
.get(timestampHashKey(config.namespace))
.map(value => decodeTimestamp(value))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ trait Persistence {
pipeline: PipelineBinaryCommands,
key: Array[Byte],
row: Row,
expiryTimestamp: Timestamp,
maxExpiryTimestamp: Timestamp
expiryTimestamp: Option[Timestamp]
): Unit

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package feast.ingestion.stores.redis

import java.util
import java.{sql, util}
import com.google.protobuf.Timestamp
import com.google.protobuf.util.Timestamps
import feast.ingestion.utils.TypeConversion
Expand Down Expand Up @@ -49,8 +49,6 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC

override def schema: StructType = ???

val MAX_EXPIRED_TIMESTAMP = new java.sql.Timestamp(Timestamps.MAX_VALUE.getSeconds * 1000)

val persistence: Persistence = new HashTypePersistence(config)

val sparkConf: SparkConf = sqlContext.sparkContext.getConf
Expand Down Expand Up @@ -118,8 +116,7 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
writePipeline,
key.getBytes(),
row,
expiryTimestampByKey(key),
MAX_EXPIRED_TIMESTAMP
expiryTimestampByKey(key)
)
}
}
Expand Down Expand Up @@ -170,27 +167,29 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
private def newExpiryTimestamp(
row: Row,
value: util.Map[Array[Byte], Array[Byte]]
): java.sql.Timestamp = {
val maxExpiryOtherFeatureTables: Long = value.asScala.toMap
): Option[java.sql.Timestamp] = {
val currentMaxExpiry: Option[Long] = value.asScala.toMap
.map { case (key, value) =>
(key.map(_.toChar).mkString, value)
(wrapByteArray(key), value)
}
.filterKeys(_.startsWith(config.expiryPrefix))
.filterKeys(_.split(":").last != config.namespace)
.values
.map(value => Timestamp.parseFrom(value).getSeconds * 1000)
.reduceOption(_ max _)
.getOrElse(0)

val rowExpiry: Long =
if (config.maxAge > 0)
(row
.getAs[java.sql.Timestamp](config.timestampColumn)
.getTime + config.maxAge * 1000)
else MAX_EXPIRED_TIMESTAMP.getTime
.get(config.expiryPrefix.getBytes())
.map(Timestamp.parseFrom(_).getSeconds * 1000)

val maxExpiry = maxExpiryOtherFeatureTables max rowExpiry
new java.sql.Timestamp(maxExpiry)
val rowExpiry: Option[Long] =
if (config.maxAge > 0)
Some(
row
.getAs[java.sql.Timestamp](config.timestampColumn)
.getTime + config.maxAge * 1000
)
else None

(currentMaxExpiry, rowExpiry) match {
case (_, None) => None
case (None, Some(expiry)) => Some(new sql.Timestamp(expiry))
case (Some(currentExpiry), Some(newExpiry)) =>
Some(new sql.Timestamp(currentExpiry max newExpiry))
}

}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,26 @@
*/
package feast.ingestion

import java.nio.file.Paths
import java.sql.Timestamp
import collection.JavaConverters._
import com.dimafeng.testcontainers.{ForAllTestContainer, GenericContainer}
import com.google.protobuf.util.Timestamps
import feast.proto.types.ValueProto.ValueType
import org.apache.spark.{SparkConf, SparkEnv}
import org.joda.time.{DateTime, Seconds}
import org.scalacheck._
import org.scalatest._
import redis.clients.jedis.Jedis
import feast.ingestion.helpers.RedisStorageHelper._
import feast.ingestion.helpers.DataHelper._
import feast.ingestion.helpers.RedisStorageHelper._
import feast.ingestion.helpers.TestRow
import feast.ingestion.metrics.StatsDStub
import feast.ingestion.utils.TypeConversion
import feast.proto.storage.RedisProto.RedisKeyV2
import feast.proto.types.ValueProto
import feast.proto.types.ValueProto.ValueType
import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.{SparkConf, SparkEnv}
import org.joda.time.DateTime
import org.scalacheck._
import redis.clients.jedis.Jedis

import java.nio.file.Paths
import java.sql.Timestamp
import java.time.Instant
import java.time.temporal.ChronoUnit
import scala.collection.JavaConverters._

class BatchPipelineIT extends SparkSpec with ForAllTestContainer {

Expand Down Expand Up @@ -111,10 +107,9 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap
storedValues should beStoredRow(
Map(
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
"_ts:test-fs" -> r.eventTimestamp,
"_ex:test-fs" -> new Timestamp(Timestamps.MAX_VALUE.getSeconds * 1000)
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp
)
)
val keyTTL = jedis.ttl(encodedEntityKey).toInt
Expand Down Expand Up @@ -158,10 +153,10 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
new java.sql.Timestamp(r.eventTimestamp.getTime + 1000 * maxAge)
storedValues should beStoredRow(
Map(
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
"_ts:test-fs" -> r.eventTimestamp,
"_ex:test-fs" -> expectedExpiryTimestamp
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp,
"_ex" -> expectedExpiryTimestamp
)
)
val keyTTL = jedis.ttl(encodedEntityKey)
Expand Down Expand Up @@ -199,10 +194,9 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
featureKeyEncoder("feature2") -> r.feature2,
featureKeyEncoderSecondTable("feature1") -> r.feature1,
featureKeyEncoderSecondTable("feature2") -> r.feature2,
"_ts:test-fs" -> r.eventTimestamp,
"_ts:test-fs-2" -> r.eventTimestamp,
"_ex:test-fs" -> expectedExpiryTimestamp1,
"_ex:test-fs-2" -> expectedExpiryTimestamp2
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp,
murmurHashHexString("_ts:test-fs-2") -> r.eventTimestamp,
"_ex" -> expectedExpiryTimestamp2
)
)
val keyTTL = jedis.ttl(encodedEntityKey)
Expand Down Expand Up @@ -259,21 +253,19 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
featureKeyEncoder("feature2") -> r.feature2,
featureKeyEncoderSecondTable("feature1") -> r.feature1,
featureKeyEncoderSecondTable("feature2") -> r.feature2,
"_ts:test-fs" -> r.eventTimestamp,
"_ts:test-fs-2" -> r.eventTimestamp,
"_ex:test-fs" -> expectedExpiryTimestamp1,
"_ex:test-fs-2" -> expectedExpiryTimestamp2
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp,
murmurHashHexString("_ts:test-fs-2") -> r.eventTimestamp,
"_ex" -> expectedExpiryTimestamp1
)
)
val keyTTL = jedis.ttl(encodedEntityKey)
val toleranceMs = 10
val keyTTL = jedis.ttl(encodedEntityKey)
keyTTL should (be <= (expectedExpiryTimestamp1.getTime - ingestionTimeUnix) / 1000 and
be > (expectedExpiryTimestamp2.getTime - ingestionTimeUnix) / 1000)

})
}

"Redis key TTL" should "be updated, when the same feature table is re-ingested, with a smaller max age" in new Scope {
"Redis key TTL" should "not be updated, when the same feature table is re-ingested, with a smaller max age" in new Scope {
val startDate = new DateTime().minusDays(1).withTimeAtStartOfDay()
val endDate = new DateTime().withTimeAtStartOfDay()
val gen = rowGenerator(startDate, endDate)
Expand Down Expand Up @@ -308,13 +300,13 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
val encodedEntityKey = encodeEntityKey(r, config.featureTable)
val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap
val expiryTimestampAfterUpdate =
new java.sql.Timestamp(r.eventTimestamp.getTime + 1000 * reducedMaxAge)
new java.sql.Timestamp(r.eventTimestamp.getTime + 1000 * maxAge)
storedValues should beStoredRow(
Map(
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
"_ts:test-fs" -> r.eventTimestamp,
"_ex:test-fs" -> expiryTimestampAfterUpdate
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp,
"_ex" -> expiryTimestampAfterUpdate
)
)
val keyTTL = jedis.ttl(encodedEntityKey)
Expand Down Expand Up @@ -354,10 +346,9 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
val storedValues = jedis.hgetAll(encodedEntityKey).asScala.toMap
storedValues should beStoredRow(
Map(
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
"_ts:test-fs" -> r.eventTimestamp,
"_ex:test-fs" -> new Timestamp(Timestamps.MAX_VALUE.getSeconds * 1000)
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp
)
)
val keyTTL = jedis.ttl(encodedEntityKey).toInt
Expand Down Expand Up @@ -395,9 +386,9 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
val storedValues = jedis.hgetAll(encodeEntityKey(r, config.featureTable)).asScala.toMap
storedValues should beStoredRow(
Map(
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
"_ts:test-fs" -> r.eventTimestamp
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp
)
)
})
Expand Down Expand Up @@ -436,9 +427,9 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
val storedValues = jedis.hgetAll(encodeEntityKey(r, config.featureTable)).asScala.toMap
storedValues should beStoredRow(
Map(
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
"_ts:test-fs" -> r.eventTimestamp
featureKeyEncoder("feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp
)
)
})
Expand Down Expand Up @@ -509,9 +500,9 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
jedis.hgetAll(encodeEntityKey(r, configWithMapping.featureTable)).asScala.toMap
storedValues should beStoredRow(
Map(
featureKeyEncoder("new_feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
"_ts:test-fs" -> r.eventTimestamp
featureKeyEncoder("new_feature1") -> r.feature1,
featureKeyEncoder("feature2") -> r.feature2,
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp
)
)
})
Expand Down Expand Up @@ -543,9 +534,9 @@ class BatchPipelineIT extends SparkSpec with ForAllTestContainer {
jedis.hgetAll(encodeEntityKey(r, configWithMapping.featureTable)).asScala.toMap
storedValues should beStoredRow(
Map(
featureKeyEncoder("feature1") -> (r.feature1 + 1),
featureKeyEncoder("feature2") -> (r.feature1 + r.feature2 * 2),
"_ts:test-fs" -> r.eventTimestamp
featureKeyEncoder("feature1") -> (r.feature1 + 1),
featureKeyEncoder("feature2") -> (r.feature1 + r.feature2 * 2),
murmurHashHexString("_ts:test-fs") -> r.eventTimestamp
)
)
})
Expand Down

0 comments on commit 642d2c9

Please sign in to comment.