In [None]:
import argparse
import collections
import re
import string
import tqdm
import json
import numpy as np
import pickle

In [None]:
input_file = "results/processed.npy"
conf_scores, f1_scores = np.load(input_file)

In [None]:
filename = "results/predictions.json"
with open(filename, "r") as f:
    raw_data = json.load(f)

In [None]:
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def f1_score(prediction, answer):
    """Compute F1 score between prediction tokens and ground truth tokens."""
    prediction_tokens = normalize_answer(prediction).split()
    answer_tokens = normalize_answer(answer).split()
    common = (collections.Counter(prediction_tokens) &
              collections.Counter(answer_tokens))
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(answer_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def metric_max_over_answers(metric_fn, prediction, answer_set):
    """Return the maximum score between any (prediction, answer) pair."""
    max_score = -float("inf")
    for answer in answer_set:
        score = metric_fn(prediction, answer)
        max_score = max(max_score, score)
    return max_score

In [None]:
top_docs = 100
answers_list = []
questions_list = []
dataset = []

for entry in tqdm.tqdm(raw_data, "processing data..."):
    # Get questions.
    questions = entry["question"]
    questions_list.append(questions)

    # Get answers.
    answers = entry["gold_answers"]
    answers_list.append(answers)

    # Get predictions.
    predictions = None
    for p in entry["predictions"]:
        if p["top_k"] == top_docs:
            predictions = p["predictions"]
            break
    if predictions is None:
        raise RuntimeError(f"Could not find entry corresponding to top_k={top_docs}.")

    # Score predictions by F1, and sort by joint doc * span confidence score.
    deduped = set()
    scored_predictions = []
    for p in predictions:
        if normalize_answer(p["text"]) in deduped:
            continue
        deduped.add(normalize_answer(p["text"]))
        f1 = metric_max_over_answers(f1_score, p["text"], answers)
        scored_predictions.append((p["score"] + p["relevance_score"], f1))
    scored_predictions = sorted(scored_predictions, key=lambda x: -x[0])

    # Store in dataset as ordered nested set.
    conf_scores, pred_scores = zip(*scored_predictions)
    set_scores = np.maximum.accumulate(pred_scores).tolist()
    dataset.append((conf_scores, set_scores))

In [None]:
# save questions_list
with open("results/questions_list", "wb") as fp:
    pickle.dump(questions_list, fp)

In [None]:
# load questions_list
with open("results/questions_list", "rb") as fp:
    questions_list = pickle.load(fp)
questions_list

In [None]:
# get weights

from sentence_transformers import SentenceTransformer, util

model = SentenceTransformer('multi-qa-mpnet-base-dot-v1')

query_embedding = model.encode(questions_list)
passage_embedding = model.encode(questions_list)

# similarity
W_all = util.dot_score(query_embedding, query_embedding).numpy() # [3610, 3610]

# normalize
min_value = np.min(W_all)
max_value = np.max(W_all)
W_all = (W_all - min_value) / (max_value - min_value)

# save W_all
np.save("results/w-normalized-multi-qa-mpnet-base-dot-v1", W_all)