From 1cb8a348046f4e8c13a0be8f537fd6f8b9ea91de Mon Sep 17 00:00:00 2001 From: Grebennikov Roman Date: Sat, 1 Apr 2023 11:41:34 +0200 Subject: [PATCH 1/3] onnx embedding support for field_match --- doc/configuration/features/text.md | 13 ++- .../cross-encoder/ms-marco-MiniLM-L-6-v2.onnx | 3 - .../metarank/ml/onnx/HuggingFaceClient.scala | 87 +++++++++++++++++++ .../ai/metarank/ml/onnx/ModelHandle.scala | 39 +++++++++ .../ml/onnx/encoder/BertEncoder.scala | 56 ++++++++++-- .../ai/metarank/ml/onnx/encoder/Encoder.scala | 4 +- .../ml/onnx/encoder/EncoderType.scala | 29 +++++-- .../ai/metarank/model/AnalyticsPayload.scala | 2 +- .../ai/metarank/model/FeatureSchema.scala | 15 ++-- .../scala/ai/metarank/util/LocalCache.scala | 74 ++++++++++++++++ .../all-MiniLM-L6-v2.onnx | 0 .../sbert/sentence-transformer/vocab.txt | 0 .../FieldMatchBiencoderFeatureTest.scala | 9 +- .../ai/metarank/ml/onnx/ModelHandleTest.scala | 32 +++++++ .../BertSemanticRecommenderTest.scala | 9 +- .../metarank/util/HuggingFaceClientTest.scala | 27 ++++++ .../ai/metarank/util/LocalCacheTest.scala | 24 +++++ 17 files changed, 384 insertions(+), 39 deletions(-) delete mode 100644 src/main/resources/sbert/cross-encoder/ms-marco-MiniLM-L-6-v2.onnx create mode 100644 src/main/scala/ai/metarank/ml/onnx/HuggingFaceClient.scala create mode 100644 src/main/scala/ai/metarank/ml/onnx/ModelHandle.scala create mode 100644 src/main/scala/ai/metarank/util/LocalCache.scala rename src/{main => test}/resources/sbert/sentence-transformer/all-MiniLM-L6-v2.onnx (100%) rename src/{main => test}/resources/sbert/sentence-transformer/vocab.txt (100%) create mode 100644 src/test/scala/ai/metarank/ml/onnx/ModelHandleTest.scala create mode 100644 src/test/scala/ai/metarank/util/HuggingFaceClientTest.scala create mode 100644 src/test/scala/ai/metarank/util/LocalCacheTest.scala diff --git a/doc/configuration/features/text.md b/doc/configuration/features/text.md index f560a3216..0328a573b 100644 --- a/doc/configuration/features/text.md +++ b/doc/configuration/features/text.md @@ -124,7 +124,7 @@ Both term and ngram method share the same approach to the text analysis: * for non-generic languages each term is stemmed * then terms/ngrams from item and ranking are scored using intersection/union method. -### BERT LLM embedding similarity +### Transformer LLM embedding similarity Then with the following config snippet we can compute a cosine distance between title and query embeddings: @@ -135,7 +135,14 @@ Then with the following config snippet we can compute a cosine distance between itemField: item.title distance: cos # optional, default cos, options: cos/dot method: - type: bert # the only one supported for now - model: sentence-transformer/all-MiniLM-L6-v2 # the only one supported now + type: transformer + model: metarank/all-MiniLM-L6-v2 ``` +Metarank supports two embedding methods: +* `transformer`: ONNX-encoded versions of the [sentence-transformers](https://sbert.net/docs/pretrained_models.html) models. See the [metarank HuggingFace namespace](https://huggingface.co/metarank) for a list of currently supported models. +* `csv`: a comma-separated file with precomputed embeddings, where first row is source sentence. Useful for externally-generated embeddings with platforms like OpenAI and Cohere. + +For `transformer` models, Metarank supports fetching model directly from the HuggingFace Hub, or loading it from a local dir, depending on the model handle format: +* `namespace/model`: fetch model from the HFHub +* `file://///`: load ONNX-encoded embedding model from a local file. diff --git a/src/main/resources/sbert/cross-encoder/ms-marco-MiniLM-L-6-v2.onnx b/src/main/resources/sbert/cross-encoder/ms-marco-MiniLM-L-6-v2.onnx deleted file mode 100644 index 5d4d2fa4e..000000000 --- a/src/main/resources/sbert/cross-encoder/ms-marco-MiniLM-L-6-v2.onnx +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e247da66458ed13a2a7e4436f99a564a41b2fb59d8ebfa955e8dcfa8a80a73f8 -size 90979431 diff --git a/src/main/scala/ai/metarank/ml/onnx/HuggingFaceClient.scala b/src/main/scala/ai/metarank/ml/onnx/HuggingFaceClient.scala new file mode 100644 index 000000000..19282b789 --- /dev/null +++ b/src/main/scala/ai/metarank/ml/onnx/HuggingFaceClient.scala @@ -0,0 +1,87 @@ +package ai.metarank.ml.onnx + +import ai.metarank.ml.onnx.HuggingFaceClient.ModelResponse +import ai.metarank.ml.onnx.HuggingFaceClient.ModelResponse.Sibling +import ai.metarank.ml.onnx.ModelHandle.HuggingFaceHandle +import ai.metarank.util.Logging +import cats.effect.IO +import cats.effect.kernel.Resource +import io.circe.Codec +import io.circe.generic.semiauto.{deriveCodec, deriveDecoder} +import org.http4s.{EntityDecoder, Request, Uri} +import org.http4s.blaze.client.BlazeClientBuilder +import org.http4s.client.Client +import org.http4s.circe._ +import fs2.Stream +import org.typelevel.ci.CIString + +import java.io.ByteArrayOutputStream +import scala.concurrent.duration._ + +case class HuggingFaceClient(client: Client[IO], endpoint: Uri) extends Logging { + + implicit val modelResponseDecoder: EntityDecoder[IO, ModelResponse] = jsonOf[IO, ModelResponse] + + def model(handle: HuggingFaceHandle) = for { + request <- IO(Request[IO](uri = endpoint / "api" / "models" / handle.ns / handle.name)) + _ <- info(s"sending HuggingFace API request $request") + response <- client.expect[ModelResponse](request) + } yield { + response + } + + def modelFile(handle: HuggingFaceHandle, fileName: String): IO[Array[Byte]] = { + get(endpoint / handle.ns / handle.name / "resolve" / "main" / fileName) + } + + def get(uri: Uri): IO[Array[Byte]] = + client + .stream(Request[IO](uri = uri)) + .evalTap(_ => info(s"sending HuggingFace API request for a file $uri")) + .evalMap(response => + response.status.code match { + case 200 => + info("HuggingFace API: HTTP 200") *> response.entity.body.compile + .foldChunks(new ByteArrayOutputStream())((acc, c) => { + acc.writeBytes(c.toArray) + acc + }) + .map(_.toByteArray) + case 302 => + response.headers.get(CIString("Location")) match { + case Some(locations) => + Uri.fromString(locations.head.value) match { + case Left(value) => IO.raiseError(value) + case Right(uri) => info("302 redirect") *> get(uri) + } + case None => IO.raiseError(new Exception(s"Got 302 redirect without location")) + } + case other => IO.raiseError(new Exception(s"HTTP code $other")) + } + ) + .compile + .fold(new ByteArrayOutputStream())((acc, c) => { + acc.writeBytes(c) + acc + }) + .map(_.toByteArray) +} + +object HuggingFaceClient { + val HUGGINGFACE_API_ENDPOINT = "https://huggingface.co" + case class ModelResponse(id: String, siblings: List[Sibling]) + object ModelResponse { + case class Sibling(rfilename: String) + } + + implicit val modelSiblingCodec: Codec[Sibling] = deriveCodec[Sibling] + implicit val modelResponseCodec: Codec[ModelResponse] = deriveCodec[ModelResponse] + + def create(endpoint: String = HUGGINGFACE_API_ENDPOINT): Resource[IO, HuggingFaceClient] = for { + uri <- Resource.eval(IO.fromEither(Uri.fromString(endpoint))) + client <- BlazeClientBuilder[IO].withRequestTimeout(120.second).withConnectTimeout(120.second).resource + } yield { + HuggingFaceClient(client, uri) + } + +} diff --git a/src/main/scala/ai/metarank/ml/onnx/ModelHandle.scala b/src/main/scala/ai/metarank/ml/onnx/ModelHandle.scala new file mode 100644 index 000000000..ba850960a --- /dev/null +++ b/src/main/scala/ai/metarank/ml/onnx/ModelHandle.scala @@ -0,0 +1,39 @@ +package ai.metarank.ml.onnx + +import io.circe.{Decoder, Encoder, Json} + +import scala.util.{Failure, Success} + +sealed trait ModelHandle { + def name: String + def asList: List[String] +} + +object ModelHandle { + + def apply(ns: String, name: String) = HuggingFaceHandle(ns, name) + + case class HuggingFaceHandle(ns: String, name: String) extends ModelHandle { + override def asList: List[String] = List(ns, name) + } + case class LocalModelHandle(dir: String) extends ModelHandle { + override def name: String = dir + override def asList: List[String] = List(dir) + } + + val huggingFacePattern = "([a-zA-Z0-9\\-]+)/([0-9A-Za-z\\-_]+)".r + val localPattern1 = "file:/(/.+)".r + val localPattern2 = "file://(/.+)".r + + implicit val modelHandleDecoder: Decoder[ModelHandle] = Decoder.decodeString.emapTry { + case huggingFacePattern(ns, name) => Success(HuggingFaceHandle(ns, name)) + case localPattern2(path) => Success(LocalModelHandle(path)) + case localPattern1(path) => Success(LocalModelHandle(path)) + case other => Failure(new Exception(s"cannot parse model handle '$other'")) + } + + implicit val modelHandleEncoder: Encoder[ModelHandle] = Encoder.instance { + case HuggingFaceHandle(ns, name) => Json.fromString(s"$ns/$name") + case LocalModelHandle(path) => Json.fromString(s"file://$path") + } +} diff --git a/src/main/scala/ai/metarank/ml/onnx/encoder/BertEncoder.scala b/src/main/scala/ai/metarank/ml/onnx/encoder/BertEncoder.scala index d56826567..37713d7f4 100644 --- a/src/main/scala/ai/metarank/ml/onnx/encoder/BertEncoder.scala +++ b/src/main/scala/ai/metarank/ml/onnx/encoder/BertEncoder.scala @@ -1,22 +1,62 @@ package ai.metarank.ml.onnx.encoder -import ai.metarank.ml.onnx.SBERT +import ai.metarank.ml.onnx.ModelHandle.{HuggingFaceHandle, LocalModelHandle} +import ai.metarank.ml.onnx.{HuggingFaceClient, ModelHandle, SBERT} import ai.metarank.model.Identifier.ItemId +import ai.metarank.util.{LocalCache, Logging} import cats.effect.IO +import org.apache.commons.io.IOUtils + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileInputStream} case class BertEncoder(sbert: SBERT) extends Encoder { override def dim: Int = sbert.dim - override def encode(str: String): Array[Float] = sbert.embed(str) + override def encode(str: String): Array[Float] = sbert.embed(str) override def encode(id: ItemId, str: String): Array[Float] = sbert.embed(str) } -object BertEncoder { - def create(model: String): IO[BertEncoder] = IO { - val sbert = SBERT( - model = this.getClass.getResourceAsStream(s"/sbert/$model.onnx"), - dic = this.getClass.getResourceAsStream("/sbert/sentence-transformer/vocab.txt") +object BertEncoder extends Logging { + def create(model: ModelHandle, modelFile: String, vocabFile: String): IO[BertEncoder] = model match { + case hf: HuggingFaceHandle => loadFromHuggingFace(hf, modelFile, vocabFile).map(BertEncoder.apply) + case local: LocalModelHandle => loadFromLocalDir(local, modelFile, vocabFile).map(BertEncoder.apply) + } + + def loadFromHuggingFace(handle: HuggingFaceHandle, modelFile: String, vocabFile: String): IO[SBERT] = for { + cache <- LocalCache.create() + modelDirName <- IO(handle.asList.mkString(File.separator)) + sbert <- HuggingFaceClient + .create() + .use(hf => + for { + modelBytes <- cache.getIfExists(modelDirName, modelFile).flatMap { + case Some(bytes) => IO.pure(bytes) + case None => hf.modelFile(handle, modelFile).flatTap(bytes => cache.put(modelDirName, modelFile, bytes)) + } + vocabBytes <- cache.getIfExists(modelDirName, vocabFile).flatMap { + case Some(bytes) => IO.pure(bytes) + case None => hf.modelFile(handle, vocabFile).flatTap(bytes => cache.put(modelDirName, vocabFile, bytes)) + } + } yield { + SBERT( + model = new ByteArrayInputStream(modelBytes), + dic = new ByteArrayInputStream(vocabBytes) + ) + } + ) + } yield { + sbert + } + + def loadFromLocalDir(handle: LocalModelHandle, modelFile: String, vocabFile: String): IO[SBERT] = for { + _ <- info(s"loading $modelFile from $handle") + modelBytes <- IO(IOUtils.toByteArray(new FileInputStream(new File(handle.dir + File.separator + modelFile)))) + vocabBytes <- IO(IOUtils.toByteArray(new FileInputStream(new File(handle.dir + File.separator + vocabFile)))) + } yield { + SBERT( + model = new ByteArrayInputStream(modelBytes), + dic = new ByteArrayInputStream(vocabBytes) ) - BertEncoder(sbert) } + } diff --git a/src/main/scala/ai/metarank/ml/onnx/encoder/Encoder.scala b/src/main/scala/ai/metarank/ml/onnx/encoder/Encoder.scala index 5884a7952..5e84f74c9 100644 --- a/src/main/scala/ai/metarank/ml/onnx/encoder/Encoder.scala +++ b/src/main/scala/ai/metarank/ml/onnx/encoder/Encoder.scala @@ -11,9 +11,7 @@ trait Encoder { object Encoder { def create(conf: EncoderType): IO[Encoder] = conf match { - case EncoderType.BertEncoderType(model) => BertEncoder.create(model) + case EncoderType.BertEncoderType(model, modelFile, vocabFile) => BertEncoder.create(model, modelFile, vocabFile) case EncoderType.CsvEncoderType(path) => CsvEncoder.create(path) - } - } diff --git a/src/main/scala/ai/metarank/ml/onnx/encoder/EncoderType.scala b/src/main/scala/ai/metarank/ml/onnx/encoder/EncoderType.scala index 9d640ab4d..e067f6c44 100644 --- a/src/main/scala/ai/metarank/ml/onnx/encoder/EncoderType.scala +++ b/src/main/scala/ai/metarank/ml/onnx/encoder/EncoderType.scala @@ -1,26 +1,43 @@ package ai.metarank.ml.onnx.encoder +import ai.metarank.ml.onnx.ModelHandle import io.circe.generic.semiauto._ import io.circe.{Decoder, DecodingFailure} sealed trait EncoderType object EncoderType { - case class BertEncoderType(model: String) extends EncoderType + case class BertEncoderType( + model: ModelHandle, + modelFile: String = "pytorch_model.onnx", + vocabFile: String = "vocab.txt" + ) extends EncoderType case class CsvEncoderType(path: String) extends EncoderType - implicit val bertDecoder: Decoder[BertEncoderType] = deriveDecoder[BertEncoderType] + implicit val bertDecoder: Decoder[BertEncoderType] = Decoder.instance(c => + for { + model <- c.downField("model").as[ModelHandle] + modelFile <- c.downField("modelFile").as[Option[String]] + vocabFile <- c.downField("vocabFile").as[Option[String]] + } yield { + BertEncoderType( + model, + modelFile = modelFile.getOrElse("pytorch_model.onnx"), + vocabFile = vocabFile.getOrElse("vocab.txt") + ) + } + ) implicit val bertEncoder: io.circe.Encoder[BertEncoderType] = deriveEncoder[BertEncoderType] implicit val csvDecoder: Decoder[CsvEncoderType] = deriveDecoder[CsvEncoderType] implicit val csvEncoder: io.circe.Encoder[CsvEncoderType] = deriveEncoder[CsvEncoderType] implicit val encoderTypeDecoder: Decoder[EncoderType] = Decoder.instance(c => c.downField("type").as[String] match { - case Left(err) => Left(err) - case Right("bert") => bertDecoder.tryDecode(c) - case Right("csv") => csvDecoder.tryDecode(c) - case Right(other) => Left(DecodingFailure(s"cannot decode embedding type $other", c.history)) + case Left(err) => Left(err) + case Right("transformer") => bertDecoder.tryDecode(c) + case Right("csv") => csvDecoder.tryDecode(c) + case Right(other) => Left(DecodingFailure(s"cannot decode embedding type $other", c.history)) } ) } diff --git a/src/main/scala/ai/metarank/model/AnalyticsPayload.scala b/src/main/scala/ai/metarank/model/AnalyticsPayload.scala index b89146200..86a594acc 100644 --- a/src/main/scala/ai/metarank/model/AnalyticsPayload.scala +++ b/src/main/scala/ai/metarank/model/AnalyticsPayload.scala @@ -119,7 +119,7 @@ object AnalyticsPayload { case NoopConfig(_) => "noop" case TrendingConfig(_, _) => "trending" case _: ALSConfig => "als" - case _: BertSemanticModelConfig => "bert" + case _: BertSemanticModelConfig => "semantic" }.toList, usedFeatures = config.features.map { case f: RateFeatureSchema => UsedFeature(f.name, "rate") diff --git a/src/main/scala/ai/metarank/model/FeatureSchema.scala b/src/main/scala/ai/metarank/model/FeatureSchema.scala index 306c02a07..e0a97d90a 100644 --- a/src/main/scala/ai/metarank/model/FeatureSchema.scala +++ b/src/main/scala/ai/metarank/model/FeatureSchema.scala @@ -57,13 +57,14 @@ object FeatureSchema { val biencoder = implicitly[Decoder[FieldMatchBiencoderSchema]] val term = implicitly[Decoder[FieldMatchSchema]] c.downField("method").downField("type").as[String] match { - case Left(err) => Left(err) - case Right("bert") => biencoder.apply(c) - case Right("csv") => biencoder.apply(c) - case Right("term") => term.apply(c) - case Right("ngram") => term.apply(c) - case Right("bm25") => term.apply(c) - case Right(other) => Left(DecodingFailure(s"term matching method $other is not supported", c.history)) + case Right(other) => Left(DecodingFailure(s"term matching method $other is not supported", c.history)) + case Left(err) => Left(err) + case Right("transformer") => biencoder.apply(c) + case Right("csv") => biencoder.apply(c) + case Right("term") => term.apply(c) + case Right("ngram") => term.apply(c) + case Right("bm25") => term.apply(c) + case Right(other) => Left(DecodingFailure(s"term matching method $other is not supported", c.history)) } case "referer" => implicitly[Decoder[RefererSchema]].apply(c) case "position" => implicitly[Decoder[PositionFeatureSchema]].apply(c) diff --git a/src/main/scala/ai/metarank/util/LocalCache.scala b/src/main/scala/ai/metarank/util/LocalCache.scala new file mode 100644 index 000000000..b2233635f --- /dev/null +++ b/src/main/scala/ai/metarank/util/LocalCache.scala @@ -0,0 +1,74 @@ +package ai.metarank.util + +import cats.effect.IO +import org.apache.commons.lang3.SystemUtils +import fs2.io.file.Files +import fs2.io.readInputStream +import org.apache.commons.io.IOUtils + +import java.io.{ByteArrayInputStream, File, FileInputStream} +import java.nio.file.Path + +case class LocalCache(cacheDir: String) extends Logging { + def getIfExists(dir: String, file: String): IO[Option[Array[Byte]]] = for { + targetFile <- IO(new File(cacheDir + File.separator + dir + File.separator + file)) + contents <- targetFile.exists() match { + case true => IO(Some(IOUtils.toByteArray(new FileInputStream(targetFile)))) + case false => IO.none + } + } yield { + contents + } + + def put(dir: String, file: String, bytes: Array[Byte]): IO[Unit] = for { + targetDir <- IO(new File(cacheDir + File.separator + dir)) + targetFile <- IO(new File(cacheDir + File.separator + dir + File.separator + file)) + _ <- IO.whenA(!targetDir.exists())(IO(targetDir.mkdirs())) + _ <- info(s"writing cache file $dir/$file") + _ <- readInputStream[IO](IO(new ByteArrayInputStream(bytes)), 1024) + .through(Files[IO].writeAll(fs2.io.file.Path(targetFile.toString))) + .compile + .drain + } yield {} +} + +object LocalCache extends Logging { + def create() = { + for { + topDir <- cacheDir() + metarankDir <- IO(new File(topDir.toString + File.separator + "metarank")) + _ <- IO.whenA(!metarankDir.exists())( + info(s"cache dir $metarankDir is not present, creating") *> IO(metarankDir.mkdirs()) + ) + _ <- info(s"using $metarankDir as local cache dir") + } yield { + LocalCache(metarankDir.toString) + } + } + + def cacheDir() = IO { + val fallback = Path.of(System.getProperty("java.io.tmpdir")) + val dir = if (SystemUtils.IS_OS_WINDOWS) { + Option(System.getenv("LOCALAPPDATA")).map(path => Path.of(path)).filter(_.toFile.exists()).getOrElse(fallback) + } else if (SystemUtils.IS_OS_MAC) { + Option(System.getProperty("user.home")) + .map(home => s"$home/Library/Caches") + .map(path => Path.of(path)) + .filter(_.toFile.exists()) + .getOrElse(fallback) + } else if (SystemUtils.IS_OS_LINUX) { + val default = Option(System.getProperty("user.home")) + .map(home => s"$home/.cache") + .map(path => Path.of(path)) + .filter(_.toFile.exists()) + Option(System.getenv("XDG_CACHE_HOME")) + .map(path => Path.of(path)) + .filter(_.toFile.exists()) + .orElse(default) + .getOrElse(fallback) + } else { + fallback + } + dir + } +} diff --git a/src/main/resources/sbert/sentence-transformer/all-MiniLM-L6-v2.onnx b/src/test/resources/sbert/sentence-transformer/all-MiniLM-L6-v2.onnx similarity index 100% rename from src/main/resources/sbert/sentence-transformer/all-MiniLM-L6-v2.onnx rename to src/test/resources/sbert/sentence-transformer/all-MiniLM-L6-v2.onnx diff --git a/src/main/resources/sbert/sentence-transformer/vocab.txt b/src/test/resources/sbert/sentence-transformer/vocab.txt similarity index 100% rename from src/main/resources/sbert/sentence-transformer/vocab.txt rename to src/test/resources/sbert/sentence-transformer/vocab.txt diff --git a/src/test/scala/ai/metarank/feature/FieldMatchBiencoderFeatureTest.scala b/src/test/scala/ai/metarank/feature/FieldMatchBiencoderFeatureTest.scala index d83edd8d6..f98a10f29 100644 --- a/src/test/scala/ai/metarank/feature/FieldMatchBiencoderFeatureTest.scala +++ b/src/test/scala/ai/metarank/feature/FieldMatchBiencoderFeatureTest.scala @@ -2,6 +2,7 @@ package ai.metarank.feature import ai.metarank.feature.FieldMatchBiencoderFeature.FieldMatchBiencoderSchema import ai.metarank.fstore.memory.MemPersistence +import ai.metarank.ml.onnx.ModelHandle import ai.metarank.ml.onnx.distance.DistanceFunction.CosineDistance import ai.metarank.ml.onnx.encoder.EncoderType.BertEncoderType import ai.metarank.model.Event.ItemEvent @@ -20,10 +21,10 @@ import org.scalatest.matchers.should.Matchers class FieldMatchBiencoderFeatureTest extends AnyFlatSpec with Matchers with FeatureTest { val schema = FieldMatchBiencoderSchema( name = FeatureName("foo"), - rankingField = FieldName(Ranking,"query"), + rankingField = FieldName(Ranking, "query"), itemField = FieldName(Item, "title"), distance = CosineDistance, - method = BertEncoderType("sentence-transformer/all-MiniLM-L6-v2") + method = BertEncoderType(ModelHandle("metarank", "all-MiniLM-L6-v2")) ) lazy val feature = schema.create().unsafeRunSync().asInstanceOf[FieldMatchBiencoderFeature] @@ -43,8 +44,8 @@ class FieldMatchBiencoderFeatureTest extends AnyFlatSpec with Matchers with Feat |itemField: item.title |distance: cosine |method: - | type: bert - | model: sentence-transformer/all-MiniLM-L6-v2 + | type: transformer + | model: metarank/all-MiniLM-L6-v2 | """.stripMargin val decoded = io.circe.yaml.parser.parse(yaml).flatMap(_.as[FeatureSchema]) decoded shouldBe Right(schema) diff --git a/src/test/scala/ai/metarank/ml/onnx/ModelHandleTest.scala b/src/test/scala/ai/metarank/ml/onnx/ModelHandleTest.scala new file mode 100644 index 000000000..6182f32f9 --- /dev/null +++ b/src/test/scala/ai/metarank/ml/onnx/ModelHandleTest.scala @@ -0,0 +1,32 @@ +package ai.metarank.ml.onnx + +import ai.metarank.ml.onnx.ModelHandle.{HuggingFaceHandle, LocalModelHandle} +import ai.metarank.ml.onnx.ModelHandleTest.HandleTest +import io.circe.Decoder +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import io.circe.parser._ +import io.circe.generic.semiauto._ + +class ModelHandleTest extends AnyFlatSpec with Matchers { + it should "decode HF handle" in { + parse("metarank/foo") shouldBe Right(HuggingFaceHandle("metarank", "foo")) + } + + it should "decode local handle with single slash" in { + parse("file://tmp/file") shouldBe Right(LocalModelHandle("/tmp/file")) + } + + it should "decode local handle with double slash" in { + parse("file:///tmp/file") shouldBe Right(LocalModelHandle("/tmp/file")) + } + + def parse(handle: String): Either[Throwable, ModelHandle] = { + decode[HandleTest](s"""{"handle": "$handle"}""").map(_.handle) + } +} + +object ModelHandleTest { + case class HandleTest(handle: ModelHandle) + implicit val handleDecoder: Decoder[HandleTest] = deriveDecoder +} diff --git a/src/test/scala/ai/metarank/ml/recommend/BertSemanticRecommenderTest.scala b/src/test/scala/ai/metarank/ml/recommend/BertSemanticRecommenderTest.scala index 6ec35d32f..a0e57f45a 100644 --- a/src/test/scala/ai/metarank/ml/recommend/BertSemanticRecommenderTest.scala +++ b/src/test/scala/ai/metarank/ml/recommend/BertSemanticRecommenderTest.scala @@ -2,6 +2,7 @@ package ai.metarank.ml.recommend import ai.metarank.config.ModelConfig import ai.metarank.config.Selector.AcceptSelector +import ai.metarank.ml.onnx.ModelHandle import ai.metarank.ml.onnx.encoder.CsvEncoder import ai.metarank.ml.onnx.encoder.EncoderType.BertEncoderType import ai.metarank.ml.recommend.BertSemanticRecommender.{BertSemanticModelConfig, BertSemanticPredictor} @@ -22,7 +23,7 @@ import java.nio.charset.StandardCharsets class BertSemanticRecommenderTest extends AnyFlatSpec with Matchers { it should "train the model" in { val conf = BertSemanticModelConfig( - encoder = BertEncoderType("sentence-transformer/all-MiniLM-L6-v2"), + encoder = BertEncoderType(ModelHandle("metarank", "all-MiniLM-L6-v2")), itemFields = List("title", "description"), store = HnswConfig() ) @@ -35,13 +36,13 @@ class BertSemanticRecommenderTest extends AnyFlatSpec with Matchers { val yaml = """type: semantic |encoder: - | type: bert - | model: sentence-transformer/all-MiniLM-L6-v2 + | type: transformer + | model: metarank/all-MiniLM-L6-v2 |itemFields: [title, description]""".stripMargin val decoded = io.circe.yaml.parser.parse(yaml).flatMap(_.as[ModelConfig]) decoded shouldBe Right( BertSemanticModelConfig( - encoder = BertEncoderType("sentence-transformer/all-MiniLM-L6-v2"), + encoder = BertEncoderType(ModelHandle("metarank", "all-MiniLM-L6-v2")), itemFields = List("title", "description"), store = HnswConfig() ) diff --git a/src/test/scala/ai/metarank/util/HuggingFaceClientTest.scala b/src/test/scala/ai/metarank/util/HuggingFaceClientTest.scala new file mode 100644 index 000000000..89a0be044 --- /dev/null +++ b/src/test/scala/ai/metarank/util/HuggingFaceClientTest.scala @@ -0,0 +1,27 @@ +package ai.metarank.util + +import ai.metarank.ml.onnx.{HuggingFaceClient, ModelHandle} +import ai.metarank.ml.onnx.HuggingFaceClient.ModelResponse +import ai.metarank.ml.onnx.HuggingFaceClient.ModelResponse.Sibling +import cats.effect.unsafe.implicits.global +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class HuggingFaceClientTest extends AnyFlatSpec with Matchers { + it should "fetch metadata" in { + val model = HuggingFaceClient + .create() + .use(client => client.model(ModelHandle("metarank", "all-MiniLM-L6-v2"))) + .unsafeRunSync() + model.siblings should contain(Sibling("vocab.txt")) + } + + it should "fetch files" in { + val vocab = + HuggingFaceClient + .create() + .use(_.modelFile(ModelHandle("metarank", "all-MiniLM-L6-v2"), "vocab.txt")) + .unsafeRunSync() + vocab.length should be > (100) + } +} diff --git a/src/test/scala/ai/metarank/util/LocalCacheTest.scala b/src/test/scala/ai/metarank/util/LocalCacheTest.scala new file mode 100644 index 000000000..0edab995e --- /dev/null +++ b/src/test/scala/ai/metarank/util/LocalCacheTest.scala @@ -0,0 +1,24 @@ +package ai.metarank.util + +import cats.effect.unsafe.implicits.global +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class LocalCacheTest extends AnyFlatSpec with Matchers { + it should "init with default dir" in { + val cache = LocalCache.create().unsafeRunSync() + } + + it should "cache a file and read it" in { + val cache = LocalCache.create().unsafeRunSync() + cache.put("test", "test.bin", "test".getBytes()).unsafeRunSync() + val back = cache.getIfExists("test", "test.bin").unsafeRunSync() + back.map(new String(_)) shouldBe Some("test") + } + + it should "load empty on not found" in { + val cache = LocalCache.create().unsafeRunSync() + val back = cache.getIfExists("test", "404.bin").unsafeRunSync() + back shouldBe None + } +} From 45c003cc9cf7136a040b999a28bf76a7baf40c11 Mon Sep 17 00:00:00 2001 From: Grebennikov Roman Date: Sat, 1 Apr 2023 11:54:15 +0200 Subject: [PATCH 2/3] compile fix --- src/main/scala/ai/metarank/model/FeatureSchema.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/main/scala/ai/metarank/model/FeatureSchema.scala b/src/main/scala/ai/metarank/model/FeatureSchema.scala index e0a97d90a..bfbb1c6e8 100644 --- a/src/main/scala/ai/metarank/model/FeatureSchema.scala +++ b/src/main/scala/ai/metarank/model/FeatureSchema.scala @@ -57,14 +57,13 @@ object FeatureSchema { val biencoder = implicitly[Decoder[FieldMatchBiencoderSchema]] val term = implicitly[Decoder[FieldMatchSchema]] c.downField("method").downField("type").as[String] match { - case Right(other) => Left(DecodingFailure(s"term matching method $other is not supported", c.history)) - case Left(err) => Left(err) case Right("transformer") => biencoder.apply(c) case Right("csv") => biencoder.apply(c) case Right("term") => term.apply(c) case Right("ngram") => term.apply(c) case Right("bm25") => term.apply(c) case Right(other) => Left(DecodingFailure(s"term matching method $other is not supported", c.history)) + case Left(err) => Left(err) } case "referer" => implicitly[Decoder[RefererSchema]].apply(c) case "position" => implicitly[Decoder[PositionFeatureSchema]].apply(c) From a329f92d963cb2c2e02e6bab8dc46940f4ad8b14 Mon Sep 17 00:00:00 2001 From: Grebennikov Roman Date: Sat, 1 Apr 2023 12:10:21 +0200 Subject: [PATCH 3/3] better logging while fetching the model from HF --- .gitignore | 3 ++- .../ai/metarank/ml/onnx/HuggingFaceClient.scala | 13 ++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 93971b422..b28613de5 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ api/target .run *.crt *.key -*.csr \ No newline at end of file +*.csr +src/test/scala/ai/metarank/tool/esci \ No newline at end of file diff --git a/src/main/scala/ai/metarank/ml/onnx/HuggingFaceClient.scala b/src/main/scala/ai/metarank/ml/onnx/HuggingFaceClient.scala index 19282b789..fbb926b73 100644 --- a/src/main/scala/ai/metarank/ml/onnx/HuggingFaceClient.scala +++ b/src/main/scala/ai/metarank/ml/onnx/HuggingFaceClient.scala @@ -1,5 +1,6 @@ package ai.metarank.ml.onnx +import ai.metarank.flow.PrintProgress import ai.metarank.ml.onnx.HuggingFaceClient.ModelResponse import ai.metarank.ml.onnx.HuggingFaceClient.ModelResponse.Sibling import ai.metarank.ml.onnx.ModelHandle.HuggingFaceHandle @@ -41,7 +42,9 @@ case class HuggingFaceClient(client: Client[IO], endpoint: Uri) extends Logging .evalMap(response => response.status.code match { case 200 => - info("HuggingFace API: HTTP 200") *> response.entity.body.compile + info("HuggingFace API: HTTP 200") *> response.entity.body + .through(PrintProgress.tap(None, "bytes")) + .compile .foldChunks(new ByteArrayOutputStream())((acc, c) => { acc.writeBytes(c.toArray) acc @@ -78,8 +81,12 @@ object HuggingFaceClient { implicit val modelResponseCodec: Codec[ModelResponse] = deriveCodec[ModelResponse] def create(endpoint: String = HUGGINGFACE_API_ENDPOINT): Resource[IO, HuggingFaceClient] = for { - uri <- Resource.eval(IO.fromEither(Uri.fromString(endpoint))) - client <- BlazeClientBuilder[IO].withRequestTimeout(120.second).withConnectTimeout(120.second).resource + uri <- Resource.eval(IO.fromEither(Uri.fromString(endpoint))) + client <- BlazeClientBuilder[IO] + .withRequestTimeout(120.second) + .withConnectTimeout(120.second) + .withIdleTimeout(200.seconds) + .resource } yield { HuggingFaceClient(client, uri) }