<a href="https://colab.research.google.com/github/jeanlucjackson/w266_final_project/blob/main/code/inference/awesome_T5_pt_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Generate Predictions From An Awesome Validation Dataset

This notebook assumes a T5 PyTorch model.

Setting the constants in the next call should be all that is necessary to run the validation set.

In [1]:
# Set these constants for each model and validation dataset combination

model_name = "T5_base_pt_long.quac"
validation_dataset_name = "triviaqa"
save_predictions = True

### Generate Predictions

In [2]:
!pip install -q transformers

[K     |████████████████████████████████| 5.5 MB 8.3 MB/s 
[K     |████████████████████████████████| 7.6 MB 13.8 MB/s 
[K     |████████████████████████████████| 163 kB 14.1 MB/s 
[?25h

In [3]:
!pip install -q sentencepiece

[?25l[K     |▎                               | 10 kB 31.9 MB/s eta 0:00:01[K     |▌                               | 20 kB 18.6 MB/s eta 0:00:01[K     |▊                               | 30 kB 24.4 MB/s eta 0:00:01[K     |█                               | 40 kB 12.4 MB/s eta 0:00:01[K     |█▎                              | 51 kB 13.6 MB/s eta 0:00:01[K     |█▌                              | 61 kB 15.8 MB/s eta 0:00:01[K     |█▉                              | 71 kB 15.4 MB/s eta 0:00:01[K     |██                              | 81 kB 17.0 MB/s eta 0:00:01[K     |██▎                             | 92 kB 16.7 MB/s eta 0:00:01[K     |██▋                             | 102 kB 14.0 MB/s eta 0:00:01[K     |██▉                             | 112 kB 14.0 MB/s eta 0:00:01[K     |███                             | 122 kB 14.0 MB/s eta 0:00:01[K     |███▍                            | 133 kB 14.0 MB/s eta 0:00:01[K     |███▋                            | 143 kB 14.0 MB/s eta 0:

In [7]:
import os
import numpy as np
import pandas as pd

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from google.colab import data_table
data_table.enable_dataframe_formatter()

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# Some important file locations and constants

project_root = "/content/drive/MyDrive/w266 NLP Final Project/"
dataset_root = project_root + "Data/"
model_root = project_root + "Models/"
prediction_folder = project_root + "Predictions/"

tokenizer = "google/t5-v1_1-base"

model_folder = model_root + model_name

validation_data_file = f"{dataset_root}squad.hf/valid_pairs.csv"
if validation_dataset_name != "squad":
  validation_data_file = f"{dataset_root}{validation_dataset_name}/valid_pairs.csv"

prediction_file = f"{prediction_folder}predictions.{model_name}.{validation_dataset_name}.csv"

max_length = 512
batch_size = 125

In [8]:
validation_df = pd.read_csv(validation_data_file)
validation_df[['orig', 'target']][:2]

Unnamed: 0,orig,target
0,generate question: answer: one context: Goliat...,"When David killed Goliath, how many of his fiv..."
1,generate question: answer: Apaches context: Ge...,Of which tribe of Red Indians was Geronimo a c...


In [9]:
# Download tokenizer and model, associate the model with the GPU

t5_tokenizer = T5Tokenizer.from_pretrained(tokenizer)
t5_model = T5ForConditionalGeneration.from_pretrained(model_folder)
t5_model.to(torch.device('cuda:0'))
pass

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.86k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/605 [00:00<?, ?B/s]

In [11]:
predictions = []
count = validation_df.shape[0]

print("Generating predictions:")
for start in range (0, count, batch_size):
  to = min([count, start + batch_size])
  inputs = t5_tokenizer(validation_df['orig'][start:to].to_list(), return_tensors='pt', max_length=max_length, truncation=True, padding=True)
  output_ids = t5_model.generate(inputs['input_ids'].cuda(), max_length=max_length)
  prediction_batch = t5_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  predictions.extend(prediction_batch)
  print (f"{to} ", end="")
  if to%1000 == 0: print()
print("Predictions generated")

Generating predictions:


RuntimeError: ignored

In [None]:
df=pd.DataFrame()
df['context'] = [str.split('context: ')[1] for str in validation_df['orig']]
df['answer'] =  [str.split('context: ')[0][26: ] for str in validation_df['orig']]
df['target'] = validation_df['target']
df['prediction'] = predictions

In [None]:
df[:10]

Unnamed: 0,context,answer,target,prediction
0,Following the unification of the Hejaz and Nej...,al-Mamlakah al-ʻArabīyah as-Suʻūdīyah,what was the real name of saudi arabia,What was the name of the new state?
1,This list contains the top ten pictures with t...,the name announcement of Kylie Jenner's first...,whats the most liked picture on instagram 2018,What is the name of her first child?
2,Alice Bowman (Meg Ryan) moves to the (fictiona...,the (fictional) South American country of Tec...,where does the movie proof of life take place,What country is Alice Bowman living in?
3,A common synonym for net profit when discussin...,on the bottom line of the report,where is net profit on the balance sheet,What is the bottom line of the report?
4,"Human fingerprints are detailed, nearly unique...",the early 20th century,when was fingerprinting first used by the police,What year was the fingerprint analysis?
5,The Los Angeles Lakers are an American profess...,in 2010,when was the last time the los angeles lakers ...,What year did they win the championship?
6,"Myofascial trigger points, also known as trigg...",hyperirritable spots in the fascia surroundin...,where are trigger points located in the body,What is myofascial trigger points?
7,USS Maine (ACR-1) is an American naval ship th...,Havana Harbor,where was the u.s.s maine when it exploded in ...,What was the location of the wreck?
8,"Patrick Walshe (July 26, 1900 – December 11, 1...",Patrick Walshe,who plays nikko in the wizard of oz,What was his name?
9,The Winter Olympics has been hosted on three c...,four,how many times have the winter olympics been i...,What countries have the Winter Olympics been h...


In [None]:
if save_predictions:
  df.to_csv(prediction_file)