# Joule BERT Inference Demo

## Introduction

In this tutorial, you'll walk through the BERT QA model trained by MXNet. 
You can provide a question and a paragraph containing the answer to the model. The model is then able to find the best answer from the answer paragraph.

Example:
```text
Q: When did BBC Japan start broadcasting?
```

Answer paragraph:
```text
BBC Japan was a general entertainment channel, which operated between December 2004 and April 2006.
It ceased operations after its Japanese distributor folded.
```
And it picked the right answer:
```text
A: December 2004
```


### Step 1 Configure the maven repository
The following command define the repo to fetch the Joule package

In [37]:
%mavenRepo s3 https://joule.s3.amazonaws.com/repo

### Step 2 Import the required library
Please run the following command to load the Joule package and its dependencies

In [38]:
%maven software.amazon.ai:joule-api:0.2.0-SNAPSHOT
%maven org.apache.mxnet:mxnet-joule:0.2.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven net.java.dev.jna:jna:5.3.0

Due to the problem with gradle integration with Jupyter, we need to manually load MXNet pacakge from pom.

Please specify the MXNet package you would like to use by changing the `<classifier>` tag. Here are the two options you can go with for Mac and Linux system.

#### Mac OS
```
<classifier>osx-x86_64</classifier>
```

#### Ubuntu 16.04/Cent OS 7/Amazon Linux
```
<classifier>linux-x86_64</classifier>
```

In [39]:
%%loadFromPOM
  <repositories>
    <repository>
      <id>joule</id>
      <url>https://joule.s3.amazonaws.com/repo</url>
    </repository>
  </repositories>

  <dependencies>
    <dependency>
      <groupId>org.apache.mxnet</groupId>
      <artifactId>mxnet-native-mkl</artifactId>
      <version>1.5.0-SNAPSHOT</version>
      <classifier>osx-x86_64</classifier>
    </dependency>
  </dependencies>


Import the library that going to be used in here

In [82]:
import java.nio.charset.StandardCharsets;
import java.nio.file.*;
import java.util.*;
import java.io.*;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.annotations.SerializedName;

import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import org.slf4j.Logger;
import software.amazon.ai.Context;
import software.amazon.ai.Model;
import software.amazon.ai.TranslateException;
import software.amazon.ai.Translator;
import software.amazon.ai.TranslatorContext;
import software.amazon.ai.inference.Predictor;
import software.amazon.ai.metric.Metrics;
import software.amazon.ai.ndarray.NDArray;
import software.amazon.ai.ndarray.NDList;
import software.amazon.ai.ndarray.NDManager;
import software.amazon.ai.ndarray.types.DataDesc;
import software.amazon.ai.ndarray.types.DataType;
import software.amazon.ai.ndarray.types.Shape;
import software.amazon.ai.util.Utils;

### Step 3 Load the BertDataParser
The Class `BertDataParser` is used to load the vocabulary that Bert Embedding being trained. Please do not change the content of the following code.

In [83]:
/**
 * This is the Utility for pre-processing the data for Bert Model.
 *
 * <p>You can use this utility to parse vocabulary JSON into Java Array and Dictionary, clean and
 * tokenize sentences and pad the text
 */

public class BertDataParser {

    private static final Gson GSON = new GsonBuilder().create();
    private static final Pattern PATTERN = Pattern.compile("(\\S+?)([.,?!])?(\\s+|$)");

    @SerializedName("token_to_idx")
    private Map<String, Integer> token2idx;

    @SerializedName("idx_to_token")
    private List<String> idx2token;

    /**
    
     * Parse the Vocabulary to JSON files [PAD], [CLS], [SEP], [MASK], [UNK] are reserved tokens.
     *
     * @param is the {@code InputStream} for the vocab.json
     * @return instance of {@code BertDataParser}
     * @throws IllegalStateException if failed read from {@code InputStream}
     */
    public static BertDataParser parse(InputStream is) {
        try (Reader reader = new InputStreamReader(is, StandardCharsets.UTF_8)) {
            return GSON.fromJson(reader, BertDataParser.class);
        } catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    /**
     * Tokenize the input, split all kinds of whitespace and Separate the end of sentence symbol: .
     * , ? !
     *
     * @param input The input string
     * @return List of tokens
     */
    public static List<String> tokenizer(String input) {
        List<String> ret = new LinkedList<>();

        Matcher m = PATTERN.matcher(input);
        while (m.find()) {
            ret.add(m.group(1));
            String token = m.group(2);
            if (token != null) {
                ret.add(token);
            }
        }

        return ret;
    }

    /**
     * Pad the tokens to the required length.
     *
     * @param <E> the type of the List
     * @param tokens input tokens
     * @param padItem things to pad at the end
     * @param num total length after padding
     * @return List of padded tokens
     */
    public static <E> List<E> pad(List<E> tokens, E padItem, int num) {
        if (tokens.size() >= num) {
            return tokens;
        }
        List<E> padded = new ArrayList<>(num);
        padded.addAll(tokens);
        for (int i = tokens.size(); i < num; ++i) {
            padded.add(padItem);
        }
        return padded;
    }

    /**
     * Form the token types List [0000...1111...000] where all questions are 0 and answers are 1.
     *
     * @param question question tokens
     * @param answer answer tokens
     * @param seqLength sequence length
     * @return List of tokenTypes
     */
    public static List<Float> getTokenTypes(
            List<String> question, List<String> answer, int seqLength) {
        List<Float> qaEmbedded = new ArrayList<>();
        qaEmbedded = pad(qaEmbedded, 0f, question.size() + 2);
        qaEmbedded.addAll(pad(new ArrayList<>(), 1f, answer.size()));
        return pad(qaEmbedded, 0f, seqLength);
    }

    /**
     * Form tokens with separation that can be used for BERT.
     *
     * @param question question tokens
     * @param answer answer tokens
     * @param seqLength sequence length
     * @return List of tokenTypes
     */
    public static List<String> formTokens(
            List<String> question, List<String> answer, int seqLength) {
        // make BERT pre-processing standard
        List<String> tokens = new ArrayList<>(question);
        tokens.add("[SEP]");
        tokens.add(0, "[CLS]");
        answer.add("[SEP]");
        tokens.addAll(answer);
        tokens.add("[SEP]");
        return pad(tokens, "[PAD]", seqLength);
    }

    /**
     * Convert tokens to indexes.
     *
     * @param tokens input tokens
     * @return List of indexes
     */
    public List<Integer> token2idx(List<String> tokens) {
        List<Integer> indexes = new ArrayList<>();
        for (String token : tokens) {
            if (token2idx.containsKey(token)) {
                indexes.add(token2idx.get(token));
            } else {
                indexes.add(token2idx.get("[UNK]"));
            }
        }
        return indexes;
    }

    /**
     * Convert indexes to tokens.
     *
     * @param indexes List of indexes
     * @return List of tokens
     */
    public List<String> idx2token(List<Integer> indexes) {
        List<String> tokens = new ArrayList<>();
        for (int index : indexes) {
            tokens.add(idx2token.get(index));
        }
        return tokens;
    }
}

Until this point, we finish all of the preparations. Let's start writing code to do inference with this example.

### Step 4 Preparing for the model and input

The model would require three inputs:

- word indices: The index of each word in a sentence
- word types: The type index of the word. All Questions will be labelled as 0 and all Answers will be labelled as 1s.
- sequence length: We need to limit the length of the input, in our case, the length is 384
- valid length: The length of the question and answer tokens

**Firstly, let's load the input**


In [84]:
var question = "When did BBC Japan start broadcasting?";
var answerMaterial = "BBC Japan was a general entertainment Channel.\nWhich operated between December 2004 and April 2006.\nIt ceased operations after its Japanese distributor folded.";

**Secondly, we can load the model and all its artifacts**

This download process may take a while based on the network speed

In [85]:
public void download(String url, String fileName) throws IOException {
  URL downloadUrl = new URL(url);
  String tempDir = System.getProperty("java.io.tmpdir");
  Path tmp = Paths.get(tempDir).resolve("bert");
  Path dest = tmp.resolve(fileName);
  if (Files.exists(dest)) {
    return;
  }
  Files.createDirectories(tmp.toAbsolutePath());
  try (InputStream is = downloadUrl.openStream()) {
    Files.copy(is, dest);
  }
}

download("https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-0002.params", "static_bert_qa-0002.params");
download("https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/static_bert_qa-symbol.json", "static_bert_qa-symbol.json");
download("https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/BertQA/vocab.json", "vocab.json");

Then let's try to load the model and vocabulary. Please create a variable `model` by using `Model.loadModel(<model_directory>, <model_name>)` to load your model.

After that, you can use `getArtifact("fileName", function)` method to load the vocabulary and create `BertDataParser` class to prepare for the preprocessing.

In [86]:
var modelName = "static_bert_qa";
var modelDir = Paths.get(System.getProperty("java.io.tmpdir")).resolve("bert");
// TODO: Add load model function here
var model = Model.loadModel(modelDir, modelName);

BertDataParser parser = model.getArtifact("vocab.json", BertDataParser::parse);

### Step 5 Creating the Translator

Inference in Deep Learning is the process of predicting the output for a given input based on a pre-defined model. 
Joule abstracts the whole process away from you. It can load the model, perform inference on the input, and provide 
output. Joule also allows you to provide user-defined inputs. The workflow looks like the following:

![image](../examples/doc/img/workFlow.png)

The red block ("Images") in the workflow is the input that Joule expects from you. The green block ("Images 
bounding box") is the output that you expect. Since Joule does not know what input to expect and what format of output that you prefer, Joule provides the `Translator` interface so you can define your own 
input and output.  

The `Translator` interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing 
component converts the user-defined input objects into an NDList, so that the `Predictor` in Joule can understand the 
input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the 
`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output 
format. 

#### Pre-processing

Now, we need to convert the sentences into tokens. You can use `BertDataParser.tokenizer` to convert question and answer into tokens. Then, you can use `BertDataParser.formTokens` to create Bert Formatted tokens. Once, we have properly formatted tokens, we can use `parser.token2idx` to create the indices. 

In the code block below, you convert question and answer defined earlier into bert-formatted tokens, and create word types for the tokens. 

In [87]:
public static float[] toFloatArray(List<? extends Number> list) {
    float[] ret = new float[list.size()];
    int idx = 0;
    for (Number n : list) {
        ret[idx++] = n.floatValue();
    }
    return ret;
}

// Create token lists for question and answer
List<String> tokenQ = BertDataParser.tokenizer(question);
List<String> tokenA = BertDataParser.tokenizer(answerMaterial);

// Create Bert-formatted tokens
List<String> tokens = BertDataParser.formTokens(tokenQ, tokenA, 384);

// Convert tokens into indices in the vocabulary
List<Integer> indexes = parser.token2idx(tokens);
float[] indexesFloat = toFloatArray(indexes);

// Get token types
List<Float> tokenTypes = BertDataParser.getTokenTypes(tokenQ, tokenA, 384);
float[] types = toFloatArray(tokenTypes);

Now that you have everything you need, you can create an NDList, and populate all of the inputs we formatted earlier, and you will be done with pre-processing! 

However, you need to do this processing within an implementation of the Translator interface. Below is one implementation of the translator we have created. Please complete the TODO sections in the `processInput` section below. (HINT: use the code snippets in the previous cell to help guide you)

Every translator takes in input, and returns output in the form of generic objects. In this case, the translator takes input in the form of `QAInput`, and return output as a `String`. `QAInput` is just an object that holds questions, answer and seqLength;

In [95]:
public class QAInput {
    private String question;
    private String answer;
    private int seqLength;

    QAInput(String question, String answer, int seqlength) {
        this.question = question;
        this.answer = answer;
        this.seqLength = seqLength;
    }

    public String getQuestion() {
        return question;
    }
    
    public String getAnswer() {
        return answer;
    }

    public int getSeqLength() {
        return seqLength;
    }
}

public class BertTranslator implements Translator<QAInput, String> {
        private BertDataParser parser;
        private List<String> tokens;

        BertTranslator(BertDataParser parser) {
            this.parser = parser;
        }

        @Override
        public NDList processInput(TranslatorContext ctx, QAInput input) {
            // Pre-processing - tokenize sentence
            // TODO: Create token lists for question and answer
            
            
            
            // TODO: Create Bert-formatted tokens
            
            
            
            // Convert tokens into indices in the vocabulary
            
            
            
            // TODO: Get token types
            List<Float> tokenTypes = BertDataParser.getTokenTypes(tokenQ, tokenA, input.getSeqLength());
            float[] types = Utils.toFloatArray(tokenTypes);
            
            
            // TODO Calculate valid length
            int validLength = 0;

            NDManager manager = ctx.getNDManager();
            
            // TODO Using the manager created above, create NDArrays for the indices, types, and valid length, in that order. 
            NDArray data0 = null;
            NDArray data1 = null;
            NDArray data2 = null;

            NDList list = new NDList(3);
            list.add("data0", data0);
            list.add("data1", data1);
            list.add("data2", data2);

            return list;
        }

        @Override
        public String processOutput(TranslatorContext ctx, NDList list) {
            NDArray array = list.get(0);
            NDList output = array.split(2, 2);
            // Get the formatted logits result
            NDArray startLogits = output.get(0).reshape(new Shape(1, -1));
            NDArray endLogits = output.get(1).reshape(new Shape(1, -1));
            // Get Probability distribution
            float[] startProb = startLogits.softmax(-1).toFloatArray();
            float[] endProb = endLogits.softmax(-1).toFloatArray();
            int startIdx = argmax(startProb);
            int endIdx = argmax(endProb);
            return tokens.subList(startIdx, endIdx + 1).toString();
        }

        private static int argmax(float[] prob) {
            int maxIdx = 0;
            for (int i = 0; i < prob.length; i++) {
                if (prob[maxIdx] < prob[i]) {
                    maxIdx = i;
                }
            }
            return maxIdx;
        }
    }

Congrats! You have created your first Translator! As you can see above, we have pre-filled the `processOutput()` that will process the `NDList` returned to a format that is favourable to you. The `processInput()` and `processOutput()` offer the flexibility to get the predictions from the model in any format you desire. 


With the Translator implemented, all there is to do is to bring up the predictor to start making predictions. 

In [None]:
protected void printProgress(int iteration, int index) {
    System.out.print(".");
    if (index % 80 == 79 || index == iteration - 1) {
        System.out.println();
    }
}

String predictResult = null;
QAInput input = new QAInput(question, answerMaterial, 384);
BertTranslator translator = new BertTranslator(parser);

// TODO: Create a Predictor and predict the output using the predictor


[IJava-executor-14] WARN org.apache.mxnet.engine.CachedOp - Input data2 not found, set NDArray to Shape(1) by default
[IJava-executor-14] WARN org.apache.mxnet.engine.CachedOp - Input data0 not found, set NDArray to Shape(1) by default
[IJava-executor-14] WARN org.apache.mxnet.engine.CachedOp - Input data1 not found, set NDArray to Shape(1) by default


That's it! It's that simple!