Skip to content

Commit

Permalink
Merge pull request #153 from arraiyopensource/feat/denormalize_points
Browse files Browse the repository at this point in the history
Feat/denormalize points
  • Loading branch information
edgarriba committed Jun 4, 2019
2 parents f940ab8 + e21e3cd commit 5436260
Show file tree
Hide file tree
Showing 30 changed files with 237 additions and 252 deletions.
1 change: 1 addition & 0 deletions docs/source/geometry.conversions.rst
Expand Up @@ -14,3 +14,4 @@ kornia.geometry.conversions
.. autofunction:: angle_axis_to_quaternion
.. autofunction:: rtvec_to_pose
.. autofunction:: normalize_pixel_coordinates
.. autofunction:: denormalize_pixel_coordinates
41 changes: 35 additions & 6 deletions kornia/geometry/conversions.py
Expand Up @@ -15,6 +15,7 @@
"angle_axis_to_quaternion",
"rtvec_to_pose",
"normalize_pixel_coordinates",
"denormalize_pixel_coordinates",
]

EPS = 1e-6
Expand Down Expand Up @@ -441,7 +442,7 @@ def normalize_pixel_coordinates(
height (int): the maximum height in the y-axis.
Return:
torch.Tensor: the nornmalized pixel coordinates.
torch.Tensor: the normalized pixel coordinates.
"""
if pixel_coordinates.shape[-1] != 2:
raise ValueError("Input pixel_coordinates must be of shape (*, 2). "
Expand All @@ -451,12 +452,40 @@ def normalize_pixel_coordinates(
torch.tensor(width), torch.tensor(height)
]).to(pixel_coordinates.device).to(pixel_coordinates.dtype)

factor: torch.Tensor = torch.tensor(2.) / (hw - torch.tensor(1.))
factor: torch.Tensor = torch.tensor(2.) / (hw - 1)

# normalize coordinates and return
pixel_coordinates_norm: torch.Tensor = \
factor * pixel_coordinates - torch.tensor(1.)
return pixel_coordinates_norm
return factor * pixel_coordinates - 1


def denormalize_pixel_coordinates(
pixel_coordinates: torch.Tensor,
height: int,
width: int) -> torch.Tensor:
r"""Denormalize pixel coordinates.
The input is assumed to be -1 if on extreme left, 1 if on
extreme right (x = w-1).
Args:
pixel_coordinate (torch.Tensor): the normalized grid coordinates.
Shape can be :math:`(*, 2)`.
width (int): the maximum width in the x-axis.
height (int): the maximum height in the y-axis.
Return:
torch.Tensor: the denormalized pixel coordinates.
"""
if pixel_coordinates.shape[-1] != 2:
raise ValueError("Input pixel_coordinates must be of shape (*, 2). "
"Got {}".format(pixel_coordinates.shape))
# compute normalization factor
hw: torch.Tensor = torch.stack([
torch.tensor(width), torch.tensor(height)
]).to(pixel_coordinates.device).to(pixel_coordinates.dtype)

factor: torch.Tensor = torch.tensor(2.) / (hw - 1)

return torch.tensor(1.) / factor * (pixel_coordinates + 1)


# TODO: add below funtionalities
Expand Down
10 changes: 8 additions & 2 deletions test/utils.py → kornia/testing/__init__.py
@@ -1,8 +1,14 @@
"""
The testing package contains testing-specific utilities.
"""


import torch
import numpy as np


# test utilites
__all__ = [
'tensor_to_gradcheck_var', 'create_eye_batch',
]


def create_pinhole(fx, fy, cx, cy, height, width, rx, ry, rz, tx, ty, tz):
Expand Down
Empty file added test/__init__.py
Empty file.
20 changes: 10 additions & 10 deletions test/color/test_gray.py
@@ -1,35 +1,35 @@
import pytest

import kornia
import kornia.testing as utils # test utils
from test.common import device_type

import torch
from torch.autograd import gradcheck
from torch.testing import assert_allclose
from common import device_type

import kornia as K
import kornia.color as color
import utils


class TestRgbToGrayscale:
def test_rgb_to_grayscale(self):
channels, height, width = 3, 4, 5
img = torch.ones(channels, height, width)
assert K.rgb_to_grayscale(img).shape == (1, height, width)
assert kornia.rgb_to_grayscale(img).shape == (1, height, width)

def test_rgb_to_grayscale_batch(self):
batch_size, channels, height, width = 2, 3, 4, 5
img = torch.ones(batch_size, channels, height, width)
assert K.rgb_to_grayscale(img).shape == \
assert kornia.rgb_to_grayscale(img).shape == \
(batch_size, 1, height, width)

def test_gradcheck(self):
batch_size, channels, height, width = 2, 3, 4, 5
img = torch.ones(batch_size, channels, height, width)
img = utils.tensor_to_gradcheck_var(img) # to var
assert gradcheck(K.rgb_to_grayscale, (img,), raise_exception=True)
assert gradcheck(kornia.rgb_to_grayscale, (img,), raise_exception=True)

def test_jit(self):
batch_size, channels, height, width = 2, 3, 64, 64
img = torch.ones(batch_size, channels, height, width)
gray = color.RgbToGrayscale()
gray_traced = torch.jit.trace(color.RgbToGrayscale(), img)
gray = kornia.color.RgbToGrayscale()
gray_traced = torch.jit.trace(kornia.color.RgbToGrayscale(), img)
assert_allclose(gray(img), gray_traced(img))
27 changes: 14 additions & 13 deletions test/color/test_hsv.py
@@ -1,11 +1,12 @@
import pytest

import kornia
import kornia.testing as utils # test utils
from test.common import device_type

import torch
from torch.autograd import gradcheck
from torch.testing import assert_allclose
from common import device_type

import kornia.color as color
import utils


class TestRgbToHsv:
Expand All @@ -30,7 +31,7 @@ def test_rgb_to_hsv(self):
[[21.0000 / 255, 22.0000 / 255],
[22.0000 / 255, 22.0000 / 255]]])

f = color.RgbToHsv()
f = kornia.color.RgbToHsv()
assert_allclose(f(data / 255), expected, atol=1e-4, rtol=1e-5)

def test_batch_rgb_to_hsv(self):
Expand All @@ -52,7 +53,7 @@ def test_batch_rgb_to_hsv(self):

[[21.0000 / 255, 22.0000 / 255],
[22.0000 / 255, 22.0000 / 255]]]) # 3x2x2
f = color.RgbToHsv()
f = kornia.color.RgbToHsv()
data = data.repeat(2, 1, 1, 1) # 2x3x2x2
expected = expected.repeat(2, 1, 1, 1) # 2x3x2x2
assert_allclose(f(data / 255), expected, atol=1e-4, rtol=1e-5)
Expand All @@ -70,14 +71,14 @@ def test_gradcheck(self):

data = utils.tensor_to_gradcheck_var(data) # to var

assert gradcheck(color.RgbToHsv(), (data,),
assert gradcheck(kornia.color.RgbToHsv(), (data,),
raise_exception=True)

def test_jit(self):
@torch.jit.script
def op_script(data: torch.Tensor) -> torch.Tensor:

return color.rgb_to_hsv(data)
return kornia.rgb_to_hsv(data)
data = torch.tensor([[[[21., 22.],
[22., 22.]],

Expand All @@ -88,7 +89,7 @@ def op_script(data: torch.Tensor) -> torch.Tensor:
[8., 8.]]]]) # 3x2x2

actual = op_script(data)
expected = color.rgb_to_hsv(data)
expected = kornia.rgb_to_hsv(data)
assert_allclose(actual, expected)


Expand All @@ -114,7 +115,7 @@ def test_hsv_to_rgb(self):
[[21.0000 / 255, 22.0000 / 255],
[22.0000 / 255, 22.0000 / 255]]])

f = color.HsvToRgb()
f = kornia.color.HsvToRgb()
assert_allclose(f(data), expected / 255, atol=1e-3, rtol=1e-3)

def test_batch_hsv_to_rgb(self):
Expand All @@ -137,15 +138,15 @@ def test_batch_hsv_to_rgb(self):
[[21.0000 / 255, 22.0000 / 255],
[22.0000 / 255, 22.0000 / 255]]]) # 3x2x2

f = color.HsvToRgb()
f = kornia.color.HsvToRgb()
data = data.repeat(2, 1, 1, 1) # 2x3x2x2
expected = expected.repeat(2, 1, 1, 1) # 2x3x2x2
assert_allclose(f(data), expected / 255, atol=1e-3, rtol=1e-3)

def test_jit(self):
@torch.jit.script
def op_script(data: torch.Tensor) -> torch.Tensor:
return color.hsv_to_rgb(data)
return kornia.hsv_to_rgb(data)

data = torch.tensor([[[[21., 22.],
[22., 22.]],
Expand All @@ -157,5 +158,5 @@ def op_script(data: torch.Tensor) -> torch.Tensor:
[8., 8.]]]]) # 3x2x2

actual = op_script(data)
expected = color.hsv_to_rgb(data)
expected = kornia.hsv_to_rgb(data)
assert_allclose(actual, expected)
22 changes: 11 additions & 11 deletions test/color/test_normalize.py
@@ -1,19 +1,20 @@
import pytest

import kornia
import kornia.testing as utils # test utils
from test.common import device_type

import torch
from torch.autograd import gradcheck
from torch.testing import assert_allclose
from common import device_type

import kornia.color as color
import utils


class TestNormalize:
def test_smoke(self):
mean = [0.5]
std = [0.1]
repr = 'Normalize(mean=[0.5], std=[0.1])'
assert str(color.Normalize(mean, std)) == repr
assert str(kornia.color.Normalize(mean, std)) == repr

def test_normalize(self):

Expand All @@ -25,7 +26,7 @@ def test_normalize(self):
# expected output
expected = torch.tensor([0.25]).repeat(1, 2, 2).view_as(data)

f = color.Normalize(mean, std)
f = kornia.color.Normalize(mean, std)
assert_allclose(f(data), expected)

def test_broadcast_normalize(self):
Expand All @@ -40,7 +41,7 @@ def test_broadcast_normalize(self):
# expected output
expected = torch.tensor([1.25, 1, 0.5]).repeat(2, 1, 1).view_as(data)

f = color.Normalize(mean, std)
f = kornia.color.Normalize(mean, std)
assert_allclose(f(data), expected)

def test_batch_normalize(self):
Expand All @@ -55,15 +56,14 @@ def test_batch_normalize(self):
# expected output
expected = torch.tensor([1.25, 1, 0.5]).repeat(2, 1, 1).view_as(data)

f = color.Normalize(mean, std)
f = kornia.color.Normalize(mean, std)
assert_allclose(f(data), expected)

def test_jit(self):
@torch.jit.script
def op_script(data: torch.Tensor, mean: torch.Tensor,
std: torch.Tensor) -> torch.Tensor:

return color.normalize(data, mean, std)
return kornia.normalize(data, mean, std)

data = torch.ones(2, 3, 1, 1)
data += 2
Expand All @@ -85,5 +85,5 @@ def test_gradcheck(self):

data = utils.tensor_to_gradcheck_var(data) # to var

assert gradcheck(color.Normalize(mean, std), (data,),
assert gradcheck(kornia.color.Normalize(mean, std), (data,),
raise_exception=True)
31 changes: 15 additions & 16 deletions test/color/test_rgb.py
@@ -1,11 +1,12 @@
import pytest

import kornia
import kornia.testing as utils # test utils
from test.common import device_type

import torch
from torch.autograd import gradcheck
from torch.testing import assert_allclose
from common import device_type

import kornia.color as color
import utils


class TestBgrToRgb:
Expand All @@ -31,7 +32,7 @@ def test_bgr_to_rgb(self):
[[1., 1.],
[1., 1.]]]) # 3x2x2

f = color.BgrToRgb()
f = kornia.color.BgrToRgb()
assert_allclose(f(data), expected)

def test_batch_bgr_to_rgb(self):
Expand Down Expand Up @@ -73,7 +74,7 @@ def test_batch_bgr_to_rgb(self):
[[1., 1.],
[1., 1.]]]]) # 2x3x2x2

f = color.BgrToRgb()
f = kornia.color.BgrToRgb()
out = f(data)
assert_allclose(out, expected)

Expand All @@ -91,14 +92,13 @@ def test_gradcheck(self):

data = utils.tensor_to_gradcheck_var(data) # to var

assert gradcheck(color.BgrToRgb(), (data,),
assert gradcheck(kornia.color.BgrToRgb(), (data,),
raise_exception=True)

def test_jit(self):
@torch.jit.script
def op_script(data: torch.Tensor) -> torch.Tensor:

return color.bgr_to_rgb(data)
return kornia.bgr_to_rgb(data)

data = torch.tensor([[[1., 1.],
[1., 1.]],
Expand All @@ -110,7 +110,7 @@ def op_script(data: torch.Tensor) -> torch.Tensor:
[3., 3.]]]) # 3x2x2

actual = op_script(data)
expected = color.bgr_to_rgb(data)
expected = kornia.bgr_to_rgb(data)
assert_allclose(actual, expected)


Expand All @@ -137,7 +137,7 @@ def test_rgb_to_bgr(self):
[[1., 1.],
[1., 1.]]]) # 3x2x2

f = color.RgbToBgr()
f = kornia.color.RgbToBgr()
assert_allclose(f(data), expected)

def test_gradcheck(self):
Expand All @@ -154,14 +154,13 @@ def test_gradcheck(self):

data = utils.tensor_to_gradcheck_var(data) # to var

assert gradcheck(color.RgbToBgr(), (data,),
assert gradcheck(kornia.color.RgbToBgr(), (data,),
raise_exception=True)

def test_jit(self):
@torch.jit.script
def op_script(data: torch.Tensor) -> torch.Tensor:

return color.rgb_to_bgr(data)
return kornia.rgb_to_bgr(data)

data = torch.tensor([[[1., 1.],
[1., 1.]],
Expand All @@ -173,7 +172,7 @@ def op_script(data: torch.Tensor) -> torch.Tensor:
[3., 3.]]]) # 3x2x2

actual = op_script(data)
expected = color.rgb_to_bgr(data)
expected = kornia.rgb_to_bgr(data)
assert_allclose(actual, expected)

def test_batch_rgb_to_bgr(self):
Expand Down Expand Up @@ -215,6 +214,6 @@ def test_batch_rgb_to_bgr(self):
[[1., 1.],
[1., 1.]]]]) # 2x3x2x2

f = color.RgbToBgr()
f = kornia.color.RgbToBgr()
out = f(data)
assert_allclose(out, expected)

0 comments on commit 5436260

Please sign in to comment.