In [1]:
import torch
import matplotlib.pyplot as plt

In [2]:
import time

from torch.cuda import Event, synchronize

event_1 = Event(enable_timing=True)
event_2 = Event(enable_timing=True)

In [21]:
from functools import lru_cache


@lru_cache(maxsize=1)
def func():
    print("computing")
    return 5

print(func())
print(func())

computing
5
5


In [3]:
from deformation_inversion_layer.interpolator import LinearInterpolator
from torch import Tensor

from composable_mapping.dense_deformation import interpolate

class CustomInterpolator(LinearInterpolator):
    def __init__(self, padding_mode: str = "border") -> None:
        super().__init__(padding_mode)
        self.count = 0

    def __call__(self, volume: Tensor, coordinates: Tensor) -> Tensor:
        self.count += 1
        return interpolate(
            volume=volume,
            grid=coordinates,
            mode="bilinear",
            padding_mode=self._padding_mode,
        )
    

In [4]:
import torch.nn.functional
from composable_mapping.affine import CPUAffineTransformation
from composable_mapping.grid_mapping import (
    InterpolationArgs,
    create_deformation_from_voxel_data,
    create_volume,
)
from composable_mapping.mapping_factory import SamplableMappingFactory
from composable_mapping.masked_tensor import MaskedTensor
from composable_mapping.voxel_coordinate_system import create_centered_normalized, create_voxel

random_image = torch.randn(1, 1, 512, 512, 512, device=torch.device("cuda:0"))
random_ddf = torch.randn(1, 3, 512, 512, 512, device=torch.device("cuda:0"))

normalized_coordinate_system = create_centered_normalized(
    shape=(512, 512, 512), device=torch.device("cuda:0")
)
voxel_coordinate_system = create_voxel(
    shape=(512, 512, 512), device=torch.device("cuda:0")
)

mapping_factory = SamplableMappingFactory(
    normalized_coordinate_system,
    interpolation_args=InterpolationArgs(mask_outside_fov=True),
)
image = mapping_factory.create_volume(random_image)
deformation = mapping_factory.create_deformation(random_ddf)

# create example affine matrix which is not identity
matrix = torch.tensor(
    [
        [1.1, 0, 0, 0],
        [0, 1.1, 0, 0],
        [0, 0, 1.1, 0],
        [0, 0, 0, 1],
    ]
)
affine = CPUAffineTransformation(transformation_matrix_on_cpu=matrix, device=torch.device("cuda:0")).pin_memory_if_target_not_cpu()
data = torch.randn(1, 3, 128, 128, 128, device=torch.device("cuda:0"))
random_tensor = torch.tensor([1.5, 1.5, 1.5], device=torch.device("cuda:0"))

In [16]:
import torch.nn.functional
from composable_mapping.grid_mapping import (
    InterpolationArgs,
    create_deformation_from_voxel_data,
    create_volume,
)
from composable_mapping.mapping_factory import SamplableMappingFactory
from composable_mapping.masked_tensor import MaskedTensor
from composable_mapping.voxel_coordinate_system import create_centered_normalized, create_voxel

random_image = torch.randn(1, 1, 128, 128, 128, device=torch.device("cuda:0"))
random_ddf = torch.randn(1, 3, 128, 128, 128, device=torch.device("cuda:0"))

normalized_coordinate_system = create_centered_normalized(
    shape=(128, 128, 128), device=torch.device("cuda:0")
)
voxel_coordinate_system = create_voxel(
    shape=(128, 128, 128), device=torch.device("cuda:0")
)

mapping_factory = SamplableMappingFactory(
    normalized_coordinate_system,
    interpolation_args=InterpolationArgs(mask_outside_fov=True),
)
image = mapping_factory.create_volume(random_image)
deformation = mapping_factory.create_deformation(random_ddf)
#deformation = deformation.resample()

normalized_grid = deformation.sample().generate_values()
voxel_grid = deformation.sample_as_displacement_field().generate_values() + voxel_coordinate_system.grid().generate_values()

synchronize()
start = time.time()
event_1.record()
for _ in range(40):
    # volume.compose(deformation)(coordinate_system.grid())
    # composition(voxel_coordinate_system.grid())
    # interpolator(random_image, grid)
    # volume(MaskedTensor(random_ddf))
    # jou = torch.tensor(5.0).pin_memory().to(torch.device("cuda:0"), non_blocking=True)
    # jou = torch.ones(1, device=torch.device("cuda:0")) * 5.0
    deformation.estimate_spatial_derivatives(0)
    """torch.nn.functional.grid_sample(
        random_image.permute((0, 1, 4, 3, 2)),
        normalized_grid.moveaxis(1, -1),
        align_corners=False,
        mode="bilinear",
        padding_mode="border",
    )"""
event_2.record()
end = time.time()
print(f"Time: {end - start}")

synchronize()
print(f"CUDA Time: {event_1.elapsed_time(event_2) / 1000}")

Time: 0.09380388259887695
CUDA Time: 0.10519551849365234


: 