Skip to content

Commit

Permalink
fix(metrics): fix simple_accuracy for nested lists of diff lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
tholor committed Aug 2, 2019
1 parent 9d5a64f commit 012e158
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
11 changes: 6 additions & 5 deletions farm/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
from scipy.stats import pearsonr, spearmanr
from seqeval.metrics import f1_score as seq_f1_score
from sklearn.metrics import matthews_corrcoef, f1_score

from farm.utils import flatten_list

def simple_accuracy(preds, labels):
# TODO: THIS HACKY TRY CATCH IS FOR GNAD
try:
preds = np.array(preds)
labels = np.array(labels)
correct = preds == labels
# works also with nested lists of different lengths (needed for masked LM task)
flat_preds = np.array(list(flatten_list(preds)))
flat_labels = np.array(list(flatten_list(labels)))
correct = flat_preds == flat_labels
return {"acc": correct.mean()}
except TypeError:
# TODO: THIS HACKY TRY CATCH IS FOR GNAD
return {"acc": (preds == labels.numpy()).mean()}


Expand Down
20 changes: 20 additions & 0 deletions farm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch
import mlflow
from copy import deepcopy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -161,3 +162,22 @@ def convert_iob_to_simple_tags(preds, spans):
simple_tags.append(cur_tag)
open_tag = False
return simple_tags, merged_spans


def flatten_list(nested_list):
"""Flatten an arbitrarily nested list, without recursion (to avoid
stack overflows). Returns a new list, the original list is unchanged.
>> list(flatten_list([1, 2, 3, [4], [], [[[[[[[[[5]]]]]]]]]]))
[1, 2, 3, 4, 5]
>> list(flatten_list([[1, 2], 3]))
[1, 2, 3]
"""
nested_list = deepcopy(nested_list)

while nested_list:
sublist = nested_list.pop(0)

if isinstance(sublist, list):
nested_list = sublist + nested_list
else:
yield sublist

0 comments on commit 012e158

Please sign in to comment.