SPARKC-706: Add basic support for Cassandra vectors
jacek-lewandowski committed May 13, 2024
1 parent 6c6ce1b commit 972458c
Showing 12 changed files with 324 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ trait SparkCassandraITSpecBase

def pv = conn.withSessionDo(_.getContext.getProtocolVersion)

def report(message: String): Unit = alert(message)
def report(message: String): Unit = cancel(message)

val ks = getKsName

Expand All @@ -147,16 +147,23 @@ trait SparkCassandraITSpecBase

/** Skips the given test if the Cluster Version is lower or equal to the given `cassandra` Version or `dse` Version
* (if this is a DSE cluster) */
def from(cassandra: Version, dse: Version)(f: => Unit): Unit = {
def from(cassandra: Version, dse: Version)(f: => Unit): Unit = from(Some(cassandra), Some(dse))(f)
def from(cassandra: Option[Version] = None, dse: Option[Version] = None)(f: => Unit): Unit = {
if (isDse(conn)) {
dse match {
case Some(dseVersion) => from(dseVersion)(f)
case None => report(s"Skipped because not DSE")
} else {
cassandra match {
case Some(cassandraVersion) => from(cassandraVersion)(f)
case None => report(s"Skipped because not Cassandra")

/** Skips the given test if the Cluster Version is lower or equal to the given version */
def from(version: Version)(f: => Unit): Unit = {
private def from(version: Version)(f: => Unit): Unit = {
skip(cluster.getCassandraVersion, version) { f }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
| d14_varchar varchar,
| d15_varint varint,
| d16_address frozen<address>,
| d17_vector frozen<vector<int,3>>,
| PRIMARY KEY ((k1, k2, k3), c1, c2, c3)
Expand Down Expand Up @@ -111,12 +112,12 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {

"allow to read regular column definitions" in {
val columns = table.regularColumns
columns.size shouldBe 16
columns.size shouldBe 17 shouldBe Set(
"d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float",
"d6_inet", "d7_int", "d8_list", "d9_map", "d10_set",
"d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar",
"d15_varint", "d16_address")
"d15_varint", "d16_address", "d17_vector")

"allow to read proper types of columns" in {
Expand All @@ -136,6 +137,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
table.columnByName("d14_varchar").columnType shouldBe VarCharType
table.columnByName("d15_varint").columnType shouldBe VarIntType
table.columnByName("d16_address").columnType shouldBe a [UserDefinedType]
table.columnByName("d17_vector").columnType shouldBe VectorType[Int](IntType, 3)

"allow to list fields of a user defined type" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption
import com.datastax.oss.driver.api.core.cql.SimpleStatement
import com.datastax.oss.driver.api.core.cql.SimpleStatement._
import com.datastax.spark.connector._
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V6_7_0, V3_6_0}
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, DSE_V6_7_0, V3_6_0}
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf}
import com.datastax.spark.connector.mapper.{DefaultColumnMapper, JavaBeanColumnMapper, JavaTestBean, JavaTestUDTBean}
Expand Down Expand Up @@ -794,7 +794,7 @@ class CassandraRDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster
results should contain ((KeyGroup(3, 300), (3, 300, "0003")))

it should "allow the use of PER PARTITION LIMITs " in from(V3_6_0) {
it should "allow the use of PER PARTITION LIMITs " in from(cassandra = V3_6_0, dse = DSE_V5_1_0) {
val result = sc.cassandraTable(ks, "clustering_time").perPartitionLimit(1).collect
result.length should be (1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption._
import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, BoundStatement}
import com.datastax.oss.driver.api.core.{DefaultConsistencyLevel, DefaultProtocolVersion}
import com.datastax.spark.connector._
import com.datastax.spark.connector.ccm.CcmConfig.V3_6_0
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, V3_6_0}
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.CassandraConnector
import com.datastax.spark.connector.embedded.SparkTemplate._
Expand Down Expand Up @@ -425,7 +425,7 @@ class RDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster {


it should "should be joinable with a PER PARTITION LIMIT limit" in from(V3_6_0){
it should "should be joinable with a PER PARTITION LIMIT limit" in from(cassandra = V3_6_0, dse = DSE_V5_1_0){
val source = sc.parallelize(keys).map(x => (x, x * 100))
val someCass = source
.joinWithCassandraTable(ks, wideTable, joinColumns = SomeColumns("key", "group"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
package com.datastax.spark.connector.rdd.typeTests

import com.datastax.oss.driver.api.core.cql.Row
import com.datastax.oss.driver.api.core.{CqlSession, Version}
import com.datastax.spark.connector._
import com.datastax.spark.connector.ccm.CcmConfig
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.CassandraConnector
import com.datastax.spark.connector.datasource.CassandraCatalog
import com.datastax.spark.connector.mapper.ColumnMapper
import com.datastax.spark.connector.rdd.{ReadConf, ValidRDDType}
import com.datastax.spark.connector.rdd.reader.RowReaderFactory
import com.datastax.spark.connector.types.TypeConverter
import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.cassandra.{DataFrameReaderWrapper, DataFrameWriterWrapper}

import scala.collection.convert.ImplicitConversionsToScala._
import scala.collection.immutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe._
import scala.reflect._

abstract class VectorTypeTest[
ScalaType: ClassTag : TypeTag,
DriverType <: Number : ClassTag,
CaseClassType <: Product : ClassTag : TypeTag : ColumnMapper: RowReaderFactory : ValidRDDType](typeName: String) extends SparkCassandraITFlatSpecBase with DefaultCluster
/** Skips the given test if the cluster is not Cassandra */
override def cassandraOnly(f: => Unit): Unit = super.cassandraOnly(f)

override lazy val conn = CassandraConnector(sparkConf)

val VectorTable = "vectors"

def createVectorTable(session: CqlSession, table: String): Unit = {
| v VECTOR<$typeName, 3>

def minCassandraVersion: Option[Version] = Some(Version.parse("5.0-beta1"))

def minDSEVersion: Option[Version] = None

def vectorFromInts(ints: Int*): Seq[ScalaType]

def vectorItem(id: Int, v: Seq[ScalaType]): CaseClassType

override lazy val spark = SparkSession.builder()
.config("spark.sql.catalog.casscatalog", "com.datastax.spark.connector.datasource.CassandraCatalog")
.withExtensions(new CassandraSparkExtensions).getOrCreate().newSession()

override def beforeClass() {
conn.withSessionDo { session =>
|WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }"""

private def hasVectors(rows: List[Row], expectedVectors: Seq[Seq[ScalaType]]): Unit = {
val returnedVectors = for (i <- expectedVectors.indices) yield {
rows.find(_.getInt("id") == i + 1).get.getVector("v", implicitly[ClassTag[DriverType]].runtimeClass.asInstanceOf[Class[Number]]).iterator().toSeq

returnedVectors should contain theSameElementsInOrderAs expectedVectors

"SCC" should s"write case class instances with $typeName vector using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_caseclass_to_df"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))))
.cassandraFormat(table, ks)
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))

spark.createDataFrame(Seq(vectorItem(2, vectorFromInts(6, 5, 4)), vectorItem(3, vectorFromInts(7, 8, 9))))
.cassandraFormat(table, ks)
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(6, 5, 4), vectorFromInts(7, 8, 9)))

spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(9, 8, 7)), vectorItem(2, vectorFromInts(10, 11, 12))))
.cassandraFormat(table, ks)
.option("confirm.truncate", value = true)
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(9, 8, 7), vectorFromInts(10, 11, 12)))

it should s"write tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_tuple_to_df"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.createDataFrame(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.toDF("id", "v")
.cassandraFormat(table, ks)
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))

it should s"write case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_caseclass_to_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.sparkContext.parallelize(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))

it should s"write tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_tuple_to_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))

it should s"read case class instances with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_caseclass_from_df"
conn.withSessionDo { session =>
createVectorTable(session, table)
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

import spark.implicits._, ks).load().as[CaseClassType].collect() should contain theSameElementsAs
Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))

it should s"read tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_tuple_from_df"
conn.withSessionDo { session =>
createVectorTable(session, table)
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

import spark.implicits._, ks).load().as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))

it should s"read case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_caseclass_from_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

spark.sparkContext.cassandraTable[CaseClassType](ks, table).collect() should contain theSameElementsAs
Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))

it should s"read tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_tuple_from_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

spark.sparkContext.cassandraTable[(Int, Seq[ScalaType])](ks, table).collect() should contain theSameElementsAs
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))

it should s"read rows with $typeName vectors using SQL API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_rows_from_sql"
conn.withSessionDo { session =>
createVectorTable(session, table)
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

import spark.implicits._
spark.sql(s"SELECT * FROM casscatalog.$ks.$table").as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))


class IntVectorTypeTest extends VectorTypeTest[Int, Integer, IntVectorItem]("INT") {
override def vectorFromInts(ints: Int*): Seq[Int] = ints

override def vectorItem(id: Int, v: Seq[Int]): IntVectorItem = IntVectorItem(id, v)

case class IntVectorItem(id: Int, v: Seq[Int])

class LongVectorTypeTest extends VectorTypeTest[Long, java.lang.Long, LongVectorItem]("BIGINT") {
override def vectorFromInts(ints: Int*): Seq[Long] =

override def vectorItem(id: Int, v: Seq[Long]): LongVectorItem = LongVectorItem(id, v)

case class LongVectorItem(id: Int, v: Seq[Long])

class FloatVectorTypeTest extends VectorTypeTest[Float, java.lang.Float, FloatVectorItem]("FLOAT") {
override def vectorFromInts(ints: Int*): Seq[Float] = + 0.1f)

override def vectorItem(id: Int, v: Seq[Float]): FloatVectorItem = FloatVectorItem(id, v)

case class FloatVectorItem(id: Int, v: Seq[Float])

class DoubleVectorTypeTest extends VectorTypeTest[Double, java.lang.Double, DoubleVectorItem]("DOUBLE") {
override def vectorFromInts(ints: Int*): Seq[Double] = + 0.1d)

override def vectorItem(id: Int, v: Seq[Double]): DoubleVectorItem = DoubleVectorItem(id, v)

case class DoubleVectorItem(id: Int, v: Seq[Double])

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.datastax.spark.connector.datasource

import java.util.Locale
import com.datastax.oss.driver.api.core.ProtocolVersion
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType}
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType, VectorType}
import com.datastax.oss.driver.api.core.`type`.DataTypes._
import com.datastax.dse.driver.api.core.`type`.DseDataTypes._
import com.datastax.oss.driver.api.core.metadata.schema.{ColumnMetadata, RelationMetadata, TableMetadata}
Expand Down Expand Up @@ -167,6 +167,7 @@ object CassandraSourceUtil extends Logging {
case m: MapType => SparkSqlMapType(catalystDataType(m.getKeyType, nullable), catalystDataType(m.getValueType, nullable), nullable)
case udt: UserDefinedType => fromUdt(udt)
case t: TupleType => fromTuple(t)
case v: VectorType => ArrayType(catalystDataType(v.getElementType, nullable), nullable)
case VARINT =>
logWarning("VarIntType is mapped to catalystTypes.DecimalType with unlimited values.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ object DataTypeConverter extends Logging {
cassandraType match {
case connector.types.SetType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.ListType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.VectorType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.MapType(kt, vt, _) => catalystTypes.MapType(catalystDataType(kt, nullable), catalystDataType(vt, nullable), nullable)
case connector.types.UserDefinedType(_, fields, _) => catalystTypes.StructType(
case connector.types.TupleType(fields @ _* ) => catalystTypes.StructType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class GettableDataToMappedTypeConverter[T : TypeTag : ColumnMapper](
val argConverter = converter(argColumnType, argScalaType)

case (VectorType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)

case (SetType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ object MappedToGettableDataConverter extends Logging {
val valueConverter = converter(valueColumnType, valueScalaType)
TypeConverter.javaHashMapConverter(keyConverter, valueConverter)

case (VectorType(argColumnType, dimension), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)

case (tt @ TupleType(argColumnType1, argColumnType2),
TypeRef(_, Symbols.PairSymbol, List(argScalaType1, argScalaType2))) =>
val c1 = converter(argColumnType1.columnType, argScalaType1)
Expand Down

0 comments on commit 972458c

