In [1]:
%load_ext autoreload  
%autoreload 2 
%matplotlib inline

In [2]:
import torch 
from tqdm import tqdm
import json
from datasets import load_dataset
from sentence_transformers import CrossEncoder

In [3]:
# get instances
dataset = load_dataset(
    'wiki_auto', 'auto_acl')

Reusing dataset wiki_auto (/Users/garylai/.cache/huggingface/datasets/wiki_auto/auto_acl/1.0.0/5ffdd9fc62422d29bd02675fb9606f77c1251ee17169ac10b143ce07ef2f4db8)
100%|██████████| 1/1 [00:00<00:00, 54.81it/s]


In [4]:
model = CrossEncoder('cross-encoder/stsb-roberta-large')

In [5]:
# get paraphrase simliarity scores
scores = []
for sample in tqdm(zip(dataset['full'][:500]['normal_sentence'], dataset['full'][:500]['simple_sentence'])):
    normal_sentence, simple_sentence = sample
    score = model.predict([normal_sentence, simple_sentence])
    scores.append(score)

500it [03:03,  2.73it/s]


In [6]:
# get top k 
score_tensor = torch.tensor(scores); score_tensor
topk_scores, topk_indices = torch.topk(score_tensor, 500)

In [7]:
# utility
index_to_score = {}
for sample in zip(topk_indices, topk_scores):
    index, score = sample
    index_to_score[int(index)] = float(score)

In [8]:
def filter_indices(difference=80, similarity_threshold=0.8):
    """
    Args: 
        - difference: normal sentence must be at least `difference` characters longer to be selected
        - similarity_threshold: normal sentence and simple sentence must be at least this similarity and above
    """
    filtered_indices = []
    for i in topk_indices:
        index = int(i)
        # check normal sentence is sufficiently longer
        if (len(dataset['full'][index]['normal_sentence']) - len(dataset['full'][index]['simple_sentence'])) < difference:
            continue
        # check two sentences are sufficiently similar
        if index_to_score[index] < similarity_threshold:
            continue         
        filtered_indices.append(index)

    return filtered_indices

In [9]:
filtered_indices = filter_indices(difference=80, similarity_threshold=0.85)
filtered_ds = dataset['full'].select(filtered_indices)

In [10]:
filtered_ds['normal_sentence'][0]

'A tank car -LRB- International Union of Railways -LRB- UIC -RRB- : tank wagon -RRB- is a type of railroad car -LRB- UIC : railway car -RRB- or rolling stock designed to transport liquid and gaseous commodities .\n'

In [31]:
def clean_sentence(sentence):
    replacement = {
        "-LRB-": "(",
        "-RRB-": ")",
        "\\n": ""
    }

    cleaned_sentence = sentence
    for substring in replacement.keys():
        if substring in cleaned_sentence:
            cleaned_sentence = cleaned_sentence.replace(substring, replacement[substring])
    return cleaned_sentence

# print(clean_sentence(filtered_ds['normal_sentence'][0]))

In [32]:
# check
for i, sample in enumerate(zip(filtered_ds['normal_sentence'], filtered_ds['simple_sentence'])):
    normal_sentence, simple_sentence = sample
    normal_sentence, simple_sentence = clean_sentence(normal_sentence), clean_sentence(simple_sentence)
    print("-" * 80)
    print(f"normal_sentence: {normal_sentence} \nsimple_sentence: {simple_sentence} \noriginal_index: {filtered_indices[i]}" )
    print("score: ", index_to_score[filtered_indices[i]])

--------------------------------------------------------------------------------
normal_sentence: A tank car ( International Union of Railways ( UIC ) : tank wagon ) is a type of railroad car ( UIC : railway car ) or rolling stock designed to transport liquid and gaseous commodities .
 
simple_sentence: A tank car or tank wagon is a type of railroad car designed to transport liquids or gases .
 
original_index: 35
score:  0.9384129047393799
--------------------------------------------------------------------------------
normal_sentence: Industrial waste is the waste produced by industrial activity which includes any material that is rendered useless during a manufacturing process such as that of factories , industries , mills , and mining operations .
 
simple_sentence: Industrial waste is the waste produced by industrial activity , such as that of factories , mills and mines .
 
original_index: 197
score:  0.9271243214607239
------------------------------------------------------------