Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions benchmark/distance_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import torch.utils.benchmark as benchmark
import scipy.ndimage as ndi
import numpy as np
from prettytable import PrettyTable
import torchmorph as tm

sizes = [64, 128, 256, 512, 1024]
batches = [1, 4, 8, 16]
dtype = torch.float32
device = "cuda"
MIN_RUN = 1.0 # seconds per measurement

torch.set_num_threads(torch.get_num_threads())

for B in batches:
table = PrettyTable()
table.field_names = [
"Size",
"SciPy (ms/img)",
"Torch 1× (ms/img)",
"Torch batch (ms/img)",
"Speedup 1×",
"Speedup batch",
]
for c in table.field_names:
table.align[c] = "r"

for s in sizes:
# Inputs
x = (torch.randn(B, 1, s, s, device=device) > 0).to(dtype)
x_np_list = [x[i, 0].detach().cpu().numpy() for i in range(B)]
x_imgs = [x[i:i+1] for i in range(B)]

# SciPy (CPU, one-by-one)
stmt_scipy = "out = [ndi.distance_transform_edt(arr) for arr in x_np_list]"
t_scipy = benchmark.Timer(
stmt=stmt_scipy,
setup="from __main__ import x_np_list, ndi",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
scipy_per_img_ms = (t_scipy.median * 1e3) / B

# Torch (CUDA, one-by-one)
stmt_torch1 = """
for xi in x_imgs:
tm.distance_transform(xi)
"""
t_torch1 = benchmark.Timer(
stmt=stmt_torch1,
setup="from __main__ import x_imgs, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
torch1_per_img_ms = (t_torch1.median * 1e3) / B

# Torch (CUDA, batched)
t_batch = benchmark.Timer(
stmt="tm.distance_transform(x)",
setup="from __main__ import x, tm",
num_threads=torch.get_num_threads(),
).blocked_autorange(min_run_time=MIN_RUN)
torchB_per_img_ms = (t_batch.median * 1e3) / B

# Speedups
speed1 = scipy_per_img_ms / torch1_per_img_ms
speedB = scipy_per_img_ms / torchB_per_img_ms

table.add_row([
s,
f"{scipy_per_img_ms:.3f}",
f"{torch1_per_img_ms:.3f}",
f"{torchB_per_img_ms:.3f}",
f"{speed1:.1f}×",
f"{speedB:.1f}×",
])

print(f"\n=== Batch Size: {B} ===")
print(table)

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ max-line-length = 100
extend-ignore = ["E203", "W503"]

[tool.pytest.ini_options]
addopts = "-v"
addopts = "-v --import-mode=importlib"
testpaths = ["test"]

[build-system]
requires = ["setuptools>=61.0", "wheel", "torch"]
build-backend = "setuptools.build_meta"
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ flake8>=6.0
setuptools>=65.0
wheel>=0.40
ninja>=1.11 # optional, speeds up torch extension builds

prettytable>=3.16.0
101 changes: 90 additions & 11 deletions test/test_distance_transform.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,99 @@
import torch
import pytest
import torchmorph as tm
from scipy.ndimage import distance_transform_edt as dte
import torchmorph as tm
import numpy as np


def batch_distance_transform_edt(batch_numpy):

is_single_sample = batch_numpy.ndim <= 2
# (H, W) -> (1, H, W)
if is_single_sample:
batch_numpy = batch_numpy[np.newaxis, ...]

results = [dte(sample) for sample in batch_numpy]
output = np.stack(results, axis=0)
# (1, H, W) -> (H, W)
if is_single_sample:
output = output.squeeze(0)

return output


# 用例 1: 批处理的 2D 图像
case_batch_2d = np.array(
[
# 第 1 张图
[[0, 1, 1, 1], [0, 0, 1, 1], [0, 1, 1, 0]],
# 第 2 张图
[[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]],
],
dtype=np.float32,
)


@pytest.mark.cuda
def test_distance_transform():
"""Test that tm.foo doubles all tensor elements."""
# 用例 2: 批处理的 3D 图像
case_3d_sample1 = np.ones((4, 5, 6), dtype=np.float32)
case_3d_sample1[1, 1, 1] = 0.0
case_3d_sample1[2, 3, 4] = 0.0
case_3d_sample2 = np.ones((4, 5, 6), dtype=np.float32)
case_3d_sample2[0, 0, 0] = 0.0
case_batch_3d = np.stack([case_3d_sample1, case_3d_sample2], axis=0)

# 用例 3: 单张 2D 图像 (隐式批处理)
case_single_2d = np.array(
[
[0, 1, 0, 1],
[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 1, 0],
],
dtype=np.float32,
)


# 用例 4: 单张 2D 图像 (显式批处理)
case_explicit_batch_one = case_single_2d[np.newaxis, ...]

# 用例 5: 含幺元维度的批处理
case_dim_one = np.ones((2, 5, 1), dtype=np.float32) # 两张 5x1 的图片
case_dim_one[0, 2, 0] = 0.0
case_dim_one[1, 4, 0] = 0.0

# 用例 6: 1D 张量的批处理
case_batch_1d = np.array([[1, 1, 0, 1, 0, 1], [0, 1, 1, 1, 1, 0]], dtype=np.float32)


@pytest.mark.parametrize(
"input_numpy",
[
pytest.param(case_batch_2d, id="批处理2D图像"),
pytest.param(case_batch_3d, id="批处理3D图像"),
pytest.param(case_single_2d, id="单张2D图像(隐式批处理)"),
pytest.param(case_explicit_batch_one, id="单张2D图像(显式批处理)"),
pytest.param(case_dim_one, id="含幺元维度的批处理"),
pytest.param(case_batch_1d, id="批处理1D数据"),
],
)
def test_batch_processing(input_numpy, request):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
x_numpy_contiguous = np.ascontiguousarray(input_numpy)
x = torch.from_numpy(x_numpy_contiguous).cuda()

print(f"\n\n--- 正在运行测试: {request.node.callspec.id} ---")
print(f"输入张量形状: {x.shape}")
dist_cuda, idx_cuda = tm.distance_transform(x.clone())
print(f"Output index shape: {idx_cuda.shape}.")

y_ref_numpy = batch_distance_transform_edt(x_numpy_contiguous)
y_ref = torch.from_numpy(y_ref_numpy).to(torch.float32).cuda()

x = torch.arange(6, dtype=torch.float32, device="cuda").reshape(2, 3)
y = tm.distance_transform(x)
assert (
dist_cuda.shape == y_ref.shape
), f"形状不匹配! CUDA输出: {dist_cuda.shape}, SciPy应为: {y_ref.shape}"
print("CUDA 和 SciPy 输出形状匹配。")

expected = x * 2
torch.testing.assert_close(y, expected)
assert y.device.type == "cuda"
assert y.shape == x.shape
print("tm.foo test passed ✅")
torch.testing.assert_close(dist_cuda, y_ref, atol=1e-3, rtol=1e-3)
print("--- 断言通过 (数值接近) ---")
Loading