Skip to content

Commit

Permalink
Support eval for classification
Browse files Browse the repository at this point in the history
Also, extracted some eval code that is shared between
classification and object detection.
  • Loading branch information
lewfish committed Mar 30, 2018
1 parent 99c7f57 commit 871f2f9
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 87 deletions.
56 changes: 41 additions & 15 deletions src/rastervision/core/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,39 @@
from abc import ABC, abstractmethod
import json

from rastervision.utils.files import str_to_file


class Evaluation(ABC):
"""An evaluation of the predictions for a set of projects."""

@abstractmethod
def __init__(self):
self.clear()

def clear(self):
"""Clear the Evaluation."""
pass
self.class_to_eval_item = {}
self.avg_item = None

@abstractmethod
def compute(ground_truth_label_store, prediction_label_store):
"""Compute metrics for a single project.
def get_by_id(self, class_id):
return self.class_to_eval_item[class_id]

def to_json(self):
json_rep = []
for eval_item in self.class_to_eval_item.values():
json_rep.append(eval_item.to_json())
json_rep.append(self.avg_item.to_json())
return json_rep

def save(self, output_uri):
"""Save this Evaluation to a file.
Args:
ground_truth_label_store: LabelStore with the ground
truth
prediction_label_store: LabelStore with the
corresponding predictions
output_uri: string URI for the file to write.
"""
pass
json_str = json.dumps(self.to_json(), indent=4)
str_to_file(json_str, output_uri)

@abstractmethod
def merge(self, evaluation):
"""Merge Evaluation for another Project into this one.
Expand All @@ -31,13 +43,27 @@ def merge(self, evaluation):
Args:
evaluation: Evaluation to merge into this one
"""
pass
if len(self.class_to_eval_item) == 0:
self.class_to_eval_item = evaluation.class_to_eval_item
else:
for class_id, other_eval_item in evaluation.class_to_eval_item.items():
self.get_by_id(class_id).merge(other_eval_item)

self.compute_avg()

@abstractmethod
def save(self, output_uri):
"""Save this Evaluation to a file.
def compute(self, ground_truth_label_store, prediction_label_store):
"""Compute metrics for a single project.
Args:
output_uri: string URI for the file to write.
ground_truth_label_store: LabelStore with the ground
truth
prediction_label_store: LabelStore with the
corresponding predictions
"""
pass

@abstractmethod
def compute_avg(self):
"""Compute average metrics over classes."""
pass
30 changes: 30 additions & 0 deletions src/rastervision/core/evaluation_item.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class EvaluationItem(object):
"""Evaluation metrics for a single class."""
def __init__(self, precision, recall, f1, gt_count=None,
class_id=None, class_name=None):
self.precision = precision
self.recall = recall
self.f1 = f1

self.gt_count = gt_count
self.class_id = class_id
self.class_name = class_name

def merge(self, other):
total_gt_count = self.gt_count + other.gt_count
self_ratio = 0
other_ratio = 0
if total_gt_count > 0:
self_ratio = self.gt_count / total_gt_count
other_ratio = other.gt_count / total_gt_count

def avg(self_val, other_val):
return self_ratio * self_val + other_ratio * other_val

self.precision = avg(self.precision, other.precision)
self.recall = avg(self.recall, other.recall)
self.f1 = avg(self.f1, other.f1)
self.gt_count = total_gt_count

def to_json(self):
return self.__dict__
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from rastervision.core.evaluation_item import EvaluationItem


class ClassificationEvaluationItem(EvaluationItem):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from rastervision.core.evaluation_item import EvaluationItem


class ObjectDetectionEvaluationItem(EvaluationItem):
def __init__(self, precision, recall, f1, count_error, gt_count=None,
class_id=None, class_name=None):
super().__init__(
precision, recall, f1, gt_count=gt_count, class_id=class_id,
class_name=class_name)
self.count_error = count_error

def merge(self, other):
super().merge(other)

total_gt_count = self.gt_count + other.gt_count
self_ratio = self.gt_count / total_gt_count
other_ratio = other.gt_count / total_gt_count

def avg(self_val, other_val):
return self_ratio * self_val + other_ratio * other_val

self.count_error = avg(self.count_error, other.count_error)
59 changes: 51 additions & 8 deletions src/rastervision/evaluations/classification_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,58 @@
import numpy as np
from sklearn import metrics

from rastervision.core.evaluation import Evaluation
from rastervision.evaluation_items.classification_evaluation_item import (
ClassificationEvaluationItem)


def compute_eval_items(gt_labels, pred_labels, class_map):
nb_classes = len(class_map)
class_to_eval_item = {}

gt_class_ids = []
pred_class_ids = []

gt_cells = gt_labels.get_cells()
for gt_cell in gt_cells:
gt_class_id = gt_labels.get_cell_class_id(gt_cell)
pred_class_id = pred_labels.get_cell_class_id(gt_cell)

if gt_class_id is not None and pred_class_id is not None:
gt_class_ids.append(gt_class_id)
pred_class_ids.append(pred_class_id)

# Add 1 because class_ids start at 1.
sklabels = np.arange(1 + nb_classes)
precision, recall, f1, support = metrics.precision_recall_fscore_support(
gt_class_ids, pred_class_ids, labels=sklabels, warn_for=())

for class_map_item in class_map.get_items():
class_id = class_map_item.id
class_name = class_map_item.name

eval_item = ClassificationEvaluationItem(
float(precision[class_id]), float(recall[class_id]),
float(f1[class_id]), gt_count=float(support[class_id]),
class_id=class_id, class_name=class_name)
class_to_eval_item[class_id] = eval_item

return class_to_eval_item


class ClassificationEvaluation(Evaluation):
def clear(self):
pass
def compute(self, class_map, ground_truth_label_store,
prediction_label_store):
gt_labels = ground_truth_label_store.get_all_labels()
pred_labels = prediction_label_store.get_all_labels()

def compute(ground_truth_label_store, prediction_label_store):
pass
self.class_to_eval_item = compute_eval_items(
gt_labels, pred_labels, class_map)

def merge(self, evaluation):
pass
self.compute_avg()

def save(self, output_uri):
pass
def compute_avg(self):
self.avg_item = ClassificationEvaluationItem(
0, 0, 0, gt_count=0, class_name='average')
for eval_item in self.class_to_eval_item.values():
self.avg_item.merge(eval_item)
68 changes: 5 additions & 63 deletions src/rastervision/evaluations/object_detection_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,8 @@

from rastervision.core.evaluation import Evaluation
from rastervision.utils.files import str_to_file


class EvaluationItem(object):
def __init__(self, precision, recall, f1, count_error,
gt_count=None, class_id=None, class_name=None):
self.precision = precision
self.recall = recall
self.f1 = f1
self.count_error = count_error

self.gt_count = gt_count
self.class_id = class_id
self.class_name = class_name

def merge(self, other):
total_gt_count = self.gt_count + other.gt_count
self_ratio = self.gt_count / total_gt_count
other_ratio = other.gt_count / total_gt_count

def avg(self_val, other_val):
return self_ratio * self_val + other_ratio * other_val

self.precision = avg(self.precision, other.precision)
self.recall = avg(self.recall, other.recall)
self.f1 = avg(self.f1, other.f1)
self.count_error = avg(self.count_error, other.count_error)
self.gt_count = total_gt_count

def to_json(self):
return self.__dict__
from rastervision.evaluation_items.object_detection_evaluation_item import (
ObjectDetectionEvaluationItem)


def compute_od_eval(ground_truth_labels, prediction_labels):
Expand Down Expand Up @@ -71,7 +43,7 @@ def parse_od_eval(od_eval, class_map):
count_error = pred_count - gt_count
norm_count_error = count_error / gt_count

eval_item = EvaluationItem(
eval_item = ObjectDetectionEvaluationItem(
precision, recall, f1, norm_count_error, gt_count=gt_count,
class_id=class_id, class_name=class_name)
class_to_eval_item[class_id] = eval_item
Expand All @@ -80,16 +52,6 @@ def parse_od_eval(od_eval, class_map):


class ObjectDetectionEvaluation(Evaluation):
def __init__(self):
self.clear()

def clear(self):
self.class_to_eval_item = {}
self.avg_item = None

def get_by_id(self, class_id):
return self.class_to_eval_item[class_id]

def compute(self, class_map, ground_truth_label_store,
prediction_label_store):
gt_labels = ground_truth_label_store.get_all_labels()
Expand All @@ -101,27 +63,7 @@ def compute(self, class_map, ground_truth_label_store,
self.compute_avg()

def compute_avg(self):
self.avg_item = EvaluationItem(0, 0, 0, 0, gt_count=0,
class_name='average')
self.avg_item = ObjectDetectionEvaluationItem(
0, 0, 0, 0, gt_count=0, class_name='average')
for eval_item in self.class_to_eval_item.values():
self.avg_item.merge(eval_item)

def merge(self, evaluation):
if len(self.class_to_eval_item) == 0:
self.class_to_eval_item = evaluation.class_to_eval_item
else:
for class_id, other_eval_item in evaluation.class_to_eval_item.items():
self.get_by_id(class_id).merge(other_eval_item)

self.compute_avg()

def to_json(self):
json_rep = []
for eval_item in self.class_to_eval_item.values():
json_rep.append(eval_item.to_json())
json_rep.append(self.avg_item.to_json())
return json_rep

def save(self, output_uri):
json_str = json.dumps(self.to_json(), indent=4)
str_to_file(json_str, output_uri)
4 changes: 3 additions & 1 deletion src/rastervision/ml_tasks/classification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from rastervision.core.ml_task import MLTask
from rastervision.evaluations.classification_evaluation import (
ClassificationEvaluation)


class Classification(MLTask):
Expand All @@ -17,4 +19,4 @@ def get_predict_windows(self, extent, options):
return extent.get_windows(chip_size, stride)

def get_evaluation(self):
pass
return ClassificationEvaluation()

0 comments on commit 871f2f9

Please sign in to comment.