# Graph Neural Network for DDI Prediction

In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch import Tensor

from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv

from torch_geometric.data import HeteroData

from torch_geometric.utils import negative_sampling

from ogb.linkproppred import Evaluator, PygLinkPropPredDataset

import matplotlib.pyplot as plt

### Load OGB Dataset

In [2]:
dataset = PygLinkPropPredDataset(name='ogbl-ddi', transform=T.ToSparseTensor())
    
split_edge = dataset.get_edge_split()
train_edge, valid_edge, test_edge = split_edge["train"], split_edge["valid"], split_edge["test"]

# Create hetero data
data_train = HeteroData()
data_train['drug'].x = train_edge['edge'].flatten().reshape(-1,1)
data_train['drug', 'interacts_with', 'drug'].edge_index = train_edge['edge'].T

data_valid = HeteroData()
data_valid['drug'].x = valid_edge['edge'].flatten().reshape(-1,1)
data_valid['drug', 'interacts_with', 'drug'].edge_index = valid_edge['edge'].T

data_valid_neg = HeteroData()
data_valid_neg['drug'].x = valid_edge['edge_neg'].flatten().reshape(-1,1)
data_valid_neg['drug', 'interacts_with', 'drug'].edge_index = valid_edge['edge_neg'].T

data_test = HeteroData()
data_test['drug'].x = test_edge['edge'].flatten().reshape(-1,1)
data_test['drug', 'interacts_with', 'drug'].edge_index = test_edge['edge'].T

data_test_neg = HeteroData()
data_test_neg['drug'].x = test_edge['edge_neg'].flatten().reshape(-1,1)
data_test_neg['drug', 'interacts_with', 'drug'].edge_index = test_edge['edge_neg'].T

# TODO: part of negative data as new relation - not interacts_with


In [3]:
data_train['drug'].num_nodes
# data_train.edge_index_dict

2135822

In [4]:
data_train["drug", "interacts_with", "drug"]

{'edge_index': tensor([[4039, 4039, 4039,  ...,  647,  708,  835],
        [2424,  225, 3901,  ...,  708,  338, 3554]])}

In [4]:
from torch_geometric.loader import LinkNeighborLoader

edge_label_index = data_train["drug", "interacts_with", "drug"].edge_index

train_loader = LinkNeighborLoader(
    data=data_train,
    num_neighbors=[15]*2,
#     neg_sampling_ratio=2,
    edge_label_index=(("drug", "interacts_with", "drug"), edge_label_index),
    batch_size=128,
    shuffle=True,
)

# Inspect a sample:
sampled_data = next(iter(train_loader))

print("Sampled mini-batch:")
print("===================")
print(sampled_data)

Sampled mini-batch:
HeteroData(
  [1mdrug[0m={ x=[2637, 1] },
  [1m(drug, interacts_with, drug)[0m={
    edge_index=[2, 25045],
    input_id=[128],
    edge_label_index=[2, 128]
  }
)


In [8]:
from torch_geometric.nn import SAGEConv, to_hetero


class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, dropout):
        super().__init__()

        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), hidden_channels)
        
        self.dropout = dropout

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = self.conv1(x.float(), edge_index)
        x = F.relu(x)
#         x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        return x


class Classifier(torch.nn.Module):
    def __init__(self):
        
        hidden_channels = 32
        out_channels = 1
        dropout = 0.2
        
        super(Classifier, self).__init__()
        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(256, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout
        
    def forward(self, x_i: Tensor, x_j: Tensor, edge_label_index: Tensor) -> Tensor:
#         # Convert node embeddings to edge-level representations:
#         edge_x_i= x_i[edge_label_index[0]]
#         edge_x_j = x_j[edge_label_index[1]]
        

#         # TODO: linear layer + relu + dropout + sigmoid 
#         return (edge_x_i * edge_x_j).sum(dim=-1)

        x_i = x_i[edge_label_index[0]]
        x_j = x_j[edge_label_index[0]]
        
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)  


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.drug_emb = torch.nn.Embedding(data_train["drug"].num_nodes, hidden_channels)
#         self.drug_j_emb = torch.nn.Embedding(data_train["drug"].num_nodes, hidden_channels)

        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels, 0.2)

        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data_train.metadata())

        self.classifier = Classifier()

    def forward(self, data, neg_edges=None) -> Tensor:
#         x_dict = {
#           "drug": self.drug_emb(data["drug"].x),
#         } 

        # `x_dict` holds feature matrices of all node types
        # `edge_index_dict` holds all edge indices of all edge types
        
        if neg_edges is None: 
            x_dict = self.gnn(data.x_dict, data.edge_index_dict)
        else:
            neg_edges_dict = {("drug", "interacts_with", "drug"): neg_edges}
            x_dict = self.gnn(data.x_dict, neg_edges_dict)

        pred = self.classifier(
            x_dict["drug"],
            x_dict["drug"],
            data["drug", "interacts_with", "drug"].edge_label_index,
        )

        return pred

        
model = Model(hidden_channels=256)

print(model)

Model(
  (drug_emb): Embedding(2135822, 256)
  (gnn): GraphModule(
    (conv1): ModuleDict(
      (drug__interacts_with__drug): SAGEConv((-1, -1), 256, aggr=mean)
    )
    (conv2): ModuleDict(
      (drug__interacts_with__drug): SAGEConv((-1, -1), 256, aggr=mean)
    )
  )
  (classifier): Classifier(
    (lins): ModuleList(
      (0): Linear(in_features=256, out_features=32, bias=True)
      (1): Linear(in_features=32, out_features=1, bias=True)
    )
  )
)


In [9]:
import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()
for epoch in range(1, 6):
    total_loss = total_examples = 0
    for batch in tqdm.tqdm(train_loader):
        optimizer.zero_grad()
        
        batch.to(device)
        out = model(batch)
        loss = -torch.log(torch.sigmoid(out) + 1e-15).mean()
    
        neg_edge_index = negative_sampling(batch['interacts_with'].edge_index, method='dense') #num_nodes=batch.x_dict['drug'].size(0), num_neg_samples=batch.x_dict['drug'].size(0), method='dense')

        out_neg = model(batch, neg_edge_index)
        loss += -torch.log(1 - torch.sigmoid(out_neg) + 1e-15).mean()

        loss.backward()
        optimizer.step()
        total_loss += float(loss) * out.numel()
        total_examples += out.numel()
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

100%|███████████████████████████████████████| 8344/8344 [02:51<00:00, 48.53it/s]


Epoch: 001, Loss: 1.3866


100%|███████████████████████████████████████| 8344/8344 [02:50<00:00, 49.05it/s]


Epoch: 002, Loss: 1.3863


100%|███████████████████████████████████████| 8344/8344 [02:49<00:00, 49.29it/s]


Epoch: 003, Loss: 1.3863


100%|███████████████████████████████████████| 8344/8344 [02:50<00:00, 48.90it/s]


Epoch: 004, Loss: 1.3863


100%|███████████████████████████████████████| 8344/8344 [02:49<00:00, 49.28it/s]

Epoch: 005, Loss: 1.3863





In [10]:
# test

edge_label_index = data_valid["drug", "interacts_with", "drug"].edge_index

valid_loader = LinkNeighborLoader(
    data=data_valid,
    num_neighbors=[15]*2,
#     neg_sampling_ratio=2,
    edge_label_index=(("drug", "interacts_with", "drug"), edge_label_index),
    batch_size=128,
    shuffle=True,
)

edge_label_index = data_valid_neg["drug", "interacts_with", "drug"].edge_index

valid_neg_loader = LinkNeighborLoader(
    data=data_valid_neg,
    num_neighbors=[15]*2,
#     neg_sampling_ratio=2,
    edge_label_index=(("drug", "interacts_with", "drug"), edge_label_index),
    batch_size=128,
    shuffle=True,
)

model.eval()
total_loss = total_examples = 0
for batch in tqdm.tqdm(valid_loader):
    with torch.no_grad():

        batch.to(device)
        out = model(batch)
        loss = -torch.log(torch.sigmoid(out) + 1e-15).mean()

        neg_edge_index = negative_sampling(batch['interacts_with'].edge_index, method='dense') #num_nodes=batch.x_dict['drug'].size(0), num_neg_samples=batch.x_dict['drug'].size(0), method='dense')

        out_neg = model(batch, neg_edge_index)
        loss += -torch.log(1 - torch.sigmoid(out_neg) + 1e-15).mean()

        total_loss += float(loss) * out.numel()
        total_examples += out.numel()
print(f"Validation Loss: {total_loss / total_examples:.4f}")
print()

pos_valid_preds = []
for batch in tqdm.tqdm(valid_loader):
    batch.to(device)
    pos_valid_preds += [model(batch).squeeze().cpu()]
pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

neg_valid_preds = []
for batch in tqdm.tqdm(valid_neg_loader):
    batch.to(device)
    neg_valid_preds += [model(batch).squeeze().cpu()]
neg_valid_pred = torch.cat(neg_valid_preds, dim=0)


evaluator = Evaluator(name='ogbl-ddi')

results = {}
for K in [10, 20, 30]:
    evaluator.K = K
    valid_hits = evaluator.eval({
        'y_pred_pos': pos_valid_pred,
        'y_pred_neg': neg_valid_pred,
    })[f'hits@{K}']

    results[f'Hits@{K}'] = (train_hits, valid_hits)

for key, result in results.items():
    train_hits, valid_hits = result
    print(key)
    print(f'Valid: {100 * valid_hits:.2f}%')
      

100%|███████████████████████████████████████| 1043/1043 [00:13<00:00, 79.88it/s]


Validation Loss: 1.3863



 57%|██████████████████████▏                | 592/1043 [00:01<00:00, 490.75it/s]


ValueError: Encountered a CUDA error. Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 2539) in your node feature matrix and try again.