## Embedding-Based Retrieval

In this notebook, we'll explore retrieving and ranking news headlines in response to a query. We'll use an encoder model based on the Universal Sentence Encoder, to create vectors for each headline and then a vector for the query.

First, we'll just use Nearest Neighbors search on the full set of news headlines. Then we'll cluster the news headline embeddings first, and only apply Nearest Neighbors to the top k clusters whose centroids are most similar to the query embedding.

In [None]:
!pip install sentence-transformers
!pip install datasets

In [2]:
import os
import time
import numpy as np
from datasets import load_dataset

from scipy.spatial.distance import cosine
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sentence_transformers import SentenceTransformer

In [None]:
dataset = load_dataset('Fraser/news-category-dataset', split='test')

In [4]:
len(dataset)

30128

In [5]:
dataset[0]

{'authors': 'Cristian Farias',
 'category': 'POLITICS',
 'category_num': 0,
 'date': '2015-10-11',
 'headline': 'The Supreme Court Let A Man Die. He Was Executed With The Wrong Drug.',
 'link': 'https://www.huffingtonpost.com/entry/supreme-court-oklahoma-death-penalty_us_5616a1a2e4b0dbb8000d7860',
 'short_description': "The court placed far too much faith in Oklahoma's disastrous lethal injection protocol in January and in June."}

In [6]:
encoder_model = SentenceTransformer('johngiorgi/declutr-base')

Downloading:   0%|          | 0.00/391 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.44k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/548 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/501M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

No sentence-transformers model found with name /root/.cache/torch/sentence_transformers/johngiorgi_declutr-base. Creating a new one with MEAN pooling.
Some weights of the model checkpoint at /root/.cache/torch/sentence_transformers/johngiorgi_declutr-base were not used when initializing RobertaModel: ['lm_head.decoder.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
# Just encoding a subset so it doesn't take too long during the live session

news_headlines = [x['headline'] for x in dataset]
news_embeddings = encoder_model.encode(news_headlines[:5000])

In [18]:
# Let's try a query for some news we might be looking for

query = 'Golden State Warriors won NBA finals'
query_embedding = encoder_model.encode([query])

In [None]:
# We'll start by loading all of the news embeddings into a Nearest Neighbors model

knn_model = NearestNeighbors(n_neighbors=10)
knn_model.fit(news_embeddings)

In [19]:
# We'll keep track of the time it takes to find the top 10 nearest headlines

start = time.time()
dists, topk_idx = knn_model.kneighbors(query_embedding)
for d, i in zip(dists[0], topk_idx[0]):
    print(d, news_headlines[i])

print('\nTime:', time.time() - start)

# (We're using a small number of headlines so it's fast for the live session,
# but it'll still go even faster if we narrow the likely candidates first.)

6.3696165 Golden State Humiliates Cleveland In Game 2 Of The NBA Finals
8.218968 Justin Timberlake And Jessica Biel Are #CoupleGoals At Grizzlies vs. Lakers Game
8.329109 Meek Mill Opens 76ers' NBA Playoff Game Immediately After Prison Release
8.560379 Novak Djokovic Loses To Sam Querrey In Upset At Wimbledon
8.58611 Atomic Liquors Las Vegas:  A "BLAST" From the Past
8.596141 Canada wildfire: why a sleeping giant awoke in Alberta and became relentless
8.606501 Dodgers Co-Owner Magic Johnson Goes Bonkers Watching Team Romp To World Series
8.671714 Vogelsong, Sandoval Lead the Giants To Victory
8.683214 LIVE: World Cup Championship Rematch
8.706457 Georgetown, N.C. State Square Off In NCAA Tournament Third Round

Time: 0.023498058319091797


In [20]:
# Now let's try clustering the news headlines beforehand. This takes time,
# but we only need to do it once, then re-use it for different queries.

cluster_model = KMeans(n_clusters=50)
news_clusters = cluster_model.fit_predict(news_embeddings)

In [21]:
cluster_news_ids = {i: [] for i in range(50)}
for i, c in enumerate(news_clusters):
    cluster_news_ids[c].append(i)

In [22]:
# Compute the distance from the query embedding to each cluster centroid

query_cluster_dists = [cosine(query_embedding[0], cluster_model.cluster_centers_[c])
                       for c in range(50)]

In [27]:
# Get the top k nearest clusters and retrieve their document ids
# (You can try different numbers of top clusters, to see the trade-off between
# speed and recall of all the best articles we found above.)

top_clusters = np.argsort(query_cluster_dists)[:2]
candidate_news_ids = [i for c in top_clusters for i in cluster_news_ids[c]]
len(candidate_news_ids)

259

In [28]:
# Now use Nearest Neighbors only on the top cluster candidates

candidate_news_embeds = [news_embeddings[i] for i in candidate_news_ids]

knn_model = NearestNeighbors(n_neighbors=10)
knn_model.fit(candidate_news_embeds)

start = time.time()
dists, topk_idx = knn_model.kneighbors(query_embedding)
for d, i in zip(dists[0], topk_idx[0]):
    orig_i = candidate_news_ids[i]
    print(d, news_headlines[orig_i])

print('\nTime:', time.time() - start)

6.3696165 Golden State Humiliates Cleveland In Game 2 Of The NBA Finals
8.218968 Justin Timberlake And Jessica Biel Are #CoupleGoals At Grizzlies vs. Lakers Game
8.329109 Meek Mill Opens 76ers' NBA Playoff Game Immediately After Prison Release
8.560379 Novak Djokovic Loses To Sam Querrey In Upset At Wimbledon
8.596141 Canada wildfire: why a sleeping giant awoke in Alberta and became relentless
8.606501 Dodgers Co-Owner Magic Johnson Goes Bonkers Watching Team Romp To World Series
8.706457 Georgetown, N.C. State Square Off In NCAA Tournament Third Round
8.749657 Don’t Let Stephen Curry Overshadow Russell Westbrook’s Historic Season
8.755055 Clint Dempsey Talks World Cup, Jurgen Klinsmann And The Pressures Facing Team USA
8.856769 Here's A Glimpse Of Who Celebrated 4/20 In San Francisco

Time: 0.01309347152709961
