# Lesson Notebook 10 - Embedding-Based Retrieval

In this notebook, we'll explore retrieving and ranking news headlines in response to a query. We'll use an encoder model similar to the Universal Sentence Encoder, to create vectors for each headline and then a vector for the query.  We'll use a library called [SentenceTransformers](https://www.sbert.net/) that has a large number of underlying model weight sets on Hugging Face.  Sentence Transformers are designed to take a sequence of words like a sentence as input and generates an representative vector, an embedding, as output. Note that Sentence Transformers are only avialable in PyTorch but that won't affect our use here thanks to the HuggingFace API. 

First, we'll generate vectors for our headlines and hold those. Then we'll generate an embedding for our query and we'll just use Nearest Neighbors search on the full set of news headlines. Finally, we'll cluster the news headline embeddings first, and only apply Nearest Neighbors to the top k clusters whose centroids are most similar to the query embedding.

If we were trying to build a system that needed to scale,we would use something like the ScaNN library to hold our embeddings and perform our searches.

<a id = 'returnToTop'></a>

## Notebook Contents

  * 1. [Setup](#setup)
  * 2. [Data Preparation](#dataPrep)
  * 3. [Encode Embeddings](#encodeData)
  * 4. [Query and Retrieval](#queryRet)
  * 5. [Retrieval via Clusters](#clusterRet)
  * 6. [Answers](#answers)      









[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/datasci-w266/2022-fall-main/blob/master/materials/lesson_notebooks/lesson_10_embedding_based_retrieval.ipynb)

[Return to Top](#returnToTop)  
<a id = 'setup'></a>

### 1. Setup

In [1]:
!pip install -U sentence-transformers
!pip install -U datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[K     |████████████████████████████████| 85 kB 2.7 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.23.1-py3-none-any.whl (5.3 MB)
[K     |████████████████████████████████| 5.3 MB 38.5 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 54.8 MB/s 
[?25hCollecting huggingface-hub>=0.4.0
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 69.6 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 44.7 MB/s 
Building wheels for collected 

In [2]:
import os
import time
import numpy as np
from datasets import load_dataset

from scipy.spatial.distance import cosine
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sentence_transformers import SentenceTransformer

[Return to Top](#returnToTop)  
<a id = 'dataPrep'></a>

### 2. Data Preparation 

For our data we'll use the test portion of the XSum sumarization data set.  The goal of XSum is to generate a one line summary of the input article.  We'll grab the 'summary' field as this will be an excellent set of "sentences" for our retrieval experiment.  It takes about a minute to process the data ans get us the test records.

In [3]:
dataset = load_dataset('xsum', split='test')

Downloading builder script:   0%|          | 0.00/5.79k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.91k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/5.86k [00:00<?, ?B/s]

Downloading and preparing dataset xsum/default (download: 245.38 MiB, generated: 507.60 MiB, post-processed: Unknown size, total: 752.98 MiB) to /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

Dataset xsum downloaded and prepared to /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934. Subsequent calls will reuse this data.


In [4]:
len(dataset)

11334

In [5]:
dataset[0]

{'document': 'Prison Link Cymru had 1,099 referrals in 2015-16 and said some ex-offenders were living rough for up to a year before finding suitable accommodation.\nWorkers at the charity claim investment in housing would be cheaper than jailing homeless repeat offenders.\nThe Welsh Government said more people than ever were getting help to address housing problems.\nChanges to the Housing Act in Wales, introduced in 2015, removed the right for prison leavers to be given priority for accommodation.\nPrison Link Cymru, which helps people find accommodation after their release, said things were generally good for women because issues such as children or domestic violence were now considered.\nHowever, the same could not be said for men, the charity said, because issues which often affect them, such as post traumatic stress disorder or drug dependency, were often viewed as less of a priority.\nAndrew Stevens, who works in Welsh prisons trying to secure housing for prison leavers, said the

[Return to Top](#returnToTop)  
<a id = 'encodeData'></a>

### 3. Encode embeddings

We'll load the sentence transformers with a smaller model so that it can run quickly in the live session.  You can experiment with others to see the tradeoff between size, processing time, and quality.  For example, you could load the sentence transformer with `'sentence-transformers/all-roberta-large-v1'` and leverage the improvements that come with using a large RoBERTa model.  You can see [a full listing of models](https://huggingface.co/models?library=sentence-transformers&sort=downloads) at HuggingFace.

In [6]:
encoder_model = SentenceTransformer('johngiorgi/declutr-base')

Downloading:   0%|          | 0.00/391 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.09k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/501M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Some weights of the model checkpoint at /root/.cache/torch/sentence_transformers/johngiorgi_declutr-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.bias', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
# Just encoding a subset so it doesn't take too long during the live session

news_headlines = [x['summary'] for x in dataset]
news_embeddings = encoder_model.encode(news_headlines[:7500])

[Return to Top](#returnToTop)  
<a id = 'queryRet'></a>

### 4. Query and Retrieval

First, we'll need to create a query and generate an embedding to represent it. Then we'll use that query embedding to walk through *all* of the headline embeddings and find the 10 nearest neighbors.  This is a non scalable approach to retrieval.  We can't examine all of the news headlines each time we have a query.

In [8]:
# Let's try a query for some news we might be looking for

query = 'Tiger Woods did not make the cut at golf tournament.' 
query_embedding = encoder_model.encode([query])

In [9]:
# We'll start by loading all of the news embeddings into a Nearest Neighbors model

knn_model = NearestNeighbors(n_neighbors=10)
knn_model.fit(news_embeddings)

NearestNeighbors(n_neighbors=10)

In [10]:
# We'll keep track of the time it takes to find the top 10 nearest headlines

start = time.time()
dists, topk_idx = knn_model.kneighbors(query_embedding)
for d, i in zip(dists[0], topk_idx[0]):
    print(d, news_headlines[i])

print('\nTime:', time.time() - start)

# (We're using a small number of headlines so it's fast for the live session,
# but it'll still go even faster if we narrow the likely candidates first.)

4.1519513 Tiger Woods missed the cut at the Farmers Insurance Open, as England's Justin Rose maintained a one-shot lead.
5.2124724 Patrick Reed won his first tournament of the season at The Barclays to seal his spot on the USA Ryder Cup team.
5.224104 Jordan Spieth will begin the final day at the Masters with a one-shot lead but playing partner Rory McIlroy's bid faltered on day three at Augusta.
5.240151 Australia's Scott Hend eagled the 18th hole for the second successive day to take a one-shot lead over England's Tyrrell Hatton at the PGA Championship.
5.317067 Rory McIlroy hopes to play in the WGC-HSBC Champions event in Shanghai despite suffering with food poisoning.
5.333722 World number one Dustin Johnson is out of the Masters at Augusta National after suffering a back injury in a fall at his rental home on Wednesday.
5.3446345 Graeme McDowell was at the back of the 18th green at Castle Stuart talking about "the dangling carrot" that is September's Ryder Cup at Hazeltine in Minn

Because of our small number of headlines we can get the 10 closest headlines in 3 hundreths of a second.

[Return to Top](#returnToTop)  
<a id = 'clusterRet'></a>

### 5. Retrieval via Clusters

If we can cluster the embeddings first then we can speed up and scale the retrieval process.  We can first find clusters that are "close" to our query. Then we can actually examine all of the embeddings with the one cluster that seems responsive to the query.

In [11]:
# Now let's try clustering the news headlines beforehand. This takes time,
# but we only need to do it once, then re-use it for different queries.

cluster_model = KMeans(n_clusters=50)
news_clusters = cluster_model.fit_predict(news_embeddings)

In [12]:
cluster_news_ids = {i: [] for i in range(50)}
for i, c in enumerate(news_clusters):
    cluster_news_ids[c].append(i)

In [13]:
# Compute the distance from the query embedding to each cluster centroid

query_cluster_dists = [cosine(query_embedding[0], cluster_model.cluster_centers_[c])
                       for c in range(50)]

In [14]:
# Get the top k nearest clusters and retrieve their document ids
# (You can try different numbers of top clusters, to see the trade-off between
# speed and recall of all the best articles we found above.)

top_clusters = np.argsort(query_cluster_dists)[:2]
candidate_news_ids = [i for c in top_clusters for i in cluster_news_ids[c]]
len(candidate_news_ids)

266

In [15]:
# Now use Nearest Neighbors only on the top cluster candidates

candidate_news_embeds = [news_embeddings[i] for i in candidate_news_ids]

knn_model = NearestNeighbors(n_neighbors=10)
knn_model.fit(candidate_news_embeds)

start = time.time()
dists, topk_idx = knn_model.kneighbors(query_embedding)
for d, i in zip(dists[0], topk_idx[0]):
    orig_i = candidate_news_ids[i]
    print(d, news_headlines[orig_i])

print('\nTime:', time.time() - start)

4.1519513 Tiger Woods missed the cut at the Farmers Insurance Open, as England's Justin Rose maintained a one-shot lead.
5.2124724 Patrick Reed won his first tournament of the season at The Barclays to seal his spot on the USA Ryder Cup team.
5.224104 Jordan Spieth will begin the final day at the Masters with a one-shot lead but playing partner Rory McIlroy's bid faltered on day three at Augusta.
5.240151 Australia's Scott Hend eagled the 18th hole for the second successive day to take a one-shot lead over England's Tyrrell Hatton at the PGA Championship.
5.317067 Rory McIlroy hopes to play in the WGC-HSBC Champions event in Shanghai despite suffering with food poisoning.
5.333722 World number one Dustin Johnson is out of the Masters at Augusta National after suffering a back injury in a fall at his rental home on Wednesday.
5.3446345 Graeme McDowell was at the back of the 18th green at Castle Stuart talking about "the dangling carrot" that is September's Ryder Cup at Hazeltine in Minn

The clustered approach provides equally good results and it only takes seven thousandths of a second.  That time savings will be meaningful when we have millions or billions of records that need to be searched.

In practice, instead of the clustering approach you would want to use something like [ScaNN](https://github.com/google-research/google-research/tree/master/scann).