# **Applying RGCN model to real-world data**
After experimenting with a simple, synthetic dataset to understand the fundamentals, we now apply the RGCN model to a real-world biomedical dataset!

In this notebook, we:
1. Load real-world biomedical dataset.
2. Convert it into a PyTorch Geometric `Data` object.
3. Stratify-split the edges by relation type into train/val/test.
4. Train and evaluate an RGCN for link prediction.

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import RGCNConv
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


## 1. About the Pharmacotherapy Dataset

The PharmacotherapyDB dataset links diseases/conditions to drug treatments and specify the type of relationship (for example, "treats" or "palliates" or "neither treats nor palliates").

This dataset has the following structure:
- The column doid_id is the disease identifier (often from the Disease Ontology), such as DOID:1234.
- The column drugbank_id is the drug identifier such as DB00123 (DrugBank format).
- The column Y indicates the relation between them, for example "treats" (0 for "treats", 1 for "palliates", 2 "neither treats nor palliates") 

From this dataset, we will construct a graph where each disease and each drug is a node, and an edge connects them according to the specified Y relationship. We can then train a Relational Graph Convolutional Network (RGCN) to link prediction, where the goal is to predict missing or potential relationships between nodes.

## 2. Data structure and loading
This function reads the CSV and creates:
- A mapping of entity strings to integer indices (for both `doid_id` and `drugbank_id`).
- A mapping of relation strings to integer relation indices.
- A `Data` object from PyG, with:
  - `data.x`: random embeddings for each unique entity (size `[num_nodes, embedding_dim]`).
  - `data.edge_index`: shape `[2, num_edges]`.
  - `data.edge_type`: shape `[num_edges]` (identifies each edge's relation).

In [4]:
def load_pharma_data(
    csv_file,
    doid_col='doid_id',
    drug_col='drugbank_id',
    rel_col='Y',
    embedding_dim=32
):
    """
    Reads the Pharmacotherapy CSV (with columns [doid_col, drug_col, rel_col])
    and produces a PyG Data object.
    """
    # 1) Read CSV
    df = pd.read_csv(csv_file)

    # 2) Gather entities & relations
    unique_entities = set()
    unique_relations = set()
    edges = []

    for row in df.itertuples(index=False):
        doid_id = getattr(row, doid_col)
        drug_id = getattr(row, drug_col)
        rel = getattr(row, rel_col)

        unique_entities.add(doid_id)
        unique_entities.add(drug_id)
        unique_relations.add(rel)

        edges.append((doid_id, drug_id, rel))

    # 3) Sort and map to indices
    entity_list = sorted(list(unique_entities))
    entity_to_idx = {ent: i for i, ent in enumerate(entity_list)}

    relation_list = sorted(list(unique_relations))
    relation_to_idx = {r: i for i, r in enumerate(relation_list)}

    # 4) Build edge_index & edge_type
    edge_index_list = []
    edge_type_list = []

    for (doid_id, drug_id, rel) in edges:
        src = entity_to_idx[doid_id]
        dst = entity_to_idx[drug_id]
        edge_index_list.append([src, dst])
        edge_type_list.append(relation_to_idx[rel])

    edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
    edge_type = torch.tensor(edge_type_list, dtype=torch.long)

    # 5) Random embeddings for each node
    num_nodes = len(entity_to_idx)
    x = torch.randn((num_nodes, embedding_dim), dtype=torch.float)

    # 6) Create PyG Data
    data = Data(
        x=x,
        edge_index=edge_index,
        edge_type=edge_type
    )

    return data, entity_to_idx, relation_to_idx

## 3. Stratified Splitting of Edges

In a multi-relation graph, some relation types might be much more common than others. A purely random
split could end up with certain relation types over-represented in one set and under-represented
in another. **Stratification** ensures that each relation type is present in similar proportions 
across **train, val, and test** sets, helping prevent issues where a model never sees certain
relation types during training or sees them only in the test set.

We'll split 80% train, 10% val, 10% test.



In [5]:
def stratified_split_edges(data, test_size=0.2, val_size=0.1, random_seed=42):
    """
    Splits data.edge_index/data.edge_type by relation type.
    80% train, 10% val, 10% test by default.
    Returns train/val/test edges + their relation types.
    """
    edge_index_np = data.edge_index.cpu().numpy()
    edge_type_np = data.edge_type.cpu().numpy()

    num_edges = edge_index_np.shape[1]
    all_edge_ids = np.arange(num_edges)
    all_edge_labels = edge_type_np  # for stratification

    # Step 1: train vs. temp split
    X_train, X_temp, y_train, y_temp = train_test_split(
        all_edge_ids,
        all_edge_labels,
        test_size=test_size,
        stratify=all_edge_labels,
        random_state=random_seed
    )

    # Step 2: val vs. test from X_temp
    # if test_size=0.2 => X_temp is 20%. We want val=10%, test=10% overall
    # that means half of X_temp is val, half is test
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp,
        y_temp,
        test_size=0.5,  # 50% of the temp => 10% overall
        stratify=y_temp,
        random_state=random_seed
    )

    def to_edge_tensors(edge_ids):
        sub_edge_index = torch.tensor(edge_index_np[:, edge_ids], dtype=torch.long)
        sub_edge_type = torch.tensor(edge_type_np[edge_ids], dtype=torch.long)
        return sub_edge_index, sub_edge_type

    train_edge_index, train_edge_type = to_edge_tensors(X_train)
    val_edge_index, val_edge_type = to_edge_tensors(X_val)
    test_edge_index, test_edge_type = to_edge_tensors(X_test)

    return (train_edge_index, train_edge_type,
            val_edge_index, val_edge_type,
            test_edge_index, test_edge_type)

## 4. Example RGCNLinkPredictor
If you already defined this in your original tutorial, you can skip this cell.
Otherwise, here's a minimal version of an RGCN-based link predictor.

In [6]:
class RGCNLinkPredictor(nn.Module):
    def __init__(
        self,
        num_nodes,
        in_channels,
        out_channels,
        num_relations,
        num_layers=2,
    ):
        super().__init__()
        self.num_nodes = num_nodes
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_layers = num_layers

        # Build RGCN layers
        self.convs = nn.ModuleList()
        self.convs.append(RGCNConv(in_channels, out_channels, num_relations, num_bases=4))

        for _ in range(num_layers - 1):
            self.convs.append(RGCNConv(out_channels, out_channels, num_relations, num_bases=4))


    def forward(self, x, edge_index, edge_type):
        # x: [num_nodes, in_channels]
        # edge_index: [2, E]
        # edge_type: [E]
        for conv in self.convs:
            x = conv(x, edge_index, edge_type)
            x = F.relu(x)
        return x  # node embeddings

    def predict(self, node_embeddings, edge_index):
        # Typically, for link prediction, we do a dot-product
        # edge_index: [2, E]
        src = node_embeddings[edge_index[0]]  # [E, out_channels]
        dst = node_embeddings[edge_index[1]]  # [E, out_channels]
        score = (src * dst).sum(dim=-1)  # dot product => shape [E]
        # Return a sigmoid to get a score between 0 and 1
        return torch.sigmoid(score)

## 5. Negative Sampling & Training/Eval Functions

In [7]:
def negative_sampling(num_neg_samples, num_nodes, device=device):
    i = torch.randint(0, num_nodes, (num_neg_samples,), device=device)
    j = torch.randint(0, num_nodes, (num_neg_samples,), device=device)
    return torch.stack([i, j], dim=0)

def train(model, optimizer, data, train_edge_index, train_edge_type):
    model.train()
    optimizer.zero_grad()

    node_embeddings = model(data.x, train_edge_index, train_edge_type)

    # Positive
    pos_score = model.predict(node_embeddings, train_edge_index)

    # Negative
    neg_edge_index = negative_sampling(
        num_neg_samples=train_edge_index.size(1),
        num_nodes=data.num_nodes,
        device=data.x.device
    )
    neg_score = model.predict(node_embeddings, neg_edge_index)

    # Link prediction loss
    loss_pos = -torch.log(pos_score + 1e-15).mean()
    loss_neg = -torch.log(1. - neg_score + 1e-15).mean()
    loss = loss_pos + loss_neg

    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(edge_index, edge_type, threshold=0.5):
    model.eval()
    with torch.no_grad():
        # Obtain node embeddings
        node_embeddings = model(data.x, edge_index, edge_type)

        # Positive (real) edge scores
        pos_score = model.predict(node_embeddings, edge_index)

        # Negative edges (sample the same number as positive)
        neg_edge_index = negative_sampling(edge_index.size(1), data.num_nodes)
        neg_score = model.predict(node_embeddings, neg_edge_index)

    # Ground-truth labels: 1 for real edges, 0 for negatives
    y_true = torch.cat([
        torch.ones(pos_score.size(0), device=device),
        torch.zeros(neg_score.size(0), device=device)
    ], dim=0)

    # Combine predicted scores
    y_scores = torch.cat([pos_score, neg_score], dim=0)

    # Binarize predictions
    y_pred = (y_scores >= threshold).float()

    # Convert to NumPy
    y_true_np = y_true.cpu().numpy()
    y_pred_np = y_pred.cpu().numpy()

    # Compute metrics
    accuracy = accuracy_score(y_true_np, y_pred_np)
    f1 = f1_score(y_true_np, y_pred_np)
    kappa = cohen_kappa_score(y_true_np, y_pred_np)

    return {
        'accuracy': accuracy,
        'f1': f1,
        'kappa': kappa
    }

## 6. Main Flow

Load data, do a stratified split, train the RGCN, and evaluate performance.

In [14]:
csv_file = "/Users/ioanna/Documents/Projects/singh-lab/courses/brown-ml-in-health-biology/GNN-tutorial/data/pharmacotherapyDB.csv"

# Step A: Load & Preprocess
data, entity_to_idx, relation_to_idx = load_pharma_data(
    csv_file,
    doid_col='doid_id',
    drug_col='drugbank_id',
    rel_col='Y',
    embedding_dim=16  # play around with this value
)
print(data)
print("Number of nodes:", data.num_nodes)
print("Number of edges:", data.num_edges)
print("Example of data.x shape:", data.x.shape)

# Step B: Stratified Split
(
    train_edge_index,
    train_edge_type,
    val_edge_index,
    val_edge_type,
    test_edge_index,
    test_edge_type
) = stratified_split_edges(data, test_size=0.2, val_size=0.1, random_seed=42)

# Step C: Build the RGCN model
num_nodes = data.num_nodes
num_relations = int(torch.max(data.edge_type)) + 1
in_channels = data.x.size(1)
out_channels = 8

model = RGCNLinkPredictor(
    num_nodes=num_nodes,
    in_channels=in_channels,
    out_channels=out_channels,
    num_relations=num_relations,
    num_layers=2
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Step D: Move data + edges to the device
data = data.to(device)
train_edge_index = train_edge_index.to(device)
train_edge_type = train_edge_type.to(device)
val_edge_index   = val_edge_index.to(device)
val_edge_type   = val_edge_type.to(device)
test_edge_index  = test_edge_index.to(device)
test_edge_type  = test_edge_type.to(device)

# Step E: Train
EPOCHS = 200
for epoch in range(1, EPOCHS + 1):
    train_loss = train(model, optimizer, data, train_edge_index, train_edge_type)
    
    if epoch % 5 == 0:  # Evaluate every 5 epochs
        val_metrics = evaluate(val_edge_index, val_edge_type, threshold=0.5)
        print(f"Epoch {epoch} | Train Loss: {train_loss:.4f}")
        print("Validation:", val_metrics)

# Step F: Final test evaluation
test_metrics = evaluate(test_edge_index, test_edge_type, threshold=0.5)
print("Test:", test_metrics)

Data(x=[647, 16], edge_index=[2, 1136], edge_type=[1136])
Number of nodes: 647
Number of edges: 1136
Example of data.x shape: torch.Size([647, 16])
Epoch 5 | Train Loss: 1.4073
Validation: {'accuracy': 0.5, 'f1': 0.6666666666666666, 'kappa': np.float64(0.0)}
Epoch 10 | Train Loss: 1.3832
Validation: {'accuracy': 0.5, 'f1': 0.6666666666666666, 'kappa': np.float64(0.0)}
Epoch 15 | Train Loss: 1.3807
Validation: {'accuracy': 0.5, 'f1': 0.6666666666666666, 'kappa': np.float64(0.0)}
Epoch 20 | Train Loss: 1.3725
Validation: {'accuracy': 0.5, 'f1': 0.6666666666666666, 'kappa': np.float64(0.0)}
Epoch 25 | Train Loss: 1.3631
Validation: {'accuracy': 0.5, 'f1': 0.6666666666666666, 'kappa': np.float64(0.0)}
Epoch 30 | Train Loss: 1.3624
Validation: {'accuracy': 0.5, 'f1': 0.6666666666666666, 'kappa': np.float64(0.0)}
Epoch 35 | Train Loss: 1.3426
Validation: {'accuracy': 0.5, 'f1': 0.6666666666666666, 'kappa': np.float64(0.0)}
Epoch 40 | Train Loss: 1.3414
Validation: {'accuracy': 0.5, 'f1': 0.6

Note: While training loss decreases, performance on validation and test sets remains flat. This usually indicates that the node representations are not informative enough for the task.

Where do node representations (`data.x`) come from?

- They can be raw features (e.g., gene expression levels, molecular fingerprints).
- They might be pretrained embeddings (e.g., using a language model or a separate encoder).
- They could be one-hot encodings (if no features are available).

## Homework: Improving Node Representations!

Currently, our node features (`data.x`) are initialized randomly:
`x = torch.randn((num_nodes, embedding_dim), dtype=torch.float)`

While this works for testing the pipeline, **random embeddings do not capture any domain knowledge**.
This is why our model is not improving on validation/test performance.

## Try the following to improve your node features:

### 1. One-Hot Encoding
Encode each node with a one-hot vector (unique to each entity).
This gives the model a stable identity feature for each node.
(*Tip: use torch.eye(num_nodes)*)

### 2. Pretrained Embeddings
- For drugs: Use molecular fingerprints (e.g., PubChem, Morgan fingerprints) or text-based embeddings (e.g., BioBERT).
- For diseases: Use text-based embeddings (e.g., BioBERT, node2vec on disease ontologies).
You'll need to map each entity ID to its corresponding vector.

### 3. Feature Concatenation
Combine multiple feature types — e.g., one-hot + pretrained — to increase richness.
