From 3806340c1db4ffde20c556548763939ff8597bdc Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Mon, 17 Nov 2025 13:39:43 +1300 Subject: [PATCH 1/4] Using torch.nn.ReLU instead of fvdb.nn.ReLU; some import cleanups Signed-off-by: Jonathan Swartz --- fvdb/nn/simple_unet.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/fvdb/nn/simple_unet.py b/fvdb/nn/simple_unet.py index 40fb1f25..50546da2 100644 --- a/fvdb/nn/simple_unet.py +++ b/fvdb/nn/simple_unet.py @@ -30,23 +30,13 @@ such as 3D semantic segmentation, shape completion, or volumetric reconstruction. """ -import math -from typing import Any, Sequence import fvdb.nn as fvnn import torch import torch.nn as nn -from fvdb.types import ( - NumericMaxRank1, - NumericMaxRank2, - ValueConstraint, - to_Vec3i, - to_Vec3iBroadcastable, -) -from torch.profiler import record_function - -import fvdb -from fvdb import ConvolutionPlan, Grid, GridBatch, JaggedTensor +from fvdb.types import NumericMaxRank1 + +from fvdb import ConvolutionPlan, GridBatch, JaggedTensor from .modules import fvnn_module @@ -87,7 +77,7 @@ def __init__( self.conv = fvnn.SparseConv3d(in_channels, out_channels, kernel_size=kernel_size, stride=1, bias=False) self.batch_norm = fvnn.BatchNorm(out_channels, momentum=momentum) - self.relu = fvnn.ReLU(inplace=True) + self.relu = torch.nn.ReLU(inplace=True) def extra_repr(self) -> str: return ( @@ -107,7 +97,7 @@ def forward( x = self.conv(data, plan) out_grid = plan.target_grid_batch x = self.batch_norm(x, out_grid) - x = self.relu(x, out_grid) + x = self.relu(x) return x @@ -170,7 +160,7 @@ def __init__( self.blocks = nn.ModuleList(layers) - self.final_relu = fvnn.ReLU(inplace=True) + self.final_relu = torch.nn.ReLU(inplace=True) def extra_repr(self) -> str: return ( @@ -199,7 +189,7 @@ def forward( data = block(data, plan) data = data + residual - data = self.final_relu(data, plan.target_grid_batch) + data = self.final_relu(data) return data From 66c2373cb1ea27935b11241eb956fdee15022d6a Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Tue, 18 Nov 2025 10:12:07 +1300 Subject: [PATCH 2/4] Switch to fvdb.relu_ for functional inplace ReLU Signed-off-by: Jonathan Swartz --- fvdb/nn/simple_unet.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/fvdb/nn/simple_unet.py b/fvdb/nn/simple_unet.py index 50546da2..42eebfd8 100644 --- a/fvdb/nn/simple_unet.py +++ b/fvdb/nn/simple_unet.py @@ -36,7 +36,7 @@ import torch.nn as nn from fvdb.types import NumericMaxRank1 -from fvdb import ConvolutionPlan, GridBatch, JaggedTensor +from fvdb import ConvolutionPlan, GridBatch, JaggedTensor, relu_ from .modules import fvnn_module @@ -77,7 +77,6 @@ def __init__( self.conv = fvnn.SparseConv3d(in_channels, out_channels, kernel_size=kernel_size, stride=1, bias=False) self.batch_norm = fvnn.BatchNorm(out_channels, momentum=momentum) - self.relu = torch.nn.ReLU(inplace=True) def extra_repr(self) -> str: return ( @@ -97,7 +96,7 @@ def forward( x = self.conv(data, plan) out_grid = plan.target_grid_batch x = self.batch_norm(x, out_grid) - x = self.relu(x) + x = relu_(x) return x @@ -160,8 +159,6 @@ def __init__( self.blocks = nn.ModuleList(layers) - self.final_relu = torch.nn.ReLU(inplace=True) - def extra_repr(self) -> str: return ( f"in_channels={self.in_channels}, mid_channels={self.mid_channels}, out_channels={self.out_channels}, " @@ -189,7 +186,7 @@ def forward( data = block(data, plan) data = data + residual - data = self.final_relu(data) + data = relu_(data) return data From c3299a06c11d213446543ed723d22d4c9facaff3 Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Tue, 18 Nov 2025 11:57:09 +1300 Subject: [PATCH 3/4] Remove 1x1 conv kmap building error Signed-off-by: Jonathan Swartz --- src/fvdb/SparseConvPackInfo.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/fvdb/SparseConvPackInfo.cpp b/src/fvdb/SparseConvPackInfo.cpp index 1d6dce17..3ef7c662 100644 --- a/src/fvdb/SparseConvPackInfo.cpp +++ b/src/fvdb/SparseConvPackInfo.cpp @@ -48,9 +48,6 @@ SparseConvPackInfo::SparseConvPackInfo(Vec3iOrScalar kernelsize, "Source and target grids must both be on the same device"); TORCH_CHECK(srcGrid.device() == targetGrid.device(), "Device should match between this grid and target grid."); - TORCH_CHECK(!(kernelsize.value() == Vec3iOrScalar(1).value() && - stride.value() == Vec3iOrScalar(1).value()), - "1x1 conv does not need kernel map to be built!"); mStride = stride; mKernelSize = kernelsize; From aea357203243bc67069bbf10368e2d20e6ca1aeb Mon Sep 17 00:00:00 2001 From: Jonathan Swartz Date: Wed, 19 Nov 2025 12:26:22 +1300 Subject: [PATCH 4/4] Transposed conv fixes. Per reading PyTorch's transposed conv documentation, I believe in/out channels should be expressed as the desired in/out (not transposed) and the topology parameters (like stride, source/target grid) should be expressed as if they are the 'original' conv operator this transposed conv is meant to invert. Signed-off-by: Jonathan Swartz --- fvdb/nn/simple_unet.py | 2 +- .../autograd/SparseConvolutionKernelMap.cpp | 76 ++++++------------- 2 files changed, 25 insertions(+), 53 deletions(-) diff --git a/fvdb/nn/simple_unet.py b/fvdb/nn/simple_unet.py index 42eebfd8..61b4af40 100644 --- a/fvdb/nn/simple_unet.py +++ b/fvdb/nn/simple_unet.py @@ -543,7 +543,7 @@ def reset_parameters(self) -> None: def forward(self, data: JaggedTensor, padded_grid: GridBatch, grid: GridBatch) -> JaggedTensor: plan = ConvolutionPlan.from_grid_batch_transposed( - kernel_size=self.kernel_size, stride=1, source_grid=padded_grid, target_grid=grid + kernel_size=self.kernel_size, stride=1, source_grid=grid, target_grid=padded_grid ) return self.deconv(data, plan) diff --git a/src/fvdb/detail/autograd/SparseConvolutionKernelMap.cpp b/src/fvdb/detail/autograd/SparseConvolutionKernelMap.cpp index 4f024d47..b58fe065 100644 --- a/src/fvdb/detail/autograd/SparseConvolutionKernelMap.cpp +++ b/src/fvdb/detail/autograd/SparseConvolutionKernelMap.cpp @@ -56,41 +56,23 @@ SparseConvolutionKernelMap::forward(AutogradContext *ctx, 3, }, opt); - if (!transposed) { - TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[0], - "The number of input features must match the number of voxels"); - TORCH_CHECK_VALUE( - kernels.dim() == 5, - std::string( - "Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") + - std::to_string(kernels.dim()) + " dimensions"); - TORCH_CHECK_VALUE( - kernels.size(1) == inFeatures.size(1), - "Expected input channels of kernels (" + std::to_string(kernels.size(1)) + - ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); - const int outC = kernels.size(0), inC = kernels.size(1); - kWidth[0] = kernels.size(2); - kWidth[1] = kernels.size(3); - kWidth[2] = kernels.size(4); - kernels = kernels.permute({2, 3, 4, 1, 0}).reshape({-1, inC, outC}).contiguous(); - } else { - TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[1], - "The number of input features must match the number of voxels"); - TORCH_CHECK_VALUE( - kernels.dim() == 5, - std::string( - "Expected kernels to have 5 dimensions (shape (in_ch, out_ch, d, h, w)) but got ") + - std::to_string(kernels.dim()) + " dimensions"); - TORCH_CHECK_VALUE( - kernels.size(0) == inFeatures.size(1), - "Expected input channels of kernels (" + std::to_string(kernels.size(0)) + - ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); - const int inC = kernels.size(0), outC = kernels.size(1); - kWidth[0] = kernels.size(2); - kWidth[1] = kernels.size(3); - kWidth[2] = kernels.size(4); - kernels = kernels.permute({2, 3, 4, 0, 1}).reshape({-1, inC, outC}).contiguous(); - } + + TORCH_CHECK_VALUE(!transposed ? inFeatures.size(0) == sizes[0] : inFeatures.size(0) == sizes[1], + "The number of input features must match the number of voxels"); + TORCH_CHECK_VALUE( + kernels.dim() == 5, + std::string( + "Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") + + std::to_string(kernels.dim()) + " dimensions"); + TORCH_CHECK_VALUE( + kernels.size(1) == inFeatures.size(1), + "Expected input channels of kernels (" + std::to_string(kernels.size(1)) + + ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); + const int outC = kernels.size(0), inC = kernels.size(1); + kWidth[0] = kernels.size(2); + kWidth[1] = kernels.size(3); + kWidth[2] = kernels.size(4); + kernels = kernels.permute({2, 3, 4, 1, 0}).reshape({-1, inC, outC}).contiguous(); // Save for backward ctx->save_for_backward({inFeatures, kernels, nbmaps, nbsizes}); @@ -169,23 +151,13 @@ SparseConvolutionKernelMap::backward(AutogradContext *ctx, variable_list grad_ou } const int outC = gradWeight.size(-1), inC = gradWeight.size(-2); - if (!transposed) { - gradWeight = gradWeight - .reshape({kWidth[2].item(), - kWidth[1].item(), - kWidth[0].item(), - inC, - outC}) - .permute({4, 3, 2, 1, 0}); - } else { - gradWeight = gradWeight - .reshape({kWidth[2].item(), - kWidth[1].item(), - kWidth[0].item(), - inC, - outC}) - .permute({3, 4, 2, 1, 0}); - } + gradWeight = gradWeight + .reshape({kWidth[2].item(), + kWidth[1].item(), + kWidth[0].item(), + inC, + outC}) + .permute({4, 3, 2, 1, 0}); return { gradInput, gradWeight, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor()}; }