In [None]:
!pip install tensorflow_hub
!pip install bert-for-tf2
!pip install sentencepiece

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
print("TF version: ", tf.__version__)
print("Hub version: ", hub.__version__)

In [1]:
import tensorflow_hub as hub
import tensorflow as tf
import bert
FullTokenizer = bert.bert_tokenization.FullTokenizer
from tensorflow.keras.models import Model  
import math

In [2]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [3]:
max_seq_length = 32  # Your choice here.
input_word_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                       name="input_word_ids")
input_mask = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                   name="input_mask")
segment_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,
                                    name="segment_ids")
bert_layer = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1",
                            trainable=True)
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])

In [4]:
model = Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=[pooled_output, sequence_output])

In [5]:

def get_masks(tokens, max_seq_length):
    """Mask for padding"""
    if len(tokens)>max_seq_length:
        raise IndexError("Token length more than max seq length!")
    return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))


def get_segments(tokens, max_seq_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    if len(tokens)>max_seq_length:
        raise IndexError("Token length more than max seq length!")
    segments = []
    current_segment_id = 0
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            current_segment_id = 1
    return segments + [0] * (max_seq_length - len(tokens))


def get_ids(tokens, tokenizer, max_seq_length):
    """Token ids from Tokenizer vocab"""
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
    return input_ids

In [6]:
vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
tokenizer = FullTokenizer(vocab_file, do_lower_case)

In [175]:
s = "I am going to s "

In [176]:
stokens = tokenizer.tokenize(s)

In [177]:
stokens

['i', 'am', 'going', 'to', 's']

In [178]:
stokens = ["[CLS]"] + stokens + ["[SEP]"]

In [179]:
stokens

['[CLS]', 'i', 'am', 'going', 'to', 's', '[SEP]']

In [180]:
import numpy as np

In [181]:
input_ids = np.array(get_ids(stokens, tokenizer, max_seq_length))
input_masks = np.array(get_masks(stokens, max_seq_length))
input_segments = np.array(get_segments(stokens, max_seq_length))

In [182]:
print(stokens)
print(input_ids)
print(input_masks)
print(input_segments)

['[CLS]', 'i', 'am', 'going', 'to', 's', '[SEP]']
[ 101 1045 2572 2183 2000 1055  102    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0]
[1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]


In [183]:
pool_embs, all_embs = model.predict([[input_ids],[input_masks],[input_segments]])

In [184]:
pool_embs[31]

array([ 0.34569314,  0.43720222, -0.8904877 ,  0.52608985,  0.63265824,
        0.16227958, -0.15965639, -0.10834861, -0.94056606,  0.84298307,
       -0.9640128 ,  0.9782785 , -0.6600977 ,  0.72650766, -0.27175733,
       -0.22856599, -0.78187704,  0.2902958 ,  0.04347978,  0.70576525,
        0.05682009,  0.99219793, -0.7601431 ,  0.01237242, -0.27132034,
        0.9933226 , -0.52691495, -0.38232714, -0.35005695, -0.5067427 ,
       -0.1927838 , -0.04245736, -0.00130793,  0.2902257 , -0.9302098 ,
        0.6470284 ,  0.09858067,  0.44149256,  0.22857212, -0.09670964,
       -0.03347912, -0.34896517,  0.44009525,  0.33045903,  0.46741658,
        0.26660106, -0.34482044, -0.06870252,  0.48484105,  0.96824723,
        0.80231476,  0.99090356, -0.12057354, -0.09168582, -0.02806631,
       -0.4699025 , -0.14151902, -0.14270502,  0.31613672,  0.40287217,
        0.4118508 , -0.45021966, -0.86051637,  0.45897365,  0.9874283 ,
        0.91373074, -0.13571554,  0.09640772,  0.19785355, -0.61

In [185]:
all_embs[0][0]

array([-7.92926133e-01,  3.15708458e-01,  1.86620891e-01, -2.07932398e-01,
       -7.76911229e-02, -6.35825455e-01,  5.69271863e-01, -8.94501433e-02,
       -1.10198073e-02, -9.46070194e-01, -6.07392192e-01,  2.88641393e-01,
       -9.02552754e-02,  1.83572307e-01,  1.44892603e-01,  1.49497271e+00,
        1.10081720e+00, -3.77546161e-01,  5.23154020e-01, -6.78690791e-01,
       -1.59130692e-01, -2.00592205e-01, -2.08860144e-01, -8.75420943e-02,
        2.75908977e-01,  2.46143147e-01,  3.31303388e-01,  9.13336813e-01,
        2.85063267e-01,  1.67783111e-01,  2.47712716e-01,  6.19767420e-02,
       -1.99691996e-01,  3.98126453e-01,  1.68219015e-01, -1.49401233e-01,
       -5.13162911e-01, -1.00533080e+00,  3.15520078e-01, -3.04535627e-02,
       -5.00586629e-02, -2.40956798e-01,  8.05944502e-02, -5.74232996e-01,
       -1.87293917e-01, -5.89798629e-01, -7.98157156e-01, -2.21477479e-01,
       -4.05072778e-01,  6.71447366e-02,  2.01702453e-02,  1.79069385e-01,
        8.24961960e-02,  