In [None]:
import matplotlib.pyplot as plt
import numpy
import torch

from scipy.ndimage import affine_transform
from tifffile import imread

In [None]:
def combine(a, b):
    assert len(a.shape) == 2
    assert a.shape == b.shape
    ret = numpy.zeros_like(a)
    ret[::2, ::2] = a[::2, ::2]
    ret[::2, 1::2] = -b[::2, 1::2]
    ret[1::2, 1::2] = a[1::2, 1::2]
    ret[1::2, ::2] = -b[1::2, ::2]
    return ret

In [None]:
def trf_scipy(ipt, matrix, output_shape):
    assert matrix[3, 0] == 0
    assert matrix[3, 1] == 0
    assert matrix[3, 2] == 0
    assert matrix[3, 3] == 1
    return affine_transform(ipt, matrix, output_shape=output_shape, order=0)

In [None]:
def trf_torch(ipt, scipy_form, output_shape):
    dim = len(ipt.shape)
    for i in range(dim):
        assert scipy_form[dim, i] == 0
    
    assert scipy_form[dim, dim] == 1
    
    transposed_scipy = numpy.zeros_like(scipy_form)
    transposed_scipy[:dim, :dim] = scipy_form[dim-1::-1, dim-1::-1]
    transposed_scipy[:dim, dim] = scipy_form[dim-1::-1, dim]
    transposed_scipy[dim, :dim] = scipy_form[dim, dim-1::-1]
    transposed_scipy[dim, dim] = scipy_form[dim, dim]
    
#     ts = numpy.eye(4, dtype=scipy_form.dtype)
#     map_pos = {
#         (0, 0): (2, 2),
#         (0, 1): (2, 1),
#         (0, 2): (2, 0),
#         (0, 3): (2, 3),
#         (1, 0): (1, 2),
#         (1, 1): (1, 1),
#         (1, 2): (1, 0),
#         (1, 3): (1, 3),
#         (2, 0): (0, 2),
#         (2, 1): (0, 1),
#         (2, 2): (0, 0),
#         (2, 3): (0, 3),
#     }
#     for spos, tpos in map_pos.items():
#         ts[tpos] = scipy_form[spos]

#     assert numpy.allclose(transposed_scipy, ts)
    t, h, w = ipt.shape

    # extended 2d case with depth as 't' and additional parameters klmnop
    # wolfram alpha input: {{2/w, 0, 0, -1}, {0, 2/h, 0, -1}, {0, 0, 2/t, -1}, {0, 0, 0, 1}} * {{a, b, k, c}, {d, e, l, f}, {m, n, o, p}, {0, 0, 0, 1}} * {{2/w, 0, 0, -1}, {0, 2/h, 0, -1},{0, 0, 2/t, -1}, {0, 0, 0, 1}}^-1
    a = transposed_scipy[0, 0]
    b = transposed_scipy[0, 1]
    k = transposed_scipy[0, 2]  # 3d
    c = transposed_scipy[0, 3]
    d = transposed_scipy[1, 0]
    e = transposed_scipy[1, 1]
    l = transposed_scipy[1, 2]  # 3d
    f = transposed_scipy[1, 3]

    m = transposed_scipy[2, 0]  # 3d
    n = transposed_scipy[2, 1]  # 3d
    o = transposed_scipy[2, 2]  # 3d
    p = transposed_scipy[2, 3]  # 3d

    theta = numpy.array(
        [
            [a, b * h / w, k * t / w, a + 2 * c / w + b * h / w + k * t / w - 1],
            [d * w / h, e, l * t / h, e + l * t / h + d * w / h + 2 * f / h - 1],
            [m * w / t, h * n / t, o, h * n / t + o + m * w / t + 2 * p / t - 1],
            # [0, 0, 0, 1],
        ]
    )
    
    affine_grid_size = (1, 1) + tuple(output_shape)
    affine_grid = torch.nn.functional.affine_grid(theta=torch.from_numpy(theta[None, ...]), size=affine_grid_size, align_corners=False)
    
    ipt = torch.from_numpy(ipt[None, None, ...])
    return torch.nn.functional.grid_sample(
        ipt, affine_grid, align_corners=False, mode="nearest", padding_mode="zeros"# "border"
    ).numpy()[0, 0]

In [None]:
matrix = numpy.eye(4, dtype="float64")
matrix[0, 0] = .7
matrix[1, 1] = .8
matrix[2, 2] = .9

# matrix[0, 1] = .03
# matrix[0, 2] = .05
matrix[0, 3] = 10
matrix[1, 3] = 10
matrix[2, 3] = 10

matrix[0, 1] = .8
matrix[1, 0] = 1.6

matrix[1, 2] = .8
matrix[2, 1] = 1.6

matrix[0, 2] = .8
matrix[2, 0] = 1.6
# matrix[1, 0] = 2

matrix = numpy.linalg.inv(matrix)

# matrix =numpy.asarray([[ 0.5,  0.,   0.,   0. ],
#  [ 0.  , 1.,   0.,   0. ],
#  [ 0. , -8. ,  2. ,  0. ],
#  [ 0.  , 0.,   0. ,  1. ]])
print(matrix)

In [None]:
ipt = numpy.ones((100, 200, 300), dtype="float64")
out_scipy = trf_scipy(ipt, matrix, ipt.shape)
out_torch= trf_torch(ipt, matrix, ipt.shape)

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(15, 5))
for i in range(3):
    ax[i].imshow(combine(out_scipy.max(i), out_torch.max(i)))

In [None]:
plt.imshow(combine(out_scipy[:, 0], out_torch[:, 0]))
plt.colorbar()

In [None]:
plt.imshow(combine(out_scipy[:, :, 0], out_torch[:, :, 0]))
plt.colorbar()

In [None]:
plt.imshow(out_scipy[0])
plt.colorbar()

In [None]:
plt.imshow(out_torch[0])
plt.colorbar()

In [None]:
plt.imshow(out_scipy[:, 0])
plt.colorbar()

In [None]:
plt.imshow(out_torch[:, 0])
plt.colorbar()

In [None]:
plt.imshow(out_scipy[:, :, 0])
plt.colorbar()

In [None]:
plt.imshow(out_torch[:, :, 0])
plt.colorbar()

In [None]:
plt.imshow(out_scipy[0])
plt.colorbar()

In [None]:
plt.imshow(out_scipy[:, :, 0])
plt.colorbar()