From 972458c6415a2684172468c6cee5d848ddd9f196 Mon Sep 17 00:00:00 2001 From: Jacek Lewandowski Date: Mon, 6 May 2024 17:36:54 +0200 Subject: [PATCH] SPARKC-706: Add basic support for Cassandra vectors --- .../SparkCassandraITFlatSpecBase.scala | 17 +- .../spark/connector/cql/SchemaSpec.scala | 6 +- .../connector/rdd/CassandraRDDSpec.scala | 4 +- .../spark/connector/rdd/RDDSpec.scala | 4 +- .../rdd/typeTests/VectorTypeTest.scala | 242 ++++++++++++++++++ .../datasource/CassandraSourceUtil.scala | 3 +- .../sql/cassandra/DataTypeConverter.scala | 1 + .../GettableDataToMappedTypeConverter.scala | 4 + .../MappedToGettableDataConverter.scala | 4 + .../spark/connector/types/ColumnType.scala | 4 +- .../spark/connector/types/TypeConverter.scala | 29 ++- .../spark/connector/types/VectorType.scala | 20 ++ 12 files changed, 324 insertions(+), 14 deletions(-) create mode 100644 connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala create mode 100644 driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala diff --git a/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala b/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala index b1c77a0ba..92b014e34 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/SparkCassandraITFlatSpecBase.scala @@ -131,7 +131,7 @@ trait SparkCassandraITSpecBase def pv = conn.withSessionDo(_.getContext.getProtocolVersion) - def report(message: String): Unit = alert(message) + def report(message: String): Unit = cancel(message) val ks = getKsName @@ -147,16 +147,23 @@ trait SparkCassandraITSpecBase /** Skips the given test if the Cluster Version is lower or equal to the given `cassandra` Version or `dse` Version * (if this is a DSE cluster) */ - def from(cassandra: Version, dse: Version)(f: => Unit): Unit = { + def from(cassandra: Version, dse: Version)(f: => Unit): Unit = from(Some(cassandra), Some(dse))(f) + def from(cassandra: Option[Version] = None, dse: Option[Version] = None)(f: => Unit): Unit = { if (isDse(conn)) { - from(dse)(f) + dse match { + case Some(dseVersion) => from(dseVersion)(f) + case None => report(s"Skipped because not DSE") + } } else { - from(cassandra)(f) + cassandra match { + case Some(cassandraVersion) => from(cassandraVersion)(f) + case None => report(s"Skipped because not Cassandra") + } } } /** Skips the given test if the Cluster Version is lower or equal to the given version */ - def from(version: Version)(f: => Unit): Unit = { + private def from(version: Version)(f: => Unit): Unit = { skip(cluster.getCassandraVersion, version) { f } } diff --git a/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala index f30f09d0b..a3b30aef9 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/cql/SchemaSpec.scala @@ -40,6 +40,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster { | d14_varchar varchar, | d15_varint varint, | d16_address frozen
, + | d17_vector frozen>, | PRIMARY KEY ((k1, k2, k3), c1, c2, c3) |) """.stripMargin) @@ -111,12 +112,12 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster { "allow to read regular column definitions" in { val columns = table.regularColumns - columns.size shouldBe 16 + columns.size shouldBe 17 columns.map(_.columnName).toSet shouldBe Set( "d1_blob", "d2_boolean", "d3_decimal", "d4_double", "d5_float", "d6_inet", "d7_int", "d8_list", "d9_map", "d10_set", "d11_timestamp", "d12_uuid", "d13_timeuuid", "d14_varchar", - "d15_varint", "d16_address") + "d15_varint", "d16_address", "d17_vector") } "allow to read proper types of columns" in { @@ -136,6 +137,7 @@ class SchemaSpec extends SparkCassandraITWordSpecBase with DefaultCluster { table.columnByName("d14_varchar").columnType shouldBe VarCharType table.columnByName("d15_varint").columnType shouldBe VarIntType table.columnByName("d16_address").columnType shouldBe a [UserDefinedType] + table.columnByName("d17_vector").columnType shouldBe VectorType[Int](IntType, 3) } "allow to list fields of a user defined type" in { diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala index 3a9ac7e90..7bf60fe28 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/CassandraRDDSpec.scala @@ -9,7 +9,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption import com.datastax.oss.driver.api.core.cql.SimpleStatement import com.datastax.oss.driver.api.core.cql.SimpleStatement._ import com.datastax.spark.connector._ -import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V6_7_0, V3_6_0} +import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, DSE_V6_7_0, V3_6_0} import com.datastax.spark.connector.cluster.DefaultCluster import com.datastax.spark.connector.cql.{CassandraConnector, CassandraConnectorConf} import com.datastax.spark.connector.mapper.{DefaultColumnMapper, JavaBeanColumnMapper, JavaTestBean, JavaTestUDTBean} @@ -794,7 +794,7 @@ class CassandraRDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster results should contain ((KeyGroup(3, 300), (3, 300, "0003"))) } - it should "allow the use of PER PARTITION LIMITs " in from(V3_6_0) { + it should "allow the use of PER PARTITION LIMITs " in from(cassandra = V3_6_0, dse = DSE_V5_1_0) { val result = sc.cassandraTable(ks, "clustering_time").perPartitionLimit(1).collect result.length should be (1) } diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala index d882cbbd6..5175173d3 100644 --- a/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala +++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/RDDSpec.scala @@ -5,7 +5,7 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption._ import com.datastax.oss.driver.api.core.cql.{AsyncResultSet, BoundStatement} import com.datastax.oss.driver.api.core.{DefaultConsistencyLevel, DefaultProtocolVersion} import com.datastax.spark.connector._ -import com.datastax.spark.connector.ccm.CcmConfig.V3_6_0 +import com.datastax.spark.connector.ccm.CcmConfig.{DSE_V5_1_0, V3_6_0} import com.datastax.spark.connector.cluster.DefaultCluster import com.datastax.spark.connector.cql.CassandraConnector import com.datastax.spark.connector.embedded.SparkTemplate._ @@ -425,7 +425,7 @@ class RDDSpec extends SparkCassandraITFlatSpecBase with DefaultCluster { } - it should "should be joinable with a PER PARTITION LIMIT limit" in from(V3_6_0){ + it should "should be joinable with a PER PARTITION LIMIT limit" in from(cassandra = V3_6_0, dse = DSE_V5_1_0){ val source = sc.parallelize(keys).map(x => (x, x * 100)) val someCass = source .joinWithCassandraTable(ks, wideTable, joinColumns = SomeColumns("key", "group")) diff --git a/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala b/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala new file mode 100644 index 000000000..af860d84d --- /dev/null +++ b/connector/src/it/scala/com/datastax/spark/connector/rdd/typeTests/VectorTypeTest.scala @@ -0,0 +1,242 @@ +package com.datastax.spark.connector.rdd.typeTests + +import com.datastax.oss.driver.api.core.cql.Row +import com.datastax.oss.driver.api.core.{CqlSession, Version} +import com.datastax.spark.connector._ +import com.datastax.spark.connector.ccm.CcmConfig +import com.datastax.spark.connector.cluster.DefaultCluster +import com.datastax.spark.connector.cql.CassandraConnector +import com.datastax.spark.connector.datasource.CassandraCatalog +import com.datastax.spark.connector.mapper.ColumnMapper +import com.datastax.spark.connector.rdd.{ReadConf, ValidRDDType} +import com.datastax.spark.connector.rdd.reader.RowReaderFactory +import com.datastax.spark.connector.types.TypeConverter +import org.apache.spark.sql.{SaveMode, SparkSession} +import org.apache.spark.sql.cassandra.{DataFrameReaderWrapper, DataFrameWriterWrapper} + +import scala.collection.convert.ImplicitConversionsToScala._ +import scala.collection.immutable +import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ +import scala.reflect._ + + +abstract class VectorTypeTest[ + ScalaType: ClassTag : TypeTag, + DriverType <: Number : ClassTag, + CaseClassType <: Product : ClassTag : TypeTag : ColumnMapper: RowReaderFactory : ValidRDDType](typeName: String) extends SparkCassandraITFlatSpecBase with DefaultCluster +{ + /** Skips the given test if the cluster is not Cassandra */ + override def cassandraOnly(f: => Unit): Unit = super.cassandraOnly(f) + + override lazy val conn = CassandraConnector(sparkConf) + + val VectorTable = "vectors" + + def createVectorTable(session: CqlSession, table: String): Unit = { + session.execute( + s"""CREATE TABLE IF NOT EXISTS $ks.$table ( + | id INT PRIMARY KEY, + | v VECTOR<$typeName, 3> + |)""".stripMargin) + } + + def minCassandraVersion: Option[Version] = Some(Version.parse("5.0-beta1")) + + def minDSEVersion: Option[Version] = None + + def vectorFromInts(ints: Int*): Seq[ScalaType] + + def vectorItem(id: Int, v: Seq[ScalaType]): CaseClassType + + override lazy val spark = SparkSession.builder() + .config(sparkConf) + .config("spark.sql.catalog.casscatalog", "com.datastax.spark.connector.datasource.CassandraCatalog") + .withExtensions(new CassandraSparkExtensions).getOrCreate().newSession() + + override def beforeClass() { + conn.withSessionDo { session => + session.execute( + s"""CREATE KEYSPACE IF NOT EXISTS $ks + |WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }""" + .stripMargin) + } + } + + private def hasVectors(rows: List[Row], expectedVectors: Seq[Seq[ScalaType]]): Unit = { + val returnedVectors = for (i <- expectedVectors.indices) yield { + rows.find(_.getInt("id") == i + 1).get.getVector("v", implicitly[ClassTag[DriverType]].runtimeClass.asInstanceOf[Class[Number]]).iterator().toSeq + } + + returnedVectors should contain theSameElementsInOrderAs expectedVectors + } + + "SCC" should s"write case class instances with $typeName vector using DataFrame API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_write_caseclass_to_df" + conn.withSessionDo { session => + createVectorTable(session, table) + + spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))) + .write + .cassandraFormat(table, ks) + .mode(SaveMode.Append) + .save() + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) + + spark.createDataFrame(Seq(vectorItem(2, vectorFromInts(6, 5, 4)), vectorItem(3, vectorFromInts(7, 8, 9)))) + .write + .cassandraFormat(table, ks) + .mode(SaveMode.Append) + .save() + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(1, 2, 3), vectorFromInts(6, 5, 4), vectorFromInts(7, 8, 9))) + + spark.createDataFrame(Seq(vectorItem(1, vectorFromInts(9, 8, 7)), vectorItem(2, vectorFromInts(10, 11, 12)))) + .write + .cassandraFormat(table, ks) + .mode(SaveMode.Overwrite) + .option("confirm.truncate", value = true) + .save() + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(9, 8, 7), vectorFromInts(10, 11, 12))) + } + } + + it should s"write tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_write_tuple_to_df" + conn.withSessionDo { session => + createVectorTable(session, table) + + spark.createDataFrame(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .toDF("id", "v") + .write + .cassandraFormat(table, ks) + .mode(SaveMode.Append) + .save() + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) + } + } + + it should s"write case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_write_caseclass_to_rdd" + conn.withSessionDo { session => + createVectorTable(session, table) + + spark.sparkContext.parallelize(Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) + } + } + + it should s"write tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_write_tuple_to_rdd" + conn.withSessionDo { session => + createVectorTable(session, table) + + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + hasVectors(session.execute(s"SELECT * FROM $ks.$table").all().iterator().toList, + Seq(vectorFromInts(1, 2, 3), vectorFromInts(4, 5, 6))) + } + } + + it should s"read case class instances with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_read_caseclass_from_df" + conn.withSessionDo { session => + createVectorTable(session, table) + } + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + + import spark.implicits._ + spark.read.cassandraFormat(table, ks).load().as[CaseClassType].collect() should contain theSameElementsAs + Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))) + } + + it should s"read tuples with $typeName vectors using DataFrame API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_read_tuple_from_df" + conn.withSessionDo { session => + createVectorTable(session, table) + } + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + + import spark.implicits._ + spark.read.cassandraFormat(table, ks).load().as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs + Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) + } + + it should s"read case class instances with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_read_caseclass_from_rdd" + conn.withSessionDo { session => + createVectorTable(session, table) + } + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + + spark.sparkContext.cassandraTable[CaseClassType](ks, table).collect() should contain theSameElementsAs + Seq(vectorItem(1, vectorFromInts(1, 2, 3)), vectorItem(2, vectorFromInts(4, 5, 6))) + } + + it should s"read tuples with $typeName vectors using RDD API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_read_tuple_from_rdd" + conn.withSessionDo { session => + createVectorTable(session, table) + } + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + + spark.sparkContext.cassandraTable[(Int, Seq[ScalaType])](ks, table).collect() should contain theSameElementsAs + Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) + } + + it should s"read rows with $typeName vectors using SQL API" in from(minCassandraVersion, minDSEVersion) { + val table = s"${typeName.toLowerCase}_read_rows_from_sql" + conn.withSessionDo { session => + createVectorTable(session, table) + } + spark.sparkContext.parallelize(Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6)))) + .saveToCassandra(ks, table) + + import spark.implicits._ + spark.sql(s"SELECT * FROM casscatalog.$ks.$table").as[(Int, Seq[ScalaType])].collect() should contain theSameElementsAs + Seq((1, vectorFromInts(1, 2, 3)), (2, vectorFromInts(4, 5, 6))) + } + +} + +class IntVectorTypeTest extends VectorTypeTest[Int, Integer, IntVectorItem]("INT") { + override def vectorFromInts(ints: Int*): Seq[Int] = ints + + override def vectorItem(id: Int, v: Seq[Int]): IntVectorItem = IntVectorItem(id, v) +} + +case class IntVectorItem(id: Int, v: Seq[Int]) + +class LongVectorTypeTest extends VectorTypeTest[Long, java.lang.Long, LongVectorItem]("BIGINT") { + override def vectorFromInts(ints: Int*): Seq[Long] = ints.map(_.toLong) + + override def vectorItem(id: Int, v: Seq[Long]): LongVectorItem = LongVectorItem(id, v) +} + +case class LongVectorItem(id: Int, v: Seq[Long]) + +class FloatVectorTypeTest extends VectorTypeTest[Float, java.lang.Float, FloatVectorItem]("FLOAT") { + override def vectorFromInts(ints: Int*): Seq[Float] = ints.map(_.toFloat + 0.1f) + + override def vectorItem(id: Int, v: Seq[Float]): FloatVectorItem = FloatVectorItem(id, v) +} + +case class FloatVectorItem(id: Int, v: Seq[Float]) + +class DoubleVectorTypeTest extends VectorTypeTest[Double, java.lang.Double, DoubleVectorItem]("DOUBLE") { + override def vectorFromInts(ints: Int*): Seq[Double] = ints.map(_.toDouble + 0.1d) + + override def vectorItem(id: Int, v: Seq[Double]): DoubleVectorItem = DoubleVectorItem(id, v) +} + +case class DoubleVectorItem(id: Int, v: Seq[Double]) + diff --git a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala index 19764a0a6..657fab249 100644 --- a/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala +++ b/connector/src/main/scala/com/datastax/spark/connector/datasource/CassandraSourceUtil.scala @@ -2,7 +2,7 @@ package com.datastax.spark.connector.datasource import java.util.Locale import com.datastax.oss.driver.api.core.ProtocolVersion -import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType} +import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes, ListType, MapType, SetType, TupleType, UserDefinedType, VectorType} import com.datastax.oss.driver.api.core.`type`.DataTypes._ import com.datastax.dse.driver.api.core.`type`.DseDataTypes._ import com.datastax.oss.driver.api.core.metadata.schema.{ColumnMetadata, RelationMetadata, TableMetadata} @@ -167,6 +167,7 @@ object CassandraSourceUtil extends Logging { case m: MapType => SparkSqlMapType(catalystDataType(m.getKeyType, nullable), catalystDataType(m.getValueType, nullable), nullable) case udt: UserDefinedType => fromUdt(udt) case t: TupleType => fromTuple(t) + case v: VectorType => ArrayType(catalystDataType(v.getElementType, nullable), nullable) case VARINT => logWarning("VarIntType is mapped to catalystTypes.DecimalType with unlimited values.") primitiveCatalystDataType(cassandraType) diff --git a/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala b/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala index 1287bf23d..0f279215d 100644 --- a/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala +++ b/connector/src/main/scala/org/apache/spark/sql/cassandra/DataTypeConverter.scala @@ -59,6 +59,7 @@ object DataTypeConverter extends Logging { cassandraType match { case connector.types.SetType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable) case connector.types.ListType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable) + case connector.types.VectorType(et, _) => catalystTypes.ArrayType(catalystDataType(et, nullable), nullable) case connector.types.MapType(kt, vt, _) => catalystTypes.MapType(catalystDataType(kt, nullable), catalystDataType(vt, nullable), nullable) case connector.types.UserDefinedType(_, fields, _) => catalystTypes.StructType(fields.map(catalystStructField)) case connector.types.TupleType(fields @ _* ) => catalystTypes.StructType(fields.map(catalystStructFieldFromTuple)) diff --git a/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala index b4afa5da6..3c68b5f52 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/mapper/GettableDataToMappedTypeConverter.scala @@ -93,6 +93,10 @@ class GettableDataToMappedTypeConverter[T : TypeTag : ColumnMapper]( val argConverter = converter(argColumnType, argScalaType) TypeConverter.forType[U](Seq(argConverter)) + case (VectorType(argColumnType, _), TypeRef(_, _, List(argScalaType))) => + val argConverter = converter(argColumnType, argScalaType) + TypeConverter.forType[U](Seq(argConverter)) + case (SetType(argColumnType, _), TypeRef(_, _, List(argScalaType))) => val argConverter = converter(argColumnType, argScalaType) TypeConverter.forType[U](Seq(argConverter)) diff --git a/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala index 3b69267ae..05f682f47 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/mapper/MappedToGettableDataConverter.scala @@ -82,6 +82,10 @@ object MappedToGettableDataConverter extends Logging { val valueConverter = converter(valueColumnType, valueScalaType) TypeConverter.javaHashMapConverter(keyConverter, valueConverter) + case (VectorType(argColumnType, dimension), TypeRef(_, _, List(argScalaType))) => + val argConverter = converter(argColumnType, argScalaType) + TypeConverter.cqlVectorConverter(dimension)(argConverter.asInstanceOf[TypeConverter[Number]]) + case (tt @ TupleType(argColumnType1, argColumnType2), TypeRef(_, Symbols.PairSymbol, List(argScalaType1, argScalaType2))) => val c1 = converter(argColumnType1.columnType, argScalaType1) diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala b/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala index 99a3f9fb3..b5aaf57fb 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/types/ColumnType.scala @@ -7,7 +7,7 @@ import java.util.{Date, UUID} import com.datastax.dse.driver.api.core.`type`.DseDataTypes import com.datastax.oss.driver.api.core.DefaultProtocolVersion.V4 import com.datastax.oss.driver.api.core.ProtocolVersion -import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType} +import com.datastax.oss.driver.api.core.`type`.{DataType, DataTypes => DriverDataTypes, ListType => DriverListType, MapType => DriverMapType, SetType => DriverSetType, TupleType => DriverTupleType, UserDefinedType => DriverUserDefinedType, VectorType => DriverVectorType} import com.datastax.spark.connector.util._ @@ -77,6 +77,7 @@ object ColumnType { case mapType: DriverMapType => MapType(fromDriverType(mapType.getKeyType), fromDriverType(mapType.getValueType), mapType.isFrozen) case userType: DriverUserDefinedType => UserDefinedType(userType) case tupleType: DriverTupleType => TupleType(tupleType) + case vectorType: DriverVectorType => VectorType(fromDriverType(vectorType.getElementType), vectorType.getDimensions) case dataType => primitiveTypeMap(dataType) } @@ -153,6 +154,7 @@ object ColumnType { val converter: TypeConverter[_] = dataType match { case list: DriverListType => TypeConverter.javaArrayListConverter(converterToCassandra(list.getElementType)) + case vec: DriverVectorType => TypeConverter.cqlVectorConverter(vec.getDimensions)(converterToCassandra(vec.getElementType).asInstanceOf[TypeConverter[Number]]) case set: DriverSetType => TypeConverter.javaHashSetConverter(converterToCassandra(set.getElementType)) case map: DriverMapType => TypeConverter.javaHashMapConverter(converterToCassandra(map.getKeyType), converterToCassandra(map.getValueType)) case udt: DriverUserDefinedType => new UserDefinedType.DriverUDTValueConverter(udt) diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala b/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala index 58615965b..ea13093e9 100644 --- a/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala +++ b/driver/src/main/scala/com/datastax/spark/connector/types/TypeConverter.scala @@ -9,7 +9,7 @@ import java.util.{Calendar, Date, GregorianCalendar, TimeZone, UUID} import com.datastax.dse.driver.api.core.data.geometry.{LineString, Point, Polygon} import com.datastax.dse.driver.api.core.data.time.DateRange -import com.datastax.oss.driver.api.core.data.CqlDuration +import com.datastax.oss.driver.api.core.data.{CqlDuration, CqlVector} import com.datastax.spark.connector.TupleValue import com.datastax.spark.connector.UDTValue.UDTValueConverter import com.datastax.spark.connector.util.ByteBufferUtil @@ -700,6 +700,7 @@ object TypeConverter { case x: java.util.List[_] => newCollection(x.asScala) case x: java.util.Set[_] => newCollection(x.asScala) case x: java.util.Map[_, _] => newCollection(x.asScala) + case x: CqlVector[_] => newCollection(x.asScala) case x: Iterable[_] => newCollection(x) } } @@ -768,6 +769,29 @@ object TypeConverter { } } + class CqlVectorConverter[T <: Number : TypeConverter](dimension: Int) extends TypeConverter[CqlVector[T]] { + val elemConverter = implicitly[TypeConverter[T]] + + implicit def elemTypeTag: TypeTag[T] = elemConverter.targetTypeTag + + @transient + lazy val targetTypeTag = { + implicitly[TypeTag[CqlVector[T]]] + } + + def newCollection(items: Iterable[Any]): java.util.List[T] = { + val buf = new java.util.ArrayList[T](dimension) + for (item <- items) buf.add(elemConverter.convert(item)) + buf + } + + def convertPF = { + case x: CqlVector[_] => x.asInstanceOf[CqlVector[T]] // it is an optimization - should we skip converting the elements? + case x: java.lang.Iterable[_] => CqlVector.newInstance[T](newCollection(x.asScala)) + case x: Iterable[_] => CqlVector.newInstance[T](newCollection(x)) + } + } + class JavaArrayListConverter[T : TypeConverter] extends CollectionConverter[java.util.ArrayList[T], T] { @transient lazy val targetTypeTag = { @@ -869,6 +893,9 @@ object TypeConverter { implicit def javaArrayListConverter[T : TypeConverter]: JavaArrayListConverter[T] = new JavaArrayListConverter[T] + implicit def cqlVectorConverter[T <: Number : TypeConverter](dimension: Int): CqlVectorConverter[T] = + new CqlVectorConverter[T](dimension) + implicit def javaSetConverter[T : TypeConverter]: JavaSetConverter[T] = new JavaSetConverter[T] diff --git a/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala b/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala new file mode 100644 index 000000000..8060fe225 --- /dev/null +++ b/driver/src/main/scala/com/datastax/spark/connector/types/VectorType.scala @@ -0,0 +1,20 @@ +package com.datastax.spark.connector.types + +import scala.language.existentials +import scala.reflect.runtime.universe._ + +case class VectorType[T](elemType: ColumnType[T], dimension: Int) extends ColumnType[Seq[T]] { + + override def isCollection: Boolean = false + + @transient + lazy val scalaTypeTag = { + implicit val elemTypeTag = elemType.scalaTypeTag + implicitly[TypeTag[Seq[T]]] + } + + def cqlTypeName = s"vector<${elemType.cqlTypeName}, ${dimension}>" + + override def converterToCassandra: TypeConverter[_ <: AnyRef] = + new TypeConverter.OptionToNullConverter(TypeConverter.seqConverter(elemType.converterToCassandra)) +}