In [16]:
# data can be dowloaded using get_data in eda.ipynb !!



import torch
from torch_geometric.data import HeteroData
import pandas as pd
from torch_geometric.transforms import RandomLinkSplit, NormalizeFeatures, ToUndirected, ToDevice

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

disease = pd.read_table('../data/raw/D-MeshMiner_miner-disease.tsv')

disease_old_ids = disease['# MESH_ID'].unique()

disease_mapping = pd.DataFrame(data={
    'id': pd.RangeIndex(len(disease_old_ids)),
    'old_id': disease_old_ids
})

disease_ids = disease_mapping['id'].to_numpy()

edges = pd.read_table('../data/raw/DCh-Miner_miner-disease-chemical.tsv')
drug_old_ids = edges['Chemical'].unique()

drug_mapping = pd.DataFrame(data={
    'id': pd.RangeIndex(len(drug_old_ids)),
    'old_id': drug_old_ids
})

disease_mapped = disease.merge(disease_mapping, left_on='# MESH_ID', right_on='old_id', how='left').loc[:, ['id', 'Definitions']]

defs = disease_mapped.set_index('id')['Definitions'].str.findall(r'\b[A-Z]{2,}\b').str.join(',')

edges_mapped = edges\
                   .merge(disease_mapping, left_on='# Disease(MESH)', right_on='old_id', how='inner')\
                   .merge(drug_mapping, left_on='Chemical', right_on='old_id', how='inner', suffixes=('_disease', '_drug'))\
                   .loc[:, ['id_disease', 'id_drug']].to_numpy()


data = HeteroData()
data['disease'].node_id = torch.tensor(disease_ids)
data['drug'].node_id = torch.tensor(drug_mapping['id'].to_numpy())

data['disease'].x = torch.tensor(defs.str.get_dummies(',').to_numpy(), dtype=torch.float)

data['disease', 'healedby', 'drug'].edge_index = torch.tensor(edges_mapped.T, dtype=torch.int64)

normalize = NormalizeFeatures()

undirect = ToUndirected()

td = ToDevice(device)

transform = RandomLinkSplit(
    edge_types=('disease', 'healedby', 'drug'),
    rev_edge_types=('drug', 'rev_healedby', 'disease'),
    add_negative_train_samples=True,
    disjoint_train_ratio=.3)

train_data, val_data, test_data = transform(undirect(td(data)))

In [17]:
edges_mapped.shape

(457186, 2)

In [18]:
train_data['disease', 'healedby', 'drug'].num_edges

224022

In [19]:
 train_data['disease', 'healedby', 'drug'].edge_label.type(torch.int)

tensor([1, 1, 1,  ..., 0, 0, 0], device='cuda:0', dtype=torch.int32)

In [20]:
data.num_features

{'disease': 3698, 'drug': 0}

In [21]:
data['disease'].num_nodes

11332

In [22]:
data['drug'].num_nodes

1663

In [27]:
'''
first we make an embedding of the nodes for both diseases and drugs
then we calculate features for edges and finally perform binary classfication on them

drugs do not have features so we must make an embedding for them

and there are many diseases with no features so we also make an embedding

'''


from torch_geometric.nn import SAGEConv, to_hetero
# for embedding
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = torch.nn.functional.dropout(x, .1)
        x = self.conv2(x, edge_index)
        return x

class FinalClassifier(torch.nn.Module):
    def forward(self, x_disease, x_drug, edge_label_index):
        edge_feat_disease = x_disease[edge_label_index[0]]
        edge_feat_drug = x_drug[edge_label_index[1]]
        return (edge_feat_disease * edge_feat_drug).sum(dim=-1)



class Model(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # disease_features have 3698 features
        self.disease_linear = torch.nn.Linear(data['disease'].num_features, in_channels)
        self.disease_embedding = torch.nn.Embedding(data['disease'].num_nodes, in_channels)
        self.drug_embedding = torch.nn.Embedding(data['drug'].num_nodes, in_channels)

        # our gnn for graph embeddings
        self.gnn = GNN(in_channels, hidden_channels, out_channels)
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        self.classifier = FinalClassifier()
    def forward(self, data):
        x_dict = {
            'disease': self.disease_embedding(data['disease'].node_id) + self.disease_linear(data['disease'].x),
            'drug': self.drug_embedding(data['drug'].node_id),
        }
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        pred = self.classifier(
            x_dict['disease'],
            x_dict['drug'],
            data['disease', 'healedby', 'drug'].edge_label_index,
        )
        return pred

In [30]:
model = Model(64, 128, 64)
print(model)

Model(
  (disease_linear): Linear(in_features=3698, out_features=64, bias=True)
  (disease_embedding): Embedding(11332, 64)
  (drug_embedding): Embedding(1663, 64)
  (gnn): GraphModule(
    (conv1): ModuleDict(
      (disease__healedby__drug): SAGEConv(64, 128, aggr=mean)
      (drug__rev_healedby__disease): SAGEConv(64, 128, aggr=mean)
    )
    (conv2): ModuleDict(
      (disease__healedby__drug): SAGEConv(128, 64, aggr=mean)
      (drug__rev_healedby__disease): SAGEConv(128, 64, aggr=mean)
    )
  )
  (classifier): FinalClassifier()
)


In [31]:
from sklearn.metrics import roc_auc_score

# training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_data = train_data.to(device)

print(f'device {device}')

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=.001)
criterion = torch.nn.CrossEntropyLoss()



for epoch in range(150):
    model.train()
    optimizer.zero_grad()
    pred = model.forward(train_data)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, train_data['disease', 'healedby', 'drug'].edge_label)
    loss.backward()
    optimizer.step()
    print(f'epoch: {epoch}, loss: {loss}')
    if epoch % 10 ==0:
        print(f'roc on test: {roc_auc_score(test_data["disease", "healedby", "drug"].edge_label.cpu().numpy(), model.forward(test_data).detach().cpu().numpy())}')


device cuda
epoch: 0, loss: 0.7758535146713257
roc on test: 0.6691371176235614
epoch: 1, loss: 0.6583035588264465
epoch: 2, loss: 0.5722266435623169
epoch: 3, loss: 0.5075036883354187
epoch: 4, loss: 0.46103599667549133
epoch: 5, loss: 0.43058016896247864
epoch: 6, loss: 0.4060705602169037
epoch: 7, loss: 0.386748343706131
epoch: 8, loss: 0.3699135482311249
epoch: 9, loss: 0.35859137773513794
epoch: 10, loss: 0.3482809066772461
roc on test: 0.909456134150225
epoch: 11, loss: 0.34332937002182007
epoch: 12, loss: 0.33444154262542725
epoch: 13, loss: 0.33053556084632874
epoch: 14, loss: 0.32051900029182434
epoch: 15, loss: 0.3143906891345978
epoch: 16, loss: 0.30408746004104614
epoch: 17, loss: 0.2971493899822235
epoch: 18, loss: 0.2858213782310486
epoch: 19, loss: 0.27908119559288025
epoch: 20, loss: 0.2744169235229492
roc on test: 0.947112755157058
epoch: 21, loss: 0.26433414220809937
epoch: 22, loss: 0.2609004080295563
epoch: 23, loss: 0.25208795070648193
epoch: 24, loss: 0.24802617728

In [33]:
roc_auc_score(test_data['disease', 'healedby', 'drug'].edge_label.cpu(), model.forward(test_data).detach().cpu().numpy())

0.9789007983573106