In [20]:
import torch
import torch.nn as nn

input_size = (4, 512, 512)
input = torch.randn((5,) + input_size)
print(f"input.shape = {input.shape}")
print(f"input.min(), .max() = {input.min()}, {input.max()}")

localization = nn.Sequential(
    nn.Conv2d(input_size[0], 8, kernel_size=7),
    nn.MaxPool2d(2, stride=2),
    # nn.BatchNorm2d(8),
    nn.ReLU(True),
    nn.Conv2d(8, 16, kernel_size=5),
    nn.MaxPool2d(2, stride=2),
    # nn.BatchNorm2d(10),
    nn.ReLU(True),
    nn.Conv2d(16, 32, kernel_size=3),
    nn.MaxPool2d(2, stride=2),
    # nn.BatchNorm2d(10),
    nn.ReLU(True),
)
loc_out = localization(input)
print(f"loc_out.shape = {loc_out.shape}")
print(f"loc_out[0].numel() = {loc_out[0].numel()}")

loc_out = loc_out.view(-1, loc_out[0].numel())
print(f"loc_out.shape = {loc_out.shape}")

input.shape = torch.Size([5, 4, 512, 512])
input.min(), .max() = -5.513580799102783, 5.091245174407959
loc_out.shape = torch.Size([5, 32, 61, 61])
loc_out[0].numel() = 119072
loc_out.shape = torch.Size([5, 119072])


In [24]:
fc_xlate = nn.Sequential(
    nn.Linear(loc_out[0].numel(), 32),
    nn.BatchNorm1d(32),
    nn.ReLU(True),
    nn.Linear(32, 2),
)
fc_xlate_out = fc_xlate(loc_out)
fc_xlate_out *= 0.1
fc_xlate_out = fc_xlate_out.view(-1, 2)
print(f"fc_xlate_out.shape = {fc_xlate_out.shape}")
print(f"fc_xlate_out = {fc_xlate_out}")

fc_xlate_out.shape = torch.Size([5, 2])
fc_xlate_out = tensor([[0.1160, 0.0276],
        [0.0182, 0.0271],
        [0.0780, 0.0421],
        [0.0081, 0.0379],
        [0.0816, 0.0503]], grad_fn=<ViewBackward0>)


In [25]:
fc_rotate = nn.Sequential(
    nn.Linear(loc_out[0].numel(), 32),
    nn.BatchNorm1d(32),
    nn.ReLU(True),
    nn.Linear(32, 1),
)
fc_rotate_out = fc_rotate(loc_out)
fc_rotate_out *= 10.0
fc_rotate_out = fc_rotate_out.view(-1, 1)
print(f"fc_rotate.shape = {fc_rotate_out.shape}")
print(f"fc_rotate = {fc_rotate_out}")

fc_rotate.shape = torch.Size([5, 1])
fc_rotate = tensor([[11.8885],
        [ 4.6743],
        [ 6.1093],
        [ 5.4224],
        [ 3.3127]], grad_fn=<ViewBackward0>)


In [41]:
sa = torch.sin(fc_rotate_out)
ca = torch.cos(fc_rotate_out)

matrix = torch.stack(
    (
        ca,
        -sa,
        (fc_xlate_out[:, 0] - 0.5).view(-1, 1) + 0.5 * (sa + ca),
        sa,
        ca,
        (fc_xlate_out[:, 1] - 0.5).view(-1, 1) + 0.5 * (ca - sa),
    ),
    dim=-1,
)
matrix = matrix.view(-1, 2, 3)
print(matrix)

tensor([[[ 0.7789,  0.6271, -0.3081],
         [-0.6271,  0.7789,  0.2307]],

        [[-0.0381,  0.9993, -1.0005],
         [-0.9993, -0.0381,  0.0077]],

        [[ 0.9849,  0.1730, -0.0161],
         [-0.1730,  0.9849,  0.1211]],

        [[ 0.6518,  0.7584, -0.5451],
         [-0.7584,  0.6518,  0.2430]],

        [[-0.9854,  0.1703, -0.9962],
         [-0.1703, -0.9854, -0.8573]]], grad_fn=<ViewBackward0>)


In [40]:
(fc_xlate_out[:, 0] - 0.5).view(-1, 1) + 0.5 * (sa + ca)

tensor([[-0.3081],
        [-1.0005],
        [-0.0161],
        [-0.5451],
        [-0.9962]], grad_fn=<AddBackward0>)

In [38]:
0.5 * (sa + ca)

tensor([[ 0.0759],
        [-0.5187],
        [ 0.4060],
        [-0.0533],
        [-0.5778]], grad_fn=<MulBackward0>)

In [57]:
fc_xform = nn.Sequential(
    nn.Linear(loc_out[0].numel(), 32),
    nn.BatchNorm1d(32),
    nn.ReLU(True),
    nn.Linear(32, 3),
)
fc_xform_out = fc_xform(loc_out)
print(f"fc_xform_out.shape = {fc_xform_out.shape}")
print(f"fc_xform_out = {fc_xform_out}")

fc_xform_out.shape = torch.Size([5, 3])
fc_xform_out = tensor([[-0.2356, -0.0443,  0.4960],
        [ 0.3811,  0.2766,  0.3428],
        [-0.1277,  0.0818,  0.3027],
        [-0.0677,  0.3057,  0.1889],
        [-0.1496,  0.1758,  0.2195]], grad_fn=<AddmmBackward0>)


In [81]:
angle_factor = 1.0
xlate_factor = 1.0

sa = torch.sin(angle_factor * fc_xform_out[:, 2]).view(-1, 1)
ca = torch.cos(angle_factor * fc_xform_out[:, 2]).view(-1, 1)
print(f"ca = {ca}")
print(f"sa = {sa}")

x_shift = xlate_factor * fc_xform_out[:, 0]
x_shift = x_shift.view(-1, 1)

y_shift = xlate_factor * fc_xform_out[:, 1]
y_shift = y_shift.view(-1, 1)

matrix = torch.stack(
    (
        ca,
        -sa,
        -0.5 * (ca - sa) + x_shift + 0.5,
        sa,
        ca,
        -0.5 * (sa + ca) + y_shift + 0.5,
    ),
    dim=-1,
)
matrix = matrix.view(-1, 2, 3)
print(matrix)

ca = tensor([[0.8795],
        [0.9418],
        [0.9545],
        [0.9822],
        [0.9760]], grad_fn=<ViewBackward0>)
sa = tensor([[0.4759],
        [0.3361],
        [0.2981],
        [0.1878],
        [0.2177]], grad_fn=<ViewBackward0>)
tensor([[[ 0.8795, -0.4759,  0.0626],
         [ 0.4759,  0.8795, -0.2220]],

        [[ 0.9418, -0.3361,  0.5783],
         [ 0.3361,  0.9418,  0.1376]],

        [[ 0.9545, -0.2981,  0.0441],
         [ 0.2981,  0.9545, -0.0446]],

        [[ 0.9822, -0.1878,  0.0351],
         [ 0.1878,  0.9822,  0.2207]],

        [[ 0.9760, -0.2177, -0.0288],
         [ 0.2177,  0.9760,  0.0789]]], grad_fn=<ViewBackward0>)
