<a href="https://colab.research.google.com/github/moonman239/Capstone-project/blob/master/Keras_BERT_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Loading Google BERT models into Keras

You may have heard of Google's BERT. This is among the best models, topping the charts for the SQuAD 2.0 dataset and achieving a >80% score on the same.

Since this is a tutorial, we won't be getting much into the nitty-gritty of how BERT works. Suffice it to say, for now, that BERT takes a unique approach. Suppose we're training BERT on the string "I like dogs." BERT will train on:
1) "I like dogs"
2) "dogs like I"
Additionally, when training BERT, Google trains by replacing various words in the sentences with "masks," e.g.: "[mask] like dogs."

Enough talk, let's code!

We'll go ahead and import our modules:

In [1]:
!pip install -U bert-serving-server bert-serving-client # BERT server/client
!pip install wget # Used for downloading the BERT model.
import wget
import os
import urllib
import tensorflow as tf
from tensorflow import keras
import zipfile

Collecting bert-serving-server
[?25l  Downloading https://files.pythonhosted.org/packages/5e/3e/44d79e1a739b8619760051410c61af67f95477c87fbe43e3e9426427feb5/bert_serving_server-1.9.1-py3-none-any.whl (60kB)
[K     |████████████████████████████████| 61kB 2.9MB/s 
[?25hCollecting bert-serving-client
  Downloading https://files.pythonhosted.org/packages/77/24/d17de2bfe84db45be0080f01f3819a821db4bfbd9b927d66c828277ebd02/bert_serving_client-1.9.1-py2.py3-none-any.whl
Collecting pyzmq>=17.1.0 (from bert-serving-server)
[?25l  Downloading https://files.pythonhosted.org/packages/5f/04/f6f0fa20b698b29c6e6b1d6b4b575c12607b0abf61810aab1df4099988c6/pyzmq-18.0.1-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 9.8MB/s 
Collecting GPUtil>=1.3.0 (from bert-serving-server)
  Downloading https://files.pythonhosted.org/packages/ed/0e/5c61eedde9f6c87713e89d794f01e378cfd9565847d4576fa627d758c554/GPUtil-1.4.0.tar.gz
Building wheels for collected packages: GPUtil

Collecting wget
  Downloading https://files.pythonhosted.org/packages/47/6a/62e288da7bcda82b935ff0c6cfe542970f04e29c756b0e147251b2fb251f/wget-3.2.zip
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Stored in directory: /root/.cache/pip/wheels/40/15/30/7d8f7cea2902b4db79e3fea550d7d7b85ecb27ef992b618f3f
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2


Next, let's set everything up. Let's begin by defining what model we'd like to work with:

In [0]:
bert_model_name = "wwm_uncased_L-24_H-1024_A-16"

Okay, now let's go ahead and download that model from GitHub:

In [3]:
bert_model_file_name = bert_model_name + ".zip"
online_bert_path = "https://storage.googleapis.com/bert_models/2019_05_30/" + bert_model_file_name
if not os.path.isfile(bert_model_file_name):
  wget.download(online_bert_path)
  zipfile.ZipFile(bert_model_file_name,'r').extractall()
print(os.path.abspath(bert_model_name))

/content/wwm_uncased_L-24_H-1024_A-16


If there are no errors here, great! We are ready to work!

## Load pretrained model.

Let's start the BERT server.

In [0]:
!bert-serving-start -model_dir /content/wwm_uncased_L-24_H-1024_A-16 -num_worker=4

usage: /usr/local/bin/bert-serving-start -model_dir /content/wwm_uncased_L-24_H-1024_A-16 -num_worker=4
                 ARG   VALUE
__________________________________________________
           ckpt_name = bert_model.ckpt
         config_name = bert_config.json
                cors = *
                 cpu = False
          device_map = []
       do_lower_case = True
  fixed_embed_length = False
                fp16 = False
 gpu_memory_fraction = 0.5
       graph_tmp_dir = None
    http_max_connect = 10
           http_port = None
        mask_cls_sep = False
      max_batch_size = 256
         max_seq_len = 25
           model_dir = /content/wwm_uncased_L-24_H-1024_A-16
          num_worker = 4
       pooling_layer = [-2]
    pooling_strategy = REDUCE_MEAN
                port = 5555
            port_out = 5556
       prefetch_size = 10
 priority_batch_size = 16
show_tokens_to_client = False
     tuned_model_dir = None
             verbose = False
                 xla = False

I:[35