In [None]:
import torch
import torch_geometric.transforms as T
import sys
from torch import nn
import copy

import pandas as pd
import numpy as np

from madrigal.models import models
from madrigal.utils import DATA_DIR, BASE_DIR

00:15:19   Note: NumExpr detected 28 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.


In [None]:
kg = pd.read_csv(BASE_DIR+'PrimeKG/kg.csv')
meta = pd.read_csv(DATA_DIR+'views_features/combined_metadata_reindexed_ddi.csv')

In [13]:
in_kg = meta[meta.view_kg == 1].node_id

In [16]:
kg = kg[kg.x_id.isin(in_kg)]

In [22]:
counts = kg.groupby('x_id')['x_id'].count().tolist()

In [27]:
np.save('counts.npy', np.array(counts))

In [26]:
import matplotlib.pyplot as plt
plt.hist(counts)
plt.show()

  plt.show()


In [None]:
VIEWS_PATH = DATA_DIR+'views_features/KG_data_hgt.pt'

In [20]:
graph = torch.load(VIEWS_PATH)
graph_copy = copy.deepcopy(graph)

In [22]:
edge_types = []
rev_edge_types = []
for edge in graph.edge_types:
    if 'rev' not in edge[1]:
        edge_types.append(edge)
        rev_edge_types.append(None)
        #rev_edge_types.append( (edge[2], 'rev_' + edge[1] ,edge[0])  )

In [23]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    is_undirected = False,
    disjoint_train_ratio=0.1,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False,
    edge_types=edge_types,
    rev_edge_types=rev_edge_types, 
)
train_data, val_data, test_data = transform(graph)

In [24]:
def manually_add_rev_labels(data):
    for edge in data.edge_types:
        if 'rev' in edge[1]:
            normal_edge = edge[1].replace("rev_", "")
            normal = (edge[2], normal_edge, edge[0])
            normal_edge_index = data[normal].edge_label_index
            data[edge]['edge_label_index'] = torch.roll(normal_edge_index, 1, 0)
            data[edge]['edge_label'] = data[normal].edge_label

def manually_add_rev_labels_2(data):
    for edge in data.edge_types:
        if 'rev' in edge[1]:
            pos = data[edge]['edge_index']
            #neg = torch.stack([torch.arange(data[edge]['edge_index'].max()), 
                               #torch.randperm(data[edge]['edge_index'].max())], dim=0)
            #data[edge]['edge_label_index'] = torch.cat([pos, neg], axis=1)
            data[edge]['edge_label_index'] = pos
            data[edge]['edge_label'] = torch.ones(pos.shape[1])
            #data[edge]['edge_label'] = torch.cat([torch.ones(pos.shape[1]), torch.zeros(neg.shape[1])])
            
manually_add_rev_labels(train_data)
manually_add_rev_labels(val_data)
manually_add_rev_labels(test_data)

In [25]:
train_edge_label_index = [ (train_data[edge]['edge_label_index'],edge) for edge in train_data.edge_types]
val_edge_label_index = [ (val_data[edge]['edge_label_index'],edge) for edge in val_data.edge_types]
test_edge_label_index = [ (test_data[edge]['edge_label_index'],edge) for edge in test_data.edge_types]

In [26]:
def check_data(data):
    for edge in data.edge_types:
        left = data[edge[0]]['num_nodes']
        right = data[edge[2]]['num_nodes']
        
        edge_index = data[edge]['edge_index']
#         print(edge)
#         print(edge_index[0,:].max(), left)
#         print(edge_index[1,:].max(), right)
        if edge_index[0,:].max() > left: print('hi')
        if edge_index[1,:].max() > right: print('hi')

check_data(train_data)
check_data(val_data)
check_data(test_data)

In [27]:
class HGTLinkPred(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, num_heads, 
                 metadata, num_edge_types, group='sum'):
        super(HGTLinkPred, self).__init__()
        
        self.encoder = models.HGT(in_channels, hidden_channels, out_channels, 
                                  num_layers, num_heads, metadata, group)
        self.decoder = models.BilinearDDIScorer(out_channels, out_channels, 1)
        self.decoders = [self.decoder for i in range(num_edge_types)]
    
    def forward(self, x_dict, edge_index_dict, edge_label_index):
        z_dict = self.encoder(x_dict, edge_index_dict)
        preds = []
        for decoder, (edge_pred_index, edge_name) in zip(self.decoders, edge_label_index):
            i, r, j = edge_name
            print(edge_name)
            node_in, node_out = edge_pred_index
            pred = decoder(z_dict[i], z_dict[j]).squeeze(0)
            pred = pred[node_in, node_out]
            preds.append(pred)
        return torch.cat(preds, 0)
    
    def save_checkpoint(self, PATH):
        torch.save(self.encoder.state_dict(), PATH)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = HGTLinkPred(in_channels=train_data.x_dict['drug'].shape[1],
                    hidden_channels=128,
                    out_channels=128,
                    num_layers=2,
                    num_heads=4,
                    num_edge_types=len(train_edge_label_index),
                    metadata = train_data.metadata()).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

loss_fn = nn.MSELoss()

def train():
    model.train()
    optimizer.zero_grad()
    preds = model(train_data.x_dict, train_data.edge_index_dict, train_edge_label_index)
    targets = torch.stack([i[0] for i in train_edge_label_index], 0)
    loss = loss_fn(preds, targets) 
    loss.backward()
    optimizer.step()
    return float(loss)

In [None]:
for epoch in range(1, 301):
    loss = train()