In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial

In [2]:
class GraphConvolution(nn.Module):
    features: int # Output feature dimensionality for nodes

    @nn.compact
    def __call__(self, node_features, adj):
        # node_features shape: (num_nodes, input_features)
        # adj shape: (num_nodes, num_nodes)

        # BUG 1: Flawed / Incomplete Graph Convolution
        # This transformation only considers the node itself, ignoring neighbors.
        # A proper GCN layer should aggregate features from neighbors using 'adj'.
        # Example: `jnp.dot(adj, transformed_features)` would sum neighbor features.
        transformed_features = nn.Dense(features=self.features)(node_features)

        # This update rule doesn't actually use the graph structure (adj matrix)!
        # It's essentially just applying an MLP to each node independently.
        updated_features = nn.relu(transformed_features)

        # Shape: (num_nodes, self.features)
        return updated_features

In [3]:
class SimpleGNN(nn.Module):
    hidden_dim: int
    num_layers: int

    @nn.compact
    def __call__(self, node_features, adj, training: bool):
        # node_features: (batch_size, num_nodes, node_feature_dim)
        # adj: (batch_size, num_nodes, num_nodes)
        # Note: For simplicity, assuming all graphs in batch have same num_nodes (padding)

        x = node_features
        for _ in range(self.num_layers):
            # Pass adjacency matrix, even though the buggy layer doesn't use it
            x_new = GraphConvolution(features=self.hidden_dim)(x, adj)
            x = x + x_new # Basic residual connection
            x = nn.LayerNorm()(x)

        # BUG 4: Graph Pooling / Readout - Summation Issue
        # Summing node features can lead to embeddings that scale with graph size.
        # Mean pooling (jnp.mean) is often more robust if graph sizes vary.
        # graph_embedding = jnp.mean(x, axis=1) # axis=1 is the node dimension
        graph_embedding = jnp.sum(x, axis=1) # Shape: (batch_size, hidden_dim)

        # Prediction head for regression task
        output = nn.Dense(features=1)(graph_embedding) # Predict a single continuous value
        # Shape: (batch_size, 1)
        return output


In [4]:
def calculate_loss(prediction, target):
    # prediction: (1,) - model output for one graph
    # target: (1,) - true value for one graph
    # This expects logits and binary labels (0 or 1).
    # return optax.squared_error(prediction, target).squeeze() # Correct loss for regression
    return optax.sigmoid_binary_cross_entropy(prediction, target.astype(jnp.int32)).squeeze()

In [5]:
def train_step(params, opt_state, batch, key, model_apply_fn):
    node_features, adj, targets = batch

    def loss_fn(params):
        predictions = model_apply_fn({'params': params}, node_features, adj, training=True, rngs={'dropout': key}) # Assuming dropout might be added later
        # Ensure target is float for potential MSE loss later
        loss = jnp.mean(calculate_loss(predictions, targets.astype(jnp.float32)))
        return loss

    loss_val, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

In [6]:
def generate_dummy_graph_batch(key, batch_size, num_nodes, node_feature_dim):
    key, key_nodes, key_adj, key_targets = jax.random.split(key, 4)

    # Random node features (e.g., atom properties)
    node_features = jax.random.normal(key_nodes, (batch_size, num_nodes, node_feature_dim))

    # Random adjacency matrix (simplified: symmetric, no self-loops for this example)
    adj_dense = jax.random.randint(key_adj, (batch_size, num_nodes, num_nodes), 0, 2)
    adj_symm = jnp.maximum(adj_dense, jnp.transpose(adj_dense, (0, 2, 1))) # Make symmetric
    adj = jnp.where(jnp.eye(num_nodes, dtype=bool), 0, adj_symm) # Remove self-loops

    # BUG 3: Target values potentially badly scaled
    # Generating targets in a large range (e.g., 0-1000) when model outputs
    # might initially be small (e.g., near 0) can lead to huge losses/gradients.
    # Targets = jax.random.uniform(key_targets, (batch_size, 1), minval=0.0, maxval=10.0) # A more reasonable scale
    targets = jax.random.uniform(key_targets, (batch_size, 1), minval=0.0, maxval=1000.0)

    return (node_features, adj, targets), key

In [7]:
NODE_FEATURE_DIM = 16
HIDDEN_DIM = 64
NUM_LAYERS = 3
NUM_NODES = 30 # Assume fixed size via padding for simplicity
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_STEPS = 1000

In [8]:
# Initialization
key = jax.random.PRNGKey(42)
model_key, params_key, data_key, dropout_key, loop_key = jax.random.split(key, 5)

# Instantiate model
model = SimpleGNN(hidden_dim=HIDDEN_DIM, num_layers=NUM_LAYERS)

# Initialize parameters
dummy_nodes = jnp.ones((BATCH_SIZE, NUM_NODES, NODE_FEATURE_DIM))
dummy_adj = jnp.ones((BATCH_SIZE, NUM_NODES, NUM_NODES))
params = model.init(params_key, dummy_nodes, dummy_adj, training=False)['params']

# Initialize optimizer
optimizer = optax.adam(learning_rate=LEARNING_RATE)
opt_state = optimizer.init(params)

TypeError: add got incompatible shapes for broadcasting: (8, 30, 16), (8, 30, 64).

In [None]:
print("Starting GNN training...")
# Training loop
for step in range(NUM_STEPS):
    loop_key, data_key, dropout_key = jax.random.split(loop_key, 3)
    batch, data_key = generate_dummy_graph_batch(data_key, BATCH_SIZE, NUM_NODES, NODE_FEATURE_DIM)

    params, opt_state, loss = train_step(
        params,
        opt_state,
        batch,
        dropout_key, # Pass key even if dropout not used yet
        model.apply
    )

    if step % 100 == 0:
        print(f"Step: {step}, Loss: {loss:.4f}")

print("Training finished.")