In [None]:
!wget https://raw.githubusercontent.com/dgront/chem-ml/refs/heads/main/COURSE/Lab5-GNN_xlogp_prediction/Lab5a_GNN_SMILES_processing.ipynb
%run Lab5a_GNN_SMILES_processing.ipynb

In [3]:
import tensorflow as tf

def edge_enhanced_gnn_layer(node_features, edge_index, edge_features,
                            node_mlp_weights, message_mlp_weights):
    """
    Pure function implementing the MPNN layer with edge features.

    Args:
        node_features: Tensor of shape (N, node_feat_dim); N = number of nodes
        edge_index: Tensor of shape (2, M); M = number of edges
        edge_features: Tensor of shape (M, edge_feat_dim)
        node_mlp_weights: a Dense layer or weights for node update
        message_mlp_weights: a Dense layer or weights for message generation

    Returns:
        Updated node features: Tensor of shape (N, output_dim)
    """
    source = edge_index[0]  # (M,) edges from
    target = edge_index[1]  # (M,) edges to

    source_features = tf.gather(node_features, source)  # (M, node_feat_dim)
    message_inputs = tf.concat([source_features, edge_features], axis=-1)  # (M, node_feat_dim + edge_feat_dim)

    # Apply message MLP
    messages = message_mlp_weights(message_inputs)  # transforms (M, gnn_input_dim) to (M, hidden_dim); network size is gnn_input_dim x hidden_dim

    # Aggregate messages for each target node
    N = tf.shape(node_features)[0]
    hidden_dim = messages.shape[-1]
    aggregated = tf.zeros((N, hidden_dim), dtype=messages.dtype)

    # sum up transformed messages arriving at every node using the 'target' indexes
    # Input tensor to be processed is (M, hidden_dim); given indexing tensor (M,1)
    # with elements [0, N) we get (N, hidden_dim)
    aggregated = tf.tensor_scatter_nd_add(
        aggregated,
        tf.expand_dims(target, axis=1),  # (M, 1)
        messages
    )  # (N, hidden_dim)

    # Apply node update MLP
    updated_nodes = node_mlp_weights(aggregated)  # (N, hidden_dim)

    return updated_nodes


In [4]:
print(V_train[0].shape)
print(E_i_train[0].shape)
print(E_f_train[0].shape)

(4, 41)
(2, 3)
(3, 5)


In [5]:
# Dimensions
node_dim = V_train[0].shape[1]        # The number of features for an atom
edge_dim = E_f_train[0].shape[1]      # The number of features for a bond
gnn_input_dim = node_dim + edge_dim
gnn_hidden_dim = 64
hidden_layer_dim = 32

# Create MLPs
message_mlp = tf.keras.layers.Dense(gnn_hidden_dim, activation='relu', input_shape=(gnn_input_dim,))
node_mlp = tf.keras.layers.Dense(gnn_hidden_dim, activation='relu')
hidden_layer = tf.keras.layers.Dense(hidden_layer_dim, activation='relu')
output_layer = tf.keras.layers.Dense(1, activation=None)  # Predicting a real number

# Define the full forward function
def forward_pass(node_features, edge_index, edge_features,
                 message_mlp, node_mlp, hidden_layer, output_layer):
    """
    Full forward pass through the GNN + hidden layer + output layer.
    """
    # 1. GNN layer
    updated_nodes = edge_enhanced_gnn_layer(node_features, edge_index, edge_features,
                                            node_mlp, message_mlp)

    # 2. Global pooling (average pooling across all nodes)
    graph_embedding = tf.reduce_mean(updated_nodes, axis=0)  # (hidden_dim,)

    # 3. Hidden dense layer
    hidden = hidden_layer(tf.expand_dims(graph_embedding, axis=0))  # (1, hidden_dim_2)

    # 4. Output layer
    output = output_layer(hidden)  # (1, 1)

    return tf.squeeze(output, axis=0)  # Return scalar

# One layer call as a test
updated_nodes = edge_enhanced_gnn_layer(V_train[0], E_i_train[0], E_f_train[0], node_mlp, message_mlp)

print(updated_nodes.shape)  # (10, 64)

val = forward_pass(V_train[0], E_i_train[0], E_f_train[0], message_mlp, node_mlp, hidden_layer, output_layer)
print(val)


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


(4, 64)
tf.Tensor([-0.03424795], shape=(1,), dtype=float32)


In [None]:
num_graphs = len(V_train)
epochs = 50
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()

# Containers for tracking metrics
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(epochs):
    total_loss = 0.0
    total_accuracy = 0.0

    # --- Training loop over all molecular graphs in the training set ---
    for node_features, edge_index, edge_features, target, name, formula in zip(V_train, E_i_train, E_f_train, Y_train, names_train, formulas_train):
        with tf.GradientTape() as tape:
            try:
                prediction = forward_pass(node_features, edge_index, edge_features,
                                          message_mlp, node_mlp, hidden_layer, output_layer)
            except:
                print(name, formula, node_features.shape)
                break

            loss = loss_fn(target, prediction)

        # Update model weights
        trainable_vars = (message_mlp.trainable_variables +
                          node_mlp.trainable_variables +
                          hidden_layer.trainable_variables +
                          output_layer.trainable_variables)
        gradients = tape.gradient(loss, trainable_vars)
        optimizer.apply_gradients(zip(gradients, trainable_vars))

        total_loss += loss.numpy()
        total_accuracy += tf.reduce_mean(tf.abs(target - prediction))

    avg_train_loss = total_loss / num_graphs
    avg_train_acc = total_accuracy / num_graphs
    train_losses.append(avg_train_loss)
    train_accuracies.append(avg_train_acc)

    # --- Validation loop ---
    val_total_loss = 0.0
    val_total_accuracy = 0.0
    val_num_graphs = len(V_val)

    for node_features, edge_index, edge_features, target in zip(V_val, E_i_val, E_f_val, Y_val):
        prediction = forward_pass(node_features, edge_index, edge_features,
                                  message_mlp, node_mlp, hidden_layer, output_layer)
        val_loss = loss_fn(target, prediction)
        val_acc = tf.reduce_mean(tf.abs(target - prediction))
        val_total_loss += val_loss.numpy()
        val_total_accuracy += val_acc.numpy()

    avg_val_loss = val_total_loss / val_num_graphs
    avg_val_acc = val_total_accuracy / val_num_graphs
    val_losses.append(avg_val_loss)
    val_accuracies.append(avg_val_acc)

    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}, "
          f"Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f}")


In [None]:
import matplotlib.pyplot as plt

epochs = range(1, len(train_losses) + 1)


# Plot training accuracy
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label='Training MSE')
plt.plot(epochs, val_losses, label='Validation MSE')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Plot training loss
plt.subplot(1, 2, 2)
plt.plot(epochs, train_accuracies, label='Training MAE')
plt.plot(epochs, val_accuracies, label='Validation MAE')
plt.xlabel('Epoch')
plt.ylabel('accuracy')
plt.legend()

plt.tight_layout()
plt.show()
