Skip to content

Commit

Permalink
fix(tests): fix attribute error
Browse files Browse the repository at this point in the history
fix(tests): fix attribute error
  • Loading branch information
Sharad Sirsat committed Jul 5, 2023
1 parent 1dcd6fa commit 3f37077
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
26 changes: 23 additions & 3 deletions mmocr/evaluation/evaluator/multi_datasets_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,29 @@ def evaluate(self, size: int) -> dict:
metrics_results.update(metric_results)
metric.results.clear()
if is_main_process():
metrics_results = [metrics_results]
averaged_results = self.average_results(metrics_results)
else:
metrics_results = [None] # type: ignore
averaged_results = None

metrics_results = [metrics_results]
broadcast_object_list(metrics_results)
broadcast_object_list([averaged_results])

return metrics_results[0], averaged_results

def average_results(self, metrics_results):
"""Compute the average of metric results across all datasets.
Args:
metrics_results (dict): Evaluation results of all metrics.
Returns:pre
dict: Average evaluation results of all metrics.
"""
averaged_results = {}
num_datasets = len(self.dataset_prefixes)
for metric_name, metric_result in metrics_results.items():
metric_avg = metric_result / num_datasets
averaged_results[metric_name] = metric_avg

return metrics_results[0]
return averaged_results
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

import math
from collections import OrderedDict
from typing import Dict, List, Optional
from unittest import TestCase

Expand Down Expand Up @@ -75,7 +76,7 @@ def generate_test_results(size, batch_size, pred, label):
predictions = [
BaseDataElement(pred=pred, label=label) for _ in range(bs)
]
yield (data_batch, predictions)
yield data_batch, predictions


class TestMultiDatasetsEvaluator(TestCase):
Expand Down Expand Up @@ -126,3 +127,22 @@ def test_composed_metrics(self):
metrics = evaluator.evaluate(size=size)
self.assertIn('Fake/Toy/accuracy', metrics)
self.assertIn('Fake/accuracy', metrics)

metrics_results = OrderedDict({
'dataset1/metric1/accuracy': 0.9,
'dataset1/metric2/f1_score': 0.8,
'dataset2/metric1/accuracy': 0.85,
'dataset2/metric2/f1_score': 0.75
})

evaluator = MultiDatasetsEvaluator([], [])
averaged_results = evaluator.average_results(metrics_results)

expected_averaged_results = {
'dataset1/metric1/accuracy': 0.9,
'dataset1/metric2/f1_score': 0.8,
'dataset2/metric1/accuracy': 0.85,
'dataset2/metric2/f1_score': 0.75
}

self.assertEqual(averaged_results, expected_averaged_results)

0 comments on commit 3f37077

Please sign in to comment.