diff --git a/belar/metrics/__init__.py b/belar/metrics/__init__.py index 70fe4a3e5..e9bef643f 100644 --- a/belar/metrics/__init__.py +++ b/belar/metrics/__init__.py @@ -1,8 +1,7 @@ from belar.metrics.base import Evaluation, Metric from belar.metrics.factual import EntailmentScore from belar.metrics.similarity import SBERTScore -from belar.metrics.simple import (BLUE, EditDistance, EditRatio, Rouge1, - Rouge2, RougeL) +from belar.metrics.simple import BLUE, EditDistance, EditRatio, Rouge1, Rouge2, RougeL __all__ = [ "Evaluation", diff --git a/belar/metrics/factual.py b/belar/metrics/factual.py index 8999f6d9f..670a15c91 100644 --- a/belar/metrics/factual.py +++ b/belar/metrics/factual.py @@ -49,12 +49,16 @@ def is_batchable(self): def batch_infer(self, inputs: dict): predictions = [] - input_ids = inputs["input_ids"] + input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"] label2id = {value.lower(): key for key, value in self.id2label.items()} for idx in range(0, len(input_ids), self.batch_size): - batch_ids = input_ids[idx : idx + self.batch_size] - output = self.model(batch_ids.to(self.device)) + batch_inp_ids = input_ids[idx : idx + self.batch_size] + batch_attn_mask = attention_mask[idx : idx + self.batch_size] + + output = self.model( + batch_inp_ids.to(self.device), batch_attn_mask.to(self.device) + ) pred = output.logits.softmax(axis=-1).detach().cpu() predictions.extend(pred[:, label2id["entailment"]].tolist()) diff --git a/belar/utils.py b/belar/utils.py index 913e827b1..403a1cc98 100644 --- a/belar/utils.py +++ b/belar/utils.py @@ -1,10 +1,9 @@ +from __future__ import annotations import typing as t from warnings import warn import torch - -if t.TYPE_CHECKING: - from torch import device as Device +from torch import device as Device DEVICES = ["cpu", "cuda"] diff --git a/tests/benchmarks/benchmark.py b/tests/benchmarks/benchmark.py index 50d63beb5..12732828b 100644 --- a/tests/benchmarks/benchmark.py +++ b/tests/benchmarks/benchmark.py @@ -5,8 +5,16 @@ from tqdm import tqdm from utils import print_table, timeit -from belar.metrics import (EditDistance, EditRatio, EntailmentScore, - Evaluation, Rouge1, Rouge2, RougeL, SBERTScore) +from belar.metrics import ( + EditDistance, + EditRatio, + EntailmentScore, + Evaluation, + Rouge1, + Rouge2, + RougeL, + SBERTScore, +) DEVICE = "cuda" if is_available() else "cpu" BATCHES = [0, 1]