# Load your own PyTorch BERT model

In the previous [example](https://github.com/deepjavalibrary/djl/blob/master/jupyter/BERTQA.ipynb), you run BERT inference with the model from Model Zoo. You can also load the model on your own pre-trained BERT and use custom classes as the input and output.

In general, the PyTorch BERT model from [HuggingFace](https://github.com/huggingface/transformers) requires these three inputs:

- word indices: The index of each word in a sentence
- word types: The type index of the word.
- attention mask: The mask indicates to the model which tokens should be attended to, and which should not after batching sequence together.

We will dive deep into these details later.

## Preparation

In [1]:
import $ivy.`ai.djl:api:0.17.0`
import $ivy.`ai.djl.pytorch:pytorch-model-zoo:0.17.0`
import $ivy.`ai.djl.pytorch:pytorch-engine:0.17.0`
import $ivy.`org.slf4j:slf4j-api:1.7.36`
import $ivy.`org.slf4j:slf4j-simple:1.7.36`

[32mimport [39m[36m$ivy.$                  
[39m
[32mimport [39m[36m$ivy.$                                        
[39m
[32mimport [39m[36m$ivy.$                                     
[39m
[32mimport [39m[36m$ivy.$                           
[39m
[32mimport [39m[36m$ivy.$                              [39m

## Import java packages

In [2]:
import java.io.IOException
import java.nio.file.Paths
import java.util

import ai.djl.modality.nlp.DefaultVocabulary
import ai.djl.modality.nlp.bert.BertTokenizer
import ai.djl.modality.nlp.qa.QAInput
import ai.djl.ndarray.NDList
import ai.djl.repository.zoo.Criteria
import ai.djl.training.util.{DownloadUtils, ProgressBar}
import ai.djl.translate.{Batchifier, Translator, TranslatorContext}

[32mimport [39m[36mjava.io.IOException
[39m
[32mimport [39m[36mjava.nio.file.Paths
[39m
[32mimport [39m[36mjava.util

[39m
[32mimport [39m[36mai.djl.modality.nlp.DefaultVocabulary
[39m
[32mimport [39m[36mai.djl.modality.nlp.bert.BertTokenizer
[39m
[32mimport [39m[36mai.djl.modality.nlp.qa.QAInput
[39m
[32mimport [39m[36mai.djl.ndarray.NDList
[39m
[32mimport [39m[36mai.djl.repository.zoo.Criteria
[39m
[32mimport [39m[36mai.djl.training.util.{DownloadUtils, ProgressBar}
[39m
[32mimport [39m[36mai.djl.translate.{Batchifier, Translator, TranslatorContext}[39m

## Download model and vocab files

In [3]:
DownloadUtils.download(
  "https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/pytorch/bertqa/0.0.1/bert-base-uncased-vocab.txt.gz",
  "build/pytorch/bert-qa/vocab.txt",
  new ProgressBar
)

DownloadUtils.download(
  "https://djl-ai.s3.amazonaws.com/mlrepo/model/nlp/question_answer/ai/djl/pytorch/bertqa/0.0.1/trace_bertqa.pt.gz",
  "build/pytorch/bert-qa/bert-qa.pt",
  new ProgressBar
)

## Create translator

In [4]:
class BertQuAnTranslator extends Translator[QAInput, String] {

  private var vocabulary: DefaultVocabulary = _
  private var tokenizer: BertTokenizer = _
  private var tokenList: util.List[String] = _

  override def prepare(ctx: TranslatorContext): Unit = {
    val path = Paths.get("build/pytorch/bert-qa/vocab.txt")
    vocabulary = DefaultVocabulary.builder
      .optMinFrequency(1)
      .addFromTextFile(path)
      .optUnknownToken("[UNK]")
      .build
    tokenizer = new BertTokenizer()
  }

  override def processInput(ctx: TranslatorContext, input: QAInput): NDList = {
    val token = tokenizer.encode(input.getQuestion.toLowerCase, input.getParagraph.toLowerCase, 384)
    // get the encoded tokens that would be used in precessOutput
    tokenList = token.getTokens
    // map the tokens(String) to indices(long)

    val manager = ctx.getNDManager
    val indices = tokenList.stream().mapToLong(vocabulary.getIndex).toArray
    val attentionMask = token.getAttentionMask.stream().mapToLong(i => i).toArray
    val tokenType = token.getTokenTypes.stream().mapToLong(i => i).toArray
    val indicesArray = manager.create(indices)
    val attentionMaskArray = manager.create(attentionMask)
    val tokenTypeArray = manager.create(tokenType)

    new NDList(indicesArray, attentionMaskArray, tokenTypeArray)
  }

  override def processOutput(ctx: TranslatorContext, list: NDList): String = {
    val startLogits = list.get(0)
    val endLogits = list.get(1)
    val startIdx = startLogits.argMax().getLong().asInstanceOf[Int]
    val endIdx = endLogits.argMax().getLong().asInstanceOf[Int]
    tokenList.subList(startIdx, endIdx + 1).toString
  }

  override def getBatchifier: Batchifier = Batchifier.STACK
}


defined [32mclass[39m [36mBertQuAnTranslator[39m

## Load pytorch model

In [5]:
val translator = new BertQuAnTranslator()
val criteria = Criteria.builder
  .setTypes(classOf[QAInput], classOf[String])
  .optModelPath(Paths.get("build/pytorch/bert-qa/"))
  .optTranslator(translator)
  .optProgress(new ProgressBar)
  .build

val model = criteria.loadModel()
val predictor = model.newPredictor(translator)

Loading:     100% |████████████████████████████████████████|


[scala-interpreter-1] INFO ai.djl.pytorch.engine.PtEngine - Number of inter-op threads is 4
[scala-interpreter-1] INFO ai.djl.pytorch.engine.PtEngine - Number of intra-op threads is 4


[36mtranslator[39m: [32mBertQuAnTranslator[39m = ammonite.$sess.cmd3$Helper$BertQuAnTranslator@221769bc
[36mcriteria[39m: [32mCriteria[39m[[32mQAInput[39m, [32mString[39m] = Criteria:
	Application: UNDEFINED
	Input: class ai.djl.modality.nlp.qa.QAInput
	Output: class java.lang.String
	ModelZoo: ai.djl.localmodelzoo

[36mmodel[39m: [32mai[39m.[32mdjl[39m.[32mrepository[39m.[32mzoo[39m.[32mZooModel[39m[[32mQAInput[39m, [32mString[39m] = ai.djl.repository.zoo.ZooModel@7a1ac919
[36mpredictor[39m: [32mai[39m.[32mdjl[39m.[32minference[39m.[32mPredictor[39m[[32mQAInput[39m, [32mString[39m] = ai.djl.inference.Predictor@547dbfd

## Test model

In [6]:
val question = "When did BBC Japan start broadcasting?";
val context = "BBC Japan was a general entertainment Channel.\n" +
    "Which operated between December 2004 and April 2006.\n" +
    "It ceased operations after its Japanese distributor folded.";
val input = new QAInput(question, context);
val predictResult = predictor.predict(input)

[36mquestion[39m: [32mString[39m = [32m"When did BBC Japan start broadcasting?"[39m
[36mcontext[39m: [32mString[39m = [32m"""BBC Japan was a general entertainment Channel.
Which operated between December 2004 and April 2006.
It ceased operations after its Japanese distributor folded."""[39m
[36minput[39m: [32mQAInput[39m = ai.djl.modality.nlp.qa.QAInput@2cb677a9
[36mpredictResult[39m: [32mString[39m = [32m"[december, 2004]"[39m

**A little more complicated example**

In [7]:
val context = """The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries."""

[36mcontext[39m: [32mString[39m = [32m"The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct cultural and ethnic identity of the Normans emerged initially in the first half of the 10th century, and it continued to evolve over the succeeding centuries."[39m

In [8]:
val q1 = "When were the Normans in Normandy?"
val q2 = "In what country is Normandy located?"

[36mq1[39m: [32mString[39m = [32m"When were the Normans in Normandy?"[39m
[36mq2[39m: [32mString[39m = [32m"In what country is Normandy located?"[39m

In [9]:
predictor.predict(new QAInput(q1, context))

[36mres8[39m: [32mString[39m = [32m"[10th, and, 11th, centuries]"[39m

In [10]:
predictor.predict(new QAInput(q2, context))

[36mres9[39m: [32mString[39m = [32m"[france]"[39m

In [11]:
predictor.close()