# Hands-on: Training and deploying Question Answering with BERT

Pre-trained language representations have been shown to improve many downstream NLP tasks such as question answering, and natural language inference. Devlin, Jacob, et al proposed BERT [1] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bidirectional representations on a wide range of tasks with minimal task-specific parameters, and obtained state- of-the-art results.

In this tutorial, we will focus on adapting the BERT model for the question answering task on the SQuAD dataset. Specifically, we will:

- understand how to pre-process the SQuAD dataset to leverage the learnt representation in BERT,
- adapt the BERT model to the question answering task, and
- load a trained model to perform inference on the SQuAD dataset

## Sagemaker configuration

This notebook requires mxnet-cu101 >= 1.6.0b20191102, gluonnlp >= 0.8.1
We can create a sagemaker notebook instance with the lifecycle configuration file: sagemaker-lifecycle.config

In [1]:
# One time script
# !bash sagemaker-lifecycle.config

Solving environment: done


  current version: 4.5.12
  latest version: 4.8.1

Please update conda by running

    $ conda update -n base -c defaults conda



## Package Plan ##

  environment location: /home/ec2-user/anaconda3/envs/JupyterSystemEnv

  added / updated specs: 
    - rise


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2019.11.28 |       hecc5488_0         145 KB  conda-forge
    rise-5.6.0                 |           py36_0         2.3 MB  conda-forge
    certifi-2019.11.28         |           py36_0         149 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         2.6 MB

The following NEW packages will be INSTALLED:

    rise:            5.6.0-py36_0         conda-forge

The following packages will be UPDATED:

    ca-certificates: 2019.9.11-hecc5488_0 conda-forge --> 

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[K    87% |████████████████████████████    | 549.4MB 86.9MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.4MB 81.9MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.5MB 82.6MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.5MB 80.5MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.5MB 78.6MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.5MB 79.6MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.5MB 80.1MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.5MB 81.4MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.5MB 81.8MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.5MB 79.3MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.5MB 79.7MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.6MB 85.5MB/s eta 0:00:01[K    87% |████████████████████████████    | 549.6MB 81.2MB/s eta 0:00:01[K    87% |████████████

[K    87% |████████████████████████████▏   | 551.5MB 2.4MB/s eta 0:00:33[K    87% |████████████████████████████▏   | 551.5MB 2.3MB/s eta 0:00:33[K    87% |████████████████████████████▏   | 551.5MB 2.4MB/s eta 0:00:33[K    87% |████████████████████████████▏   | 551.5MB 2.4MB/s eta 0:00:33[K    87% |████████████████████████████▏   | 551.5MB 2.4MB/s eta 0:00:33[K    87% |████████████████████████████▏   | 551.5MB 2.4MB/s eta 0:00:33[K    87% |████████████████████████████▏   | 551.5MB 2.4MB/s eta 0:00:33[K    87% |████████████████████████████▏   | 551.6MB 2.4MB/s eta 0:00:33[K    87% |████████████████████████████▏   | 551.6MB 2.4MB/s eta 0:00:33[K    87% |████████████████████████████▏   | 551.6MB 2.4MB/s eta 0:00:33[K    87% |████████████████████████████▏   | 551.6MB 89.8MB/s eta 0:00:01[K    87% |████████████████████████████▏   | 551.6MB 90.0MB/s eta 0:00:01[K    87% |████████████████████████████▏   | 551.6MB 88.7MB/s eta 0:00:01[K    87% |██████████████████████

[K    88% |████████████████████████████▎   | 554.5MB 82.0MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.5MB 79.0MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.5MB 79.5MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.5MB 80.4MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.5MB 80.8MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.5MB 80.9MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.5MB 81.8MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.5MB 77.8MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.5MB 80.1MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.5MB 81.2MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.6MB 81.9MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.6MB 87.9MB/s eta 0:00:01[K    88% |████████████████████████████▎   | 554.6MB 91.0MB/s eta 0:00:01[K    88% |████████████

[K    89% |████████████████████████████▌   | 558.7MB 79.5MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.7MB 78.4MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.7MB 79.6MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.7MB 78.7MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.8MB 79.2MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.8MB 81.4MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.8MB 79.9MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.8MB 81.5MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.8MB 83.8MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.8MB 80.9MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.8MB 85.7MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.8MB 85.8MB/s eta 0:00:01[K    89% |████████████████████████████▌   | 558.8MB 86.5MB/s eta 0:00:01[K    89% |████████████

[K    89% |████████████████████████████▊   | 563.0MB 81.5MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.0MB 81.6MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.0MB 83.5MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.0MB 87.9MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.1MB 86.0MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.1MB 85.9MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.1MB 84.9MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.1MB 83.3MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.1MB 86.7MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.1MB 86.0MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.1MB 87.2MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.1MB 88.2MB/s eta 0:00:01[K    89% |████████████████████████████▊   | 563.1MB 88.7MB/s eta 0:00:01[K    89% |████████████

[K    90% |█████████████████████████████   | 567.3MB 83.5MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.3MB 83.6MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.3MB 82.5MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.3MB 83.1MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.4MB 86.0MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.4MB 84.3MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.4MB 88.3MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.4MB 86.2MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.4MB 87.0MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.4MB 90.4MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.4MB 86.8MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.4MB 87.7MB/s eta 0:00:01[K    90% |█████████████████████████████   | 567.4MB 86.9MB/s eta 0:00:01[K    90% |████████████

[K    91% |█████████████████████████████▏  | 571.7MB 81.8MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.7MB 81.7MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.7MB 80.6MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.7MB 81.6MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.7MB 85.5MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.7MB 87.3MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.7MB 94.3MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.7MB 93.7MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.7MB 91.1MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.8MB 89.8MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.8MB 88.8MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.8MB 89.9MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 571.8MB 90.3MB/s eta 0:00:01[K    91% |████████████

[K    91% |█████████████████████████████▍  | 576.0MB 86.2MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.0MB 86.1MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.0MB 87.3MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.0MB 85.7MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.0MB 84.5MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.0MB 85.5MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.0MB 81.4MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.1MB 79.1MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.1MB 79.3MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.1MB 80.4MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.1MB 83.3MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.1MB 80.0MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 576.1MB 78.6MB/s eta 0:00:01[K    91% |████████████

[K    92% |█████████████████████████████▋  | 580.3MB 81.1MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.3MB 79.9MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.3MB 82.2MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.3MB 83.8MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.3MB 84.0MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.3MB 81.8MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.3MB 84.1MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.3MB 84.2MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.3MB 82.8MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.4MB 82.2MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.4MB 82.6MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.4MB 81.5MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 580.4MB 81.9MB/s eta 0:00:01[K    92% |████████████

[K    93% |█████████████████████████████▉  | 584.5MB 82.5MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.6MB 81.4MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.6MB 85.4MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.6MB 82.9MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.6MB 83.7MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.6MB 83.2MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.6MB 81.2MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.6MB 79.2MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.6MB 81.5MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.6MB 80.4MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.6MB 81.6MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.7MB 82.1MB/s eta 0:00:01[K    93% |█████████████████████████████▉  | 584.7MB 81.0MB/s eta 0:00:01[K    93% |████████████

[K    93% |██████████████████████████████  | 588.8MB 82.9MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.8MB 84.5MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.8MB 82.2MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.9MB 84.0MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.9MB 87.0MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.9MB 84.3MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.9MB 83.6MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.9MB 82.2MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.9MB 83.3MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.9MB 83.2MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.9MB 81.6MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.9MB 80.1MB/s eta 0:00:01[K    93% |██████████████████████████████  | 588.9MB 82.0MB/s eta 0:00:01[K    93% |████████████

[K    94% |██████████████████████████████▎ | 593.1MB 86.6MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.1MB 86.4MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.1MB 86.3MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.1MB 86.3MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.2MB 88.4MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.2MB 86.4MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.2MB 84.7MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.2MB 82.9MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.2MB 83.6MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.2MB 84.2MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.2MB 88.8MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.2MB 89.7MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 593.2MB 85.1MB/s eta 0:00:01[K    94% |████████████

[K    95% |██████████████████████████████▌ | 597.4MB 86.7MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.4MB 87.3MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.4MB 85.7MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.4MB 82.8MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.5MB 84.9MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.5MB 85.0MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.5MB 85.6MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.5MB 86.6MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.5MB 83.5MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.5MB 83.4MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.5MB 87.7MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.5MB 83.9MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 597.5MB 86.3MB/s eta 0:00:01[K    95% |████████████

[K    95% |██████████████████████████████▊ | 601.7MB 82.5MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.7MB 81.7MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.7MB 83.6MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.7MB 82.1MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.7MB 81.5MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.8MB 80.0MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.8MB 83.0MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.8MB 83.8MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.8MB 86.0MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.8MB 86.7MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.8MB 84.5MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.8MB 86.2MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 601.8MB 87.5MB/s eta 0:00:01[K    95% |████████████

[K    96% |███████████████████████████████ | 606.0MB 84.7MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.0MB 85.8MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.1MB 88.7MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.1MB 84.2MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.1MB 84.9MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.1MB 86.8MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.1MB 87.2MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.1MB 88.6MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.1MB 90.3MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.1MB 88.9MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.1MB 86.3MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.1MB 85.1MB/s eta 0:00:01[K    96% |███████████████████████████████ | 606.2MB 85.7MB/s eta 0:00:01[K    96% |████████████

[K    97% |███████████████████████████████▏| 610.4MB 84.0MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.4MB 86.3MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.4MB 85.5MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.4MB 82.8MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.4MB 83.3MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.4MB 81.7MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.4MB 80.7MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.4MB 82.8MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.4MB 84.8MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.4MB 86.9MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.5MB 83.1MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.5MB 81.2MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 610.5MB 81.4MB/s eta 0:00:01[K    97% |████████████

[K    97% |███████████████████████████████▍| 614.7MB 83.4MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.7MB 83.5MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.7MB 83.4MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.7MB 80.9MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.7MB 82.4MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.7MB 84.7MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.7MB 83.8MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.8MB 84.5MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.8MB 85.7MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.8MB 83.4MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.8MB 85.9MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.8MB 84.6MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 614.8MB 85.4MB/s eta 0:00:01[K    97% |████████████

[K    98% |███████████████████████████████▌| 619.0MB 80.8MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.0MB 82.6MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.0MB 86.0MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.0MB 88.2MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.0MB 85.3MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.1MB 81.6MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.1MB 80.5MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.1MB 82.6MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.1MB 82.6MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.1MB 83.9MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.1MB 86.0MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.1MB 81.3MB/s eta 0:00:01[K    98% |███████████████████████████████▋| 619.1MB 80.9MB/s eta 0:00:01[K    98% |████████████

[K    99% |███████████████████████████████▊| 623.3MB 81.5MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 623.3MB 82.7MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 623.3MB 84.9MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 623.3MB 84.9MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 623.4MB 84.5MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 623.4MB 83.2MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 623.4MB 85.3MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 623.4MB 82.7MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 623.4MB 81.9MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 623.4MB 83.3MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 623.4MB 88.8MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 623.4MB 89.4MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 623.4MB 90.3MB/s eta 0:00:01[K    99% |████████████

[K    100% |████████████████████████████████| 627.7MB 83kB/s  eta 0:00:01
[?25hRequirement not upgraded as not directly required: graphviz<0.9.0,>=0.8.1 in ./anaconda3/envs/mxnet_p36/lib/python3.6/site-packages (from mxnet-cu101) (0.8.4)
Requirement not upgraded as not directly required: requests<3,>=2.20.0 in ./anaconda3/envs/mxnet_p36/lib/python3.6/site-packages (from mxnet-cu101) (2.20.0)
Requirement not upgraded as not directly required: numpy<2.0.0,>1.16.0 in ./anaconda3/envs/mxnet_p36/lib/python3.6/site-packages (from mxnet-cu101) (1.16.4)
Requirement not upgraded as not directly required: chardet<3.1.0,>=3.0.2 in ./anaconda3/envs/mxnet_p36/lib/python3.6/site-packages (from requests<3,>=2.20.0->mxnet-cu101) (3.0.4)
Requirement not upgraded as not directly required: idna<2.8,>=2.5 in ./anaconda3/envs/mxnet_p36/lib/python3.6/site-packages (from requests<3,>=2.20.0->mxnet-cu101) (2.6)
Requirement not upgraded as not directly required: certifi>=2017.4.17 in ./anaconda3/envs/mxnet_p

In [2]:
!pip list | grep mxnet
!pip list | grep gluonnlp

keras-mxnet                        2.2.4.2       
mxnet-cu101                        1.6.0b20191122
mxnet-mkl                          1.5.0         
mxnet-model-server                 1.0.5         
[33mYou are using pip version 10.0.1, however version 19.3.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
gluonnlp                           0.9.0.dev0    
[33mYou are using pip version 10.0.1, however version 19.3.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


## Load MXNet and GluonNLP

We first import the libraries:

In [3]:
import argparse, collections, time, logging
import json
import os
import io
import copy
import random
import warnings

import numpy as np
import gluonnlp as nlp
import mxnet as mx
import bert
import qa_utils

from gluonnlp.data import SQuAD
from bert.model.qa import BertForQALoss, BertForQA
from bert.data.qa import SQuADTransform, preprocess_dataset
from bert.bert_qa_evaluate import get_F1_EM, predict, PredResult

# Hyperparameters
parser = argparse.ArgumentParser('BERT finetuning')
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--batch_size', default=32)
parser.add_argument('--num_epochs', default=1)
parser.add_argument('--lr', default=5e-5)


args = parser.parse_args([])


epochs = args.epochs
batch_size = args.batch_size
num_epochs = args.num_epochs
lr = args.lr

# output_dir = args.output_dir
# if not os.path.exists(output_dir):
#     os.mkdir(output_dir)
# test_batch_size = args.test_batch_size
# optimizer = args.optimizer
# accumulate = args.accumulate
# warmup_ratio = args.warmup_ratio
# log_interval = args.log_interval
# max_seq_length = args.max_seq_length
# doc_stride = args.doc_stride
# max_query_length = args.max_query_length
# n_best_size = args.n_best_size

## Inspect the SQuAD Dataset

Then we take a look at the Stanford Question Answering Dataset (SQuAD). The dataset can be downloaded using the `nlp.data.SQuAD` API. In this tutorial, we create a small dataset with 3 samples from the SQuAD dataset for demonstration purpose.

The question answering task on the SQuAD dataset is setup the following way. For each sample in the dataset, a context is provided. The context is usually a long paragraph which contains lots of information. Then a question asked based on the context. The goal is to find the text span in the context that answers the question in the sample.

In [4]:
full_data = nlp.data.SQuAD(segment='dev', version='1.1')
# loading a subset of the dev set of SQuAD
num_target_samples = 3
target_samples = [full_data[i] for i in range(num_target_samples)]
dataset = mx.gluon.data.SimpleDataset(target_samples)
print('Number of samples in the created dataset subsampled from SQuAD = %d'%len(dataset))

Downloading /home/ec2-user/.mxnet/datasets/squad/dev-v1.1.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/squad/dev-v1.1.zip...
Number of samples in the created dataset subsampled from SQuAD = 3


In [5]:
target_samples[0]

(0,
 '56be4db0acb8001400a502ec',
 'Which NFL team represented the AFC at Super Bowl 50?',
 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.',
 ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 [177, 177, 177])

Let's take a look at a sample from the dataset. In this sample, the question is about the location of the game, with a description about the Super Bowl 50 game as the context. Note that three different answer spans are correct for this question, and they start from index 403, 355 and 355 in the context respectively.

In [6]:
sample = dataset[2]

context_idx = 3

print('\nContext:\n')
print(sample[context_idx])


Context:

Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.


In [7]:
question_idx = 2
answer_idx = 4
answer_pos_idx = 5

print("\nQuestion")
print(sample[question_idx])
print("\nCorrect Answer Spans")
print(sample[answer_idx])
print("\nAnswer Span Start Indices:")
print(sample[answer_pos_idx])


Question
Where did Super Bowl 50 take place?

Correct Answer Spans
['Santa Clara, California', "Levi's Stadium", "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."]

Answer Span Start Indices:
[403, 355, 355]


## Data Pre-processing for QA with BERT

Recall that during BERT pre-training, it takes a sentence pair as the input, separated by the 'SEP' special token. For SQuAD, we can feed the context-question pair as the sentence pair input. To use BERT to predict the starting and ending span of the answer, we can add a classification layer for each token in the context texts, to predict if a token is the start or the end of the answer span. 

![qa](natural_language_understanding/qa.png)

In the next few code blocks, we will work on pre-processing the samples in the SQuAD dataset in the desired format with these special separators. 


### Get Pre-trained BERT Model

First, let's use the *get_model* API in GluonNLP to get the model definition for BERT, and the vocabulary used for the BERT model. Note that we discard the pooler and classifier layers used for the next sentence prediction task, as well as the decoder layers for the masked language model task during the BERT pre-training phase. These layers are not useful for predicting the starting and ending indices of the answer span.

The list of pre-trained BERT models available in GluonNLP can be found [here](http://gluon-nlp.mxnet.io/model_zoo/bert/index.html).

In [8]:
bert_model, vocab = nlp.model.get_model('bert_12_768_12',
                                        dataset_name='book_corpus_wiki_en_uncased',
                                        use_classifier=False,
                                        use_decoder=False,
                                        use_pooler=False,
                                        pretrained=False)

Vocab file is not found. Downloading.
Downloading /home/ec2-user/.mxnet/models/1578960940.0694501book_corpus_wiki_en_uncased-a6607397.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/vocab/book_corpus_wiki_en_uncased-a6607397.zip...


Note that there are several special tokens in the vocabulary for BERT. In particular, the `[SEP]` token is used for separating the sentence pairs, and the `[CLS]` token is added at the beginning of the sentence pairs. They will be used to pre-process the SQuAD dataset later.

In [9]:
print(vocab)

Vocab(size=30522, unk="[UNK]", reserved="['[CLS]', '[SEP]', '[MASK]', '[PAD]']")


### Tokenization

The second step is to process the samples using the same tokenizer used for BERT, which is provided as the `BERTTokenizer` API in GluonNLP. Note that instead of word level and character level representation, BERT uses subwords to represent a word, separated `##`. 

In the following example, the word `suspending` is tokenized as two subwords (`suspend` and `##ing`), and `numerals` is tokenized as three subwords (`nu`, `##meral`, `##s`).

In [10]:
tokenizer = nlp.data.BERTTokenizer(vocab=vocab, lower=True)

tokenizer("as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals")

['as',
 'well',
 'as',
 'temporarily',
 'suspend',
 '##ing',
 'the',
 'tradition',
 'of',
 'naming',
 'each',
 'super',
 'bowl',
 'game',
 'with',
 'roman',
 'nu',
 '##meral',
 '##s']

### Sentence Pair Composition

With the tokenizer inplace, we are ready to process the question-context texts and compose sentence pairs. The functionality is available via the `SQuADTransform` API. 

In [11]:
transform = bert.data.qa.SQuADTransform(tokenizer, is_pad=False, is_training=False, do_lookup=False)
dev_data_transform, _ = preprocess_dataset(dataset, transform)
logging.info('The number of examples after preprocessing:{}'.format(len(dev_data_transform)))

Done! Transform dataset costs 0.14 seconds.


Let's take a look at the sample after the transformation:

In [12]:
sample = dev_data_transform[2]
print('\nsegment type: \n' + str(sample[2]))
print('\ntext length: ' + str(sample[3]))
print('\nsentence pair: \n' + str(sample[1]))


segment type: 
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

text length: 168

sentence pair: 
['[CLS]', 'where', 'did', 'super', 'bowl', '50', 'take', 'place', '?', '[SEP]', 'super', 'bowl', '50', 'was', 'an', 'american', 'football', 'game', 'to', 'determine', 'the', 'champion', 'of', 'the', 'national', 'football', 'league', '(', 'nfl', ')', 'for', 'the', '2015', 'season', '.', 'the', 'american', 'football', 'conference', '(', 'afc', ')', 'champion', 'denver', 'broncos', 'defeated', 'the', 'national', 'football', 'conference', '(', 

### Vocabulary Lookup

Finally, we convert the transformed texts to subword indices, which are used to contructor NDArrays as the inputs to the model.

In [13]:
def vocab_lookup(example_id, subwords, type_ids, length, start, end):
    indices = vocab[subwords]
    return example_id, indices, type_ids, length, start, end

dev_data_transform = dev_data_transform.transform(vocab_lookup, lazy=False)
print(dev_data_transform[2][1])

[2, 2073, 2106, 3565, 4605, 2753, 2202, 2173, 1029, 3, 3565, 4605, 2753, 2001, 2019, 2137, 2374, 2208, 2000, 5646, 1996, 3410, 1997, 1996, 2120, 2374, 2223, 1006, 5088, 1007, 2005, 1996, 2325, 2161, 1012, 1996, 2137, 2374, 3034, 1006, 10511, 1007, 3410, 7573, 14169, 3249, 1996, 2120, 2374, 3034, 1006, 22309, 1007, 3410, 3792, 12915, 2484, 1516, 2184, 2000, 7796, 2037, 2353, 3565, 4605, 2516, 1012, 1996, 2208, 2001, 2209, 2006, 2337, 1021, 1010, 2355, 1010, 2012, 11902, 1005, 1055, 3346, 1999, 1996, 2624, 3799, 3016, 2181, 2012, 4203, 10254, 1010, 2662, 1012, 2004, 2023, 2001, 1996, 12951, 3565, 4605, 1010, 1996, 2223, 13155, 1996, 1000, 3585, 5315, 1000, 2007, 2536, 2751, 1011, 11773, 11107, 1010, 2004, 2092, 2004, 8184, 28324, 2075, 1996, 4535, 1997, 10324, 2169, 3565, 4605, 2208, 2007, 3142, 16371, 28990, 2015, 1006, 2104, 2029, 1996, 2208, 2052, 2031, 2042, 2124, 2004, 1000, 3565, 4605, 1048, 1000, 1007, 1010, 2061, 2008, 1996, 8154, 2071, 14500, 3444, 1996, 5640, 16371, 28990, 2015

## Model Definition

After the data is processed, we can define the model that uses the representation produced by BERT for predicting the starting and ending positions of the answer span.

We download a BERT model trained on the SQuAD dataset, prepare the dataloader.

In [14]:
net = BertForQA(bert_model)
ctx = mx.gpu(0)
ckpt = qa_utils.download_qa_ckpt()
net.load_parameters(ckpt, ctx=ctx)

batch_size = 1
dev_dataloader = mx.gluon.data.DataLoader(
    dev_data_transform, batch_size=batch_size, shuffle=False)

Downloading ./bert_qa-7eb11865.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/bert_qa-7eb11865.zip...
Downloaded checkpoint to ./bert_qa-7eb11865.params


MXNetError: [00:15:55] src/engine/threaded_engine.cc:331: Check failed: device_count_ > 0 (-1 vs. 0) : GPU usage requires at least 1 GPU
Stack trace:
  [bt] (0) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x6ccefb) [0x7fbedaeccefb]
  [bt] (1) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x38fb53d) [0x7fbede0fb53d]
  [bt] (2) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(mxnet::CopyFromTo(mxnet::NDArray const&, mxnet::NDArray const&, int, bool)+0xa39) [0x7fbede3444e9]
  [bt] (3) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(mxnet::imperative::PushFComputeEx(std::function<void (nnvm::NodeAttrs const&, mxnet::OpContext const&, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> > const&)> const&, nnvm::Op const*, nnvm::NodeAttrs const&, mxnet::Context const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::engine::Var*, std::allocator<mxnet::engine::Var*> > const&, std::vector<mxnet::Resource, std::allocator<mxnet::Resource> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&)+0x482) [0x7fbede1da892]
  [bt] (4) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(mxnet::Imperative::InvokeOp(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, mxnet::DispatchMode, mxnet::OpStatePtr)+0x4f4) [0x7fbede1df6b4]
  [bt] (5) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x25b) [0x7fbede1dfebb]
  [bt] (6) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x389dc2f) [0x7fbede09dc2f]
  [bt] (7) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x62) [0x7fbede09e1f2]
  [bt] (8) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/lib-dynload/../../libffi.so.6(ffi_call_unix64+0x4c) [0x7fbf62fe5ec0]



In [None]:
all_results = collections.defaultdict(list)

total_num = 0
for data in dev_dataloader:
    example_ids, inputs, token_types, valid_length, _, _ = data
    total_num += len(inputs)
    batch_size = inputs.shape[0]
    output = net(inputs.astype('float32').as_in_context(ctx),
                               token_types.astype('float32').as_in_context(ctx),
                               valid_length.astype('float32').as_in_context(ctx))
    pred_start, pred_end = mx.nd.split(output, axis=2, num_outputs=2)
    example_ids = example_ids.asnumpy().tolist()
    pred_start = pred_start.reshape(batch_size, -1).asnumpy()
    pred_end = pred_end.reshape(batch_size, -1).asnumpy()
    
    for example_id, start, end in zip(example_ids, pred_start, pred_end):
        all_results[example_id].append(PredResult(start=start, end=end))

In [None]:
qa_utils.predict(dataset, all_results, vocab)

### Let's Train the Model

Now we can put all the pieces together, and start fine-tuning the model with a few epochs.

In [None]:
# net = BertForQA(bert=bert_model)
# nlp.utils.load_parameters(net, pretrained_bert_parameters, ctx=ctx,
#                           ignore_extra=True, cast_dtype=True)
net.span_classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
net.hybridize(static_alloc=True)

loss_function = BertForQALoss()
loss_function.hybridize(static_alloc=True)

## Deploy on SageMaker

1. Preparing functions for inference 
2. Saving the model parameters
3. Building a docker container with dependencies installed
4. Launching a serving end-point with SageMaker SDK

### 1. Preparing functions for inference

Two functions: 
1. model_fn() to load model parameters
2. transform_fn() to run model inference given an input

In [None]:
%%writefile code/serve.py
import collections, json, logging, warnings
import multiprocessing as mp
from functools import partial

import gluonnlp as nlp
import mxnet as mx
from mxnet.gluon import Block, nn
import bert
from bert.data.qa import preprocess_dataset, SQuADTransform
import bert_qa_evaluate



class BertForQA(Block):
    """Model for SQuAD task with BERT.
    The model feeds token ids and token type ids into BERT to get the
    pooled BERT sequence representation, then apply a Dense layer for QA task.
    Parameters
    ----------
    bert: BERTModel
        Bidirectional encoder with transformer.
    prefix : str or None
        See document of `mx.gluon.Block`.
    params : ParameterDict or None
        See document of `mx.gluon.Block`.
    """

    def __init__(self, bert, prefix=None, params=None):
        super(BertForQA, self).__init__(prefix=prefix, params=params)
        self.bert = bert
        with self.name_scope():
            self.span_classifier = nn.Dense(units=2, flatten=False)

    def forward(self, inputs, token_types, valid_length=None):  # pylint: disable=arguments-differ
        """Generate the unnormalized score for the given the input sequences.
        Parameters
        ----------
        inputs : NDArray, shape (batch_size, seq_length)
            Input words for the sequences.
        token_types : NDArray, shape (batch_size, seq_length)
            Token types for the sequences, used to indicate whether the word belongs to the
            first sentence or the second one.
        valid_length : NDArray or None, shape (batch_size,)
            Valid length of the sequence. This is used to mask the padded tokens.
        Returns
        -------
        outputs : NDArray
            Shape (batch_size, seq_length, 2)
        """
        bert_output = self.bert(inputs, token_types, valid_length)
        output = self.span_classifier(bert_output)
        return output
    
    
def get_all_results(net, vocab, squadTransform, test_dataset, ctx = mx.cpu()):
    all_results = collections.defaultdict(list)
    
    def _vocab_lookup(example_id, subwords, type_ids, length, start, end):
        indices = vocab[subwords]
        return example_id, indices, type_ids, length, start, end
    
    dev_data_transform, _ = preprocess_dataset(test_dataset, squadTransform)
    dev_data_transform = dev_data_transform.transform(_vocab_lookup, lazy=False)
    dev_dataloader = mx.gluon.data.DataLoader(dev_data_transform, batch_size=1, shuffle=False)
    
    for data in dev_dataloader:
        example_ids, inputs, token_types, valid_length, _, _ = data
        batch_size = inputs.shape[0]
        output = net(inputs.astype('float32').as_in_context(ctx),
                     token_types.astype('float32').as_in_context(ctx),
                     valid_length.astype('float32').as_in_context(ctx))
        pred_start, pred_end = mx.nd.split(output, axis=2, num_outputs=2)
        example_ids = example_ids.asnumpy().tolist()
        pred_start = pred_start.reshape(batch_size, -1).asnumpy()
        pred_end = pred_end.reshape(batch_size, -1).asnumpy()

        for example_id, start, end in zip(example_ids, pred_start, pred_end):
            all_results[example_id].append(bert_qa_evaluate.PredResult(start=start, end=end))
    return(all_results)


def _test_example_transform(test_examples):
    test_examples_tuples = []
    i = 0
    for test in test_examples:
        tup = (i, "", test[0], test[1], [], [])
        test_examples_tuples.append(tup)
        i += 1
    return(test_examples_tuples)


def model_fn(model_dir = "", params_path = "bert_qa-7eb11865.params"):
    """
    Load the gluon model. Called once when hosting service starts.
    :param: model_dir The directory where model files are stored.
    :return: a Gluon model, and the vocabulary
    """
    bert_model, vocab = nlp.model.get_model('bert_12_768_12',
                                        dataset_name='book_corpus_wiki_en_uncased',
                                        use_classifier=False,
                                        use_decoder=False,
                                        use_pooler=False,
                                        pretrained=False)
    net = BertForQA(bert_model)
    if len(model_dir) > 0:
        params_path = model_dir + "/" +params_path
    net.load_parameters(params_path, ctx=mx.cpu())
    
    tokenizer = nlp.data.BERTTokenizer(vocab,  lower=True)
    transform = SQuADTransform(tokenizer, is_pad=False, is_training=False, do_lookup=False)
    return net, vocab, transform



def transform_fn(model, question_json, content_json, input_content_type=None, output_content_type=None):
    """
    Transform a request using the Gluon model. Called once per request.
    :param model: The Gluon model and the vocab
    :param dataset: The request payload
    
        Example:
        ## (example_id, [question, content], ques_cont_token_types, valid_length, _, _)


        (2, 
        '56be4db0acb8001400a502ee', 
        'Where did Super Bowl 50 take place?', 
        
        'Super Bowl 50 was an American football game to determine the champion of the National 
        Football League (NFL) for the 2015 season. The American Football Conference (AFC) 
        champion Denver Broncos defeated the National Football Conference (NFC) champion 
        Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played 
        on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, 
        California. As this was the 50th Super Bowl, the league emphasized the "golden 
        anniversary" with various gold-themed initiatives, as well as temporarily suspending 
        the tradition of naming each Super Bowl game with Roman numerals (under which the 
        game would have been known as "Super Bowl L"), so that the logo could prominently 
        feature the Arabic numerals 50.', 
        
        ['Santa Clara, California', "Levi's Stadium", "Levi's Stadium 
        in the San Francisco Bay Area at Santa Clara, California."], 
        
        [403, 355, 355])

    :param input_content_type: The request content type, assume json
    :param output_content_type: The (desired) response content type, assume json
    :return: response payload and content type.
    """
    net, vocab, squadTransform = model
    if question_json[-4:] == ".json":
        question = json.loads(question_json)
        content = json.loads(content_json)
    else:
        question = question_json
        content = content_json  
    test_examples_tuples = [(0, "", question, content, [], [])]

    test_dataset = mx.gluon.data.SimpleDataset(test_examples_tuples)
    all_results = get_all_results(net, vocab, squadTransform, test_dataset, ctx=mx.cpu())

    
    all_predictions = collections.defaultdict(list) # collections.OrderedDict()
    data_transform = test_dataset.transform(squadTransform._transform)
    for features in data_transform:
        f_id = features[0].example_id
        results = all_results[f_id]
        prediction, nbest = bert_qa_evaluate.predict(
            features=features,
            results=results,
            tokenizer=nlp.data.BERTBasicTokenizer(vocab))        
        nbest_prediction = [] 
        for i in range(3):
            nbest_prediction.append('%.2f%% \t %s'%(nbest[i][1] * 100, nbest[i][0]))
        all_predictions[f_id] = nbest_prediction
    response_body = json.dumps(all_predictions)
    return response_body, output_content_type

### 2. Saving the model parameters

In [None]:
## save parameters, model definition and vocabulary in a zip file

# output_dir = "model_outputs"
# if not os.path.exists(output_dir):
#     os.mkdir(output_dir)

with open('vocab.json', 'w') as f:
    f.write(vocab.to_json())

import tarfile
with tarfile.open("model.tar.gz", "w:gz") as tar:
#     tar.add("Dockerfile")
    tar.add("code/serve.py")
    tar.add("bert/data/qa.py")
    tar.add("bert_qa_evaluate.py")
    tar.add("bert_qa-7eb11865.params")
    tar.add("vocab.json")

In [None]:
## test

test_example_0 = ('Which NFL team represented the AFC at Super Bowl 50?',
 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.')

test_example_1 = ('Where did Super Bowl 50 take place?',
 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.')

my_test_examples = (test_example_0, test_example_1)


## prepare test examples
with open('my_test_examples.json', 'w') as f:
    json.dump(my_test_examples, f)
    
# with open('my_test_examples.json') as data_file:    
#     my_test_examples_json = json.load(data_file)

In [None]:
# ## test
# !cp code/serve.py serve.py
# import serve

# # ## if change serve.py, needs to reload
# # %load_ext autoreload
# # %autoreload serve
# import importlib
# importlib.reload(serve)

# mymodel = serve.model_fn()
# serve.transform_fn(mymodel, 'my_test_examples.json')

### 3. Building a docker container with dependencies installed

Let's prepare a docker container with all the dependencies required for model inference. Here we build a docker container based on the SageMaker MXNet inference container, and you can find the list of all available inference containers at https://docs.aws.amazon.com/sagemaker/latest/dg/pre-built-containers-frameworks-deep-learning.html

Here we use local mode for demonstration purpose. To deploy on actual instances, you need to login into AWS elastic container registry (ECR) service, and push the container to ECR. 

```
docker build -t $YOUR_EDR_DOCKER_TAG . -f Dockerfile
$(aws ecr get-login --no-include-email --region $YOUR_REGION)
docker push $YOUR_EDR_DOCKER_TAG
```

In [None]:
%%writefile Dockerfile

ARG REGION
FROM 763104351884.dkr.ecr.$REGION.amazonaws.com/mxnet-inference:1.4.1-gpu-py3

RUN pip install --upgrade --user --pre 'mxnet-mkl' 'https://github.com/dmlc/gluon-nlp/tarball/v0.9.x'

RUN pip list | grep mxnet

COPY *.py /opt/ml/model/code/
COPY bert/data/qa.py /opt/ml/model/code/bert/data/
COPY bert/bert_qa_evaluate.py /opt/ml/model/code/bert/

Docker login cmd

In [None]:
!$(aws ecr get-login --no-include-email --region us-east-1 --registry-ids 763104351884)

In [None]:
!export REGION=$(wget -qO- http://169.254.169.254/latest/meta-data/placement/availability-zone) &&\
 docker build --no-cache --build-arg REGION=${REGION::-1} -t my-docker:inference . -f Dockerfile

### 4. Launching a serving end-point with SageMaker SDK

We create a MXNet model which can be deployed later, by specifying the docker image, and entry point for the inference code. If serve.py does not work, use dummy_hosting_module.py for debugging purpose. 

In [None]:
import sagemaker
from sagemaker.mxnet.model import MXNetModel
sagemaker_model = MXNetModel(model_data='file:///home/ec2-user/SageMaker/ako2020-bert/tutorial/model.tar.gz',
                             image='my-docker:inference', # docker images
                             role=sagemaker.get_execution_role(), 
                             py_version='py3',            # python version
                             entry_point='serve.py',
                             source_dir='.')

We use 'local' mode to test our deployment code, where the inference happens on the current instance.
If you are ready to deploy the model on a new instance, change the `instance_type` argument to values such as `ml.c4.xlarge`.

Here we use 'local' mode for testing, for real instances use c5.2xlarge, p2.xlarge, etc.

In [None]:
predictor = sagemaker_model.deploy(initial_instance_count=1, instance_type='local')

In [None]:
output = predictor.predict(my_test_example_0[0], my_test_example_0[1])  
print('\nPrediction output: {}\n\n'.format(output))

### Clean Up

Remove the endpoint after we are done. 

In [None]:
# predictor.delete_endpoint()

## Side notes (Do NOT Run!)

In [None]:
# !wget -qO- http://169.254.169.254/latest/meta-data/placement/availability-zone
# !aws ecr get-login --no-include-email --region us-east-1 --registry-ids 763104351884
# !docker pull 763104351884.dkr.ecr.us-east-1.amazonaws.com/mxnet-inference:1.6.0-gpu-py3


        (2, 
        ['[CLS]', 'where', 'did', 'super', 'bowl', '50', 'take', 'place', '?', 
        '[SEP]', 'super', 'bowl', '50', 'was', 'an', 'american', 'football', 
        'game', 'to', 'determine', 'the', 'champion', 'of', 'the', 'national', 
        'football', 'league', '(', 'nfl', ')', 'for', 'the', '2015', 'season', '.', 
        'the', 'american', 'football', 'conference', '(', 'afc', ')', 'champion', 
        'denver', 'broncos', 'defeated', 'the', 'national', 'football', 'conference', 
        '(', 'nfc', ')', 'champion', 'carolina', 'panthers', '24', '–', '10', 'to', 
        'earn', 'their', 'third', 'super', 'bowl', 'title', '.', 'the', 'game', 'was', 
        'played', 'on', 'february', '7', ',', '2016', ',', 'at', 'levi', "'", 's', 'stadium', 
        'in', 'the', 'san', 'francisco', 'bay', 'area', 'at', 'santa', 'clara', ',', 
        'california', '.', 'as', 'this', 'was', 'the', '50th', 'super', 'bowl', ',', 
        'the', 'league', 'emphasized', 'the', '"', 'golden', 'anniversary', '"', 'with', 
        'various', 'gold', '-', 'themed', 'initiatives', ',', 'as', 'well', 'as', 
        'temporarily', 'suspend', '##ing', 'the', 'tradition', 'of', 'naming', 
        'each', 'super', 'bowl', 'game', 'with', 'roman', 'nu', '##meral', '##s', 
        '(', 'under', 'which', 'the', 'game', 'would', 'have', 'been', 'known', 
        'as', '"', 'super', 'bowl', 'l', '"', ')', ',', 'so', 'that', 'the', 'logo', 
        'could', 'prominently', 'feature', 'the', 'arabic', 'nu', '##meral', '##s', 
        '50', '.', '[SEP]'], 
        
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 
        
        168, 0, 0)

In [None]:

# ## prepare dataset json
# my_test_examples_tuples = test_example_transform(my_test_examples)
# with open('test_examples_tuples.json', 'w') as f:
#     json.dump(my_test_examples_tuples, f)

# test_dataset[1]



# with open('dataset.json') as data_file:    
#     data = json.load(data_file)