Skip to content

Commit

Permalink
qdrant support for vector knn search (#971)
Browse files Browse the repository at this point in the history
* qdrant support for vector knn search

* unbreak connector e2e tests

* ci win fix
  • Loading branch information
shuttie committed Mar 16, 2023
1 parent 7b954a1 commit 727e2bd
Show file tree
Hide file tree
Showing 31 changed files with 458 additions and 123 deletions.
7 changes: 6 additions & 1 deletion .github/compose-connectors.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,9 @@ services:
- 6650:6650
environment:
PULSAR_MEM: " -Xms512m -Xmx512m -XX:MaxDirectMemorySize=1g"
command: bin/pulsar standalone --wipe-data
command: bin/pulsar standalone --wipe-data

qdrant:
image: qdrant/qdrant:v1.0.3
ports:
- 6333:6333
4 changes: 2 additions & 2 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ jobs:
run: docker-compose -f .github/compose-connectors.yaml up -d

- name: Run integration tests
run: sbt -mem 3000 IntegrationTest/test
run: sbt -mem 5000 IntegrationTest/test

- name: Stop docker-compose
run: docker-compose -f .github/compose-connectors.yaml down

- name: Make docker image
run: sbt -mem 3000 assembly docker
run: sbt -mem 5000 assembly docker

- name: Run docker image tests
run: .github/test_docker.sh
Expand Down
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ lazy val PLATFORM = Option(System.getenv("PLATFORM")).getOrElse("amd64")

ThisBuild / organization := "ai.metarank"
ThisBuild / scalaVersion := "2.13.10"
ThisBuild / version := "0.6.4"
ThisBuild / version := "0.7.0-M1"

lazy val root = (project in file("."))
.enablePlugins(DockerPlugin)
Expand Down
4 changes: 2 additions & 2 deletions build_docker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ V=$1

docker run --rm --privileged multiarch/qemu-user-static --reset -p yes

PLATFORM=amd64 VERSION=$V sbt -mem 3000 dockerBuildAndPush
PLATFORM=arm64 VERSION=$V sbt -mem 3000 dockerBuildAndPush
PLATFORM=amd64 VERSION=$V sbt -mem 5000 dockerBuildAndPush
PLATFORM=arm64 VERSION=$V sbt -mem 5000 dockerBuildAndPush

docker manifest create metarank/metarank:$V metarank/metarank:$V-arm64 metarank/metarank:$V-amd64
docker manifest rm metarank/metarank:latest
Expand Down
4 changes: 2 additions & 2 deletions doc/configuration/recommendations.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Recommendations in Metarank

Starting from version `0.6.x`, Metarank supports two types of recommendations:
Starting from version `0.6.x`, Metarank supports three types of recommendations:
* [Trending](recommendations/trending.md): popularity-sorted list of items with customized ordering.
* [Similar items](recommendations/similar.md): matrix-factorization collaborative filtering recommender of items you may also like.
* [Semantin](recommendations/semantic.md): a content-based semantic similarity recommender, based on neural embeddings.
* [Semantic](recommendations/semantic.md): a content-based semantic similarity recommender, based on neural embeddings.
26 changes: 26 additions & 0 deletions src/it/scala/ai/metaranke2e/knn/QdrantTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package ai.metaranke2e.knn

import ai.metarank.ml.recommend.KnnConfig.QdrantConfig
import ai.metarank.ml.recommend.embedding.{EmbeddingMap, KnnIndex}
import ai.metarank.model.Identifier.ItemId
import cats.effect.unsafe.implicits.global
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class QdrantTest extends AnyFlatSpec with Matchers {
it should "write embeddings" in {
val map = EmbeddingMap(
ids = Array("1", "2", "3"),
embeddings = Array(
Array(1.0, 2.0, 3.0),
Array(1.0, 2.0, 1.0),
Array(1.0, 1.0, 0.0)
),
rows = 3,
cols = 3
)
val store = KnnIndex.write(map, QdrantConfig("http://localhost:6333", "test", 3, "Cosine")).unsafeRunSync()
val response = store.lookup(List(ItemId("1")), 3).unsafeRunSync()
response.map(_.item.value) shouldBe List("2", "3")
}
}
10 changes: 6 additions & 4 deletions src/main/scala/ai/metarank/FeatureMapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import ai.metarank.model.{Dimension, FeatureSchema, FieldName, Key, MValue, Sche
import ai.metarank.ml.rank.LambdaMARTRanker.{LambdaMARTConfig, LambdaMARTModel, LambdaMARTPredictor}
import ai.metarank.ml.rank.NoopRanker.{NoopConfig, NoopModel, NoopPredictor}
import ai.metarank.ml.rank.ShuffleRanker.{ShuffleConfig, ShuffleModel, ShufflePredictor}
import ai.metarank.ml.recommend.BertSemanticRecommender.{BertSemanticModelConfig, BertSemanticPredictor}
import ai.metarank.ml.recommend.MFRecommender.MFPredictor
import ai.metarank.ml.recommend.TrendingRecommender.{TrendingConfig, TrendingPredictor}
import ai.metarank.ml.recommend.mf.ALSRecImpl
Expand Down Expand Up @@ -87,10 +88,11 @@ object FeatureMapping extends Logging {
}
name -> LambdaMARTPredictor(name, conf, makeDatasetDescriptor(modelFeatures))

case (name, conf: NoopConfig) => name -> NoopPredictor(name, conf)
case (name, conf: ShuffleConfig) => name -> ShufflePredictor(name, conf)
case (name, conf: TrendingConfig) => name -> TrendingPredictor(name, conf)
case (name, conf: ALSConfig) => name -> MFPredictor(name, conf, ALSRecImpl(conf))
case (name, conf: NoopConfig) => name -> NoopPredictor(name, conf)
case (name, conf: ShuffleConfig) => name -> ShufflePredictor(name, conf)
case (name, conf: TrendingConfig) => name -> TrendingPredictor(name, conf)
case (name, conf: ALSConfig) => name -> MFPredictor(name, conf, ALSRecImpl(conf))
case (name, conf: BertSemanticModelConfig) => name -> BertSemanticPredictor(name, conf)
}

new FeatureMapping(
Expand Down
18 changes: 16 additions & 2 deletions src/main/scala/ai/metarank/flow/TrainBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,22 @@ case class TrainBuffer(
case Some(id) =>
IO(cache.getIfPresent(id.value)).flatMap {
case None =>
// ranking already gone, nothing to do
IO.unit
// ranking already gone, or it never was present
IO {
queue.add(
ClickthroughValues(
Clickthrough(
id = event.id,
ts = event.timestamp,
user = event.user,
session = event.session,
items = List(event.item),
interactions = List(TypedInteraction(event.item, event.`type`))
),
Nil
)
)
}
// warn(s"ranking $id is present in interaction, but missing in cache")
case Some(ctv) =>
IO {
Expand Down
14 changes: 3 additions & 11 deletions src/main/scala/ai/metarank/fstore/redis/RedisModelStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,11 @@ case class RedisModelStore(client: RedisClient, prefix: String)(implicit kc: KCo
bytesOption <- client.get(kc.encode(prefix, key))
model <- bytesOption match {
case None =>
pred.load(None) match {
case Left(error) => IO.raiseError(error)
case Right(value) => IO(Some(value))
}
pred.load(None).map(Some.apply)
case Some(bytes) =>
vc.decode(bytes) match {
case Left(err) => IO.raiseError(err)
case Right(decodedBytes) =>
pred.load(Some(decodedBytes)) match {
case Left(error) => IO.raiseError(error)
case Right(value) => IO(Some(value))
}

case Left(err) => IO.raiseError(err)
case Right(decodedBytes) => pred.load(Some(decodedBytes)).map(Some.apply)
}
}
} yield {
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/ai/metarank/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ sealed trait Predictor[C <: ModelConfig, T <: Context, M <: Model[T]] {
def config: C
def name: String
def fit(data: fs2.Stream[IO, TrainValues]): IO[M]
def load(bytes: Option[Array[Byte]]): Either[Throwable, M]
def load(bytes: Option[Array[Byte]]): IO[M]
}

object Predictor {
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/ai/metarank/ml/rank/LambdaMARTRanker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ object LambdaMARTRanker {
}
}

override def load(bytes: Option[Array[Byte]]): Either[Throwable, LambdaMARTModel] = bytes match {
override def load(bytes: Option[Array[Byte]]): IO[LambdaMARTModel] = IO.fromEither(loadSync(bytes))

def loadSync(bytes: Option[Array[Byte]]): Either[Throwable, LambdaMARTModel] = bytes match {
case None => Left(new Exception(s"cannot load model: not found, maybe you forgot to run train?"))
case Some(blob) =>
val stream = new DataInputStream(new ByteArrayInputStream(blob))
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/ai/metarank/ml/rank/NoopRanker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import io.circe.{Decoder, Encoder}
object NoopRanker {
case class NoopConfig(selector: Selector = AcceptSelector()) extends ModelConfig
case class NoopPredictor(name: String, config: NoopConfig) extends RankPredictor[NoopConfig, NoopModel] {
override def load(bytes: Option[Array[Byte]]): Either[Throwable, NoopModel] = Right(NoopModel(name, config))
override def load(bytes: Option[Array[Byte]]): IO[NoopModel] = IO.pure(NoopModel(name, config))

override def fit(data: fs2.Stream[IO, TrainValues]): IO[NoopModel] =
IO.pure(NoopModel(name, config))
Expand Down
3 changes: 1 addition & 2 deletions src/main/scala/ai/metarank/ml/rank/ShuffleRanker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ object ShuffleRanker {
case class ShuffleConfig(maxPositionChange: Int, selector: Selector = AcceptSelector()) extends ModelConfig

case class ShufflePredictor(name: String, config: ShuffleConfig) extends RankPredictor[ShuffleConfig, ShuffleModel] {
override def load(bytes: Option[Array[Byte]]): Either[Throwable, ShuffleModel] =
Right(ShuffleModel(name, config))
override def load(bytes: Option[Array[Byte]]): IO[ShuffleModel] = IO.pure(ShuffleModel(name, config))
override def fit(data: fs2.Stream[IO, TrainValues]): IO[ShuffleModel] =
IO.pure(ShuffleModel(name, config))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import ai.metarank.config.Selector.AcceptSelector
import ai.metarank.config.{ModelConfig, Selector}
import ai.metarank.ml.Predictor.RecommendPredictor
import ai.metarank.ml.onnx.SBERT
import ai.metarank.ml.recommend.KnnConfig.HnswConfig
import ai.metarank.ml.recommend.MFRecommender.EmbeddingSimilarityModel
import ai.metarank.ml.recommend.embedding.{EmbeddingMap, HnswJavaIndex}
import ai.metarank.ml.recommend.embedding.HnswJavaIndex.{HnswIndexReader, HnswIndexWriter, HnswOptions}
import ai.metarank.ml.recommend.embedding.{EmbeddingMap, HnswJavaIndex, KnnIndex}
import ai.metarank.model.Field.{StringField, StringListField}
import ai.metarank.model.Identifier.ItemId
import ai.metarank.model.{FieldName, TrainValues}
Expand All @@ -26,14 +28,15 @@ object BertSemanticRecommender {
items <- data.collect { case item: ItemValues => item }.compile.toList
_ <- info(s"Loaded ${items.size} items")
embeddings <- embed(items, fieldSet, encoder)
index <- IO(HnswJavaIndex.create(embeddings, config.m, config.ef))
index <- KnnIndex.write(embeddings, config.store)
} yield {
EmbeddingSimilarityModel(name, index)
}

override def load(bytes: Option[Array[Byte]]): Either[Throwable, EmbeddingSimilarityModel] = bytes match {
case Some(value) => Right(EmbeddingSimilarityModel(name, HnswJavaIndex.load(value)))
case None => Left(new Exception(s"cannot load index $name: not found"))
override def load(bytes: Option[Array[Byte]]): IO[EmbeddingSimilarityModel] = bytes match {
case Some(value) =>
KnnIndex.load(value, config.store).map(index => EmbeddingSimilarityModel(name, index))
case None => IO.raiseError(new Exception(s"cannot load index $name: not found"))
}

def embed(items: List[ItemValues], fieldSet: Set[String], encoder: Encoder): IO[EmbeddingMap] = IO {
Expand Down Expand Up @@ -74,8 +77,7 @@ object BertSemanticRecommender {
case class BertSemanticModelConfig(
encoder: EncoderType,
itemFields: List[String],
m: Int = 32,
ef: Int = 200,
store: KnnConfig,
selector: Selector = Selector.AcceptSelector()
) extends ModelConfig

Expand All @@ -100,15 +102,13 @@ object BertSemanticRecommender {
for {
encoder <- c.downField("encoder").as[EncoderType]
itemFields <- c.downField("itemFields").as[List[String]]
m <- c.downField("m").as[Option[Int]]
ef <- c.downField("ef").as[Option[Int]]
store <- c.downField("store").as[Option[KnnConfig]]
selector <- c.downField("selector").as[Option[Selector]]
} yield {
BertSemanticModelConfig(
encoder = encoder,
itemFields = itemFields,
m = m.getOrElse(32),
ef = ef.getOrElse(200),
store = store.getOrElse(HnswConfig()),
selector = selector.getOrElse(AcceptSelector())
)
}
Expand Down Expand Up @@ -141,12 +141,14 @@ object BertSemanticRecommender {
}
}

object CsvEncoder {
object CsvEncoder extends Logging {
def create(lines: fs2.Stream[IO, String]) = for {
dic <- lines
.filter(_.nonEmpty)
.evalMap(line => IO.fromEither(parseLine(line)))
.compile
.toList
_ <- info(s"loaded ${dic.size} embeddings")
size <- IO(dic.map(_._2.length).distinct).flatMap {
case Nil => IO.raiseError(new Exception("no embeddings found"))
case one :: Nil => IO.pure(one)
Expand Down
40 changes: 40 additions & 0 deletions src/main/scala/ai/metarank/ml/recommend/KnnConfig.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package ai.metarank.ml.recommend

import io.circe.{Decoder, DecodingFailure, Encoder, Json, JsonObject}
import io.circe.generic.semiauto.{deriveDecoder, deriveEncoder}

sealed trait KnnConfig

object KnnConfig {
case class HnswConfig(m: Int = 32, ef: Int = 200) extends KnnConfig
case class QdrantConfig(endpoint: String, collection: String, size: Int, distance: String) extends KnnConfig

implicit val hnswDecoder: Decoder[HnswConfig] = Decoder.instance(c =>
for {
m <- c.downField("m").as[Option[Int]]
ef <- c.downField("ef").as[Option[Int]]
} yield {
HnswConfig(m.getOrElse(32), ef.getOrElse(200))
}
)
implicit val hnswEncoder: Encoder[HnswConfig] = deriveEncoder

implicit val qdrantDecoder: Decoder[QdrantConfig] = deriveDecoder
implicit val qdrantEncoder: Encoder[QdrantConfig] = deriveEncoder

implicit val knnDecoder: Decoder[KnnConfig] = Decoder.instance(c =>
c.downField("type").as[String] match {
case Left(value) => Left(value)
case Right("hnsw") => hnswDecoder.tryDecode(c)
case Right("qdrant") => qdrantDecoder.tryDecode(c)
case Right(other) => Left(DecodingFailure(s"knn index type '$other' is not supported", c.history))
}
)

implicit val knnEncoder: Encoder[KnnConfig] = Encoder.instance {
case c: HnswConfig => hnswEncoder(c).deepMerge(withType("hnsw"))
case c: QdrantConfig => qdrantEncoder(c).deepMerge(withType("qdrant"))
}

def withType(tpe: String) = Json.fromJsonObject(JsonObject.fromIterable(List("type" -> Json.fromString(tpe))))
}
18 changes: 10 additions & 8 deletions src/main/scala/ai/metarank/ml/recommend/MFRecommender.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package ai.metarank.ml.recommend
import ai.metarank.ml.Model
import ai.metarank.ml.Model.{RecommendModel, Response}
import ai.metarank.ml.Predictor.RecommendPredictor
import ai.metarank.ml.recommend.embedding.{EmbeddingMap, HnswJavaIndex}
import ai.metarank.ml.recommend.embedding.KnnIndex.KnnIndexReader
import ai.metarank.ml.recommend.embedding.{EmbeddingMap, HnswJavaIndex, KnnIndex}
import ai.metarank.ml.recommend.mf.MFRecImpl
import ai.metarank.ml.recommend.mf.MFRecImpl.MFModelConfig
import ai.metarank.model.Clickthrough.TypedInteraction
Expand All @@ -28,16 +29,16 @@ object MFRecommender {
_ <- info(s"writing training dataset to $file")
_ <- writeUIRT(data, file)
embeddings <- IO(mf.train(file))
index <- IO(HnswJavaIndex.create(embeddings, config.m, config.ef))
index <- KnnIndex.write(embeddings, config.store)
} yield {
EmbeddingSimilarityModel(name, index)
}
)
}

override def load(bytes: Option[Array[Byte]]): Either[Throwable, EmbeddingSimilarityModel] = bytes match {
case Some(value) => Right(EmbeddingSimilarityModel(name, HnswJavaIndex.load(value)))
case None => Left(new Exception(s"cannot load index $name: not found"))
override def load(bytes: Option[Array[Byte]]): IO[EmbeddingSimilarityModel] = bytes match {
case Some(value) => KnnIndex.load(value, config.store).map(index => EmbeddingSimilarityModel(name, index))
case None => IO.raiseError(new Exception(s"cannot load index $name: not found"))
}

def writeUIRT(source: fs2.Stream[IO, TrainValues], dest: Path): IO[Unit] = {
Expand All @@ -62,14 +63,15 @@ object MFRecommender {
}
}

case class EmbeddingSimilarityModel(name: String, index: HnswJavaIndex) extends RecommendModel {
case class EmbeddingSimilarityModel(name: String, index: KnnIndexReader) extends RecommendModel {
override def predict(request: RecommendRequest): IO[Model.Response] = for {
_ <- request.items match {
case Nil => IO.raiseError(new Exception("similar items recommender requires request.items to be non-empty"))
case _ => IO.unit
}
response <- IO(index.lookup(request.items, request.count))
items <- IO.fromOption(NonEmptyList.fromList(response))(new Exception("empty response from the recommender"))
response <- index.lookup(request.items, request.count + request.items.size)
filtered <- IO(response.filterNot(is => request.items.contains(is.item)).take(request.count))
items <- IO.fromOption(NonEmptyList.fromList(filtered))(new Exception("empty response from the recommender"))
} yield {
Response(items)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ object RandomRecommender {
case class RandomPredictor(name: String, config: RandomConfig)
extends RecommendPredictor[RandomConfig, RandomModel]
with Logging {
override def load(bytes: Option[Array[Byte]]): Either[Throwable, RandomModel] = bytes match {
override def load(bytes: Option[Array[Byte]]): IO[RandomModel] = IO.fromEither(loadSync(bytes))
def loadSync(bytes: Option[Array[Byte]]): Either[Throwable, RandomModel] = bytes match {
case None => Left(new Exception("Cannot load model from store: not found. Did you train it before?"))
case Some(bytes) =>
val stream = new DataInputStream(new ByteArrayInputStream(bytes))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ object TrendingRecommender {

}

override def load(bytesOption: Option[Array[Byte]]): Either[Throwable, TrendingModel] = {
override def load(bytes: Option[Array[Byte]]): IO[TrendingModel] = IO.fromEither(loadSync(bytes))
def loadSync(bytesOption: Option[Array[Byte]]): Either[Throwable, TrendingModel] = {
bytesOption match {
case None => Left(new Exception("cannot load trending model: not found"))
case Some(bytes) =>
Expand Down

0 comments on commit 727e2bd

Please sign in to comment.