Skip to content

Commit

Permalink
update miou and acc formulas
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenmt committed Nov 22, 2020
1 parent 3b6eb5d commit f410f88
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 56 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Download our pre-processed `NYUv2` dataset [here](https://www.dropbox.com/sh/86n

**Update - July 2020**: We have further improved the readability and updated all implementations in `im2im_pred` to comply the current latest version PyTorch 1.5. We fixed a bug to exclude non-defined pixel predictions for a more accurate mean IoU computation in semantic segmentation tasks. We also provided an additional option for users applying data augmentation in NYUv2 to avoid over-fitting and achieve better performances.

**Update - Nov 2020 [IMPORTANT!]**: We have updated mIoU and Acc. formulas to be consistent with the standard benchmark from the [official COCO segmentation scripts](https://github.com/pytorch/vision/tree/master/references/segmentation). The mIoU for all methods are now expected to improve approximately 8%. The new formulas compute mIoU and Acc. based on the accumulated pixel predictions across all images, while the original formulas compute mIoU and Acc. based on pixel predictions averaged in each image across all images.

All models (files) built with SegNet (proposed in the original paper), are described in the following table:

| File Name | Type | Flags | Comments |
Expand Down
149 changes: 93 additions & 56 deletions im2im_pred/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,52 +31,75 @@ def model_fit(x_pred, x_output, task_type):

return loss


def compute_miou(x_pred, x_output):
_, x_pred_label = torch.max(x_pred, dim=1)
x_output_label = x_output
batch_size = x_pred.size(0)
class_nb = x_pred.size(1)
device = x_pred.device
for i in range(batch_size):
true_class = 0
first_switch = True
invalid_mask = (x_output[i] >= 0).float()
for j in range(class_nb):
pred_mask = torch.eq(x_pred_label[i], j * torch.ones(x_pred_label[i].shape).long().to(device))
true_mask = torch.eq(x_output_label[i], j * torch.ones(x_output_label[i].shape).long().to(device))
mask_comb = pred_mask.float() + true_mask.float()
union = torch.sum((mask_comb > 0).float() * invalid_mask) # remove non-defined pixel predictions
intsec = torch.sum((mask_comb > 1).float())
if union == 0:
continue
if first_switch:
class_prob = intsec / union
first_switch = False
else:
class_prob = intsec / union + class_prob
true_class += 1
if i == 0:
batch_avg = class_prob / true_class
else:
batch_avg = class_prob / true_class + batch_avg
return batch_avg / batch_size


def compute_iou(x_pred, x_output):
_, x_pred_label = torch.max(x_pred, dim=1)
x_output_label = x_output
batch_size = x_pred.size(0)
for i in range(batch_size):
if i == 0:
pixel_acc = torch.div(
torch.sum(torch.eq(x_pred_label[i], x_output_label[i]).float()),
torch.sum((x_output_label[i] >= 0).float()))
else:
pixel_acc = pixel_acc + torch.div(
torch.sum(torch.eq(x_pred_label[i], x_output_label[i]).float()),
torch.sum((x_output_label[i] >= 0).float()))
return pixel_acc / batch_size
# Legacy: compute mIoU and Acc. for each image and average across all images.

# def compute_miou(x_pred, x_output):
# _, x_pred_label = torch.max(x_pred, dim=1)
# x_output_label = x_output
# batch_size = x_pred.size(0)
# class_nb = x_pred.size(1)
# device = x_pred.device
# for i in range(batch_size):
# true_class = 0
# first_switch = True
# invalid_mask = (x_output[i] >= 0).float()
# for j in range(class_nb):
# pred_mask = torch.eq(x_pred_label[i], j * torch.ones(x_pred_label[i].shape).long().to(device))
# true_mask = torch.eq(x_output_label[i], j * torch.ones(x_output_label[i].shape).long().to(device))
# mask_comb = pred_mask.float() + true_mask.float()
# union = torch.sum((mask_comb > 0).float() * invalid_mask) # remove non-defined pixel predictions
# intsec = torch.sum((mask_comb > 1).float())
# if union == 0:
# continue
# if first_switch:
# class_prob = intsec / union
# first_switch = False
# else:
# class_prob = intsec / union + class_prob
# true_class += 1
# if i == 0:
# batch_avg = class_prob / true_class
# else:
# batch_avg = class_prob / true_class + batch_avg
# return batch_avg / batch_size
#
#
# def compute_iou(x_pred, x_output):
# _, x_pred_label = torch.max(x_pred, dim=1)
# x_output_label = x_output
# batch_size = x_pred.size(0)
# for i in range(batch_size):
# if i == 0:
# pixel_acc = torch.div(
# torch.sum(torch.eq(x_pred_label[i], x_output_label[i]).float()),
# torch.sum((x_output_label[i] >= 0).float()))
# else:
# pixel_acc = pixel_acc + torch.div(
# torch.sum(torch.eq(x_pred_label[i], x_output_label[i]).float()),
# torch.sum((x_output_label[i] >= 0).float()))
# return pixel_acc / batch_size


# New mIoU and Acc. formula: accumulate every pixel and average across all pixels in all images
class ConfMatrix(object):
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None

def update(self, pred, target):
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device)
with torch.no_grad():
k = (target >= 0) & (target < n)
inds = n * target[k].to(torch.int64) + pred[k]
self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)

def get_metrics(self):
h = self.mat.float()
acc = torch.diag(h).sum() / h.sum()
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
return torch.mean(iu), acc


def depth_error(x_pred, x_output):
Expand Down Expand Up @@ -126,6 +149,7 @@ def multi_task_trainer(train_loader, test_loader, multi_task_model, device, opti
# iteration for all batches
multi_task_model.train()
train_dataset = iter(train_loader)
conf_mat = ConfMatrix(multi_task_model.class_nb)
for k in range(train_batch):
train_data, train_label, train_depth, train_normal = train_dataset.next()
train_data, train_label = train_data.to(device), train_label.long().to(device)
Expand All @@ -146,17 +170,22 @@ def multi_task_trainer(train_loader, test_loader, multi_task_model, device, opti
loss.backward()
optimizer.step()

# accumulate label prediction for every pixel in training images
conf_mat.update(train_pred[0].argmax(1).flatten(), train_label.flatten())

cost[0] = train_loss[0].item()
cost[1] = compute_miou(train_pred[0], train_label).item()
cost[2] = compute_iou(train_pred[0], train_label).item()
cost[3] = train_loss[1].item()
cost[4], cost[5] = depth_error(train_pred[1], train_depth)
cost[6] = train_loss[2].item()
cost[7], cost[8], cost[9], cost[10], cost[11] = normal_error(train_pred[2], train_normal)
avg_cost[index, :12] += cost[:12] / train_batch

# compute mIoU and acc
avg_cost[index, 1:3] = conf_mat.get_metrics()

# evaluating test data
multi_task_model.eval()
conf_mat = ConfMatrix(multi_task_model.class_nb)
with torch.no_grad(): # operations inside don't track history
test_dataset = iter(test_loader)
for k in range(test_batch):
Expand All @@ -169,16 +198,18 @@ def multi_task_trainer(train_loader, test_loader, multi_task_model, device, opti
model_fit(test_pred[1], test_depth, 'depth'),
model_fit(test_pred[2], test_normal, 'normal')]

conf_mat.update(test_pred[0].argmax(1).flatten(), test_label.flatten())

cost[12] = test_loss[0].item()
cost[13] = compute_miou(test_pred[0], test_label).item()
cost[14] = compute_iou(test_pred[0], test_label).item()
cost[15] = test_loss[1].item()
cost[16], cost[17] = depth_error(test_pred[1], test_depth)
cost[18] = test_loss[2].item()
cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(test_pred[2], test_normal)

avg_cost[index, 12:] += cost[12:] / test_batch

# compute mIoU and acc
avg_cost[index, 13:15] = conf_mat.get_metrics()

scheduler.step()
print('Epoch: {:04d} | TRAIN: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} ||'
'TEST: {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} | {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} {:.4f} '
Expand All @@ -195,7 +226,6 @@ def multi_task_trainer(train_loader, test_loader, multi_task_model, device, opti


def single_task_trainer(train_loader, test_loader, single_task_model, device, optimizer, scheduler, opt, total_epoch=200):
total_epoch = 200
train_batch = len(train_loader)
test_batch = len(test_loader)
avg_cost = np.zeros([total_epoch, 24], dtype=np.float32)
Expand All @@ -205,6 +235,7 @@ def single_task_trainer(train_loader, test_loader, single_task_model, device, op
# iteration for all batches
single_task_model.train()
train_dataset = iter(train_loader)
conf_mat = ConfMatrix(single_task_model.class_nb)
for k in range(train_batch):
train_data, train_label, train_depth, train_normal = train_dataset.next()
train_data, train_label = train_data.to(device), train_label.long().to(device)
Expand All @@ -217,9 +248,9 @@ def single_task_trainer(train_loader, test_loader, single_task_model, device, op
train_loss = model_fit(train_pred, train_label, opt.task)
train_loss.backward()
optimizer.step()

conf_mat.update(train_pred.argmax(1).flatten(), train_label.flatten())
cost[0] = train_loss.item()
cost[1] = compute_miou(train_pred, train_label).item()
cost[2] = compute_iou(train_pred, train_label).item()

if opt.task == 'depth':
train_loss = model_fit(train_pred, train_depth, opt.task)
Expand All @@ -237,8 +268,12 @@ def single_task_trainer(train_loader, test_loader, single_task_model, device, op

avg_cost[index, :12] += cost[:12] / train_batch

if opt.task == 'semantic':
avg_cost[index, 1:3] = conf_mat.get_metrics()

# evaluating test data
single_task_model.eval()
conf_mat = ConfMatrix(single_task_model.class_nb)
with torch.no_grad(): # operations inside don't track history
test_dataset = iter(test_loader)
for k in range(test_batch):
Expand All @@ -250,9 +285,9 @@ def single_task_trainer(train_loader, test_loader, single_task_model, device, op

if opt.task == 'semantic':
test_loss = model_fit(test_pred, test_label, opt.task)

conf_mat.update(test_pred.argmax(1).flatten(), test_label.flatten())
cost[12] = test_loss.item()
cost[13] = compute_miou(test_pred, test_label).item()
cost[14] = compute_iou(test_pred, test_label).item()

if opt.task == 'depth':
test_loss = model_fit(test_pred, test_depth, opt.task)
Expand All @@ -265,6 +300,8 @@ def single_task_trainer(train_loader, test_loader, single_task_model, device, op
cost[19], cost[20], cost[21], cost[22], cost[23] = normal_error(test_pred, test_normal)

avg_cost[index, 12:] += cost[12:] / test_batch
if opt.task == 'semantic':
avg_cost[index, 13:15] = conf_mat.get_metrics()

scheduler.step()
if opt.task == 'semantic':
Expand Down

0 comments on commit f410f88

Please sign in to comment.