Drug_API.ipynb

**PyTorch Geometric (PyG) API Demo: Data, Batching, GATConv, Pooling, and a Thin Wrapper**

This notebook is an **API-style demonstration of PyTorch Geometric**.  
It is intentionally **tool-focused** and avoids any domain- or project-specific content.

What is covered:
- Constructing graphs using `torch_geometric.data.Data`
- Batching multiple graphs using `torch_geometric.loader.DataLoader`
- Using `GATConv` for message passing
- Using `global_mean_pool` to obtain graph-level embeddings
- A minimal wrapper module that exposes a clean `forward()` for:
  - graph encoding (graph → embedding)
  - pair scoring (graph A + graph B → score)

Everything uses **tiny synthetic graphs** so it runs fast.


In [8]:
# Imports
import torch
import torch.nn as nn

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool

from sklearn.metrics import roc_auc_score

Core PyG objects

**1) PyG Graph Representation with `Data`**

A PyG graph is commonly represented as:
- `x`: node feature matrix of shape `[num_nodes, num_node_features]`
- `edge_index`: connectivity of shape `[2, num_edges]` with directed edges

Optional fields that become important when batching:
- `batch`: assigned by `DataLoader` to indicate which node belongs to which graph

Helper: build a tiny synthetic graph

In [9]:
def make_toy_graph(num_nodes: int, num_node_features: int, edges: list[tuple[int, int]]) -> Data:
    """
    Create a tiny synthetic graph for API demonstration.

    Parameters
    ----------
    num_nodes : int
        Number of nodes in the graph.
    num_node_features : int
        Dimensionality of node features.
    edges : list of (src, dst)
        Directed edges. If you want an undirected graph, include both directions.

    Returns
    -------
    Data
        PyG graph object with fields `x` and `edge_index`.
    """
    x = torch.randn(num_nodes, num_node_features, dtype=torch.float)
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return Data(x=x, edge_index=edge_index)

Create a few graphs

In [10]:
num_node_features = 5

g1 = make_toy_graph(
    num_nodes=3,
    num_node_features=num_node_features,
    edges=[(0, 1), (1, 0), (1, 2), (2, 1)]
)

g2 = make_toy_graph(
    num_nodes=4,
    num_node_features=num_node_features,
    edges=[(0, 1), (1, 0), (1, 2), (2, 1), (2, 3), (3, 2)]
)

g3 = make_toy_graph(
    num_nodes=5,
    num_node_features=num_node_features,
    edges=[(0, 2), (2, 0), (1, 2), (2, 1), (2, 4), (4, 2)]
)

g1, g2, g3

(Data(x=[3, 5], edge_index=[2, 4]),
 Data(x=[4, 5], edge_index=[2, 6]),
 Data(x=[5, 5], edge_index=[2, 6]))

Batching with DataLoader

**2) Batching multiple graphs with `DataLoader`**

PyG batches graphs by concatenating all nodes and edges into a single big graph,
and it provides a `batch` vector that maps each node to its original graph id.

This is the standard pattern for graph-level tasks.


DataLoader batch demo

In [11]:
loader = DataLoader([g1, g2, g3], batch_size=3, shuffle=False)
batch = next(iter(loader))

print("Batch.x shape       :", tuple(batch.x.shape))
print("Batch.edge_index shape:", tuple(batch.edge_index.shape))
print("Batch.batch shape   :", tuple(batch.batch.shape))
print("Graphs in batch     :", int(batch.batch.max().item()) + 1)

Batch.x shape       : (12, 5)
Batch.edge_index shape: (2, 16)
Batch.batch shape   : (12,)
Graphs in batch     : 3


Native PyG: GATConv forward

**3) Native PyG layer demo: `GATConv`**

`GATConv` performs attention-based message passing.
It updates node embeddings using neighbor information with learned attention weights.



Native GATConv on batched data

In [12]:
in_channels = num_node_features
out_channels = 8
heads = 2

gat = GATConv(in_channels=in_channels, out_channels=out_channels, heads=heads, concat=True)

node_emb = gat(batch.x, batch.edge_index)
print("Node embeddings shape:", tuple(node_emb.shape))  # [total_nodes_in_batch, out_channels * heads]

Node embeddings shape: (12, 16)


Pooling to get graph embeddings

**4) Graph-level embeddings using `global_mean_pool`**

Many pipelines need a single vector per graph. Pooling aggregates node embeddings:
- `global_mean_pool` computes the mean of node embeddings for each graph in the batch.


Pooling demo

In [13]:
graph_emb = global_mean_pool(node_emb, batch.batch)
print("Graph embeddings shape:", tuple(graph_emb.shape))  # [num_graphs_in_batch, hidden_dim]

Graph embeddings shape: (3, 16)


Thin wrapper layer

**5) Thin wrapper layer (clean forward pass)**

This wrapper is intentionally minimal:
- `GATGraphEncoder`: graph batch → graph embeddings
- `GraphPairScorer`: (graph A, graph B) → logit

It is still PyG-driven (Data, DataLoader batching, GATConv, pooling), just packaged
in a reusable way.


Wrapper: encoder

In [14]:
class GATGraphEncoder(nn.Module):
    """
    Minimal graph encoder using PyG's GATConv + global mean pooling.

    Input:  batched Data object with fields x, edge_index, batch
    Output: graph embeddings of shape [num_graphs, emb_dim]
    """
    def __init__(self, in_dim: int, hidden_dim: int, emb_dim: int, heads: int = 2):
        super().__init__()
        self.gat1 = GATConv(in_channels=in_dim, out_channels=hidden_dim, heads=heads, concat=True)
        self.gat2 = GATConv(in_channels=hidden_dim * heads, out_channels=emb_dim, heads=1, concat=True)
        self.act = nn.ELU()

    def forward(self, data: Data) -> torch.Tensor:
        x = self.gat1(data.x, data.edge_index)
        x = self.act(x)
        x = self.gat2(x, data.edge_index)

        # graph-level embedding
        g = global_mean_pool(x, data.batch)
        return g

Wrapper: pair scorer

In [15]:
class GraphPairScorer(nn.Module):
    """
    Scores a pair of graphs using an encoder + small MLP head.

    Strategy:
    - encode A -> ea
    - encode B -> eb
    - combine -> [ea, eb, |ea-eb|, ea*eb]
    - MLP -> logit
    """
    def __init__(self, encoder: nn.Module, emb_dim: int):
        super().__init__()
        self.encoder = encoder
        self.head = nn.Sequential(
            nn.Linear(emb_dim * 4, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, batch_a: Data, batch_b: Data) -> torch.Tensor:
        ea = self.encoder(batch_a)
        eb = self.encoder(batch_b)

        combined = torch.cat([ea, eb, torch.abs(ea - eb), ea * eb], dim=1)
        logit = self.head(combined).squeeze(-1)
        return logit

Pair batching

**6) Pair batching with PyG**

To score many pairs efficiently, we batch all "left" graphs together and all "right"
graphs together using two DataLoaders, then score them in parallel.


Build synthetic pairs and batch them

In [16]:
# Synthetic pairs: (left_graph, right_graph, label)
pairs = [
    (g1, g2, 1),
    (g1, g1, 0),
    (g2, g3, 0),
    (g3, g1, 1),
    (g2, g1, 1),
    (g3, g3, 0),
]

left_graphs  = [p[0] for p in pairs]
right_graphs = [p[1] for p in pairs]
labels = torch.tensor([p[2] for p in pairs], dtype=torch.float)

left_loader  = DataLoader(left_graphs, batch_size=len(left_graphs), shuffle=False)
right_loader = DataLoader(right_graphs, batch_size=len(right_graphs), shuffle=False)

batch_left = next(iter(left_loader))
batch_right = next(iter(right_loader))

print("Left batch graphs :", int(batch_left.batch.max().item()) + 1)
print("Right batch graphs:", int(batch_right.batch.max().item()) + 1)
print("Labels shape      :", tuple(labels.shape))


Left batch graphs : 6
Right batch graphs: 6
Labels shape      : (6,)


Run model + ROC-AUC sanity check

In [17]:
encoder = GATGraphEncoder(in_dim=num_node_features, hidden_dim=8, emb_dim=16, heads=2)
model = GraphPairScorer(encoder=encoder, emb_dim=16)

model.eval()
with torch.no_grad():
    logits = model(batch_left, batch_right)
    probs = torch.sigmoid(logits)

print("Logits:", [round(v, 4) for v in logits.tolist()])
print("Probs :", [round(v, 4) for v in probs.tolist()])
print("ROC-AUC (synthetic):", round(roc_auc_score(labels.numpy(), probs.numpy()), 4))

Logits: [0.1037, 0.1027, 0.0391, 0.0569, 0.0781, 0.0358]
Probs : [0.5259, 0.5256, 0.5098, 0.5142, 0.5195, 0.5089]
ROC-AUC (synthetic): 0.7778


Minimal perturbation sensitivity

**7) Minimal sensitivity demo (perturbation baseline)**

A quick, tool-level baseline to probe sensitivity:
- perturb one node feature column in a graph batch
- see how pair probabilities shift

This is not a full interpretability framework; it's a lightweight sanity tool.

Perturb + measure change

In [19]:
def perturb_feature_column(data: Data, col_idx: int, delta: float = 0.25) -> Data:
    out = Data(
        x=data.x.clone(),
        edge_index=data.edge_index.clone(),
        batch=data.batch.clone()
    )
    out.x[:, col_idx] += delta
    return out

model.eval()
with torch.no_grad():
    base_probs = torch.sigmoid(model(batch_left, batch_right))

col_to_test = 0
with torch.no_grad():
    pert_left = perturb_feature_column(batch_left, col_idx=col_to_test, delta=0.25)
    pert_probs = torch.sigmoid(model(pert_left, batch_right))

diff = (pert_probs - base_probs).abs()

print(f"Tested feature column: {col_to_test}")
print("Mean |prob|:", float(diff.mean()))
print("Max  |prob|:", float(diff.max()))

Tested feature column: 0
Mean |prob|: 0.0016743242740631104
Max  |prob|: 0.002989351749420166


## Summary

This notebook demonstrated core PyTorch Geometric APIs:
- `Data(x, edge_index)` for graph representation
- `DataLoader` batching with the `batch` vector
- `GATConv` for attention-based message passing
- `global_mean_pool` for graph-level embeddings

It also included a thin wrapper that packages the same PyG operations into:
- a reusable encoder (graph → embedding)
- a reusable pair scorer (graph pair → logit)

All data was synthetic to keep the notebook fast and independent of any project context.
