In [5]:
from sentence_transformers import SentenceTransformer
import json
import numpy as np
import faiss

## 1. Load models to test out

In [14]:
cosine_similarity_model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-v4')

In [15]:
dotprod_model = SentenceTransformer('sentence-transformers/msmarco-distilbert-base-tas-b')

## 2. Load data

In [22]:
with open('processed_books.json') as f:
    data = json.load(f)

In [23]:
texts = list(map(lambda x: x['Text'], data))

## 3. Generate embeddings

In [84]:
cosine_similarity_embeddings = cosine_similarity_model.encode(texts, show_progress_bar=True)

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

In [85]:
dotprod_embeddings = dotprod_model.encode(texts, show_progress_bar=True)

Batches:   0%|          | 0/3 [00:00<?, ?it/s]

In [86]:
cosine_similarity_embeddings.shape, dotprod_embeddings.shape

((78, 768), (78, 768))

In [116]:
# with open('msmarco-distilbert-base-v4_emb.npy', 'wb') as f:
#     np.save(f, cosine_similarity_embeddings)
# with open('msmarco-distilbert-base-tas-b_emb.npy', 'wb') as f:
#     np.save(f, dotprod_embeddings)

## 4. Try out semantic search over the embeddings

In [6]:
cosine_similarity_embeddings = np.load('msmarco-distilbert-base-v4_emb.npy')
dotprod_embeddings = np.load('msmarco-distilbert-base-tas-b_emb.npy')

### 4a. Inner product model

In [24]:
dotprod_index = faiss.IndexFlatIP(768)

In [25]:
dotprod_index.add(dotprod_embeddings)

In [26]:
dotprod_index.ntotal

78

In [27]:
dotprod_index.search(dotprod_embeddings[0:5], 5)

(array([[131.48384 , 110.014946, 108.7542  , 107.00133 , 105.67548 ],
        [138.61638 , 113.56459 , 110.47831 , 109.51258 , 108.82706 ],
        [143.60858 , 119.47456 , 115.84654 , 115.78647 , 114.877655],
        [131.46783 , 112.02878 , 109.30339 , 109.08049 , 108.019325],
        [136.34142 , 111.849014, 109.536316, 108.93637 , 108.49289 ]],
       dtype=float32),
 array([[ 0, 17, 62, 38, 58],
        [ 1, 71, 31, 63, 58],
        [ 2, 65,  7, 16, 46],
        [ 3, 71, 44, 34, 65],
        [ 4, 17, 46, 11, 29]]))

In [28]:
query = "Book about adventure in the jungle"

In [29]:
query_embed = dotprod_model.encode([query])
query_dists, query_nnids = dotprod_index.search(query_embed, 5)
for id in query_nnids[0]:
    print(data[id]['Name'], data[id]['Author'])

The Jungle Book Rudyard Kipling
The Wonderful Wizard of Oz L. Frank Baum
Tarzan and the Lost Empire Edgar Rice Burroughs
Treasure Island Robert Louis Stevenson
A Journey to the Centre of the Earth Jules Verne


In [30]:
query_nnids[0]

array([67, 69, 24, 45, 32])

### 4b. Cosine similarity model

In [31]:
cossim_index = faiss.IndexFlatIP(768)

In [34]:
cosine_similarity_embeddings.shape

(78, 768)

In [35]:
cosine_similarity_embeddings_normalized = cosine_similarity_embeddings / np.linalg.norm(cosine_similarity_embeddings, axis=1)[:, None]

In [36]:
np.linalg.norm(cosine_similarity_embeddings_normalized, axis=1)

array([1.        , 1.        , 1.        , 1.        , 0.99999994,
       0.99999994, 1.        , 0.99999994, 1.        , 1.        ,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       0.99999994, 1.        , 0.99999994, 1.        , 1.        ,
       1.        , 1.        , 1.        , 0.99999994, 0.99999994,
       1.        , 1.        , 0.99999994, 1.        , 0.99999994,
       1.        , 1.        , 1.        , 1.        , 1.        ,
       0.99999994, 1.        , 1.        , 1.        , 0.99999994,
       0.99999994, 0.99999994, 1.        , 0.99999994, 1.        ,
       0.99999994, 1.        , 0.99999994, 1.        , 1.        ,
       0.99999994, 0.99999994, 1.        , 0.99999994, 1.        ,
       1.0000001 , 0.99999994, 0.99999994, 1.        , 1.        ,
       1.        , 1.        , 1.        , 0.99999994, 0.99999994,
       0.99999994, 1.        , 1.        , 1.        , 0.99999994,
       0.99999994, 1.        , 1.        , 1.        , 1.     

In [37]:
cossim_index.add(cosine_similarity_embeddings_normalized)

In [40]:
query = "Book about poor kid"

In [41]:
query_embed = cosine_similarity_model.encode([query])
query_embed_normalized = query_embed / np.linalg.norm(query_embed, axis=1)[:, None]
query_dists, query_nnids = dotprod_index.search(query_embed_normalized, 5)
for id in query_nnids[0]:
    print(data[id]['Name'], data[id]['Author'])

The Wonderful Wizard of Oz L. Frank Baum
Little Women Louisa May Alcott
Grimm's Fairy Tales Jacob Grimm and Wilhelm Grimm
Alice's Adventures in Wonderland Lewis Carroll
A Room with a View E.M. Foster
