In [1]:
import jax
import jax.numpy as jnp
import jax.random as random
from jax.nn import relu, softmax
import optax

# Generate random data: random point charges in 1D
def generate_data(key, num_samples, num_points, length=10.0):
    """Generates random 1D point charge data."""
    key_positions, key_labels = random.split(key)
    positions = random.uniform(key_positions, shape=(num_samples, num_points), minval=0.0, maxval=length)
    positions = jnp.sort(positions, axis=1)  # Sort positions

    # Coulomb potential computation
    def compute_coulomb_potential(positions):
        num_points = positions.shape[0]
        potential = jnp.zeros(num_points)
        for i in range(num_points):
            for j in range(num_points):
                if i != j:
                    potential = potential.at[i].add(1.0 / jnp.abs(positions[i] - positions[j]))
        return potential

    labels = jax.vmap(compute_coulomb_potential)(positions)
    return positions, labels

# Define the graph U-CNN model
class GraphUCNN:
    def __init__(self, layers):
        self.layers = layers

    def __call__(self, positions, adjacency_matrix):
        """Forward pass of the U-CNN."""
        x = positions[:, None]  # Add feature dimension
        for layer in self.layers:
            x = layer(x, adjacency_matrix)
        return x.squeeze(-1)

class GraphConvLayer:
    def __init__(self, input_dim, output_dim, key):
        key_w, key_b = random.split(key)
        self.w = random.normal(key_w, (input_dim, output_dim)) * 0.1
        self.b = random.normal(key_b, (output_dim,)) * 0.1

    def __call__(self, x, adjacency_matrix):
        """Graph convolution."""
        x = jnp.einsum("ij,jk->ik", adjacency_matrix, x)  # Aggregate messages
        x = x @ self.w + self.b  # Linear transformation
        return relu(x)  # Activation

# Train the model
key = random.PRNGKey(42)
num_samples = 1000
num_points = 10
positions, labels = generate_data(key, num_samples, num_points)

# Create adjacency matrices for 1D graph connections
def create_adjacency_matrix(positions):
    n = len(positions)
    adjacency_matrix = jnp.zeros((n, n))
    for i in range(n - 1):
        adjacency_matrix = adjacency_matrix.at[i, i + 1].set(1.0)
        adjacency_matrix = adjacency_matrix.at[i + 1, i].set(1.0)
    return adjacency_matrix

adjacency_matrices = jax.vmap(create_adjacency_matrix)(positions)

# Model and optimizer setup
key_layers = random.split(key, num=num_points)
layers = [GraphConvLayer(1, 16, key_layers[0]), GraphConvLayer(16, 32, key_layers[1]), GraphConvLayer(32, 1, key_layers[2])]
model = GraphUCNN(layers)

optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(model)

@jax.jit
def loss_fn(params, positions, adjacency_matrices, labels):
    preds = jax.vmap(model)(positions, adjacency_matrices)
    return jnp.mean((preds - labels) ** 2)

@jax.jit
def train_step(params, opt_state, positions, adjacency_matrices, labels):
    loss, grads = jax.value_and_grad(loss_fn)(params, positions, adjacency_matrices, labels)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Training loop
num_epochs = 100
params = model  # In this example, model params are in model itself
for epoch in range(num_epochs):
    params, opt_state, train_loss = train_step(params, opt_state, positions, adjacency_matrices, labels)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {train_loss:.4f}")


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


TypeError: zeros_like requires ndarray or scalar arguments, got <class '__main__.GraphUCNN'> at position 0.