# Augmenting STS Dataset with SBERT

## Requirements

In [1]:
import torch
import torch.nn.functional as F
import pandas as pd
import csv

from sentence_transformers import SentenceTransformer
from scipy.stats import pearsonr
from pathlib import Path

src_dir = Path('data/stsbenchmark')
assert src_dir.exists()
output_dir = Path('data')
output_dir.mkdir(exist_ok=True, parents=True)

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

## Load STS Benchmark Dataset

In [2]:
%%capture
df_sts = pd.read_csv(src_dir / 'sts-test.csv', error_bad_lines=False, header = None, delimiter="\t", quoting=csv.QUOTE_NONE, encoding='utf-8')
df_sts = df_sts.rename(columns={0: "genre", 1: "filename", 2: "year", 3: "trash", 4: "score", 5: "s1", 6: "s2"})

In [3]:
df_sts.genre = df_sts.genre.astype("category")
df_sts.filename = df_sts.filename.astype("category")
df_sts.year = df_sts.year.astype("category")
df_sts.genre = df_sts.genre.astype("category")

In [4]:
df_sts

Unnamed: 0,genre,filename,year,trash,score,s1,s2
0,main-captions,MSRvid,2012test,24,2.5,A girl is styling her hair.,A girl is brushing her hair.
1,main-captions,MSRvid,2012test,33,3.6,A group of men play soccer on the beach.,A group of boys are playing soccer on the beach.
2,main-captions,MSRvid,2012test,45,5.0,One woman is measuring another woman's ankle.,A woman measures another woman's ankle.
3,main-captions,MSRvid,2012test,63,4.2,A man is cutting up a cucumber.,A man is slicing a cucumber.
4,main-captions,MSRvid,2012test,66,1.5,A man is playing a harp.,A man is playing a keyboard.
...,...,...,...,...,...,...,...
1090,main-news,headlines,2015,1438,0.4,"US, China fail to paper over cracks in ties",China: Relief in focus as hope for missing fades
1091,main-news,headlines,2015,1454,1.4,World Cup live: France 0-0 Germany,World Cup live: Germany 0-0 Ghana
1092,main-news,headlines,2015,1456,4.8,Tokyo to host 2020 Games,Tokyo wins race to host 2020 Olympic Games
1093,main-news,headlines,2015,1463,4.4,France warns of extremists benefiting from Egy...,France fears extremists will benefit from Egyp...


## Compute Embeddings

In [5]:
model = SentenceTransformer('all-MiniLM-L6-v2')

In [6]:
s1_embeds = model.encode(df_sts.s1, convert_to_tensor=True, device=device)
s2_embeds = model.encode(df_sts.s2, convert_to_tensor=True, device=device)
# normalize embeddings, reshape for later cosine similarity
s1_embeds = F.normalize(s1_embeds).unsqueeze(1)
s2_embeds = F.normalize(s2_embeds).unsqueeze(2)
# print shapes
print(s1_embeds.shape, s2_embeds.shape)

torch.Size([1095, 1, 384]) torch.Size([1095, 384, 1])


In [7]:
scores = (s1_embeds @ s2_embeds).squeeze().cpu()

In [10]:
pearsonr(scores, df_sts.score)

(0.8663615563747435, 0.0)