<a href="https://colab.research.google.com/github/jeanlucjackson/w266_final_project/blob/main/code/BB/bb_bart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BART for Question Generation

In [80]:
import numpy as np
import pandas as pd
import os
import re

import json

# Make longer output readable without scrolling
from pprint import pprint

# Stop warning messages from showing
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import torch  # Only if you use a pytorch model, both options are shown below
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer



In [2]:
!pip install -q sentencepiece

[?25l[K     |▎                               | 10 kB 14.8 MB/s eta 0:00:01[K     |▌                               | 20 kB 9.8 MB/s eta 0:00:01[K     |▊                               | 30 kB 12.6 MB/s eta 0:00:01[K     |█                               | 40 kB 7.6 MB/s eta 0:00:01[K     |█▎                              | 51 kB 7.8 MB/s eta 0:00:01[K     |█▌                              | 61 kB 9.1 MB/s eta 0:00:01[K     |█▉                              | 71 kB 8.6 MB/s eta 0:00:01[K     |██                              | 81 kB 7.4 MB/s eta 0:00:01[K     |██▎                             | 92 kB 8.2 MB/s eta 0:00:01[K     |██▋                             | 102 kB 8.9 MB/s eta 0:00:01[K     |██▉                             | 112 kB 8.9 MB/s eta 0:00:01[K     |███                             | 122 kB 8.9 MB/s eta 0:00:01[K     |███▍                            | 133 kB 8.9 MB/s eta 0:00:01[K     |███▋                            | 143 kB 8.9 MB/s eta 0:00:01[K   

In [3]:
!pip install -q transformers

[K     |████████████████████████████████| 5.3 MB 8.9 MB/s 
[K     |████████████████████████████████| 7.6 MB 38.8 MB/s 
[K     |████████████████████████████████| 163 kB 22.6 MB/s 
[?25h

In [4]:
!pip install -q datasets

[K     |████████████████████████████████| 441 kB 8.7 MB/s 
[K     |████████████████████████████████| 115 kB 56.6 MB/s 
[K     |████████████████████████████████| 212 kB 57.7 MB/s 
[K     |████████████████████████████████| 127 kB 49.9 MB/s 
[K     |████████████████████████████████| 115 kB 46.2 MB/s 
[?25h

In [None]:
# !pip install -q evaluate
# import evaluate

In [6]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


## Data

In [19]:
dataset_root = "/content/drive/MyDrive/w266 NLP Final Project/Data/"
dataset_name = "squad"
dataset_folder = dataset_root+dataset_name+".hf"
training_file = dataset_folder + '/train_pairs.csv'
validation_file = dataset_folder + '/valid_pairs.csv'

Load the data

In [20]:
dataset = load_from_disk(dataset_folder)

In [105]:
# Training data
count=100
training_data = dataset['train'].shuffle(seed=1962).select(range(count))
training_answers = [answer['text'][0] for answer in training_data['answers']]
training_context = training_data['context']
training_questions = training_data['question']

# Validation data
count=10
validation_data = dataset['validation'].shuffle(seed=1962).select(range(count))
validation_answers = [answer['text'][0] for answer in validation_data['answers']]
validation_context = validation_data['context']
validation_questions = validation_data['question']



Formatting input and output/target pairs

In [106]:
training_orig = [f"answer: {answer} context: {context}" for answer, context in zip (training_answers, training_context)]
training_target = training_questions
validation_orig = [f"answer: {answer} context: {context}" for answer, context in zip (validation_answers, validation_context)]
validation_target = validation_questions

In [107]:
training_df = pd.DataFrame()
training_df['orig'] = training_orig
training_df['target'] = training_target
training_df

Unnamed: 0,orig,target
0,answer: biotech companies context: Prior to mo...,What type of businesses did Nickles want to at...
1,answer: Tytus Woyciechowski context: Four boar...,To whom did Chopin reveal in letters which par...
2,answer: the Endangered Species Committee conte...,"If a species may be harmed, who holds final sa..."
3,answer: China context: In Asian countries such...,What country has the dog as part of its 12 ani...
4,answer: 45 years context: Saint Athanasius of ...,How long did his episcopate last?
...,...,...
95,answer: objects context: Buddhist scriptures a...,Some schools venerate certain texts as religio...
96,answer: Sadat context: Following Anwar Sadat's...,Whose rise to the presidency of Egypt led to t...
97,answer: flogged context: Combining statements ...,How was Jesus tortured before he was crucified?
98,answer: Nick Fradiani context: Nick Fradiani w...,Who won American Idols fourteenth season?


In [108]:
validation_df = pd.DataFrame()
validation_df['orig'] = validation_orig
validation_df['target'] = validation_target
validation_df

Unnamed: 0,orig,target
0,answer: four context: Prince Albert appears wi...,How many levels of galleries do the façades su...
1,"answer: ink context: When some species, includ...",What are the secretions commonly called?
2,answer: 1835 context: The Grainger Market repl...,When did Newcastle's first indoor market open?
3,answer: Bills context: Bills can be introduced...,What may be presented to Parliament in various...
4,answer: the Timucua context: Jacksonville is i...,"Prior to the arrival of the French, the area n..."
5,answer: 1912 context: In addition to the Riema...,When did Landau propose his four conjectural p...
6,"answer: stagnant context: In Marxian analysis,...",What type of wages does mechanization and auto...
7,answer: 90 context: The final major evolution ...,What percentage of electrical power in the Uni...
8,"answer: 1985 context: In 1968, ABC took advant...",When was the ABC Pictures division eventually ...
9,answer: the Charter of Fundamental Rights of t...,What charter has become an important aspect of...


In [109]:
train_pairs = training_df.shape[0]
valid_pairs = validation_df.shape[0]

print(f"{train_pairs} training pairs")
print(f"{valid_pairs} validation pairs")

100 training pairs
10 validation pairs


In [110]:
# Save splits to separate csv files, to load only part at a time later
training_df.to_csv(train_file)
validation_df.to_csv(valid_file)

Load formatted inputs and outputs

In [111]:
training_df = pd.read_csv(train_file)
validation_df = pd.read_csv(valid_file)

In [112]:
training_df[:2]

Unnamed: 0.1,Unnamed: 0,orig,target
0,0,answer: biotech companies context: Prior to mo...,What type of businesses did Nickles want to at...
1,1,answer: Tytus Woyciechowski context: Four boar...,To whom did Chopin reveal in letters which par...


Preprocess data

In [113]:
def preprocess_data_pt(text_pair, tokenizer, max_length=1024):
    orig_text, target_text = text_pair
    orig_encoded = tokenizer.batch_encode_plus(
        [orig_text],
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )

    orig_input_ids = orig_encoded['input_ids'][0]
    orig_attention_mask = orig_encoded['attention_mask'][0]
    
    target_encoded = tokenizer.batch_encode_plus(
        [target_text],
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    
    label_ids = target_encoded['input_ids'][0]
    
    return {'input_ids': orig_input_ids,
            'attention_mask': orig_attention_mask,
            'labels': label_ids}

In [114]:
class QuestionDataGenerator(tf.keras.utils.Sequence):
    
    def __init__(self,
                 tokenizer,
                 n_examples,
                 data_filename,
                 max_length=1024,
                 shuffle=True):
        
        self.tokenizer = tokenizer
        self.n_examples = n_examples
        self.data_filename = data_filename
        self.max_length = max_length
        self.shuffle = shuffle
        
        # Initialize row order, call on_epoch_end to shuffle row indices
        self.row_order = np.arange(1, self.n_examples+1)
        self.on_epoch_end()
    
    def __len__(self):
        return self.n_examples
    
    def __getitem__(self, idx):
        row_to_load = self.row_order[idx]
        df = pd.read_csv(self.data_filename,
                         skiprows=range(1, row_to_load),
                         nrows=1)
        
        text_pairs = df[['orig', 'target']].values.astype(str)[0]
        
        batch_data = preprocess_data_pt(
            text_pairs,
            self.tokenizer,
            self.max_length
        )

        return batch_data
    
    def __call__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)
            
            if i == self.__len__()-1:
                self.on_epoch_end()
    
    def on_epoch_end(self):
        if self.shuffle:
            self.row_order = list(np.random.permutation(self.row_order))

In [None]:
# Load the pretrained tensorflow model

model_name = 'facebook/bart-base'
bart_tokenizer = BartTokenizer.from_pretrained(model_name)
bart_model = BartForConditionalGeneration.from_pretrained(model_name)

In [116]:
# Create the data iterators for train and validation data, pytorch version

max_length = 32

train_data_generator = QuestionDataGenerator(
    tokenizer=bart_tokenizer,
    n_examples=train_pairs,
    data_filename=train_file,
    max_length=max_length
)

valid_data_generator = QuestionDataGenerator(
    tokenizer=bart_tokenizer,
    n_examples=valid_pairs,
    data_filename=valid_file,
    max_length=max_length
)

In [117]:
# Specify batch size and other training arguments

batch_size = 8

# Modify this filepath to where you want to save the model after fine-tuning
dir_path = '/content/drive/MyDrive/w266 NLP Final Project/Models/BB BART/'
file_path = dir_path + 'bartbase-finetuned-squad'

args = Seq2SeqTrainingArguments(
    file_path,
    evaluation_strategy='epoch',
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [118]:
# Define the trainer, passing in the model, training args, and data generators

trainer = Seq2SeqTrainer(
    bart_model,
    args,
    train_dataset=train_data_generator,
    eval_dataset=valid_data_generator
)

In [119]:
# Call train

trainer.train(resume_from_checkpoint=True)

***** Running training *****
  Num examples = 100
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 13


Epoch,Training Loss,Validation Loss
1,No log,7.628573


***** Running Evaluation *****
  Num examples = 10
  Batch size = 8


Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=13, training_loss=8.895165076622597, metrics={'train_runtime': 78.8822, 'train_samples_per_second': 1.268, 'train_steps_per_second': 0.165, 'total_flos': 1905426432000.0, 'train_loss': 8.895165076622597, 'epoch': 1.0})

Save model

In [120]:
dir_path = "/content/drive/MyDrive/w266 NLP Final Project/Models/bart_base_pt_squad/"
trainer.save_model(dir_path)

Saving model checkpoint to /content/drive/MyDrive/w266 NLP Final Project/Models/bart_base_pt_squad/
Configuration saved in /content/drive/MyDrive/w266 NLP Final Project/Models/bart_base_pt_squad/config.json
Model weights saved in /content/drive/MyDrive/w266 NLP Final Project/Models/bart_base_pt_squad/pytorch_model.bin


### Check to see if model worked

In [121]:
sample_df = validation_df[0:10].copy()


In [124]:
predictions = []
for input_text in sample_df['orig']:
  inputs = bart_tokenizer(input_text, return_tensors='pt')
  output_ids = bart_model.generate(inputs['input_ids'])
  prediction = "".join([bart_tokenizer.decode(out_ids, skip_special_tokens=True, 
                                            clean_up_tokenization_spaces=False) for out_ids in output_ids])
  predictions.append(prediction)

sample_df['prediction'] = predictions



In [125]:
sample_df

Unnamed: 0.1,Unnamed: 0,orig,target,prediction
0,0,answer: four context: Prince Albert appears wi...,How many levels of galleries do the façades su...,What was the design of the interior of the bui...
1,1,"answer: ink context: When some species, includ...",What are the secretions commonly called?,What is ctenophores' bioluminescence?
2,2,answer: 1835 context: The Grainger Market repl...,When did Newcastle's first indoor market open?,What was Grainger's first market?
3,3,answer: Bills context: Bills can be introduced...,What may be presented to Parliament in various...,What is a draft bill?
4,4,answer: the Timucua context: Jacksonville is i...,"Prior to the arrival of the French, the area n...",What is the population of Jacksonville?
5,5,answer: 1912 context: In addition to the Riema...,When did Landau propose his four conjectural p...,What is the Riemann conjecture?
6,6,"answer: stagnant context: In Marxian analysis,...",What type of wages does mechanization and auto...,What is the level of wage growth for the worki...
7,7,answer: 90 context: The final major evolution ...,What percentage of electrical power in the Uni...,What type of steam engine was the first to be ...
8,8,"answer: 1985 context: In 1968, ABC took advant...",When was the ABC Pictures division eventually ...,What was ABC's first major acquisition?
9,9,answer: the Charter of Fundamental Rights of t...,What charter has become an important aspect of...,What is the Charter of Fundamental Rights of t...


## BART

Sources:


*  https://huggingface.co/facebook/bart-base
*  https://huggingface.co/transformers/v3.1.0/model_doc/bart.html#bartforquestionanswering
*  https://huggingface.co/valhalla/bart-large-finetuned-squadv1?context=My+name+is+Sarah+and+I+live+in+London&question=Where+do+I+live%3F
*  https://huggingface.co/docs/transformers/v4.23.1/en/model_doc/bart#transformers.BartForConditionalGeneration

###Training

facebook/bart-base trained on CNN/daily mail data?

facebook/bart-large trained on  MultiNLI (MNLI) dataset