Skip to content

Commit

Permalink
Cleanup & fix incorrect schema in VectorConverter when output is not …
Browse files Browse the repository at this point in the history
…float array.
  • Loading branch information
Jelmer Kuperus committed Jul 9, 2022
1 parent 8312a13 commit d9cee73
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import com.github.jelmerk.spark.linalg.Normalizer
import org.apache.spark.internal.Logging
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.SQLDataTypes._
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType, FloatType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType, FloatType, StructType}

/**
* Companion class for VectorConverter.
Expand Down Expand Up @@ -63,8 +64,8 @@ class VectorConverter(override val uid: String)
case (ArrayType(DoubleType, _), "array<float>") => doubleArrayToFloatArray(col(getInputCol))
case (ArrayType(DoubleType, _), "vector") => doubleArrayToVector(col(getInputCol))

case (dataType, "array<float>") if dataType.typeName == "vector" => vectorToFloatArray(col(getInputCol))
case (dataType, "array<double>") if dataType.typeName == "vector" => vectorToDoubleArray(col(getInputCol))
case (VectorType, "array<float>") => vectorToFloatArray(col(getInputCol))
case (VectorType, "array<double>") => vectorToDoubleArray(col(getInputCol))

case _ => throw new IllegalArgumentException("Cannot convert vector")
})
Expand All @@ -84,7 +85,7 @@ class VectorConverter(override val uid: String)
val inputColumnSchema = schema(getInputCol)

val inputColHasValidDataType = inputColumnSchema.dataType match {
case dataType: DataType if dataType.typeName == "vector" => true
case VectorType => true
case ArrayType(DoubleType, _) => true
case _ => false
}
Expand All @@ -93,8 +94,14 @@ class VectorConverter(override val uid: String)
throw new IllegalArgumentException(s"Input column $getInputCol must be a double array or vector.")
}

val outputFields = schema.fields :+ StructField(getOutputCol, ArrayType(FloatType), inputColumnSchema.nullable)
StructType(outputFields)
val outputType: DataType = getOutputType match {
case "array<double>" => ArrayType(DoubleType)
case "array<float>" => ArrayType(FloatType)
case "vector" => VectorType
}

schema
.add(getOutputCol, outputType, inputColumnSchema.nullable)
}

private val vectorToFloatArray: UserDefinedFunction = udf { vector: Vector => vector.toArray.map(_.toFloat) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.github.jelmerk.spark.knn
import java.io.InputStream
import java.net.InetAddress
import java.util.concurrent.{CountDownLatch, ExecutionException, FutureTask, LinkedBlockingQueue, ThreadLocalRandom, ThreadPoolExecutor, TimeUnit}

import com.github.jelmerk.knn.ObjectSerializer

import scala.language.{higherKinds, implicitConversions}
Expand All @@ -28,6 +27,7 @@ import com.github.jelmerk.knn.scalalike._
import com.github.jelmerk.knn.util.NamedThreadFactory
import com.github.jelmerk.spark.linalg.functions.VectorDistanceFunctions
import com.github.jelmerk.spark.util.SerializableConfiguration
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder

import scala.annotation.tailrec
Expand Down Expand Up @@ -187,17 +187,24 @@ private[knn] trait KnnModelParams extends Params with HasFeaturesCol with HasPre
case _ => DoubleType
}

val neighborsField = StructField(getPredictionCol, ArrayType(StructType(Seq(StructField("neighbor", identifierDataType, nullable = false), StructField("distance", distanceType)))))
val predictionStruct = new StructType()
.add("neighbor", identifierDataType, nullable = false)
.add("distance", distanceType, nullable = false)

val neighborsField = StructField(getPredictionCol, new ArrayType(predictionStruct, containsNull = false))

getOutputFormat match {
case "minimal" if !isSet(queryIdentifierCol) => throw new IllegalArgumentException("queryIdentifierCol must be set when using outputFormat minimal.")
case "minimal" => StructType(Array(schema(getQueryIdentifierCol), neighborsField))
case "minimal" =>
new StructType()
.add(schema(getQueryIdentifierCol))
.add(neighborsField)
case _ =>
if (schema.fieldNames.contains(getPredictionCol)) {
throw new IllegalArgumentException(s"Output column $getPredictionCol already exists.")
}

StructType(schema.fields :+ neighborsField)
schema
.add(neighborsField)
}
}
}
Expand Down Expand Up @@ -737,14 +744,15 @@ private[knn] abstract class KnnAlgorithm[TModel <: Model[TModel]](override val u
val model = (identifierType, vectorType) match {
case (IntegerType, ArrayType(FloatType, _)) => typedFit[Int, Array[Float], IntFloatArrayIndexItem, Float](dataset)
case (IntegerType, ArrayType(DoubleType, _)) => typedFit[Int, Array[Double], IntDoubleArrayIndexItem, Double](dataset)
case (IntegerType, t) if t.typeName == "vector" => typedFit[Int, Vector, IntVectorIndexItem, Double](dataset)
case (IntegerType, VectorType) => typedFit[Int, Vector, IntVectorIndexItem, Double](dataset)
case (LongType, ArrayType(FloatType, _)) => typedFit[Long, Array[Float], LongFloatArrayIndexItem, Float](dataset)
case (LongType, ArrayType(DoubleType, _)) => typedFit[Long, Array[Double], LongDoubleArrayIndexItem, Double](dataset)
case (LongType, t) if t.typeName == "vector" => typedFit[Long, Vector, LongVectorIndexItem, Double](dataset)
case (LongType, VectorType) => typedFit[Long, Vector, LongVectorIndexItem, Double](dataset)
case (StringType, ArrayType(FloatType, _)) => typedFit[String, Array[Float], StringFloatArrayIndexItem, Float](dataset)
case (StringType, ArrayType(DoubleType, _)) => typedFit[String, Array[Double], StringDoubleArrayIndexItem, Double](dataset)
case (StringType, t) if t.typeName == "vector" => typedFit[String, Vector, StringVectorIndexItem, Double](dataset)
case _ => throw new IllegalArgumentException(s"Cannot create index for items with identifier of type " +
case (StringType, VectorType) => typedFit[String, Vector, StringVectorIndexItem, Double](dataset)
case _ =>
throw new IllegalArgumentException(s"Cannot create index for items with identifier of type " +
s"${identifierType.simpleString} and vector of type ${vectorType.simpleString}. " +
s"Supported identifiers are string, int, long and string. Supported vectors are array<float>, array<double> and vector ")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.github.jelmerk.knn.util.VectorUtils
import org.apache.spark.internal.Logging
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.ml.linalg.SQLDataTypes._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
Expand Down Expand Up @@ -36,8 +37,7 @@ class Normalizer(override val uid: String)
def setOutputCol(value: String): this.type = set(outputCol, value)

override def transform(dataset: Dataset[_]): DataFrame = dataset.schema(getInputCol).dataType match {
case dataType: DataType if dataType.typeName == "vector" =>
dataset.withColumn(getOutputCol, normalizeVector(col(getInputCol)))
case VectorType => dataset.withColumn(getOutputCol, normalizeVector(col(getInputCol)))
case ArrayType(FloatType, _) => dataset.withColumn(getOutputCol, normalizeFloatArray(col(getInputCol)))
case ArrayType(DoubleType, _) => dataset.withColumn(getOutputCol, normalizeDoubleArray(col(getInputCol)))
}
Expand All @@ -56,7 +56,7 @@ class Normalizer(override val uid: String)
val inputColumnSchema = schema(getInputCol)

val inputColHasValidDataType = inputColumnSchema.dataType match {
case dataType: DataType if dataType.typeName == "vector" => true
case VectorType => true
case ArrayType(FloatType, _) => true
case ArrayType(DoubleType, _) => true
case _ => false
Expand All @@ -66,8 +66,8 @@ class Normalizer(override val uid: String)
throw new IllegalArgumentException(s"Input column $getInputCol must be a float array, double array or vector.")
}

val outputFields = schema.fields :+ StructField(getOutputCol, inputColumnSchema.dataType, inputColumnSchema.nullable)
StructType(outputFields)
schema
.add(getOutputCol, inputColumnSchema.dataType, inputColumnSchema.nullable)
}

private def magnitude(vector: Vector): Double = {
Expand Down

0 comments on commit d9cee73

Please sign in to comment.