In [None]:
% mkdir data
% cd data
! git clone https://github.com/iamyuanchung/TOEFL-QA.git
% cd ..

/content/data
Cloning into 'TOEFL-QA'...
remote: Enumerating objects: 1037, done.[K
remote: Total 1037 (delta 0), reused 0 (delta 0), pack-reused 1037[K
Receiving objects: 100% (1037/1037), 1.91 MiB | 3.63 MiB/s, done.
Resolving deltas: 100% (21/21), done.
/content


In [None]:
! pip install transformers
! pip install sentencepiece
! pip install rouge-score
! pip install -U nltk
! pip install datasets
! pip install bert_score

# Imports

In [None]:
import os
import sys
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import spacy
from tqdm.notebook import tqdm
import re
from pprint import pprint
import sentencepiece
import nltk
from rouge_score import rouge_scorer
nltk.download('all')
from nltk.translate import meteor_score
from datasets import load_dataset
from datasets import load_metric
import statistics
import string

[nltk_data] Downloading collection 'all'
[nltk_data]    | 
[nltk_data]    | Downloading package abc to /root/nltk_data...
[nltk_data]    |   Unzipping corpora/abc.zip.
[nltk_data]    | Downloading package alpino to /root/nltk_data...
[nltk_data]    |   Unzipping corpora/alpino.zip.
[nltk_data]    | Downloading package biocreative_ppi to
[nltk_data]    |     /root/nltk_data...
[nltk_data]    |   Unzipping corpora/biocreative_ppi.zip.
[nltk_data]    | Downloading package brown to /root/nltk_data...
[nltk_data]    |   Unzipping corpora/brown.zip.
[nltk_data]    | Downloading package brown_tei to /root/nltk_data...
[nltk_data]    |   Unzipping corpora/brown_tei.zip.
[nltk_data]    | Downloading package cess_cat to /root/nltk_data...
[nltk_data]    |   Unzipping corpora/cess_cat.zip.
[nltk_data]    | Downloading package cess_esp to /root/nltk_data...
[nltk_data]    |   Unzipping corpora/cess_esp.zip.
[nltk_data]    | Downloading package chat80 to /root/nltk_data...
[nltk_data]    |   Unzipp

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config

In [None]:
import importlib
util = importlib.import_module("data.TOEFL-QA.utils")
TOEFL_PATH = "./data/TOEFL-QA/data/"
raw = util.load_data(TOEFL_PATH)
train_raw, dev_raw, test_raw = tuple(raw)
dataset = load_dataset("race", "high")

# Options

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print('Using device:', device)

Using device: cuda


In [None]:
PRETRAINED_MODEL = 't5-base'
DIR = "question_generator/"
BATCH_SIZE = 1
SEQ_LENGTH = 512
EPOCHS = 200
USE_ANSWERS = False
TOEFL_epoch10 = "/content/drive/MyDrive/toeflqa_finetune.pt.epoch10"
TOEFL_epoch50 = "/content/drive/MyDrive/toeflqa_finetune.pt.epoch50"
TOEFL_epoch100 = "/content/toeflqa_finetune_epoch100.pt"
RACE_epoch5 = "/content/drive/MyDrive/qg_pretrained_t5_model_race_5epoch.pth"

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
tokenizer = T5Tokenizer.from_pretrained(PRETRAINED_MODEL)
tokenizer.add_special_tokens(
    {'additional_special_tokens': ['<answer>', '<context>']}
);
qg_model = T5ForConditionalGeneration.from_pretrained('t5-base')
qg_model.resize_token_embeddings(len(tokenizer))
#qg_model.load_state_dict(torch.load(RACE_epoch5)["model_state_dict"])

# Utility Functions

In [None]:
def get_sent_str(sentence_list):
    sent = " ".join(sentence_list)
    sent = re.sub(r" (?P<punc>[.?,])", r"\1", sent)
    return sent

def get_sent_list(sentences):
    sent_list = []
    for sent in sentences:
        sent_list.append(get_sent_str(sent))
    return sent_list

In [None]:
def get_contexts(sentences):
    out = []
    for i in range(3, len(sentences)+1):
        out.append(" ".join([get_sent_str(sent) for sent in sentences[i-3:i]]))
    return out

def encode_contexts(inputs, answers=None):
    out = []
    for i in range(len(inputs)):
        s = ""
        if USE_ANSWERS:
            s = '<answer> ' + inputs[i] + " <context> " + answers[i]
        else:
            s = inputs[i]
        out.append(tokenizer(
            s, 
            pad_to_max_length=True, 
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors="pt"
        ))
    return out

# Tokenizer Downloaded

In [None]:
def make_text(row):    
    encoded = {}
    if USE_ANSWERS:
        s = '<answer> ' + row['answer'] + ' <context> ' + row['article']
    else:
        s = row['article']
    encoded_text = tokenizer(
        s,
        pad_to_max_length=True, 
        max_length=SEQ_LENGTH,
        truncation=True,
        return_tensors="pt"
    )
    encoded['input_ids'] = torch.squeeze(encoded_text['input_ids'])
    encoded['attention_mask'] = torch.squeeze(encoded_text['attention_mask'])

    encoded_question = tokenizer(
        row['question'],
        pad_to_max_length=True,
        max_length=SEQ_LENGTH,
        truncation=True,
        return_tensors='pt'
    )
    encoded['input_ids_question'] = torch.squeeze(encoded_question['input_ids'])
    return encoded

dataset = dataset.map(make_text)
dataset.set_format(type = 'torch', columns=['input_ids', 'attention_mask', 'input_ids_question'])
valid_loader = DataLoader(dataset["validation"], batch_size=BATCH_SIZE, shuffle=True)

  0%|          | 0/3498 [00:00<?, ?ex/s]



  0%|          | 0/62445 [00:00<?, ?ex/s]

  0%|          | 0/3451 [00:00<?, ?ex/s]

# Evaluation of Model

In [None]:
def all_tpos(raw_data):
  result = dict()
  for sentence in raw_data.keys():
    digits = re.findall(r'\d+', sentence)
    types = 'conversation' if 'conversation' in sentence else 'lecture'
    name = 'tpo_' + digits[0] + "-" + types + "_" + digits[1]
    if name in result.keys():
      result[name] = result[name] + [digits[2]]
    else:
      result[name] = [digits[2]]
  return result

def all_race_passage(raw_data):
  x = {}
  seen = set()
  for i in range(len(raw_data)):
    if raw_data['article'][i] not in seen:
      seen.add(raw_data['article'][i])
      x[raw_data['article'][i]] = [raw_data['question'][i]]
    else:
      x[raw_data['article'][i]].append(raw_data['question'][i])
  return x

# If to use RACE, use dataset[x] for x in {'validation', 'test'} and TOEFL = False.
# If to use TOEFL, use dev_raw or test_raw and TOEFL = True.
# To see how each generated question is scored, print_detail = True
def evaluate_model(model, dev_raw, print_detail = False, TOEFL= False):
  results = {}
  device = "cpu" if TOEFL else "cuda"
  model.to(device)
  model.eval()
  if TOEFL: 
    raw = all_tpos(dev_raw)
    iterate = raw.keys()
  else:
    iterate = range(len(dev_raw))
    lookup = all_race_passage(dev_raw)
  bleu_total = []
  meteor_total = []
  rouge_total = []
  bert_total = []
  scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True)
  bar1 = tqdm(total = len(iterate))
  metric_bert = load_metric("bertscore")
  seen = set()
  for sentence in iterate:
    seen_TOEFL = set()
    if TOEFL:
      question = raw[sentence][0]
      sentence = sentence + "_" + question
      contexts = get_contexts(dev_raw[sentence]["sentences"])
      encoded_contexts = encode_contexts(contexts, )
      questions = []
      for i in encoded_contexts:
        question = model.generate(input_ids=i["input_ids"])
        question = tokenizer.decode(question[0], skip_special_tokens=True)
        result = ""
        for i in range(len(question)):
          char = question[i]
          if char not in string.punctuation:
            result += char
        question = result
        if question not in seen_TOEFL:
          questions.append(question)
          seen_TOEFL.add(question)
      ground_truth = [dev_raw[i]['question'] for i in dev_raw.keys() if i.startswith(sentence)]
    else:
      input_ids = (dev_raw["input_ids"][sentence]).to(device)
      input_ids_fixed = torch.tensor([input_ids.tolist()], device = device)
      question = model.generate(input_ids = input_ids_fixed)
      questions = tokenizer.decode(question[0].to('cpu'), skip_special_tokens=True)
      if questions in seen:
        continue
      seen.add(questions)
      questions = [questions]
      ground_truth = lookup[dev_raw['article'][sentence]]
      gt = []
      for i in range(len(ground_truth)):
        s = ground_truth[i]
        s = s.split(" ")
        gt.append(s)
      ground_truth = gt
    bleus = []
    meteors = []
    rouges = []
    berts = []
    for generated in questions:
      highest_bleu = 0.0
      highest_meteor = 0.0
      highest_rouge = 0.0
      highest_bert = 0.0
      for qs in ground_truth:
        truth = " ".join(qs)
        bert_scorer = metric_bert.compute(predictions = [generated], references = [truth], lang = "English")["f1"][0]
        r_score = scorer.score(truth, generated)
        rouge = r_score['rouge1'][2]
        generated_split = generated.split(" ")
        bleu = nltk.translate.bleu_score.sentence_bleu([qs], generated_split)
        meteor = nltk.translate.meteor_score.meteor_score([qs], generated_split)
        if bleu > highest_bleu:
          highest_bleu = bleu
        if meteor > highest_meteor:
          highest_meteor = meteor
        if rouge > highest_rouge:
          highest_rouge = rouge
        if bert_scorer > highest_bert:
          highest_bert = bert_scorer
      bleus.append(highest_bleu)
      meteors.append(highest_meteor)
      rouges.append(highest_rouge)
      berts.append(highest_bert)
    results[sentence] = {
            "questions": questions,
            "bleu": bleus,
            "meteor": meteors,
            "rouge": rouges,
            "bert": berts,
            "ground_truth": [" ".join(x) for x in ground_truth]
        }
    if print_detail:
      print(results[sentence]["questions"])
    bleu_total.append(statistics.mean(bleus))
    meteor_total.append(statistics.mean(meteors))
    rouge_total.append(statistics.mean(rouges))
    bert_total.append(statistics.mean(berts))
    bar1.update(1)
  bar1.close()
  return bleu_total, meteor_total, rouge_total, bert_total, results

In [None]:
bleu_total, meteor_total, rouge_total, bert_total, results = evaluate_model(qg_model, dataset['test'], False, False)

In [None]:
print(sorted(bleu_total)[-5:])
print(sorted(meteor_total)[-5:])
print(sorted(rouge_total)[-5:])
print(sorted(bert_total)[-5:])

0.5176526474654245
0.6237762237762239
0.8
0.8235591053962708
