In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_cluster import radius
from torch_scatter import scatter

class ContinuousConv(nn.Module):
    def __init__(self, in_channels, out_channels, filter_resolution=4, radius=0.5):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.radius = radius
        self.filter_resolution = filter_resolution
        # Learnable filter grid: shape (D, D, D, in_channels, out_channels)
        self.filters = nn.Parameter(torch.randn(filter_resolution, filter_resolution, filter_resolution, in_channels, out_channels))
        
    def ball_to_cube(self, r):
        """
        Maps a relative position r (shape: (E,3)) from the ball to a cube coordinate in [-1, 1]^3.
        Here we use a simplified mapping. For the full mapping, see Griepentrog et al. (2008).
        """
        norm = torch.norm(r, dim=-1, keepdim=True)  # (E,1)
        r_unit = r / (norm + 1e-8)                  # (E,3)
        return r_unit * torch.tanh(norm)            # (E,3)
    
    def trilinear_interpolate(self, coords):
        """
        Performs trilinear interpolation on the filter grid.
        coords: (E, 3) continuous coordinates in grid space, in [0, D-1]
        Returns: interpolated filters of shape (E, in_channels, out_channels)
        """
        D = self.filter_resolution
        # coords: (E,3) -> split into x, y, z
        x = coords[:, 0]
        y = coords[:, 1]
        z = coords[:, 2]
        
        # Get floor and ceiling indices
        x0 = x.floor().long()
        y0 = y.floor().long()
        z0 = z.floor().long()
        x1 = (x0 + 1).clamp(max=D-1)
        y1 = (y0 + 1).clamp(max=D-1)
        z1 = (z0 + 1).clamp(max=D-1)
        
        # Compute distances for interpolation weights
        xd = (x - x0.float()).view(-1, 1, 1)  # shape (E,1,1)
        yd = (y - y0.float()).view(-1, 1, 1)
        zd = (z - z0.float()).view(-1, 1, 1)
        
        # Gather filter values at the 8 corners.
        c000 = self.filters[x0, y0, z0]  # shape (E, in_channels, out_channels)
        c001 = self.filters[x0, y0, z1]
        c010 = self.filters[x0, y1, z0]
        c011 = self.filters[x0, y1, z1]
        c100 = self.filters[x1, y0, z0]
        c101 = self.filters[x1, y0, z1]
        c110 = self.filters[x1, y1, z0]
        c111 = self.filters[x1, y1, z1]
        
        # Trilinear interpolation
        c00 = c000 * (1 - zd) + c001 * zd
        c01 = c010 * (1 - zd) + c011 * zd
        c10 = c100 * (1 - zd) + c101 * zd
        c11 = c110 * (1 - zd) + c111 * zd
        
        c0 = c00 * (1 - yd) + c01 * yd
        c1 = c10 * (1 - yd) + c11 * yd
        
        c = c0 * (1 - xd) + c1 * xd  # (E, in_channels, out_channels)
        return c

        
    def forward(self, positions, features):
        """
        positions: Tensor of shape (N, 3) representing particle positions
        features: Tensor of shape (N, in_channels) representing particle features
        Returns:
            output: Tensor of shape (N, out_channels)
        """
        N = positions.shape[0] if positions.dim() == 2 else positions.shape[1]
        # Compute neighbor indices using radius search.
        # edge_index: Tensor of shape (2, num_edges)
        edge_index = radius(positions, positions, self.radius, max_num_neighbors=32)
        # Ensure row and col are 1D
        row, col = edge_index
        row = row.flatten()
        col = col.flatten()
        
        # Compute relative positions: r = x_neighbor - x_target, shape: (E, 3)
        r = positions[col] - positions[row]
        
        # Compute window function: a(xi, x) = (1 - ||r||^2/R^2)^3 for ||r||^2 < R^2, else 0.
        dist2 = (r ** 2).sum(dim=-1)  # (E,)
        valid = (dist2 < self.radius**2).float()
        window = ((1 - dist2 / (self.radius**2)) ** 3) * valid  # (E,)
        
        # Map relative coordinates via ball_to_cube.
        mapped = self.ball_to_cube(r)  # should have shape (E,3)
        # Convert from [-1, 1] to grid coordinates: -1 -> 0, 1 -> D-1.
        grid_coords = (mapped + 1) * ((self.filter_resolution - 1) / 2)
        
        # Trilinear interpolate the filters at these continuous coordinates.
        filt = self.trilinear_interpolate(grid_coords)  # (E, in_channels, out_channels)
        
        # Multiply the interpolated filter with source features.
        # features[col] has shape (E, in_channels); we sum over the in_channels.
        conv_edge = torch.einsum('eio,ei->eo', filt, features[col])  # (E, out_channels)
        
        # Weight by the window function.
        conv_edge = conv_edge * window.unsqueeze(1)
        
        # Aggregate contributions for each target particle.
        output = scatter(conv_edge, row, dim=0, dim_size=N, reduce="mean")
        
        # We use ψ(x)=1 (no additional normalization) as per the paper.
        return output

# Example usage:
if __name__ == "__main__":
    torch.manual_seed(0)
    # Create 3 particles with random positions (N,3) and 16 feature channels.
    pos = torch.randn(3, 250, 3)
    feats = torch.randn(3, 250, 7)
    
    conv = ContinuousConv(7, 3, filter_resolution=4, radius=0.5)
    out = conv(pos, feats)
    print("Output shape:", out.shape)
    print(out)


RuntimeError: x.dim() == 2 INTERNAL ASSERT FAILED at "csrc/cpu/radius_cpu.cpp":13, please report a bug to PyTorch. Input mismatch