Skip to content

Commit

Permalink
Add scala fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
jelmerk committed Jan 5, 2024
1 parent 11857ec commit 51c1711
Show file tree
Hide file tree
Showing 23 changed files with 1,328 additions and 954 deletions.
12 changes: 5 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ permissions:

on:
pull_request:
paths:
- '**'
paths-ignore:
- '**.md'
push:
branches:
- '*'
branches-ignore:
- '!master'
tags-ignore:
- 'v[0-9]+.[0-9]+.[0-9]+'
paths-ignore:
- '**.md'

jobs:
ci-pipeline:
Expand Down Expand Up @@ -45,7 +43,7 @@ jobs:
3.9
- name: Build and test
run: |
sbt -java-home "$JAVA_HOME_8_X64" clean +test -DsparkVersion="$SPARK_VERSION"
sbt -java-home "$JAVA_HOME_8_X64" clean scalafmtCheckAll +test -DsparkVersion="$SPARK_VERSION"
- name: Publish Unit test results
uses: mikepenz/action-junit-report@v4
with:
Expand Down
21 changes: 21 additions & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
version = 3.5.8
runner.dialect = scala212
maxColumn = 120
style = defaultWithAlign
danglingParentheses.preset = true
indentOperator.preset = spray
rewrite.rules = [RedundantParens, Imports, PreferCurlyFors]
binPack.literalArgumentLists = false
align.arrowEnumeratorGenerator = false
align.tokenCategory = {
Equals = Assign
LeftArrow = Assign
}

rewrite.imports.sort = scalastyle
rewrite.imports.groups = [
["javax?\\..*"],
["scala\\..*"],
["^(?!com\\.miro).+"],
["^com\\.miro.*"]
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ import com.github.jelmerk.knn.scalalike.hnsw.HnswIndex
import com.github.jelmerk.spark.util.SerializableConfiguration
import org.apache.spark.serializer.KryoRegistrator

/**
* Implementation of KryoRegistrator that registers hnswlib classes with spark.
* Can be registered by setting spark.kryo.registrator to com.github.jelmerk.spark.HnswLibKryoRegistrator
*/
/** Implementation of KryoRegistrator that registers hnswlib classes with spark. Can be registered by setting
* spark.kryo.registrator to com.github.jelmerk.spark.HnswLibKryoRegistrator
*/
class HnswLibKryoRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo): Unit = {
kryo.register(classOf[HnswIndex[_, _, _, _]], new JavaSerializer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,15 @@ import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType, FloatType, StructType}

/**
* Companion class for VectorConverter.
/** Companion class for VectorConverter.
*/
object VectorConverter extends DefaultParamsReadable[Normalizer] {
override def load(path: String): Normalizer = super.load(path)
}

private[conversion] trait VectorConverterParams extends HasInputCol with HasOutputCol {

/**
* Param for the type of vector to produce. one of array<float>, array<double>, vector
* Default: "array<float>"
/** Param for the type of vector to produce. one of array<float>, array<double>, vector Default: "array<float>"
*
* @group param
*/
Expand All @@ -36,13 +33,16 @@ private[conversion] trait VectorConverterParams extends HasInputCol with HasOutp
setDefault(outputType -> "array<float>")
}

/**
* Converts the input vector to a vector of another type.
/** Converts the input vector to a vector of another type.
*
* @param uid identifier
* @param uid
* identifier
*/
class VectorConverter(override val uid: String)
extends Transformer with VectorConverterParams with Logging with DefaultParamsWritable {
extends Transformer
with VectorConverterParams
with Logging
with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("conv"))

Expand All @@ -57,18 +57,21 @@ class VectorConverter(override val uid: String)

override def transform(dataset: Dataset[_]): DataFrame = {

dataset.withColumn(getOutputCol, (dataset.schema(getInputCol).dataType, getOutputType) match {
case (ArrayType(FloatType, _), "array<double>") => floatArrayToDoubleArray(col(getInputCol))
case (ArrayType(FloatType, _), "vector") => floatArrayToVector(col(getInputCol))
dataset.withColumn(
getOutputCol,
(dataset.schema(getInputCol).dataType, getOutputType) match {
case (ArrayType(FloatType, _), "array<double>") => floatArrayToDoubleArray(col(getInputCol))
case (ArrayType(FloatType, _), "vector") => floatArrayToVector(col(getInputCol))

case (ArrayType(DoubleType, _), "array<float>") => doubleArrayToFloatArray(col(getInputCol))
case (ArrayType(DoubleType, _), "vector") => doubleArrayToVector(col(getInputCol))
case (ArrayType(DoubleType, _), "array<float>") => doubleArrayToFloatArray(col(getInputCol))
case (ArrayType(DoubleType, _), "vector") => doubleArrayToVector(col(getInputCol))

case (VectorType, "array<float>") => vectorToFloatArray(col(getInputCol))
case (VectorType, "array<double>") => vectorToDoubleArray(col(getInputCol))
case (VectorType, "array<float>") => vectorToFloatArray(col(getInputCol))
case (VectorType, "array<double>") => vectorToDoubleArray(col(getInputCol))

case _ => throw new IllegalArgumentException("Cannot convert vector")
})
case _ => throw new IllegalArgumentException("Cannot convert vector")
}
)
}

override def copy(extra: ParamMap): Transformer = defaultCopy(extra)
Expand All @@ -85,9 +88,9 @@ class VectorConverter(override val uid: String)
val inputColumnSchema = schema(getInputCol)

val inputColHasValidDataType = inputColumnSchema.dataType match {
case VectorType => true
case VectorType => true
case ArrayType(DoubleType, _) => true
case _ => false
case _ => false
}

if (!inputColHasValidDataType) {
Expand All @@ -96,8 +99,8 @@ class VectorConverter(override val uid: String)

val outputType: DataType = getOutputType match {
case "array<double>" => ArrayType(DoubleType)
case "array<float>" => ArrayType(FloatType)
case "vector" => VectorType
case "array<float>" => ArrayType(FloatType)
case "vector" => VectorType
}

schema
Expand All @@ -108,12 +111,16 @@ class VectorConverter(override val uid: String)

private val doubleArrayToFloatArray: UserDefinedFunction = udf { vector: Seq[Double] => vector.map(_.toFloat) }

private val floatArrayToDoubleArray: UserDefinedFunction = udf { vector: Seq[Float] => vector.toArray.map(_.toDouble) }
private val floatArrayToDoubleArray: UserDefinedFunction = udf { vector: Seq[Float] =>
vector.toArray.map(_.toDouble)
}

private val vectorToDoubleArray: UserDefinedFunction = udf { vector: Vector => vector.toArray }

private val floatArrayToVector: UserDefinedFunction = udf { vector: Seq[Float] => Vectors.dense(vector.map(_.toDouble).toArray) }
private val floatArrayToVector: UserDefinedFunction = udf { vector: Seq[Float] =>
Vectors.dense(vector.map(_.toDouble).toArray)
}

private val doubleArrayToVector: UserDefinedFunction = udf { vector: Seq[Double] => Vectors.dense(vector.toArray) }

}
}

0 comments on commit 51c1711

Please sign in to comment.