Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions belar/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from belar.metrics.base import Evaluation, Metric
from belar.metrics.factual import EntailmentScore
from belar.metrics.similarity import SBERTScore
from belar.metrics.simple import (BLUE, EditDistance, EditRatio, Rouge1,
Rouge2, RougeL)
from belar.metrics.simple import BLUE, EditDistance, EditRatio, Rouge1, Rouge2, RougeL

__all__ = [
"Evaluation",
Expand Down
10 changes: 7 additions & 3 deletions belar/metrics/factual.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,16 @@ def is_batchable(self):

def batch_infer(self, inputs: dict):
predictions = []
input_ids = inputs["input_ids"]
input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
label2id = {value.lower(): key for key, value in self.id2label.items()}

for idx in range(0, len(input_ids), self.batch_size):
batch_ids = input_ids[idx : idx + self.batch_size]
output = self.model(batch_ids.to(self.device))
batch_inp_ids = input_ids[idx : idx + self.batch_size]
batch_attn_mask = attention_mask[idx : idx + self.batch_size]

output = self.model(
batch_inp_ids.to(self.device), batch_attn_mask.to(self.device)
)
pred = output.logits.softmax(axis=-1).detach().cpu()
predictions.extend(pred[:, label2id["entailment"]].tolist())

Expand Down
5 changes: 2 additions & 3 deletions belar/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations
import typing as t
from warnings import warn

import torch

if t.TYPE_CHECKING:
from torch import device as Device
from torch import device as Device

DEVICES = ["cpu", "cuda"]

Expand Down
12 changes: 10 additions & 2 deletions tests/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@
from tqdm import tqdm
from utils import print_table, timeit

from belar.metrics import (EditDistance, EditRatio, EntailmentScore,
Evaluation, Rouge1, Rouge2, RougeL, SBERTScore)
from belar.metrics import (
EditDistance,
EditRatio,
EntailmentScore,
Evaluation,
Rouge1,
Rouge2,
RougeL,
SBERTScore,
)

DEVICE = "cuda" if is_available() else "cpu"
BATCHES = [0, 1]
Expand Down