# Training DRKG Using TransE_L2
This notebook shows how to train DRKG embeddings using TransE_L2

Before training the model, you need to download the original DRKG source file into your local storage, e.g., ./data/drkg.tsv

## Install DGL-KE
Before training the model, we need to install dgl and dgl-ke packages as well as other dependencies. 

In [4]:
!sudo pip3 install torch
!sudo pip3 install dgl 
!sudo pip3 install dglke

Collecting torch
[?25l  Downloading https://files.pythonhosted.org/packages/13/70/54e9fb010fe1547bc4774716f11ececb81ae5b306c05f090f4461ee13205/torch-1.5.0-cp36-cp36m-manylinux1_x86_64.whl (752.0MB)
[K     |████████████████████████████████| 752.0MB 79.9MB/s eta 0:00:01
Installing collected packages: torch
Successfully installed torch-1.5.0
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
Collecting dgl
[?25l  Downloading https://files.pythonhosted.org/packages/c5/b4/84e4ebd70ef3985181ef5d2d2a366a45af0e3cd18d249fb212ac03f683cf/dgl-0.4.3.post2-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)
[K     |████████████████████████████████| 3.0MB 2.9MB/s eta 0:00:01
Collecting scipy>=1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/dc/29/162476fd44203116e7980cfbd9352eef9db37c49445d1fec35509022f6aa/scipy-1.4.1-cp36-cp36m-manylinux1_x86_64.whl (26.1MB)
[K     |████████████████████████████████| 26.1MB 23.0MB/s eta 0:00:01
Installing collected packages: sci

In [1]:
import pandas as pd
import numpy as np
import sys
sys.path.insert(1, '../utils')
from utils import download_and_extract
download_and_extract()
drkg_file = '../data/drkg/drkg.tsv'

df = pd.read_csv(drkg_file, sep="\t")
triples = df.values

In [2]:
num_triples = len(triples)
num_triples
# Please make sure the output directory exist.
import os
#os.mkdir('train_Hetionet')

In [8]:
rowtax,col=np.where(triples=='GNBR::in_tax::Gene:Tax')
row,col=np.where(triples!='GNBR::in_tax::Gene:Tax')
ntriples=triples[row]
taxclass=triples[rowtax]

In [77]:
labels=np.zeros((len(np.unique(taxclass[:,0])),len(np.unique(taxclass[:,2]))), dtype=int)

In [78]:
genes=taxclass[:,0]

In [79]:
classes=taxclass[:,2]

In [86]:
d = dict([(y,x) for x,y in enumerate(sorted(set(genes.tolist())))])
dc = dict([(y,x) for x,y in enumerate(sorted(set(classes.tolist())))])
classes_id=list(dc.keys())
gene_id=list(d.keys())

In [87]:
for i in range(len(genes)):
    labels[d[genes[i]],dc[classes[i]]]=1

In [88]:
classes_retained_ind=sum(labels)>100

In [89]:
genes_retained_ind=np.sum(labels[:,classes_retained_ind],axis=1)>0

In [91]:
classes_retained=np.array(classes_id)[classes_retained_ind]

In [92]:
genes_retained=np.array(gene_id)[genes_retained_ind]

In [99]:
download_and_extract()
drkg_file = '../data/drkg/drkg.tsv'
df = pd.read_csv(drkg_file, sep ="\t", header=None)
triplets = df.values.tolist()
entity_dictionary = {}
def insert_entry(entry, ent_type, dic):
    if ent_type not in dic:
        dic[ent_type] = {}
    ent_n_id = len(dic[ent_type])
    if entry not in dic[ent_type]:
         dic[ent_type][entry] = ent_n_id
    return dic

for triple in triplets:
    if triple[1]!='GNBR::in_tax::Gene:Tax':
        src = triple[0]
        split_src = src.split('::')
        src_type = split_src[0]
        dest = triple[2]
        split_dest = dest.split('::')
        dest_type = split_dest[0]
        insert_entry(src,src_type,entity_dictionary)
        insert_entry(dest,dest_type,entity_dictionary)
edge_dictionary={}
for triple in triplets:
    if triple[1]!='GNBR::in_tax::Gene:Tax':
        src = triple[0]
        split_src = src.split('::')
        src_type = split_src[0]
        dest = triple[2]
        split_dest = dest.split('::')
        dest_type = split_dest[0]

        src_int_id = entity_dictionary[src_type][src]
        dest_int_id = entity_dictionary[dest_type][dest]

        pair = (src_int_id,dest_int_id)
        etype = (src_type,triple[1],dest_type)
        if etype in edge_dictionary:
            edge_dictionary[etype] += [pair]
        else:
            edge_dictionary[etype] = [pair]

In [100]:
import dgl
graph_wo_tax = dgl.heterograph(edge_dictionary)

In [107]:
labels=np.zeros((len((entity_dictionary['Gene'])),len(np.unique(classes_retained))), dtype=int)

In [108]:
dcc = dict([(y,x) for x,y in enumerate(sorted(set(classes_retained.tolist())))])
labeled_ind=[]
for i in range(len(genes)):
    if classes[i] in dcc:
        labels[entity_dictionary['Gene'][genes[i]],dcc[classes[i]]]=1
        labeled_ind+=[entity_dictionary['Gene'][genes[i]]]

In [110]:
import pickle,os
pickle.dump(graph_wo_tax, open("graph_wotax.pickle", "wb"),
                protocol=4);
pickle.dump(labels, open("labels.pickle", "wb"),
                protocol=4);
pickle.dump(labeled_ind, open("labeled_ind.pickle", "wb"),
                protocol=4);

In [112]:
graph_wo_tax

Graph(num_nodes={'Anatomy': 400, 'Atc': 4048, 'Biological Process': 11381, 'Cellular Component': 1391, 'Compound': 24313, 'Disease': 5103, 'Gene': 39220, 'Molecular Function': 2884, 'Pathway': 1822, 'Pharmacologic Class': 345, 'Side Effect': 5701, 'Symptom': 415},
      num_edges={('Gene', 'bioarx::HumGenHumGen:Gene:Gene', 'Gene'): 58094, ('Gene', 'bioarx::VirGenHumGen:Gene:Gene', 'Gene'): 535, ('Compound', 'bioarx::DrugVirGen:Compound:Gene', 'Gene'): 1165, ('Compound', 'bioarx::DrugHumGen:Compound:Gene', 'Gene'): 24501, ('Disease', 'bioarx::Covid2_acc_host_gene::Disease:Gene', 'Gene'): 332, ('Disease', 'bioarx::Coronavirus_ass_host_gene::Disease:Gene', 'Gene'): 129, ('Gene', 'DGIDB::INHIBITOR::Gene:Compound', 'Compound'): 5971, ('Gene', 'DGIDB::ANTAGONIST::Gene:Compound', 'Compound'): 3006, ('Gene', 'DGIDB::OTHER::Gene:Compound', 'Compound'): 11070, ('Gene', 'DGIDB::AGONIST::Gene:Compound', 'Compound'): 3012, ('Gene', 'DGIDB::BINDER::Gene:Compound', 'Compound'): 143, ('Gene', 'DGIDB::