# Load Huggingface Question Answering Tensorflow Model

In [1]:
import $ivy.`ai.djl:api:0.17.0`
import $ivy.`ai.djl.huggingface:tokenizers:0.17.0`
import $ivy.`ai.djl.pytorch:pytorch-model-zoo:0.17.0`
import $ivy.`ai.djl.pytorch:pytorch-engine:0.17.0`
import $ivy.`org.slf4j:slf4j-api:1.7.36`
import $ivy.`org.slf4j:slf4j-simple:1.7.36`

[32mimport [39m[36m$ivy.$                  
[39m
[32mimport [39m[36m$ivy.$                                     
[39m
[32mimport [39m[36m$ivy.$                                        
[39m
[32mimport [39m[36m$ivy.$                                     
[39m
[32mimport [39m[36m$ivy.$                           
[39m
[32mimport [39m[36m$ivy.$                              [39m

In [2]:
import java.io.IOException
import java.nio.file.{Files, Paths}
import java.util

import ai.djl.modality.Classifications
import ai.djl.modality.nlp.DefaultVocabulary
import ai.djl.modality.nlp.bert.BertTokenizer
import ai.djl.modality.nlp.Vocabulary
import ai.djl.ndarray.NDList
import ai.djl.repository.zoo.Criteria
import ai.djl.training.util.{DownloadUtils, ProgressBar}
import ai.djl.translate.{Batchifier, Translator, TranslatorContext}
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer

[32mimport [39m[36mjava.io.IOException
[39m
[32mimport [39m[36mjava.nio.file.{Files, Paths}
[39m
[32mimport [39m[36mjava.util

[39m
[32mimport [39m[36mai.djl.modality.Classifications
[39m
[32mimport [39m[36mai.djl.modality.nlp.DefaultVocabulary
[39m
[32mimport [39m[36mai.djl.modality.nlp.bert.BertTokenizer
[39m
[32mimport [39m[36mai.djl.modality.nlp.Vocabulary
[39m
[32mimport [39m[36mai.djl.ndarray.NDList
[39m
[32mimport [39m[36mai.djl.repository.zoo.Criteria
[39m
[32mimport [39m[36mai.djl.training.util.{DownloadUtils, ProgressBar}
[39m
[32mimport [39m[36mai.djl.translate.{Batchifier, Translator, TranslatorContext}
[39m
[32mimport [39m[36mai.djl.huggingface.tokenizers.HuggingFaceTokenizer[39m

In [3]:
case class PredictedToken(token: String, score: Double)

defined [32mclass[39m [36mPredictedToken[39m

In [4]:
class HFBertFillMaskTranslator extends Translator[String, Seq[PredictedToken]] {

  private var vocabulary: DefaultVocabulary = _
  private var tokenizer: HuggingFaceTokenizer = _
  private var tokenList: Array[String] = _
  private final val MaskToken = "[MASK]"
  private final val TopK = 5

  override def prepare(ctx: TranslatorContext): Unit = {
    val path = Paths.get("build/huggingface/fill_mask/pytorch/bert-base-uncased/vocab.txt")
    vocabulary = DefaultVocabulary.builder
      .optMinFrequency(1)
      .addFromTextFile(path)
      .optUnknownToken("[UNK]")
      .build
    tokenizer = HuggingFaceTokenizer.newInstance("bert-base-uncased")
  }

  override def processInput(ctx: TranslatorContext, input: String): NDList = {
    val token = tokenizer.encode(input.toLowerCase().replace(MaskToken.toLowerCase(), MaskToken))
    // get the encoded tokens that would be used in precessOutput
    tokenList = token.getTokens
    // map the tokens(String) to indices(long)

    val manager = ctx.getNDManager
    val indices = tokenList.map(vocabulary.getIndex)
    val attentionMask = token.getAttentionMask.map(i => i)
    val indicesArray = manager.create(indices)
    val attentionMaskArray = manager.create(attentionMask)

    new NDList(indicesArray, attentionMaskArray)
  }

  override def processOutput(ctx: TranslatorContext, list: NDList): Seq[PredictedToken] = {
    val maskIndex = tokenList.zipWithIndex.find(_._1 == MaskToken).map(_._2).getOrElse(-1)
    if (maskIndex == -1) {
      Seq.empty[PredictedToken]
    } else {
      val ndArray = list.get(0)
      val shape = ndArray.getShape
      val len = shape.get(1)

      (1 to TopK).map { i =>
        val out = ndArray.get(maskIndex).argSort().getLong(len - i)
        PredictedToken(vocabulary.getToken(out), ndArray.getFloat(maskIndex, out))
      }
    }
  }

  override def getBatchifier: Batchifier = Batchifier.STACK
}

defined [32mclass[39m [36mHFBertFillMaskTranslator[39m

In [5]:
val input = "Paris is the [MASK] of France.".toLowerCase.replace("[mask]", "[MASK]")

val translator = new HFBertFillMaskTranslator()
val criteria = Criteria.builder
  .setTypes(classOf[String], classOf[Seq[PredictedToken]])
  .optModelPath(Paths.get("build/huggingface/fill_mask/pytorch/bert-base-uncased/"))
  .optTranslator(translator)
  .optProgress(new ProgressBar)
  .build

val model = criteria.loadModel()

val predictor = model.newPredictor(translator)

val predictResult = predictor.predict(input)
predictResult.foreach(println(_))

Loading:     100% |████████████████████████████████████████|


[scala-interpreter-1] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 4
[scala-interpreter-1] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 4


PredictedToken(capital,18.19974136352539)
PredictedToken(heart,10.769936561584473)
PredictedToken(center,10.469231605529785)
PredictedToken(centre,10.209856986999512)
PredictedToken(city,9.985564231872559)


[36minput[39m: [32mString[39m = [32m"paris is the [MASK] of france."[39m
[36mtranslator[39m: [32mHFBertFillMaskTranslator[39m = ammonite.$sess.cmd3$Helper$HFBertFillMaskTranslator@56774091
[36mcriteria[39m: [32mCriteria[39m[[32mString[39m, [32mSeq[39m[[32mPredictedToken[39m]] = Criteria:
	Application: UNDEFINED
	Input: class java.lang.String
	Output: interface scala.collection.Seq
	ModelZoo: ai.djl.localmodelzoo

[36mmodel[39m: [32mai[39m.[32mdjl[39m.[32mrepository[39m.[32mzoo[39m.[32mZooModel[39m[[32mString[39m, [32mSeq[39m[[32mPredictedToken[39m]] = ai.djl.repository.zoo.ZooModel@5ab918f6
[36mpredictor[39m: [32mai[39m.[32mdjl[39m.[32minference[39m.[32mPredictor[39m[[32mString[39m, [32mSeq[39m[[32mPredictedToken[39m]] = ai.djl.inference.Predictor@51b24605
[36mpredictResult[39m: [32mSeq[39m[[32mPredictedToken[39m] = [33mVector[39m(
  [33mPredictedToken[39m([32m"capital"[39m, [32m18.19974136352539[39m),
  [33mPredicted