Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Involution on sparse point clouds #56

Open
HuguesTHOMAS opened this issue Aug 10, 2022 · 2 comments
Open

Involution on sparse point clouds #56

HuguesTHOMAS opened this issue Aug 10, 2022 · 2 comments

Comments

@HuguesTHOMAS
Copy link

Hi,

I found your work on involution very interesting, and it relates to other ideas I am working on like deformable convolutions. So I tried reimplementing your idea for sparse point clouds using the KPConv framework.

KPConv is very similar to an image convolution except the input features are located at neighbors points which are not at the same locations as the kernel points, where the convolution weights are defined, so we simply use a correlation matrix to project the features from the neighbor points to the kernel points. A simple pseudo-code of the whole convolution would look like this:

# Dimensions:
#    N = number of points (equivalent to H*W the number of pixels in images)
#    H = number of neighbors per point (equivalent to K*K convolution input patch)
#    K = number of kernel per point (equivalent to K*K convolution kernel size)
#   C1 = number of input feature channels
#   C2 = number of output feature channels

# Inputs:
#   input_features (N, H, C1)
#   neighbor_weights (N, K, H)
#   conv_weights (K, C1, C2)

# Code:

# Project feature from neighbors to kernel points
weighted_feats = torch.matmul(neighbor_weights, input_features)  # (N, K, H) x (N, H, C1) -> (N, K, C1)

# Apply convolution weights and sum over the whole kernel
output_feats = torch.einsum("nkc,kcd->nd", weighted_feats, conv_weights)  # (N, K, C) x (K, C1, C2) -> (N, C2)

# Outputs:
#  output_feats (N, C2)

KPConv is written with simple Pytorch operations, so for involution, I naturally used a similar implementation as your naive Pytorch implementation:

# Dimensions:
#    N = number of points (equivalent to H*W the number of pixels in images)
#    H = number of neighbors per point (equivalent to K*K convolution input patch)
#    K = number of kernel per point (equivalent to K*K convolution kernel size)
#    C = number of input and output feature channels
#    G = number of groups

# Inputs:
#   input_features (N, H, C)
#   neighbor_weights (N, K, H)

# Code:

# Get features at our point locations (like your 2D average pooling)
center_features = torch.mean(input_features, dim=1)  # average across neighbors (N, H, C) -> (N, C)

# Generate convolution weights
conv_weights = gen_mlp(reduce_mlp(center_features ))  # (N, C) -> (N, C//r) -> (N, K*G)

# Project feature from neighbors to kernel points
weighted_feats = torch.matmul(neighbor_weights, input_features)  # (N, K, H) x (N, H, C) -> (N, K, C)

# Apply convolution weights and sum over the whole kernel
weighted_feats = weighted_feats.view(-1, K, G, C//G)  # (N, K, C) -> (N, K, G, C//G)
conv_weights = conv_weights.view(-1, K, G)  # (M, K*G) -> (M, K, G)
output_feats = torch.einsum("nkgc,nkg->ngc", weighted_feats, conv_weights)  # (N, K, G, C//G) x (N, K, G) -> (N, G, C//G)
output_feats = output_feats.view(-1, C)  # (N, G, C//G) -> (M, O)

# Outputs:
#  output_feats (N, C)
@shangshang0912
Copy link

shangshang0912 commented Aug 10, 2022 via email

@HuguesTHOMAS
Copy link
Author

Now, I tried this with the following involution parameters :

channels_per_group = 16
reduc_ratio = 4

and I noticed:

  • the involution drastically reduces the number of parameters in the network.
  • the memory consumption on GPU is equivalent between convolution and involution.
  • the involution is slower (by ~20%) than the convolution.
  • the scores obtained by the involution are very low.

I have a few questions:

  1. First about memory and computation speed. In my implementation, both the convolution and the involution are coded with naive pytorch operation, and not optimized. So they should be comparable. Is it normal that the involution takes as much memory on GPU and is slower than convolution?

  2. About the scores, do you see a reason why I would have very bad scores? Did you have to use a specific optimizer, or learning rate schedule for involution? Are there things to know when training involution networks compared to convolutions?

  3. Does the involution network need a very big dataset to learn its parameters? Did you notice lower scores when training on easier tasks or datasets?

Thanks a lot for taking the time to read me, I hope you can enlighten me because I can't figure out the problem. I verified my implementation several times, and I did not find a bug/error in it.

Best,
Hugues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants