-
-
Notifications
You must be signed in to change notification settings - Fork 947
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #100 from arraiyopensource/feat/confusion_matrix
create metrics module and implement confusion_matrix and mean_iou
- Loading branch information
Showing
10 changed files
with
389 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ TGM focuses on Image and tensor warping functions such as: | |
core | ||
image | ||
losses | ||
metrics | ||
contrib | ||
utils | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
torchgeometry.metrics | ||
===================== | ||
|
||
.. currentmodule:: torchgeometry.metrics | ||
|
||
.. autofunction:: confusion_matrix | ||
.. autofunction:: mean_iou |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
import pytest | ||
|
||
import torch | ||
import torchgeometry as tgm | ||
from torch.autograd import gradcheck | ||
|
||
import utils | ||
from common import device_type | ||
|
||
|
||
class TestMeanIoU: | ||
def test_two_classes_perfect(self): | ||
batch_size = 1 | ||
num_classes = 2 | ||
actual = torch.tensor( | ||
[[1, 1, 1, 1, 0, 0, 0, 0]]) | ||
predicted = torch.tensor( | ||
[[1, 1, 1, 1, 0, 0, 0, 0]]) | ||
|
||
mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes) | ||
mean_iou_real = torch.tensor( | ||
[[1.0, 1.0]], dtype=torch.float32) | ||
assert mean_iou.shape == (batch_size, num_classes) | ||
assert utils.check_equal_torch(mean_iou, mean_iou_real) | ||
|
||
def test_two_classes_perfect_batch2(self): | ||
batch_size = 2 | ||
num_classes = 2 | ||
actual = torch.tensor( | ||
[[1, 1, 1, 1, 0, 0, 0, 0]]).repeat(batch_size, 1) | ||
predicted = torch.tensor( | ||
[[1, 1, 1, 1, 0, 0, 0, 0]]).repeat(batch_size, 1) | ||
|
||
mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes) | ||
mean_iou_real = torch.tensor( | ||
[[1.0, 1.0]], dtype=torch.float32) | ||
assert mean_iou.shape == (batch_size, num_classes) | ||
assert utils.check_equal_torch(mean_iou, mean_iou_real) | ||
|
||
def test_two_classes(self): | ||
batch_size = 1 | ||
num_classes = 2 | ||
actual = torch.tensor( | ||
[[1, 1, 1, 1, 0, 0, 0, 0]]) | ||
predicted = torch.tensor( | ||
[[1, 1, 1, 1, 0, 0, 0, 1]]) | ||
|
||
mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes) | ||
mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes) | ||
mean_iou_real = torch.tensor( | ||
[[0.75, 0.80]], dtype=torch.float32) | ||
assert mean_iou.shape == (batch_size, num_classes) | ||
assert utils.check_equal_torch(mean_iou, mean_iou_real) | ||
|
||
def test_four_classes_2d_perfect(self): | ||
batch_size = 1 | ||
num_classes = 4 | ||
actual = torch.tensor( | ||
[[[0, 0, 1, 1], | ||
[0, 0, 1, 1], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
predicted = torch.tensor( | ||
[[[0, 0, 1, 1], | ||
[0, 0, 1, 1], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
|
||
mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes) | ||
mean_iou_real = torch.tensor( | ||
[[1.0, 1.0, 1.0, 1.0]], dtype=torch.float32) | ||
assert mean_iou.shape == (batch_size, num_classes) | ||
assert utils.check_equal_torch(mean_iou, mean_iou_real) | ||
|
||
def test_four_classes_2d_one_class_no_predicted(self): | ||
batch_size = 1 | ||
num_classes = 4 | ||
actual = torch.tensor( | ||
[[[0, 0, 0, 0], | ||
[0, 0, 0, 0], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
predicted = torch.tensor( | ||
[[[3, 3, 2, 2], | ||
[3, 3, 2, 2], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
|
||
mean_iou = tgm.metrics.mean_iou(predicted, actual, num_classes) | ||
mean_iou_real = torch.tensor( | ||
[[0.0, 0.0, 0.5, 0.5]], dtype=torch.float32) | ||
assert mean_iou.shape == (batch_size, num_classes) | ||
assert utils.check_equal_torch(mean_iou, mean_iou_real) | ||
|
||
|
||
class TestConfusionMatrix: | ||
def test_two_classes(self): | ||
num_classes = 2 | ||
actual = torch.tensor( | ||
[[1, 1, 1, 1, 0, 0, 0, 0]]) | ||
predicted = torch.tensor( | ||
[[1, 1, 1, 1, 0, 0, 0, 1]]) | ||
|
||
conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes) | ||
conf_mat_real = torch.tensor( | ||
[[[3, 1], | ||
[0, 4]]], dtype=torch.float32) | ||
assert utils.check_equal_torch(conf_mat, conf_mat_real) | ||
|
||
def test_two_classes_batch2(self): | ||
batch_size = 2 | ||
num_classes = 2 | ||
actual = torch.tensor( | ||
[[1, 1, 1, 1, 0, 0, 0, 0]]).repeat(batch_size, 1) | ||
predicted = torch.tensor( | ||
[[1, 1, 1, 1, 0, 0, 0, 1]]).repeat(batch_size, 1) | ||
|
||
conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes) | ||
conf_mat_real = torch.tensor( | ||
[[[3, 1], | ||
[0, 4]]], dtype=torch.float32) | ||
assert utils.check_equal_torch(conf_mat, conf_mat_real) | ||
|
||
def test_three_classes(self): | ||
num_classes = 3 | ||
actual = torch.tensor( | ||
[[2, 2, 0, 0, 1, 0, 0, 2, 1, 1, 0, 0, 1, 2, 1, 0]]) | ||
predicted = torch.tensor( | ||
[[2, 1, 0, 0, 0, 0, 0, 1, 0, 2, 2, 1, 0, 0, 2, 2]]) | ||
|
||
conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes) | ||
conf_mat_real = torch.tensor( | ||
[[[4, 1, 2], | ||
[3, 0, 2], | ||
[1, 2, 1]]], dtype=torch.float32) | ||
assert utils.check_equal_torch(conf_mat, conf_mat_real) | ||
|
||
def test_three_classes_normalized(self): | ||
num_classes = 3 | ||
normalized = True | ||
actual = torch.tensor( | ||
[[2, 2, 0, 0, 1, 0, 0, 2, 1, 1, 0, 0, 1, 2, 1, 0]]) | ||
predicted = torch.tensor( | ||
[[2, 1, 0, 0, 0, 0, 0, 1, 0, 2, 2, 1, 0, 0, 2, 2]]) | ||
|
||
conf_mat = tgm.metrics.confusion_matrix( | ||
predicted, actual, num_classes, normalized) | ||
|
||
conf_mat_real = torch.tensor( | ||
[[[0.5000, 0.3333, 0.4000], | ||
[0.3750, 0.0000, 0.4000], | ||
[0.1250, 0.6667, 0.2000]]], dtype=torch.float32) | ||
assert utils.check_equal_torch(conf_mat, conf_mat_real) | ||
|
||
def test_four_classes_2d_perfect(self): | ||
num_classes = 4 | ||
actual = torch.tensor( | ||
[[[0, 0, 1, 1], | ||
[0, 0, 1, 1], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
predicted = torch.tensor( | ||
[[[0, 0, 1, 1], | ||
[0, 0, 1, 1], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
|
||
conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes) | ||
conf_mat_real = torch.tensor( | ||
[[[4, 0, 0, 0], | ||
[0, 4, 0, 0], | ||
[0, 0, 4, 0], | ||
[0, 0, 0, 4]]], dtype=torch.float32) | ||
assert utils.check_equal_torch(conf_mat, conf_mat_real) | ||
|
||
def test_four_classes_2d_one_class_nonperfect(self): | ||
num_classes = 4 | ||
actual = torch.tensor( | ||
[[[0, 0, 1, 1], | ||
[0, 0, 1, 1], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
predicted = torch.tensor( | ||
[[[0, 0, 1, 1], | ||
[0, 3, 0, 1], | ||
[2, 2, 1, 3], | ||
[2, 2, 3, 3]]]) | ||
|
||
conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes) | ||
conf_mat_real = torch.tensor( | ||
[[[3, 0, 0, 1], | ||
[1, 3, 0, 0], | ||
[0, 0, 4, 0], | ||
[0, 1, 0, 3]]], dtype=torch.float32) | ||
assert utils.check_equal_torch(conf_mat, conf_mat_real) | ||
|
||
def test_four_classes_2d_one_class_missing(self): | ||
num_classes = 4 | ||
actual = torch.tensor( | ||
[[[0, 0, 1, 1], | ||
[0, 0, 1, 1], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
predicted = torch.tensor( | ||
[[[3, 3, 1, 1], | ||
[3, 3, 1, 1], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
|
||
conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes) | ||
conf_mat_real = torch.tensor( | ||
[[[0, 0, 0, 4], | ||
[0, 4, 0, 0], | ||
[0, 0, 4, 0], | ||
[0, 0, 0, 4]]], dtype=torch.float32) | ||
assert utils.check_equal_torch(conf_mat, conf_mat_real) | ||
|
||
def test_four_classes_2d_one_class_no_predicted(self): | ||
num_classes = 4 | ||
actual = torch.tensor( | ||
[[[0, 0, 0, 0], | ||
[0, 0, 0, 0], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
predicted = torch.tensor( | ||
[[[3, 3, 2, 2], | ||
[3, 3, 2, 2], | ||
[2, 2, 3, 3], | ||
[2, 2, 3, 3]]]) | ||
|
||
conf_mat = tgm.metrics.confusion_matrix(predicted, actual, num_classes) | ||
conf_mat_real = torch.tensor( | ||
[[[0, 0, 4, 4], | ||
[0, 0, 0, 0], | ||
[0, 0, 4, 0], | ||
[0, 0, 0, 4]]], dtype=torch.float32) | ||
assert utils.check_equal_torch(conf_mat, conf_mat_real) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .confusion_matrix import confusion_matrix | ||
from .mean_iou import mean_iou |
Oops, something went wrong.