Skip to content

Commit

Permalink
[Fix] fix
Browse files Browse the repository at this point in the history
  • Loading branch information
acdart committed Oct 17, 2021
1 parent a7ae896 commit b6526af
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 27 deletions.
22 changes: 9 additions & 13 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def calc_adaptive_fm(pred_label, label, beta=0.3):
else:
precision = area_intersection / torch.count_nonzero(binary_pred_label)
recall = area_intersection / torch.count_nonzero(label)
adaptive_fm = (1 + beta) * precision * recall / (beta * precision + recall)
adaptive_fm = (1 + beta) * precision * recall / (
beta * precision + recall)
return adaptive_fm


Expand Down Expand Up @@ -310,8 +311,7 @@ def eval_metrics(results,
return ret_metrics


def calc_sod_metrics(pred_label,
label):
def calc_sod_metrics(pred_label, label):
if isinstance(pred_label, str):
pred_label = torch.from_numpy(np.load(pred_label))
else:
Expand All @@ -323,27 +323,25 @@ def calc_sod_metrics(pred_label,
else:
label = torch.from_numpy(label)

pred_label = pred_label.float()
if pred_label.max() != pred_label.min():
pred_label = (pred_label - pred_label.min()) / (pred_label.max() - pred_label.min())
pred_label = (pred_label - pred_label.min()) / (
pred_label.max() - pred_label.min())

mae = calc_mae(pred_label, label)
adaptive_fm = calc_adaptive_fm(pred_label, label)

return mae, adaptive_fm


def pre_eval_to_sod_metrics(pre_eval_results,
nan_to_num=None):
def pre_eval_to_sod_metrics(pre_eval_results, nan_to_num=None):
pre_eval_results = tuple(zip(*pre_eval_results))
assert len(pre_eval_results) == 2

mae = sum(pre_eval_results[0]) / len(pre_eval_results[0])
adp_fm = sum(pre_eval_results[1]) / len(pre_eval_results[1])

ret_metrics = OrderedDict({
'MAE': mae.numpy(),
'adpFm': adp_fm.numpy()
})
ret_metrics = OrderedDict({'MAE': mae.numpy(), 'adpFm': adp_fm.numpy()})
if nan_to_num is not None:
ret_metrics = OrderedDict({
metric: np.nan_to_num(metric_value, nan=nan_to_num)
Expand All @@ -352,9 +350,7 @@ def pre_eval_to_sod_metrics(pre_eval_results,
return ret_metrics


def eval_sod_metrics(results,
gt_seg_maps,
nan_to_num=None):
def eval_sod_metrics(results, gt_seg_maps, nan_to_num=None):
maes = []
adp_fms = []

Expand Down
12 changes: 11 additions & 1 deletion mmseg/datasets/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import mmcv
import numpy as np
import torch.nn.functional as F
from mmcv.utils import print_log
from prettytable import PrettyTable
from torch.utils.data import Dataset
Expand Down Expand Up @@ -262,7 +263,7 @@ def get_gt_seg_maps(self, efficient_test=None):
self.gt_seg_map_loader(results)
yield results['gt_semantic_seg']

def pre_eval(self, preds, indices):
def pre_eval(self, preds, indices, return_logit=False):
"""Collect eval result from each iteration.
Args:
Expand All @@ -284,6 +285,14 @@ def pre_eval(self, preds, indices):
pre_eval_results = []

for pred, index in zip(preds, indices):
if return_logit:
if pred.shape[0] >= 2:
pred = F.softmax(pred, dim=0)
pred = pred.argmax(dim=0)
else:
pred = F.sigmoid(pred)
pred = pred.squeeze(0)
pred = (pred > 0.5).int()
seg_map = self.get_gt_seg_map_by_idx(index)
pre_eval_results.append(
intersect_and_union(pred, seg_map, len(self.CLASSES),
Expand Down Expand Up @@ -358,6 +367,7 @@ def get_palette_for_custom_classes(self, class_names, palette=None):
def evaluate(self,
results,
metric='mIoU',
return_logit=False,
logger=None,
gt_seg_maps=None,
**kwargs):
Expand Down
22 changes: 14 additions & 8 deletions mmseg/datasets/sod_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import mmcv
import numpy as np
import torch.nn.functional as F
from mmcv.utils import print_log
from prettytable import PrettyTable

from mmseg.core import calc_sod_metrics, eval_sod_metrics, \
pre_eval_to_sod_metrics
from mmseg.core import (calc_sod_metrics, eval_sod_metrics,
pre_eval_to_sod_metrics)
from . import CustomDataset
from .builder import DATASETS

Expand All @@ -21,7 +22,7 @@ class SODCustomDataset(CustomDataset):
def __init__(self, **kwargs):
super(SODCustomDataset, self).__init__(**kwargs)

def pre_eval(self, preds, indices):
def pre_eval(self, preds, indices, return_logit=False):
"""Collect eval result from each iteration.
Args:
Expand All @@ -43,15 +44,22 @@ def pre_eval(self, preds, indices):
pre_eval_results = []

for pred, index in zip(preds, indices):
if return_logit:
if pred.shape[0] >= 2:
pred = F.softmax(pred, dim=0)
pred = pred[1]
else:
pred = F.sigmoid(pred)
pred = pred.squeeze(0)
seg_map = self.get_gt_seg_map_by_idx(index)
pre_eval_results.append(
calc_sod_metrics(pred, seg_map))
pre_eval_results.append(calc_sod_metrics(pred, seg_map))

return pre_eval_results

def evaluate(self,
results,
logger=None,
return_logit=False,
gt_seg_maps=None,
**kwargs):
"""Evaluate the dataset.
Expand All @@ -77,9 +85,7 @@ def evaluate(self,
results, str):
if gt_seg_maps is None:
gt_seg_maps = self.get_gt_seg_maps()
ret_metrics = eval_sod_metrics(
results,
gt_seg_maps)
ret_metrics = eval_sod_metrics(results, gt_seg_maps)
# test a list of pre_eval_results
else:
ret_metrics = pre_eval_to_sod_metrics(results)
Expand Down
13 changes: 8 additions & 5 deletions mmseg/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def cross_entropy(pred,
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_zeros(target_shape)
valid_mask = (labels > 0) & (labels != ignore_index)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(valid_mask, as_tuple=True)

if inds[0].numel() > 0:
if labels.dim() == 3:
bin_labels[inds[0], labels[valid_mask] - 1, inds[1], inds[2]] = 1
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
else:
bin_labels[inds[0], labels[valid_mask] - 1] = 1
bin_labels[inds[0], labels[valid_mask]] = 1

valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None:
Expand Down Expand Up @@ -83,8 +83,11 @@ def binary_cross_entropy(pred,
pred.dim() == 4 and label.dim() == 3), \
'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
'H, W], label shape [N, H, W] are supported'
label, weight = _expand_onehot_labels(label, weight, pred.shape,
ignore_index)
if pred.shape[1] == 1:
pred = pred.squeeze(1)
else:
label, weight = _expand_onehot_labels(label, weight, pred.shape,
ignore_index)

# weighted element-wise losses
if weight is not None:
Expand Down

0 comments on commit b6526af

Please sign in to comment.