# Load models

## Original BART-SWiPE model (full)

In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BartTokenizerFast

# tokenizer_orig = AutoTokenizer.from_pretrained("Salesforce/bart-large-swipe")
tokenizer_ft = BartTokenizerFast.from_pretrained('facebook/bart-large') #use_fast = True)
model_orig = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/bart-large-swipe")

## Original BART-SWiPE-clean model

In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BartTokenizerFast

tokenizer_ft = BartTokenizerFast.from_pretrained('facebook/bart-large') # , use_fast=True)
#tokenizer_orig_clean = AutoTokenizer.from_pretrained("Salesforce/bart-large-swipe-clean")
model_orig_clean = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/bart-large-swipe-clean")

## Fine-tuned BART model (swipe-full)

In [2]:
from transformers import BartForConditionalGeneration, BartTokenizer, GenerationConfig

# saved model
model_ft_full = BartForConditionalGeneration.from_pretrained("../models/bart-swipe-ft/model-swipe-full")
# tokenizer
tokenizer_ft_full = BartTokenizer.from_pretrained('facebook/bart-large')
# generation config
generation_config = GenerationConfig.from_pretrained("../models/bart-swipe-ft/model-swipe-full")
model_ft_full.generation_config = generation_config

## Fine-tuned BART model (swipe-cleaned)

In [38]:
from transformers import BartForConditionalGeneration, BartTokenizerFast, GenerationConfig, BartTokenizer

# saved model
model_ft = BartForConditionalGeneration.from_pretrained("../models/bart-swipe-ft/model-swipe-clean-bart-tokenizer-512")
# tokenizer
tokenizer_ft = BartTokenizer.from_pretrained('facebook/bart-large', use_fast=True)
# generation config
generation_config = GenerationConfig.from_pretrained("../models/bart-swipe-ft/model-swipe-clean-bart-tokenizer-512")
model_ft.generation_config = generation_config

# Load dataset

## SWiPE - clean

In [2]:
from datasets import load_from_disk, DatasetDict
import pandas as pd

swipe_clean_dataset = load_from_disk("../data/swipe_clean")
swipe_clean_dataset = DatasetDict({
    'test': swipe_clean_dataset['test_id']
})

In [5]:
swipe_clean_dataset

DatasetDict({
    test: Dataset({
        features: ['r_content', 's_content'],
        num_rows: 483
    })
})

## ASSET - test 

In [39]:
from datasets import load_dataset

asset_dataset = load_dataset("facebook/asset", "simplification")
asset_dataset

DatasetDict({
    validation: Dataset({
        features: ['original', 'simplifications'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['original', 'simplifications'],
        num_rows: 359
    })
})

# Generate

In [3]:
tokenizer = tokenizer_ft
model = model_orig
dataset = swipe_clean_dataset

## Tokenizer adjustments

In [4]:
from transformers import AddedToken

special_tokens_dict = {
    'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    'sep_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    'cls_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True)
}

# add the special tokens to the tokenizer
tokenizer.add_special_tokens(special_tokens_dict)

5

## SWiPE

In [5]:
from tqdm import tqdm
import torch

predictions = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for example in tqdm(dataset['test']):
    # tokenize the text
    input_ids = tokenizer(example['r_content'], return_tensors="pt") # , truncation=True, padding="max_length", max_length=512)
    # move input_ids to the same device as the model
    input_ids = {key: value.to(device) for key, value in input_ids.items()}
    
    # generate prediction
    output_ids = model.generate(**input_ids, max_length=200, temperature=1.5, num_beams=4, num_return_sequences=1, do_sample=True) # ,max_length = 512, min_length=50, length_penalty=1.0) 
    simplified_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    predictions.append(simplified_text)

100%|███████████████████████████████████████████| 483/483 [03:13<00:00,  2.49it/s]


## ASSET

In [41]:
from tqdm import tqdm
import torch

predictions = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for example in tqdm(dataset['test']):
    # tokenize the text
    input_ids = tokenizer(example['original'], return_tensors="pt") # , truncation=True, padding="max_length", max_length=512)
    # move input_ids to the same device as the model
    input_ids = {key: value.to(device) for key, value in input_ids.items()}
    
    # generate prediction
    output_ids = model.generate(**input_ids, max_length=200, temperature=1.5, num_beams=5, num_return_sequences=1, do_sample=True) # ,max_length = 512, min_length=50, length_penalty=1.0) 
    simplified_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    
    predictions.append(simplified_text)

100%|███████████████████████████████████████████| 359/359 [01:27<00:00,  4.12it/s]


# Save results

In [6]:
import pandas as pd

# r_content for Swipe
# original for ASSET

df_results = pd.DataFrame({
    'text': dataset['test']['r_content'], 
    'prediction': predictions
})

df_results.head()

Unnamed: 0,text,prediction
0,The term jargon may have the following meaning...,Jargon is a word that can mean many different ...
1,"Russian (Russkij yazyk, Русский язык) is the m...","Russian (Russkij yazyk, Русский язык) is the m..."
2,"Great Britain, also called Britain, is an isla...","Great Britain, also called Britain, is an isla..."
3,"Transport, or transportation (as it is called ...","Transport, or transportation (as it is called ..."
4,Stockholm (help·info) (IPA: ['stɔkhɔlm]; UN/LO...,Stockholm is the capital city of Sweden. It is...


# Save results

In [7]:
df_results.to_csv("../data/gen_predictions/predictions_bart-large-swipe-bart-tokenizer-fast-adjust-num-beams4-no-use-fast!_swipe-clean-test.csv", index=False)