diff --git a/belar/metrics/__init__.py b/belar/metrics/__init__.py index e9bef643f..70fe4a3e5 100644 --- a/belar/metrics/__init__.py +++ b/belar/metrics/__init__.py @@ -1,7 +1,8 @@ from belar.metrics.base import Evaluation, Metric from belar.metrics.factual import EntailmentScore from belar.metrics.similarity import SBERTScore -from belar.metrics.simple import BLUE, EditDistance, EditRatio, Rouge1, Rouge2, RougeL +from belar.metrics.simple import (BLUE, EditDistance, EditRatio, Rouge1, + Rouge2, RougeL) __all__ = [ "Evaluation", diff --git a/belar/metrics/factual.py b/belar/metrics/factual.py index 670a15c91..3a9769280 100644 --- a/belar/metrics/factual.py +++ b/belar/metrics/factual.py @@ -1,9 +1,16 @@ from __future__ import annotations +import json +import re +import string import typing as t from dataclasses import dataclass -from transformers import AutoModelForSequenceClassification, AutoTokenizer +import numpy as np +import spacy +import transformers +from transformers import (AutoConfig, AutoModelForSequenceClassification, + AutoTokenizer, PreTrainedModel) from belar.metrics import Metric from belar.utils import device_check @@ -11,6 +18,20 @@ if t.TYPE_CHECKING: from torch import device as Device +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, + MODEL_WITH_LM_HEAD_MAPPING_NAMES) + +MODEL_MAPPINGS_NAMES = [ + MODEL_WITH_LM_HEAD_MAPPING_NAMES, + MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, +] + +DEVICES = ["cpu", "cuda"] +SPACY_MODEL = "en_core_web_sm" +LABEL2SCORE = {"entailment": 1, "contradiction": 0, "neutral": 0.5} +EPS = 1e-8 + @dataclass class EntailmentScore(Metric): @@ -47,6 +68,20 @@ def name(self): def is_batchable(self): return True + def infer(self, ground_truth: str, generated_text: str): + encodings = self.tokenizer( + ground_truth, + generated_text, + truncation=True, + return_tensors="pt", + max_length=self.max_length, + padding="max_length", + ) + label2id = {value.lower(): key for key, value in self.id2label.items()} + output = self.model(**encodings) + pred = output.logits.softmax(axis=-1).detach().cpu().squeeze() + return {label: pred[id].item() for label, id in label2id.items()} + def batch_infer(self, inputs: dict): predictions = [] input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"] @@ -87,3 +122,218 @@ def score( score = self.batch_infer(encodings) return score + + +class QAGQ: + def __init__( + self, + model: PreTrainedModel, + model_name_or_path: str, + device: t.Literal["cpu", "cuda"] | Device = "cpu", + ): + self.model = model.from_pretrained(model_name_or_path) + self.model.eval() # type: ignore + self.device = device_check(device) + self.model.to(self.device) # type: ignore + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + + @classmethod + def from_pretrained(cls, model_name_or_path): + config = AutoConfig.from_pretrained(model_name_or_path) + model_mappings = [ + arch for model_type in MODEL_MAPPINGS_NAMES for arch in model_type.values() + ] + architecture = np.intersect1d(model_mappings, config.architectures) + if len(architecture) == 0: + raise ValueError("Model doesn't support QA or LM architecture") + model = getattr(transformers, architecture[0]) + return cls(model, model_name_or_path) + + def batch_generate_question(self, answers: list[str], context: str, **kwargs): + input_texts = [ + "answer: %s context: %s " % (ans, context) for ans in answers + ] + max_length = kwargs.pop("input_max_length", 512) + encodings = self.tokenizer( + input_texts, + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="pt", + ) + encodings = {k: v.to(self.device) for k, v in encodings.items()} + outputs = self.model.generate(**encodings, **kwargs) # type: ignore + outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + return [output.replace("question:", "").strip() for output in outputs] + + def batch_generate_answers(self, questions: list[str], context: str, **kwargs): + max_length = kwargs.pop("input_max_length", 512) + encodings = self.tokenizer( + questions, + [context] * len(questions), + padding="max_length", + truncation=True, + max_length=max_length, + return_tensors="pt", + ) + encodings = { + k: v.view(-1, max_length).to(self.device) for k, v in encodings.items() + } + poss_ans_starts, poss_ans_ends = self.model( + **encodings, return_dict=False + ) # type: ignore + best_start = poss_ans_starts.argmax(1) + best_ends = poss_ans_ends.argmax(1) + answers = [ + encodings["input_ids"][i][start : end + 1] + for i, (start, end) in enumerate(zip(best_start, best_ends)) + ] + answers = self.tokenizer.batch_decode(answers) + return answers + + +@dataclass +class Qsquare(Metric): + qa_model_name: str = "consciousAI/question-answering-roberta-base-s" + qg_model_name: str = "mrm8488/t5-base-finetuned-question-generation-ap" + device: t.Literal["cpu", "cuda"] = "cpu" + max_answers: int = 10 + crosscheck_candidates: bool = True + load_single = False + batch_size: int = 4 + include_nouns: bool = True + save_results: bool = False + + def __post_init__( + self, + ): + self.nlp = spacy.load(SPACY_MODEL) + self.qa = QAGQ.from_pretrained(self.qa_model_name) + self.qg = QAGQ.from_pretrained(self.qg_model_name) + + @property + def name(self): + return "Q^2" + + @property + def is_batchable(self): + return True + + def generate_candidates(self, text: str): + text = text.strip() + nouns = [ + i.text.lower() + for i in self.nlp(text).noun_chunks + if i.text.lower() not in self.nlp.Defaults.stop_words + ] + entities = set([ent.text.lower() for ent in self.nlp(text).ents]) + num_nouns = max(0, self.max_answers - len(entities)) + nouns = list(np.setdiff1d(nouns, list(entities))) + if nouns and self.include_nouns: + nouns = np.random.choice(nouns, size=num_nouns).tolist() + else: + nouns = [] + + return list(entities.union(set(nouns))) + + def generate_questions(self, candidates: list[str], context: str, **kwargs): + questions = [] + for idx in range(0, len(candidates), self.batch_size): + batch_questions = self.qg.batch_generate_question( + candidates[idx : idx + self.batch_size], context, **kwargs + ) + questions.extend( + [qstn if qstn.endswith("?") else f"{qstn}?" for qstn in batch_questions] + ) + assert len(questions) == len(candidates), "Missing question for some candidates" + return questions + + def generate_answers(self, questions: list[str], context: str): + answers = [] + for idx in range(0, len(questions), self.batch_size): + batch_answers = self.qa.batch_generate_answers( + questions[idx : idx + self.batch_size], context + ) + answers.extend(batch_answers) + assert len(answers) == len(questions), "Missing answers for some questions" + return answers + + def filter_candidates( + self, questions: list[str], candidates: list[str], gen_answers: list[str] + ): + final_questions = [] + final_candidates = [] + for qstn, ans1, ans2 in zip(questions, candidates, gen_answers): + if self.clean_candidate(ans1) == self.clean_candidate(ans2): + final_candidates.append(ans1) + final_questions.append(qstn) + + return final_questions, final_candidates + + def clean_candidate(self, text): + text = text.strip().lower() + text = text.translate(str.maketrans("", "", string.punctuation)) + text = re.sub(r"\b(a|an|the|in|our)\b", " ", text) + + return text + + def score_candidates(self, ques_ans_dict: dict): + nli = EntailmentScore() + for qas in ques_ans_dict.values(): + for item in qas: + item["answer"] = self.clean_candidate(item["answer"]) + item["predicted_answer"] = self.clean_candidate( + item["predicted_answer"] + ) + if item["answer"] == item["predicted_answer"]: + item.update({"score": 1}) + else: + qstn = item.get("question") + score_dict = nli.infer( + f'{qstn}{item.get("answer")}', + f'{qstn}{item.get("predicted_answer")}', + ) + label = max(zip(score_dict.values(), score_dict.keys()))[1] + item.update({"score": LABEL2SCORE[label]}) + + return ques_ans_dict + + def score(self, ground_truth: list[str], generated_text: list[str], **kwargs): + gnd_qans = {} + ans_candidates = [self.generate_candidates(text) for text in ground_truth] + for i, (candidates, context) in enumerate(zip(ans_candidates, ground_truth)): + questions = self.generate_questions(candidates, context, **kwargs) + gen_answers = self.generate_answers(questions, context) + if self.crosscheck_candidates: + questions, candidates = self.filter_candidates( + questions, candidates, gen_answers + ) + gnd_qans[i] = [ + {"question": qstn, "answer": ans} + for qstn, ans in zip(questions, candidates) + ] + + for i, gen_text in enumerate(generated_text): + questions = [item["question"] for item in gnd_qans[i]] + gen_answers = self.generate_answers(questions, gen_text) + _ = [ + item.update({"predicted_answer": ans}) + for item, ans in zip(gnd_qans[i], gen_answers) + ] + + del self.qa + del self.qg + + gnd_qans = self.score_candidates(gnd_qans) + + if self.save_results: + with open("qa-qj-intermediate.json", "w") as file: + json.dump(gnd_qans, file, indent=4) + + scores = [[dic["score"] for dic in item] for item in gnd_qans.values()] + scores = [sum(sublist) / (len(sublist) + EPS) for sublist in scores] + return scores + + +ENTScore = EntailmentScore() +Q2Score = Qsquare() diff --git a/belar/utils.py b/belar/utils.py index 403a1cc98..cc9eb84fa 100644 --- a/belar/utils.py +++ b/belar/utils.py @@ -1,4 +1,5 @@ from __future__ import annotations + import typing as t from warnings import warn diff --git a/pyproject.toml b/pyproject.toml index e1af83e8d..2c84c3a99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "sentence-transformers", "nltk", "datasets", + "spacy", ] dynamic = ["version", "readme"] diff --git a/tests/benchmarks/benchmark.py b/tests/benchmarks/benchmark.py index 12732828b..50d63beb5 100644 --- a/tests/benchmarks/benchmark.py +++ b/tests/benchmarks/benchmark.py @@ -5,16 +5,8 @@ from tqdm import tqdm from utils import print_table, timeit -from belar.metrics import ( - EditDistance, - EditRatio, - EntailmentScore, - Evaluation, - Rouge1, - Rouge2, - RougeL, - SBERTScore, -) +from belar.metrics import (EditDistance, EditRatio, EntailmentScore, + Evaluation, Rouge1, Rouge2, RougeL, SBERTScore) DEVICE = "cuda" if is_available() else "cpu" BATCHES = [0, 1]