# Augmenting STS Dataset with SBERT

## Requirements

In [26]:
import pandas as pd
import torch
from pathlib import Path
import scipy.special as special
import csv
import numpy as np
import functools
from sentence_transformers import SentenceTransformer
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
from scipy.stats import pearsonr

src_dir = Path("stsbenchmark")
assert src_dir.exists()
output_dir = Path("data")
output_dir.mkdir(exist_ok=True, parents=True)

In [2]:
%%capture
df = pd.read_csv(src_dir / 'sts-train.csv', error_bad_lines=False, header = None, delimiter="\t", quoting=csv.QUOTE_NONE, encoding='utf-8')
df = df.rename(columns={0: "genre", 1: "filename", 2: "year", 3: "trash", 4: "score", 5: "s1", 6: "s2"})

In [3]:
df.genre = df.genre.astype("category")
df.filename = df.filename.astype("category")
df.year = df.year.astype("category")
df.genre = df.genre.astype("category")

In [4]:
df

Unnamed: 0,genre,filename,year,trash,score,s1,s2
0,main-captions,MSRvid,2012test,1,5.00,A plane is taking off.,An air plane is taking off.
1,main-captions,MSRvid,2012test,4,3.80,A man is playing a large flute.,A man is playing a flute.
2,main-captions,MSRvid,2012test,5,3.80,A man is spreading shreded cheese on a pizza.,A man is spreading shredded cheese on an uncoo...
3,main-captions,MSRvid,2012test,6,2.60,Three men are playing chess.,Two men are playing chess.
4,main-captions,MSRvid,2012test,9,4.25,A man is playing the cello.,A man seated is playing the cello.
...,...,...,...,...,...,...,...
5547,main-news,headlines,2015,1489,1.20,"Palestinian hunger striker, Israel reach deal",Palestinian activist detained in Israeli raid
5548,main-news,headlines,2015,1493,4.80,Assad says Syria will comply with UN arms reso...,Syria's Assad vows to comply with U.N. resolution
5549,main-news,headlines,2015,1496,4.60,South Korean President Sorry For Ferry Response,S. Korean president 'sorry' for ferry disaster
5550,main-news,headlines,2015,1498,0.00,Food price hikes raise concerns in Iran,American Chris Horner wins Tour of Spain


In [5]:
model = SentenceTransformer('all-MiniLM-L6-v2')

In [6]:
s1_embeds = model.encode(df.s1, device='cuda:0')
# normalize embeddings
s1_embeds = s1_embeds / np.linalg.norm(s1_embeds, axis=-1)[:, None]
s1_embeds.shape

(5552, 384)

In [7]:
s2_embeds = model.encode(df.s2, device='cuda:0')
# normalize embeddings
s2_embeds = s2_embeds / np.linalg.norm(s2_embeds, axis=-1)[:, None]
s2_embeds.shape

(5552, 384)

In [8]:
similarities = (s1_embeds[:, None, :] @ s2_embeds[:, :, None]).squeeze()

In [9]:
# normalize to 0 .. 1 range (should already be so)
similarities = (similarities - similarities.min()) / (similarities.max() - similarities.min())

print(similarities.shape)
similarities.min(), similarities.max()

(5552,)


(0.0, 1.0)

In [10]:
pearsonr(similarities, df.score / 5)

(0.8386030111528666, 0.0)

In [11]:
sentences = pd.concat([df.s1, df.s2]).drop_duplicates().reset_index(drop=True)
sentences

0                                   A plane is taking off.
1                          A man is playing a large flute.
2            A man is spreading shreded cheese on a pizza.
3                             Three men are playing chess.
4                              A man is playing the cello.
                               ...                        
10188        Palestinian activist detained in Israeli raid
10189    Syria's Assad vows to comply with U.N. resolution
10190       S. Korean president 'sorry' for ferry disaster
10191             American Chris Horner wins Tour of Spain
10192         Obama mulls limited military action in Syria
Length: 10193, dtype: object

In [12]:
model = SentenceTransformer('all-MiniLM-L6-v2')

In [13]:
embeddings = model.encode(sentences, device='cuda:0')
# normalize embeddings
embeddings = embeddings / np.linalg.norm(embeddings, axis=-1)[:, None]
embeddings.shape

(10193, 384)

In [14]:
n = len(sentences)  # 1000
index_pairs = np.concatenate(np.stack(np.mgrid[:n, :n], axis=2))
index_pairs.shape

(103897249, 2)

In [15]:
dataloader = torch.utils.data.DataLoader(index_pairs, batch_size=8192)
next(iter(dataloader)).shape

torch.Size([8192, 2])

In [None]:
similarities = []

for idx in tqdm(dataloader):
    s1_embeds = embeddings[idx[:, 0]]
    s2_embeds = embeddings[idx[:, 1]]
    # compute inner product
    similarities.append((s1_embeds[:, None, :] @ s2_embeds[:, :, None]).squeeze())

similarities = np.concatenate(similarities)

# normalize to 0 .. 1 range (should already be so)
similarities = (similarities - similarities.min()) / (similarities.max() - similarities.min())

print(similarities.shape)
similarities.min(), similarities.max()

  0%|          | 0/12683 [00:00<?, ?it/s]

In [None]:
plt.hist(similarities, bins=100)
plt.show()

In [None]:
# select evenly distributed
bins = np.histogram_bin_edges(similarities, bins=1000)
indices = np.digitize(similarities, bins)
indices.shape

In [None]:
N_SAMPLE = 1_000
sample_idxs = []

for i in tqdm(range(1000)):
    bin_indices = np.where(indices == i)[0]
    if len(bin_indices) > 1:
        sample_idxs.append(np.random.choice(bin_indices, size=N_SAMPLE))

sample_idxs = np.concatenate(sample_idxs)
subset = similarities[sample_idxs]
subset.shape

In [None]:
plt.hist(subset, bins=100)
plt.show()

In [None]:
s1 = sentences[index_pairs[:, 0]].iloc[sample_idxs]
s2 = sentences[index_pairs[:, 1]].iloc[sample_idxs]
subset = pd.Series(subset)
len(s1), len(s2), len(subset)

In [None]:
# Save to disk
df_scored = pd.DataFrame({
    "s1": s1.tolist(),
    "s2": s2.tolist(),
    "score": subset.tolist(),
})
df_scored

In [27]:
df_scored.to_feather(output_dir / "subset.feather")