In [1]:
import pandas as pd

In [2]:
from src.pycocoevalcap.meteor.meteor import Meteor
from src.pycocoevalcap.rouge.rouge import Rouge
from src.pycocoevalcap.bleu.bleu import Bleu
from sentence_transformers import SentenceTransformer, util

meteor_obj = Meteor()
rouge_obj = Rouge()
bleu_obj = Bleu(4)

In [3]:
from datasets import load_dataset
import random

In [4]:
def white_space_fix(text):
    return " ".join(text.split())

def q_process_squad_row(row):
        context = row["context"]
        question = row["question"]
        if row["answers"]["text"]:
            answ = random.choice(row["answers"]["text"])
            return {
                "article": white_space_fix(
                    "answer: " + answ + " context: " + context + " </s>"
                ),
                "answer": white_space_fix(question + " </s>"),
            }
        else:
            return {
                "article": "NONE",
                "answer": "NONE",
            }

def test_question_prediction(pred_file_name, dataset='squad_v2'):
    df = pd.read_csv(pred_file_name).astype(str)
    predictions = df["predictions_str"].tolist()
    normal_preds = [white_space_fix(pred).removesuffix(' </s>').removeprefix('question: ') for pred in predictions]

    if dataset == 'squad_v2':
        dev_dataset = load_dataset("squad_v2", split="validation")
        dev_dataset = dev_dataset.map(
            q_process_squad_row,
            remove_columns=["id", "title", "context", "question", "answers"],
        ).filter(lambda row: row["article"] != "NONE")

    elif dataset == 'drop':
        dev_dataset = load_dataset("drop", split="validation")
        dev_dataset = dev_dataset.map(
            q_process_drop_row,
            remove_columns=[
                "passage",
                "question",
                "answers_spans",
            ],
        ).filter(lambda row: row["article"] != "NONE")

    gold_lines = []
    for row in dev_dataset:
        gold_line = white_space_fix(row["answer"].strip()).removesuffix(' </s>')
        gold_lines.append(gold_line)

    assert len(gold_lines) == len(normal_preds)

    word_target_dict = {}
    word_response_dict = {}

    for i in range(len(gold_lines)):
        word_target_dict[i] = [gold_lines[i]]
        word_response_dict[i] = [normal_preds[i]]

    bleu_score, bleu_scores = bleu_obj.compute_score(
            word_target_dict, word_response_dict)

    bleu1_score, _, _, bleu4_score = bleu_score

    bleu1_scores, _, _, bleu4_scores = bleu_scores

    rouge_score, rouge_scores = rouge_obj.compute_score(
            word_target_dict, word_response_dict) 

    '''
    model = SentenceTransformer('stsb-roberta-large')

    embedding1 = model.encode(gold_lines, convert_to_tensor=True)
    embedding2 = model.encode(normal_preds, convert_to_tensor=True)
    cosine_scores = util.pytorch_cos_sim(embedding1, embedding2)
    sim = 0.0
    for i in range(len(gold_lines)):
        sim += cosine_scores[i][i].item()

    mean_sim = sim / len(gold_lines)
    '''
    #return {"ROUGE-L": rouge_score, "BLEU-1": bleu1_score, "BLEU-4": bleu4_score, "COS": mean_sim}
    return {"ROUGE-L": rouge_score, "BLEU-1": bleu1_score, "BLEU-4": bleu4_score}

In [5]:
main_path = "/Users/saeed/Desktop/codes/repos/dreamscape-qa/pretrained_models/august_25_runs/re_question_generation_model/"
test_question_prediction(main_path + "squad_dev.epoch0.csv")


Downloading:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading and preparing dataset squad_v2/squad_v2 (download: 44.34 MiB, generated: 122.41 MiB, post-processed: Unknown size, total: 166.75 MiB) to /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a...


Downloading:   0%|          | 0.00/9.55M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/801k [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset squad_v2 downloaded and prepared to /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a. Subsequent calls will reuse this data.


  0%|          | 0/11873 [00:00<?, ?ex/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

{'testlen': 54538, 'reflen': 61434, 'guess': [54538, 48610, 42682, 36754], 'correct': [23633, 11089, 5976, 3297]}
ratio: 0.8877494546993377


{'ROUGE-L': 0.3943826107626814,
 'BLEU-1': 0.381861380939652,
 'BLEU-4': 0.16541608378455783}

In [6]:
main_path = "/Users/saeed/Desktop/codes/repos/dreamscape-qa/pretrained_models/august_25_runs/re_question_generation_model/"
test_question_prediction(main_path + "squad_dev.epoch1.csv")

Reusing dataset squad_v2 (/Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a)
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-fb385097f0d48656.arrow
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-b76ee7e16ed15846.arrow


{'testlen': 55177, 'reflen': 61434, 'guess': [55177, 49249, 43321, 37393], 'correct': [24258, 11607, 6387, 3610]}
ratio: 0.8981508610866801


{'ROUGE-L': 0.4028527235580784,
 'BLEU-1': 0.39250796905803037,
 'BLEU-4': 0.17495859144948325}

In [7]:
main_path = "/Users/saeed/Desktop/codes/repos/dreamscape-qa/pretrained_models/august_25_runs/re_question_generation_model/"
test_question_prediction(main_path + "squad_dev.epoch2.csv")

Reusing dataset squad_v2 (/Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a)
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-fb385097f0d48656.arrow
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-b76ee7e16ed15846.arrow


{'testlen': 54717, 'reflen': 61434, 'guess': [54717, 48789, 42861, 36933], 'correct': [23821, 11496, 6325, 3575]}
ratio: 0.8906631506982959


{'ROUGE-L': 0.39641388970731256,
 'BLEU-1': 0.3850562598604746,
 'BLEU-4': 0.17304818039824393}

In [8]:
main_path = "/Users/saeed/Desktop/codes/repos/dreamscape-qa/pretrained_models/august_25_runs/re_question_generation_model/"
test_question_prediction(main_path + "squad_dev.epoch3.csv")

Reusing dataset squad_v2 (/Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a)
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-fb385097f0d48656.arrow
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-b76ee7e16ed15846.arrow


{'testlen': 56229, 'reflen': 61434, 'guess': [56229, 50301, 44373, 38445], 'correct': [24612, 11831, 6424, 3620]}
ratio: 0.915274929192289


{'ROUGE-L': 0.4046596652099174,
 'BLEU-1': 0.3990109570545346,
 'BLEU-4': 0.1764391602821799}

In [9]:
main_path = "/Users/saeed/Desktop/codes/repos/dreamscape-qa/pretrained_models/august_25_runs/re_question_generation_model/"
test_question_prediction(main_path + "squad_dev.epoch4.csv")

Reusing dataset squad_v2 (/Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a)
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-fb385097f0d48656.arrow
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-b76ee7e16ed15846.arrow


{'testlen': 56525, 'reflen': 61434, 'guess': [56525, 50597, 44669, 38742], 'correct': [24274, 11659, 6334, 3584]}
ratio: 0.9200931080509015


{'ROUGE-L': 0.39845488368514137,
 'BLEU-1': 0.39371667491332285,
 'BLEU-4': 0.17402306291567513}

In [10]:
main_path = "/Users/saeed/Desktop/codes/repos/dreamscape-qa/pretrained_models/august_25_runs/re_question_generation_model/"
test_question_prediction(main_path + "squad_dev.epoch5.csv")

Reusing dataset squad_v2 (/Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a)
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-fb385097f0d48656.arrow
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-b76ee7e16ed15846.arrow


{'testlen': 55817, 'reflen': 61434, 'guess': [55817, 49889, 43961, 38033], 'correct': [24168, 11729, 6422, 3616]}
ratio: 0.9085685451053015


{'ROUGE-L': 0.39808810588953436,
 'BLEU-1': 0.3915345714993328,
 'BLEU-4': 0.17534611178575313}

In [11]:
main_path = "/Users/saeed/Desktop/codes/repos/dreamscape-qa/pretrained_models/august_25_runs/re_question_generation_model/"
test_question_prediction(main_path + "squad_dev.web.csv")

Reusing dataset squad_v2 (/Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a)
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-fb385097f0d48656.arrow
Loading cached processed dataset at /Users/saeed/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/de2e67b822b2ef3f4b137148d0758f48075e3892c359c50271ef6c9add7e794a/cache-b76ee7e16ed15846.arrow


{'testlen': 55336, 'reflen': 61434, 'guess': [55336, 49408, 43480, 37554], 'correct': [25144, 12493, 7058, 4123]}
ratio: 0.9007390044600564


{'ROUGE-L': 0.41789665604499515,
 'BLEU-1': 0.40697484544194795,
 'BLEU-4': 0.19052514121187367}