% mkdir data
% cd data
! git clone https://github.com/iamyuanchung/TOEFL-QA.git
% cd ..

! pip install transformers
! pip install sentencepiece
! pip install rouge-score
! pip install -U nltk

# Imports

In [1]:
import os
import sys
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import spacy
from tqdm import tqdm
import re
from pprint import pprint
import sentencepiece
import nltk
from rouge_score import rouge_scorer
nltk.download('all')
from nltk.translate import meteor_score

[nltk_data] Downloading collection 'all'
[nltk_data]    | 
[nltk_data]    | Downloading package abc to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package abc is already up-to-date!
[nltk_data]    | Downloading package alpino to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package alpino is already up-to-date!
[nltk_data]    | Downloading package biocreative_ppi to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package biocreative_ppi is already up-to-date!
[nltk_data]    | Downloading package brown to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package brown is already up-to-date!
[nltk_data]    | Downloading package brown_tei to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package brown_tei is already up-to-date!
[nltk_data]    | Downloading package cess_cat to
[nltk_data]    |     C:\Users\K

[nltk_data]    |   Package ptb is already up-to-date!
[nltk_data]    | Downloading package product_reviews_1 to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package product_reviews_1 is already up-to-date!
[nltk_data]    | Downloading package product_reviews_2 to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package product_reviews_2 is already up-to-date!
[nltk_data]    | Downloading package pros_cons to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package pros_cons is already up-to-date!
[nltk_data]    | Downloading package qc to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package qc is already up-to-date!
[nltk_data]    | Downloading package reuters to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package reuters is already up-to-date!
[nltk_data]    | Downloading package rte to
[nltk_data]  

[nltk_data]    |   Package word2vec_sample is already up-to-date!
[nltk_data]    | Downloading package panlex_swadesh to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package panlex_swadesh is already up-to-date!
[nltk_data]    | Downloading package mte_teip5 to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package mte_teip5 is already up-to-date!
[nltk_data]    | Downloading package averaged_perceptron_tagger to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package averaged_perceptron_tagger is already up-
[nltk_data]    |       to-date!
[nltk_data]    | Downloading package averaged_perceptron_tagger_ru to
[nltk_data]    |     C:\Users\Kevin\AppData\Roaming\nltk_data...
[nltk_data]    |   Package averaged_perceptron_tagger_ru is already
[nltk_data]    |       up-to-date!
[nltk_data]    | Downloading package perluniprops to
[nltk_data]    |     C:\Users\Kevin\AppData\Roamin

In [2]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config

In [3]:
import importlib
util = importlib.import_module("data.TOEFL-QA.utils")
TOEFL_PATH = "./data/TOEFL-QA/data/"
raw = util.load_data(TOEFL_PATH)
train_raw, dev_raw, test_raw = tuple(raw)

# Options

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
print('Using device:', device)

Using device: cuda


In [15]:
PRETRAINED_MODEL = 't5-base'
DIR = "question_generator/toeflqa_finetune_hf_withanswer/"
BATCH_SIZE = 1
SEQ_LENGTH = 512
EPOCHS = 200
FNAME = "toeflqa_finetune_withanswer.pt.epoch100"
USE_ANSWERS = False

In [6]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
qg_model = T5ForConditionalGeneration.from_pretrained('t5-base')

tokenizer = T5Tokenizer.from_pretrained(PRETRAINED_MODEL)
tokenizer.add_special_tokens(
    {'additional_special_tokens': ['<answer>', '<context>']}
);


qg_model.resize_token_embeddings(len(tokenizer)) # to account for new special tokens
trained = torch.load(DIR + FNAME)
qg_model.load_state_dict(trained["model_state_dict"])
qg_model = qg_model.to(device)

In [7]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
  
causal_tokenizer = AutoTokenizer.from_pretrained("noahjadallah/cause-effect-detection")

causal_model = AutoModelForTokenClassification.from_pretrained("noahjadallah/cause-effect-detection")

# https://colab.research.google.com/drive/14V9Ooy3aNPsRfTK88krwsereia8cfSPc?usp=sharing#scrollTo=eqYFDe_2HfQ7


# 2.2. Cause Begin -> B-Cause -> 1
# 2.3. Cause Inside -> I-Cause -> 2
# 2.4. Effect Begin -> B-Effect -> 3
# 2.5. Effect Inside -> I-Effect -> 4

label_list = ['O', 'B-CAUSE', 'I-CAUSE', 'B-EFFECT', 'I-EFFECT']

# Utility Functions

In [8]:
def get_sent_str(sentence_list):
    sent = " ".join(sentence_list)
    sent = re.sub(r" (?P<punc>[.?,])", r"\1", sent)
    return sent

def get_sent_list(sentences):
    sent_list = []
    for sent in sentences:
        sent_list.append(get_sent_str(sent))
    return sent_list

In [9]:
def set_fuzzy_context(key, raw_data):
    question = [raw_data[key]["question"]]
    results = []
    for ref in get_sent_list(raw_data[key]["sentences"]):
        results.append(bertscore.compute(predictions=question, references=[ref], lang='en'))
    idx = np.argsort(-1 * np.array([i["precision"] for i in results]).ravel())
    top5 = idx[:5]
    sent_list = get_sent_list(raw_data[key]["sentences"])
    raw_data[key]["context"] = " ".join([sent_list[i] for i in sorted(top5)]) # reorder sentences

In [10]:
def get_causation_prediction(sequence: str):
    tokens = causal_tokenizer.tokenize(causal_tokenizer.decode(causal_tokenizer.encode(sequence)))
    inputs = causal_tokenizer.encode(sequence, return_tensors="pt")

    outputs = causal_model(inputs).logits
    predictions = torch.argmax(outputs, dim=2).numpy()
    effects = [tokens[i] for i in range(len(tokens)) if predictions[0][i] > 2]
    return effects

def get_contexts(sentences):
    out = []
    for i in range(4, len(sentences)+1):
        center = i - 2
        effect = get_causation_prediction(get_sent_str(sentences[center]))
        if len(effect) < 2: # It's possible there's no effect in this sentencee
            continue
        effect = " ".join(effect)
        context = " ".join([get_sent_str(sent) for sent in sentences[center-2:center+2]])
        out.append('<answer> ' + effect + " <context> " + context)
    return out

def encode_contexts(inputs, answers=None):
    out = []
    for i in range(len(inputs)):
        s = ""
        s = inputs[i]
        out.append(tokenizer(
            s, 
            pad_to_max_length=True, 
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors="pt"
        ).to(device))
    return out

# Evaluation of Model

In [11]:
def all_tpos(raw_data):
    result = dict()
    for sentence in raw_data.keys():
        digits = re.findall(r'\d+', sentence)
        types = 'conversation' if 'conversation' in sentence else 'lecture'
        name = 'tpo_' + digits[0] + "-" + types + "_" + digits[1]
        if name in result.keys():
            result[name] = result[name] + [digits[2]]
        else:
            result[name] = [digits[2]]
    return result

def evaluate_model(model, raw, print_detail = False):
    results = {}
    model.to(device)
    model.eval()
    raw_tpos = all_tpos(raw)
    bleu_total = []
    meteor_total = []
    rouge_total = []
    scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True)
    for key_base in raw_tpos.keys():
        question = raw_tpos[key_base][0]
        raw_key = key_base + "_" + question
        contexts = get_contexts(raw[raw_key]["sentences"])
        encoded_contexts = encode_contexts(contexts)
        questions = []
        for i in encoded_contexts:
            question = model.generate(input_ids=i["input_ids"])
            questions.append(tokenizer.decode(question[0], skip_special_tokens=True))
            # print(tokenizer.decode(question[0], skip_special_tokens=True))
        ground_truth = [raw[i]['question'] for i in raw.keys() if i.startswith(key_base)]
        bleus = []
        meteors = []
        rouges = []
        for generated in tqdm(questions):
            highest_bleu = 0.0
            highest_meteor = 0.0
            highest_rouge = 0.0
            for truth_split in ground_truth:
                truth = " ".join(truth_split)
                generated_split = generated.split(" ")
                r_score = scorer.score(truth, generated)
                rouge = r_score['rouge1'][2]
                bleu = nltk.translate.bleu_score.sentence_bleu([truth_split], generated_split)
                meteor = nltk.translate.meteor_score.meteor_score([truth_split], generated_split)
                if bleu > highest_bleu:
                    highest_bleu = bleu
                if meteor > highest_meteor:
                    highest_meteor = meteor
                if rouge > highest_rouge:
                    highest_rouge = rouge
            bleus.append(highest_bleu)
            meteors.append(highest_meteor)
            rouges.append(highest_rouge)
        results[key_base] = {
            "questions": questions,
            "bleu": bleus,
            "meteor": meteors,
            "rouge": rouges,
            "ground_truth": ground_truth
        }
        bleu_total.append(sum(bleus) / len(bleus))
        meteor_total.append(sum(meteors) / len(meteors))
        rouge_total.append(sum(rouges) / len(rouges))
    return results, bleu_total, meteor_total, rouge_total

In [13]:
results, bleu_total, meteor_total, rouge_total = evaluate_model(qg_model, test_raw, True)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=29.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=37.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=27.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=34.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=36.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=38.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=28.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=34.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=31.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=31.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=33.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=31.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=28.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=34.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=24.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=26.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=38.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=36.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=37.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=36.0), HTML(value='')))


{'tpo_22-conversation_1': {'questions': ["what is the professor's opinion on the paper office", "what is the professor's answer to the question", "what is the professor's answer to the question", "what is the professor's answer to the question", 'what does the professor imply about the statue at the main entrance of the university', 'what does the professor imply about the statue at the main entrance of the university', 'what does the professor imply about the paper', "what is the professor's opinion on the matter", "what is the professor's opinion", 'what does the professor say about the statue', "what is the professor's opinion", 'what does the professor say about the sally koenig', "what is the professor's opinion on the subject", 'what does the professor imply about the government organization', "what is the professor's opinion on the article", "what is the professor's opinion on the article", "what is the professor's opinion on the article", "what is the professor's opinion on th

In [18]:
import pickle
with open(f"{FNAME}.results", "wb+") as f:
    pickle.dump(results, f)

In [19]:
results

{'tpo_22-conversation_1': {'questions': ["what is the professor's opinion on the paper office",
   "what is the professor's answer to the question",
   "what is the professor's answer to the question",
   "what is the professor's answer to the question",
   'what does the professor imply about the statue at the main entrance of the university',
   'what does the professor imply about the statue at the main entrance of the university',
   'what does the professor imply about the paper',
   "what is the professor's opinion on the matter",
   "what is the professor's opinion",
   'what does the professor say about the statue',
   "what is the professor's opinion",
   'what does the professor say about the sally koenig',
   "what is the professor's opinion on the subject",
   'what does the professor imply about the government organization',
   "what is the professor's opinion on the article",
   "what is the professor's opinion on the article",
   "what is the professor's opinion on the a

In [25]:
sorted(results.items(), key=lambda x: -max(x[1]["bleu"]))

[('tpo_22-lecture_3',
  {'questions': ['according to the professor, what is the definition of a mass extinction',
    'according to the professor, the current loss of bio diversity can be traced to human to human',
    'according to the professor, humans have been eliminating species and altering ecosystems with as ##',
    'according to the professor, what is the reason why the species are disappearing',
    'according to the professor, megafa ##una include elephants, wild horses,',
    'what does the professor imply about megafauna',
    'what does the professor mean by the proposal focuses on a particular subset of megaf',
    'what is the lecture mainly about',
    "what is the professor's opinion on the pleistocene",
    "what is the professor's idea",
    'what does the professor imply about the pleistocene rewilding',
    'what does the professor imply about the pleistocene rewilding',
    'what does the professor mean by pleistocene rewilding',
    'what does the professor impl

In [23]:
results.items()

dict_items([('tpo_22-conversation_1', {'questions': ["what is the professor's opinion on the paper office", "what is the professor's answer to the question", "what is the professor's answer to the question", "what is the professor's answer to the question", 'what does the professor imply about the statue at the main entrance of the university', 'what does the professor imply about the statue at the main entrance of the university', 'what does the professor imply about the paper', "what is the professor's opinion on the matter", "what is the professor's opinion", 'what does the professor say about the statue', "what is the professor's opinion", 'what does the professor say about the sally koenig', "what is the professor's opinion on the subject", 'what does the professor imply about the government organization', "what is the professor's opinion on the article", "what is the professor's opinion on the article", "what is the professor's opinion on the article", "what is the professor's op