In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

from time import time
import pandas as pd
import numpy as np
import tqdm
import random
from collections import defaultdict
import argparse

import torch
import torch.nn as nn
from torch import cuda
from torch.optim import Adam
from torch.utils.data import Dataset

from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils import MarginLoss,DataLoader
from torchkge import KnowledgeGraph,DistMultModel,TransEModel,TransRModel
from torchkge.models.bilinear import HolEModel,ComplExModel

from utils import *

  from tqdm.autonotebook import tqdm


In [8]:
'''
This section is user defined !!!
'''
h_dim = 300

data_path = "../processed_data/ddr1/"
save_model_path = '../best_model/ddr1/'
output_path = "../results/ddr1/"

ent_list = ['CP1_B+10_B','CP1_SP+10_SP','CP1_B+LPS+10_B+LPS',
            'CP1_SP+LPS+10_SP+LPS','CP1_T+10_T','CP1_random']

In [9]:
# load processed_data after training
cause = pd.read_csv(data_path + 'cause.txt',sep='\t',names=['from','rel','to'])
ent2id = np.load(data_path + 'ent2id.npy', allow_pickle=True).item()
rel2id = np.load(data_path + 'rel2id.npy', allow_pickle=True).item()

h_cand = [v for k,v in ent2id.items() if k.startswith('CID:')]
h_cand_ent = [k for k,v in ent2id.items() if k.startswith('CID:')]

t_cand = [v for k,v in ent2id.items() if k.startswith(('Protein:','TF:','RBP:'))]
t_cand_ent = [k for k,v in ent2id.items() if k.startswith(('Protein:','TF:','RBP:'))]

In [None]:
# count target inference score
ti_dict = {}

for ent in tqdm.tqdm(ent_list):
    results = []
    for i in range(5):
        model = DistMultModel(h_dim, len(ent2id), len(rel2id))
        model.load_state_dict(torch.load(save_model_path + "pertkg{}.pt".format(i)))
        if cuda.is_available():
            cuda.empty_cache()
            model.cuda()
        model.normalize_parameters()
        model.eval()
        with torch.no_grad():
            ent_emb,rel_emb = model.get_embeddings() # (n_ent, emb_dim)
            score = inference(ent,
                            ent2id,rel2id,
                            ent_emb,rel_emb,
                            h_cand,t_cand,
                            'target_inference')
            results.append(score)

    average_list = [sum(x) / len(x) for x in zip(*results)]
    ti_dict['{}'.format(ent)] = average_list

In [None]:
# count confidence
results = []
for i in range(5):
    model = DistMultModel(h_dim, len(ent2id), len(rel2id))
    model.load_state_dict(torch.load(save_model_path + "pertkg{}.pt".format(i)))
    if cuda.is_available():
        cuda.empty_cache()
        model.cuda()
    model.normalize_parameters()
    model.eval()
    with torch.no_grad():        
        ent_emb,rel_emb = model.get_embeddings() # (n_ent, emb_dim)
        score = inference(h_cand_ent,
                        ent2id,rel2id,
                        ent_emb,rel_emb,
                        h_cand,t_cand,
                        'batch_target_inference')
        results.append(score)

arr1 = np.array(results[0])
arr2 = np.array(results[1])
arr3 = np.array(results[2])
arr4 = np.array(results[3])
arr5 = np.array(results[4])
average_arr = np.mean([arr1, arr2, arr3, arr4, 
                       arr5,
                       ], axis=0).tolist()

ti_dict_n_n = {}
for idx, ent in enumerate(h_cand_ent):
    ti_dict_n_n['{}'.format(ent)] = average_arr[idx]

ti_percent_dict = {}
for k,v in tqdm.tqdm(ti_dict.items()):  
    ti_percent = []

    k_ranks = get_rank(v)

    comp_ranks = []
    for comp in h_cand_ent:  
        comp_ranks.append(get_rank(ti_dict_n_n[comp]))
        
    packed_ranks = list(zip(*comp_ranks))

    for idx,x in enumerate(k_ranks):
        ranks = packed_ranks[idx]
        ti_percent.append(sum([1 for i in ranks if i >= (x+50)])/len(ranks))  # 50 is correct factor

    ti_percent_dict[k] = ti_percent


In [None]:
# output 
for k,v in ti_dict.items():
    ti_score = v
    ti_percent = ti_percent_dict[k]
    df = pd.DataFrame({'target':t_cand_ent,
                       'ti_score':ti_score,
                       'confidence':ti_percent})
    df.to_csv(output_path + '{}.txt'.format(k),sep='\t',index=False,header=True)