<a href="https://colab.research.google.com/github/jeanlucjackson/w266_final_project/blob/main/code/inference/BART_pt_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inference: BART trained on SQuAD

In [None]:
# Set these constants for each model and validation dataset combination

model_name = "bart_base_pt.squad"
validation_dataset_name = "squad"

save_predictions = True
save_mode = 'a' # w for write, a for append

max_length = 512 # 1024 for long model and 512 otherwise
batch_size = 150 # 150 is the norm, but dial back when needed

start_sample = 10000  # If None, then 0 will be used
end_sample = None # If None, then the end of the set will be used

### Generate predictions

In [None]:
!pip install -q transformers

In [None]:
!pip install -q sentencepiece

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import numpy as np
import pandas as pd

import torch
from transformers import BartTokenizer, BartForConditionalGeneration

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from google.colab import data_table
data_table.enable_dataframe_formatter()

In [None]:
# Some important file locations and constants

project_root = "/content/drive/MyDrive/w266 NLP Final Project/"
dataset_root = project_root + "Data/"
model_root = project_root + "Models/"
prediction_folder = project_root + "Predictions/"

tokenizer = "facebook/bart-base"

model_folder = model_root + model_name

validation_data_file = f"{dataset_root}squad.hf/bart_valid_pairs.csv"
if validation_dataset_name != "squad":
  validation_data_file = f"{dataset_root}{validation_dataset_name}/bart_valid_pairs.csv"

prediction_file = f"{prediction_folder}predictions.{model_name}.{validation_dataset_name}.csv"

In [None]:
validation_df = pd.read_csv(validation_data_file)
validation_df[['orig', 'target']][:2]

Unnamed: 0,orig,target
0,four </s> Prince Albert appears within the mai...,How many levels of galleries do the façades su...
1,"ink </s> When some species, including Bathycte...",What are the secretions commonly called?


In [None]:
validation_df.shape[0]

10570

In [None]:
# Download tokenizer and model, associate the model with the GPU

bart_tokenizer = BartTokenizer.from_pretrained(tokenizer)
bart_model = BartForConditionalGeneration.from_pretrained(model_folder)
bart_model.to(torch.device('cuda:0'))
pass

In [None]:
predictions = []
for input_text in validation_df['orig']:
  inputs = bart_tokenizer(input_text, return_tensors='pt')
  output_ids = bart_model.generate(inputs['input_ids'].cuda())
  prediction = "".join([bart_tokenizer.decode(out_ids, skip_special_tokens=True, 
                                            clean_up_tokenization_spaces=False) for out_ids in output_ids])
  predictions.append(prediction)

validation_df['prediction'] = predictions



In [None]:
validation_df['context'] = [str.split('</s>')[1] for str in validation_df['orig']]
validation_df['answer'] =  [str.split('</s>')[0] for str in validation_df['orig']]

In [None]:
# Reorder columns
df = validation_df[['context', 'answer', 'target', 'prediction']]

df[:10]

Unnamed: 0,context,answer,target,prediction
0,Prince Albert appears within the main arch ab...,four,How many levels of galleries do the façades su...,How many levels of galleries are there?
1,"When some species, including Bathyctena chuni...",ink,What are the secretions commonly called?,What are secretions?
2,The Grainger Market replaced an earlier marke...,1835,When did Newcastle's first indoor market open?,When was the Grainger Market opened?
3,Bills can be introduced to Parliament in a nu...,Bills,What may be presented to Parliament in various...,What can be introduced to Parliament in a numb...
4,Jacksonville is in the First Coast region of ...,the Timucua,"Prior to the arrival of the French, the area n...",Who originally inhabited Jacksonville?
5,"In addition to the Riemann hypothesis, many m...",1912,When did Landau propose his four conjectural p...,When were Landau's problems solved?
6,"In Marxian analysis, capitalist firms increas...",stagnant,What type of wages does mechanization and auto...,What type of wages did the substitution of cap...
7,The final major evolution of the steam engine...,90,What percentage of electrical power in the Uni...,What percentage of electric power is produced ...
8,"In 1968, ABC took advantage of new FCC owners...",1985,When was the ABC Pictures division eventually ...,When was ABC Motion Pictures dissolved?
9,The 2007 Lisbon Treaty explicitly recognised ...,the Charter of Fundamental Rights of the Europ...,What charter has become an important aspect of...,What document has become an integral part of E...


Save predictions

In [None]:
if save_predictions:
  df.to_csv(prediction_file, mode=save_mode)

Unused code

In [None]:
# predictions = []

# if start_sample is None:
#   start_sample = 0

# if end_sample is None:
#   end_sample = validation_df.shape[0]

# print(f"Generating predictions from {start_sample} to {end_sample}:")
# for start in range (start_sample, end_sample, batch_size):
#   to = min([end_sample, start + batch_size])
#   inputs = bart_tokenizer(validation_df['orig'][start:to].to_list(), return_tensors='pt', max_length=max_length, truncation=True, padding=True)
#   output_ids = bart_model.generate(inputs['input_ids'].cuda(), max_length=max_length)
#   prediction_batch = bart_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
#   predictions.extend(prediction_batch)
#   print (f"{to} ", end="")
#   if to%1000 == 0: print()
# print("Predictions generated.")

Generating predictions from 10000 to 10570:


In [None]:
# df=pd.DataFrame()
# df['context'] = [str.split('context: ')[1] for str in validation_df['orig'][start_sample:end_sample]]
# df['answer'] =  [str.split('context: ')[0][26: ] for str in validation_df['orig'][start_sample:end_sample]]
# df['target'] = validation_df['target']
# df['prediction'] = predictions