# Graphs Convolutional Networks (GCNs)

In [60]:
import jax
from clu import metrics
from flax import struct
from flax import nnx
from flax.training import train_state
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp  # Useful for handling sparse adjacency matrices
from typing import Sequence
import optax
from tqdm import tqdm

# Data

In [2]:
def print_file_head(filename, num_lines=5):
    print(f"\n--- Head of {filename} {num_lines} lines ---")
    try:
        with open(filename, 'r') as f:
            for i, line in enumerate(f):
                if i >= num_lines:
                    break
                print(line.strip())
    except FileNotFoundError:
        print(f"Error: {filename} not found")


print_file_head("./dataset/cora/cora.cites", 5)
print_file_head("./dataset/cora/cora.content", 1)


--- Head of ./dataset/cora/cora.cites 5 lines ---
35	1033
35	103482
35	103515
35	1050679
35	1103960

--- Head of ./dataset/cora/cora.content 1 lines ---
31336	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	1	0	0	0	0	0	0	1	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	1	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	1	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	1	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	0	

## Preprocessing

In [3]:
def _load_content_file(filepath: str):
    """Loads features and labels from the .content file."""
    print(f"Loading content file: {filepath}")
    idx_features_labels = np.genfromtxt(
        filepath,
        dtype=np.dtype(str)
    )
    features = sp.csr_matrix(
        idx_features_labels[:, 1:-1],
        dtype=np.float32
    )
    labels = idx_features_labels[:, -1]
    paper_ids = np.array(
        idx_features_labels[:, 0],
        dtype=np.int32
    )
    return features, labels, paper_ids

In [4]:
file_path = "./dataset/cora/cora.content"
features, labels, paper_ids = _load_content_file(file_path)

Loading content file: ./dataset/cora/cora.content


In [5]:
print("--- Features Matrix ---\n")
print(features.shape)
print(features[0].shape)
print(features[0])

--- Features Matrix ---

(2708, 1433)
(1, 1433)
  (0, 0)	0.0
  (0, 1)	0.0
  (0, 2)	0.0
  (0, 3)	0.0
  (0, 4)	0.0
  (0, 5)	0.0
  (0, 6)	0.0
  (0, 7)	0.0
  (0, 8)	0.0
  (0, 9)	0.0
  (0, 10)	0.0
  (0, 11)	0.0
  (0, 12)	0.0
  (0, 13)	0.0
  (0, 14)	0.0
  (0, 15)	0.0
  (0, 16)	0.0
  (0, 17)	0.0
  (0, 18)	0.0
  (0, 19)	0.0
  (0, 20)	0.0
  (0, 21)	0.0
  (0, 22)	0.0
  (0, 23)	0.0
  (0, 24)	0.0
  :	:
  (0, 1408)	0.0
  (0, 1409)	0.0
  (0, 1410)	0.0
  (0, 1411)	0.0
  (0, 1412)	0.0
  (0, 1413)	0.0
  (0, 1414)	0.0
  (0, 1415)	0.0
  (0, 1416)	0.0
  (0, 1417)	0.0
  (0, 1418)	0.0
  (0, 1419)	0.0
  (0, 1420)	0.0
  (0, 1421)	0.0
  (0, 1422)	0.0
  (0, 1423)	0.0
  (0, 1424)	0.0
  (0, 1425)	0.0
  (0, 1426)	1.0
  (0, 1427)	0.0
  (0, 1428)	0.0
  (0, 1429)	0.0
  (0, 1430)	0.0
  (0, 1431)	0.0
  (0, 1432)	0.0


In [6]:
print("--- Labels Matrix ---\n")
print(labels.shape)
print(labels)

--- Labels Matrix ---

(2708,)
['Neural_Networks' 'Rule_Learning' 'Reinforcement_Learning' ...
 'Genetic_Algorithms' 'Case_Based' 'Neural_Networks']


In [7]:
print("--- Paper IDs ---\n")
print(paper_ids.shape)
print(paper_ids)

--- Paper IDs ---

(2708,)
[  31336 1061127 1106406 ... 1128978  117328   24043]


In [8]:
def _load_cites_file(filepath: str, idx_map: dict):
    """Loads citation links and builds the adjacency matrix."""
    print(f"Loading cites file: {filepath}")
    edges_unordered = np.genfromtxt(
        filepath,
        dtype=np.int32
    )

    # Convert paper IDs in edges to our new integer indices
    edges = np.array(
        list(map(
            idx_map.get,
            edges_unordered.flatten())
        ), dtype=np.int32,
    ).reshape(edges_unordered.shape)

    # Create COO (Coordinate) format sparse adjacency matrix
    adj = sp.coo_matrix(
        (
            np.ones(edges.shape[0]),
            (edges[:, 0], edges[:, 1])
        ),
        # Use len(idx_map) for the number of nodes
        shape=(len(idx_map), len(idx_map)),
        dtype=np.float32
    )

    # Build symmetric adjacency matrix
    adj = (
            adj +
            adj.T.multiply(adj.T > adj) -
            adj.multiply(adj.T > adj)
    )
    return adj

In [9]:
file_path = "./dataset/cora/cora.cites"
idx_map = {
    paper_id: idx for idx, paper_id in enumerate(paper_ids)
}
adj = _load_cites_file(file_path, idx_map)
print("-- Adjacency Matrix ---")
print(adj.shape)
print(adj[:2].shape)
print(adj[:2])

Loading cites file: ./dataset/cora/cora.cites
-- Adjacency Matrix ---
(2708, 2708)
(2, 2708)
  (0, 8)	1.0
  (0, 14)	1.0
  (0, 258)	1.0
  (0, 435)	1.0
  (0, 544)	1.0
  (1, 344)	1.0


In [10]:
def _normalize_features(features: np.ndarray):
    """Normalizes node features."""
    features = features.todense()
    features /= features.sum(1).reshape(-1, 1)
    return features

In [11]:
def _one_hot_encode_labels(labels: np.ndarray):
    """Converts categorical labels to one-hot encoded format."""
    classes = sorted(list(set(labels)))
    class_to_idx = {
        c: i for i, c in enumerate(classes)
    }

    # Corrected labels_one_hot initialization shape
    labels_one_hot = np.zeros(
        (len(labels), len(classes)),
        dtype=np.float32
    )
    for i, label in enumerate(labels):
        labels_one_hot[i, class_to_idx[label]] = 1
    return labels_one_hot

In [12]:
def _create_masks(num_nodes: int):
    """Defines standard train, validation, and test masks."""
    idx_train = jnp.arange(140)
    idx_val = jnp.arange(140, 140 + 500)
    idx_test = jnp.arange(140 + 500, 140 + 500 + 1000)

    train_mask = jnp.zeros(
        num_nodes, dtype=bool
    ).at[idx_train].set(True)
    val_mask = jnp.zeros(
        num_nodes, dtype=bool
    ).at[idx_val].set(True)
    test_mask = jnp.zeros(
        num_nodes, dtype=bool
    ).at[idx_test].set(True)

    return train_mask, val_mask, test_mask

In [13]:
def load_cora(path: str = "./dataset/cora/"):
    """
    Loads and preprocesses the Cora dataset
    from raw .content and .cites files.

    Args:
        path (str): Directory path where
            cora.content and cora.cites are located.
            Defaults to "./dataset/cora/".

    Returns:
        tuple:
        (adj, features, labels, train_mask, val_mask, test_mask)
        where adj, features, and labels are JAX numpy arrays.
    """
    # 1. Load features and labels, and create paper ID to index mapping
    features_sparse, labels_raw, paper_ids = _load_content_file(path + "cora.content")
    idx_map = {paper_id: idx for idx, paper_id in enumerate(paper_ids)}

    # 2. Build adjacency matrix
    adj_sparse = _load_cites_file(path + "cora.cites", idx_map)

    # 3. Preprocess features
    features_normalized = _normalize_features(features_sparse)

    # 4. One-hot encode labels
    labels_one_hot = _one_hot_encode_labels(labels_raw)

    # 5. Create masks for data splits
    num_nodes = labels_raw.shape[0]  # Use the original number of labels for node count
    train_mask, val_mask, test_mask = _create_masks(num_nodes)

    # 6. Convert to JAX numpy arrays
    adj_jax = jnp.asarray(adj_sparse.todense(), dtype=jnp.float32)
    features_jax = jnp.asarray(features_normalized, dtype=jnp.float32)
    labels_jax = jnp.asarray(labels_one_hot, dtype=jnp.float32)

    return adj_jax, features_jax, labels_jax, train_mask, val_mask, test_mask

In [15]:
adj_raw, features, labels, train_mask, val_mask, test_mask = load_cora()

print(f"\nCora dataset loaded:")
print(f"Number of nodes: {adj_raw.shape[0]}")
print(f"Number of features per node: {features.shape[1]}")
print(f"Number of classes: {labels.shape[1]}")
print(f"Features shape: {features.shape}")
print(f"Adjacency matrix shape: {adj_raw.shape}")
print(f"Labels shape (one-hot encoded): {labels.shape}")
print(f"Train nodes: {jnp.sum(train_mask)}")
print(f"Validation nodes: {jnp.sum(val_mask)}")
print(f"Test nodes: {jnp.sum(test_mask)}")


Loading content file: ./dataset/cora/cora.content
Loading cites file: ./dataset/cora/cora.cites

Cora dataset loaded:
Number of nodes: 2708
Number of features per node: 1433
Number of classes: 7
Features shape: (2708, 1433)
Adjacency matrix shape: (2708, 2708)
Labels shape (one-hot encoded): (2708, 7)
Train nodes: 140
Validation nodes: 500
Test nodes: 1000


In [89]:
def preprocess_adjacency_matrix(adj):
    """
    Adds self-loops and symmetrically normalizes the adjacency matrix.

    Args:
        adj (jnp.ndarray): The raw adjacency matrix from your dataset.

    Returns:
        jnp.ndarray: The transformed adjacency matrix ready for the GCN.
    """
    # Step 1: Add self-loops (A' = A + I)
    # This ensures each node considers its own features during aggregation.
    adj_self_loops = adj + jnp.eye(adj.shape[0])

    # Step 2: Calculate the inverse square root of the degree matrix (D'^-0.5)
    # This creates the normalization factor.
    D_hat = jnp.diag(jnp.sum(adj_self_loops, axis=1)**-0.5)

    # Step 3: Symmetrically normalize the matrix (D'^-0.5 * A' * D'^-0.5)
    # This averages neighbor features, preventing issues from node degrees.
    A_hat = D_hat @ adj_self_loops @ D_hat

    return A_hat

In [91]:
# Create the correctly processed adjacency matrix
adj_normalized = preprocess_adjacency_matrix(adj_raw)

## Comparison with the actual preprocessing

In [16]:
from torch_geometric.datasets import Planetoid

cora_dataset = Planetoid(root='/tmp/Cora', name='Cora')

In [17]:
cora_dataset[0]

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [18]:
print(cora_dataset[0]['train_mask'].count_nonzero())
print(cora_dataset[0]['val_mask'].count_nonzero())
print(cora_dataset[0]['test_mask'].count_nonzero())

tensor(140)
tensor(500)
tensor(1000)


# Building the GCN Layer

In [77]:
class GCNLayer(nnx.Module):
    def __init__(
            self,
            input_features: int,
            output_features: int,
            dropout_rate: float,
            *, rngs: nnx.Rngs
    ):
        self.linear = nnx.Linear(
            input_features,
            output_features,
            kernel_init=nnx.initializers.glorot_uniform(),
            rngs=rngs
        )
        self.dropout = nnx.Dropout(
            rate=dropout_rate,
            rngs=rngs
        )

    def __call__(
            self,
            A_hat: jnp.ndarray,
            H: jnp.ndarray,
            *, rngs: nnx.Rngs,
    ):
        """
        Applies a single Graph Convolutional
        Network layer using NNX.

        Args:
            A_hat (jnp.ndarray): The symmetrically
                normalized adjacency matrix.
            H (jnp.ndarray): The input node embeddings/
                features from the previous layer.
            rngs (nnx.Rngs): Random number generator
                collection for dropout.

        Returns:
            jnp.ndarray: The output node embeddings/
                features for the current layer.
        """
        # 1. Dropout for regularization
        # nnx knows when to apply dropout - train
        #   and when don't - eval
        H_dropped = self.dropout(H, rngs=rngs)

        # 2. Linear Transformation (H^(l) * W^(l))
        H_linear = self.linear(H_dropped)

        # 3. Graph Convolution (A_hat * H_linear)
        H_aggregated = A_hat @ H_linear

        # 4. Activation Function
        H_activated = nnx.relu(H_aggregated)

        return H_activated

In [79]:
class GCN(nnx.Module):
    """A graph Convolutional Network Model"""

    def __init__(
            self,
            input_features: int,
            output_classes: int,
            hidden_features: Sequence[int],
            dropout_rate: float,
            *, rngs: nnx.Rngs
    ):
        self.gcn_layers = []
        current_features = input_features

        # Create the hidden GCN layers
        for hidden_dim in hidden_features:
            self.gcn_layers.append(
                GCNLayer(
                    input_features=current_features,
                    output_features=hidden_dim,
                    dropout_rate=dropout_rate,
                    rngs=rngs
                )
            )
            # The input of the next layer is the
            #   output to this one
            current_features = hidden_dim

        # The final layer maps node embeddings to
        #   output classes
        self.output_layer = nnx.Linear(
            in_features=current_features,
            out_features=output_classes,
            rngs=rngs
        )

    def __call__(
            self,
            A_hat: jnp.ndarray,
            H: jnp.ndarray,
            *, rngs: nnx.Rngs,
    ):
        """Performs the forward pass of the GCN."""
        # Pass data through all the GCN layers
        for layer in self.gcn_layers:
            H = layer(A_hat, H, rngs=rngs)

        # Apply the final linear layer to
        #   get logits
        logits = self.output_layer(H)
        return logits

# Training the GCN Model

In [55]:
# --- Model & Hyperparameters ---
LEARNING_RATE = 0.01
EPOCHS = 200
HIDDEN_FEATURES = [64, 32]  # Two hidden layers
DROPOUT_RATE = 0.5

In [83]:
# --- Initialize Model and Optimizer ---
model = GCN(
    input_features=features.shape[1],
    output_classes=labels.shape[1],
    hidden_features=HIDDEN_FEATURES,
    dropout_rate=DROPOUT_RATE,
    rngs=nnx.Rngs(0),
)

In [86]:
graphdef, state = nnx.split(model)
nnx.display(graphdef, state)

# Metric

In [94]:
# --- Loss & Accuracy Functions ---
def categorical_loss(logits, labels, mask):
    """Calculates masked cross-entropy loss."""
    # Convert one-hot labels to integer labels for this loss function
    int_labels = jnp.argmax(labels, axis=1)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, int_labels)
    # Return the mean loss only for the nodes in the mask
    return jnp.mean(loss[mask])

metrics = nnx.MultiMetric(
    accuracy=nnx.metrics.Accuracy(),
    loss=categorical_loss,
)

In [87]:
optimizer = nnx.Optimizer(model, optax.adam(LEARNING_RATE))

In [104]:
def loss_fn(model: GCN, batch):
    logits = model(
        batch["adj"],
        batch["features"],
        rngs=nnx.Rngs(42)
    )
    loss = categorical_loss(
        logits,
        batch["label"],
        batch["mask"]
    )
    return loss, logits

In [105]:
@nnx.jit
def train_step(
        model: GCN,
        optimizer,
        metrics: nnx.MultiMetric,
        batch
):
    """Performs one training step: computes grads, updates model."""
    grad_fn = nnx.value_and_grad(
        loss_fn,
        has_aux=True,
    )
    (loss, logits), grads = grad_fn(model, batch)

    metrics.update(
        loss=loss,
        logits=logits,
        labels=batch["label"],
    )
    optimizer.update(grads)  # In-place updates

In [99]:
@nnx.jit
def eval_step(
        model: GCN,
        metrics: nnx.MultiMetric,
        batch
):
    loss, logits = loss_fn(model, batch)
    metrics.update(
        loss=loss,
        logits=logits,
        labels=batch["label"],
    )


In [103]:
# Create batch format for GCN training
def create_batch(adj, features, labels, mask):
    """Creates a batch in the format expected by the training loop."""
    return {
        "adj": adj,
        "label": labels,
        "mask": mask,
        "adj": adj,
        "features": features
    }

# Create train and test batches
train_batch = create_batch(adj_normalized, features, labels, train_mask)
test_batch = create_batch(adj_normalized, features, labels, test_mask)

# For the training loop, you can also create a simple dataset iterator
def create_dataset_iterator(batch, num_repeats=1):
    """Creates an iterator that yields the same batch repeatedly."""
    for _ in range(num_repeats):
        yield batch

# Create dataset iterators
train_ds = create_dataset_iterator(train_batch, EPOCHS)
test_ds = create_dataset_iterator(test_batch, 1)  # For evaluation

print("Batch format created successfully!")
print(f"Train batch keys: {list(train_batch.keys())}")
print(f"Features shape: {train_batch['features'].shape}")
print(f"Labels shape: {train_batch['label'].shape}")
print(f"Adjacency shape: {train_batch['adj'].shape}")
print(f"Train mask sum: {jnp.sum(train_batch['mask'])}")

Batch format created successfully!
Train batch keys: ['adj', 'label', 'mask', 'features']
Features shape: (2708, 1433)
Labels shape: (2708, 7)
Adjacency shape: (2708, 2708)
Train mask sum: 140


In [None]:
from IPython.display import clear_output
import matplotlib.pyplot as plt

metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

for step, batch in enumerate(train_ds.as_numpy_iterator()):
  # Run the optimization for one step and make a stateful update to the following:
  # - The train state's model parameters
  # - The optimizer state
  # - The training loss and accuracy batch metrics
  train_step(model, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
    # Log the training metrics.
    for metric, value in metrics.compute().items():  # Compute the metrics.
      metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
    metrics.reset()  # Reset the metrics for the test set.

    # Compute the metrics on the test set after each training epoch.
    for test_batch in test_ds.as_numpy_iterator():
      eval_step(model, metrics, test_batch)

    # Log the test metrics.
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()  # Reset the metrics for the next training epoch.

    clear_output(wait=True)
    # Plot loss and accuracy in subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.set_title('Loss')
    ax2.set_title('Accuracy')
    for dataset in ('train', 'test'):
      ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
      ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
    ax1.legend()
    ax2.legend()
    plt.show()