In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_cluster import radius
from torch_scatter import scatter
from time import time
from torch_geometric.utils import add_self_loops

class ContinuousConv(nn.Module):
    def __init__(self, in_channels, out_channels, filter_resolution=4, radius=0.5):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.radius = radius
        self.filter_resolution = filter_resolution
        # Learnable filter grid: shape (D, D, D, in_channels, out_channels)
        self.filters = nn.Parameter(torch.randn(filter_resolution, filter_resolution, filter_resolution, in_channels, out_channels))
        
    def ball_to_cube(self, r):
        """
        Maps a relative position r (shape: (E,3)) from the ball to a cube coordinate in [-1, 1]^3.
        Here we use a simplified mapping. For the full mapping, see Griepentrog et al. (2008).
        """
        norm = torch.norm(r, dim=-1, keepdim=True)  # (E,1)
        r_unit = r / (norm + 1e-8)                  # (E,3)
        return r_unit * torch.tanh(norm)            # (E,3)
    
    def trilinear_interpolate(self, coords):
        """
        Performs trilinear interpolation on the filter grid.
        coords: (E, 3) continuous coordinates in grid space, in [0, D-1]
        Returns: interpolated filters of shape (E, in_channels, out_channels)
        """
        D = self.filter_resolution
        # coords: (E,3) -> split into x, y, z
        x = coords[:, 0]
        y = coords[:, 1]
        z = coords[:, 2]
        
        # Get floor and ceiling indices
        x0 = x.floor().long()
        y0 = y.floor().long()
        z0 = z.floor().long()
        x1 = (x0 + 1).clamp(max=D-1)
        y1 = (y0 + 1).clamp(max=D-1)
        z1 = (z0 + 1).clamp(max=D-1)
        
        # Compute distances for interpolation weights
        xd = (x - x0.float()).view(-1, 1, 1)  # shape (E,1,1)
        yd = (y - y0.float()).view(-1, 1, 1)
        zd = (z - z0.float()).view(-1, 1, 1)
        
        # Gather filter values at the 8 corners.
        c000 = self.filters[x0, y0, z0]  # shape (E, in_channels, out_channels)
        c001 = self.filters[x0, y0, z1]
        c010 = self.filters[x0, y1, z0]
        c011 = self.filters[x0, y1, z1]
        c100 = self.filters[x1, y0, z0]
        c101 = self.filters[x1, y0, z1]
        c110 = self.filters[x1, y1, z0]
        c111 = self.filters[x1, y1, z1]
        
        # Trilinear interpolation
        c00 = c000 * (1 - zd) + c001 * zd
        c01 = c010 * (1 - zd) + c011 * zd
        c10 = c100 * (1 - zd) + c101 * zd
        c11 = c110 * (1 - zd) + c111 * zd
        
        c0 = c00 * (1 - yd) + c01 * yd
        c1 = c10 * (1 - yd) + c11 * yd
        
        c = c0 * (1 - xd) + c1 * xd  # (E, in_channels, out_channels)
        return c

        
    def forward(self, positions, features):
        """
        positions: Tensor of shape (N, 3) representing particle positions
        features: Tensor of shape (N, in_channels) representing particle features
        Returns:
            output: Tensor of shape (N, out_channels)
        """
        N = positions.shape[0] if positions.dim() == 2 else positions.shape[1]
        # Compute neighbor indices using radius search.
        # edge_index: Tensor of shape (2, num_edges)
        edge_index = radius(positions, positions, self.radius, max_num_neighbors=32)
        # Ensure row and col are 1D
        row, col = edge_index
        row = row.flatten()
        col = col.flatten()
        
        # Compute relative positions: r = x_neighbor - x_target, shape: (E, 3)
        r = positions[col] - positions[row]
        
        # Compute window function: a(xi, x) = (1 - ||r||^2/R^2)^3 for ||r||^2 < R^2, else 0.
        dist2 = (r ** 2).sum(dim=-1)  # (E,)
        valid = (dist2 < self.radius**2).float()
        window = ((1 - dist2 / (self.radius**2)) ** 3) * valid  # (E,)
        
        # Map relative coordinates via ball_to_cube.
        mapped = self.ball_to_cube(r)  # should have shape (E,3)
        # Convert from [-1, 1] to grid coordinates: -1 -> 0, 1 -> D-1.
        grid_coords = (mapped + 1) * ((self.filter_resolution - 1) / 2)
        
        # Trilinear interpolate the filters at these continuous coordinates.
        filt = self.trilinear_interpolate(grid_coords)  # (E, in_channels, out_channels)
        
        # Multiply the interpolated filter with source features.
        # features[col] has shape (E, in_channels); we sum over the in_channels.
        conv_edge = torch.einsum('eio,ei->eo', filt, features[col])  # (E, out_channels)
        
        # Weight by the window function.
        conv_edge = conv_edge * window.unsqueeze(1)
        
        # Aggregate contributions for each target particle.
        output = scatter(conv_edge, row, dim=0, dim_size=N, reduce="mean")
        
        # We use ψ(x)=1 (no additional normalization) as per the paper.
        return output

# Example usage:
if __name__ == "__main__":
    torch.manual_seed(0)
    # Create 3 particles with random positions (N,3) and 16 feature channels.
    pos = torch.randn(10000, 3)
    feats = torch.randn(10000, 7)
    
    conv = ContinuousConv(7, 3, filter_resolution=4, radius=0.5)
    start = time()
    out = conv(pos, feats)
    end = time()
    print("Forward pass took", end - start, "seconds.")
    print("Output shape:", out.shape)
    print(out)


Forward pass took 0.1304917335510254 seconds.
Output shape: torch.Size([10000, 3])
tensor([[-0.1385, -0.0430, -0.1563],
        [-0.0022,  0.0188, -0.0830],
        [ 0.0561, -0.0546, -0.0832],
        ...,
        [-0.0430,  0.0866, -0.0392],
        [-0.0804,  0.0928,  0.0526],
        [ 0.0265,  0.0318,  0.1496]], grad_fn=<DivBackward0>)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_cluster import radius
from torch_scatter import scatter
from time import time

class ContinuousConv(nn.Module):
    def __init__(self, in_channels, out_channels, filter_resolution=4, radius=0.5):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.radius = radius
        self.filter_resolution = filter_resolution
        # Learnable filter grid: shape (D, D, D, in_channels, out_channels)
        self.filters = nn.Parameter(torch.randn(filter_resolution, filter_resolution, filter_resolution, in_channels, out_channels))
        
    def ball_to_cube(self, r):
        """
        Maps a relative position r (shape: (E,3)) from the ball to a cube coordinate in [-1, 1]^3.
        Here we use a simplified mapping. For the full mapping, see Griepentrog et al. (2008).
        """
        norm = torch.norm(r, dim=-1, keepdim=True)  # (E,1)
        r_unit = r / (norm + 1e-8)                  # (E,3)
        return r_unit * torch.tanh(norm)            # (E,3)
    
    def trilinear_interpolate(self, coords):
        """
        Performs trilinear interpolation on the filter grid.
        coords: (E, 3) continuous coordinates in grid space, in [0, D-1]
        Returns: interpolated filters of shape (E, in_channels, out_channels)
        """
        D = self.filter_resolution
        # coords: (E,3) -> split into x, y, z
        x = coords[:, 0]
        y = coords[:, 1]
        z = coords[:, 2]
        
        # Get floor and ceiling indices
        x0 = x.floor().long()
        y0 = y.floor().long()
        z0 = z.floor().long()
        x1 = (x0 + 1).clamp(max=D-1)
        y1 = (y0 + 1).clamp(max=D-1)
        z1 = (z0 + 1).clamp(max=D-1)
        
        # Compute distances for interpolation weights
        xd = (x - x0.float()).view(-1, 1, 1)  # shape (E,1,1)
        yd = (y - y0.float()).view(-1, 1, 1)
        zd = (z - z0.float()).view(-1, 1, 1)
        
        # Gather filter values at the 8 corners.
        c000 = self.filters[x0, y0, z0]  # shape (E, in_channels, out_channels)
        c001 = self.filters[x0, y0, z1]
        c010 = self.filters[x0, y1, z0]
        c011 = self.filters[x0, y1, z1]
        c100 = self.filters[x1, y0, z0]
        c101 = self.filters[x1, y0, z1]
        c110 = self.filters[x1, y1, z0]
        c111 = self.filters[x1, y1, z1]
        
        # Trilinear interpolation
        c00 = c000 * (1 - zd) + c001 * zd
        c01 = c010 * (1 - zd) + c011 * zd
        c10 = c100 * (1 - zd) + c101 * zd
        c11 = c110 * (1 - zd) + c111 * zd
        
        c0 = c00 * (1 - yd) + c01 * yd
        c1 = c10 * (1 - yd) + c11 * yd
        
        c = c0 * (1 - xd) + c1 * xd  # (E, in_channels, out_channels)
        return c
        
    def forward(self, positions, features):
        """
        positions: Tensor of shape (N, 3) representing particle positions
        features: Tensor of shape (N, in_channels) representing particle features
        Returns:
            output: Tensor of shape (N, out_channels)
        """
        N = positions.shape[0] if positions.dim() == 2 else positions.shape[1]
        # Compute neighbor indices using radius search.
        # edge_index: Tensor of shape (2, num_edges)
        edge_index = radius(positions, positions, self.radius, max_num_neighbors=32)
        # Ensure row and col are 1D
        row, col = edge_index
        row = row.flatten()
        col = col.flatten()
        
        # Compute relative positions: r = x_neighbor - x_target, shape: (E, 3)
        r = positions[col] - positions[row]
        
        # Compute window function: a(xi, x) = (1 - ||r||^2/R^2)^3 for ||r||^2 < R^2, else 0.
        dist2 = (r ** 2).sum(dim=-1)  # (E,)
        valid = (dist2 < self.radius**2).float()
        window = ((1 - dist2 / (self.radius**2)) ** 3) * valid  # (E,)
        
        # Map relative coordinates via ball_to_cube.
        mapped = self.ball_to_cube(r)  # should have shape (E,3)
        # Convert from [-1, 1] to grid coordinates: -1 -> 0, 1 -> D-1.
        grid_coords = (mapped + 1) * ((self.filter_resolution - 1) / 2)
        
        # Trilinear interpolate the filters at these continuous coordinates.
        filt = self.trilinear_interpolate(grid_coords)  # (E, in_channels, out_channels)
        
        # Multiply the interpolated filter with source features.
        # features[col] has shape (E, in_channels); we sum over the in_channels.
        conv_edge = torch.einsum('eio,ei->eo', filt, features[col])  # (E, out_channels)
        
        # Weight by the window function.
        conv_edge = conv_edge * window.unsqueeze(1)
        
        # Aggregate contributions for each target particle.
        output = scatter(conv_edge, row, dim=0, dim_size=N, reduce="mean")
        
        # We use ψ(x)=1 (no additional normalization) as per the paper.
        return output
    

        


# Example usage:
if __name__ == "__main__":
    torch.manual_seed(0)
    # Create 3 particles with random positions (N,3) and 16 feature channels.
    pos = torch.randn(10000, 3)
    feats = torch.randn(10000, 7)
    
    conv = ContinuousConv(7, 3, filter_resolution=4, radius=0.5)
    start = time()
    out = conv(pos, feats)
    end = time()
    print("Forward pass took", end - start, "seconds.")
    print("Output shape:", out.shape)
    print(out)


In [32]:
class ContinuousConv(nn.Module):
    def __init__(self, in_channels, out_channels, filter_resolution=4, radius=0.5, self_connection=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.radius = radius
        self.filter_resolution = filter_resolution
        self.self_connection = self_connection
        self.filters = nn.Parameter(torch.randn(filter_resolution, filter_resolution, filter_resolution, in_channels, out_channels))
        
    def ball_to_cube(self, r):
        norm = torch.norm(r, dim=-1, keepdim=True)
        r_unit = r / (norm + 1e-8)
        return r_unit * torch.tanh(norm)
    
    def trilinear_interpolate(self, coords):
        D = self.filter_resolution
        x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
        x0, y0, z0 = x.floor().long(), y.floor().long(), z.floor().long()
        x1, y1, z1 = (x0 + 1).clamp(max=D-1), (y0 + 1).clamp(max=D-1), (z0 + 1).clamp(max=D-1)
        xd, yd, zd = (x - x0.float()).view(-1, 1, 1), (y - y0.float()).view(-1, 1, 1), (z - z0.float()).view(-1, 1, 1)
        
        c000, c001 = self.filters[x0, y0, z0], self.filters[x0, y0, z1]
        c010, c011 = self.filters[x0, y1, z0], self.filters[x0, y1, z1]
        c100, c101 = self.filters[x1, y0, z0], self.filters[x1, y0, z1]
        c110, c111 = self.filters[x1, y1, z0], self.filters[x1, y1, z1]
        
        c00, c01, c10, c11 = c000 * (1 - zd) + c001 * zd, c010 * (1 - zd) + c011 * zd, c100 * (1 - zd) + c101 * zd, c110 * (1 - zd) + c111 * zd
        c0, c1 = c00 * (1 - yd) + c01 * yd, c10 * (1 - yd) + c11 * yd
        return c0 * (1 - xd) + c1 * xd
    
    def forward(self, data):
        positions, features = data.pos, data.x
        batch = data.batch if hasattr(data, 'batch') else torch.zeros(positions.size(0), dtype=torch.long, device=positions.device)
        
        edge_index = radius(positions, positions, self.radius, batch_x=batch, batch_y=batch, max_num_neighbors=32)
        print(edge_index.shape)
        if self.self_connection:
            edge_index, _ = add_self_loops(edge_index)
        print(edge_index.shape)
        row, col = edge_index[0], edge_index[1]
        
        r = positions[col] - positions[row]
        print(r.shape)
        dist2 = (r ** 2).sum(dim=-1)
        valid = (dist2 < self.radius**2).float()
        window = ((1 - dist2 / (self.radius**2)) ** 3) * valid
        
        mapped = self.ball_to_cube(r)
        grid_coords = (mapped + 1) * ((self.filter_resolution - 1) / 2)
        print(grid_coords.shape)
        filt = self.trilinear_interpolate(grid_coords)
        print(filt.shape)
        conv_edge = torch.einsum('eio,ei->eo', filt, features[col])
        conv_edge = conv_edge * window.unsqueeze(1)
        
        output = scatter(conv_edge, row, dim=0, dim_size=positions.size(0), reduce="mean")
        return output
    
from torch_geometric.data import Data, Batch

# Create sample data for two simple graphs

# Graph 1: 10 nodes with random 3D positions and 3-dimensional features.
pos1 = torch.rand(10, 3)
x1 = torch.rand(10, 5)  # assuming in_channels = 3
acc = torch.rand(10, 3)
data1 = Data(pos=pos1, x=x1, acc=acc)

# Graph 2: 15 nodes with random 3D positions and 3-dimensional features.
pos2 = torch.rand(15, 3)
x2 = torch.rand(15, 5)  # assuming in_channels = 3
acc = torch.rand(15, 3)
data2 = Data(pos=pos2, x=x2, acc=acc)

# Batch the graphs into a single Batch object.
batch_data = Batch.from_data_list([data1, data2])
model = ContinuousConv(in_channels=5, out_channels=3, filter_resolution=4, radius=0.5)

print(model(batch_data))


torch.Size([2, 143])
torch.Size([2, 168])
torch.Size([168, 3])
torch.Size([168, 3])
torch.Size([168, 5, 3])
tensor([[-0.0533, -0.1668, -0.0093],
        [-0.0381, -0.2061, -0.0651],
        [-0.0129, -0.2045, -0.1225],
        [-0.1168, -0.3230,  0.1320],
        [-0.0642, -0.2461, -0.0149],
        [-0.0053, -0.1680,  0.0161],
        [-0.1595, -0.4142,  0.0877],
        [-0.0611, -0.2656,  0.0544],
        [-0.1566, -0.3811, -0.0055],
        [-0.0204, -0.2228, -0.0267],
        [-0.0574, -0.2421,  0.0308],
        [-0.0601, -0.3400, -0.0209],
        [-0.1851, -0.2070,  0.1003],
        [-0.2794, -0.4179,  0.1564],
        [-0.0259, -0.1924,  0.0774],
        [-0.0833, -0.1909,  0.0717],
        [-0.0355, -0.1606,  0.0351],
        [-0.1438, -0.2334,  0.1545],
        [-0.0434, -0.3050, -0.0505],
        [-0.0776, -0.3524,  0.0103],
        [ 0.0044, -0.1516,  0.0775],
        [-0.1626, -0.2761,  0.1048],
        [-0.0843, -0.4863, -0.0470],
        [-0.1551, -0.3526,  0.0592],
    