Skip to content

Commit

Permalink
Util: Label-based prediction output.
Browse files Browse the repository at this point in the history
  • Loading branch information
janothan committed Feb 1, 2021
1 parent b004138 commit 22e3446
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
18 changes: 13 additions & 5 deletions kbc_evaluation/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __init__(
self,
file_to_be_evaluated: str,
is_apply_filtering: bool = False,
triples_to_read: int = None,
is_stop_early: bool = True,
):
"""Constructor. Note that the file is immediately parsed.
Expand All @@ -204,13 +204,15 @@ def __init__(
Path to the file that shall be evaluated (in the prediction file format).
is_apply_filtering : bool
True if filtering shall be applied as described in Bordes et al.
triples_to_read : int
Optional limit on the number of triples to be read. If None, all triples will be read.
is_stop_early : bool
By default true. Stop parsing after the correct prediction was found. This greatly improves memory and
disk consumption. In some cases (debugging, analyzing results), it may make sense to not stop early.
"""
self.file_to_be_evaluated = file_to_be_evaluated
self.is_apply_filtering = is_apply_filtering
self.total_prediction_tasks = 0
self.triple_predictions = {}
self.is_stop_early = is_stop_early

# initialize lookup datastructures for filtering (contains only true statements)
# the maps are types as follows:
Expand Down Expand Up @@ -266,7 +268,10 @@ def _apply_filtering(self) -> None:
for predicted_head in heads:
if predicted_head == truth[0]:
new_heads.append(predicted_head)
break # before continue
if self.is_stop_early:
break
else:
continue
if predicted_head not in correct_heads:
new_heads.append(predicted_head)

Expand All @@ -280,7 +285,10 @@ def _apply_filtering(self) -> None:
for predicted_tail in tails:
if predicted_tail == truth[2]:
new_tails.append(predicted_tail)
break # before continue
if self.is_stop_early:
break
else:
continue
if predicted_tail not in correct_tails:
new_tails.append(predicted_tail)

Expand Down
25 changes: 24 additions & 1 deletion kbc_evaluation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,31 @@ def write_sample_predictions(
top_predictions: int = 10,
number_of_triples: int = 100,
) -> None:
"""Method to write human-understandable predictions.
Parameters
----------
prediction_file : str
The file with the predictions.
file_to_be_written : str
The evaluation file that shall be written.
data_set : DataSet
The dataset that is used.
is_apply_filtering : True
True if filtered predictions shall be shown, else false.
top_predictions : int
Out of the predictions, the top N of the predictions.
number_of_triples : int
The number of triples to be evaluated (in most cases 100 or 1000 may be sufficient).
Returns
-------
None
"""
predictions_set = ParsedSet(
file_to_be_evaluated=prediction_file, is_apply_filtering=is_apply_filtering
file_to_be_evaluated=prediction_file,
is_apply_filtering=is_apply_filtering,
is_stop_early=False,
)

definitions_map = data_set.definitions_map()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ def test_definitions_file(self):
fb15k_map = DataSet.FB15K.definitions_map()
assert fb15k_map is not None

# concept check
# concept check wn
assert wn_map["08293982"][0] == "__coalition_NN_1"
assert (
wn_map["08293982"][1]
== "an organization of people (or countries) involved in a pact or treaty"
)

# concept check fb15k
assert fb15k_map["/m/0102t4"][0] == "Marshall"
assert fb15k_map["/m/0102t4"][1] == "city in Texas, USA"

Expand Down

0 comments on commit 22e3446

Please sign in to comment.