<a href="https://colab.research.google.com/github/jeanlucjackson/w266_final_project/blob/main/code/Evaluation/probe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
!pip install -q transformers

[K     |████████████████████████████████| 5.5 MB 5.6 MB/s 
[K     |████████████████████████████████| 182 kB 37.9 MB/s 
[K     |████████████████████████████████| 7.6 MB 24.6 MB/s 
[?25h

In [23]:
!pip install -q sentencepiece

In [52]:
import os
import numpy as np
import pandas as pd

import torch

from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import T5Tokenizer, T5ForConditionalGeneration

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)


In [34]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Set these constants for each model and validation dataset combination

model_name = "bart_base_pt_long.nq"
validation_dataset_names = ["nq", "quac", "squad", "triviaqa"]

save_predictions = True
save_mode = 'w' # w for write, a for append

max_length = 1024 # 1024 for long model and 512 otherwise
batch_size = 50

start_sample = None  # If None, then 0 will be used
end_sample = None # If None, then the end of the set will be used

### Generate Predictions From An Awesome Validation Dataset

This notebook assumes a T5 PyTorch model.

Setting the constants in the next call should be all that is necessary to run the validation set.

### Probe Class

In [53]:
class Probe:

  project_root = "/content/drive/MyDrive/w266 NLP Final Project/"
  model_root = project_root + "Models/"

  def __init__ (self):
    self.models = {}
    self.tokenizers = {}

  def predict(self, context, question, 
              base_model='bart', training_dataset='amalgam', 
              num_beams=1, early_stopping=False, no_repeat_ngram_size=0, 
              maximum_input_length = 1024, maximum_target_length = 50):
    
    tokenizer = self.retrieve_tokenizer(base_model)
    model = self.retrieve_model(base_mode, training_dataset)
    return

  def retrieve_tokenizer(self, base_model='bart'):
    tokenizer = self.tokenizers.get(base_model)
    if tokenizer:
      return tokenizer

    if base_model=='bart':
      tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
    elif base_model=="T5":
      tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-base")
    else:
      raise ValueError ("invalid base model")

    self.tokenizers[base_model] = tokenizer
    return tokenizer

  def retrieve_model(self, base_model='bart', training_dataset='amalgam'):
    model_tuple = (base_model, training_dataset)
    model = self.models.get(model_tuple)
    if model:
      return model
    
    model_dir=f"{self.model_root}{base_model}_base_pt_long.{training_dataset}"

    if base_model=='bart':
      model = BartForConditionalGeneration.from_pretrained(model_dir)
    elif base_model=='T5':
      model = T5ForConditionalGeneration.from_pretrained(model_dir)
    else:
      raise ValueError ("invalid base model")
    
    model.to(torch.device('cuda:0'))
    self.models[model_tuple] = model
    return model


In [54]:
p = Probe()

In [55]:
p.retrieve_model()

RuntimeError: ignored

### Generate Predictions

In [None]:
# Some important file locations and constants

#project_root = "/content/drive/MyDrive/w266 NLP Final Project/"
project_root = "/home/localadmin/Documents/w266_NLP_Final_Project/"
dataset_root = project_root + "Data/"
model_root = project_root + "Models/"
prediction_folder = project_root + "Predictions/checkpoint/"

In [None]:
# Get the model and tokenizer

bart_tokenizer = BartTokenizer.from_pretrained(tokenizer)
bart_model = BartForConditionalGeneration.from_pretrained(model_folder)
bart_model.to(torch.device('cuda:0'))
pass

In [None]:
for dataset_name in validation_dataset_names:
  if dataset_name == "squad":
    validation_data_file = f"{dataset_root}squad.hf/bart_valid_pairs.csv"
  else:  
    validation_data_file = f"{dataset_root}{dataset_name}/bart_valid_pairs.csv"
  print(validation_data_file)
  validation_df = pd.read_csv(validation_data_file)
  prediction_file = f"{prediction_folder}predictions.{model_name}.{dataset_name}.beams.csv" 
  
  start_sample = None
  end_sample = None

  predictions = []
  
  if start_sample is None: start_sample = 0
  if end_sample is None: end_sample = validation_df.shape[0]
  
  print(f"Generating predictions using {dataset_name} from {start_sample} to {end_sample}:")
  for start in range (start_sample, end_sample, batch_size):
     to = min([end_sample, start + batch_size])
     inputs = bart_tokenizer(validation_df['orig'][start:to].to_list(), return_tensors='pt', max_length=max_length, truncation=True, padding=True)
     output_ids = bart_model.generate(inputs['input_ids'].cuda(), max_length=50, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
     prediction_batch = bart_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
     predictions.extend(prediction_batch)
     print (f"{to} ", end="")
     if to%1000 == 0: print()
  print("\nPredictions generated.")

  df=pd.DataFrame()
  df['context'] = [str.split('</s>')[1] for str in validation_df['orig'][start_sample:end_sample]]
  df['answer'] =  [str.split('</s>')[0] for str in validation_df['orig'][start_sample:end_sample]]
  df['target'] = validation_df['target']
  df['prediction'] = predictions

  if save_predictions:
    df.to_csv(prediction_file, mode=save_mode)
    print(f"Write: {prediction_file}")

/home/localadmin/Documents/w266_NLP_Final_Project/Data/nq/bart_valid_pairs.csv
Generating predictions using nq from 0 to 2356:
50 100 150 200 250 300 350 400 450 500 550 600 650 700 750 800 850 900 950 1000 
1050 1100 1150 1200 1250 1300 1350 1400 1450 1500 1550 1600 1650 1700 1750 1800 1850 1900 1950 2000 
2050 2100 2150 2200 2250 2300 2350 2356 
Predictions generated.
Write: /home/localadmin/Documents/w266_NLP_Final_Project/Predictions/checkpoint/predictions.bart_base_pt_long.nq.nq.beams.csv
/home/localadmin/Documents/w266_NLP_Final_Project/Data/quac/bart_valid_pairs.csv
Generating predictions using quac from 0 to 5868:
50 100 150 200 250 300 350 400 450 500 550 600 650 700 750 800 850 900 950 1000 
1050 1100 1150 1200 1250 1300 1350 1400 1450 1500 1550 1600 1650 1700 1750 1800 1850 1900 1950 2000 
2050 2100 2150 2200 2250 2300 2350 2400 2450 2500 2550 2600 2650 2700 2750 2800 2850 2900 2950 3000 
3050 3100 3150 3200 3250 3300 3350 3400 3450 3500 3550 3600 3650 3700 3750 3800 3850 39