<a href="https://colab.research.google.com/github/krishnarevi/TSAI_END2.0_Session14/blob/main/part3_session14_BART.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install simpletransformers

Collecting simpletransformers
  Downloading simpletransformers-0.61.13-py3-none-any.whl (221 kB)
[?25l[K     |█▌                              | 10 kB 26.5 MB/s eta 0:00:01[K     |███                             | 20 kB 19.6 MB/s eta 0:00:01[K     |████▍                           | 30 kB 14.9 MB/s eta 0:00:01[K     |██████                          | 40 kB 13.4 MB/s eta 0:00:01[K     |███████▍                        | 51 kB 5.1 MB/s eta 0:00:01[K     |████████▉                       | 61 kB 5.6 MB/s eta 0:00:01[K     |██████████▍                     | 71 kB 5.2 MB/s eta 0:00:01[K     |███████████▉                    | 81 kB 5.9 MB/s eta 0:00:01[K     |█████████████▎                  | 92 kB 5.7 MB/s eta 0:00:01[K     |██████████████▉                 | 102 kB 5.0 MB/s eta 0:00:01[K     |████████████████▎               | 112 kB 5.0 MB/s eta 0:00:01[K     |█████████████████▊              | 122 kB 5.0 MB/s eta 0:00:01[K     |███████████████████▎            | 133 kB

We will be using Quora question answer dataset to serve as training data for our BART Paraphrasing Model.

In [None]:
!pip install wandb



In [None]:
!wget http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv -P data

--2021-08-08 14:20:11--  http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv
Resolving qim.fs.quoracdn.net (qim.fs.quoracdn.net)... 151.101.1.2, 151.101.65.2, 151.101.129.2, ...
Connecting to qim.fs.quoracdn.net (qim.fs.quoracdn.net)|151.101.1.2|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58176133 (55M) [text/tab-separated-values]
Saving to: ‘data/quora_duplicate_questions.tsv’


2021-08-08 14:20:13 (96.8 MB/s) - ‘data/quora_duplicate_questions.tsv’ saved [58176133/58176133]



We also have a couple of helper functions, one to load data, and one to clean unnecessary spaces in the training data. 

In [None]:
import warnings

import pandas as pd


def load_data(
    file_path, input_text_column, target_text_column, label_column, keep_label=1
):
    df = pd.read_csv(file_path, sep="\t", error_bad_lines=False)
    df = df.loc[df[label_column] == keep_label]
    df = df.rename(
        columns={input_text_column: "input_text", target_text_column: "target_text"}
    )
    df = df[["input_text", "target_text"]]
    df["prefix"] = "paraphrase"

    return df


def clean_unnecessary_spaces(out_string):
    if not isinstance(out_string, str):
        warnings.warn(f">>> {out_string} <<< is not a string.")
        out_string = str(out_string)
    out_string = (
        out_string.replace(" .", ".")
        .replace(" ?", "?")
        .replace(" !", "!")
        .replace(" ,", ",")
        .replace(" ' ", "'")
        .replace(" n't", "n't")
        .replace(" 'm", "'m")
        .replace(" 's", "'s")
        .replace(" 've", "'ve")
        .replace(" 're", "'re")
    )
    return out_string

First, we import all the necessary stuff and set up logging.

In [None]:
import os
from datetime import datetime
import logging

import pandas as pd
from sklearn.model_selection import train_test_split
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs



logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.ERROR)

Next, we load the dataset

In [None]:
# Quora Data

# The Quora Dataset is not separated into train/test, so we do it manually the first time.
df = load_data(
    "data/quora_duplicate_questions.tsv", "question1", "question2", "is_duplicate"
)
q_train, q_test = train_test_split(df)

q_train.to_csv("data/quora_train.tsv", sep="\t")
q_test.to_csv("data/quora_test.tsv", sep="\t")

INFO:numexpr.utils:NumExpr defaulting to 4 threads.


In [None]:

train_df = q_train
eval_df = q_test

train_df = train_df[["prefix", "input_text", "target_text"]]
eval_df = eval_df[["prefix", "input_text", "target_text"]]

train_df = train_df.dropna().reset_index(drop=True)
eval_df = eval_df.dropna().reset_index(drop=True)

train_df["input_text"] = train_df["input_text"].apply(clean_unnecessary_spaces)
train_df["target_text"] = train_df["target_text"].apply(clean_unnecessary_spaces)

eval_df["input_text"] = eval_df["input_text"].apply(clean_unnecessary_spaces)
eval_df["target_text"] = eval_df["target_text"].apply(clean_unnecessary_spaces)

print(train_df.head())
len(train_df)

       prefix  ...                                        target_text
0  paraphrase  ...  Which is the best site for free sex chats with...
1  paraphrase  ...  What are the safety precautions on handling sh...
2  paraphrase  ...         What is the ideal age for getting married?
3  paraphrase  ...  What are the courses to do in the USA after th...
4  paraphrase  ...  Why is Manaphy bipolar in Pokémon ranger and t...

[5 rows x 3 columns]


111947

Then, we set up the model and hyperparameter values. Note that we are using the pre-trained facebook/bart-large model, and fine-tuning it on our own dataset.
Finally, we’ll generate paraphrases for each of the sentences in the test data.

In [None]:
model_args = Seq2SeqArgs()
model_args.do_sample = True
model_args.eval_batch_size = 64
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 2500
model_args.evaluate_during_training_verbose = True
model_args.fp16 = False
model_args.learning_rate = 5e-5
model_args.max_length = 50
model_args.max_seq_length = 50
model_args.num_beams = 5
model_args.num_return_sequences = 3
model_args.num_train_epochs = 2
model_args.overwrite_output_dir = True
model_args.reprocess_input_data = True
model_args.save_eval_checkpoints = False
model_args.save_steps = -1
model_args.top_k = 50
model_args.top_p = 0.95
model_args.train_batch_size = 8
model_args.use_multiprocessing = False
model_args.wandb_project = "Paraphrasing with BART"



model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="facebook/bart-large",
    args=model_args,
)

model.train_model(train_df, eval_data=eval_df)

to_predict = [
    prefix + ": " + str(input_text)
    for prefix, input_text in zip(eval_df["prefix"].tolist(), eval_df["input_text"].tolist())
]
truth = eval_df["target_text"].tolist()

preds = model.predict(to_predict)

# Saving the predictions if needed
os.makedirs("predictions", exist_ok=True)

with open(f"predictions/predictions_{datetime.now()}.txt", "w") as f:
    for i, text in enumerate(eval_df["input_text"].tolist()):
        f.write(str(text) + "\n\n")

        f.write("Truth:\n")
        f.write(truth[i] + "\n\n")

        f.write("Prediction:\n")
        for pred in preds[i]:
            f.write(str(pred) + "\n")
        f.write(
            "________________________________________________________________________________\n"
        )

INFO:filelock:Lock 140106268019280 acquired on /root/.cache/huggingface/transformers/d065edfe6954baf0b989a2063b26eb07e8c4d0b19354b5c74af9a51f5518df6e.6ca4df1a6ec59aa763989ceec10dff41dde19f0f0824b9f5d3fcd35a8abffdb2.lock


Downloading:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

INFO:filelock:Lock 140106268019280 released on /root/.cache/huggingface/transformers/d065edfe6954baf0b989a2063b26eb07e8c4d0b19354b5c74af9a51f5518df6e.6ca4df1a6ec59aa763989ceec10dff41dde19f0f0824b9f5d3fcd35a8abffdb2.lock
INFO:filelock:Lock 140106273777552 acquired on /root/.cache/huggingface/transformers/0d6fc8b2ef1860c1f8f0baff4b021e3426cc7d11b153f98e563b799603ee2f25.647b4548b6d9ea817e82e7a9231a320231a1c9ea24053cc9e758f3fe68216f05.lock


Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

INFO:filelock:Lock 140106273777552 released on /root/.cache/huggingface/transformers/0d6fc8b2ef1860c1f8f0baff4b021e3426cc7d11b153f98e563b799603ee2f25.647b4548b6d9ea817e82e7a9231a320231a1c9ea24053cc9e758f3fe68216f05.lock
INFO:filelock:Lock 140106273105168 acquired on /root/.cache/huggingface/transformers/6e75e35f0bdd15870c98387e13b93a8e100237eb33ad99c36277a0562bd6d850.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock


Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

INFO:filelock:Lock 140106273105168 released on /root/.cache/huggingface/transformers/6e75e35f0bdd15870c98387e13b93a8e100237eb33ad99c36277a0562bd6d850.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b.lock
INFO:filelock:Lock 140106293511952 acquired on /root/.cache/huggingface/transformers/d94f53c8851dcda40774f97280e634b94b721a58e71bcc152b5f51d0d49a046a.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730.lock


Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

INFO:filelock:Lock 140106293511952 released on /root/.cache/huggingface/transformers/d94f53c8851dcda40774f97280e634b94b721a58e71bcc152b5f51d0d49a046a.fc9576039592f026ad76a1c231b89aee8668488c671dfbe6616bab2ed298d730.lock
INFO:filelock:Lock 140106292911504 acquired on /root/.cache/huggingface/transformers/1abf196c889c24daca2909359ca2090e5fcbfa21a9ea36d763f70adbafb500d7.67d01b18f2079bd75eac0b2f2e7235768c7f26bd728e7a855a1c5acae01a91a8.lock


Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

INFO:filelock:Lock 140106292911504 released on /root/.cache/huggingface/transformers/1abf196c889c24daca2909359ca2090e5fcbfa21a9ea36d763f70adbafb500d7.67d01b18f2079bd75eac0b2f2e7235768c7f26bd728e7a855a1c5acae01a91a8.lock
INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/111947 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model: Training started


Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Running Epoch 0 of 2:   0%|          | 0/13994 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/37316 [00:00<?, ?it/s]

Exception in thread Thread-23:
Traceback (most recent call last):
  File "/usr/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.7/multiprocessing/pool.py", line 470, in _handle_results
    task = get()
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 251, in recv
    return _ForkingPickler.loads(buf.getbuffer())
  File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/reductions.py", line 294, in rebuild_storage_fd
    storage = cls._new_shared_fd(fd, size)
RuntimeError: unable to mmap 400 bytes from file <filename not specified>: Cannot allocate memory (12)



You can use the script below to test the model on any sentence.

In [None]:
import logging

from simpletransformers.seq2seq import Seq2SeqModel


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.ERROR)

model = Seq2SeqModel(
    encoder_decoder_type="bart", encoder_decoder_name="outputs"
)


while True:
    original = input("Enter text to paraphrase: ")
    to_predict = [original]

    preds = model.predict(to_predict)

    print("---------------------------------------------------------")
    print(original)

    print()
    print("Predictions >>>")
    for pred in preds[0]:
        print(pred)

    print("---------------------------------------------------------")
    print()