-
Notifications
You must be signed in to change notification settings - Fork 377
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Also, extracted some eval code that is shared between classification and object detection.
- Loading branch information
Showing
8 changed files
with
157 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
class EvaluationItem(object): | ||
"""Evaluation metrics for a single class.""" | ||
def __init__(self, precision, recall, f1, gt_count=None, | ||
class_id=None, class_name=None): | ||
self.precision = precision | ||
self.recall = recall | ||
self.f1 = f1 | ||
|
||
self.gt_count = gt_count | ||
self.class_id = class_id | ||
self.class_name = class_name | ||
|
||
def merge(self, other): | ||
total_gt_count = self.gt_count + other.gt_count | ||
self_ratio = 0 | ||
other_ratio = 0 | ||
if total_gt_count > 0: | ||
self_ratio = self.gt_count / total_gt_count | ||
other_ratio = other.gt_count / total_gt_count | ||
|
||
def avg(self_val, other_val): | ||
return self_ratio * self_val + other_ratio * other_val | ||
|
||
self.precision = avg(self.precision, other.precision) | ||
self.recall = avg(self.recall, other.recall) | ||
self.f1 = avg(self.f1, other.f1) | ||
self.gt_count = total_gt_count | ||
|
||
def to_json(self): | ||
return self.__dict__ |
Empty file.
5 changes: 5 additions & 0 deletions
5
src/rastervision/evaluation_items/classification_evaluation_item.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from rastervision.core.evaluation_item import EvaluationItem | ||
|
||
|
||
class ClassificationEvaluationItem(EvaluationItem): | ||
pass |
22 changes: 22 additions & 0 deletions
22
src/rastervision/evaluation_items/object_detection_evaluation_item.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from rastervision.core.evaluation_item import EvaluationItem | ||
|
||
|
||
class ObjectDetectionEvaluationItem(EvaluationItem): | ||
def __init__(self, precision, recall, f1, count_error, gt_count=None, | ||
class_id=None, class_name=None): | ||
super().__init__( | ||
precision, recall, f1, gt_count=gt_count, class_id=class_id, | ||
class_name=class_name) | ||
self.count_error = count_error | ||
|
||
def merge(self, other): | ||
super().merge(other) | ||
|
||
total_gt_count = self.gt_count + other.gt_count | ||
self_ratio = self.gt_count / total_gt_count | ||
other_ratio = other.gt_count / total_gt_count | ||
|
||
def avg(self_val, other_val): | ||
return self_ratio * self_val + other_ratio * other_val | ||
|
||
self.count_error = avg(self.count_error, other.count_error) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,58 @@ | ||
import numpy as np | ||
from sklearn import metrics | ||
|
||
from rastervision.core.evaluation import Evaluation | ||
from rastervision.evaluation_items.classification_evaluation_item import ( | ||
ClassificationEvaluationItem) | ||
|
||
|
||
def compute_eval_items(gt_labels, pred_labels, class_map): | ||
nb_classes = len(class_map) | ||
class_to_eval_item = {} | ||
|
||
gt_class_ids = [] | ||
pred_class_ids = [] | ||
|
||
gt_cells = gt_labels.get_cells() | ||
for gt_cell in gt_cells: | ||
gt_class_id = gt_labels.get_cell_class_id(gt_cell) | ||
pred_class_id = pred_labels.get_cell_class_id(gt_cell) | ||
|
||
if gt_class_id is not None and pred_class_id is not None: | ||
gt_class_ids.append(gt_class_id) | ||
pred_class_ids.append(pred_class_id) | ||
|
||
# Add 1 because class_ids start at 1. | ||
sklabels = np.arange(1 + nb_classes) | ||
precision, recall, f1, support = metrics.precision_recall_fscore_support( | ||
gt_class_ids, pred_class_ids, labels=sklabels, warn_for=()) | ||
|
||
for class_map_item in class_map.get_items(): | ||
class_id = class_map_item.id | ||
class_name = class_map_item.name | ||
|
||
eval_item = ClassificationEvaluationItem( | ||
float(precision[class_id]), float(recall[class_id]), | ||
float(f1[class_id]), gt_count=float(support[class_id]), | ||
class_id=class_id, class_name=class_name) | ||
class_to_eval_item[class_id] = eval_item | ||
|
||
return class_to_eval_item | ||
|
||
|
||
class ClassificationEvaluation(Evaluation): | ||
def clear(self): | ||
pass | ||
def compute(self, class_map, ground_truth_label_store, | ||
prediction_label_store): | ||
gt_labels = ground_truth_label_store.get_all_labels() | ||
pred_labels = prediction_label_store.get_all_labels() | ||
|
||
def compute(ground_truth_label_store, prediction_label_store): | ||
pass | ||
self.class_to_eval_item = compute_eval_items( | ||
gt_labels, pred_labels, class_map) | ||
|
||
def merge(self, evaluation): | ||
pass | ||
self.compute_avg() | ||
|
||
def save(self, output_uri): | ||
pass | ||
def compute_avg(self): | ||
self.avg_item = ClassificationEvaluationItem( | ||
0, 0, 0, gt_count=0, class_name='average') | ||
for eval_item in self.class_to_eval_item.values(): | ||
self.avg_item.merge(eval_item) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters