In [None]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Uses GPU if available
model_name = 'facebook/mbart-large-50-many-to-many-mmt'
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)

def get_backtranslation(text, src='en_XX', tgt='ru_RU'):
    pivot_lang_txt = get_translation(text, src, tgt)
    back_to_src_txt = get_translation(pivot_lang_txt, tgt, src)
    return back_to_src_txt

def get_translation(text, src, tgt):
    tokenizer.src_lang = src
    encoded = tokenizer(text, return_tensors='pt', max_length=60, truncation=True, padding=True).to(device)
    output = model.generate(**encoded, forced_bos_token_id= tokenizer.lang_code_to_id[tgt])
    decoded = tokenizer.batch_decode(output, skip_special_tokens=True)
    return decoded

In [None]:
import pylev
from sentence_transformers import SentenceTransformer, util
import language_tool_python

# initializing LanguageTool and SentenceTransformer model
sen_model = SentenceTransformer('paraphrase-MiniLM-L3-v2').to(device)
tool = language_tool_python.LanguageTool('en-US')

def get_distance(src_txt, paraphrased_txt):
    """Returns levenschtein distance at word level between src_text and paraphrase"""
    return pylev.levenschtein(src_txt.split(), paraphrased_txt.split())

def get_similarity(src_txt, paraphrased_txt):
    """Returns cosine similarity between source and paraphrase sentence vectors"""
    src_txt_encoded = sen_model.encode(src_txt, convert_to_tensor=True)
    paraphrased_txt_encoded = sen_model.encode(paraphrased_txt, convert_to_tensor=True)
    return util.pytorch_cos_sim(src_txt_encoded , paraphrased_txt_encoded).item()

def get_num_grammatical_errors(paraphrased_txt):
    """Returns the number of errors calculated by LanguageTool"""
    return len(tool.check(paraphrased_txt))

In [None]:
def get_distance_label(source, paraphrase):
    """Returns label that indicates how large the changes were between source and paraphrase text"""
    distance = get_distance(source, paraphrase)
    dist_percentage = distance/len(source.split())
    if dist_percentage <=0.25:
        return 'small'
    elif dist_percentage <=0.5:
        return 'medium'
    elif dist_percentage<=0.75:
        return 'large'
    else:
        return 'gigantic'

        
def get_length_label(source, paraphrase):
    """Returns the label that indicates whether to reduce, match, or expand source text"""
    para_length = len(paraphrase.split())
    source_length = len(source.split())
    if para_length == source_length:
        return 'match'
    elif source_length > para_length :
        return 'reduce'
    else:
        return 'expand'

def get_prefix_text(source, paraphrase):
    """Gets distance and length label. Then returns string with desired format"""
    length_label = get_length_label(source, paraphrase)
    distance_label = get_distance_label(source, paraphrase)
    return f"Paraphrase: {distance_label} changes, {length_label} input. {source}"

In [None]:
import json
from tqdm import tqdm


def generate_paraphrases():
    with open('open_stax_sentences.json', 'r', encoding='utf-8') as r:
        open_stax_sentences = json.load(r)
    batch = []  # batching to speed up the translation model
    batch_size = 32
    for sentence in tqdm(open_stax_sentences):
        batch.append(sentence)
        if len(batch) == batch_size:
            # gets a batch of back-translated text that are potential paraphrases
            paraphrases = get_backtranslation(batch)
            for src, paraphrase in zip(batch, paraphrases):
                modified_src = get_prefix_text(src, paraphrase)
                if get_distance(src, paraphrase) >= 3 and get_similarity(src, paraphrase) > 0.8 and not get_num_grammatical_errors(paraphrase):
                    yield {"Source": modified_src, 'Target': paraphrase}
            batch = []


paraphrase_data = list(generate_paraphrases())


In [None]:
from datasets import load_dataset
import re

def clean_spaces(sentence):
    """Just gets rid of the spaces before/after punctuation"""
    return re.sub(' ([.,;:-?!])', r'\1', sentence)


paws_data = load_dataset('paws', 'labeled_final')['train']
for item in paws_data:
    if item['label'] == 1:
        paraphrase_data.extend({"Source": get_prefix_text(clean_spaces(item['sentence1'])), "Target": clean_spaces(item['sentence2'])})

In [None]:
import random

random.shuffle(paraphrase_data)
train_ds = paraphrase_data[:-1000]
test_ds = paraphrase_data[-1000:]

with open('train_ds.json', 'w', encoding='utf-8') as w:
    json.dump(train_ds, w, ensure_ascii=False, indent=2)
    
with open('test_ds.json', 'w', encoding='utf-8') as w:
    json.dump(test_ds, w, ensure_ascii=False, indent=2)
    