Skip to content

Commit

Permalink
Fix in result data structure, better unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
janothan committed Jan 15, 2021
1 parent fcb83e4 commit 77da502
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
2 changes: 1 addition & 1 deletion kbc_evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def write_result_object_to_file(
+ f"Relative Hits at {result_object.n}: {result_object.non_filtered_hits_at_n_relative}\n"
+ f"Mean rank (Heads): {result_object.non_filtered_mean_rank_heads}\n"
+ f"Mean rank (Tails): {result_object.non_filtered_mean_rank_tails}\n"
+ f"Mean rank (All): {result_object.evaluated_file}\n"
+ f"Mean rank (All): {result_object.non_filtered_mean_rank_all}\n"
)

filtered_text = (
Expand Down
50 changes: 30 additions & 20 deletions tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ def test_hits_at(self):
os.chdir("./..")
assert os.path.isfile(test_file_path)

evaluator = EvaluationRunner(file_to_be_evaluated=test_file_path)
assert evaluator.calculate_hits_at(1)[2] == 2
assert evaluator.calculate_hits_at(3)[2] == 3
assert evaluator.calculate_hits_at(10)[2] == 4
runner = EvaluationRunner(file_to_be_evaluated=test_file_path)
assert runner.calculate_hits_at(1)[2] == 2
assert runner.calculate_hits_at(3)[2] == 3
assert runner.calculate_hits_at(10)[2] == 4

def test_hits_at_with_confidence(self):
test_file_path = "./tests/test_resources/eval_test_file_with_confidences.txt"
if not os.path.isfile(test_file_path):
os.chdir("./..")
assert os.path.isfile(test_file_path)

evaluator = EvaluationRunner(file_to_be_evaluated=test_file_path)
assert evaluator.calculate_hits_at(1)[2] == 2
assert evaluator.calculate_hits_at(3)[2] == 3
assert evaluator.calculate_hits_at(10)[2] == 4
runner = EvaluationRunner(file_to_be_evaluated=test_file_path)
assert runner.calculate_hits_at(1)[2] == 2
assert runner.calculate_hits_at(3)[2] == 3
assert runner.calculate_hits_at(10)[2] == 4

def test_calculate_results_no_filtering(self):
test_file_path = "./tests/test_resources/eval_test_file_with_confidences.txt"
Expand All @@ -45,18 +45,28 @@ def test_calculate_results_no_filtering(self):
file_to_be_evaluated=test_file_path, data_set=DataSet.WN18, n=1
)
assert results.filtered_hits_at_n_all == 2
assert results.filtered_hits_at_n_all >= results.filtered_hits_at_n_all
assert results.n == 1

# simple type assertions
assert type(results.evaluated_file) == str
assert type(results.n) == int
assert type(results.filtered_hits_at_n_heads) == int
assert type(results.filtered_hits_at_n_tails) == int
assert type(results.filtered_hits_at_n_all) == int

results = Evaluator.calculate_results(
file_to_be_evaluated=test_file_path, data_set=DataSet.WN18, n=3
)
assert results.filtered_hits_at_n_all == 3
assert results.filtered_hits_at_n_all >= results.filtered_hits_at_n_all
assert results.n == 3

results = Evaluator.calculate_results(
file_to_be_evaluated=test_file_path, data_set=DataSet.WN18, n=10
)
assert results.filtered_hits_at_n_all == 4
assert results.filtered_hits_at_n_all >= results.filtered_hits_at_n_all
assert results.n == 10

def test_hits_at_filtering(self):
Expand All @@ -65,12 +75,12 @@ def test_hits_at_filtering(self):
os.chdir("./..")
assert os.path.isfile(test_file_path)

evaluator = EvaluationRunner(
runner = EvaluationRunner(
file_to_be_evaluated=test_file_path, is_apply_filtering=True
)
assert evaluator.calculate_hits_at(1)[2] == 2
assert evaluator.calculate_hits_at(3)[2] == 4
assert evaluator.calculate_hits_at(10)[2] == 6
assert runner.calculate_hits_at(1)[2] == 2
assert runner.calculate_hits_at(3)[2] == 4
assert runner.calculate_hits_at(10)[2] == 6

def test_hits_at_filtering_with_confidence(self):
test_file_path = (
Expand All @@ -80,24 +90,24 @@ def test_hits_at_filtering_with_confidence(self):
os.chdir("./..")
assert os.path.isfile(test_file_path)

evaluator = EvaluationRunner(
runner = EvaluationRunner(
file_to_be_evaluated=test_file_path, is_apply_filtering=True
)
assert evaluator.calculate_hits_at(1)[2] == 2
assert evaluator.calculate_hits_at(3)[2] == 4
assert evaluator.calculate_hits_at(10)[2] == 6
assert runner.calculate_hits_at(1)[2] == 2
assert runner.calculate_hits_at(3)[2] == 4
assert runner.calculate_hits_at(10)[2] == 6

def test_mean_rank(self):
test_file_path = "./tests/test_resources/eval_test_file.txt"
assert os.path.isfile(test_file_path)
evaluator = EvaluationRunner(file_to_be_evaluated=test_file_path)
assert evaluator.mean_rank()[2] == 3
runner = EvaluationRunner(file_to_be_evaluated=test_file_path)
assert runner.mean_rank()[2] == 3

def test_mean_rank_with_confidence(self):
test_file_path = "./tests/test_resources/eval_test_file_with_confidences.txt"
assert os.path.isfile(test_file_path)
evaluator = EvaluationRunner(file_to_be_evaluated=test_file_path)
assert evaluator.mean_rank()[2] == 3
runner = EvaluationRunner(file_to_be_evaluated=test_file_path)
assert runner.mean_rank()[2] == 3

def test_write_results_to_file(self):
test_file_path = "./tests/test_resources/eval_test_file.txt"
Expand Down

0 comments on commit 77da502

Please sign in to comment.