Skip to content

Commit

Permalink
Merge pull request #100 from arraiyopensource/feat/confusion_matrix
Browse files Browse the repository at this point in the history
create metrics module and implement confusion_matrix and mean_iou
  • Loading branch information
edgarriba committed Mar 26, 2019
2 parents 189fb34 + 2de47a1 commit 3678dcf
Show file tree
Hide file tree
Showing 10 changed files with 389 additions and 26 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ TGM focuses on Image and tensor warping functions such as:
core
image
losses
metrics
contrib
utils

Expand Down
7 changes: 7 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
torchgeometry.metrics
=====================

.. currentmodule:: torchgeometry.metrics

.. autofunction:: confusion_matrix
.. autofunction:: mean_iou
2 changes: 2 additions & 0 deletions mypy_files.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ torchgeometry/losses/tversky.py
torchgeometry/losses/depth_smooth.py
torchgeometry/contrib/spatial_soft_argmax2d.py
torchgeometry/contrib/extract_patches.py
torchgeometry/metrics/confusion_matrix.py
torchgeometry/metrics/mean_iou.py
2 changes: 1 addition & 1 deletion setup_travis_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fi

# Install CPU-PyTorch
$sdk_dir/.dev_env/bin/conda install -y \
pytorch-nightly \
pytorch-cpu==1.0.1 \
-c pytorch

# Tests dependencies
Expand Down
37 changes: 12 additions & 25 deletions test/test_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@


class TestExtractTensorPatches:
def _test_smoke(self):
def test_smoke(self):
input = torch.arange(16.).view(1, 1, 4, 4)
m = tgm.contrib.ExtractTensorPatches(3)
assert m(input).shape == (1, 4, 1, 3, 3)

def _test_b1_ch1_h4w4_ws3(self):
def test_b1_ch1_h4w4_ws3(self):
input = torch.arange(16.).view(1, 1, 4, 4)
m = tgm.contrib.ExtractTensorPatches(3)
patches = m(input)
Expand All @@ -23,7 +23,7 @@ def _test_b1_ch1_h4w4_ws3(self):
assert utils.check_equal_torch(input[0, :, 1:, :3], patches[0, 2])
assert utils.check_equal_torch(input[0, :, 1:, 1:], patches[0, 3])

def _test_b1_ch2_h4w4_ws3(self):
def test_b1_ch2_h4w4_ws3(self):
input = torch.arange(16.).view(1, 1, 4, 4)
input = input.expand(-1, 2, -1, -1) # copy all channels
m = tgm.contrib.ExtractTensorPatches(3)
Expand All @@ -34,7 +34,7 @@ def _test_b1_ch2_h4w4_ws3(self):
assert utils.check_equal_torch(input[0, :, 1:, :3], patches[0, 2])
assert utils.check_equal_torch(input[0, :, 1:, 1:], patches[0, 3])

def _test_b1_ch1_h4w4_ws2(self):
def test_b1_ch1_h4w4_ws2(self):
input = torch.arange(16.).view(1, 1, 4, 4)
m = tgm.contrib.ExtractTensorPatches(2)
patches = m(input)
Expand All @@ -44,7 +44,7 @@ def _test_b1_ch1_h4w4_ws2(self):
assert utils.check_equal_torch(input[0, :, 1:3, 1:3], patches[0, 4])
assert utils.check_equal_torch(input[0, :, 2:4, 1:3], patches[0, 7])

def _test_b1_ch1_h4w4_ws2_stride2(self):
def test_b1_ch1_h4w4_ws2_stride2(self):
input = torch.arange(16.).view(1, 1, 4, 4)
m = tgm.contrib.ExtractTensorPatches(2, stride=2)
patches = m(input)
Expand All @@ -54,7 +54,7 @@ def _test_b1_ch1_h4w4_ws2_stride2(self):
assert utils.check_equal_torch(input[0, :, 2:4, 0:2], patches[0, 2])
assert utils.check_equal_torch(input[0, :, 2:4, 2:4], patches[0, 3])

def _test_b1_ch1_h4w4_ws2_stride21(self):
def test_b1_ch1_h4w4_ws2_stride21(self):
input = torch.arange(16.).view(1, 1, 4, 4)
m = tgm.contrib.ExtractTensorPatches(2, stride=(2, 1))
patches = m(input)
Expand All @@ -64,7 +64,7 @@ def _test_b1_ch1_h4w4_ws2_stride21(self):
assert utils.check_equal_torch(input[0, :, 2:4, 0:2], patches[0, 3])
assert utils.check_equal_torch(input[0, :, 2:4, 2:4], patches[0, 5])

def _test_b1_ch1_h3w3_ws2_stride1_padding1(self):
def test_b1_ch1_h3w3_ws2_stride1_padding1(self):
input = torch.arange(9.).view(1, 1, 3, 3)
m = tgm.contrib.ExtractTensorPatches(2, stride=1, padding=1)
patches = m(input)
Expand All @@ -74,7 +74,7 @@ def _test_b1_ch1_h3w3_ws2_stride1_padding1(self):
assert utils.check_equal_torch(input[0, :, 1:3, 0:2], patches[0, 9])
assert utils.check_equal_torch(input[0, :, 1:3, 1:3], patches[0, 10])

def _test_b2_ch1_h3w3_ws2_stride1_padding1(self):
def test_b2_ch1_h3w3_ws2_stride1_padding1(self):
batch_size = 2
input = torch.arange(9.).view(1, 1, 3, 3)
input = input.expand(batch_size, -1, -1, -1)
Expand All @@ -91,15 +91,15 @@ def _test_b2_ch1_h3w3_ws2_stride1_padding1(self):
assert utils.check_equal_torch(
input[i, :, 1:3, 1:3], patches[i, 10])

def _test_b1_ch1_h3w3_ws23(self):
def test_b1_ch1_h3w3_ws23(self):
input = torch.arange(9.).view(1, 1, 3, 3)
m = tgm.contrib.ExtractTensorPatches((2, 3))
patches = m(input)
assert patches.shape == (1, 2, 1, 2, 3)
assert utils.check_equal_torch(input[0, :, 0:2, 0:3], patches[0, 0])
assert utils.check_equal_torch(input[0, :, 1:3, 0:3], patches[0, 1])

def _test_b1_ch1_h3w4_ws23(self):
def test_b1_ch1_h3w4_ws23(self):
input = torch.arange(12.).view(1, 1, 3, 4)
m = tgm.contrib.ExtractTensorPatches((2, 3))
patches = m(input)
Expand All @@ -110,28 +110,15 @@ def _test_b1_ch1_h3w4_ws23(self):
assert utils.check_equal_torch(input[0, :, 1:3, 1:4], patches[0, 3])

# TODO: implement me
def _test_jit(self):
def test_jit(self):
pass

def _test_gradcheck(self):
def test_gradcheck(self):
input = torch.rand(2, 3, 4, 4)
input = utils.tensor_to_gradcheck_var(input) # to var
assert gradcheck(tgm.contrib.extract_tensor_patches,
(input, 3,), raise_exception=True)

def test_run_all(self):
self._test_smoke()
self._test_b1_ch1_h4w4_ws3()
self._test_b1_ch2_h4w4_ws3()
self._test_b1_ch1_h4w4_ws2()
self._test_b1_ch1_h3w3_ws23()
self._test_b1_ch1_h3w4_ws23()
self._test_b1_ch1_h4w4_ws2_stride2()
self._test_b1_ch1_h4w4_ws2_stride21()
self._test_b1_ch1_h3w3_ws2_stride1_padding1()
self._test_b2_ch1_h3w3_ws2_stride1_padding1()
self._test_gradcheck()


class TestSoftArgmax2d:
def _test_smoke(self):
Expand Down
237 changes: 237 additions & 0 deletions test/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import pytest

import torch
import torchgeometry as tgm
from torch.autograd import gradcheck

import utils
from common import device_type


class TestMeanIoU:
def test_two_classes_perfect(self):
batch_size = 1
num_classes = 2
actual = torch.tensor(
[[1, 1, 1, 1, 0, 0, 0, 0]])
predicted = torch.tensor(
[[1, 1, 1, 1, 0, 0, 0, 0]])

mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes)
mean_iou_real = torch.tensor(
[[1.0, 1.0]], dtype=torch.float32)
assert mean_iou.shape == (batch_size, num_classes)
assert utils.check_equal_torch(mean_iou, mean_iou_real)

def test_two_classes_perfect_batch2(self):
batch_size = 2
num_classes = 2
actual = torch.tensor(
[[1, 1, 1, 1, 0, 0, 0, 0]]).repeat(batch_size, 1)
predicted = torch.tensor(
[[1, 1, 1, 1, 0, 0, 0, 0]]).repeat(batch_size, 1)

mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes)
mean_iou_real = torch.tensor(
[[1.0, 1.0]], dtype=torch.float32)
assert mean_iou.shape == (batch_size, num_classes)
assert utils.check_equal_torch(mean_iou, mean_iou_real)

def test_two_classes(self):
batch_size = 1
num_classes = 2
actual = torch.tensor(
[[1, 1, 1, 1, 0, 0, 0, 0]])
predicted = torch.tensor(
[[1, 1, 1, 1, 0, 0, 0, 1]])

mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes)
mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes)
mean_iou_real = torch.tensor(
[[0.75, 0.80]], dtype=torch.float32)
assert mean_iou.shape == (batch_size, num_classes)
assert utils.check_equal_torch(mean_iou, mean_iou_real)

def test_four_classes_2d_perfect(self):
batch_size = 1
num_classes = 4
actual = torch.tensor(
[[[0, 0, 1, 1],
[0, 0, 1, 1],
[2, 2, 3, 3],
[2, 2, 3, 3]]])
predicted = torch.tensor(
[[[0, 0, 1, 1],
[0, 0, 1, 1],
[2, 2, 3, 3],
[2, 2, 3, 3]]])

mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes)
mean_iou_real = torch.tensor(
[[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32)
assert mean_iou.shape == (batch_size, num_classes)
assert utils.check_equal_torch(mean_iou, mean_iou_real)

def test_four_classes_2d_one_class_no_predicted(self):
batch_size = 1
num_classes = 4
actual = torch.tensor(
[[[0, 0, 0, 0],
[0, 0, 0, 0],
[2, 2, 3, 3],
[2, 2, 3, 3]]])
predicted = torch.tensor(
[[[3, 3, 2, 2],
[3, 3, 2, 2],
[2, 2, 3, 3],
[2, 2, 3, 3]]])

mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes)
mean_iou_real = torch.tensor(
[[0.0, 0.0, 0.5, 0.5]], dtype=torch.float32)
assert mean_iou.shape == (batch_size, num_classes)
assert utils.check_equal_torch(mean_iou, mean_iou_real)


class TestConfusionMatrix:
def test_two_classes(self):
num_classes = 2
actual = torch.tensor(
[[1, 1, 1, 1, 0, 0, 0, 0]])
predicted = torch.tensor(
[[1, 1, 1, 1, 0, 0, 0, 1]])

conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes)
conf_mat_real = torch.tensor(
[[[3, 1],
[0, 4]]], dtype=torch.float32)
assert utils.check_equal_torch(conf_mat, conf_mat_real)

def test_two_classes_batch2(self):
batch_size = 2
num_classes = 2
actual = torch.tensor(
[[1, 1, 1, 1, 0, 0, 0, 0]]).repeat(batch_size, 1)
predicted = torch.tensor(
[[1, 1, 1, 1, 0, 0, 0, 1]]).repeat(batch_size, 1)

conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes)
conf_mat_real = torch.tensor(
[[[3, 1],
[0, 4]]], dtype=torch.float32)
assert utils.check_equal_torch(conf_mat, conf_mat_real)

def test_three_classes(self):
num_classes = 3
actual = torch.tensor(
[[2, 2, 0, 0, 1, 0, 0, 2, 1, 1, 0, 0, 1, 2, 1, 0]])
predicted = torch.tensor(
[[2, 1, 0, 0, 0, 0, 0, 1, 0, 2, 2, 1, 0, 0, 2, 2]])

conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes)
conf_mat_real = torch.tensor(
[[[4, 1, 2],
[3, 0, 2],
[1, 2, 1]]], dtype=torch.float32)
assert utils.check_equal_torch(conf_mat, conf_mat_real)

def test_three_classes_normalized(self):
num_classes = 3
normalized = True
actual = torch.tensor(
[[2, 2, 0, 0, 1, 0, 0, 2, 1, 1, 0, 0, 1, 2, 1, 0]])
predicted = torch.tensor(
[[2, 1, 0, 0, 0, 0, 0, 1, 0, 2, 2, 1, 0, 0, 2, 2]])

conf_mat = tgm.metrics.confusion_matrix(
predicted, actual, num_classes, normalized)

conf_mat_real = torch.tensor(
[[[0.5000, 0.3333, 0.4000],
[0.3750, 0.0000, 0.4000],
[0.1250, 0.6667, 0.2000]]], dtype=torch.float32)
assert utils.check_equal_torch(conf_mat, conf_mat_real)

def test_four_classes_2d_perfect(self):
num_classes = 4
actual = torch.tensor(
[[[0, 0, 1, 1],
[0, 0, 1, 1],
[2, 2, 3, 3],
[2, 2, 3, 3]]])
predicted = torch.tensor(
[[[0, 0, 1, 1],
[0, 0, 1, 1],
[2, 2, 3, 3],
[2, 2, 3, 3]]])

conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes)
conf_mat_real = torch.tensor(
[[[4, 0, 0, 0],
[0, 4, 0, 0],
[0, 0, 4, 0],
[0, 0, 0, 4]]], dtype=torch.float32)
assert utils.check_equal_torch(conf_mat, conf_mat_real)

def test_four_classes_2d_one_class_nonperfect(self):
num_classes = 4
actual = torch.tensor(
[[[0, 0, 1, 1],
[0, 0, 1, 1],
[2, 2, 3, 3],
[2, 2, 3, 3]]])
predicted = torch.tensor(
[[[0, 0, 1, 1],
[0, 3, 0, 1],
[2, 2, 1, 3],
[2, 2, 3, 3]]])

conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes)
conf_mat_real = torch.tensor(
[[[3, 0, 0, 1],
[1, 3, 0, 0],
[0, 0, 4, 0],
[0, 1, 0, 3]]], dtype=torch.float32)
assert utils.check_equal_torch(conf_mat, conf_mat_real)

def test_four_classes_2d_one_class_missing(self):
num_classes = 4
actual = torch.tensor(
[[[0, 0, 1, 1],
[0, 0, 1, 1],
[2, 2, 3, 3],
[2, 2, 3, 3]]])
predicted = torch.tensor(
[[[3, 3, 1, 1],
[3, 3, 1, 1],
[2, 2, 3, 3],
[2, 2, 3, 3]]])

conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes)
conf_mat_real = torch.tensor(
[[[0, 0, 0, 4],
[0, 4, 0, 0],
[0, 0, 4, 0],
[0, 0, 0, 4]]], dtype=torch.float32)
assert utils.check_equal_torch(conf_mat, conf_mat_real)

def test_four_classes_2d_one_class_no_predicted(self):
num_classes = 4
actual = torch.tensor(
[[[0, 0, 0, 0],
[0, 0, 0, 0],
[2, 2, 3, 3],
[2, 2, 3, 3]]])
predicted = torch.tensor(
[[[3, 3, 2, 2],
[3, 3, 2, 2],
[2, 2, 3, 3],
[2, 2, 3, 3]]])

conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes)
conf_mat_real = torch.tensor(
[[[0, 0, 4, 4],
[0, 0, 0, 0],
[0, 0, 4, 0],
[0, 0, 0, 4]]], dtype=torch.float32)
assert utils.check_equal_torch(conf_mat, conf_mat_real)
1 change: 1 addition & 0 deletions torchgeometry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torchgeometry import losses
from torchgeometry import contrib
from torchgeometry import utils
from torchgeometry import metrics

# Exposes ``torchgeometry.core`` package to top level
from .core.homography_warper import HomographyWarper, homography_warp
Expand Down
2 changes: 2 additions & 0 deletions torchgeometry/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .confusion_matrix import confusion_matrix
from .mean_iou import mean_iou
Loading

0 comments on commit 3678dcf

Please sign in to comment.