Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx embedding support for field_match #989

Merged
merged 3 commits into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ api/target
.run
*.crt
*.key
*.csr
*.csr
src/test/scala/ai/metarank/tool/esci
13 changes: 10 additions & 3 deletions doc/configuration/features/text.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Both term and ngram method share the same approach to the text analysis:
* for non-generic languages each term is stemmed
* then terms/ngrams from item and ranking are scored using intersection/union method.

### BERT LLM embedding similarity
### Transformer LLM embedding similarity

Then with the following config snippet we can compute a cosine distance between title and query embeddings:

Expand All @@ -135,7 +135,14 @@ Then with the following config snippet we can compute a cosine distance between
itemField: item.title
distance: cos # optional, default cos, options: cos/dot
method:
type: bert # the only one supported for now
model: sentence-transformer/all-MiniLM-L6-v2 # the only one supported now
type: transformer
model: metarank/all-MiniLM-L6-v2
```

Metarank supports two embedding methods:
* `transformer`: ONNX-encoded versions of the [sentence-transformers](https://sbert.net/docs/pretrained_models.html) models. See the [metarank HuggingFace namespace](https://huggingface.co/metarank) for a list of currently supported models.
* `csv`: a comma-separated file with precomputed embeddings, where first row is source sentence. Useful for externally-generated embeddings with platforms like OpenAI and Cohere.

For `transformer` models, Metarank supports fetching model directly from the HuggingFace Hub, or loading it from a local dir, depending on the model handle format:
* `namespace/model`: fetch model from the HFHub
* `file:///<path>/<to>/<model dir>`: load ONNX-encoded embedding model from a local file.

This file was deleted.

94 changes: 94 additions & 0 deletions src/main/scala/ai/metarank/ml/onnx/HuggingFaceClient.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package ai.metarank.ml.onnx

import ai.metarank.flow.PrintProgress
import ai.metarank.ml.onnx.HuggingFaceClient.ModelResponse
import ai.metarank.ml.onnx.HuggingFaceClient.ModelResponse.Sibling
import ai.metarank.ml.onnx.ModelHandle.HuggingFaceHandle
import ai.metarank.util.Logging
import cats.effect.IO
import cats.effect.kernel.Resource
import io.circe.Codec
import io.circe.generic.semiauto.{deriveCodec, deriveDecoder}
import org.http4s.{EntityDecoder, Request, Uri}
import org.http4s.blaze.client.BlazeClientBuilder
import org.http4s.client.Client
import org.http4s.circe._
import fs2.Stream
import org.typelevel.ci.CIString

import java.io.ByteArrayOutputStream
import scala.concurrent.duration._

case class HuggingFaceClient(client: Client[IO], endpoint: Uri) extends Logging {

implicit val modelResponseDecoder: EntityDecoder[IO, ModelResponse] = jsonOf[IO, ModelResponse]

def model(handle: HuggingFaceHandle) = for {
request <- IO(Request[IO](uri = endpoint / "api" / "models" / handle.ns / handle.name))
_ <- info(s"sending HuggingFace API request $request")
response <- client.expect[ModelResponse](request)
} yield {
response
}

def modelFile(handle: HuggingFaceHandle, fileName: String): IO[Array[Byte]] = {
get(endpoint / handle.ns / handle.name / "resolve" / "main" / fileName)
}

def get(uri: Uri): IO[Array[Byte]] =
client
.stream(Request[IO](uri = uri))
.evalTap(_ => info(s"sending HuggingFace API request for a file $uri"))
.evalMap(response =>
response.status.code match {
case 200 =>
info("HuggingFace API: HTTP 200") *> response.entity.body
.through(PrintProgress.tap(None, "bytes"))
.compile
.foldChunks(new ByteArrayOutputStream())((acc, c) => {
acc.writeBytes(c.toArray)
acc
})
.map(_.toByteArray)
case 302 =>
response.headers.get(CIString("Location")) match {
case Some(locations) =>
Uri.fromString(locations.head.value) match {
case Left(value) => IO.raiseError(value)
case Right(uri) => info("302 redirect") *> get(uri)
}
case None => IO.raiseError(new Exception(s"Got 302 redirect without location"))
}
case other => IO.raiseError(new Exception(s"HTTP code $other"))
}
)
.compile
.fold(new ByteArrayOutputStream())((acc, c) => {
acc.writeBytes(c)
acc
})
.map(_.toByteArray)
}

object HuggingFaceClient {
val HUGGINGFACE_API_ENDPOINT = "https://huggingface.co"
case class ModelResponse(id: String, siblings: List[Sibling])
object ModelResponse {
case class Sibling(rfilename: String)
}

implicit val modelSiblingCodec: Codec[Sibling] = deriveCodec[Sibling]
implicit val modelResponseCodec: Codec[ModelResponse] = deriveCodec[ModelResponse]

def create(endpoint: String = HUGGINGFACE_API_ENDPOINT): Resource[IO, HuggingFaceClient] = for {
uri <- Resource.eval(IO.fromEither(Uri.fromString(endpoint)))
client <- BlazeClientBuilder[IO]
.withRequestTimeout(120.second)
.withConnectTimeout(120.second)
.withIdleTimeout(200.seconds)
.resource
} yield {
HuggingFaceClient(client, uri)
}

}
39 changes: 39 additions & 0 deletions src/main/scala/ai/metarank/ml/onnx/ModelHandle.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package ai.metarank.ml.onnx

import io.circe.{Decoder, Encoder, Json}

import scala.util.{Failure, Success}

sealed trait ModelHandle {
def name: String
def asList: List[String]
}

object ModelHandle {

def apply(ns: String, name: String) = HuggingFaceHandle(ns, name)

case class HuggingFaceHandle(ns: String, name: String) extends ModelHandle {
override def asList: List[String] = List(ns, name)
}
case class LocalModelHandle(dir: String) extends ModelHandle {
override def name: String = dir
override def asList: List[String] = List(dir)
}

val huggingFacePattern = "([a-zA-Z0-9\\-]+)/([0-9A-Za-z\\-_]+)".r
val localPattern1 = "file:/(/.+)".r
val localPattern2 = "file://(/.+)".r

implicit val modelHandleDecoder: Decoder[ModelHandle] = Decoder.decodeString.emapTry {
case huggingFacePattern(ns, name) => Success(HuggingFaceHandle(ns, name))
case localPattern2(path) => Success(LocalModelHandle(path))
case localPattern1(path) => Success(LocalModelHandle(path))
case other => Failure(new Exception(s"cannot parse model handle '$other'"))
}

implicit val modelHandleEncoder: Encoder[ModelHandle] = Encoder.instance {
case HuggingFaceHandle(ns, name) => Json.fromString(s"$ns/$name")
case LocalModelHandle(path) => Json.fromString(s"file://$path")
}
}
56 changes: 48 additions & 8 deletions src/main/scala/ai/metarank/ml/onnx/encoder/BertEncoder.scala
Original file line number Diff line number Diff line change
@@ -1,22 +1,62 @@
package ai.metarank.ml.onnx.encoder

import ai.metarank.ml.onnx.SBERT
import ai.metarank.ml.onnx.ModelHandle.{HuggingFaceHandle, LocalModelHandle}
import ai.metarank.ml.onnx.{HuggingFaceClient, ModelHandle, SBERT}
import ai.metarank.model.Identifier.ItemId
import ai.metarank.util.{LocalCache, Logging}
import cats.effect.IO
import org.apache.commons.io.IOUtils

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileInputStream}

case class BertEncoder(sbert: SBERT) extends Encoder {
override def dim: Int = sbert.dim

override def encode(str: String): Array[Float] = sbert.embed(str)
override def encode(str: String): Array[Float] = sbert.embed(str)
override def encode(id: ItemId, str: String): Array[Float] = sbert.embed(str)
}

object BertEncoder {
def create(model: String): IO[BertEncoder] = IO {
val sbert = SBERT(
model = this.getClass.getResourceAsStream(s"/sbert/$model.onnx"),
dic = this.getClass.getResourceAsStream("/sbert/sentence-transformer/vocab.txt")
object BertEncoder extends Logging {
def create(model: ModelHandle, modelFile: String, vocabFile: String): IO[BertEncoder] = model match {
case hf: HuggingFaceHandle => loadFromHuggingFace(hf, modelFile, vocabFile).map(BertEncoder.apply)
case local: LocalModelHandle => loadFromLocalDir(local, modelFile, vocabFile).map(BertEncoder.apply)
}

def loadFromHuggingFace(handle: HuggingFaceHandle, modelFile: String, vocabFile: String): IO[SBERT] = for {
cache <- LocalCache.create()
modelDirName <- IO(handle.asList.mkString(File.separator))
sbert <- HuggingFaceClient
.create()
.use(hf =>
for {
modelBytes <- cache.getIfExists(modelDirName, modelFile).flatMap {
case Some(bytes) => IO.pure(bytes)
case None => hf.modelFile(handle, modelFile).flatTap(bytes => cache.put(modelDirName, modelFile, bytes))
}
vocabBytes <- cache.getIfExists(modelDirName, vocabFile).flatMap {
case Some(bytes) => IO.pure(bytes)
case None => hf.modelFile(handle, vocabFile).flatTap(bytes => cache.put(modelDirName, vocabFile, bytes))
}
} yield {
SBERT(
model = new ByteArrayInputStream(modelBytes),
dic = new ByteArrayInputStream(vocabBytes)
)
}
)
} yield {
sbert
}

def loadFromLocalDir(handle: LocalModelHandle, modelFile: String, vocabFile: String): IO[SBERT] = for {
_ <- info(s"loading $modelFile from $handle")
modelBytes <- IO(IOUtils.toByteArray(new FileInputStream(new File(handle.dir + File.separator + modelFile))))
vocabBytes <- IO(IOUtils.toByteArray(new FileInputStream(new File(handle.dir + File.separator + vocabFile))))
} yield {
SBERT(
model = new ByteArrayInputStream(modelBytes),
dic = new ByteArrayInputStream(vocabBytes)
)
BertEncoder(sbert)
}

}
4 changes: 1 addition & 3 deletions src/main/scala/ai/metarank/ml/onnx/encoder/Encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ trait Encoder {

object Encoder {
def create(conf: EncoderType): IO[Encoder] = conf match {
case EncoderType.BertEncoderType(model) => BertEncoder.create(model)
case EncoderType.BertEncoderType(model, modelFile, vocabFile) => BertEncoder.create(model, modelFile, vocabFile)
case EncoderType.CsvEncoderType(path) => CsvEncoder.create(path)

}

}
29 changes: 23 additions & 6 deletions src/main/scala/ai/metarank/ml/onnx/encoder/EncoderType.scala
Original file line number Diff line number Diff line change
@@ -1,26 +1,43 @@
package ai.metarank.ml.onnx.encoder

import ai.metarank.ml.onnx.ModelHandle
import io.circe.generic.semiauto._
import io.circe.{Decoder, DecodingFailure}

sealed trait EncoderType

object EncoderType {
case class BertEncoderType(model: String) extends EncoderType
case class BertEncoderType(
model: ModelHandle,
modelFile: String = "pytorch_model.onnx",
vocabFile: String = "vocab.txt"
) extends EncoderType

case class CsvEncoderType(path: String) extends EncoderType

implicit val bertDecoder: Decoder[BertEncoderType] = deriveDecoder[BertEncoderType]
implicit val bertDecoder: Decoder[BertEncoderType] = Decoder.instance(c =>
for {
model <- c.downField("model").as[ModelHandle]
modelFile <- c.downField("modelFile").as[Option[String]]
vocabFile <- c.downField("vocabFile").as[Option[String]]
} yield {
BertEncoderType(
model,
modelFile = modelFile.getOrElse("pytorch_model.onnx"),
vocabFile = vocabFile.getOrElse("vocab.txt")
)
}
)
implicit val bertEncoder: io.circe.Encoder[BertEncoderType] = deriveEncoder[BertEncoderType]
implicit val csvDecoder: Decoder[CsvEncoderType] = deriveDecoder[CsvEncoderType]
implicit val csvEncoder: io.circe.Encoder[CsvEncoderType] = deriveEncoder[CsvEncoderType]

implicit val encoderTypeDecoder: Decoder[EncoderType] = Decoder.instance(c =>
c.downField("type").as[String] match {
case Left(err) => Left(err)
case Right("bert") => bertDecoder.tryDecode(c)
case Right("csv") => csvDecoder.tryDecode(c)
case Right(other) => Left(DecodingFailure(s"cannot decode embedding type $other", c.history))
case Left(err) => Left(err)
case Right("transformer") => bertDecoder.tryDecode(c)
case Right("csv") => csvDecoder.tryDecode(c)
case Right(other) => Left(DecodingFailure(s"cannot decode embedding type $other", c.history))
}
)
}
2 changes: 1 addition & 1 deletion src/main/scala/ai/metarank/model/AnalyticsPayload.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ object AnalyticsPayload {
case NoopConfig(_) => "noop"
case TrendingConfig(_, _) => "trending"
case _: ALSConfig => "als"
case _: BertSemanticModelConfig => "bert"
case _: BertSemanticModelConfig => "semantic"
}.toList,
usedFeatures = config.features.map {
case f: RateFeatureSchema => UsedFeature(f.name, "rate")
Expand Down
14 changes: 7 additions & 7 deletions src/main/scala/ai/metarank/model/FeatureSchema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ object FeatureSchema {
val biencoder = implicitly[Decoder[FieldMatchBiencoderSchema]]
val term = implicitly[Decoder[FieldMatchSchema]]
c.downField("method").downField("type").as[String] match {
case Left(err) => Left(err)
case Right("bert") => biencoder.apply(c)
case Right("csv") => biencoder.apply(c)
case Right("term") => term.apply(c)
case Right("ngram") => term.apply(c)
case Right("bm25") => term.apply(c)
case Right(other) => Left(DecodingFailure(s"term matching method $other is not supported", c.history))
case Right("transformer") => biencoder.apply(c)
case Right("csv") => biencoder.apply(c)
case Right("term") => term.apply(c)
case Right("ngram") => term.apply(c)
case Right("bm25") => term.apply(c)
case Right(other) => Left(DecodingFailure(s"term matching method $other is not supported", c.history))
case Left(err) => Left(err)
}
case "referer" => implicitly[Decoder[RefererSchema]].apply(c)
case "position" => implicitly[Decoder[PositionFeatureSchema]].apply(c)
Expand Down
Loading