In [1]:
import os, urllib.request, zipfile, json
import numpy as np
import scipy.sparse as sp
from sklearn.metrics import f1_score
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dropout, Dense, BatchNormalization, LayerNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

In [2]:
DATA_URL = "https://snap.stanford.edu/graphsage/ppi.zip"
DATA_DIR = "ppi"

In [None]:
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)
    zip_path, _ = urllib.request.urlretrieve(DATA_URL)
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall(DATA_DIR)

In [4]:
X = np.load(os.path.join(DATA_DIR, "ppi/ppi-feats.npy"))
with open(os.path.join(DATA_DIR,"ppi/ppi-id_map.json")) as f: id_map = json.load(f)
with open(os.path.join(DATA_DIR,"ppi/ppi-class_map.json")) as f: class_map = json.load(f)
with open(os.path.join(DATA_DIR,"ppi/ppi-G.json")) as f: G_json = json.load(f)

In [5]:
n_classes = len(next(iter(class_map.values())))

In [None]:
Y = np.zeros((X.shape[0], n_classes), dtype=np.int32)
train_mask = np.zeros(X.shape[0], dtype=bool)
val_mask = np.zeros(X.shape[0], dtype=bool)
test_mask = np.zeros(X.shape[0], dtype=bool)

for node in G_json['nodes']:
    node_id = str(node['id'])
    idx = id_map[node_id]
    Y[idx] = class_map[node_id]

    if node.get('test', False):
        test_mask[idx] = True
    elif node.get('val', False):
        val_mask[idx] = True
    else:
        train_mask[idx] = True

N, F = X.shape

In [None]:
X = (X - X[train_mask].mean(axis=0)) / (X[train_mask].std(axis=0) + 1e-6)
X = np.nan_to_num(X)

raw_links = G_json["links"]
clean_links = [(l['source'], l['target']) for l in raw_links if l['source'] != l['target']]
unique_links = set(clean_links)

# Symmetrize
rows = [u for u,v in unique_links] + [v for u,v in unique_links]
cols = [v for u,v in unique_links] + [u for u,v in unique_links]
data = np.ones(len(rows), dtype=np.float32)

A = sp.coo_matrix((data, (rows, cols)), shape=(N, N))
A.setdiag(1.0)  # Add self-loops
A = A.tocsr()

In [8]:
deg = np.array(A.sum(1)).flatten()
d_inv_sqrt = np.power(deg, -0.5)
d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
D_inv_sqrt = sp.diags(d_inv_sqrt)
A_norm = D_inv_sqrt @ A @ D_inv_sqrt  # A_hat = D^(-1/2) A D^(-1/2)

A_norm = A_norm.tocoo()
A_sp = tf.SparseTensor(
    indices=np.vstack((A_norm.row, A_norm.col)).T,
    values=A_norm.data.astype(np.float32),
    dense_shape=A_norm.shape
)
A_sp = tf.sparse.reorder(A_sp)

In [None]:
class GraphConvLayer(tf.keras.layers.Layer):
    def __init__(self, out_dim, activation=None, use_bias=True,
                 dropout_rate=0.0, norm_type='layer', **kwargs):
        super().__init__(**kwargs)
        self.out_dim = out_dim
        self.activation = tf.keras.activations.get(activation)
        self.use_bias = use_bias
        self.dropout_rate = dropout_rate
        self.norm_type = norm_type

    def build(self, input_shape):
        F_in = input_shape[0][-1]

        self.W = self.add_weight(
            name="W",
            shape=(F_in, self.out_dim),
            initializer="glorot_uniform",
            trainable=True
        )

        if self.use_bias:
            self.b = self.add_weight(
                name="bias",
                shape=(self.out_dim,),
                initializer="zeros",
                trainable=True
            )

        if F_in != self.out_dim:
            self.W_skip = self.add_weight(
                name="W_skip",
                shape=(F_in, self.out_dim),
                initializer="glorot_uniform",
                trainable=True
            )
        else:
            self.W_skip = None

        if self.norm_type == 'batch':
            self.norm = BatchNormalization()
        elif self.norm_type == 'layer':
            self.norm = LayerNormalization()
        else:
            self.norm = None

        self.dropout = Dropout(self.dropout_rate)

    def call(self, inputs, training=False):
        X, A_norm = inputs  # X: (N,F), A_norm: sparse (N,N)

        # Graph convolution: A_norm @ X @ W
        supports = tf.matmul(X, self.W)  # (N, out_dim)
        output = tf.sparse.sparse_dense_matmul(A_norm, supports)

        
        if self.use_bias:
            output += self.b

        if self.norm is not None:
            output = self.norm(output)

        if self.W_skip is not None:
            skip = tf.matmul(X, self.W_skip)
        else:
            skip = X

        output = output + skip

        output = self.dropout(output, training=training)

        return self.activation(output) if self.activation else output

class AttentionLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        F = input_shape[-1]
        self.W_att = self.add_weight(
            name="W_attention",
            shape=(F, 1),
            initializer="glorot_uniform",
            trainable=True
        )

    def call(self, inputs):
        att_scores = tf.matmul(inputs, self.W_att)  # (N, 1)
        att_weights = tf.nn.sigmoid(att_scores)     # (N, 1)

        return inputs * att_weights

class PPIModel(Model):
    def __init__(self, hidden_dims=[512, 256], n_classes=121, dropout=0.2,
                 l2_reg=1e-5, norm_type='layer', use_attention=True):
        super().__init__()

        self.input_dropout = Dropout(dropout/2)
        self.use_attention = use_attention

        self.conv_layers = []
        for i, dim in enumerate(hidden_dims):
            self.conv_layers.append(
                GraphConvLayer(dim, activation="relu", dropout_rate=dropout,
                              norm_type=norm_type, name=f"gcn_{i}")
            )

        if use_attention:
            self.attention = AttentionLayer()

        self.classifier = Dense(
            n_classes,
            activation="sigmoid",
            kernel_regularizer=tf.keras.regularizers.l2(l2_reg)
        )

    def call(self, inputs, training=False):
        X, A = inputs

        h = self.input_dropout(X, training=training)

        for conv in self.conv_layers:
            h = conv([h, A], training=training)

        if self.use_attention:
            h = self.attention(h)

        return self.classifier(h)

In [None]:
model_params = {
    'hidden_dims': [512, 256],
    'n_classes': n_classes,
    'dropout': 0.2,
    'l2_reg': 1e-5,
    'norm_type': 'layer',
    'use_attention': False
}


In [None]:
model = PPIModel(**model_params)
optimizer = Adam(learning_rate=5e-3)
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False)


In [12]:
def evaluate(X, A, Y, mask):
    preds = model([X, A], training=False)
    preds = tf.cast(preds >= 0.5, tf.int32).numpy()
    return f1_score(Y[mask], preds[mask], average="micro")

In [13]:
X_tf = tf.constant(X, dtype=tf.float32)
Y_tf = tf.constant(Y, dtype=tf.float32)

In [14]:
_ = model([X_tf, A_sp], training=False)

In [None]:
@tf.function
def train_step(X, A, Y, mask):
    with tf.GradientTape() as tape:
        preds = model([X, A], training=True)
        loss = loss_fn(
            Y,
            preds,
            sample_weight=tf.cast(mask, tf.float32)
        )
        loss += sum(model.losses)

    grads = tape.gradient(loss, model.trainable_variables)

    # Gradient clipping
    grads, _ = tf.clip_by_global_norm(grads, 5.0)

    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss


In [None]:
max_epochs = 300
patience = 20
checkpoint_path = "best_ppi_gnn.weights.h5"

def lr_scheduler(epoch, lr):
    if epoch < 50:
        return lr
    elif epoch < 100:
        return lr * 0.5
    else:
        return lr * 0.1

best_val = 0.0
wait = 0

In [None]:
for epoch in range(1, max_epochs + 1):
    if epoch % 50 == 0 and epoch > 0:
        lr = optimizer.learning_rate.numpy()
        optimizer.learning_rate.assign(lr * 0.5)
        print(f"Learning rate decreased to {optimizer.learning_rate.numpy()}")

    loss = train_step(X_tf, A_sp, Y_tf, train_mask)

    val_f1 = evaluate(X_tf, A_sp, Y, val_mask)

    print(f"Epoch {epoch:03d}: loss={loss:.4f} val_f1={val_f1:.4f} lr={optimizer.learning_rate.numpy():.6f}")

    if val_f1 > best_val:
        best_val = val_f1
        wait = 0
        model.save_weights(checkpoint_path)
        print(f"Model improved, saving checkpoint (val_f1={val_f1:.4f})")
    else:
        wait += 1
        if wait >= patience:
            print(f"Early stopping after {epoch} epochs.")
            break

Epoch 001: loss=0.7194 val_f1=0.4267 lr=0.005000
Model improved, saving checkpoint (val_f1=0.4267)
Epoch 002: loss=0.5619 val_f1=0.4175 lr=0.005000
Epoch 003: loss=0.4900 val_f1=0.4153 lr=0.005000
Epoch 004: loss=0.4659 val_f1=0.4089 lr=0.005000
Epoch 005: loss=0.4606 val_f1=0.4181 lr=0.005000
Epoch 006: loss=0.4563 val_f1=0.4208 lr=0.005000
Epoch 007: loss=0.4506 val_f1=0.4169 lr=0.005000
Epoch 008: loss=0.4462 val_f1=0.4147 lr=0.005000
Epoch 009: loss=0.4426 val_f1=0.4178 lr=0.005000
Epoch 010: loss=0.4396 val_f1=0.4254 lr=0.005000
Epoch 011: loss=0.4372 val_f1=0.4308 lr=0.005000
Model improved, saving checkpoint (val_f1=0.4308)
Epoch 012: loss=0.4349 val_f1=0.4348 lr=0.005000
Model improved, saving checkpoint (val_f1=0.4348)
Epoch 013: loss=0.4330 val_f1=0.4382 lr=0.005000
Model improved, saving checkpoint (val_f1=0.4382)
Epoch 014: loss=0.4312 val_f1=0.4409 lr=0.005000
Model improved, saving checkpoint (val_f1=0.4409)
Epoch 015: loss=0.4295 val_f1=0.4439 lr=0.005000
Model improved,

In [None]:
model.load_weights(checkpoint_path)

val_f1 = evaluate(X_tf, A_sp, Y, val_mask)
test_f1 = evaluate(X_tf, A_sp, Y, test_mask)
print(f"Final results - Val Micro-F1: {val_f1:.4f}, Test Micro-F1: {test_f1:.4f}")

model.summary()

predictions = model([X_tf, A_sp], training=False).numpy()
binary_preds = (predictions >= 0.5).astype(np.int32)

from sklearn.metrics import classification_report

class_names = [f"class_{i}" for i in range(n_classes)]
print("\nTest set performance by class:")
report = classification_report(
    Y[test_mask],
    binary_preds[test_mask],
    target_names=class_names,
    zero_division=0
)
print(report)

Final results - Val Micro-F1: 0.7024, Test Micro-F1: 0.7179



Test set performance by class:
              precision    recall  f1-score   support

     class_0       0.88      0.87      0.88      3518
     class_1       0.83      0.34      0.49      1293
     class_2       0.84      0.26      0.39      1138
     class_3       0.75      0.49      0.59      1302
     class_4       0.79      0.24      0.36       721
     class_5       0.76      0.54      0.63      1041
     class_6       0.75      0.51      0.61      1321
     class_7       0.66      0.34      0.45      1916
     class_8       0.88      0.18      0.30       878
     class_9       0.88      0.72      0.79      1926
    class_10       0.91      0.65      0.76      1192
    class_11       0.80      0.47      0.59      1245
    class_12       0.75      0.95      0.84      3864
    class_13       0.71      0.18      0.29       988
    class_14       0.82      0.77      0.80       973
    class_15       0.74      0.59      0.66      1981
    class_16       0.90      0.36      0.51      