diff --git a/fvdb/nn/simple_unet.py b/fvdb/nn/simple_unet.py index 40fb1f25..61b4af40 100644 --- a/fvdb/nn/simple_unet.py +++ b/fvdb/nn/simple_unet.py @@ -30,23 +30,13 @@ such as 3D semantic segmentation, shape completion, or volumetric reconstruction. """ -import math -from typing import Any, Sequence import fvdb.nn as fvnn import torch import torch.nn as nn -from fvdb.types import ( - NumericMaxRank1, - NumericMaxRank2, - ValueConstraint, - to_Vec3i, - to_Vec3iBroadcastable, -) -from torch.profiler import record_function - -import fvdb -from fvdb import ConvolutionPlan, Grid, GridBatch, JaggedTensor +from fvdb.types import NumericMaxRank1 + +from fvdb import ConvolutionPlan, GridBatch, JaggedTensor, relu_ from .modules import fvnn_module @@ -87,7 +77,6 @@ def __init__( self.conv = fvnn.SparseConv3d(in_channels, out_channels, kernel_size=kernel_size, stride=1, bias=False) self.batch_norm = fvnn.BatchNorm(out_channels, momentum=momentum) - self.relu = fvnn.ReLU(inplace=True) def extra_repr(self) -> str: return ( @@ -107,7 +96,7 @@ def forward( x = self.conv(data, plan) out_grid = plan.target_grid_batch x = self.batch_norm(x, out_grid) - x = self.relu(x, out_grid) + x = relu_(x) return x @@ -170,8 +159,6 @@ def __init__( self.blocks = nn.ModuleList(layers) - self.final_relu = fvnn.ReLU(inplace=True) - def extra_repr(self) -> str: return ( f"in_channels={self.in_channels}, mid_channels={self.mid_channels}, out_channels={self.out_channels}, " @@ -199,7 +186,7 @@ def forward( data = block(data, plan) data = data + residual - data = self.final_relu(data, plan.target_grid_batch) + data = relu_(data) return data @@ -556,7 +543,7 @@ def reset_parameters(self) -> None: def forward(self, data: JaggedTensor, padded_grid: GridBatch, grid: GridBatch) -> JaggedTensor: plan = ConvolutionPlan.from_grid_batch_transposed( - kernel_size=self.kernel_size, stride=1, source_grid=padded_grid, target_grid=grid + kernel_size=self.kernel_size, stride=1, source_grid=grid, target_grid=padded_grid ) return self.deconv(data, plan) diff --git a/src/fvdb/SparseConvPackInfo.cpp b/src/fvdb/SparseConvPackInfo.cpp index 1d6dce17..3ef7c662 100644 --- a/src/fvdb/SparseConvPackInfo.cpp +++ b/src/fvdb/SparseConvPackInfo.cpp @@ -48,9 +48,6 @@ SparseConvPackInfo::SparseConvPackInfo(Vec3iOrScalar kernelsize, "Source and target grids must both be on the same device"); TORCH_CHECK(srcGrid.device() == targetGrid.device(), "Device should match between this grid and target grid."); - TORCH_CHECK(!(kernelsize.value() == Vec3iOrScalar(1).value() && - stride.value() == Vec3iOrScalar(1).value()), - "1x1 conv does not need kernel map to be built!"); mStride = stride; mKernelSize = kernelsize; diff --git a/src/fvdb/detail/autograd/SparseConvolutionKernelMap.cpp b/src/fvdb/detail/autograd/SparseConvolutionKernelMap.cpp index 4f024d47..b58fe065 100644 --- a/src/fvdb/detail/autograd/SparseConvolutionKernelMap.cpp +++ b/src/fvdb/detail/autograd/SparseConvolutionKernelMap.cpp @@ -56,41 +56,23 @@ SparseConvolutionKernelMap::forward(AutogradContext *ctx, 3, }, opt); - if (!transposed) { - TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[0], - "The number of input features must match the number of voxels"); - TORCH_CHECK_VALUE( - kernels.dim() == 5, - std::string( - "Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") + - std::to_string(kernels.dim()) + " dimensions"); - TORCH_CHECK_VALUE( - kernels.size(1) == inFeatures.size(1), - "Expected input channels of kernels (" + std::to_string(kernels.size(1)) + - ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); - const int outC = kernels.size(0), inC = kernels.size(1); - kWidth[0] = kernels.size(2); - kWidth[1] = kernels.size(3); - kWidth[2] = kernels.size(4); - kernels = kernels.permute({2, 3, 4, 1, 0}).reshape({-1, inC, outC}).contiguous(); - } else { - TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[1], - "The number of input features must match the number of voxels"); - TORCH_CHECK_VALUE( - kernels.dim() == 5, - std::string( - "Expected kernels to have 5 dimensions (shape (in_ch, out_ch, d, h, w)) but got ") + - std::to_string(kernels.dim()) + " dimensions"); - TORCH_CHECK_VALUE( - kernels.size(0) == inFeatures.size(1), - "Expected input channels of kernels (" + std::to_string(kernels.size(0)) + - ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); - const int inC = kernels.size(0), outC = kernels.size(1); - kWidth[0] = kernels.size(2); - kWidth[1] = kernels.size(3); - kWidth[2] = kernels.size(4); - kernels = kernels.permute({2, 3, 4, 0, 1}).reshape({-1, inC, outC}).contiguous(); - } + + TORCH_CHECK_VALUE(!transposed ? inFeatures.size(0) == sizes[0] : inFeatures.size(0) == sizes[1], + "The number of input features must match the number of voxels"); + TORCH_CHECK_VALUE( + kernels.dim() == 5, + std::string( + "Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") + + std::to_string(kernels.dim()) + " dimensions"); + TORCH_CHECK_VALUE( + kernels.size(1) == inFeatures.size(1), + "Expected input channels of kernels (" + std::to_string(kernels.size(1)) + + ") to equal input channels of features: " + std::to_string(inFeatures.size(1))); + const int outC = kernels.size(0), inC = kernels.size(1); + kWidth[0] = kernels.size(2); + kWidth[1] = kernels.size(3); + kWidth[2] = kernels.size(4); + kernels = kernels.permute({2, 3, 4, 1, 0}).reshape({-1, inC, outC}).contiguous(); // Save for backward ctx->save_for_backward({inFeatures, kernels, nbmaps, nbsizes}); @@ -169,23 +151,13 @@ SparseConvolutionKernelMap::backward(AutogradContext *ctx, variable_list grad_ou } const int outC = gradWeight.size(-1), inC = gradWeight.size(-2); - if (!transposed) { - gradWeight = gradWeight - .reshape({kWidth[2].item(), - kWidth[1].item(), - kWidth[0].item(), - inC, - outC}) - .permute({4, 3, 2, 1, 0}); - } else { - gradWeight = gradWeight - .reshape({kWidth[2].item(), - kWidth[1].item(), - kWidth[0].item(), - inC, - outC}) - .permute({3, 4, 2, 1, 0}); - } + gradWeight = gradWeight + .reshape({kWidth[2].item(), + kWidth[1].item(), + kWidth[0].item(), + inC, + outC}) + .permute({4, 3, 2, 1, 0}); return { gradInput, gradWeight, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor()}; }