Skip to content

Commit

Permalink
SPARKC-706: Add basic support for Cassandra vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
jacek-lewandowski committed May 13, 2024
1 parent 6c6ce1b commit 972458c
Show file tree
Hide file tree
Showing 12 changed files with 324 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ trait SparkCassandraITSpecBase

def pv = conn.withSessionDo(_.getContext.getProtocolVersion)

def report(message: String): Unit = alert(message)
def report(message: String): Unit = cancel(message)

val ks = getKsName

Expand All @@ -147,16 +147,23 @@ trait SparkCassandraITSpecBase

/** Skips the given test if the Cluster Version is lower or equal to the given `cassandra` Version or `dse` Version
* (if this is a DSE cluster) */
def from(cassandra: Version, dse: Version)(f: => Unit): Unit = {
def from(cassandra: Version, dse: Version)(f: => Unit): Unit = from(Some(cassandra), Some(dse))(f)
def from(cassandra: Option[Version] = None, dse: Option[Version] = None)(f: => Unit): Unit = {
if (isDse(conn)) {
from(dse)(f)
dse match {
case Some(dseVersion) => from(dseVersion)(f)
case None => report(s"Skipped because not DSE")
}
} else {
from(cassandra)(f)
cassandra match {
case Some(cassandraVersion) => from(cassandraVersion)(f)
case None => report(s"Skipped because not Cassandra")
}
}
}

/** Skips the given test if the Cluster Version is lower or equal to the given version */
def from(version: Version)(f: => Unit): Unit = {
private def from(version: Version)(f: => Unit): Unit = {
skip(cluster.getCassandraVersion, version) { f }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
| d14_varchar varchar,
| d15_varint varint,
| d16_address frozen<address>,
| d17_vector frozen<vector<int,3>>,
| PRIMARY KEY ((k1, k2, k3), c1, c2, c3)
|)
""".stripMargin)
Expand Down Expand Up @@ -111,12 +112,12 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {

"allow to read regular column definitions" in {
val columns = table.regularColumns
columns.size shouldBe 16
columns.size shouldBe 17
columns.map(_.columnName).toSet shouldBe Set(
"d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float",
"d6_inet", "d7_int", "d8_list", "d9_map", "d10_set",
"d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar",
"d15_varint", "d16_address")
"d15_varint", "d16_address", "d17_vector")
}

"allow to read proper types of columns" in {
Expand All @@ -136,6 +137,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster {
table.columnByName("d14_varchar").columnType shouldBe VarCharType
table.columnByName("d15_varint").columnType shouldBe VarIntType
table.columnByName("d16_address").columnType shouldBe a [UserDefinedType]
table.columnByName("d17_vector").columnType shouldBe VectorType[Int](IntType, 3)
}

"allow to list fields of a user defined type" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption
import com.datastax.oss.driver.api.core.cql.SimpleStatement
import com.datastax.oss.driver.api.core.cql.SimpleStatement._
import com.datastax.spark.connector._
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V6_7_0, V3_6_0}
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, DSE_V6_7_0, V3_6_0}
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf}
import com.datastax.spark.connector.mapper.{DefaultColumnMapper, JavaBeanColumnMapper, JavaTestBean, JavaTestUDTBean}
Expand Down Expand Up @@ -794,7 +794,7 @@ class CassandraRDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster
results should contain ((KeyGroup(3, 300), (3, 300, "0003")))
}

it should "allow the use of PER PARTITION LIMITs " in from(V3_6_0) {
it should "allow the use of PER PARTITION LIMITs " in from(cassandra = V3_6_0, dse = DSE_V5_1_0) {
val result = sc.cassandraTable(ks, "clustering_time").perPartitionLimit(1).collect
result.length should be (1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption._
import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, BoundStatement}
import com.datastax.oss.driver.api.core.{DefaultConsistencyLevel, DefaultProtocolVersion}
import com.datastax.spark.connector._
import com.datastax.spark.connector.ccm.CcmConfig.V3_6_0
import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, V3_6_0}
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.CassandraConnector
import com.datastax.spark.connector.embedded.SparkTemplate._
Expand Down Expand Up @@ -425,7 +425,7 @@ class RDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster {

}

it should "should be joinable with a PER PARTITION LIMIT limit" in from(V3_6_0){
it should "should be joinable with a PER PARTITION LIMIT limit" in from(cassandra = V3_6_0, dse = DSE_V5_1_0){
val source = sc.parallelize(keys).map(x => (x, x * 100))
val someCass = source
.joinWithCassandraTable(ks, wideTable, joinColumns = SomeColumns("key", "group"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
package com.datastax.spark.connector.rdd.typeTests

import com.datastax.oss.driver.api.core.cql.Row
import com.datastax.oss.driver.api.core.{CqlSession, Version}
import com.datastax.spark.connector._
import com.datastax.spark.connector.ccm.CcmConfig
import com.datastax.spark.connector.cluster.DefaultCluster
import com.datastax.spark.connector.cql.CassandraConnector
import com.datastax.spark.connector.datasource.CassandraCatalog
import com.datastax.spark.connector.mapper.ColumnMapper
import com.datastax.spark.connector.rdd.{ReadConf, ValidRDDType}
import com.datastax.spark.connector.rdd.reader.RowReaderFactory
import com.datastax.spark.connector.types.TypeConverter
import org.apache.spark.sql.{SaveMode, SparkSession}
import org.apache.spark.sql.cassandra.{DataFrameReaderWrapper, DataFrameWriterWrapper}

import scala.collection.convert.ImplicitConversionsToScala._
import scala.collection.immutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe._
import scala.reflect._


abstract class VectorTypeTest[
ScalaType: ClassTag : TypeTag,
DriverType <: Number : ClassTag,
CaseClassType <: Product : ClassTag : TypeTag : ColumnMapper: RowReaderFactory : ValidRDDType](typeName: String) extends SparkCassandraITFlatSpecBase with DefaultCluster
{
/** Skips the given test if the cluster is not Cassandra */
override def cassandraOnly(f: => Unit): Unit = super.cassandraOnly(f)

override lazy val conn = CassandraConnector(sparkConf)

val VectorTable = "vectors"

def createVectorTable(session: CqlSession, table: String): Unit = {
session.execute(
s"""CREATE TABLE IF NOT EXISTS $ks.$table (
| id INT PRIMARY KEY,
| v VECTOR<$typeName, 3>
|)""".stripMargin)
}

def minCassandraVersion: Option[Version] = Some(Version.parse("5.0-beta1"))

def minDSEVersion: Option[Version] = None

def vectorFromInts(ints: Int*): Seq[ScalaType]

def vectorItem(id: Int, v: Seq[ScalaType]): CaseClassType

override lazy val spark = SparkSession.builder()
.config(sparkConf)
.config("spark.sql.catalog.casscatalog", "com.datastax.spark.connector.datasource.CassandraCatalog")
.withExtensions(new CassandraSparkExtensions).getOrCreate().newSession()

override def beforeClass() {
conn.withSessionDo { session =>
session.execute(
s"""CREATE KEYSPACE IF NOT EXISTS $ks
|WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }"""
.stripMargin)
}
}

private def hasVectors(rows: List[Row], expectedVectors: Seq[Seq[ScalaType]]): Unit = {
val returnedVectors = for (i <- expectedVectors.indices) yield {
rows.find(_.getInt("id") == i + 1).get.getVector("v", implicitly[ClassTag[DriverType]].runtimeClass.asInstanceOf[Class[Number]]).iterator().toSeq
}

returnedVectors should contain theSameElementsInOrderAs expectedVectors
}

"SCC" should s"write case class instances with $typeName vector using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_caseclass_to_df"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))))
.write
.cassandraFormat(table, ks)
.mode(SaveMode.Append)
.save()
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))

spark.createDataFrame(Seq(vectorItem(2, vectorFromInts(6, 5, 4)), vectorItem(3, vectorFromInts(7, 8, 9))))
.write
.cassandraFormat(table, ks)
.mode(SaveMode.Append)
.save()
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(6, 5, 4), vectorFromInts(7, 8, 9)))

spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(9, 8, 7)), vectorItem(2, vectorFromInts(10, 11, 12))))
.write
.cassandraFormat(table, ks)
.mode(SaveMode.Overwrite)
.option("confirm.truncate", value = true)
.save()
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(9, 8, 7), vectorFromInts(10, 11, 12)))
}
}

it should s"write tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_tuple_to_df"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.createDataFrame(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.toDF("id", "v")
.write
.cassandraFormat(table, ks)
.mode(SaveMode.Append)
.save()
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
}
}

it should s"write case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_caseclass_to_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.sparkContext.parallelize(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
}
}

it should s"write tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_write_tuple_to_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)

spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)
hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList,
Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6)))
}
}

it should s"read case class instances with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_caseclass_from_df"
conn.withSessionDo { session =>
createVectorTable(session, table)
}
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

import spark.implicits._
spark.read.cassandraFormat(table, ks).load().as[CaseClassType].collect() should contain theSameElementsAs
Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))
}

it should s"read tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_tuple_from_df"
conn.withSessionDo { session =>
createVectorTable(session, table)
}
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

import spark.implicits._
spark.read.cassandraFormat(table, ks).load().as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
}

it should s"read case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_caseclass_from_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)
}
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

spark.sparkContext.cassandraTable[CaseClassType](ks, table).collect() should contain theSameElementsAs
Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))
}

it should s"read tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_tuple_from_rdd"
conn.withSessionDo { session =>
createVectorTable(session, table)
}
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

spark.sparkContext.cassandraTable[(Int, Seq[ScalaType])](ks, table).collect() should contain theSameElementsAs
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
}

it should s"read rows with $typeName vectors using SQL API" in from(minCassandraVersion, minDSEVersion) {
val table = s"${typeName.toLowerCase}_read_rows_from_sql"
conn.withSessionDo { session =>
createVectorTable(session, table)
}
spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))))
.saveToCassandra(ks, table)

import spark.implicits._
spark.sql(s"SELECT * FROM casscatalog.$ks.$table").as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs
Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))
}

}

class IntVectorTypeTest extends VectorTypeTest[Int, Integer, IntVectorItem]("INT") {
override def vectorFromInts(ints: Int*): Seq[Int] = ints

override def vectorItem(id: Int, v: Seq[Int]): IntVectorItem = IntVectorItem(id, v)
}

case class IntVectorItem(id: Int, v: Seq[Int])

class LongVectorTypeTest extends VectorTypeTest[Long, java.lang.Long, LongVectorItem]("BIGINT") {
override def vectorFromInts(ints: Int*): Seq[Long] = ints.map(_.toLong)

override def vectorItem(id: Int, v: Seq[Long]): LongVectorItem = LongVectorItem(id, v)
}

case class LongVectorItem(id: Int, v: Seq[Long])

class FloatVectorTypeTest extends VectorTypeTest[Float, java.lang.Float, FloatVectorItem]("FLOAT") {
override def vectorFromInts(ints: Int*): Seq[Float] = ints.map(_.toFloat + 0.1f)

override def vectorItem(id: Int, v: Seq[Float]): FloatVectorItem = FloatVectorItem(id, v)
}

case class FloatVectorItem(id: Int, v: Seq[Float])

class DoubleVectorTypeTest extends VectorTypeTest[Double, java.lang.Double, DoubleVectorItem]("DOUBLE") {
override def vectorFromInts(ints: Int*): Seq[Double] = ints.map(_.toDouble + 0.1d)

override def vectorItem(id: Int, v: Seq[Double]): DoubleVectorItem = DoubleVectorItem(id, v)
}

case class DoubleVectorItem(id: Int, v: Seq[Double])

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.datastax.spark.connector.datasource

import java.util.Locale
import com.datastax.oss.driver.api.core.ProtocolVersion
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType}
import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType, VectorType}
import com.datastax.oss.driver.api.core.`type`.DataTypes._
import com.datastax.dse.driver.api.core.`type`.DseDataTypes._
import com.datastax.oss.driver.api.core.metadata.schema.{ColumnMetadata, RelationMetadata, TableMetadata}
Expand Down Expand Up @@ -167,6 +167,7 @@ object CassandraSourceUtil extends Logging {
case m: MapType => SparkSqlMapType(catalystDataType(m.getKeyType, nullable), catalystDataType(m.getValueType, nullable), nullable)
case udt: UserDefinedType => fromUdt(udt)
case t: TupleType => fromTuple(t)
case v: VectorType => ArrayType(catalystDataType(v.getElementType, nullable), nullable)
case VARINT =>
logWarning("VarIntType is mapped to catalystTypes.DecimalType with unlimited values.")
primitiveCatalystDataType(cassandraType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ object DataTypeConverter extends Logging {
cassandraType match {
case connector.types.SetType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.ListType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.VectorType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable)
case connector.types.MapType(kt, vt, _) => catalystTypes.MapType(catalystDataType(kt, nullable), catalystDataType(vt, nullable), nullable)
case connector.types.UserDefinedType(_, fields, _) => catalystTypes.StructType(fields.map(catalystStructField))
case connector.types.TupleType(fields @ _* ) => catalystTypes.StructType(fields.map(catalystStructFieldFromTuple))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class GettableDataToMappedTypeConverter[T : TypeTag : ColumnMapper](
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))

case (VectorType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))

case (SetType(argColumnType, _), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.forType[U](Seq(argConverter))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ object MappedToGettableDataConverter extends Logging {
val valueConverter = converter(valueColumnType, valueScalaType)
TypeConverter.javaHashMapConverter(keyConverter, valueConverter)

case (VectorType(argColumnType, dimension), TypeRef(_, _, List(argScalaType))) =>
val argConverter = converter(argColumnType, argScalaType)
TypeConverter.cqlVectorConverter(dimension)(argConverter.asInstanceOf[TypeConverter[Number]])

case (tt @ TupleType(argColumnType1, argColumnType2),
TypeRef(_, Symbols.PairSymbol, List(argScalaType1, argScalaType2))) =>
val c1 = converter(argColumnType1.columnType, argScalaType1)
Expand Down

0 comments on commit 972458c

Please sign in to comment.