Skip to content

Commit

Permalink
fixed metric "estimation" (batch aggregation issue)
Browse files Browse the repository at this point in the history
  • Loading branch information
civodlu committed Feb 11, 2020
1 parent 2d7fc20 commit 1283082
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 50 deletions.
15 changes: 10 additions & 5 deletions src/trw/train/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ class LossDiceMulticlass(nn.Module):
If multi-class, compute the loss for each class then average the losses
"""
def __init__(self, normalization_fn=nn.Sigmoid, eps=0.0001):
def __init__(self, normalization_fn=nn.Sigmoid, eps=0.0001, return_dice_by_class=False):
super().__init__()

self.eps = eps
self.normalization = None
self.return_dice_by_class = return_dice_by_class

if normalization_fn is not None:
self.normalization = normalization_fn()
Expand All @@ -25,7 +26,8 @@ def forward(self, output, target):
target: must have W x d0 x ... x dn shape
Returns:
The dice score
if return_dice_by_class is False, return 1 - dice score suitable for optimization.
Else, return the average dice score by class
"""
assert len(output.shape) > 2
assert len(output.shape) == len(target.shape) + 1, 'output: must have W x C x d0 x ... x dn shape and target: must have W x d0 x ... x dn shape'
Expand All @@ -43,6 +45,9 @@ def forward(self, output, target):
numerator = 2 * intersection.sum(indices_to_sum)
denominator = output + encoded_target
denominator = denominator.sum(indices_to_sum) + self.eps

loss_per_channerl = 1 - numerator / denominator
return loss_per_channerl.sum(1) / output.shape[1]

if not self.return_dice_by_class:
loss_per_channerl = 1 - numerator / denominator
return loss_per_channerl.sum(1) / output.shape[1] # average over channels
else:
return (numerator / denominator).sum(0) / output.shape[0] # average over samples
152 changes: 116 additions & 36 deletions src/trw/train/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from trw.train import utilities
from sklearn import metrics
from trw.train import losses
import collections
import torch


Expand All @@ -13,14 +14,27 @@ class Metric:
"""
def __call__(self, outputs):
"""
Calculate a metric from the `outputs`
:param outputs: the data required to calculate the metric from
:return: a tuple (`metric name`, `metric value`) or `None`
Args:
outputs:
the outputs of a batch
Returns:
a dictionary of metric names/values or None
"""
metric_value = 0
return {
'metric_name': metric_value
}
return Metric, {'metric_name': metric_value}

@staticmethod
def aggregate_metrics(metric_by_batch):
"""
Args:
metric_by_batch: a list of metrics, one for each batch
Returns:
a dictionary of result name and value
"""
raise NotImplemented()


class MetricLoss(Metric):
Expand All @@ -30,11 +44,16 @@ class MetricLoss(Metric):
def __call__(self, outputs):
loss = utilities.to_value(outputs.get('loss'))
if loss is not None:
return {
'loss': float(loss)
}
return {'loss': float(loss)}
return None

@staticmethod
def aggregate_metrics(metric_by_batch):
loss = 0.0
for m in metric_by_batch:
loss += m['loss']
return {'loss': loss}


class MetricClassificationError(Metric):
"""
Expand All @@ -45,17 +64,27 @@ def __call__(self, outputs):
found = utilities.to_value(outputs.get('output'))
if truth is not None and found is not None:
return {
'classification error': 1.0 - np.sum(found == truth) / len(truth)
'nb_trues': np.sum(found == truth),
'total': len(truth)
}
return None

@staticmethod
def aggregate_metrics(metric_by_batch):
nb_trues = 0
total = 0
for m in metric_by_batch:
nb_trues += m['nb_trues']
total += m['total']
return {'classification error': 1.0 - nb_trues / total}


class MetricSegmentationDice(Metric):
"""
Calculate the average dice score of a segmentation map 'output_truth' and class
segmentation probabilities 'output_raw'
"""
def __init__(self, dice_fn=losses.LossDiceMulticlass()):
def __init__(self, dice_fn=losses.LossDiceMulticlass(return_dice_by_class=True)):
self.dice_fn = dice_fn

def __call__(self, outputs):
Expand All @@ -65,17 +94,35 @@ def __call__(self, outputs):
found = outputs.get('output_raw')

if found is None or truth is None:
return {}
return None

assert len(found.shape) == len(truth.shape) + 1, f'expecting dim={len(truth.shape)}, got={len(found.shape)}'
with torch.no_grad():
one_minus_dices = self.dice_fn(found, truth)
mean_dices = utilities.to_value(torch.mean(one_minus_dices))
dice_by_class = utilities.to_value(self.dice_fn(found, truth))

return {
'1-dice': mean_dices
'dice_by_class': dice_by_class
}

@staticmethod
def aggregate_metrics(metric_by_batch):
sum_dices = metric_by_batch[0]['dice_by_class']
for m in metric_by_batch[1:]:
sum_dices += m['dice_by_class']

nb_batches = len(metric_by_batch)
if nb_batches > 0:
# calculate the dice score by class
one_minus_dice = 1 - sum_dices / len(metric_by_batch)
r = collections.OrderedDict()
for c in range(len(sum_dices)):
r[f'1-dice[class={c}]'] = one_minus_dice[c]
r['1-dice'] = np.average(one_minus_dice)

return r

return {'1-dice': 1}


class MetricClassificationSensitivitySpecificity(Metric):
"""
Expand All @@ -93,34 +140,67 @@ def __call__(self, outputs):
if truth is not None and found is not None:
cm = metrics.confusion_matrix(y_pred=found, y_true=truth)
if len(cm) == 2:
# special case: binary classification
# special case: only binary classification
tn, fp, fn, tp = cm.ravel()

if tp + fn > 0:
one_minus_sensitivity = 1.0 - tp / (tp + fn)
else:
# invalid! `None` will be discarded
one_minus_sensitivity = None

if fp + tn > 0:
one_minus_specificity = 1.0 - tn / (fp + tn)
else:
# invalid! `None` will be discarded
one_minus_specificity = None

return {
# we return the 1.0 - metric, since in the history we always keep the smallest number
'1-sensitivity': one_minus_sensitivity,
'1-specificity': one_minus_specificity,
'tn': tn,
'fn': fn,
'fp': fp,
'tp': tp,
}
else:
return {
# this is perfect classification
'1-sensitivity': 0.0,
'1-specificity': 0.0,
}
if truth[0] == 0:
# 0, means perfect classification of the negative
return {
'tn': cm[0, 0],
'fn': 0,
'fp': 0,
'tp': 0,
}
else:
# 1, means perfect classification of the positive
return {
'tp': cm[0, 0],
'fn': 0,
'fp': 0,
'tn': 0,
}

# something is missing, don't calculate the stats
return None

@staticmethod
def aggregate_metrics(metric_by_batch):
tn = 0
fp = 0
fn = 0
tp = 0

for m in metric_by_batch:
tn += m['tn']
fn += m['fn']
tp += m['tp']
fp += m['fp']

if tp + fn > 0:
one_minus_sensitivity = 1.0 - tp / (tp + fn)
else:
# invalid! `None` will be discarded
one_minus_sensitivity = None

if fp + tn > 0:
one_minus_specificity = 1.0 - tn / (fp + tn)
else:
# invalid! `None` will be discarded
one_minus_specificity = None

return {
# we return the 1.0 - metric, since in the history we always keep the smallest number
'1-sensitivity': one_minus_sensitivity,
'1-specificity': one_minus_specificity,
}


def default_classification_metrics():
""""
Expand Down
10 changes: 5 additions & 5 deletions src/trw/train/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def loss_term_cleanup(self, loss_term):
dict_torch_values_to_numpy(metrics_results)



def extract_metrics(metrics_outputs, outputs):
"""
Extract metrics from an output
Expand All @@ -81,10 +80,11 @@ def extract_metrics(metrics_outputs, outputs):
"""
history = collections.OrderedDict()
for metric in metrics_outputs:
r = metric(outputs)
if r is not None:
assert isinstance(r, collections.Mapping), 'must be a dict like structure'
history.update(r)
metric_result = metric(outputs)
if metric_result is not None:
metric_type = type(metric)
assert isinstance(metric_result, collections.Mapping), 'must be a dict like structure'
history.update({metric_type: metric_result})
return history


Expand Down
17 changes: 15 additions & 2 deletions src/trw/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,24 @@ def aggregate_list_of_dicts(list_of_dicts):
for key in keys:
values = [dict[key] for dict in list_of_dicts]
values = [v for v in values if v is not None]
aggregate_values(values)
aggregated[key] = aggregate_values(values)
return aggregated


def aggregate_list_of_metrics(list_of_metrics):
if len(list_of_metrics) == 0:
return {}

keys = list_of_metrics[0].keys()
aggregated = collections.OrderedDict()
for key in keys:
values = [dict[key] for dict in list_of_metrics]
aggregated_values = key.aggregate_metrics(values)
for name, value in aggregated_values.items():
aggregated[name] = value
return aggregated


def generic_aggregate_loss_terms(loss_terms_history):
"""
Aggregate the loss terms for all the internal_nodes of an epoch
Expand Down Expand Up @@ -164,7 +177,7 @@ def generic_aggregate_loss_terms(loss_terms_history):
loss_term_outputs.append(loss_term_output)

aggregated_outputs[output_name] = aggregate_list_of_dicts(loss_term_outputs)
aggregated_metrics[output_name] = aggregate_list_of_dicts(loss_term_metrics_results)
aggregated_metrics[output_name] = aggregate_list_of_metrics(loss_term_metrics_results)

# keep the `overall_loss` in the metrics
overall_losses = []
Expand Down
12 changes: 10 additions & 2 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_classification_accuracy_100(self):
r = o.evaluate_batch(batch, False)
history = r['metrics_results']

history = trw.train.trainer.aggregate_list_of_metrics([history])
assert history['classification error'] == 0.0
assert history['loss'] == 0.0

Expand All @@ -26,6 +27,7 @@ def test_classification_accuracy_33(self):
r = o.evaluate_batch(batch, False)
history = r['metrics_results']

history = trw.train.trainer.aggregate_list_of_metrics([history])
assert abs(history['classification error'] - 0.33333) < 1e-4
assert abs(history['loss'] - 0.33333 * 100) < 1e-2

Expand All @@ -49,6 +51,7 @@ def test_classification_sensitivity_1_specificity_0(self):
r = o.evaluate_batch(batch, False)
history = r['metrics_results']

history = trw.train.trainer.aggregate_list_of_metrics([history])
assert abs(history['classification error'] - 1.0 / (4)) < 1e-4
assert abs(history['1-sensitivity'] - (1 - 1.0)) < 1e-4
assert abs(history['1-specificity'] - (1 - 2.0 / 3)) < 1e-4
Expand All @@ -73,6 +76,7 @@ def test_classification_sensitivity_0_specificity_1(self):
r = o.evaluate_batch(batch, False)
history = r['metrics_results']

history = trw.train.trainer.aggregate_list_of_metrics([history])
assert abs(history['classification error'] - 1.0 / 4) < 1e-4
assert abs(history['1-specificity'] - (1 - 1.0)) < 1e-4
assert abs(history['1-sensitivity'] - (1 - 1.0 / 2)) < 1e-4
Expand All @@ -86,9 +90,10 @@ def test_metrics_sensitivity_specificity_perfect(self):
r = o.evaluate_batch(batch, False)
history = r['metrics_results']

history = trw.train.trainer.aggregate_list_of_metrics([history])
assert abs(history['classification error']) < 1e-4
assert abs(history['1-specificity']) < 1e-4
assert abs(history['1-sensitivity']) < 1e-4
assert history['1-sensitivity'] is None

def test_metrics_sensitivity_specificity_all_wrong(self):
input_values = torch.from_numpy(np.asarray([[1, 0], [1, 0], [1, 0], [1, 0]], dtype=float))
Expand All @@ -99,6 +104,7 @@ def test_metrics_sensitivity_specificity_all_wrong(self):
r = o.evaluate_batch(batch, False)
history = r['metrics_results']

history = trw.train.trainer.aggregate_list_of_metrics([history])
assert abs(history['classification error'] - 1.0) < 1e-4
assert history['1-specificity'] is None
assert abs(history['1-sensitivity'] - 1.0) < 1e-4
Expand All @@ -112,6 +118,7 @@ def test_metrics_sensitivity_specificity_all_wrong_specificity_none(self):
r = o.evaluate_batch(batch, False)
history = r['metrics_results']

history = trw.train.trainer.aggregate_list_of_metrics([history])
assert abs(history['classification error'] - 1.0) < 1e-4
assert history['1-specificity'] is None
assert abs(history['1-sensitivity'] - 1.0) < 1e-4
Expand All @@ -125,6 +132,7 @@ def test_metrics_sensitivity_specificity_all_wrong_sensitivity_none(self):
r = o.evaluate_batch(batch, False)
history = r['metrics_results']

history = trw.train.trainer.aggregate_list_of_metrics([history])
assert abs(history['classification error'] - 1.0) < 1e-4
assert history['1-sensitivity'] is None
assert abs(history['1-specificity'] - 1.0) < 1e-4
Expand All @@ -144,7 +152,7 @@ def test_metrics_with_none_aggregated(self):
batch = {'target': target_values}
r2 = o.evaluate_batch(batch, False)

r = trw.train.trainer.aggregate_list_of_dicts([r1['metrics_results'], r2['metrics_results']])
r = trw.train.trainer.aggregate_list_of_metrics([r1['metrics_results'], r2['metrics_results']])

# make sure we can aggregate appropriately the metrics, even if there is a `None` value
assert abs(r['classification error'] - 1.0) < 1e-4
Expand Down

0 comments on commit 1283082

Please sign in to comment.