Skip to content

Commit

Permalink
show progress for embedding cache loading (#1001)
Browse files Browse the repository at this point in the history
  • Loading branch information
shuttie committed Apr 11, 2023
1 parent e3e5f7c commit fa8d0a2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
6 changes: 4 additions & 2 deletions src/main/scala/ai/metarank/ml/onnx/EmbeddingCache.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.metarank.ml.onnx

import ai.metarank.flow.PrintProgress
import ai.metarank.util.CSVStream
import cats.effect.IO

Expand All @@ -13,7 +14,8 @@ object EmbeddingCache {
def empty(): EmbeddingCache = EmbeddingCache(Map.empty)
def fromStream(stream: fs2.Stream[IO, Array[String]], dim: Int): IO[EmbeddingCache] =
stream
.evalMapChunk(line => IO.fromEither(parseEmbedding(line, dim)))
.parEvalMapUnordered(8)(line => IO.fromEither(parseEmbedding(line, dim)))
.through(PrintProgress.tap(None, "embeddings"))
.compile
.toList
.map(list => EmbeddingCache(list.map(e => e.key -> e.emb).toMap))
Expand All @@ -24,7 +26,7 @@ object EmbeddingCache {

def parseEmbedding(line: Array[String], dim: Int): Either[Throwable, Embedding] = {
if (line.length != dim + 1) {
Left(new Exception(s"dim mismatch for line ${line.toList}"))
Left(new Exception(s"dim mismatch for line ${line.toList}: expected $dim, got line with ${line.length} cols"))
} else {
val key = line(0)
val buffer = new Array[Float](dim)
Expand Down
16 changes: 10 additions & 6 deletions src/main/scala/ai/metarank/ml/onnx/encoder/Encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ai.metarank.ml.onnx.encoder

import ai.metarank.ml.onnx.EmbeddingCache
import ai.metarank.model.Identifier.ItemId
import ai.metarank.util.Logging
import cats.effect.IO

trait Encoder {
Expand All @@ -10,16 +11,16 @@ trait Encoder {
def dim: Int
}

object Encoder {
object Encoder extends Logging {
def create(conf: EncoderType): IO[Encoder] = conf match {
case EncoderType.BertEncoderType(model, itemCache, rankCache, modelFile, vocabFile, dim) =>
for {
items <- itemCache match {
case Some(path) => EmbeddingCache.fromCSV(path, ',', dim)
fields <- rankCache match {
case Some(path) => info("Loading ranking embeddings") *> EmbeddingCache.fromCSV(path, ',', dim)
case None => IO.pure(EmbeddingCache.empty())
}
fields <- rankCache match {
case Some(path) => EmbeddingCache.fromCSV(path, ',', dim)
items <- itemCache match {
case Some(path) => info("Loading item embeddings") *> EmbeddingCache.fromCSV(path, ',', dim)
case None => IO.pure(EmbeddingCache.empty())
}
bert <- BertEncoder.create(model, modelFile, vocabFile)
Expand All @@ -29,8 +30,11 @@ object Encoder {

case EncoderType.CsvEncoderType(itemPath, fieldPath, dim) =>
for {
items <- EmbeddingCache.fromCSV(itemPath, ',', dim)
_ <- info("Loading ranking embeddings")
fields <- EmbeddingCache.fromCSV(fieldPath, ',', dim)
_ <- info("Loading item embeddings")
items <- EmbeddingCache.fromCSV(itemPath, ',', dim)

} yield {
CachedEncoder(items, fields, ZeroEncoder(dim))
}
Expand Down

0 comments on commit fa8d0a2

Please sign in to comment.