From 727e2bdf42d693b5f70f1d4e7f8278d57160c44f Mon Sep 17 00:00:00 2001 From: Roman Grebennikov Date: Thu, 16 Mar 2023 10:27:05 +0100 Subject: [PATCH] qdrant support for vector knn search (#971) * qdrant support for vector knn search * unbreak connector e2e tests * ci win fix --- .github/compose-connectors.yaml | 7 +- .github/workflows/integration.yml | 4 +- build.sbt | 2 +- build_docker.sh | 4 +- doc/configuration/recommendations.md | 4 +- .../scala/ai/metaranke2e/knn/QdrantTest.scala | 26 +++ .../scala/ai/metarank/FeatureMapping.scala | 10 +- .../scala/ai/metarank/flow/TrainBuffer.scala | 18 +- .../fstore/redis/RedisModelStore.scala | 14 +- src/main/scala/ai/metarank/ml/Predictor.scala | 2 +- .../metarank/ml/rank/LambdaMARTRanker.scala | 4 +- .../ai/metarank/ml/rank/NoopRanker.scala | 2 +- .../ai/metarank/ml/rank/ShuffleRanker.scala | 3 +- .../recommend/BertSemanticRecommender.scala | 26 +-- .../ai/metarank/ml/recommend/KnnConfig.scala | 40 +++++ .../metarank/ml/recommend/MFRecommender.scala | 18 +- .../ml/recommend/RandomRecommender.scala | 3 +- .../ml/recommend/TrendingRecommender.scala | 3 +- .../recommend/embedding/HnswJavaIndex.scala | 121 +++++++------ .../ml/recommend/embedding/KnnIndex.scala | 32 ++++ .../ml/recommend/embedding/QdrantIndex.scala | 159 ++++++++++++++++++ .../metarank/ml/recommend/mf/ALSRecImpl.scala | 11 +- .../metarank/ml/recommend/mf/MFRecImpl.scala | 4 +- .../ai/metarank/model/AnalyticsPayload.scala | 4 +- .../resources/ranklens/events/events.jsonl.gz | 4 +- .../codec/impl/TrainValuesCodecTest.scala | 2 +- .../scala/ai/metarank/ml/PredictorSuite.scala | 2 +- .../ml/rank/LambdaMARTRankerTest.scala | 4 +- .../BertSemanticRecommenderTest.scala | 7 +- .../EmbeddingSimilarityModelTest.scala | 34 ++++ .../embedding/HnswJavaIndexTest.scala | 7 +- 31 files changed, 458 insertions(+), 123 deletions(-) create mode 100644 src/it/scala/ai/metaranke2e/knn/QdrantTest.scala create mode 100644 src/main/scala/ai/metarank/ml/recommend/KnnConfig.scala create mode 100644 src/main/scala/ai/metarank/ml/recommend/embedding/KnnIndex.scala create mode 100644 src/main/scala/ai/metarank/ml/recommend/embedding/QdrantIndex.scala create mode 100644 src/test/scala/ai/metarank/ml/recommend/EmbeddingSimilarityModelTest.scala diff --git a/.github/compose-connectors.yaml b/.github/compose-connectors.yaml index 79da5854c..13e503c2b 100644 --- a/.github/compose-connectors.yaml +++ b/.github/compose-connectors.yaml @@ -65,4 +65,9 @@ services: - 6650:6650 environment: PULSAR_MEM: " -Xms512m -Xmx512m -XX:MaxDirectMemorySize=1g" - command: bin/pulsar standalone --wipe-data \ No newline at end of file + command: bin/pulsar standalone --wipe-data + + qdrant: + image: qdrant/qdrant:v1.0.3 + ports: + - 6333:6333 diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index e08a8a305..521f459e2 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -40,13 +40,13 @@ jobs: run: docker-compose -f .github/compose-connectors.yaml up -d - name: Run integration tests - run: sbt -mem 3000 IntegrationTest/test + run: sbt -mem 5000 IntegrationTest/test - name: Stop docker-compose run: docker-compose -f .github/compose-connectors.yaml down - name: Make docker image - run: sbt -mem 3000 assembly docker + run: sbt -mem 5000 assembly docker - name: Run docker image tests run: .github/test_docker.sh diff --git a/build.sbt b/build.sbt index e7650d5ae..efd8dd88b 100644 --- a/build.sbt +++ b/build.sbt @@ -4,7 +4,7 @@ lazy val PLATFORM = Option(System.getenv("PLATFORM")).getOrElse("amd64") ThisBuild / organization := "ai.metarank" ThisBuild / scalaVersion := "2.13.10" -ThisBuild / version := "0.6.4" +ThisBuild / version := "0.7.0-M1" lazy val root = (project in file(".")) .enablePlugins(DockerPlugin) diff --git a/build_docker.sh b/build_docker.sh index 1a92d9231..4652fbf7e 100755 --- a/build_docker.sh +++ b/build_docker.sh @@ -6,8 +6,8 @@ V=$1 docker run --rm --privileged multiarch/qemu-user-static --reset -p yes -PLATFORM=amd64 VERSION=$V sbt -mem 3000 dockerBuildAndPush -PLATFORM=arm64 VERSION=$V sbt -mem 3000 dockerBuildAndPush +PLATFORM=amd64 VERSION=$V sbt -mem 5000 dockerBuildAndPush +PLATFORM=arm64 VERSION=$V sbt -mem 5000 dockerBuildAndPush docker manifest create metarank/metarank:$V metarank/metarank:$V-arm64 metarank/metarank:$V-amd64 docker manifest rm metarank/metarank:latest diff --git a/doc/configuration/recommendations.md b/doc/configuration/recommendations.md index 1a048c507..7dc66fba0 100644 --- a/doc/configuration/recommendations.md +++ b/doc/configuration/recommendations.md @@ -1,6 +1,6 @@ # Recommendations in Metarank -Starting from version `0.6.x`, Metarank supports two types of recommendations: +Starting from version `0.6.x`, Metarank supports three types of recommendations: * [Trending](recommendations/trending.md): popularity-sorted list of items with customized ordering. * [Similar items](recommendations/similar.md): matrix-factorization collaborative filtering recommender of items you may also like. -* [Semantin](recommendations/semantic.md): a content-based semantic similarity recommender, based on neural embeddings. \ No newline at end of file +* [Semantic](recommendations/semantic.md): a content-based semantic similarity recommender, based on neural embeddings. \ No newline at end of file diff --git a/src/it/scala/ai/metaranke2e/knn/QdrantTest.scala b/src/it/scala/ai/metaranke2e/knn/QdrantTest.scala new file mode 100644 index 000000000..fb396580e --- /dev/null +++ b/src/it/scala/ai/metaranke2e/knn/QdrantTest.scala @@ -0,0 +1,26 @@ +package ai.metaranke2e.knn + +import ai.metarank.ml.recommend.KnnConfig.QdrantConfig +import ai.metarank.ml.recommend.embedding.{EmbeddingMap, KnnIndex} +import ai.metarank.model.Identifier.ItemId +import cats.effect.unsafe.implicits.global +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class QdrantTest extends AnyFlatSpec with Matchers { + it should "write embeddings" in { + val map = EmbeddingMap( + ids = Array("1", "2", "3"), + embeddings = Array( + Array(1.0, 2.0, 3.0), + Array(1.0, 2.0, 1.0), + Array(1.0, 1.0, 0.0) + ), + rows = 3, + cols = 3 + ) + val store = KnnIndex.write(map, QdrantConfig("http://localhost:6333", "test", 3, "Cosine")).unsafeRunSync() + val response = store.lookup(List(ItemId("1")), 3).unsafeRunSync() + response.map(_.item.value) shouldBe List("2", "3") + } +} diff --git a/src/main/scala/ai/metarank/FeatureMapping.scala b/src/main/scala/ai/metarank/FeatureMapping.scala index d59765299..d3da9f6d7 100644 --- a/src/main/scala/ai/metarank/FeatureMapping.scala +++ b/src/main/scala/ai/metarank/FeatureMapping.scala @@ -26,6 +26,7 @@ import ai.metarank.model.{Dimension, FeatureSchema, FieldName, Key, MValue, Sche import ai.metarank.ml.rank.LambdaMARTRanker.{LambdaMARTConfig, LambdaMARTModel, LambdaMARTPredictor} import ai.metarank.ml.rank.NoopRanker.{NoopConfig, NoopModel, NoopPredictor} import ai.metarank.ml.rank.ShuffleRanker.{ShuffleConfig, ShuffleModel, ShufflePredictor} +import ai.metarank.ml.recommend.BertSemanticRecommender.{BertSemanticModelConfig, BertSemanticPredictor} import ai.metarank.ml.recommend.MFRecommender.MFPredictor import ai.metarank.ml.recommend.TrendingRecommender.{TrendingConfig, TrendingPredictor} import ai.metarank.ml.recommend.mf.ALSRecImpl @@ -87,10 +88,11 @@ object FeatureMapping extends Logging { } name -> LambdaMARTPredictor(name, conf, makeDatasetDescriptor(modelFeatures)) - case (name, conf: NoopConfig) => name -> NoopPredictor(name, conf) - case (name, conf: ShuffleConfig) => name -> ShufflePredictor(name, conf) - case (name, conf: TrendingConfig) => name -> TrendingPredictor(name, conf) - case (name, conf: ALSConfig) => name -> MFPredictor(name, conf, ALSRecImpl(conf)) + case (name, conf: NoopConfig) => name -> NoopPredictor(name, conf) + case (name, conf: ShuffleConfig) => name -> ShufflePredictor(name, conf) + case (name, conf: TrendingConfig) => name -> TrendingPredictor(name, conf) + case (name, conf: ALSConfig) => name -> MFPredictor(name, conf, ALSRecImpl(conf)) + case (name, conf: BertSemanticModelConfig) => name -> BertSemanticPredictor(name, conf) } new FeatureMapping( diff --git a/src/main/scala/ai/metarank/flow/TrainBuffer.scala b/src/main/scala/ai/metarank/flow/TrainBuffer.scala index 2261d7b89..52735012a 100644 --- a/src/main/scala/ai/metarank/flow/TrainBuffer.scala +++ b/src/main/scala/ai/metarank/flow/TrainBuffer.scala @@ -85,8 +85,22 @@ case class TrainBuffer( case Some(id) => IO(cache.getIfPresent(id.value)).flatMap { case None => - // ranking already gone, nothing to do - IO.unit + // ranking already gone, or it never was present + IO { + queue.add( + ClickthroughValues( + Clickthrough( + id = event.id, + ts = event.timestamp, + user = event.user, + session = event.session, + items = List(event.item), + interactions = List(TypedInteraction(event.item, event.`type`)) + ), + Nil + ) + ) + } // warn(s"ranking $id is present in interaction, but missing in cache") case Some(ctv) => IO { diff --git a/src/main/scala/ai/metarank/fstore/redis/RedisModelStore.scala b/src/main/scala/ai/metarank/fstore/redis/RedisModelStore.scala index 6995bfb98..088a61691 100644 --- a/src/main/scala/ai/metarank/fstore/redis/RedisModelStore.scala +++ b/src/main/scala/ai/metarank/fstore/redis/RedisModelStore.scala @@ -27,19 +27,11 @@ case class RedisModelStore(client: RedisClient, prefix: String)(implicit kc: KCo bytesOption <- client.get(kc.encode(prefix, key)) model <- bytesOption match { case None => - pred.load(None) match { - case Left(error) => IO.raiseError(error) - case Right(value) => IO(Some(value)) - } + pred.load(None).map(Some.apply) case Some(bytes) => vc.decode(bytes) match { - case Left(err) => IO.raiseError(err) - case Right(decodedBytes) => - pred.load(Some(decodedBytes)) match { - case Left(error) => IO.raiseError(error) - case Right(value) => IO(Some(value)) - } - + case Left(err) => IO.raiseError(err) + case Right(decodedBytes) => pred.load(Some(decodedBytes)).map(Some.apply) } } } yield { diff --git a/src/main/scala/ai/metarank/ml/Predictor.scala b/src/main/scala/ai/metarank/ml/Predictor.scala index 9cb655e40..94204c179 100644 --- a/src/main/scala/ai/metarank/ml/Predictor.scala +++ b/src/main/scala/ai/metarank/ml/Predictor.scala @@ -11,7 +11,7 @@ sealed trait Predictor[C <: ModelConfig, T <: Context, M <: Model[T]] { def config: C def name: String def fit(data: fs2.Stream[IO, TrainValues]): IO[M] - def load(bytes: Option[Array[Byte]]): Either[Throwable, M] + def load(bytes: Option[Array[Byte]]): IO[M] } object Predictor { diff --git a/src/main/scala/ai/metarank/ml/rank/LambdaMARTRanker.scala b/src/main/scala/ai/metarank/ml/rank/LambdaMARTRanker.scala index 9d2603d25..fcf189cc7 100644 --- a/src/main/scala/ai/metarank/ml/rank/LambdaMARTRanker.scala +++ b/src/main/scala/ai/metarank/ml/rank/LambdaMARTRanker.scala @@ -93,7 +93,9 @@ object LambdaMARTRanker { } } - override def load(bytes: Option[Array[Byte]]): Either[Throwable, LambdaMARTModel] = bytes match { + override def load(bytes: Option[Array[Byte]]): IO[LambdaMARTModel] = IO.fromEither(loadSync(bytes)) + + def loadSync(bytes: Option[Array[Byte]]): Either[Throwable, LambdaMARTModel] = bytes match { case None => Left(new Exception(s"cannot load model: not found, maybe you forgot to run train?")) case Some(blob) => val stream = new DataInputStream(new ByteArrayInputStream(blob)) diff --git a/src/main/scala/ai/metarank/ml/rank/NoopRanker.scala b/src/main/scala/ai/metarank/ml/rank/NoopRanker.scala index 846443abe..3a1e3ec2c 100644 --- a/src/main/scala/ai/metarank/ml/rank/NoopRanker.scala +++ b/src/main/scala/ai/metarank/ml/rank/NoopRanker.scala @@ -13,7 +13,7 @@ import io.circe.{Decoder, Encoder} object NoopRanker { case class NoopConfig(selector: Selector = AcceptSelector()) extends ModelConfig case class NoopPredictor(name: String, config: NoopConfig) extends RankPredictor[NoopConfig, NoopModel] { - override def load(bytes: Option[Array[Byte]]): Either[Throwable, NoopModel] = Right(NoopModel(name, config)) + override def load(bytes: Option[Array[Byte]]): IO[NoopModel] = IO.pure(NoopModel(name, config)) override def fit(data: fs2.Stream[IO, TrainValues]): IO[NoopModel] = IO.pure(NoopModel(name, config)) diff --git a/src/main/scala/ai/metarank/ml/rank/ShuffleRanker.scala b/src/main/scala/ai/metarank/ml/rank/ShuffleRanker.scala index aaf67643b..207ebb5d3 100644 --- a/src/main/scala/ai/metarank/ml/rank/ShuffleRanker.scala +++ b/src/main/scala/ai/metarank/ml/rank/ShuffleRanker.scala @@ -16,8 +16,7 @@ object ShuffleRanker { case class ShuffleConfig(maxPositionChange: Int, selector: Selector = AcceptSelector()) extends ModelConfig case class ShufflePredictor(name: String, config: ShuffleConfig) extends RankPredictor[ShuffleConfig, ShuffleModel] { - override def load(bytes: Option[Array[Byte]]): Either[Throwable, ShuffleModel] = - Right(ShuffleModel(name, config)) + override def load(bytes: Option[Array[Byte]]): IO[ShuffleModel] = IO.pure(ShuffleModel(name, config)) override def fit(data: fs2.Stream[IO, TrainValues]): IO[ShuffleModel] = IO.pure(ShuffleModel(name, config)) } diff --git a/src/main/scala/ai/metarank/ml/recommend/BertSemanticRecommender.scala b/src/main/scala/ai/metarank/ml/recommend/BertSemanticRecommender.scala index 1dbdce375..1c2ecbac0 100644 --- a/src/main/scala/ai/metarank/ml/recommend/BertSemanticRecommender.scala +++ b/src/main/scala/ai/metarank/ml/recommend/BertSemanticRecommender.scala @@ -4,8 +4,10 @@ import ai.metarank.config.Selector.AcceptSelector import ai.metarank.config.{ModelConfig, Selector} import ai.metarank.ml.Predictor.RecommendPredictor import ai.metarank.ml.onnx.SBERT +import ai.metarank.ml.recommend.KnnConfig.HnswConfig import ai.metarank.ml.recommend.MFRecommender.EmbeddingSimilarityModel -import ai.metarank.ml.recommend.embedding.{EmbeddingMap, HnswJavaIndex} +import ai.metarank.ml.recommend.embedding.HnswJavaIndex.{HnswIndexReader, HnswIndexWriter, HnswOptions} +import ai.metarank.ml.recommend.embedding.{EmbeddingMap, HnswJavaIndex, KnnIndex} import ai.metarank.model.Field.{StringField, StringListField} import ai.metarank.model.Identifier.ItemId import ai.metarank.model.{FieldName, TrainValues} @@ -26,14 +28,15 @@ object BertSemanticRecommender { items <- data.collect { case item: ItemValues => item }.compile.toList _ <- info(s"Loaded ${items.size} items") embeddings <- embed(items, fieldSet, encoder) - index <- IO(HnswJavaIndex.create(embeddings, config.m, config.ef)) + index <- KnnIndex.write(embeddings, config.store) } yield { EmbeddingSimilarityModel(name, index) } - override def load(bytes: Option[Array[Byte]]): Either[Throwable, EmbeddingSimilarityModel] = bytes match { - case Some(value) => Right(EmbeddingSimilarityModel(name, HnswJavaIndex.load(value))) - case None => Left(new Exception(s"cannot load index $name: not found")) + override def load(bytes: Option[Array[Byte]]): IO[EmbeddingSimilarityModel] = bytes match { + case Some(value) => + KnnIndex.load(value, config.store).map(index => EmbeddingSimilarityModel(name, index)) + case None => IO.raiseError(new Exception(s"cannot load index $name: not found")) } def embed(items: List[ItemValues], fieldSet: Set[String], encoder: Encoder): IO[EmbeddingMap] = IO { @@ -74,8 +77,7 @@ object BertSemanticRecommender { case class BertSemanticModelConfig( encoder: EncoderType, itemFields: List[String], - m: Int = 32, - ef: Int = 200, + store: KnnConfig, selector: Selector = Selector.AcceptSelector() ) extends ModelConfig @@ -100,15 +102,13 @@ object BertSemanticRecommender { for { encoder <- c.downField("encoder").as[EncoderType] itemFields <- c.downField("itemFields").as[List[String]] - m <- c.downField("m").as[Option[Int]] - ef <- c.downField("ef").as[Option[Int]] + store <- c.downField("store").as[Option[KnnConfig]] selector <- c.downField("selector").as[Option[Selector]] } yield { BertSemanticModelConfig( encoder = encoder, itemFields = itemFields, - m = m.getOrElse(32), - ef = ef.getOrElse(200), + store = store.getOrElse(HnswConfig()), selector = selector.getOrElse(AcceptSelector()) ) } @@ -141,12 +141,14 @@ object BertSemanticRecommender { } } - object CsvEncoder { + object CsvEncoder extends Logging { def create(lines: fs2.Stream[IO, String]) = for { dic <- lines + .filter(_.nonEmpty) .evalMap(line => IO.fromEither(parseLine(line))) .compile .toList + _ <- info(s"loaded ${dic.size} embeddings") size <- IO(dic.map(_._2.length).distinct).flatMap { case Nil => IO.raiseError(new Exception("no embeddings found")) case one :: Nil => IO.pure(one) diff --git a/src/main/scala/ai/metarank/ml/recommend/KnnConfig.scala b/src/main/scala/ai/metarank/ml/recommend/KnnConfig.scala new file mode 100644 index 000000000..46b150834 --- /dev/null +++ b/src/main/scala/ai/metarank/ml/recommend/KnnConfig.scala @@ -0,0 +1,40 @@ +package ai.metarank.ml.recommend + +import io.circe.{Decoder, DecodingFailure, Encoder, Json, JsonObject} +import io.circe.generic.semiauto.{deriveDecoder, deriveEncoder} + +sealed trait KnnConfig + +object KnnConfig { + case class HnswConfig(m: Int = 32, ef: Int = 200) extends KnnConfig + case class QdrantConfig(endpoint: String, collection: String, size: Int, distance: String) extends KnnConfig + + implicit val hnswDecoder: Decoder[HnswConfig] = Decoder.instance(c => + for { + m <- c.downField("m").as[Option[Int]] + ef <- c.downField("ef").as[Option[Int]] + } yield { + HnswConfig(m.getOrElse(32), ef.getOrElse(200)) + } + ) + implicit val hnswEncoder: Encoder[HnswConfig] = deriveEncoder + + implicit val qdrantDecoder: Decoder[QdrantConfig] = deriveDecoder + implicit val qdrantEncoder: Encoder[QdrantConfig] = deriveEncoder + + implicit val knnDecoder: Decoder[KnnConfig] = Decoder.instance(c => + c.downField("type").as[String] match { + case Left(value) => Left(value) + case Right("hnsw") => hnswDecoder.tryDecode(c) + case Right("qdrant") => qdrantDecoder.tryDecode(c) + case Right(other) => Left(DecodingFailure(s"knn index type '$other' is not supported", c.history)) + } + ) + + implicit val knnEncoder: Encoder[KnnConfig] = Encoder.instance { + case c: HnswConfig => hnswEncoder(c).deepMerge(withType("hnsw")) + case c: QdrantConfig => qdrantEncoder(c).deepMerge(withType("qdrant")) + } + + def withType(tpe: String) = Json.fromJsonObject(JsonObject.fromIterable(List("type" -> Json.fromString(tpe)))) +} diff --git a/src/main/scala/ai/metarank/ml/recommend/MFRecommender.scala b/src/main/scala/ai/metarank/ml/recommend/MFRecommender.scala index b6437241a..eff307b67 100644 --- a/src/main/scala/ai/metarank/ml/recommend/MFRecommender.scala +++ b/src/main/scala/ai/metarank/ml/recommend/MFRecommender.scala @@ -3,7 +3,8 @@ package ai.metarank.ml.recommend import ai.metarank.ml.Model import ai.metarank.ml.Model.{RecommendModel, Response} import ai.metarank.ml.Predictor.RecommendPredictor -import ai.metarank.ml.recommend.embedding.{EmbeddingMap, HnswJavaIndex} +import ai.metarank.ml.recommend.embedding.KnnIndex.KnnIndexReader +import ai.metarank.ml.recommend.embedding.{EmbeddingMap, HnswJavaIndex, KnnIndex} import ai.metarank.ml.recommend.mf.MFRecImpl import ai.metarank.ml.recommend.mf.MFRecImpl.MFModelConfig import ai.metarank.model.Clickthrough.TypedInteraction @@ -28,16 +29,16 @@ object MFRecommender { _ <- info(s"writing training dataset to $file") _ <- writeUIRT(data, file) embeddings <- IO(mf.train(file)) - index <- IO(HnswJavaIndex.create(embeddings, config.m, config.ef)) + index <- KnnIndex.write(embeddings, config.store) } yield { EmbeddingSimilarityModel(name, index) } ) } - override def load(bytes: Option[Array[Byte]]): Either[Throwable, EmbeddingSimilarityModel] = bytes match { - case Some(value) => Right(EmbeddingSimilarityModel(name, HnswJavaIndex.load(value))) - case None => Left(new Exception(s"cannot load index $name: not found")) + override def load(bytes: Option[Array[Byte]]): IO[EmbeddingSimilarityModel] = bytes match { + case Some(value) => KnnIndex.load(value, config.store).map(index => EmbeddingSimilarityModel(name, index)) + case None => IO.raiseError(new Exception(s"cannot load index $name: not found")) } def writeUIRT(source: fs2.Stream[IO, TrainValues], dest: Path): IO[Unit] = { @@ -62,14 +63,15 @@ object MFRecommender { } } - case class EmbeddingSimilarityModel(name: String, index: HnswJavaIndex) extends RecommendModel { + case class EmbeddingSimilarityModel(name: String, index: KnnIndexReader) extends RecommendModel { override def predict(request: RecommendRequest): IO[Model.Response] = for { _ <- request.items match { case Nil => IO.raiseError(new Exception("similar items recommender requires request.items to be non-empty")) case _ => IO.unit } - response <- IO(index.lookup(request.items, request.count)) - items <- IO.fromOption(NonEmptyList.fromList(response))(new Exception("empty response from the recommender")) + response <- index.lookup(request.items, request.count + request.items.size) + filtered <- IO(response.filterNot(is => request.items.contains(is.item)).take(request.count)) + items <- IO.fromOption(NonEmptyList.fromList(filtered))(new Exception("empty response from the recommender")) } yield { Response(items) } diff --git a/src/main/scala/ai/metarank/ml/recommend/RandomRecommender.scala b/src/main/scala/ai/metarank/ml/recommend/RandomRecommender.scala index 2740f5b50..1014f1614 100644 --- a/src/main/scala/ai/metarank/ml/recommend/RandomRecommender.scala +++ b/src/main/scala/ai/metarank/ml/recommend/RandomRecommender.scala @@ -60,7 +60,8 @@ object RandomRecommender { case class RandomPredictor(name: String, config: RandomConfig) extends RecommendPredictor[RandomConfig, RandomModel] with Logging { - override def load(bytes: Option[Array[Byte]]): Either[Throwable, RandomModel] = bytes match { + override def load(bytes: Option[Array[Byte]]): IO[RandomModel] = IO.fromEither(loadSync(bytes)) + def loadSync(bytes: Option[Array[Byte]]): Either[Throwable, RandomModel] = bytes match { case None => Left(new Exception("Cannot load model from store: not found. Did you train it before?")) case Some(bytes) => val stream = new DataInputStream(new ByteArrayInputStream(bytes)) diff --git a/src/main/scala/ai/metarank/ml/recommend/TrendingRecommender.scala b/src/main/scala/ai/metarank/ml/recommend/TrendingRecommender.scala index 8ca65f50a..bf0b7b858 100644 --- a/src/main/scala/ai/metarank/ml/recommend/TrendingRecommender.scala +++ b/src/main/scala/ai/metarank/ml/recommend/TrendingRecommender.scala @@ -86,7 +86,8 @@ object TrendingRecommender { } - override def load(bytesOption: Option[Array[Byte]]): Either[Throwable, TrendingModel] = { + override def load(bytes: Option[Array[Byte]]): IO[TrendingModel] = IO.fromEither(loadSync(bytes)) + def loadSync(bytesOption: Option[Array[Byte]]): Either[Throwable, TrendingModel] = { bytesOption match { case None => Left(new Exception("cannot load trending model: not found")) case Some(bytes) => diff --git a/src/main/scala/ai/metarank/ml/recommend/embedding/HnswJavaIndex.scala b/src/main/scala/ai/metarank/ml/recommend/embedding/HnswJavaIndex.scala index f41a86d5d..abf1f8343 100644 --- a/src/main/scala/ai/metarank/ml/recommend/embedding/HnswJavaIndex.scala +++ b/src/main/scala/ai/metarank/ml/recommend/embedding/HnswJavaIndex.scala @@ -1,8 +1,10 @@ package ai.metarank.ml.recommend.embedding import ai.metarank.ml.Model.ItemScore +import ai.metarank.ml.recommend.embedding.KnnIndex.{KnnIndexReader, KnnIndexWriter} import ai.metarank.model.Identifier.ItemId import ai.metarank.util.Logging +import cats.effect.IO import com.github.jelmerk.knn.{DistanceFunctions, Index, Item, ProgressListener, SearchResult} import com.github.jelmerk.knn.hnsw.HnswIndex @@ -10,70 +12,85 @@ import scala.jdk.CollectionConverters._ import scala.jdk.OptionConverters._ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.util +import fs2.{Chunk, Stream} -case class HnswJavaIndex(index: HnswIndex[String, Array[Double], Embedding, java.lang.Double]) { - def lookup(items: List[ItemId], n: Int): List[ItemScore] = items match { - case head :: Nil => - makeResponse(index.findNeighbors(head.value, n)) - case Nil => - Nil - case _ => - val embeddings = items.flatMap(item => index.get(item.value).toScala.map(_.vector)).toArray - val center = centroid(embeddings) - makeResponse(index.findNearest(center, n)) +object HnswJavaIndex extends Logging { + class LoggerListener extends ProgressListener { + override def updateProgress(workDone: Int, max: Int): Unit = + logger.info(s"indexed $workDone of $max items: ${math.round(100.0 * workDone / max)}%") } - def centroid(items: Array[Array[Double]]): Array[Double] = { - val result = new Array[Double](index.getDimensions) - var i = 0 - while (i < result.length) { - var sum = 0.0 - var itemIndex = 0 - while (itemIndex < items.length) { - sum += items(itemIndex)(i) - itemIndex += 1 + case class HnswIndexReader(index: HnswIndex[String, Array[Double], Embedding, java.lang.Double]) + extends KnnIndexReader { + override def lookup(items: List[ItemId], n: Int): IO[List[ItemScore]] = IO { + items match { + case Nil => Nil + case head :: Nil => + index.get(head.value).toScala match { + case Some(value) => lookupOne(value.vector, n) + case None => Nil + } + case _ => + val embeddings = items.flatMap(item => index.get(item.value).toScala.map(_.vector)).toArray + val center = centroid(embeddings) + lookupOne(center, n) } - result(i) = sum / items.length - i += 1 } - result - } - def makeResponse(result: util.List[SearchResult[Embedding, java.lang.Double]]): List[ItemScore] = - result.asScala.toList.map(sr => ItemScore(ItemId(sr.item().id), sr.distance())) + def centroid(items: Array[Array[Double]]): Array[Double] = { + val result = new Array[Double](index.getDimensions) + var i = 0 + while (i < result.length) { + var sum = 0.0 + var itemIndex = 0 + while (itemIndex < items.length) { + sum += items(itemIndex)(i) + itemIndex += 1 + } + result(i) = sum / items.length + i += 1 + } + result + } - def save(): Array[Byte] = { - val stream = new ByteArrayOutputStream() - index.save(stream) - stream.toByteArray - } -} + def lookupOne(vector: Array[Double], n: Int): List[ItemScore] = { + val result = index.findNearest(vector, n) + result.asScala.toList.map(sr => ItemScore(ItemId(sr.item().id), sr.distance())) + } -object HnswJavaIndex extends Logging { - class LoggerListener extends ProgressListener { - override def updateProgress(workDone: Int, max: Int): Unit = - logger.info(s"indexed $workDone of $max items: ${math.round(100.0 * workDone / max)}%") + override def save(): Array[Byte] = { + val stream = new ByteArrayOutputStream() + index.save(stream) + stream.toByteArray + } } - def create(source: EmbeddingMap, m: Int, ef: Int): HnswJavaIndex = { - val builder = HnswIndex - .newBuilder(source.cols, DistanceFunctions.DOUBLE_COSINE_DISTANCE, source.rows) - .withM(m) - .withEf(ef) - .withEfConstruction(ef) - .build[String, Embedding]() - - val embeddings = for { - (embedding, id) <- source.embeddings.zip(source.ids) + object HnswIndexWriter extends KnnIndexWriter[HnswIndexReader, HnswOptions] { + override def write(source: EmbeddingMap, options: HnswOptions): IO[HnswIndexReader] = for { + index <- IO( + HnswIndex + .newBuilder(source.cols, DistanceFunctions.DOUBLE_COSINE_DISTANCE, source.rows) + .withM(options.m) + .withEf(options.ef) + .withEfConstruction(options.ef) + .build[String, Embedding]() + ) + emb <- Stream + .chunk[IO, Array[Double]](Chunk.array(source.embeddings)) + .zip(Stream.chunk[IO, String](Chunk.array(source.ids))) + .map { case (embedding, id) => Embedding(id, embedding) } + .compile + .toList + _ <- IO(index.addAll(util.Arrays.asList(emb: _*), 1, new LoggerListener(), 256)) } yield { - Embedding(id, embedding) + HnswIndexReader(index) } - builder.addAll(util.Arrays.asList(embeddings: _*), 1, new LoggerListener(), 256) - new HnswJavaIndex(builder) - } - def load(bytes: Array[Byte]): HnswJavaIndex = { - val index = HnswIndex.load[String, Array[Double], Embedding, java.lang.Double](new ByteArrayInputStream(bytes)) - new HnswJavaIndex(index) + override def load(bytes: Array[Byte], options: HnswOptions): IO[HnswIndexReader] = IO { + val index = HnswIndex.load[String, Array[Double], Embedding, java.lang.Double](new ByteArrayInputStream(bytes)) + HnswIndexReader(index) + } } + + case class HnswOptions(m: Int, ef: Int) } diff --git a/src/main/scala/ai/metarank/ml/recommend/embedding/KnnIndex.scala b/src/main/scala/ai/metarank/ml/recommend/embedding/KnnIndex.scala new file mode 100644 index 000000000..6709bc075 --- /dev/null +++ b/src/main/scala/ai/metarank/ml/recommend/embedding/KnnIndex.scala @@ -0,0 +1,32 @@ +package ai.metarank.ml.recommend.embedding + +import ai.metarank.ml.Model.ItemScore +import ai.metarank.ml.recommend.KnnConfig +import ai.metarank.ml.recommend.embedding.HnswJavaIndex.{HnswIndexWriter, HnswOptions} +import ai.metarank.ml.recommend.embedding.QdrantIndex.{QdrantIndexWriter, QdrantOptions} +import ai.metarank.model.Identifier.ItemId +import cats.effect.IO + +object KnnIndex { + trait KnnIndexReader { + def lookup(items: List[ItemId], n: Int): IO[List[ItemScore]] + def save(): Array[Byte] + } + + trait KnnIndexWriter[R <: KnnIndexReader, O] { + def load(bytes: Array[Byte], options: O): IO[R] + def write(embeddings: EmbeddingMap, options: O): IO[R] + } + + def write(source: EmbeddingMap, config: KnnConfig): IO[KnnIndexReader] = config match { + case KnnConfig.HnswConfig(m, ef) => HnswIndexWriter.write(source, HnswOptions(m, ef)) + case KnnConfig.QdrantConfig(endpoint, collection, dim, dist) => + QdrantIndexWriter.write(source, QdrantOptions(endpoint, collection, dim, dist)) + } + + def load(bytes: Array[Byte], config: KnnConfig): IO[KnnIndexReader] = config match { + case KnnConfig.HnswConfig(m, ef) => HnswIndexWriter.load(bytes, HnswOptions(m, ef)) + case KnnConfig.QdrantConfig(endpoint, collection, dim, dist) => + QdrantIndexWriter.load(bytes, QdrantOptions(endpoint, collection, dim, dist)) + } +} diff --git a/src/main/scala/ai/metarank/ml/recommend/embedding/QdrantIndex.scala b/src/main/scala/ai/metarank/ml/recommend/embedding/QdrantIndex.scala new file mode 100644 index 000000000..af8055ea0 --- /dev/null +++ b/src/main/scala/ai/metarank/ml/recommend/embedding/QdrantIndex.scala @@ -0,0 +1,159 @@ +package ai.metarank.ml.recommend.embedding + +import ai.metarank.ml.Model +import ai.metarank.ml.Model.ItemScore +import ai.metarank.ml.recommend.embedding.KnnIndex.{KnnIndexReader, KnnIndexWriter} +import ai.metarank.model.Identifier +import ai.metarank.model.Identifier.ItemId +import ai.metarank.util.Logging +import cats.effect.IO +import fs2.Chunk +import io.circe.Codec +import io.circe.generic.semiauto.deriveCodec +import org.http4s.{EntityDecoder, EntityEncoder, Method, Request, Uri} +import org.http4s.blaze.client.BlazeClientBuilder +import org.http4s.circe._ +import org.http4s.client.Client +import io.circe.syntax._ +import io.lettuce.core.output.ByteArrayOutput + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.util.UUID +import scala.concurrent.duration._ + +object QdrantIndex extends Logging { + val BITSTREAM_VERSION = 1 + + case class QdrantIndexReader(endpoint: Uri, client: Client[IO], ids: Map[String, UUID], idrev: Map[UUID, String]) + extends KnnIndexReader { + override def lookup(items: List[Identifier.ItemId], n: Int): IO[List[Model.ItemScore]] = { + val request = SearchRequest(items.flatMap(x => ids.get(x.value)), n) + client + .expect[SearchResponse]( + Request[IO](uri = endpoint / "points" / "recommend", method = Method.POST).withEntity(request) + ) + .flatTap(resp => info(s"request: ${request}") *> info(s"response: ${resp}")) + .map(response => response.result.flatMap(is => idrev.get(is.id).map(id => ItemScore(ItemId(id), is.score)))) + } + + override def save(): Array[Byte] = { + val out = new ByteArrayOutputStream() + val stream = new DataOutputStream(out) + stream.writeByte(BITSTREAM_VERSION) + stream.writeInt(ids.size) + ids.foreach { case (key, value) => + stream.writeUTF(key) + stream.writeLong(value.getLeastSignificantBits) + stream.writeLong(value.getMostSignificantBits) + } + out.toByteArray + } + } + + object QdrantIndexWriter extends KnnIndexWriter[QdrantIndexReader, QdrantOptions] { + override def load(bytes: Array[Byte], options: QdrantOptions): IO[QdrantIndexReader] = for { + clientTuple <- makeClient().allocated + (client, close) = clientTuple + ids <- loadIdMap(bytes) + endpoint <- IO.fromEither(Uri.fromString(options.endpoint)) + } yield { + QdrantIndexReader(endpoint / "collections" / options.collection, client, ids, ids.map(kv => kv._2 -> kv._1)) + } + + private def loadIdMap(bytes: Array[Byte]): IO[Map[String, UUID]] = IO { + val in = new ByteArrayInputStream(bytes) + val stream = new DataInputStream(in) + val version = stream.readByte() + val size = stream.readInt() + val pairs = (0 until size).map(_ => { + val key = stream.readUTF() + val least = stream.readLong() + val most = stream.readLong() + key -> new UUID(most, least) + }) + pairs.toMap + } + + override def write(embeddings: EmbeddingMap, options: QdrantOptions): IO[QdrantIndexReader] = for { + clientTuple <- makeClient().allocated + (client, _) = clientTuple + endpoint <- IO.fromEither(Uri.fromString(options.endpoint)) + uri = endpoint / "collections" / options.collection + exists <- collectionExists(client, uri) + _ <- IO.whenA(!exists)(createCollection(client, uri, options.dim, options.distance).void) + ids <- fs2.Stream + .chunk[IO, String](Chunk.array(embeddings.ids)) + .zip(fs2.Stream.chunk(Chunk.array(embeddings.embeddings))) + .map(x => Embedding(x._1, x._2)) + .groupWithin(128, 1.second) + .evalMap(batch => putBatch(client, uri, batch.toList)) + .flatMap(list => fs2.Stream(list: _*)) + .compile + .toList + _ <- info(s"uploaded ${ids.size} vectors") + } yield { + QdrantIndexReader(uri, client, ids.toMap, ids.map(x => x._2 -> x._1).toMap) + } + + def makeClient() = BlazeClientBuilder[IO] + .withRequestTimeout(10.second) + .withConnectTimeout(10.second) + .resource + + def createCollection(client: Client[IO], uri: Uri, dim: Int, dist: String): IO[QdrantResponse] = { + info(s"creating collection $uri") *> client.expect[QdrantResponse]( + Request[IO](method = Method.PUT, uri = uri.withQueryParam("wait", "true")) + .withEntity(CreateCollectionRequest(CreateCollectionVectors(dim, dist))) + ) + } + + def collectionExists(client: Client[IO], collection: Uri): IO[Boolean] = + client.get(collection)(response => + response.status.code match { + case 200 => info(s"collection $collection exists") *> IO(true) + case other => info(s"collection $collection is missing, status=$other") *> IO(false) + } + ) + + def putBatch(client: Client[IO], collection: Uri, batch: List[Embedding]): IO[List[(String, UUID)]] = + client + .expect[QdrantResponse]( + Request[IO](method = Method.PUT, uri = (collection / "points").withQueryParam("wait", "true")) + .withEntity( + QdrantPointsRequest(batch.map(e => QdrantPoint(UUID.nameUUIDFromBytes(e.id.getBytes()), e.vector))) + ) + ) + .map(r => batch.map(e => e.id -> UUID.nameUUIDFromBytes(e.id.getBytes()))) + .flatTap(batch => info(s"wrote batch of ${batch.size} items")) + + } + + case class CreateCollectionRequest(vectors: CreateCollectionVectors) + case class CreateCollectionVectors(size: Int, distance: String) + implicit val vectorsCodec: Codec[CreateCollectionVectors] = deriveCodec + implicit val createCodec: Codec[CreateCollectionRequest] = deriveCodec + implicit val createJson: EntityEncoder[IO, CreateCollectionRequest] = jsonEncoderOf[CreateCollectionRequest] + + case class QdrantOptions(endpoint: String, collection: String, dim: Int, distance: String) + + case class QdrantResponse(status: String) + implicit val responseCodec: Codec[QdrantResponse] = deriveCodec + implicit val responseJson: EntityDecoder[IO, QdrantResponse] = jsonOf[IO, QdrantResponse] + + case class QdrantPointsRequest(points: List[QdrantPoint]) + case class QdrantPoint(id: UUID, vector: Array[Double]) + implicit val pointCodec: Codec[QdrantPoint] = deriveCodec + implicit val pointRequestCodec: Codec[QdrantPointsRequest] = deriveCodec + + implicit val pointRequestJson: EntityEncoder[IO, QdrantPointsRequest] = jsonEncoderOf[QdrantPointsRequest] + + case class SearchRequest(positive: List[UUID], limit: Int) + implicit val searchCodec: Codec[SearchRequest] = deriveCodec + implicit val searchJson: EntityEncoder[IO, SearchRequest] = jsonEncoderOf[SearchRequest] + + case class SearchResponse(result: List[IdScore], status: String) + case class IdScore(id: UUID, score: Double) + implicit val idscoreCodec: Codec[IdScore] = deriveCodec + implicit val searchResponseCodec: Codec[SearchResponse] = deriveCodec + implicit val searchResponseJson: EntityDecoder[IO, SearchResponse] = jsonOf[IO, SearchResponse] +} diff --git a/src/main/scala/ai/metarank/ml/recommend/mf/ALSRecImpl.scala b/src/main/scala/ai/metarank/ml/recommend/mf/ALSRecImpl.scala index 62ecbed00..3042b34e4 100644 --- a/src/main/scala/ai/metarank/ml/recommend/mf/ALSRecImpl.scala +++ b/src/main/scala/ai/metarank/ml/recommend/mf/ALSRecImpl.scala @@ -1,6 +1,8 @@ package ai.metarank.ml.recommend.mf import ai.metarank.config.{ModelConfig, Selector} +import ai.metarank.ml.recommend.KnnConfig +import ai.metarank.ml.recommend.KnnConfig.HnswConfig import ai.metarank.ml.recommend.embedding.EmbeddingMap import ai.metarank.ml.recommend.mf.ALSRecImpl.{ALSConfig, EALSRecommenderWrapper} import ai.metarank.ml.recommend.mf.MFRecImpl.MFModelConfig @@ -47,8 +49,7 @@ object ALSRecImpl { factors: Int = 100, userReg: Float = 0.01f, itemReg: Float = 0.01f, - m: Int = 32, - ef: Int = 200, + store: KnnConfig = HnswConfig(), selector: Selector = Selector.AcceptSelector() ) extends MFModelConfig @@ -63,8 +64,7 @@ object ALSRecImpl { factors <- c.downField("factors").as[Option[Int]] userReg <- c.downField("userReg").as[Option[Float]] itemReg <- c.downField("itemRef").as[Option[Float]] - m <- c.downField("m").as[Option[Int]] - ef <- c.downField("ef").as[Option[Int]] + store <- c.downField("store").as[Option[KnnConfig]] selector <- c.downField("selector").as[Option[Selector]] } yield { val d = ALSConfig() @@ -74,8 +74,7 @@ object ALSRecImpl { factors = factors.getOrElse(d.factors), userReg = userReg.getOrElse(d.userReg), itemReg = itemReg.getOrElse(d.itemReg), - m = m.getOrElse(d.m), - ef = ef.getOrElse(d.ef), + store = store.getOrElse(HnswConfig()), selector = selector.getOrElse(Selector.AcceptSelector()) ) } diff --git a/src/main/scala/ai/metarank/ml/recommend/mf/MFRecImpl.scala b/src/main/scala/ai/metarank/ml/recommend/mf/MFRecImpl.scala index 0f0d59b85..2979a2f20 100644 --- a/src/main/scala/ai/metarank/ml/recommend/mf/MFRecImpl.scala +++ b/src/main/scala/ai/metarank/ml/recommend/mf/MFRecImpl.scala @@ -1,6 +1,7 @@ package ai.metarank.ml.recommend.mf import ai.metarank.config.{ModelConfig, Selector} +import ai.metarank.ml.recommend.KnnConfig import ai.metarank.ml.recommend.embedding.EmbeddingMap import fs2.io.file.Path @@ -15,8 +16,7 @@ object MFRecImpl { def factors: Int def userReg: Float def itemReg: Float - def m: Int - def ef: Int + def store: KnnConfig def selector: Selector } } diff --git a/src/main/scala/ai/metarank/model/AnalyticsPayload.scala b/src/main/scala/ai/metarank/model/AnalyticsPayload.scala index f5a290344..249a9ced5 100644 --- a/src/main/scala/ai/metarank/model/AnalyticsPayload.scala +++ b/src/main/scala/ai/metarank/model/AnalyticsPayload.scala @@ -35,6 +35,7 @@ import ai.metarank.main.CliArgs.{ import ai.metarank.ml.rank.LambdaMARTRanker.LambdaMARTConfig import ai.metarank.ml.rank.NoopRanker.NoopConfig import ai.metarank.ml.rank.ShuffleRanker.ShuffleConfig +import ai.metarank.ml.recommend.BertSemanticRecommender.BertSemanticModelConfig import ai.metarank.ml.recommend.TrendingRecommender.TrendingConfig import ai.metarank.ml.recommend.mf.ALSRecImpl.ALSConfig import ai.metarank.model.AnalyticsPayload.{SystemParams, UsedFeature} @@ -116,8 +117,9 @@ object AnalyticsPayload { case NoopConfig(_) => "noop" case TrendingConfig(_, _) => "trending" case _: ALSConfig => "als" + case _: BertSemanticModelConfig => "bert" }.toList, - usedFeatures = config.features.toList.map { + usedFeatures = config.features.map { case f: RateFeatureSchema => UsedFeature(f.name, "rate") case f: BooleanFeatureSchema => UsedFeature(f.name, "boolean") case f: FieldMatchSchema => UsedFeature(f.name, "field_match") diff --git a/src/test/resources/ranklens/events/events.jsonl.gz b/src/test/resources/ranklens/events/events.jsonl.gz index e7f6ec518..746badd82 100644 --- a/src/test/resources/ranklens/events/events.jsonl.gz +++ b/src/test/resources/ranklens/events/events.jsonl.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2a43f8799fb9ea224b0b196b2b9361f40f59ef4744382f9629b9c786199c38da -size 3752998 +oid sha256:d4009c698bcc23f0648107dc167f2a3c797ebcb1737e68ec727c7057558a99aa +size 4023025 diff --git a/src/test/scala/ai/metarank/fstore/codec/impl/TrainValuesCodecTest.scala b/src/test/scala/ai/metarank/fstore/codec/impl/TrainValuesCodecTest.scala index 87b85a4f2..cf09f98e4 100644 --- a/src/test/scala/ai/metarank/fstore/codec/impl/TrainValuesCodecTest.scala +++ b/src/test/scala/ai/metarank/fstore/codec/impl/TrainValuesCodecTest.scala @@ -60,7 +60,7 @@ class TrainValuesCodecTest extends AnyFlatSpec with Matchers { it should "roundtrip ctv" in { val out = new ByteArrayOutputStream() TrainValuesCodec.write(ctv, new DataOutputStream(out)) - val temp = File("/tmp/ctv.bin").writeByteArray(out.toByteArray) + // val temp = File("/tmp/ctv.bin").writeByteArray(out.toByteArray) val decoded = TrainValuesCodec.read(new DataInputStream(new ByteArrayInputStream(out.toByteArray))) decoded shouldBe ctv } diff --git a/src/test/scala/ai/metarank/ml/PredictorSuite.scala b/src/test/scala/ai/metarank/ml/PredictorSuite.scala index e1b25b249..d7e67fef2 100644 --- a/src/test/scala/ai/metarank/ml/PredictorSuite.scala +++ b/src/test/scala/ai/metarank/ml/PredictorSuite.scala @@ -29,7 +29,7 @@ trait PredictorSuite[C <: ModelConfig, T <: Context, M <: Model[T]] extends AnyF val rec = predictor.fit(fs2.Stream.apply(cts: _*)).unsafeRunSync() val bytes = rec.save() val restore = predictor.load(bytes) - val req = restore.map(_.predict(request(10)).unsafeRunSync()).toOption.get + val req = restore.flatMap(_.predict(request(10))).unsafeRunSync() req.items.toList shouldNot be(empty) } diff --git a/src/test/scala/ai/metarank/ml/rank/LambdaMARTRankerTest.scala b/src/test/scala/ai/metarank/ml/rank/LambdaMARTRankerTest.scala index 4399e4a9d..5ff20c5d8 100644 --- a/src/test/scala/ai/metarank/ml/rank/LambdaMARTRankerTest.scala +++ b/src/test/scala/ai/metarank/ml/rank/LambdaMARTRankerTest.scala @@ -57,7 +57,7 @@ class LambdaMARTRankerTest extends PredictorSuite[LambdaMARTConfig, QueryRequest val pred2 = LambdaMARTPredictor("foo", conf, desc) val model = predictor.fit(fs2.Stream(cts: _*)).unsafeRunSync() val blob = model.save() - val result = pred2.load(blob) - result.isLeft shouldBe true + val result = Try(pred2.load(blob).unsafeRunSync()) + result.isSuccess shouldBe false } } diff --git a/src/test/scala/ai/metarank/ml/recommend/BertSemanticRecommenderTest.scala b/src/test/scala/ai/metarank/ml/recommend/BertSemanticRecommenderTest.scala index 17b7e05d3..aea245411 100644 --- a/src/test/scala/ai/metarank/ml/recommend/BertSemanticRecommenderTest.scala +++ b/src/test/scala/ai/metarank/ml/recommend/BertSemanticRecommenderTest.scala @@ -6,6 +6,7 @@ import ai.metarank.ml.recommend.BertSemanticRecommender.Encoder.CsvEncoder import ai.metarank.ml.recommend.BertSemanticRecommender.{BertSemanticModelConfig, BertSemanticPredictor} import ai.metarank.ml.recommend.BertSemanticRecommender.EncoderType.BertEncoderType import ai.metarank.ml.recommend.BertSemanticRecommenderTest.Movie +import ai.metarank.ml.recommend.KnnConfig.HnswConfig import ai.metarank.model.Event.ItemEvent import ai.metarank.model.Field.StringField import ai.metarank.model.Identifier.ItemId @@ -23,7 +24,8 @@ class BertSemanticRecommenderTest extends AnyFlatSpec with Matchers { it should "train the model" in { val conf = BertSemanticModelConfig( encoder = BertEncoderType("sentence-transformer/all-MiniLM-L6-v2"), - itemFields = List("title", "description") + itemFields = List("title", "description"), + store = HnswConfig() ) val model = BertSemanticPredictor("foo", conf) val events: List[ItemValues] = RanklensEvents.apply().collect { case e: ItemEvent => ItemValues(e) } @@ -41,7 +43,8 @@ class BertSemanticRecommenderTest extends AnyFlatSpec with Matchers { decoded shouldBe Right( BertSemanticModelConfig( encoder = BertEncoderType("sentence-transformer/all-MiniLM-L6-v2"), - itemFields = List("title", "description") + itemFields = List("title", "description"), + store = HnswConfig() ) ) } diff --git a/src/test/scala/ai/metarank/ml/recommend/EmbeddingSimilarityModelTest.scala b/src/test/scala/ai/metarank/ml/recommend/EmbeddingSimilarityModelTest.scala new file mode 100644 index 000000000..36a496ea7 --- /dev/null +++ b/src/test/scala/ai/metarank/ml/recommend/EmbeddingSimilarityModelTest.scala @@ -0,0 +1,34 @@ +package ai.metarank.ml.recommend + +import ai.metarank.ml.Model +import ai.metarank.ml.Model.ItemScore +import ai.metarank.ml.recommend.MFRecommender.EmbeddingSimilarityModel +import ai.metarank.ml.recommend.embedding.KnnIndex.KnnIndexReader +import ai.metarank.model.Identifier +import ai.metarank.model.Identifier.ItemId +import cats.effect.IO +import cats.effect.unsafe.implicits.global +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class EmbeddingSimilarityModelTest extends AnyFlatSpec with Matchers { + it should "not return the same item" in { + val index = new KnnIndexReader { + override def lookup(items: List[Identifier.ItemId], n: Int): IO[List[Model.ItemScore]] = + IO.pure( + List( + ItemScore(ItemId("p1"), 1.0), + ItemScore(ItemId("p2"), 1.0), + ItemScore(ItemId("p3"), 1.0), + ItemScore(ItemId("p4"), 1.0), + ItemScore(ItemId("p5"), 1.0) + ) + ) + + override def save(): Array[Byte] = ??? + } + val model = EmbeddingSimilarityModel("foo", index) + val response = model.predict(RecommendRequest(3, items = List(ItemId("p2"), ItemId("p4")))).unsafeRunSync() + response.items.toList.map(_.item.value) shouldBe List("p1", "p3", "p5") + } +} diff --git a/src/test/scala/ai/metarank/ml/recommend/embedding/HnswJavaIndexTest.scala b/src/test/scala/ai/metarank/ml/recommend/embedding/HnswJavaIndexTest.scala index 7a9dc8e09..3ab55f825 100644 --- a/src/test/scala/ai/metarank/ml/recommend/embedding/HnswJavaIndexTest.scala +++ b/src/test/scala/ai/metarank/ml/recommend/embedding/HnswJavaIndexTest.scala @@ -1,8 +1,11 @@ package ai.metarank.ml.recommend.embedding +import ai.metarank.ml.recommend.embedding.HnswJavaIndex.{HnswIndexWriter, HnswOptions} import ai.metarank.model.Identifier.ItemId +import cats.effect.unsafe.implicits.global import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers + import scala.util.Random class HnswJavaIndexTest extends AnyFlatSpec with Matchers { @@ -14,8 +17,8 @@ class HnswJavaIndexTest extends AnyFlatSpec with Matchers { rows = 1000, cols = 100 ) - val index = HnswJavaIndex.create(map, 32, 200) - val similar = index.lookup(List(ItemId("75")), 10) + val index = HnswIndexWriter.write(map, HnswOptions(32, 200)).unsafeRunSync() + val similar = index.lookup(List(ItemId("75")), 10).unsafeRunSync() similar.size shouldBe 10 }