Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions fvdb/nn/simple_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,13 @@
such as 3D semantic segmentation, shape completion, or volumetric reconstruction.
"""

import math
from typing import Any, Sequence

import fvdb.nn as fvnn
import torch
import torch.nn as nn
from fvdb.types import (
NumericMaxRank1,
NumericMaxRank2,
ValueConstraint,
to_Vec3i,
to_Vec3iBroadcastable,
)
from torch.profiler import record_function

import fvdb
from fvdb import ConvolutionPlan, Grid, GridBatch, JaggedTensor
from fvdb.types import NumericMaxRank1

from fvdb import ConvolutionPlan, GridBatch, JaggedTensor, relu_

from .modules import fvnn_module

Expand Down Expand Up @@ -87,7 +77,6 @@ def __init__(

self.conv = fvnn.SparseConv3d(in_channels, out_channels, kernel_size=kernel_size, stride=1, bias=False)
self.batch_norm = fvnn.BatchNorm(out_channels, momentum=momentum)
self.relu = fvnn.ReLU(inplace=True)

def extra_repr(self) -> str:
return (
Expand All @@ -107,7 +96,7 @@ def forward(
x = self.conv(data, plan)
out_grid = plan.target_grid_batch
x = self.batch_norm(x, out_grid)
x = self.relu(x, out_grid)
x = relu_(x)
return x


Expand Down Expand Up @@ -170,8 +159,6 @@ def __init__(

self.blocks = nn.ModuleList(layers)

self.final_relu = fvnn.ReLU(inplace=True)

def extra_repr(self) -> str:
return (
f"in_channels={self.in_channels}, mid_channels={self.mid_channels}, out_channels={self.out_channels}, "
Expand Down Expand Up @@ -199,7 +186,7 @@ def forward(
data = block(data, plan)

data = data + residual
data = self.final_relu(data, plan.target_grid_batch)
data = relu_(data)

return data

Expand Down Expand Up @@ -556,7 +543,7 @@ def reset_parameters(self) -> None:

def forward(self, data: JaggedTensor, padded_grid: GridBatch, grid: GridBatch) -> JaggedTensor:
plan = ConvolutionPlan.from_grid_batch_transposed(
kernel_size=self.kernel_size, stride=1, source_grid=padded_grid, target_grid=grid
kernel_size=self.kernel_size, stride=1, source_grid=grid, target_grid=padded_grid
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks backwards to me, and counter to the intention. I need to look at the torch.convtranspose3d code more carefully, and create a testing framework that confirms the ordering. Generally speaking, we should not be reordering arguments like this, it indicates a semantic mismatch. If the transpose plans are wrong compared to pytorch, then we should fix it in the plan, not in the unet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, with source_grid=paded_grid, target_grid=grid, an exception will be thrown as if the order is incorrect. Investigating whether that's a mistake in the transposed convolution code or the ordering of the arguments, is what I was explaining in the PR description above, as to which direction to take this in. If you look at the 'topological' arguments to a PyTorch transposed conv (like stride), a stride of 2 will upsample (insert zeros in the input) in the inverse way a stride of 2 in a conv operator will downsample. So does that imply that we should order our source and target arguments to the transposed conv as if they were the source and target of the conv operator (since we'd also use stride=2, etc. to have the same meaning and not expect stride=1/2)?

I'm fine with not doing this and source/target take the natural meanings, I'm just trying to determine what is the intent both of PyTorch and the state of the transposed conv kmap code. You could also read the PyTorch docs for transposed conv as redefining what 'stride' means in transposed conv as basically 'inverse stride'.

)
return self.deconv(data, plan)

Expand Down
3 changes: 0 additions & 3 deletions src/fvdb/SparseConvPackInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ SparseConvPackInfo::SparseConvPackInfo(Vec3iOrScalar kernelsize,
"Source and target grids must both be on the same device");
TORCH_CHECK(srcGrid.device() == targetGrid.device(),
"Device should match between this grid and target grid.");
TORCH_CHECK(!(kernelsize.value() == Vec3iOrScalar(1).value() &&
stride.value() == Vec3iOrScalar(1).value()),
"1x1 conv does not need kernel map to be built!");

mStride = stride;
mKernelSize = kernelsize;
Expand Down
76 changes: 24 additions & 52 deletions src/fvdb/detail/autograd/SparseConvolutionKernelMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,41 +56,23 @@ SparseConvolutionKernelMap::forward(AutogradContext *ctx,
3,
},
opt);
if (!transposed) {
TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[0],
"The number of input features must match the number of voxels");
TORCH_CHECK_VALUE(
kernels.dim() == 5,
std::string(
"Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") +
std::to_string(kernels.dim()) + " dimensions");
TORCH_CHECK_VALUE(
kernels.size(1) == inFeatures.size(1),
"Expected input channels of kernels (" + std::to_string(kernels.size(1)) +
") to equal input channels of features: " + std::to_string(inFeatures.size(1)));
const int outC = kernels.size(0), inC = kernels.size(1);
kWidth[0] = kernels.size(2);
kWidth[1] = kernels.size(3);
kWidth[2] = kernels.size(4);
kernels = kernels.permute({2, 3, 4, 1, 0}).reshape({-1, inC, outC}).contiguous();
} else {
TORCH_CHECK_VALUE(inFeatures.size(0) == sizes[1],
"The number of input features must match the number of voxels");
TORCH_CHECK_VALUE(
kernels.dim() == 5,
std::string(
"Expected kernels to have 5 dimensions (shape (in_ch, out_ch, d, h, w)) but got ") +
std::to_string(kernels.dim()) + " dimensions");
TORCH_CHECK_VALUE(
kernels.size(0) == inFeatures.size(1),
"Expected input channels of kernels (" + std::to_string(kernels.size(0)) +
") to equal input channels of features: " + std::to_string(inFeatures.size(1)));
const int inC = kernels.size(0), outC = kernels.size(1);
kWidth[0] = kernels.size(2);
kWidth[1] = kernels.size(3);
kWidth[2] = kernels.size(4);
kernels = kernels.permute({2, 3, 4, 0, 1}).reshape({-1, inC, outC}).contiguous();
}

TORCH_CHECK_VALUE(!transposed ? inFeatures.size(0) == sizes[0] : inFeatures.size(0) == sizes[1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks backwards to me. I need to confirm what is expected in torch before agreeing that this is how it should behave. It may be that we need to switch how our transposed convolution works.

"The number of input features must match the number of voxels");
TORCH_CHECK_VALUE(
kernels.dim() == 5,
std::string(
"Expected kernels to have 5 dimensions (shape (out_ch, in_ch, d, h, w)) but got ") +
std::to_string(kernels.dim()) + " dimensions");
TORCH_CHECK_VALUE(
kernels.size(1) == inFeatures.size(1),
"Expected input channels of kernels (" + std::to_string(kernels.size(1)) +
") to equal input channels of features: " + std::to_string(inFeatures.size(1)));
const int outC = kernels.size(0), inC = kernels.size(1);
kWidth[0] = kernels.size(2);
kWidth[1] = kernels.size(3);
kWidth[2] = kernels.size(4);
kernels = kernels.permute({2, 3, 4, 1, 0}).reshape({-1, inC, outC}).contiguous();

// Save for backward
ctx->save_for_backward({inFeatures, kernels, nbmaps, nbsizes});
Expand Down Expand Up @@ -169,23 +151,13 @@ SparseConvolutionKernelMap::backward(AutogradContext *ctx, variable_list grad_ou
}

const int outC = gradWeight.size(-1), inC = gradWeight.size(-2);
if (!transposed) {
gradWeight = gradWeight
.reshape({kWidth[2].item<int32_t>(),
kWidth[1].item<int32_t>(),
kWidth[0].item<int32_t>(),
inC,
outC})
.permute({4, 3, 2, 1, 0});
} else {
gradWeight = gradWeight
.reshape({kWidth[2].item<int32_t>(),
kWidth[1].item<int32_t>(),
kWidth[0].item<int32_t>(),
inC,
outC})
.permute({3, 4, 2, 1, 0});
}
gradWeight = gradWeight
.reshape({kWidth[2].item<int32_t>(),
kWidth[1].item<int32_t>(),
kWidth[0].item<int32_t>(),
inC,
outC})
.permute({4, 3, 2, 1, 0});
return {
gradInput, gradWeight, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor()};
}
Expand Down
Loading