In [1]:
%matplotlib qt
import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.pool import knn_graph
import numpy as np

def fibonacci_sphere(samples=1000):
    """Generate approximately uniform points on a sphere."""
    indices = np.arange(0, samples, dtype=float) + 0.5
    phi = np.arccos(1 - 2*indices/samples)
    theta = np.pi * (1 + 5**0.5) * indices
    x, y, z = np.cos(theta) * np.sin(phi), np.sin(theta) * np.sin(phi), np.cos(phi)
    return np.stack([x, y, z], axis=1).astype(np.float32)

# uniform mesh
coords = torch.tensor(fibonacci_sphere(2000))
edge_index = knn_graph(coords, k=8)

data_uniform = Data(x=torch.randn(len(coords), 1), pos=coords, edge_index=edge_index)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def stretch_sphere(coords, stretch_factor=3.0):
    """Apply a smooth non-uniform stretch mapping around the +x pole."""
    x, y, z = coords[:,0], coords[:,1], coords[:,2]
    # local weighting: stronger deformation near +x pole
    weight = torch.exp(stretch_factor * x) / torch.exp(torch.as_tensor(stretch_factor))
    stretched = torch.stack([x, weight*y, weight*z], dim=1)
    # renormalize back onto unit sphere
    stretched = stretched / stretched.norm(dim=1, keepdim=True)
    return stretched

coords_stretched = stretch_sphere(coords.clone(), stretch_factor=2.0)
edge_index_stretched = knn_graph(coords_stretched, k=8)

data_stretched = Data(x=torch.randn(len(coords), 1),
                      pos=coords_stretched,
                      edge_index=edge_index_stretched)


In [3]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

class DiffusionGNN(MessagePassing):
    def __init__(self, hidden=64):
        super().__init__(aggr='mean')
        self.edge_mlp = nn.Sequential(
            nn.Linear(1 + 3, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, x, pos, edge_index):
        edge_attr = self._compute_edge_attr(pos, edge_index)
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, pos=pos)

    def message(self, x_j, x_i, edge_attr):
        rel_feat = torch.cat([x_j - x_i, edge_attr], dim=-1)
        return self.edge_mlp(rel_feat)

    def update(self, aggr_out, x):
        return x + aggr_out  # residual diffusion update

    def _compute_edge_attr(self, pos, edge_index):
        src, dst = edge_index
        rel_vec = pos[dst] - pos[src]
        dist = rel_vec.norm(dim=-1, keepdim=True)
        return rel_vec  # can include dist as extra feature if desired


In [4]:
def generate_diffusion_targets(data, kappa=0.1, dt=0.05):
    """Generate next-step ground truth using Laplacian operator approximation."""
    x, edge_index, pos = data.x, data.edge_index, data.pos
    src, dst = edge_index
    dist = (pos[src] - pos[dst]).norm(dim=-1)
    w = 1.0 / (dist + 1e-6)
    lap = torch.zeros_like(x)
    for i in range(x.shape[0]):
        mask = (src == i)
        neighbors = dst[mask]
        weights = w[mask][:, None]
        lap[i] = (weights * (x[neighbors] - x[i])).mean()
    return x + kappa * dt * lap


In [5]:
import torch.optim as optim

model = DiffusionGNN()
opt = optim.Adam(model.parameters(), lr=1e-3)

data = data_uniform
y_true = generate_diffusion_targets(data)

for epoch in range(1000):
    opt.zero_grad()
    y_pred = model(data.x, data.pos, data.edge_index)
    loss = F.mse_loss(y_pred, y_true)
    loss.backward()
    opt.step()
    if epoch % 50 == 0:
        print(f"Epoch {epoch}: {loss.item():.6f}")


Epoch 0: 0.020179
Epoch 50: 0.000046
Epoch 100: 0.000028
Epoch 150: 0.000026
Epoch 200: 0.000026
Epoch 250: 0.000025
Epoch 300: 0.000025
Epoch 350: 0.000025
Epoch 400: 0.000024
Epoch 450: 0.000024
Epoch 500: 0.000024
Epoch 550: 0.000024
Epoch 600: 0.000024
Epoch 650: 0.000024
Epoch 700: 0.000024
Epoch 750: 0.000024
Epoch 800: 0.000024
Epoch 850: 0.000023
Epoch 900: 0.000023
Epoch 950: 0.000023


In [6]:
with torch.no_grad():
    pred_stretched = model(data_stretched.x, data_stretched.pos, data_stretched.edge_index)
    true_stretched = generate_diffusion_targets(data_stretched)
    err = F.mse_loss(pred_stretched, true_stretched)
    print("Cross-grid (stretched) MSE:", err.item())


Cross-grid (stretched) MSE: 0.44955042004585266


In [8]:
import matplotlib.pyplot as plt

def plot_sphere(data, values, title):
    fig = plt.figure(figsize=(5,5))
    ax = fig.add_subplot(111, projection='3d')
    p = ax.scatter(data.pos[:,0], data.pos[:,1], data.pos[:,2],
                   c=values.squeeze().numpy(), cmap='coolwarm', s=5)
    plt.colorbar(p)
    ax.set_title(title)
    plt.show()

plot_sphere(data_uniform, y_true, "True (uniform)")
plot_sphere(data_stretched, pred_stretched, "Predicted (stretched)")


In [8]:
# plot coords
%matplotlib qt
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], s=1, c='b')
ax.scatter(coords_stretched[:, 0], coords_stretched[:, 1], coords_stretched[:, 2], s=1, c='r')
plt.show()

Data(x=[2000, 1], edge_index=[2, 16000], pos=[2000, 3])

