## Case Studies Analysis

In [None]:
## import python package
import os, sys
sys.path.append('./scripts')
import graph_tool.all as gt
import pickle
import torch
import numpy as np
import utils
import joblib
from glob import glob
import pandas as pd
from node_synonymizer import NodeSynonymizer
synonymizer = NodeSynonymizer()
import collections
import itertools
import networkx as nx
%matplotlib inline
import matplotlib.pyplot as plt
from networkx.drawing.nx_agraph import graphviz_layout

In [None]:
## define some functions
def get_name(graph_id):
    if synonymizer.get_canonical_curies(graph_id)[graph_id]:
        return synonymizer.get_canonical_curies(graph_id)[graph_id]['preferred_name']
    else:
        None

def make_path(rel_ent_score):
    rel_vec, ent_vec, score = rel_ent_score
    return ['->'.join([id_to_name[id2entity[ent_vec[index]]]+'->'+id2relation[rel_vec[index+1]] for index in range(len(ent_vec)-1)] + [id_to_name[id2entity[ent_vec[len(ent_vec)-1]]]]), score]

class knowledge_graph:

    def __init__(self, data_dir, bandwidth=3000):
        # Load data
        self.bandwidth = bandwidth
        self.entity2id, self.id2entity = utils.load_index(os.path.join(data_dir, 'entity2freq.txt'))
        self.num_entities = len(self.entity2id)
        self.relation2id, self.id2relation = utils.load_index(os.path.join(data_dir, 'relation2freq.txt'))

        # Load graph structures
        adj_list_path = os.path.join(data_dir, 'adj_list.pkl')
        with open(adj_list_path, 'rb') as f:
            self.adj_list = pickle.load(f)

        self.page_rank_scores = self.load_page_rank_scores(os.path.join(data_dir, 'kg.pgrk'))

        self.graph = {source:self.get_action_space(source) for source in range(self.num_entities)}

    def load_page_rank_scores(self, input_path):
        pgrk_scores = collections.defaultdict(float)
        with open(input_path) as f:
            for line in f:
                entity, score = line.strip().split('\t')
                entity_id = self.entity2id[entity.strip()]
                score = float(score)
                pgrk_scores[entity_id] = score
        return pgrk_scores


    def get_action_space(self, source):
        action_space = []
        if source in self.adj_list:
            for relation in self.adj_list[source]:
                targets = self.adj_list[source][relation]
                for target in targets:
                    action_space.append((relation, target))
            if len(action_space) + 1 >= self.bandwidth:
                # Base graph pruning
                sorted_action_space = sorted(action_space, key=lambda x: self.page_rank_scores[x[1]], reverse=True)
                action_space = sorted_action_space[:self.bandwidth]
        return action_space

def check_curie(curie):
    if curie is None:
        return (None, None)
    res = synonymizer.get_canonical_curies(curie)[curie]
    if res is not None:
        preferred_curie = synonymizer.get_canonical_curies(curie)[curie]['preferred_curie']
    else:
        preferred_curie = None
    if preferred_curie in kg.entity2id:
        return (preferred_curie, kg.entity2id[preferred_curie])
    else:
        return (preferred_curie, None)

    
def generate_graphml_files(pair):
    edges = path_res_dict[pair]
    nodes = set(list(edges[0]) + list(edges[1]))

    g = gt.Graph(directed=True)
    node_name = g.new_vertex_property("string")
    name_to_vertex = dict()
    for index, name in enumerate(nodes):
        v = g.add_vertex()
        node_name[v] = name
        name_to_vertex[name] = v

    edge_relation = g.new_edge_property("string")
    for index in range(len(edges)):
        source = edges.loc[index,0]
        target = edges.loc[index,1]
        e = g.add_edge(g.vertex(name_to_vertex[source]),g.vertex(name_to_vertex[target]))
        edge_relation[e] = edges.loc[index,2].replace('biolink:','')

    if not os.path.exists('graphml_files'):
        os.makedirs('graphml_files')
        
    g.vertex_properties["node name"] = node_name
    g.edge_properties["edge relation"] = edge_relation
    g.save(os.path.join('graphml_files','_'.join([get_name(x) for x in pair]).replace(' ','_').lower()+'.graphml'), fmt='graphml')


def my_draw_networkx_edge_labels(
    G,
    pos,
    edge_labels=None,
    label_pos=0.5,
    font_size=10,
    font_color="k",
    font_family="sans-serif",
    font_weight="normal",
    alpha=None,
    bbox=None,
    horizontalalignment="center",
    verticalalignment="center",
    ax=None,
    rotate=True,
    clip_on=True,
    rad=0
):
    import matplotlib.pyplot as plt
    import numpy as np

    if ax is None:
        ax = plt.gca()
    if edge_labels is None:
        labels = {(u, v): d for u, v, d in G.edges(data=True)}
    else:
        labels = edge_labels
    text_items = {}
    for (n1, n2), label in labels.items():
        (x1, y1) = pos[n1]
        (x2, y2) = pos[n2]
        (x, y) = (
            x1 * label_pos + x2 * (1.0 - label_pos),
            y1 * label_pos + y2 * (1.0 - label_pos),
        )
        pos_1 = ax.transData.transform(np.array(pos[n1]))
        pos_2 = ax.transData.transform(np.array(pos[n2]))
        linear_mid = 0.5*pos_1 + 0.5*pos_2
        d_pos = pos_2 - pos_1
        rotation_matrix = np.array([(0,1), (-1,0)])
        ctrl_1 = linear_mid + rad*rotation_matrix@d_pos
        ctrl_mid_1 = 0.5*pos_1 + 0.5*ctrl_1
        ctrl_mid_2 = 0.5*pos_2 + 0.5*ctrl_1
        bezier_mid = 0.5*ctrl_mid_1 + 0.5*ctrl_mid_2
        (x, y) = ax.transData.inverted().transform(bezier_mid)

        if rotate:
            # in degrees
            angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
            # make label orientation "right-side-up"
            if angle > 90:
                angle -= 180
            if angle < -90:
                angle += 180
            # transform data coordinate angle to screen coordinate angle
            xy = np.array((x, y))
            trans_angle = ax.transData.transform_angles(
                np.array((angle,)), xy.reshape((1, 2))
            )[0]
        else:
            trans_angle = 0.0
        # use default box of white with white border
        if bbox is None:
            bbox = dict(boxstyle="round", ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0))
        if not isinstance(label, str):
            label = str(label)  # this makes "1" and 1 labeled the same

        t = ax.text(
            x,
            y,
            label,
            size=font_size,
            color=font_color,
            family=font_family,
            weight=font_weight,
            alpha=alpha,
            horizontalalignment=horizontalalignment,
            verticalalignment=verticalalignment,
            rotation=trans_angle,
            transform=ax.transData,
            bbox=bbox,
            zorder=1,
            clip_on=clip_on,
        )
        text_items[(n1, n2)] = t

    ax.tick_params(
        axis="both",
        which="both",
        bottom=False,
        left=False,
        labelbottom=False,
        labelleft=False,
    )

    return text_items    

def reorganize_df(edges):
    temp = dict()
    for index in range(len(edges)):
        source, target = edges.loc[index,0],edges.loc[index,1]
        if (source, target) not in temp:
            temp[(source, target)] = dict()
            temp[(source, target)]['edges'] = set([edges.loc[index,2].replace('biolink:','')])
        else:
            temp[(source, target)]['edges'].update(set([edges.loc[index,2].replace('biolink:','')]))

    return pd.DataFrame([(key[0],key[1],'/'.join(list(value['edges']))) for key, value in temp.items()])

def path_graph(pair, title, tp_score, is_in_train_set, is_in_not_train_set, rad=0.07):
    edges = reorganize_df(path_res_dict[pair])
    g = nx.DiGraph()
    g.add_edges_from([(x[0],x[1], {'relation': x[2].replace('biolink:','')}) for x in edges.to_numpy()])
    labels= nx.get_edge_attributes(g,'relation')

    f = plt.figure(figsize=(18, 12));
    ax = f.add_subplot(111);
    pos = graphviz_layout(g, prog='dot', args='-Grankdir="LR"');
    nx.draw_networkx_nodes(g, pos, node_size=10, linewidths=0.5, alpha=0.9)
    nx.draw_networkx_edges(g, pos, connectionstyle=f'arc3,rad={rad}', width=0.5)
    nx.draw_networkx_labels(g, pos, font_size=12, font_color='blue');
    my_draw_networkx_edge_labels(g, pos, edge_labels=labels, rotate=True, font_size=6, font_weight='bold', rad = rad);
    plt.text(0.005, 0.06, f"Predicted Score: {tp_score}", size=10, color='black', fontweight='bold', transform=ax.transAxes);
    plt.text(0.005, 0.03, f"Is in train set: {is_in_train_set}", size=10, color='black', fontweight='bold', transform=ax.transAxes);
    plt.text(0.005, 0, f"Is in val/test set: {is_in_not_train_set}", size=10, color='black', fontweight='bold', transform=ax.transAxes);
    plt.text(0.005, 0.97, f"Drug ID in KG2c: {pair[0]} ({title.split(' - ')[0]})", size=10, color='black', fontweight='bold', transform=ax.transAxes);
    plt.text(0.005, 0.94, f"Disease ID in KG2c: {pair[1]} ({title.split(' - ')[1]})", size=10, color='black', fontweight='bold', transform=ax.transAxes);
    return [f,g]

In [None]:
## set up some variables
data_dir = './data'
folder_name = 'RF_model_3class'
model_name = 'RF_model.pt'

In [None]:
## read unsupervised GraphSage embedding vectors
with open(os.path.join(data_dir, 'graphsage_output', 'unsuprvised_graphsage_entity_embeddings_1epoch_10000.pkl'), 'rb') as infile:
    entity_embeddings_dict = pickle.load(infile)
## generate mapping dictionaries
entity2id, id2entity = utils.load_index(os.path.join(data_dir, 'entity2freq.txt'))
relation2id, id2relation = utils.load_index(os.path.join(data_dir, 'relation2freq.txt'))
all_graph_nodes_info = pd.read_csv(os.path.join(data_dir, 'all_graph_nodes_info.txt'), sep='\t', header=0)
all_graph_nodes_info['all_names'] = all_graph_nodes_info['all_names'].apply(eval)
id_to_name = dict() 
for index in range(len(all_graph_nodes_info)):
    if type(all_graph_nodes_info.loc[index,'name']) is str:
        id_to_name[all_graph_nodes_info.loc[index,'id']] = all_graph_nodes_info.loc[index,'name']
    else:
        id_to_name[all_graph_nodes_info.loc[index,'id']] = all_graph_nodes_info.loc[index,'all_names'][0]

In [None]:
## read purposed drug repurposing model
fitModel = joblib.load(os.path.join('models', folder_name, model_name))

In [None]:
## set up case studies disease ids
## check more information from 'synonyms' service on https://arax.ncats.io/
hemophilia_b_id = 'MONDO:0010604'
huntington_disease_id = 'MONDO:0007739'

In [None]:
## find all drugs ids
type2id, id2type = utils.load_index(os.path.join(data_dir, 'type2freq.txt'))
with open(os.path.join(data_dir, 'entity2typeid.pkl'), 'rb') as infile:
    entity2typeid = pickle.load(infile)
drug_type = ['biolink:Drug', 'biolink:SmallMolecule']
drug_type_ids = [type2id[x] for x in drug_type]
drug_ids = [id2entity[index] for index, typeid in enumerate(entity2typeid) if typeid in drug_type_ids]
## read training, validation and test datasets
train_rf_3class = pd.read_csv(os.path.join(data_dir, 'pretrain_reward_shaping_model_train_val_test_random_data_3class', 'train_pairs.txt'), sep='\t', header=0)
val_rf_3class = pd.read_csv(os.path.join(data_dir, 'pretrain_reward_shaping_model_train_val_test_random_data_3class', 'val_pairs.txt'), sep='\t', header=0)
test_rf_3class = pd.read_csv(os.path.join(data_dir, 'pretrain_reward_shaping_model_train_val_test_random_data_3class', 'test_pairs.txt'), sep='\t', header=0)

In [None]:
## read customized RTX-KG2c into graph tools
kg = knowledge_graph(data_dir, bandwidth=3000)
G = gt.Graph()
kg_tmp = dict()
for source in kg.graph:
    for (relation, target) in kg.graph[source]:
        if (source, target) not in kg_tmp:
            kg_tmp[(source, target)] = set([relation])
        else:
            kg_tmp[(source, target)].update(set([relation]))
etype = G.new_edge_property('object')
for (source, target) in kg_tmp:
    e = G.add_edge(source,target)
    etype[e] = kg_tmp[(source, target)]
G.edge_properties['edge_type'] = etype

### Case Studies Analysis with Hemophilia B

In [None]:
X = np.vstack([np.hstack([entity_embeddings_dict[drug_id],entity_embeddings_dict[hemophilia_b_id]]) for drug_id in drug_ids])
res_temp = fitModel.predict_proba(X)
res = pd.concat([pd.DataFrame(drug_ids),pd.DataFrame([hemophilia_b_id]*len(drug_ids)),pd.DataFrame(res_temp)], axis=1)
res.columns = ['drug_id','disease_id','tn_score','tp_score','unknown_score']
res = res.sort_values(by=['tp_score'],ascending=False).reset_index(drop=True)

In [None]:
tp_in_train_set = pd.DataFrame(res['drug_id'].isin(list(train_rf_3class.loc[(train_rf_3class['target'].isin([hemophilia_b_id])) & (train_rf_3class['y']==1),'source'])))
tp_not_in_train_set_list = list(val_rf_3class.loc[(val_rf_3class['target'].isin([hemophilia_b_id])) & (val_rf_3class['y']==1),'source']) + list(test_rf_3class.loc[(test_rf_3class['target'].isin([hemophilia_b_id])) & (test_rf_3class['y']==1),'source'])
tp_not_in_train_set = pd.DataFrame(res['drug_id'].isin(tp_not_in_train_set_list))
tn_in_train_set = pd.DataFrame(res['drug_id'].isin(list(train_rf_3class.loc[(train_rf_3class['target'].isin([hemophilia_b_id])) & (train_rf_3class['y']==0),'source'])))
tn_not_in_train_set_list = list(val_rf_3class.loc[(val_rf_3class['target'].isin([hemophilia_b_id])) & (val_rf_3class['y']==0),'source']) + list(test_rf_3class.loc[(test_rf_3class['target'].isin([hemophilia_b_id])) & (test_rf_3class['y']==0),'source'])
tn_not_in_train_set = pd.DataFrame(res['drug_id'].isin(tn_not_in_train_set_list))
random_pairs_in_train_set = pd.DataFrame(res['drug_id'].isin(list(train_rf_3class.loc[(train_rf_3class['target'].isin([hemophilia_b_id])) & (train_rf_3class['y']==2),'source'])))
random_pairs_not_in_train_set_list = list(val_rf_3class.loc[(val_rf_3class['target'].isin([hemophilia_b_id])) & (val_rf_3class['y']==2),'source']) + list(test_rf_3class.loc[(test_rf_3class['target'].isin([hemophilia_b_id])) & (test_rf_3class['y']==2),'source'])
random_pairs_not_in_train_set = pd.DataFrame(res['drug_id'].isin(random_pairs_not_in_train_set_list))
res = pd.concat([res,tp_in_train_set,tp_not_in_train_set,tn_in_train_set,tn_not_in_train_set,random_pairs_in_train_set,random_pairs_not_in_train_set], axis=1)
res.columns = ['drug_id', 'disease_id', 'tn_score', 'tp_score', 'unknown_score', 'tp_in_train_set', 'tp_not_in_train_set', 'tn_in_train_set','tn_not_in_train_set','random_pairs_in_train_set','random_pairs_not_in_train_set']
test = res.apply(lambda row: [get_name(row[0]),get_name(row[1])], axis=1, result_type='expand')
res = pd.concat([res['drug_id'],test[0].str.lower(),res['disease_id'],test[1],res.loc[:,list(res.columns)[2:]]], axis=1).reset_index(drop=True)
res.columns = ['drug_id', 'drug_name', 'disease_id', 'disease_name', 'tn_score', 'tp_score', 'unknown_score', 'tp_in_train_set', 'tp_not_in_train_set', 'tn_in_train_set','tn_not_in_train_set','random_pairs_in_train_set','random_pairs_not_in_train_set']

In [None]:
## show top 50 predicted results with top 5 in training set
temp1 = pd.concat([res.loc[res['tp_in_train_set'],:][:5],res.loc[~res['tp_in_train_set'],:]]).reset_index(drop=True)
hemophilia_b_top50 = temp1.loc[:50,:].reset_index(drop=True)

In [None]:
## extract all kg-based paths between the top 50 drug-disease pairs
filtered_res_all_paths = dict()
filter_edges = [kg.relation2id[edge] for edge in ['biolink:related_to','biolink:biolink:part_of','biolink:coexists_with','biolink:contraindicated_for']]
for index1 in range(len(huntington_top50)):
    print(index1)
    source, target = huntington_top50.loc[index1,['drug_id','disease_id']]
    all_paths = [list(path) for path in gt.all_paths(G, check_curie(source)[1], check_curie(target)[1], cutoff=3)]
    entity_paths = []
    relation_paths = []
    for path in all_paths:
        path_temp = []
        for index2 in range(len(path)-1):
            if index2 == 0:
                path_temp += [path[index2], list(etype[G.edge(path[index2], path[index2+1])]), path[index2+1]]
            else:
                path_temp += [list(etype[G.edge(path[index2], path[index2+1])]), path[index2+1]]
        flattened_paths = list(itertools.product(*map(lambda x: [x] if type(x) is not list else x, path_temp)))
        for flattened_path in flattened_paths:
            if len(flattened_path) == 7:
                relation_paths += [[kg.relation2id['SELF_LOOP_RELATION']] + [x for index3, x in enumerate(flattened_path) if index3%2==1]]
                entity_paths += [[x for index3, x in enumerate(flattened_path) if index3%2==0]]
            elif len(flattened_path) == 5:
                relation_paths += [[kg.relation2id['SELF_LOOP_RELATION']] + [x for index3, x in enumerate(flattened_path) if index3%2==1] + [kg.relation2id['SELF_LOOP_RELATION']]]
                entity_paths += [[x for index3, x in enumerate(flattened_path) if index3%2==0] + [flattened_path[-1]]]
            else:
                logger.info(f"Found weird path: {flattened_path}")
    edge_mat = torch.tensor(relation_paths)
    node_mat = torch.tensor(np.array(entity_paths).astype(int))
    temp = pd.DataFrame(edge_mat.numpy())
    keep_index = list(temp.loc[~(temp[1].isin(filter_edges) | temp[2].isin(filter_edges) | temp[3].isin(filter_edges)),:].index)
    filtered_res_all_paths[(source,target)] = [edge_mat[keep_index],node_mat[keep_index]]

In [None]:
## save paths and calcualte path probability based on the trained ADAC-based RL model
if not os.path.exists(os.path.join('case_study_data','hemophilia_b')):
    os.makedirs(os.path.join('case_study_data','hemophilia_b'))
with open(os.path.join('case_study_data','hemophilia_b','paths.pkl'),'wb') as outfile:
    pickle.dump(filtered_res_all_paths, outfile)
## run the following python script (see main.sh)
## python ${work_folder}/scripts/calculate_path_prob.py --log_dir ${work_folder}/log_folder --log_name case_study_Huntington.log --data_dir ${work_folder}/data --policy_net_file ${work_folder}/models/ADAC_model/policy_net/policy_model_epoch51.pt --target_paths_file ${work_folder}/case_study_data/hemophilia_b/paths.pkl --output_file ${work_folder}/case_study_data/huntington_disease/paths_prob.pkl --max_path 3 --bandwidth 3000 --bucket_interval 50 --pretrain_model_path ${work_folder}/models/RF_model_3class/RF_model.pt --use_gpu --state_history 2 --ac_hidden 512 512 --disc_hidden 512 512 --metadisc_hidden 512 256 --batch_size 1000 --factor 0.9 

In [None]:
path_res = dict()
N = 10

with open(os.path.join('case_study_data','hemophilia_b','paths_prob.pkl'),'rb') as infile:
    paths_temp = pickle.load(infile)
for (source, target) in paths_temp:
    batch_paths = paths_temp[(source, target)]
    if len(batch_paths[1]) == 0:
        continue
    temp = pd.DataFrame(batch_paths[0].numpy())
    batch_paths = [batch_paths[0],batch_paths[1],batch_paths[2]]
    pred_prob_scores = batch_paths[2]
    sorted_scores, indices = torch.sort(pred_prob_scores, descending=True)
    batch_paths_sorted = [batch_paths[0][indices], batch_paths[1][indices], sorted_scores]
    # top_10_indexes = list(pd.DataFrame(batch_paths_sorted[1].numpy()).drop_duplicates().index[:10])
    temp_dict = dict()
    count = 0
    top_10_indexes = []
    for index, x in enumerate(batch_paths_sorted[1].numpy()):
        if tuple(x) in temp_dict:
            top_10_indexes += [index]
        else:
            count += 1
            temp_dict[tuple(x)] = 1
            top_10_indexes += [index]
        if count == N:
            break
    res = [batch_paths_sorted[0][top_10_indexes], batch_paths_sorted[1][top_10_indexes], batch_paths_sorted[2][top_10_indexes]]
    path_res[(source, target)] = [make_path([res[0][index].numpy(),res[1][index].numpy(), res[2][index].numpy().item()]) for index in range(len(res[0]))] + [fitModel.predict_proba(np.hstack([entity_embeddings_dict[source],entity_embeddings_dict[target]]).reshape(1,-1))[0]]

In [None]:
path_res_dict = dict()
for pair in path_res:
    path_segment_set = set()
    for path in path_res[pair][:-1]:
        path_segment = path[0].split('->')
        temp = set([(path_segment[index],path_segment[index+2],path_segment[index+1]) for index in range(0,len(path_segment)-2,2)])
        path_segment_set.update(temp)
    path_res_dict[pair] = pd.DataFrame(path_segment_set)

for pair in path_res_dict:
    if path_res_dict[pair].shape[0] == 0:
        continue
    drug_name = pred_res.loc[(pred_res['drug_id']==pair[0]) & (pred_res['disease_id']==pair[1]),'drug_name'].to_numpy().item().capitalize()
    disease_name = pred_res.loc[(pred_res['drug_id']==pair[0]) & (pred_res['disease_id']==pair[1]),'disease_name'].to_numpy().item().capitalize()
    title = f"{drug_name} - {disease_name}"
    tp_score = round(pred_res.loc[(pred_res['drug_id']==pair[0]) & (pred_res['disease_id']==pair[1]),'tp_score'].to_numpy().item(),6)
    is_in_train_set = str(pred_res.loc[(pred_res['drug_id']==pair[0]) & (pred_res['disease_id']==pair[1]),'tp_in_train_set'].to_numpy().item())
    is_in_not_train_set = str(pred_res.loc[(pred_res['drug_id']==pair[0]) & (pred_res['disease_id']==pair[1]),'tp_not_in_train_set'].to_numpy().item())
    fig, g = path_graph(pair, title, tp_score, is_in_train_set, is_in_not_train_set)

### Case Studies Analysis with Huntington disease

In [None]:
X = np.vstack([np.hstack([entity_embeddings_dict[drug_id],entity_embeddings_dict[huntington_disease_id]]) for drug_id in drug_ids])
res_temp = fitModel.predict_proba(X)
res = pd.concat([pd.DataFrame(drug_ids),pd.DataFrame([huntington_disease_id]*len(drug_ids)),pd.DataFrame(res_temp)], axis=1)
res.columns = ['drug_id','disease_id','tn_score','tp_score','unknown_score']
res = res.sort_values(by=['tp_score'],ascending=False).reset_index(drop=True)

In [None]:
tp_in_train_set = pd.DataFrame(res['drug_id'].isin(list(train_rf_3class.loc[(train_rf_3class['target'].isin([huntington_disease_id])) & (train_rf_3class['y']==1),'source'])))
tp_not_in_train_set_list = list(val_rf_3class.loc[(val_rf_3class['target'].isin([huntington_disease_id])) & (val_rf_3class['y']==1),'source']) + list(test_rf_3class.loc[(test_rf_3class['target'].isin([huntington_disease_id])) & (test_rf_3class['y']==1),'source'])
tp_not_in_train_set = pd.DataFrame(res['drug_id'].isin(tp_not_in_train_set_list))
tn_in_train_set = pd.DataFrame(res['drug_id'].isin(list(train_rf_3class.loc[(train_rf_3class['target'].isin([huntington_disease_id])) & (train_rf_3class['y']==0),'source'])))
tn_not_in_train_set_list = list(val_rf_3class.loc[(val_rf_3class['target'].isin([huntington_disease_id])) & (val_rf_3class['y']==0),'source']) + list(test_rf_3class.loc[(test_rf_3class['target'].isin([huntington_disease_id])) & (test_rf_3class['y']==0),'source'])
tn_not_in_train_set = pd.DataFrame(res['drug_id'].isin(tn_not_in_train_set_list))
random_pairs_in_train_set = pd.DataFrame(res['drug_id'].isin(list(train_rf_3class.loc[(train_rf_3class['target'].isin([huntington_disease_id])) & (train_rf_3class['y']==2),'source'])))
random_pairs_not_in_train_set_list = list(val_rf_3class.loc[(val_rf_3class['target'].isin([huntington_disease_id])) & (val_rf_3class['y']==2),'source']) + list(test_rf_3class.loc[(test_rf_3class['target'].isin([huntington_disease_id])) & (test_rf_3class['y']==2),'source'])
random_pairs_not_in_train_set = pd.DataFrame(res['drug_id'].isin(random_pairs_not_in_train_set_list))
res = pd.concat([res,tp_in_train_set,tp_not_in_train_set,tn_in_train_set,tn_not_in_train_set,random_pairs_in_train_set,random_pairs_not_in_train_set], axis=1)
res.columns = ['drug_id', 'disease_id', 'tn_score', 'tp_score', 'unknown_score', 'tp_in_train_set', 'tp_not_in_train_set', 'tn_in_train_set','tn_not_in_train_set','random_pairs_in_train_set','random_pairs_not_in_train_set']
test = res.apply(lambda row: [get_name(row[0]),get_name(row[1])], axis=1, result_type='expand')
res = pd.concat([res['drug_id'],test[0].str.lower(),res['disease_id'],test[1],res.loc[:,list(res.columns)[2:]]], axis=1).reset_index(drop=True)
res.columns = ['drug_id', 'drug_name', 'disease_id', 'disease_name', 'tn_score', 'tp_score', 'unknown_score', 'tp_in_train_set', 'tp_not_in_train_set', 'tn_in_train_set','tn_not_in_train_set','random_pairs_in_train_set','random_pairs_not_in_train_set']

In [None]:
## show top 50 predicted results with top 5 in training set
temp1 = pd.concat([res.loc[res['tp_in_train_set'],:][:5],res.loc[~res['tp_in_train_set'],:]]).reset_index(drop=True)
huntington_top50 = temp1.loc[:50,:].reset_index(drop=True)

In [None]:
## extract all kg-based paths between the top 50 drug-disease pairs
filtered_res_all_paths = dict()
filter_edges = [kg.relation2id[edge] for edge in ['biolink:related_to','biolink:biolink:part_of','biolink:coexists_with','biolink:contraindicated_for']]
for index1 in range(len(huntington_top50)):
    print(index1)
    source, target = huntington_top50.loc[index1,['drug_id','disease_id']]
    all_paths = [list(path) for path in gt.all_paths(G, check_curie(source)[1], check_curie(target)[1], cutoff=3)]
    entity_paths = []
    relation_paths = []
    for path in all_paths:
        path_temp = []
        for index2 in range(len(path)-1):
            if index2 == 0:
                path_temp += [path[index2], list(etype[G.edge(path[index2], path[index2+1])]), path[index2+1]]
            else:
                path_temp += [list(etype[G.edge(path[index2], path[index2+1])]), path[index2+1]]
        flattened_paths = list(itertools.product(*map(lambda x: [x] if type(x) is not list else x, path_temp)))
        for flattened_path in flattened_paths:
            if len(flattened_path) == 7:
                relation_paths += [[kg.relation2id['SELF_LOOP_RELATION']] + [x for index3, x in enumerate(flattened_path) if index3%2==1]]
                entity_paths += [[x for index3, x in enumerate(flattened_path) if index3%2==0]]
            elif len(flattened_path) == 5:
                relation_paths += [[kg.relation2id['SELF_LOOP_RELATION']] + [x for index3, x in enumerate(flattened_path) if index3%2==1] + [kg.relation2id['SELF_LOOP_RELATION']]]
                entity_paths += [[x for index3, x in enumerate(flattened_path) if index3%2==0] + [flattened_path[-1]]]
            else:
                logger.info(f"Found weird path: {flattened_path}")
    edge_mat = torch.tensor(relation_paths)
    node_mat = torch.tensor(np.array(entity_paths).astype(int))
    temp = pd.DataFrame(edge_mat.numpy())
    keep_index = list(temp.loc[~(temp[1].isin(filter_edges) | temp[2].isin(filter_edges) | temp[3].isin(filter_edges)),:].index)
    filtered_res_all_paths[(source,target)] = [edge_mat[keep_index],node_mat[keep_index]]

In [None]:
## save paths and calcualte path probability based on the trained ADAC-based RL model
if not os.path.exists(os.path.join('case_study_data','huntington_disease')):
    os.makedirs(os.path.join('case_study_data','huntington_disease'))
with open(os.path.join('case_study_data','huntington_disease','paths.pkl'),'wb') as outfile:
    pickle.dump(filtered_res_all_paths, outfile)
## run the following python script (see main.sh)
## python ${work_folder}/scripts/calculate_path_prob.py --log_dir ${work_folder}/log_folder --log_name case_study_Huntington.log --data_dir ${work_folder}/data --policy_net_file ${work_folder}/models/ADAC_model/policy_net/policy_model_epoch51.pt --target_paths_file ${work_folder}/case_study_data/huntington_disease/paths.pkl --output_file ${work_folder}/case_study_data/huntington_disease/paths_prob.pkl --max_path 3 --bandwidth 3000 --bucket_interval 50 --pretrain_model_path ${work_folder}/models/RF_model_3class/RF_model.pt --use_gpu --state_history 2 --ac_hidden 512 512 --disc_hidden 512 512 --metadisc_hidden 512 256 --batch_size 1000 --factor 0.9 

In [None]:
path_res = dict()
N = 10

with open(os.path.join('case_study_data','huntington_disease','paths_prob.pkl'),'rb') as infile:
    paths_temp = pickle.load(infile)
for (source, target) in paths_temp:
    batch_paths = paths_temp[(source, target)]
    if len(batch_paths[1]) == 0:
        continue
    temp = pd.DataFrame(batch_paths[0].numpy())
    batch_paths = [batch_paths[0],batch_paths[1],batch_paths[2]]
    pred_prob_scores = batch_paths[2]
    sorted_scores, indices = torch.sort(pred_prob_scores, descending=True)
    batch_paths_sorted = [batch_paths[0][indices], batch_paths[1][indices], sorted_scores]
    # top_10_indexes = list(pd.DataFrame(batch_paths_sorted[1].numpy()).drop_duplicates().index[:10])
    temp_dict = dict()
    count = 0
    top_10_indexes = []
    for index, x in enumerate(batch_paths_sorted[1].numpy()):
        if tuple(x) in temp_dict:
            top_10_indexes += [index]
        else:
            count += 1
            temp_dict[tuple(x)] = 1
            top_10_indexes += [index]
        if count == N:
            break
    res = [batch_paths_sorted[0][top_10_indexes], batch_paths_sorted[1][top_10_indexes], batch_paths_sorted[2][top_10_indexes]]
    path_res[(source, target)] = [make_path([res[0][index].numpy(),res[1][index].numpy(), res[2][index].numpy().item()]) for index in range(len(res[0]))] + [fitModel.predict_proba(np.hstack([entity_embeddings_dict[source],entity_embeddings_dict[target]]).reshape(1,-1))[0]]

In [None]:
path_res_dict = dict()
for pair in path_res:
    path_segment_set = set()
    for path in path_res[pair][:-1]:
        path_segment = path[0].split('->')
        temp = set([(path_segment[index],path_segment[index+2],path_segment[index+1]) for index in range(0,len(path_segment)-2,2)])
        path_segment_set.update(temp)
    path_res_dict[pair] = pd.DataFrame(path_segment_set)

for pair in path_res_dict:
    if path_res_dict[pair].shape[0] == 0:
        continue
    drug_name = pred_res.loc[(pred_res['drug_id']==pair[0]) & (pred_res['disease_id']==pair[1]),'drug_name'].to_numpy().item().capitalize()
    disease_name = pred_res.loc[(pred_res['drug_id']==pair[0]) & (pred_res['disease_id']==pair[1]),'disease_name'].to_numpy().item().capitalize()
    title = f"{drug_name} - {disease_name}"
    tp_score = round(pred_res.loc[(pred_res['drug_id']==pair[0]) & (pred_res['disease_id']==pair[1]),'tp_score'].to_numpy().item(),6)
    is_in_train_set = str(pred_res.loc[(pred_res['drug_id']==pair[0]) & (pred_res['disease_id']==pair[1]),'tp_in_train_set'].to_numpy().item())
    is_in_not_train_set = str(pred_res.loc[(pred_res['drug_id']==pair[0]) & (pred_res['disease_id']==pair[1]),'tp_not_in_train_set'].to_numpy().item())
    fig, g = path_graph(pair, title, tp_score, is_in_train_set, is_in_not_train_set)