#OSC NLP 2.4 - BERT Nearest Neighbor Search

##Introduction

In this lab, you'll build a "search engine" using BERT encodings with nearest neighbor algorithms.  You'll start with a basic working system and tune it on your own to improve search evaluation metrics against a test judgement set!



##Install Dependencies

We're using the transformers library, provided by https://huggingface.co

Installation should only take a moment

In [0]:
!pip install transformers

###Setup g-drive file access

Be sure to set the ```path``` value below, to the same folder as where you uploaded the labs folder from this github repository

In [0]:
path = '/content/drive/My Drive/Colab Notebooks/labs/'

In [0]:
#Grant access to your local g-drive
from google.colab import drive
drive.mount('/content/drive/')

### Import and initialize dependencies

In [0]:
%tensorflow_version 2.x

In [0]:
import torch #Because we're using pytorch model architectures
import json #Because we're loading blog-posts.json
import numpy #Because we're working with tensors
import tqdm #Because we'll see a pretty progress bar in the notebook
from transformers import (
  BertConfig,
  BertTokenizer,
  BertModel
)
device = torch.device("cuda:0")
n_gpu = torch.cuda.device_count()

# Show the GPU that you're borrowing from Google.
# Usually either a 'Tesla T4' or 'Tesla P100-PCIE-16GB'
# If you get a different one, please let us know!
torch.cuda.get_device_name()

###Load Content Data

For this lab, we'll use the blog content from http://o19s.com/blog (data captured as of 2020-01-20)

It contains titles, summaries (two to five sentences), and content (long written text)

In [0]:
posts = json.load(open(path + 'blog-posts.json'))
print('Loaded',len(posts),'blog posts')

##Tokenize

We are now at the point where we need to convert the blog data into BERT encodings.  We'll use these encodings to build an index of vectors and metadata, that can be explored with nearest-neighbor, and other tensor based retrieval algorithms.

###Using Huggingface Transformers

We'll be using the library available here: https://huggingface.co/transformers/ 

For BERT, we're using the smaller ```bert-base-uncased``` model also provided by Huggingface.  You can learn more about this model here: https://huggingface.co/bert-base-uncased and in a list of all models here: https://huggingface.co/models

In [0]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
print('tokenizer and model are ready')

In [0]:
def encode(text):
  #For the text parameter, get the output encodings and metadata from the BERT tokenizer and encoder:
  input_ids = torch.tensor(tokenizer.encode(text, add_special_tokens=True)).unsqueeze(0)
  outputs = model(input_ids)
  #The model is returned as a flat array, so we need to reshape it to N by 768, where N is the number of tokens that have been encoded, and 768 being the output encoding dimensions for one token.
  flat = outputs[0][0].detach().numpy()
  encodings = numpy.reshape(flat,(flat.size//768,768))
  return encodings

In [0]:
#Test the encoder!  You should get output similar to:
"""
[[-0.2130882  -0.12260079 -0.29699266 ... -0.3321391   0.7823828
   0.17157161]
 [ 0.3834221   0.14510818 -0.48845595 ... -0.2496235   1.3287714
   0.08275799]
 [ 0.32850048  0.7211096   0.57169473 ... -0.631611    0.7564233
   0.48113757]
 ...
 [ 0.01067545 -1.0143113   0.69491935 ...  0.40127933  0.39096206
  -0.52970445]
 [-0.42570952 -0.8559151   0.23692301 ...  0.36261156  0.24261396
  -0.5436267 ]
 [ 0.20549215  0.31845802  0.30723798 ...  0.39721662  0.03817794
  -0.0633641 ]]
"""
print(encode(posts[0]['summary']))

###Build the index!

Index all the things!  Take a guess as to how long this will take, for less than 700 blog post titles and summaries.

In [0]:
def encodeIndex(posts):
  index = []
  for i in tqdm.tqdm(range(len(posts))):
    post = posts[i]

    title_encodings = encode(post['title'])
    summary_encodings = encode(post['summary'])

    """
     Uncomment at your own risk!
     This will take a long while and fail due to an error during encoding time.
     Bonus points: why will this happen?
    """
    #content_encodings = encode(post['content']) 

    index.append({
      "id":post['id'],
      "url":post['url'],
      "title":post['title'],
      "summary":post['summary'],
      "title_encodings":title_encodings,
      "summary_encodings":summary_encodings
      #"content":post['content'],
      #"content_encodings":content_encodings
    })
  return index
index = encodeIndex(posts)

## Let's build an encoding similarity search engine!

We will have a similarity function, and a search function.  Using the title_encodings and summary_encodings, make a new similarity function to replace 'weak_similarity', or change the boosts from the title and summary fields.  Who can make the best recall for k@10?

###Printing and Measuring methods

In [0]:
#Run print the top k results for all the queries!
def print_results(querypair,results,metrics=[],k=10):
  print('QUERY:',querypair[0])
  print('SCORES:',metrics)
  print('RELEVANT DOCS:',querypair[1])
  for result in results[0:k]:
    score = result[0]
    doc = result[1]
    if doc['url'] in querypair[1]:
      found = 'FOUND!'
    else:
      found = '      '
    print('..',found,score,doc["title"],doc["url"])
  print('-----------------------------')

In [0]:
#Relevance metrics!  Use these to score the results returned
def precision(results,judgements,k=10):
  tp=0
  fp=0
  if len(judgements)==0:
    return 0.0
  for result in results[0:k]:
    if result[1]['url'] in judgements:
      tp+=1
    else:
      fp+=1
  return tp/(tp+fp)

#Relevance metrics!  Use these to score the results returned
def recall(results,judgements,k=10):
  tp=0
  if len(judgements)==0:
    return 0.0
  for result in results[0:k]:
    if result[1]['url'] in judgements:
      tp+=1
  return tp/len(judgements)

def f1(p,r):
  n = p*r;
  d = p+r;
  if d==0.0:
    return 0.0
  return 2*(n/d);

def measure_results(results,judgements,k=10):
  p = precision(results,judgements,k)
  r = recall(results,judgements,k)
  f = f1(p,r)
  return p,r,f

###The retrieval and scoring functions

These functions are used to query the index, score the results, and return the ranked list.  Tune these functions to improve the relevance metrics!

In [0]:
#Our weak_similarity function takes the dot product of all tokens and averages them together
def weak_similarity(encoding1,encoding2):
  total = 0
  dims = 0
  for a in encoding1:
    for b in encoding2:
      total+=a.dot(b)
      dims+=1
  return total/dims

In [0]:
#Our search function takes a querystring, the index, and the similarity function
#It returns a ranked resultset ranked by descending score
def berty_searchy(querystring,index,similarity):
  query_encodings = encode(querystring)
  resultset = []
  for i in range(len(index)):
    record = index[i]
    
    #Get the similarity score between query and title
    title_similarity = similarity(query_encodings,record["title_encodings"])

    #Get the similarity score between query and title
    summary_similarity = similarity(query_encodings,record["summary_encodings"])
    
    #This is the scorer that blends the similarities!
    score = title_similarity * 1.2 + summary_similarity

    resultset.append([score,record])

  #Rerank the resultset by score descending
  reranked = sorted(resultset, reverse=True, key=lambda k: k[0])
  return reranked

In [0]:
#Test the weak_similarity function with some comparissons:
A = encode("Apples are very tasty.")
B = encode("Apple stock is high.")
C = encode("I bought a new iPhone today.")
D = encode("I ate some fruit.")

print("~(A,B) ==", weak_similarity(A,B))
print("~(B,C) ==", weak_similarity(B,C))
print("~(A,C) ==", weak_similarity(A,C))
print("~(A,D) ==", weak_similarity(A,D))
print("~(B,D) ==", weak_similarity(B,D))
print("~(C,D) ==", weak_similarity(C,D))

###Test Set

This is our test set of ten queries and judgements of the most relevant documents per query.  We will use this set to evaluate relevance metrics of our similarity and scoring methods!

In [0]:
queries = [
  ("Why should I come to Haystack?",["https://opensourceconnections.com/blog/2019/01/07/haystack-democratize-relevance/","https://opensourceconnections.com/blog/2020/01/14/why-you-should-submit-a-talk-to-haystack-the-search-relevance-conference/","https://opensourceconnections.com/blog/2018/04/10/haystack-search-relevance/","https://opensourceconnections.com/blog/2019/04/24/haystack/","https://opensourceconnections.com/blog/2018/01/04/why-a-search-relevance-conference/"]),
  ("programming with java",["https://opensourceconnections.com/blog/2008/01/02/is-java-the-new-cobol-2/","https://opensourceconnections.com/blog/2015/12/22/exploring-custom-typecodecs-in-the-cassandra-java-driver/","https://opensourceconnections.com/blog/2013/02/12/using-solr-join-to-find-the-best-time-to-ask-questions-on-stackoverflow/"]),
  ("OCR example",["https://opensourceconnections.com/blog/2019/12/03/using-tika-and-tesseract-as-an-api-exposed-by-solr-via-extractingrequesthandler/","https://opensourceconnections.com/blog/2019/11/26/tika-and-tesseract-outside-of-solr/","https://opensourceconnections.com/blog/2019/10/24/it-s-okay-to-run-tika-inside-of-solr-if-and-only-if/","https://opensourceconnections.com/blog/2019/12/10/tesseract-3-and-tika/"]),
  ("searching PDFs",["https://opensourceconnections.com/blog/2019/10/01/solr-meetup-mimecast/","https://opensourceconnections.com/blog/2019/12/17/parsing-tika-tesseract-output-inside-of-solr-via-statelessscriptupdateprocessorfactory/","https://opensourceconnections.com/blog/2019/11/22/it-s-time-for-tika-tuesdays/","https://opensourceconnections.com/blog/2019/10/24/it-s-okay-to-run-tika-inside-of-solr-if-and-only-if/"]),
  ("learning tika",["https://opensourceconnections.com/blog/2019/11/22/it-s-time-for-tika-tuesdays/","https://opensourceconnections.com/blog/2019/11/26/tika-and-tesseract-outside-of-solr/","https://opensourceconnections.com/blog/2019/12/03/using-tika-and-tesseract-as-an-api-exposed-by-solr-via-extractingrequesthandler/,https://opensourceconnections.com/blog/2013/04/28/indexing-millions-of-documents-using-tika-and-atomic-update/,https://opensourceconnections.com/blog/2019/10/24/it-s-okay-to-run-tika-inside-of-solr-if-and-only-if/,https://opensourceconnections.com/blog/2019/12/10/tesseract-3-and-tika/"]),
  ("diversity in ecommerce",["https://opensourceconnections.com/blog/2019/06/26/catching-mices-in-berlin-for-ecommerce-search/","https://opensourceconnections.com/blog/2019/09/05/diversity-vs-relevance/"]),
  ("morelikethis",["https://opensourceconnections.com/blog/2019/05/02/london-solr-meetup-k8s-solr/","https://opensourceconnections.com/blog/2016/09/13/search-engines-are-the-future-of-recsys/","https://opensourceconnections.com/blog/2016/08/21/recommendations-systems-not-as-cool-as-friends/","https://opensourceconnections.com/blog/2016/06/06/recommender-systems-101-basket-analysis/","https://opensourceconnections.com/blog/2016/10/05/elastic-graph-recommendor/","https://opensourceconnections.com/blog/2013/10/05/search-aware-product-recommendation-in-solr/","https://opensourceconnections.com/blog/2013/07/04/friend-recommendations-using-mapreduce/"]),
  ("tokenization and analyzers",["https://opensourceconnections.com/blog/2015/09/18/the-simple-power-of-elasticsearch-analyzers/","https://opensourceconnections.com/blog/2015/09/22/elyzer-step-by-step-elasticsearch-analyzer-debugging/"]),
  ("Is Solr better than Elasticsearch?",["https://opensourceconnections.com/blog/2015/12/15/solr-vs-elasticsearch-relevance-part-one/","https://opensourceconnections.com/blog/2016/01/22/solr-vs-elasticsearch-relevance-part-two/","https://opensourceconnections.com/blog/2019/02/28/stop-worrying-solr-elasticsearch/","https://opensourceconnections.com/blog/2016/06/01/thoughts-on-algolia/"]),
  ("quepid judgements",["https://opensourceconnections.com/blog/2014/06/10/what-is-search-relevancy/","https://opensourceconnections.com/blog/2019/07/25/2019-07-22-quepid-is-now-open-source/","https://opensourceconnections.com/blog/2013/10/07/quepid-give-your-search-queries-some-love/","https://opensourceconnections.com/blog/2015/07/15/quepid-v0.2.0-released/"])
]

##Evaluate!

Run a "search" for each query, and evaluate the relevance metrics for the returned results.  Change the similarity method and the scorer to get a higher F1 metric!

In [0]:
def evaluate(querypairs,k=10):
  measurements = []
  for querypair in querypairs:
    query = querypair[0]
    judgements = querypair[1]
    results = berty_searchy(query,index,weak_similarity)
    p,r,f = measure_results(results,judgements)
    measurements.append([p,r,f])
    print_results(querypair,results,metrics=[p,r,f],k=k)
  print('---------------------------------------')
  print('================================================')
  print('TOTAL PRECISION ..',sum([p[0] for p in measurements])/len(measurements))
  print('TOTAL RECALL    ..',sum([r[1] for r in measurements])/len(measurements))
  print('TOTAL F1        ..',sum([f[2] for f in measurements])/len(measurements))

# Run the search and evaluate the results!
# 'k' is the number of top results that will be shown and evaluated
evaluate(queries,k=10)