In [1]:
"""
Copyright 2019 Carlos Rodriguez

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

'\nCopyright 2019 Carlos Rodriguez\n\nLicensed under the Apache License, Version 2.0 (the "License");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an "AS IS" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n'

# How to build a Q&A Chat-bot from Scratch using the Universal Sentence Encoder and KNN Vector Search


# Getting Started

libraries, etc.

In [2]:
# Install the latest Tensorflow version.
!pip3 install --quiet "tensorflow>=2.0"
# Install TF-Hub.
!pip3 install --quiet tensorflow-hub
!pip3 install --quiet nmslib

In [18]:
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import os
import nmslib

import time
import sys
import datetime
import random

In [4]:
# USENC_2 = "https://tfhub.dev/google/universal-sentence-encoder/2"
# USENC_LITE2 = "https://tfhub.dev/google/universal-sentence-encoder-lite/2"
USENC_4 = "https://tfhub.dev/google/universal-sentence-encoder-large/4"

def load_encoder(module_url:str) -> hub.module.Module:
    return hub.load(module_url)

def encode(embed: hub.module.Module, messages: list) -> np.ndarray:
    return embed(messages)["outputs"]

def create_index(embeddings: np.ndarray, method: str='hnsw') -> nmslib.dist.FloatIndex:
    """
    Ref: https://github.com/nmslib/nmslib/blob/master/manual/methods.md
    """
    # initialize a new index, using a HNSW index on Cosine Similarity
    search_index: nmslib.dist.FloatIndex = nmslib.init(method=method, space='cosinesimil')
    search_index.addDataPointBatch(embeddings)
    search_index.createIndex({'post': 2}, print_progress=True)

    return search_index

def search(query_vector: np.ndarray, n_results:int = 3) -> tuple:
    idx, dist = search_index.knnQuery(query_vector, k=n_results)
    return (idx, dist)

In [6]:
# load encoder module
print("Loading encoder...")
embed: hub.module.Module = load_encoder(USENC_4)

Loading encoder...


In [9]:
# sample bot
qna: dict = {
    'queries' : {
        "favorite baseball team" : "fav_baseball",
        "best baseball team" : "fav_baseball",
        "favorite basketball team": "fav_basketball",
        "best basketball team": "fav_basketball",
        "grew up": "hometown",
        "hometown": "hometown",
        "grow up": "hometown"
      },

      'answers' : {
          'fav_baseball': ["NY Yankees, obviously", "have to say...Yankees"],
          'fav_basketball': ["Grew up in the Jordan era...Bulls", "Bulls", "Chicago"],
          'hometown': ["South Norwalk", "Connecticut", "Southern Connecticut right outside of NY"]
      }
}

In [10]:
def update_search_index(embed, queries: list) -> nmslib.dist.FloatIndex:
    # encode queries
    print('Encoding bot data...')
    query_embeddings: np.ndarray = encode(embed, queries)

    return create_index(query_embeddings)

In [11]:
# assemble possible queries
queries: list = list(qna['queries'].keys())

In [12]:
# re-create search index
search_index: nmslib.dist.FloatIndex = update_search_index(embed, queries)

Encoding bot data...


In [36]:
# credit https://gist.github.com/Y4suyuki/6805818
def bubbles(pause: int):
    animation = "|/-\\"

    for i in range(pause):
        time.sleep(0.1)
        sys.stdout.write("\r" + animation[i % len(animation)])
        sys.stdout.flush()


In [39]:
def chat(message: str) -> str:
    bubbles(10)
    
    # encode query
    vectory_query = encode(embed, [message])
    
    # get search results
    idx, dist = search(vectory_query)

    # traverse to answer
    if idx.any():
      search_result = queries[idx[0]]
      answer_key = qna['queries'][search_result]
      answer = qna['answers'][answer_key]
      
      print("\n", random.choice(answer), end="")


In [40]:
chat("What's your favorite baseball team?")

-
 NY Yankees, obviously

In [41]:
chat("Where did you grow up?")

-
 Connecticut