# BERT with TensorFlow/Keras

# Generating word embeddings

First add necessary imports and load the BERT model from tfhub.dev

In [1]:
#!python3 -m pip install tensorflow tensorflow-hub tensorflow-text

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from sklearn.metrics.pairwise import cosine_similarity

text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
preprocessor = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
encoder_inputs = preprocessor(text_input)
encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/3", trainable=True)
outputs = encoder(encoder_inputs)
pooled_output = outputs["pooled_output"]
sequence_output = outputs["sequence_output"]

embedding_model = tf.keras.Model(text_input, pooled_output)

Now we can encode some words using the bert model.

In [None]:
query = tf.constant(["neuroscientist"])
query_embedding = embedding_model(query).numpy()
print(query_embedding)

# Sentence Similarity


In [None]:
documents = [
  { 'id': 1, 'text': "cardiac surgeon" },
  { 'id': 2, 'text': "neuroscientist" }
]
print(documents)

In [None]:
document_embeddings = list(
    map(lambda doc:
        { 'id': doc['id'], 'text': embedding_model(tf.constant([doc['text']]).numpy()) },
        documents
    )
)
print(document_embeddings)

In [None]:
cosine_similarities = list(
    map(lambda doc:
        { 'id': doc['id'], 'score': cosine_similarity(query_embedding, doc['text'])[0][0] },
        document_embeddings
    )
)
print(cosine_similarities)

In [None]:
cosine_similarities.sort(key = lambda doc: doc['score'], reverse=True)
print(cosine_similarities)

In [None]:
print("Documents:")
print(documents)

results = list(
    map(lambda score:
        { 'id': score['id'], 'text': list(map(lambda doc: doc['text'], filter(lambda doc: doc['id'] == score['id'], documents)))[0] },
        cosine_similarities
    )
)
print("")
print("Ranked by most similar to search query: '" + query.numpy()[0].decode('ascii') + "'")
print(results)