In [20]:
import requests
import numpy as np
from io import StringIO
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss

In [2]:

res = requests.get('https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/sick2014/SICK_train.txt')
# create dataframe
data = pd.read_csv(StringIO(res.text), sep='\t')
data.head()

Unnamed: 0,pair_ID,sentence_A,sentence_B,relatedness_score,entailment_judgment
0,1,A group of kids is playing in a yard and an ol...,A group of boys in a yard is playing and a man...,4.5,NEUTRAL
1,2,A group of children is playing in the house an...,A group of kids is playing in a yard and an ol...,3.2,NEUTRAL
2,3,The young boys are playing outdoors and the ma...,The kids are playing outdoors near a man with ...,4.7,ENTAILMENT
3,5,The kids are playing outdoors near a man with ...,A group of kids is playing in a yard and an ol...,3.4,NEUTRAL
4,9,The young boys are playing outdoors and the ma...,A group of kids is playing in a yard and an ol...,3.7,NEUTRAL


In [3]:

# we take all samples from both sentence A and B
sentences = data['sentence_A'].tolist()
sentences[:5]

['A group of kids is playing in a yard and an old man is standing in the background',
 'A group of children is playing in the house and there is no man standing in the background',
 'The young boys are playing outdoors and the man is smiling nearby',
 'The kids are playing outdoors near a man with a smile',
 'The young boys are playing outdoors and the man is smiling nearby']

In [4]:

# we take all samples from both sentence A and B
sentences = data['sentence_A'].tolist()
sentence_b = data['sentence_B'].tolist()
sentences.extend(sentence_b)  # merge them
len(set(sentences))  # together we have ~4.5K unique sentences

4802

In [5]:
urls = [
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.train.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2013/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/images.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2015/images.test.tsv'
]

In [6]:
# each of these dataset have the same structure, so we loop through each creating our sentences data
for url in urls:
    res = requests.get(url)
    # extract to dataframe
    data = pd.read_csv(StringIO(res.text), sep='\t', header=None, error_bad_lines=False)
    # add to columns 1 and 2 to sentences list
    sentences.extend(data[1].tolist())
    sentences.extend(data[2].tolist())

b'Skipping line 191: expected 3 fields, saw 4\nSkipping line 206: expected 3 fields, saw 4\nSkipping line 295: expected 3 fields, saw 4\nSkipping line 695: expected 3 fields, saw 4\nSkipping line 699: expected 3 fields, saw 4\n'
b'Skipping line 104: expected 3 fields, saw 4\nSkipping line 181: expected 3 fields, saw 4\nSkipping line 317: expected 3 fields, saw 4\nSkipping line 412: expected 3 fields, saw 5\nSkipping line 508: expected 3 fields, saw 4\n'


In [7]:
len(set(sentences))

14505

In [9]:

# remove duplicates and NaN
sentences = [word for word in list(set(sentences)) if type(word) is str]

In [10]:

# initialize sentence transformer model
model = SentenceTransformer('bert-base-nli-mean-tokens')
# create sentence embeddings
sentence_embeddings = model.encode(sentences)
sentence_embeddings.shape

(14504, 768)

In [11]:
d = sentence_embeddings.shape[1]
d

768

In [12]:
# IndexFlatL2 measures the L2 (or Euclidean) distance between all given points between our query vector, and the vectors loaded into the index
index = faiss.IndexFlatL2(d)

In [13]:

index.is_trained

True

In [14]:
index.add(sentence_embeddings)
index.ntotal

14504

In [15]:
# Then search given a query xq and number of nearest neigbors to return k.

k = 4
xq = model.encode(["Someone sprints with a football"])

In [16]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[10784  6858  1343  2809]]
Wall time: 23 ms


In [18]:
for i in I[0]:
    print(sentences[i])

A group of football players is running in the field
A group of people playing football is running in the field
Two groups of people are playing football
A person playing football is running past an official carrying a football


In [21]:
# Now, if we’d rather extract the numerical vectors from Faiss, we can do that too.
# we have 4 vectors to return (k) - so we initialize a zero array to hold them
vecs = np.zeros((k, d))
# then iterate through each ID from I and add the reconstructed vector to our zero-array
for i, val in enumerate(I[0].tolist()):
    vecs[i, :] = index.reconstruct(val)

In [22]:
vecs.shape

(4, 768)

In [23]:
# So, we are reducing the scope of our search, producing an approximate answer, rather than exact (as produced through exhaustive search).
nlist = 50  # how many cells
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

In [24]:
index.is_trained

False

In [25]:
index.train(sentence_embeddings)
index.is_trained  # check if index is now trained

True

In [26]:
index.add(sentence_embeddings)
index.ntotal  # number of embeddings indexed

14504

In [27]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[10784  6858  1343  2809]]
Wall time: 11 ms


In [29]:
for i in I[0]:
    print(sentences[i])

A group of football players is running in the field
A group of people playing football is running in the field
Two groups of people are playing football
A person playing football is running past an official carrying a football


In [30]:
index.nprobe = 10

In [31]:

index.make_direct_map()

In [32]:
index.reconstruct(10784  )

array([ 1.62704606e-02,  2.23259032e-01, -1.50373921e-01, -3.07472467e-01,
       -2.71224350e-01, -1.05931818e-01, -6.46094307e-02,  4.73814942e-02,
       -7.33490527e-01, -3.76576871e-01, -7.67627954e-01,  1.69028684e-01,
        5.31076372e-01,  5.11766613e-01,  1.14415836e+00, -8.56293514e-02,
       -6.72400951e-01, -9.66370761e-01,  2.54540890e-02, -2.15598449e-01,
       -1.25656593e+00, -8.29821765e-01, -9.82496515e-02, -2.18508437e-01,
        5.06102383e-01,  1.05279125e-01,  5.03968894e-01,  6.52429521e-01,
       -1.39458716e+00,  6.58474565e-01, -2.15253279e-01, -2.24874988e-01,
        8.18183720e-01,  8.46426561e-02, -7.61416674e-01, -2.89282888e-01,
       -9.82582867e-02, -7.30461836e-01,  7.85579160e-02, -8.43546569e-01,
       -5.92421174e-01,  7.74713695e-01, -1.20920563e+00, -2.27579549e-01,
       -1.30733597e+00, -2.30814755e-01, -1.31322527e+00,  1.62906349e-02,
       -9.72854614e-01,  1.93081975e-01,  4.74245638e-01,  1.18920863e+00,
       -1.96741295e+00, -