This notebook has 3 purposes:
* Understanding the structure of the data
* Constructing the knowledge graph (KG)
* Producing the positive and negative samples for training

# Data Structure

In [1]:
import pandas as pd
import os

In [2]:
path = os.path.dirname(os.getcwd())

Nodes

In [3]:
nodes = pd.read_csv(path + r'\data\nodes.csv')
nodes['node_type']= nodes['node_type'].apply(lambda x: x.replace("/","_"))
print(nodes.shape)
nodes.head()

(129375, 5)


Unnamed: 0,node_index,node_id,node_type,node_name,node_source
0,0,9796,gene_protein,PHYHIP,NCBI
1,1,7918,gene_protein,GPANK1,NCBI
2,2,8233,gene_protein,ZRSR2,NCBI
3,3,4899,gene_protein,NRF1,NCBI
4,4,5297,gene_protein,PI4KA,NCBI


Edges

In [4]:
edges = pd.read_csv(path + r'\data\edges.csv')
print(edges.shape)
edges.head()

(8100498, 4)


Unnamed: 0,relation,display_relation,x_index,y_index
0,protein_protein,ppi,0,8889
1,protein_protein,ppi,1,2798
2,protein_protein,ppi,2,5646
3,protein_protein,ppi,3,11592
4,protein_protein,ppi,4,2122


Knowledge Graph

In [5]:
kg = pd.read_csv(path + r'\data\kg.csv')
kg['x_type']= kg['x_type'].apply(lambda x: x.replace("/","_"))
kg['y_type']= kg['y_type'].apply(lambda x: x.replace("/","_"))
kg['relation']= kg['relation'].apply(lambda x: x.replace("-","_"))
kg['relation']= kg['relation'].apply(lambda x: x.replace(" ","_"))
print(kg.shape)
kg.head()

  kg = pd.read_csv(path + r'\data\kg.csv')


(8100498, 12)


Unnamed: 0,relation,display_relation,x_index,x_id,x_type,x_name,x_source,y_index,y_id,y_type,y_name,y_source
0,protein_protein,ppi,0,9796,gene_protein,PHYHIP,NCBI,8889,56992,gene_protein,KIF15,NCBI
1,protein_protein,ppi,1,7918,gene_protein,GPANK1,NCBI,2798,9240,gene_protein,PNMA1,NCBI
2,protein_protein,ppi,2,8233,gene_protein,ZRSR2,NCBI,5646,23548,gene_protein,TTC33,NCBI
3,protein_protein,ppi,3,4899,gene_protein,NRF1,NCBI,11592,11253,gene_protein,MAN1B1,NCBI
4,protein_protein,ppi,4,5297,gene_protein,PI4KA,NCBI,2122,8601,gene_protein,RGS20,NCBI


# KG Construction

In [7]:
from torch_geometric.data import HeteroData
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
import torch
from copy import deepcopy

Create HeteroData object

In [8]:
data = HeteroData()

display(data)

HeteroData()

Proving that both x and y indices span the entire range

In [9]:
for i in range(max(kg['x_index'])+1):
    if i not in kg['x_index'] or i not in kg['y_index']:
        display(i)

print(max(kg['x_index']) == max(kg['y_index']))

True


Get the amount of each type of node

In [10]:
nodes = deepcopy(kg)
nodes.drop_duplicates(subset=['x_index'], keep='first', inplace=True)
node_dict = {}

for node_type in nodes['x_type'].unique():
    split = nodes.loc[kg['x_type'] == node_type]
    node_dict[node_type] = len(split['x_index'].unique())

display(node_dict)

{'gene_protein': 27671,
 'drug': 7957,
 'effect_phenotype': 15311,
 'disease': 17080,
 'biological_process': 28642,
 'molecular_function': 11169,
 'cellular_component': 4176,
 'exposure': 818,
 'pathway': 2516,
 'anatomy': 14035}

Randomize a feature tensor for each type of node and add to data object

In [11]:
for node_type in node_dict.keys():
    data[node_type].x = torch.randn(node_dict[node_type], 16)

display(data)

HeteroData(
  gene_protein={ x=[27671, 16] },
  drug={ x=[7957, 16] },
  effect_phenotype={ x=[15311, 16] },
  disease={ x=[17080, 16] },
  biological_process={ x=[28642, 16] },
  molecular_function={ x=[11169, 16] },
  cellular_component={ x=[4176, 16] },
  exposure={ x=[818, 16] },
  pathway={ x=[2516, 16] },
  anatomy={ x=[14035, 16] }
)

Create edge dict to keep track of within-group indices

In [12]:
temp = deepcopy(kg)
temp.drop_duplicates(subset=['x_index'], keep='first', inplace=True)
temp['group_idx'] = temp.groupby('x_type').cumcount()
idx_to_group = dict(zip(temp['x_index'], temp['group_idx']))

Get the edge connections and add to data object

In [13]:
edges = deepcopy(kg)

# Apply edge dictionary
edges['group_x'] = edges['x_index'].map(idx_to_group)
edges['group_y'] = edges['y_index'].map(idx_to_group)

# There are inconsistencies in the data, so we need to make sure that the group indices are stored to the correct side of the edge list
exceptions = ['pathway_protein', 'drug_effect', 'drug_protein', 'exposure_molfunc', 'exposure_protein', 'molfunc_protein']

# Group by relation
groups = edges.groupby('relation')

for relation, group in groups:
    
    # Done to see if the group is consistent
    subgroups = group.groupby('x_type')
    
    if subgroups.ngroups == 1:
        
        # If consistent, make a simple edge list
        x_indices = torch.tensor(group['group_x'].values, dtype=torch.long)
        y_indices = torch.tensor(group['group_y'].values, dtype=torch.long)
    
        edge_list = torch.stack([x_indices, y_indices], dim=0)
        
    else:
        
        # If not, then we need to make sure that the group indices are stored to the correct side of the edge list
        edge_list = [[],[]]
        
        if relation in exceptions:
            edge_list[0].extend(list(subgroups.get_group(list(subgroups.groups)[0])['group_x']))
            edge_list[1].extend(list(subgroups.get_group(list(subgroups.groups)[0])['group_y']))
            
            edge_list[0].extend(list(subgroups.get_group(list(subgroups.groups)[1])['group_y']))
            edge_list[1].extend(list(subgroups.get_group(list(subgroups.groups)[1])['group_x']))
            
        else:
            edge_list[0].extend(list(subgroups.get_group(list(subgroups.groups)[0])['group_y']))
            edge_list[1].extend(list(subgroups.get_group(list(subgroups.groups)[0])['group_x']))
            
            edge_list[0].extend(list(subgroups.get_group(list(subgroups.groups)[1])['group_x']))
            edge_list[1].extend(list(subgroups.get_group(list(subgroups.groups)[1])['group_y']))

        # Convert to useable type
        edge_list = torch.Tensor(edge_list)
        edge_list = edge_list.type(torch.int64)
        
    # Store in data
    data[group['x_type'].values[0], relation, group['y_type'].values[0]].edge_index = edge_list
    
# Data must be undirected for proper GNN Message Passing
data = T.ToUndirected()(data)
display(data)

HeteroData(
  gene_protein={ x=[27671, 16] },
  drug={ x=[7957, 16] },
  effect_phenotype={ x=[15311, 16] },
  disease={ x=[17080, 16] },
  biological_process={ x=[28642, 16] },
  molecular_function={ x=[11169, 16] },
  cellular_component={ x=[4176, 16] },
  exposure={ x=[818, 16] },
  pathway={ x=[2516, 16] },
  anatomy={ x=[14035, 16] },
  (anatomy, anatomy_anatomy, anatomy)={ edge_index=[2, 28064] },
  (gene_protein, anatomy_protein_absent, anatomy)={ edge_index=[2, 39774] },
  (gene_protein, anatomy_protein_present, anatomy)={ edge_index=[2, 3036406] },
  (biological_process, bioprocess_bioprocess, biological_process)={ edge_index=[2, 105772] },
  (gene_protein, bioprocess_protein, biological_process)={ edge_index=[2, 289610] },
  (cellular_component, cellcomp_cellcomp, cellular_component)={ edge_index=[2, 9690] },
  (gene_protein, cellcomp_protein, cellular_component)={ edge_index=[2, 166804] },
  (drug, contraindication, disease)={ edge_index=[2, 61350] },
  (disease, disease_dis

Finally, we use PyG's validation function to see if our dataset is valid

In [14]:
data.validate()

True

# Training Data

We organize the data into two main datasets, one with all edges and their relationships and the other with just drugs and diseases

In [15]:
pretrain_head_idx = []
pretrain_relation = []
pretrain_tail_idx = []

finetune_head_idx = []
finetine_relation = []
finetine_tail_idx = []

for i,edge_type in enumerate(data.edge_types):
    pretrain_head_idx.extend(data[edge_type].edge_index[0].tolist())
    pretrain_relation.extend([i]*data[edge_type].edge_index.shape[1])
    pretrain_tail_idx.extend(data[edge_type].edge_index[1].tolist())
    
    if edge_type[1] == 'contraindication' or edge_type[1] == 'rev_contraindication':
        finetune_head_idx.extend(data[edge_type].edge_index[0].tolist())
        finetine_relation.extend([0]*data[edge_type].edge_index.shape[1])
        finetine_tail_idx.extend(data[edge_type].edge_index[1].tolist())
    elif edge_type[1] == 'indication' or edge_type[1] == 'rev_indication':
        finetune_head_idx.extend(data[edge_type].edge_index[0].tolist())
        finetine_relation.extend([1]*data[edge_type].edge_index.shape[1])
        finetine_tail_idx.extend(data[edge_type].edge_index[1].tolist())
    elif edge_type[1] == 'off-label use' or edge_type[1] == 'rev_off-label use':
        finetune_head_idx.extend(data[edge_type].edge_index[0].tolist())
        finetine_relation.extend([2]*data[edge_type].edge_index.shape[1])
        finetine_tail_idx.extend(data[edge_type].edge_index[1].tolist())
        
pretrain_head_idx = torch.tensor(pretrain_head_idx)
pretrain_relation = torch.tensor(pretrain_relation)
pretrain_tail_idx = torch.tensor(pretrain_tail_idx)

finetune_head_idx = torch.tensor(finetune_head_idx)
finetine_relation = torch.tensor(finetine_relation)
finetine_tail_idx = torch.tensor(finetine_tail_idx)

display(pretrain_head_idx.shape, pretrain_relation.shape, pretrain_tail_idx.shape)
display(finetune_head_idx.shape, finetine_relation.shape, finetine_tail_idx.shape)

torch.Size([12604474])

torch.Size([12604474])

torch.Size([12604474])

torch.Size([160252])

torch.Size([160252])

torch.Size([160252])

Next we divide these edges into dataloaders

In [14]:
pdata = torch.stack([pretrain_head_idx, pretrain_relation, pretrain_tail_idx], dim=1)
fdata = torch.stack([finetune_head_idx, finetine_relation, finetine_tail_idx], dim=1)

psplit = torch.utils.data.random_split(pdata, [0.8,0.1,0.1])
fsplit = torch.utils.data.random_split(fdata, [0.8,0.1,0.1])

ptrain_loader = DataLoader(psplit[0], batch_size=128, shuffle=True)
pval_loader = DataLoader(psplit[1], batch_size=128, shuffle=True)
ptest_loader = DataLoader(psplit[2], batch_size=128, shuffle=True)

ftrain_loader = DataLoader(fsplit[0], batch_size=128, shuffle=True)
fval_loader = DataLoader(fsplit[1], batch_size=128, shuffle=True)
ftest_loader = DataLoader(fsplit[2], batch_size=128, shuffle=True)

This should be all the data preparation that is necessary for simple pretraining and fine tuning (I hope), a .py file will be created for the simple creation of these dataloaders