Skip to content

Commit

Permalink
Refactor Spark ingestion to use Jedis instead of Spark Redis (#107)
Browse files Browse the repository at this point in the history
* replace spark redis with jedis library

* unified cluster and single node interface

Signed-off-by: Khor Shu Heng <khor.heng@go-jek.com>

* Use intersect type to simplified common interface

Signed-off-by: Khor Shu Heng <khor.heng@go-jek.com>

* handle case where config is not set

Signed-off-by: Khor Shu Heng <khor.heng@go-jek.com>

* modify doc string

Signed-off-by: Khor Shu Heng <khor.heng@go-jek.com>
  • Loading branch information
khorshuheng committed Feb 8, 2022
1 parent b2e7426 commit 0f4e27c
Show file tree
Hide file tree
Showing 11 changed files with 199 additions and 90 deletions.
6 changes: 3 additions & 3 deletions spark/ingestion/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@
</dependency>

<dependency>
<groupId>com.redislabs</groupId>
<artifactId>spark-redis_${scala.version}</artifactId>
<version>2.5.0</version>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
<version>4.1.1</version>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ object BasePipeline {
conf
.set("spark.redis.host", host)
.set("spark.redis.port", port.toString)
.set("spark.redis.auth", password)
.set("spark.redis.password", password)
.set("spark.redis.ssl", ssl.toString)
case BigTableConfig(projectId, instanceId) =>
conf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ object BatchPipeline extends BasePipeline {
sparkSession: SparkSession,
config: IngestionJobConfig
): Option[StreamingQuery] = {
java.security.Security.setProperty("networkaddress.cache.ttl", "0");
java.security.Security.setProperty("networkaddress.cache.negative.ttl", "0");
val featureTable = config.featureTable
val projection =
BasePipeline.inputProjection(config.source, featureTable.features, featureTable.entities)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2018-2022 The Feast Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.ingestion.stores.redis

import redis.clients.jedis.{ClusterPipeline, DefaultJedisClientConfig, HostAndPort}
import redis.clients.jedis.providers.ClusterConnectionProvider

import scala.collection.JavaConverters._

/**
* Provide pipeline for Redis cluster.
*/
case class ClusterPipelineProvider(endpoint: RedisEndpoint) extends PipelineProvider {

val nodes = Set(new HostAndPort(endpoint.host, endpoint.port)).asJava
val DEFAULT_CLIENT_CONFIG = DefaultJedisClientConfig
.builder()
.password(endpoint.password)
.build()
val provider = new ClusterConnectionProvider(nodes, DEFAULT_CLIENT_CONFIG)

/**
* @return a cluster pipeline
*/
override def pipeline(): UnifiedPipeline = new ClusterPipeline(provider)

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ package feast.ingestion.stores.redis

import java.nio.charset.StandardCharsets
import java.util

import com.google.common.hash.Hashing
import com.google.protobuf.Timestamp
import feast.ingestion.utils.TypeConversion
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import redis.clients.jedis.{Pipeline, Response}
import redis.clients.jedis.commands.PipelineBinaryCommands
import redis.clients.jedis.Response

import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -103,7 +103,7 @@ class HashTypePersistence(config: SparkRedisConfig) extends Persistence with Ser
}

override def save(
pipeline: Pipeline,
pipeline: PipelineBinaryCommands,
key: Array[Byte],
row: Row,
expiryTimestamp: java.sql.Timestamp,
Expand All @@ -119,7 +119,7 @@ class HashTypePersistence(config: SparkRedisConfig) extends Persistence with Ser
}

override def get(
pipeline: Pipeline,
pipeline: PipelineBinaryCommands,
key: Array[Byte]
): Response[util.Map[Array[Byte], Array[Byte]]] = {
pipeline.hgetAll(key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ package feast.ingestion.stores.redis

import java.sql.Timestamp
import java.util

import org.apache.spark.sql.Row
import redis.clients.jedis.{Pipeline, Response}
import redis.clients.jedis.commands.PipelineBinaryCommands
import redis.clients.jedis.Response

/**
* Determine how a Spark row should be serialized and stored on Redis.
Expand All @@ -38,7 +38,7 @@ trait Persistence {
* is equal to the maxExpiryTimestamp
*/
def save(
pipeline: Pipeline,
pipeline: PipelineBinaryCommands,
key: Array[Byte],
row: Row,
expiryTimestamp: Timestamp,
Expand All @@ -56,7 +56,7 @@ trait Persistence {
* @return Redis response representing the row value
*/
def get(
pipeline: Pipeline,
pipeline: PipelineBinaryCommands,
key: Array[Byte]
): Response[util.Map[Array[Byte], Array[Byte]]]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2018-2022 The Feast Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.ingestion.stores.redis

import redis.clients.jedis.commands.PipelineBinaryCommands

import java.io.Closeable

/**
* Provide either a pipeline or cluster pipeline to read and write data into Redis.
*/
trait PipelineProvider {

type UnifiedPipeline = PipelineBinaryCommands with Closeable

/**
* @return an interface for executing pipeline commands
*/
def pipeline(): UnifiedPipeline
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2018-2022 The Feast Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.ingestion.stores.redis

case class RedisEndpoint(host: String, port: Int, password: String)
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package feast.ingestion.stores.redis
import java.util
import com.google.protobuf.Timestamp
import com.google.protobuf.util.Timestamps
import com.redislabs.provider.redis.util.PipelineUtils.{foreachWithPipeline, mapWithPipeline}
import com.redislabs.provider.redis.{ReadWriteConfig, RedisConfig, RedisEndpoint, RedisNode}
import feast.ingestion.utils.TypeConversion
import feast.proto.storage.RedisProto.RedisKeyV2
import feast.proto.types.ValueProto
Expand All @@ -30,7 +28,7 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import redis.clients.jedis.util.JedisClusterCRC16
import redis.clients.jedis.Jedis

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -68,66 +66,65 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
else data

dataToStore.foreachPartition { partition: Iterator[Row] =>
// refresh redis cluster topology for each batch
implicit val redisConfig: RedisConfig = {
new RedisConfig(
new RedisEndpoint(sparkConf)
)
}

implicit val readWriteConfig: ReadWriteConfig = {
ReadWriteConfig.fromSparkConf(sparkConf)
val endpoint = RedisEndpoint(
host = sparkConf.get("spark.redis.host"),
port = sparkConf.get("spark.redis.port").toInt,
password = sparkConf.get("spark.redis.password", "")
)
val jedis = new Jedis(endpoint.host, endpoint.port)
if (endpoint.password.nonEmpty) {
jedis.auth(endpoint.password)
}
val clusterEnabled =
jedis.configGet("cluster-enabled").asScala.toList.reverse.headOption.contains("yes")
val pipelineProvider =
if (clusterEnabled) ClusterPipelineProvider(endpoint) else SingleNodePipelineProvider(jedis)

// grouped iterator to only allocate memory for a portion of rows
partition.grouped(config.iteratorGroupingSize).foreach { batch =>
// group by key and keep only latest row per each key
val rowsWithKey: Map[RedisKeyV2, Row] =
compactRowsToLatestTimestamp(batch.map(row => dataKeyId(row) -> row)).toMap

groupKeysByNode(redisConfig.hosts, rowsWithKey.keysIterator).foreach { case (node, keys) =>
val conn = node.connect()
// retrieve latest stored values
val storedValues = mapWithPipeline(conn, keys) { (pipeline, key) =>
persistence.get(pipeline, key.toByteArray)
}.map(_.asInstanceOf[util.Map[Array[Byte], Array[Byte]]])

val timestamps = storedValues.map(persistence.storedTimestamp)
val timestampByKey = keys.zip(timestamps).toMap

val expiryTimestampByKey = keys
.zip(storedValues)
.map { case (key, storedValue) =>
(key, newExpiryTimestamp(rowsWithKey(key), storedValue))
}
.toMap

foreachWithPipeline(conn, keys) { (pipeline, key) =>
val row = rowsWithKey(key)

timestampByKey(key) match {
case Some(t) if (t.after(row.getAs[java.sql.Timestamp](config.timestampColumn))) =>
()
case _ =>
if (metricSource.nonEmpty) {
val lag = System.currentTimeMillis() - row
.getAs[java.sql.Timestamp](config.timestampColumn)
.getTime

metricSource.get.METRIC_TOTAL_ROWS_INSERTED.inc()
metricSource.get.METRIC_ROWS_LAG.update(lag)
}
persistence.save(
pipeline,
key.toByteArray,
row,
expiryTimestampByKey(key),
MAX_EXPIRED_TIMESTAMP
)
}
val keys = rowsWithKey.keysIterator.toList
val readPipeline = pipelineProvider.pipeline()
val readResponses =
keys.map(key => persistence.get(readPipeline, key.toByteArray))
readPipeline.close()
val storedValues = readResponses.map(_.get())
val timestamps = storedValues.map(persistence.storedTimestamp)
val timestampByKey = keys.zip(timestamps).toMap
val expiryTimestampByKey = keys
.zip(storedValues)
.map { case (key, storedValue) =>
(key, newExpiryTimestamp(rowsWithKey(key), storedValue))
}
.toMap

val writePipeline = pipelineProvider.pipeline()
rowsWithKey.foreach { case (key, row) =>
timestampByKey(key) match {
case Some(t) if (t.after(row.getAs[java.sql.Timestamp](config.timestampColumn))) =>
()
case _ =>
if (metricSource.nonEmpty) {
val lag = System.currentTimeMillis() - row
.getAs[java.sql.Timestamp](config.timestampColumn)
.getTime

metricSource.get.METRIC_TOTAL_ROWS_INSERTED.inc()
metricSource.get.METRIC_ROWS_LAG.update(lag)
}
persistence.save(
writePipeline,
key.toByteArray,
row,
expiryTimestampByKey(key),
MAX_EXPIRED_TIMESTAMP
)
}
conn.close()
}
writePipeline.close()
}
}
dataToStore.unpersist()
Expand Down Expand Up @@ -175,24 +172,6 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
}
}

private def groupKeysByNode(
nodes: Array[RedisNode],
keys: Iterator[RedisKeyV2]
): Iterator[(RedisNode, Array[RedisKeyV2])] = {
keys
.map(key => (getMasterNode(nodes, key), key))
.toArray
.groupBy(_._1)
.map(x => (x._1, x._2.map(_._2)))
.iterator
}

private def getMasterNode(nodes: Array[RedisNode], key: RedisKeyV2): RedisNode = {
val slot = JedisClusterCRC16.getSlot(key.toByteArray)

nodes.filter { node => node.startSlot <= slot && node.endSlot >= slot }.filter(_.idx == 0)(0)
}

private def newExpiryTimestamp(
row: Row,
value: util.Map[Array[Byte], Array[Byte]]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2018-2022 The Feast Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.ingestion.stores.redis

import redis.clients.jedis.Jedis

/**
* Provide pipeline for single node Redis.
*/
case class SingleNodePipelineProvider(jedis: Jedis) extends PipelineProvider {

/**
* @return a single node redis pipeline
*/
override def pipeline(): UnifiedPipeline = jedis.pipelined()

}

0 comments on commit 0f4e27c

Please sign in to comment.