In [3]:
import torch

In [27]:
from torch import allclose, float32, device as torch_device, ones, randn, tensor
from composable_mapping.affine import Affine
from composable_mapping.coordinate_system import CoordinateSystem
from composable_mapping.interpolator import LinearInterpolator, NearestInterpolator
from composable_mapping.mappable_tensor.affine_transformation import HostAffineTransformation
from composable_mapping.mappable_tensor.mappable_tensor import PlainTensor, VoxelGrid
from composable_mapping.util import get_spatial_shape

grid = (
    CoordinateSystem.voxel(
        spatial_shape=(3, 4, 5),
        dtype=float32,
        device=torch_device("cpu"),
    )
    .shift_voxel(0.5)
    .grid()
)
mask = ones((2, 1, 14, 15, 16), dtype=float32, device=torch_device("cpu"))
mask[:, :, :4] = 0.0
test_volume = PlainTensor(
    randn((2, 3, 14, 15, 16), dtype=float32, device=torch_device("cpu")), mask=mask
)
interpolator = NearestInterpolator(extrapolation_mode="zeros")
# print(test_volume.generate_values()[0, 0, 0])
print(interpolator(test_volume, grid).generate_values()[0, 0, 0])
print(interpolator(test_volume, grid.reduce()).generate_values()[0, 0, 0])

[(0, -11), (0, -11), (0, -11)]
[tensor([1.]), tensor([1.]), tensor([1.])]
tensor([[ 1.9768,  0.8790,  0.4095,  2.0108, -0.8797],
        [ 1.7850,  0.0890,  1.2031, -0.8355, -0.8218],
        [-1.0326, -0.8467, -0.5697,  0.3434,  0.4706],
        [ 0.7503, -1.1749, -0.8391,  0.5067,  0.7121]])
tensor([[ 1.7850,  0.0890, -0.8355, -0.8218, -0.8218],
        [-1.0326, -0.8467,  0.3434,  0.4706,  0.4706],
        [-1.0326, -0.8467,  0.3434,  0.4706,  0.4706],
        [-0.1115, -0.5640, -1.8770, -1.8803, -1.8803]])


In [5]:
torch.nn.functional.pad(
    torch.zeros((1, 1, 5, 5, 5)), pad=(20, 20, 20, 20, 20, 20), mode="circular"
)

RuntimeError: Padding value causes wrapping around more than once.

In [25]:
torch.nn.functional.pad(torch.zeros(1, 1, 1, 1, dtype=torch.bool), (1, 1, 1, 1), 'constant', False)

tensor([[[[False, False, False],
          [False, False, False],
          [False, False, False]]]])

In [2]:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

In [3]:
from typing import Optional, Tuple
from torch import Tensor
from numpy import ndindex

from composable_mapping.util import optional_add


def _zero_to_none(integer: int) -> Optional[int]:
    return integer if integer != 0 else None


def _build_slices(kernel_size: int) -> Tuple[slice, ...]:
    return tuple(
        slice(_zero_to_none(shift), _zero_to_none(shift - kernel_size + 1))
        for shift in range(kernel_size)
    )


def conv_nd_slicing(volume: Tensor, kernel: Tensor) -> Tensor:
    slices = [_build_slices(kernel_size) for kernel_size in kernel.shape]
    output = None
    for index in ndindex(*kernel.shape):
        slice_tuple = tuple(slice_[i] for slice_, i in zip(slices, index))
        output = optional_add(output, volume[(...,) + slice_tuple] * kernel[index])
    return output

In [33]:
test_volume = torch.randn(1, 1, 128, 128, 128, device="cuda:0").requires_grad_(True)
conv_kernel = torch.randn(1, 1, 1, device="cuda:0")

In [14]:
from itertools import product

from composable_mapping.util import optional_add


torch.cuda.reset_peak_memory_stats()

start.record()
for _ in range(100):
    output_1 = conv_nd_slicing(test_volume, conv_kernel)
    output_1.sum().backward()
end.record()
print(output_1.shape)
torch.cuda.synchronize()
print(start.elapsed_time(end))
print(torch.cuda.max_memory_allocated(device="cuda:0") / 1024 ** 3)

torch.Size([1, 1, 255, 255, 255])
1123.42431640625
0.8255114555358887


In [40]:
torch.cuda.reset_peak_memory_stats(device="cuda:0")
start.record()
for _ in range(100):
    output_2 = torch.nn.functional.conv3d(test_volume.view(-1, 1, *test_volume.shape[2:]), conv_kernel[None, None])
    output_2 = output_2.view(*test_volume.shape[:2] + output_2.shape[2:])
    output_2.sum().backward()
end.record()
print(output_2.shape)
torch.cuda.synchronize()
print(start.elapsed_time(end))
print(torch.cuda.max_memory_allocated(device="cuda:0") / 1024 ** 3)

torch.Size([1, 1, 128, 128, 128])
36.58659362792969
0.6931252479553223


In [16]:
torch.testing.assert_close(output_1, output_2)

AssertionError: Tensor-likes are not close!

Mismatched elements: 16324683 / 16581375 (98.5%)
Greatest absolute difference: 0.005977630615234375 at index (0, 0, 86, 175, 143) (up to 1e-05 allowed)
Greatest relative difference: 2948.333251953125 at index (0, 0, 59, 125, 245) (up to 1.3e-06 allowed)