In [1]:
%matplotlib qt
import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.pool import knn_graph
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

def generate_grid(num_points_per_dim=50, size=1.0):
    """Generate approximately uniform points on a sphere."""
    x = np.linspace(-size, size, num_points_per_dim)
    y = np.linspace(-size, size, num_points_per_dim)
    xx, yy = np.meshgrid(x, y)
    return np.stack([xx.ravel(), yy.ravel()], axis=1).astype(np.float32)


def generate_stretched_grid(num_points_per_dim=50, size=1.0, stretch_factor=2.0):
    """Generate non-uniformly stretched points on a sphere."""
    coarse_coords = generate_grid(num_points_per_dim, size)
    # remove_coarse_coords = np.logical_and(
    #     np.abs(coarse_coords[:, 0]) < size / stretch_factor,
    #     np.abs(coarse_coords[:, 1]) < size / stretch_factor,
    # )
    # coarse_coords = coarse_coords[~remove_coarse_coords]
    fine_coords = generate_grid(num_points_per_dim, size / stretch_factor)
    coords = np.concatenate([coarse_coords, fine_coords], axis=0)
    return coords.astype(np.float32)

# uniform mesh
num_points_per_dim = 30
stretch_factor = 5
coords = torch.tensor(generate_grid(num_points_per_dim))
stretched_coords = torch.tensor(generate_stretched_grid(num_points_per_dim, stretch_factor=stretch_factor))
delta_init = torch.zeros(len(coords), 1)
delta_init_stretched = torch.zeros(len(stretched_coords), 1)
# find coord_index closest to (0, 0.5)
center_coord = torch.tensor([-1./stretch_factor, 1./stretch_factor])
distances = torch.norm(coords - center_coord, dim=1)
center_index = torch.argmin(distances)
delta_init[center_index] = 1.0  # delta init at center
distances_stretched = torch.norm(stretched_coords - center_coord, dim=1)
center_index_stretched = torch.argmin(distances_stretched)
delta_init_stretched[center_index_stretched] = 1.0  # delta init at center

edge_index = knn_graph(coords, k=8)
edge_index_stretched = knn_graph(stretched_coords, k=4)
edge_index_stretched = torch.concat([edge_index, edge_index_stretched], dim=1)
edge_index_stretched = torch.unique(edge_index_stretched, dim=1)
data_uniform = Data(x=delta_init, pos=coords, edge_index=edge_index)
data_stretched = Data(x=delta_init_stretched, pos=stretched_coords,
                        edge_index=edge_index_stretched)
# display graph
# matplotlib subplot
plt.figure(figsize=(6, 6))
plt.subplot(2, 1, 1)
G = nx.Graph()
for i in range(data_uniform.pos.size(0)):
    G.add_node(i, pos=(data_uniform.pos[i, 0].item(), data_uniform.pos[i, 1].item()))
for i, j in data_uniform.edge_index.t().tolist():
    G.add_edge(i, j)
pos = nx.get_node_attributes(G, 'pos')
nx.draw(G, pos, node_size=10, with_labels=False)
plt.show()

plt.subplot(2, 1, 2)
G_stretched = nx.Graph()
for i in range(data_stretched.pos.size(0)):
    G_stretched.add_node(i, pos=(data_stretched.pos[i, 0].item(), data_stretched.pos[i, 1].item()))
for i, j in data_stretched.edge_index.t().tolist():
    G_stretched.add_edge(i, j)
pos_stretched = nx.get_node_attributes(G_stretched, 'pos')
nx.draw(G_stretched, pos_stretched, node_size=10, with_labels=False, edge_color='r', node_color='r')
plt.show()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing

class DiffusionGNN(MessagePassing):
    def __init__(self, hidden=64):
        super().__init__(aggr='mean')
        self.edge_mlp = nn.Sequential(
            nn.Linear(1 + 3, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, x, pos, edge_index):
        edge_attr = self._compute_edge_attr(pos, edge_index)
        return self.propagate(edge_index, x=x, edge_attr=edge_attr, pos=pos)

    def message(self, x_j, x_i, edge_attr):
        rel_feat = torch.cat([x_j - x_i, edge_attr], dim=-1)
        return self.edge_mlp(rel_feat)

    def update(self, aggr_out, x):
        return x + aggr_out  # residual diffusion update

    def _compute_edge_attr(self, pos, edge_index):
        src, dst = edge_index
        rel_vec = pos[dst] - pos[src]
        dist = rel_vec.norm(dim=-1, keepdim=True)
        return rel_vec  # can include dist as extra feature if desired


In [3]:
def generate_diffusion_targets(data, kappa=0.1, dt=0.05):
    """Generate next-step ground truth using Laplacian operator approximation."""
    x, edge_index, pos = data.x, data.edge_index, data.pos
    src, dst = edge_index
    dist = (pos[src] - pos[dst]).norm(dim=-1)
    w = 1.0 / (dist + 1e-7)
    lap = torch.zeros_like(x)
    for i in range(x.shape[0]):
        mask = (src == i)
        neighbors = dst[mask]
        weights = w[mask][:, None]
        lap[i] = (weights * (x[neighbors] - x[i])).mean()
    return x + kappa * dt * lap


In [4]:
from copy import copy
from tqdm import tqdm
data_sim = copy(data_uniform)
for step in tqdm(range(200)):
    y_true = generate_diffusion_targets(data_sim, kappa=1., dt=0.1)
    data_sim.x = y_true
color_sim = data_sim.x    
plt.figure(figsize=(6, 6))
plt.subplot(1, 1, 1)
plt.scatter(data_sim.pos[:, 0], data_sim.pos[:, 1], c=color_sim.squeeze().numpy(), cmap='viridis')
plt.title('Uniform Grid Diffusion Targets')
plt.colorbar()
plt.show()

100%|██████████| 200/200 [00:08<00:00, 22.57it/s]


In [5]:
from copy import copy
from tqdm import tqdm
data_sim_stretched = copy(data_stretched)
for step in tqdm(range(500)):
    y_true = generate_diffusion_targets(data_sim_stretched, kappa=1., dt=0.1)
    data_sim_stretched.x = y_true
color_sim = data_sim_stretched.x    
plt.figure(figsize=(6, 6))
plt.subplot(1, 1, 1)
plt.scatter(data_sim_stretched.pos[:, 0], data_sim_stretched.pos[:, 1], c=color_sim.squeeze().numpy(), cmap='viridis')
plt.title('Stretch Grid Diffusion Targets')
plt.colorbar()
plt.show()

100%|██████████| 500/500 [00:49<00:00, 10.17it/s]


In [68]:
from copy import copy
from tqdm import tqdm
data_sim = copy(data_uniform)
for step in tqdm(range(100)):
    y_true = generate_diffusion_targets(data_sim, kappa=0.1, dt=0.05)
    data_sim.x = y_true
color_sim = data_sim.x    
#color = generate_diffusion_targets(data_uniform, kappa=1.)
#color_stretched = generate_diffusion_targets(data_stretched)
plt.figure(figsize=(6, 6))
plt.subplot(1, 1, 1)
plt.scatter(data_sim.pos[:, 0], data_sim.pos[:, 1], c=color_sim.squeeze().numpy(), cmap='viridis')
plt.title('Uniform Grid Diffusion Targets')
plt.colorbar()
plt.show()

100%|██████████| 100/100 [00:03<00:00, 27.06it/s]


In [5]:
import torch.optim as optim

model = DiffusionGNN()
opt = optim.Adam(model.parameters(), lr=1e-3)

data = data_uniform
y_true = generate_diffusion_targets(data)

for epoch in range(1000):
    opt.zero_grad()
    y_pred = model(data.x, data.pos, data.edge_index)
    loss = F.mse_loss(y_pred, y_true)
    loss.backward()
    opt.step()
    if epoch % 50 == 0:
        print(f"Epoch {epoch}: {loss.item():.6f}")


Epoch 0: 0.020179
Epoch 50: 0.000046
Epoch 100: 0.000028
Epoch 150: 0.000026
Epoch 200: 0.000026
Epoch 250: 0.000025
Epoch 300: 0.000025
Epoch 350: 0.000025
Epoch 400: 0.000024
Epoch 450: 0.000024
Epoch 500: 0.000024
Epoch 550: 0.000024
Epoch 600: 0.000024
Epoch 650: 0.000024
Epoch 700: 0.000024
Epoch 750: 0.000024
Epoch 800: 0.000024
Epoch 850: 0.000023
Epoch 900: 0.000023
Epoch 950: 0.000023


In [6]:
with torch.no_grad():
    pred_stretched = model(data_stretched.x, data_stretched.pos, data_stretched.edge_index)
    true_stretched = generate_diffusion_targets(data_stretched)
    err = F.mse_loss(pred_stretched, true_stretched)
    print("Cross-grid (stretched) MSE:", err.item())


Cross-grid (stretched) MSE: 0.44955042004585266


In [8]:
import matplotlib.pyplot as plt

def plot_sphere(data, values, title):
    fig = plt.figure(figsize=(5,5))
    ax = fig.add_subplot(111, projection='3d')
    p = ax.scatter(data.pos[:,0], data.pos[:,1], data.pos[:,2],
                   c=values.squeeze().numpy(), cmap='coolwarm', s=5)
    plt.colorbar(p)
    ax.set_title(title)
    plt.show()

plot_sphere(data_uniform, y_true, "True (uniform)")
plot_sphere(data_stretched, pred_stretched, "Predicted (stretched)")


In [8]:
# plot coords
%matplotlib qt
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], s=1, c='b')
ax.scatter(coords_stretched[:, 0], coords_stretched[:, 1], coords_stretched[:, 2], s=1, c='r')
plt.show()

Data(x=[2000, 1], edge_index=[2, 16000], pos=[2000, 3])

