In [63]:
import os
import json
import pandas as pd


1. All indication link for drug and disease
2. Get the drug-gene, gene-disease DF 
3. Connect them and merged with indication DF
4. Get the drug-gene-disease indication DF as positive training set

all used link:
    "disease_protein",
    "drug_effect",
    "drug_protein",
    "disease_phenotype_positive"
    

# Init the folder

In [64]:
NODE_FILE = "/playpen/jesse/drug_repurpose/PrimeKG/nodes.csv"
df_node = pd.read_csv(NODE_FILE)

unique_node_types = df_node['node_type'].nunique()
print(f"Number of unique node types: {unique_node_types}")

print("Unique node types:")
print(df_node['node_type'].unique())

print("\nCount of each node type:")
print(df_node['node_type'].value_counts())

df_node

Number of unique node types: 10
Unique node types:
['gene/protein' 'drug' 'effect/phenotype' 'disease' 'biological_process'
 'molecular_function' 'cellular_component' 'exposure' 'pathway' 'anatomy']

Count of each node type:
node_type
biological_process    28642
gene/protein          27671
disease               17080
effect/phenotype      15311
anatomy               14035
molecular_function    11169
drug                   7957
cellular_component     4176
pathway                2516
exposure                818
Name: count, dtype: int64


Unnamed: 0,node_index,node_id,node_type,node_name,node_source
0,0,9796,gene/protein,PHYHIP,NCBI
1,1,7918,gene/protein,GPANK1,NCBI
2,2,8233,gene/protein,ZRSR2,NCBI
3,3,4899,gene/protein,NRF1,NCBI
4,4,5297,gene/protein,PI4KA,NCBI
...,...,...,...,...,...
129370,129370,R-HSA-936837,pathway,Ion transport by P-type ATPases,REACTOME
129371,129371,R-HSA-997272,pathway,Inhibition of voltage gated Ca2+ channels via...,REACTOME
129372,129372,1062,anatomy,anatomical entity,UBERON
129373,129373,468,anatomy,multi-cellular organism,UBERON


In [65]:
source_counts = df_node.groupby(['node_type', 'node_source']).size().reset_index(name='count')
print(source_counts)

             node_type    node_source  count
0              anatomy         UBERON  14035
1   biological_process             GO  28642
2   cellular_component             GO   4176
3              disease          MONDO  15813
4              disease  MONDO_grouped   1267
5                 drug       DrugBank   7957
6     effect/phenotype            HPO  15311
7             exposure            CTD    818
8         gene/protein           NCBI  27671
9   molecular_function             GO  11169
10             pathway       REACTOME   2516


In [66]:
EDGE_FILE = "/playpen/jesse/drug_repurpose/PrimeKG/edges.csv"  
df_edges = pd.read_csv(EDGE_FILE)

unique_relations = df_edges['relation'].nunique()
print(f"Number of unique relation types: {unique_relations}")

print("\nUnique relation types:")
print(df_edges['relation'].unique())

print("\nRelation counts:")
print(df_edges['relation'].value_counts())

Number of unique relation types: 30

Unique relation types:
['protein_protein' 'drug_protein' 'contraindication' 'indication'
 'off-label use' 'drug_drug' 'phenotype_protein' 'phenotype_phenotype'
 'disease_phenotype_negative' 'disease_phenotype_positive'
 'disease_protein' 'disease_disease' 'drug_effect' 'bioprocess_bioprocess'
 'molfunc_molfunc' 'cellcomp_cellcomp' 'molfunc_protein'
 'cellcomp_protein' 'bioprocess_protein' 'exposure_protein'
 'exposure_disease' 'exposure_exposure' 'exposure_bioprocess'
 'exposure_molfunc' 'exposure_cellcomp' 'pathway_pathway'
 'pathway_protein' 'anatomy_anatomy' 'anatomy_protein_present'
 'anatomy_protein_absent']

Relation counts:
relation
anatomy_protein_present       3036406
drug_drug                     2672628
protein_protein                642150
disease_phenotype_positive     300634
bioprocess_protein             289610
cellcomp_protein               166804
disease_protein                160822
molfunc_protein                139060
drug_effect

# Step1: Get all indication 

In [158]:
indication_relation = "indication"

# Load known indications
indication_edges = df_edges[df_edges['relation'] == indication_relation][['x_index', 'y_index']]
print(indication_edges)

         x_index  y_index
346730     16687    33577
346731     16687    36035
346764     20297    33577
346765     20297    36035
346768     16693    33577
...          ...      ...
5776153    84333    14471
5776154    27527    16634
5776155    38622    16634
5776156    28673    16634
5776158    39497    17237

[18776 rows x 2 columns]


In [178]:
dupes = indication_edges.groupby(
    ['x_index', 'y_index']
).size().reset_index(name='count')

# Filter for duplicates
print(dupes[dupes['count'] > 1])

Empty DataFrame
Columns: [x_index, y_index, count]
Index: []


In [188]:
merged = indication_edges.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='x_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'x_name',
    'node_type': 'x_type'
}).drop(columns=['node_index'])

merged = merged.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='y_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'y_name',
    'node_type': 'y_type'
}).drop(columns=['node_index'])

def clarify_roles(row):
    if row['x_type'] == 'drug':
        return pd.Series({
            'drug_index': row['x_index'],
            'drug_name': row['x_name'],
            'disease_index': row['y_index'],
            'disease_name': row['y_name']
        })
    else:
        return pd.Series({
            'drug_index': row['y_index'],
            'drug_name': row['y_name'],
            'disease_index': row['x_index'],
            'disease_name': row['x_name']
        })

indication_edges_ordered = merged.apply(clarify_roles, axis=1)
indication_edges_ordered = indication_edges_ordered[['drug_index', 'disease_index']]
indication_edges_ordered

Unnamed: 0,drug_index,disease_index
0,16687,33577
1,16687,36035
2,20297,33577
3,20297,36035
4,16693,33577
...,...,...
18771,14471,84333
18772,16634,27527
18773,16634,38622
18774,16634,28673


In [189]:
dupes = indication_edges_ordered.groupby(
    ['drug_index', 'disease_index']
).size().reset_index(name='count')

# Filter for duplicates
print(dupes[dupes['count'] > 1])

      drug_index  disease_index  count
0          14014          33675      2
1          14014          37888      2
2          14014          39773      2
3          14014          83781      2
4          14014          83791      2
...          ...            ...    ...
9383       20651          28518      2
9384       20651          32680      2
9385       20651          35475      2
9386       20652          32483      2
9387       20653          28197      2

[9388 rows x 3 columns]


In [190]:
indication_edges_ordered = indication_edges_ordered.drop_duplicates()
print(indication_edges_ordered)
dupes = indication_edges_ordered.groupby(
    ['drug_index', 'disease_index']
).size().reset_index(name='count')

# Filter for duplicates
print(dupes[dupes['count'] > 1])

      drug_index  disease_index
0          16687          33577
1          16687          36035
2          20297          33577
3          20297          36035
4          16693          33577
...          ...            ...
9383       14471          84333
9384       16634          27527
9385       16634          38622
9386       16634          28673
9387       17237          39497

[9388 rows x 2 columns]
Empty DataFrame
Columns: [drug_index, disease_index, count]
Index: []


In [221]:
indication_edges_ordered.to_csv('./data_analysis/indication_edges.csv', index=False)

# Step2: Get 2 DF for indication (drug-gene gene-disease)

In [191]:
indication_cot = [
    "drug_protein",
    "disease_protein"
]

indication_cot_edges = df_edges[df_edges['relation'].isin(indication_cot)]
print(indication_cot_edges)

                relation display_relation  x_index  y_index
321075      drug_protein          carrier    14012     7183
321076      drug_protein          carrier    14012     8256
321077      drug_protein          carrier    14013     4107
321078      drug_protein          carrier    14014     1424
321079      drug_protein          carrier    14015     1424
...                  ...              ...      ...      ...
6030122  disease_protein  associated with    28149     4152
6030123  disease_protein  associated with    28181       59
6030124  disease_protein  associated with    31190     5826
6030125  disease_protein  associated with    33606    10422
6030126  disease_protein  associated with    30032    33802

[212128 rows x 4 columns]


In [192]:
disease_gene = indication_cot_edges[indication_cot_edges['relation'] == 'disease_protein'][['x_index', 'y_index']]
drug_gene = indication_cot_edges[indication_cot_edges['relation'] == 'drug_protein'][['x_index', 'y_index']]
print(disease_gene.head(5))
print(drug_gene.head(5))

         x_index  y_index
3235582     7097    28313
3235583     2174    28313
3235584     8038    28313
3235585     5925    28313
3235586      238    28313
        x_index  y_index
321075    14012     7183
321076    14012     8256
321077    14013     4107
321078    14014     1424
321079    14015     1424


In [202]:
merged = disease_gene.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='x_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'x_name',
    'node_type': 'x_type'
}).drop(columns=['node_index'])

merged = merged.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='y_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'y_name',
    'node_type': 'y_type'
}).drop(columns=['node_index'])

def clarify_roles(row):
    if row['x_type'] == 'gene/protein':
        return pd.Series({
            'gene_index': row['x_index'],
            'gene_name': row['x_name'],
            'disease_index': row['y_index'],
            'disease_name': row['y_name']
        })
    else:
        return pd.Series({
            'gene_index': row['y_index'],
            'gene_name': row['y_name'],
            'disease_index': row['x_index'],
            'disease_name': row['x_name']
        })

disease_gene_orderd = merged.apply(clarify_roles, axis=1)
disease_gene_orderd = disease_gene_orderd[['gene_index', 'disease_index']]
disease_gene_orderd = disease_gene_orderd.drop_duplicates()

In [203]:

dupes = disease_gene_orderd.groupby(
    ['gene_index', 'disease_index']
).size().reset_index(name='count')

# Filter for duplicates
print(dupes[dupes['count'] > 1])

Empty DataFrame
Columns: [gene_index, disease_index, count]
Index: []


In [222]:
disease_gene_orderd.to_csv('./data_analysis/disease_gene.csv', index=False)

In [206]:
merged = drug_gene.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='x_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'x_name',
    'node_type': 'x_type'
}).drop(columns=['node_index'])

merged = merged.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='y_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'y_name',
    'node_type': 'y_type'
}).drop(columns=['node_index'])

def clarify_roles(row):
    if row['x_type'] == 'gene/protein':
        return pd.Series({
            'gene_index': row['x_index'],
            'gene_name': row['x_name'],
            'drug_index': row['y_index'],
            'drug_name': row['y_name']
        })
    else:
        return pd.Series({
            'gene_index': row['y_index'],
            'gene_name': row['y_name'],
            'drug_index': row['x_index'],
            'drug_name': row['x_name']
        })

drug_gene_ordered = merged.apply(clarify_roles, axis=1)
drug_gene_ordered = drug_gene_ordered[['gene_index', 'drug_index']]
drug_gene_ordered = drug_gene_ordered.drop_duplicates()
drug_gene_ordered

Unnamed: 0,gene_index,drug_index
0,7183,14012
1,8256,14012
2,4107,14013
3,1424,14014
4,1424,14015
...,...,...
25648,6069,14491
25649,12188,14490
25650,12188,14491
25651,10452,14712


In [207]:
dupes = drug_gene_ordered.groupby(
    ['gene_index', 'drug_index']
).size().reset_index(name='count')

# Filter for duplicates
print(dupes[dupes['count'] > 1])

Empty DataFrame
Columns: [gene_index, drug_index, count]
Index: []


In [223]:
drug_gene_ordered.to_csv('./data_analysis/drug_gene.csv', index=False)

In [208]:
# ---- Inference Path 1: Disease → Protein → Drug ----
all_inferred_disease_drug_protein = pd.merge(disease_gene_orderd, drug_gene_ordered, on='gene_index')
print(all_inferred_disease_drug_protein)
all_inferred_disease_drug_protein = all_inferred_disease_drug_protein.drop_duplicates()
print(all_inferred_disease_drug_protein)

        gene_index  disease_index  drug_index
0             7097          28313       14679
1             7097          28313       14012
2             7097          28313       14680
3             2174          28313       15852
4             2174          28313       14280
...            ...            ...         ...
556892        4152          28149       14499
556893        4152          28149       14050
556894        4152          28149       15752
556895        5826          31190       14490
556896        5826          31190       14491

[556897 rows x 3 columns]
        gene_index  disease_index  drug_index
0             7097          28313       14679
1             7097          28313       14012
2             7097          28313       14680
3             2174          28313       15852
4             2174          28313       14280
...            ...            ...         ...
556892        4152          28149       14499
556893        4152          28149       14050
556894 

In [224]:
all_inferred_disease_drug_protein.to_csv('./data_analysis/drug_gene_disease.csv', index=False)

In [209]:
dupes = all_inferred_disease_drug_protein.groupby(
    ['gene_index', 'disease_index', 'drug_index']
).size().reset_index(name='count')

# Filter for duplicates
print(dupes[dupes['count'] > 1])

Empty DataFrame
Columns: [gene_index, disease_index, drug_index, count]
Index: []


In [257]:
valid_ddprotein_index = pd.merge(all_inferred_disease_drug_protein, indication_edges_ordered, on=['drug_index', 'disease_index'])
print(valid_ddprotein_index)

      gene_index  disease_index  drug_index
0           1659          28313       14140
1           1659          28313       14161
2           1659          28313       16011
3           1659          28313       15098
4           1659          28313       14178
...          ...            ...         ...
4731        8192          33605       15417
4732         361          33605       14641
4733        2329          33605       14807
4734       11665          33625       20290
4735       22105          33677       14578

[4736 rows x 3 columns]


In [258]:
valid_ddprotein_index.to_csv('./data_analysis/valid_ddprotein_index.csv', index=False)

In [259]:
dupes = valid_ddprotein_index.groupby(
    ['disease_index', 'drug_index']
).size().reset_index(name='count')

# Filter for duplicates
print(dupes[dupes['count'] >= 1])

      disease_index  drug_index  count
0             27219       14208      1
1             27219       14478      1
2             27249       14208      1
3             27275       14449      1
4             27285       14560      1
...             ...         ...    ...
1926          84305       14207      1
1927          84310       14223      3
1928          84310       14323     11
1929          84326       15018      1
1930          84327       15018      1

[1931 rows x 3 columns]


# Step3: Get Positive Train set for valid drug-disease-gene

In [260]:
print(valid_ddprotein_index)

      gene_index  disease_index  drug_index
0           1659          28313       14140
1           1659          28313       14161
2           1659          28313       16011
3           1659          28313       15098
4           1659          28313       14178
...          ...            ...         ...
4731        8192          33605       15417
4732         361          33605       14641
4733        2329          33605       14807
4734       11665          33625       20290
4735       22105          33677       14578

[4736 rows x 3 columns]


In [261]:
valid_ddprotein = valid_ddprotein_index.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='drug_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'drug_name', 'node_type': 'drug_type'}).drop(columns=['node_index'])

valid_ddprotein = valid_ddprotein.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='disease_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'disease_name', 'node_type': 'disease_type'}).drop(columns=['node_index'])

valid_ddprotein = valid_ddprotein.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='gene_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'gene_name', 'node_type': 'gene_type'}).drop(columns=['node_index'])


valid_ddprotein = valid_ddprotein.drop_duplicates()
print(valid_ddprotein.head(10))

   gene_index  disease_index  drug_index        drug_name drug_type  \
0        1659          28313       14140      Ziprasidone      drug   
1        1659          28313       14161       Olanzapine      drug   
2        1659          28313       16011         Loxapine      drug   
3        1659          28313       15098        Promazine      drug   
4        1659          28313       14178   Chlorpromazine      drug   
5        1659          28313       14952      Haloperidol      drug   
6        1659          28313       14223      Risperidone      drug   
7        1659          28313       14913  Trifluoperazine      drug   
8        1659          28313       14829      Flupentixol      drug   
9        1659          28313       14323     Aripiprazole      drug   

    disease_name disease_type gene_name     gene_type  
0  schizophrenia      disease    ADRA1A  gene/protein  
1  schizophrenia      disease    ADRA1A  gene/protein  
2  schizophrenia      disease    ADRA1A  gene/prot

In [262]:
valid_ddprotein.to_csv("./data_analysis/valid_ddprotein.csv", index=False)

In [None]:
train_positive_protein = valid_ddprotein.sample(n=500, random_state=42).reset_index(drop=True)
train_positive_protein_index = train_positive_protein[['drug_index', 'disease_index', 'gene_index']]


In [None]:
train_positive_protein.to_csv("./data_analysis/train_positive_protein.csv", index=False)

# Step4: Get drug-disease-phenotype as Positive Train set
disease_phenotype_positive, then can be treated by XX drug

In [308]:
phenotype_cot = [
    "disease_phenotype_positive"
]

phenotype_cot_edges = df_edges[df_edges['relation'].isin(phenotype_cot)]
disease_pheno = phenotype_cot_edges[phenotype_cot_edges['relation'] == 'disease_phenotype_positive'][['x_index', 'y_index']]
disease_pheno

Unnamed: 0,x_index,y_index
3085246,27472,24442
3085247,27158,22784
3085248,27158,84344
3085249,27158,22488
3085250,27158,22272
...,...,...
5949711,33713,25218
5949712,32207,22574
5949713,32207,26287
5949714,33561,22204


In [309]:
merged = disease_pheno.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='x_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'x_name',
    'node_type': 'x_type'
}).drop(columns=['node_index'])

merged = merged.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='y_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'y_name',
    'node_type': 'y_type'
}).drop(columns=['node_index'])

def clarify_roles(row):
    if row['x_type'] == 'effect/phenotype':
        return pd.Series({
            'phenotype_index': row['x_index'],
            'phenotype_name': row['x_name'],
            'disease_index': row['y_index'],
            'disease_name': row['y_name']
        })
    else:
        return pd.Series({
            'phenotype_index': row['y_index'],
            'phenotype_name': row['y_name'],
            'disease_index': row['x_index'],
            'disease_name': row['x_name']
        })

disease_pheno_ordered = merged.apply(clarify_roles, axis=1)
disease_pheno_ordered = disease_pheno_ordered[['phenotype_index', 'disease_index']]
disease_pheno_ordered = disease_pheno_ordered.drop_duplicates()


In [310]:
dupes = disease_pheno_ordered.groupby(
    ['phenotype_index', 'disease_index']
).size().reset_index(name='count')

# Filter for duplicates
print(dupes[dupes['count'] > 1])
print(disease_pheno_ordered)

Empty DataFrame
Columns: [phenotype_index, disease_index, count]
Index: []
        phenotype_index  disease_index
0                 24442          27472
1                 22784          27158
2                 84344          27158
3                 22488          27158
4                 22272          27158
...                 ...            ...
150331            25218          33713
150332            22574          32207
150333            26287          32207
150334            22204          33561
150335            24588          38100

[150317 rows x 2 columns]


In [311]:
print(indication_edges_ordered)

      drug_index  disease_index
0          16687          33577
1          16687          36035
2          20297          33577
3          20297          36035
4          16693          33577
...          ...            ...
9383       14471          84333
9384       16634          27527
9385       16634          38622
9386       16634          28673
9387       17237          39497

[9388 rows x 2 columns]


In [312]:
valid_ddphenotype_index = pd.merge(indication_edges_ordered, disease_pheno_ordered, on='disease_index', how='inner').drop_duplicates()
valid_ddphenotype_index.to_csv("./data_analysis/valid_ddphenotype_index.csv", index=False)

In [313]:
valid_ddphenotype = valid_ddphenotype_index.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='drug_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'drug_name', 'node_type': 'drug_type'}).drop(columns=['node_index'])

valid_ddphenotype = valid_ddphenotype.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='disease_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'disease_name', 'node_type': 'disease_type'}).drop(columns=['node_index'])

valid_ddphenotype = valid_ddphenotype.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='phenotype_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'phenotype_name', 'node_type': 'phenotype_type'}).drop(columns=['node_index'])


valid_ddphenotype = valid_ddphenotype.drop_duplicates()
print(valid_ddphenotype.head(10))

   drug_index  disease_index  phenotype_index   drug_name drug_type  \
0       16687          33577            94180  Fosinopril      drug   
1       16687          33577            23513  Fosinopril      drug   
2       16687          33577            94390  Fosinopril      drug   
3       16687          33577            22741  Fosinopril      drug   
4       20297          33577            94180   Imidapril      drug   
5       20297          33577            23513   Imidapril      drug   
6       20297          33577            94390   Imidapril      drug   
7       20297          33577            22741   Imidapril      drug   
8       16693          33577            94180  Cilazapril      drug   
9       16693          33577            23513  Cilazapril      drug   

            disease_name disease_type  \
0  hypertensive disorder      disease   
1  hypertensive disorder      disease   
2  hypertensive disorder      disease   
3  hypertensive disorder      disease   
4  hypertensi

In [314]:
valid_ddphenotype.to_csv("./data_analysis/valid_ddphenotype.csv", index=False)

In [315]:
# First, sample 500 for training
train_positive_phenotype = valid_ddphenotype.sample(n=500, random_state=42).reset_index(drop=True)
train_positive_phenotype_index = train_positive_phenotype[['drug_index', 'disease_index', 'phenotype_index']]

In [316]:
train_positive_phenotype.to_csv("./data_analysis/train_positive_phenotype.csv", index=False)

# Step5: Get the remaining indication edges and select 500 as test set

In [296]:
phenotype_pairs = train_positive_phenotype_index[['drug_index', 'disease_index']]
protein_pairs = train_positive_protein_index[['drug_index', 'disease_index']]

merged_train_index = pd.concat([phenotype_pairs, protein_pairs]).drop_duplicates().reset_index(drop=True)
merged_train_index.drop_duplicates()

Unnamed: 0,drug_index,disease_index
0,15772,28547
1,20597,33643
2,14042,29078
3,15860,32325
4,14390,33609
...,...,...
769,15146,38957
770,14956,38242
771,14310,37703
772,17605,29375


In [297]:
remaining_indication_index = indication_edges_ordered.merge(
    merged_train_index,
    on=['drug_index', 'disease_index'],
    how='left',
    indicator=True
)

remaining_indication_index = remaining_indication_index[remaining_indication_index['_merge'] == 'left_only'].drop(columns=['_merge'])
remaining_indication_index

Unnamed: 0,drug_index,disease_index
0,16687,33577
1,16687,36035
2,20297,33577
3,20297,36035
4,16693,33577
...,...,...
9383,14471,84333
9384,16634,27527
9385,16634,38622
9386,16634,28673


In [335]:
phenotype_edges = df_edges[df_edges['relation'] == 'disease_phenotype_positive']
protein_edges = df_edges[df_edges['relation'] == 'disease_protein']

In [337]:
def get_related_items(disease_id, edges, target='phenotype'):
    matches_x = edges[edges['x_index'] == disease_id]['y_index'].tolist()
    matches_y = edges[edges['y_index'] == disease_id]['x_index'].tolist()
    return list(set(matches_x + matches_y))

remaining_indication_pg = remaining_indication_index
disease_ids = remaining_indication_index['disease_index'].unique()

disease_to_phenotypes = {
    d: get_related_items(d, phenotype_edges, target='phenotype') for d in disease_ids
}
disease_to_proteins = {
    d: get_related_items(d, protein_edges, target='protein') for d in disease_ids
}

remaining_indication_pg['related_phenotypes'] = remaining_indication_pg['disease_index'].map(disease_to_phenotypes)
remaining_indication_pg['related_proteins'] = remaining_indication_pg['disease_index'].map(disease_to_proteins)


filtered_remaining_indication = remaining_indication_pg[
    remaining_indication_pg['related_phenotypes'].apply(lambda x: len(x) > 0) &
    remaining_indication_pg['related_proteins'].apply(lambda x: len(x) > 0)
].reset_index(drop=True)
filtered_remaining_indication

Unnamed: 0,drug_index,disease_index,related_phenotypes,related_proteins
0,16687,33577,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."
1,20297,33577,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."
2,16693,33577,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."
3,14944,33577,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."
4,15762,33577,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."
...,...,...,...,...
2549,15858,31384,"[91907, 91908, 24582, 84361, 22282, 23179, 926...","[10376, 4113, 6931, 2967, 13210, 3106, 10021, ..."
2550,15858,30905,"[85504, 22788, 22791, 85005, 22675, 22681, 244...","[10376, 4113, 6931, 2967, 13210, 3106, 10021, ..."
2551,17286,27446,"[22658, 22788, 24199, 23816, 86664, 87179, 848...","[9001, 7684, 12150]"
2552,16634,27527,"[90898, 85019, 90273, 88994, 84774, 84519, 230...","[2515, 11356]"


In [339]:
test_positive = filtered_remaining_indication.sample(n=500, random_state=42).reset_index(drop=True)
test_positive.to_csv("./data_analysis/test_positive.csv", index=False)

In [341]:
test_positive_with_drug = test_positive.merge(
    df_node,
    left_on='drug_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'drug_name',
    'node_type': 'drug_type'
}).drop(columns=['node_index'])

test_positive_full = test_positive_with_drug.merge(
    df_node,
    left_on='disease_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'disease_name',
    'node_type': 'disease_type'
}).drop(columns=['node_index'])

columns_to_keep = [
    'drug_index',
    'disease_index',
    'drug_type',
    'drug_name',
    'disease_type',
    'disease_name',
    'related_phenotypes',
    'related_proteins'
]

cleaned_test_positive = test_positive_full[columns_to_keep]
cleaned_test_positive

Unnamed: 0,drug_index,disease_index,drug_type,drug_name,disease_type,disease_name,related_phenotypes,related_proteins
0,14520,33651,drug,Allantoin,disease,seborrheic dermatitis,"[94334, 22550]","[1641, 150]"
1,15098,33683,drug,Promazine,disease,urticaria (disease),"[94533, 22540, 25518, 94454, 94520]","[3552, 417, 8834, 5667, 4968, 2889, 7083, 1004..."
2,15493,29637,drug,Bromocriptine,disease,prolactin producing pituitary gland tumor,"[25479, 84488, 90495, 25620, 25599, 84764, 225...","[3104, 2401, 2789, 3046, 1884, 22023, 5320, 10..."
3,14275,30654,drug,Doxorubicin,disease,acute lymphoblastic/lymphocytic leukemia,"[23970, 86308, 22757, 84890]","[2, 9221, 1549, 4111, 13840, 4114, 1558, 1559,..."
4,20375,33714,drug,Aluminum hydroxide,disease,esophagitis (disease),"[93737, 22987, 22988, 94534]","[2978, 9029, 4425, 12842, 1740, 12766, 4959]"
...,...,...,...,...,...,...,...,...
495,15778,28208,drug,Tolazamide,disease,type 2 diabetes mellitus,"[22759, 24491, 22483, 91382, 85559]","[34295, 13830, 7688, 10248, 34312, 34317, 3431..."
496,15179,33703,drug,Nicotinamide,disease,acne (disease),"[22539, 94372, 94243, 94244]","[12305, 1587, 6229, 1485]"
497,14269,32531,drug,Methylprednisolone,disease,mantle cell lymphoma,"[22952, 25547, 23214, 24143, 85040, 24337, 256...","[769, 134, 6926, 4369, 7186, 7441, 8212, 1558,..."
498,15451,33577,drug,Bisoprolol,disease,hypertensive disorder,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."


In [342]:
cleaned_test_positive.to_csv("./data_analysis/test_positive_full.csv", index=False)

# Step6: Get all negative sample from the contraindication edges

In [287]:
contraindication_relation = "contraindication"

contraindication_edges = df_edges[df_edges['relation'] == contraindication_relation][['x_index', 'y_index']]
print(contraindication_edges)

         x_index  y_index
346728     15193    33577
346729     15193    36035
346732     14483    33577
346733     14483    36035
346734     16476    33577
...          ...      ...
5776145    35751    14251
5776146    35846    20456
5776147    35751    20456
5776148    27446    17286
5776157    84334    18277

[61350 rows x 2 columns]


In [288]:
dupes = contraindication_edges.groupby(
    ['x_index', 'y_index']
).size().reset_index(name='count')

# Filter for duplicates
print(dupes[dupes['count'] > 1])

Empty DataFrame
Columns: [x_index, y_index, count]
Index: []


In [289]:
merged = contraindication_edges.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='x_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'x_name',
    'node_type': 'x_type'
}).drop(columns=['node_index'])

merged = merged.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='y_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'y_name',
    'node_type': 'y_type'
}).drop(columns=['node_index'])

def clarify_roles(row):
    if row['x_type'] == 'drug':
        return pd.Series({
            'drug_index': row['x_index'],
            'drug_name': row['x_name'],
            'disease_index': row['y_index'],
            'disease_name': row['y_name']
        })
    else:
        return pd.Series({
            'drug_index': row['y_index'],
            'drug_name': row['y_name'],
            'disease_index': row['x_index'],
            'disease_name': row['x_name']
        })

contraindication_edges_ordered = merged.apply(clarify_roles, axis=1)
contraindication_edges_ordered = contraindication_edges_ordered[['drug_index', 'disease_index']]
contraindication_edges_ordered = contraindication_edges_ordered.drop_duplicates()

In [290]:
print(contraindication_edges_ordered)
dupes = indication_edges_ordered.groupby(
    ['drug_index', 'disease_index']
).size().reset_index(name='count')

# Filter for duplicates
print(dupes[dupes['count'] > 1])

       drug_index  disease_index
0           15193          33577
1           15193          36035
2           14483          33577
3           14483          36035
4           16476          33577
...           ...            ...
30670       14251          35751
30671       20456          35846
30672       20456          35751
30673       17286          27446
30674       18277          84334

[30675 rows x 2 columns]
Empty DataFrame
Columns: [drug_index, disease_index, count]
Index: []


In [291]:
contraindication_edges_ordered.to_csv('./data_analysis/contraindication_edges.csv', index=False)

# Step7: Get all drug disease phenotype for Negative Train set

In [302]:
# contraindication_edges_ordered

contra_ddphenotype_index = pd.merge(contraindication_edges_ordered, disease_pheno_ordered, on='disease_index', how='inner').drop_duplicates()
contra_ddphenotype_index.to_csv("./data_analysis/contra_ddphenotype_index.csv", index=False)
contra_ddphenotype_index

Unnamed: 0,drug_index,disease_index,phenotype_index
0,15193,33577,94180
1,15193,33577,23513
2,15193,33577,94390
3,15193,33577,22741
4,14483,33577,94180
...,...,...,...
179779,17286,27446,22354
179780,17286,27446,22699
179781,17286,27446,23101
179782,17286,27446,22658


In [304]:
train_negative_phenotype_index = contra_ddphenotype_index.sample(n=500, random_state=42).reset_index(drop=True)
train_negative_phenotype_index

Unnamed: 0,drug_index,disease_index,phenotype_index
0,17064,27626,84771
1,14249,32719,25974
2,14976,31650,84675
3,15116,33091,22437
4,14561,31502,24199
...,...,...,...
495,16071,28158,23157
496,20303,31122,23368
497,14026,32899,89372
498,14385,32909,24227


In [307]:
train_negative_phenotype = train_negative_phenotype_index.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='drug_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'drug_name', 'node_type': 'drug_type'}).drop(columns=['node_index'])

train_negative_phenotype = train_negative_phenotype.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='disease_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'disease_name', 'node_type': 'disease_type'}).drop(columns=['node_index'])

train_negative_phenotype = train_negative_phenotype.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='phenotype_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'phenotype_name', 'node_type': 'phenotype_type'}).drop(columns=['node_index'])


train_negative_phenotype = train_negative_phenotype.drop_duplicates()
print(train_negative_phenotype.head(10))

   drug_index  disease_index  phenotype_index             drug_name drug_type  \
0       17064          27626            84771           Meprobamate      drug   
1       14249          32719            25974       Pseudoephedrine      drug   
2       14976          31650            84675             Triazolam      drug   
3       15116          33091            22437            Paroxetine      drug   
4       14561          31502            24199       Cholecalciferol      drug   
5       20157          27703            22859           Guaifenesin      drug   
6       15821          28052            23011  Isosorbide dinitrate      drug   
7       14861          32714            25663           Neostigmine      drug   
8       15908          32967            22484         Glutamic acid      drug   
9       14833          31033            94148            Riboflavin      drug   

                                 disease_name disease_type  \
0                            thrombocytopenia 

In [317]:
train_negative_phenotype.to_csv("./data_analysis/train_negative_phenotype.csv", index=False)

# Step8: Get all drug disease gene for Negative Train set

In [318]:
contra_ddgene_index = pd.merge(contraindication_edges_ordered, disease_gene_orderd, on='disease_index', how='inner').drop_duplicates()
contra_ddgene_index.to_csv("./data_analysis/contra_ddgene_index.csv", index=False)
contra_ddgene_index

Unnamed: 0,drug_index,disease_index,gene_index
0,15193,33577,3912
1,15193,33577,8805
2,15193,33577,2667
3,15193,33577,4794
4,15193,33577,10901
...,...,...,...
1308217,20456,35751,35031
1308218,17286,27446,9001
1308219,17286,27446,7684
1308220,17286,27446,12150


In [319]:
train_negative_gene_index = contra_ddgene_index.sample(n=500, random_state=42).reset_index(drop=True)
train_negative_gene_index

Unnamed: 0,drug_index,disease_index,gene_index
0,14198,36187,4497
1,15495,37703,1876
2,15078,36035,10901
3,14806,83760,10047
4,14286,27933,2551
...,...,...,...
495,18277,83760,3466
496,15783,35641,1863
497,15776,37703,2591
498,14476,35966,1497


In [321]:
train_negative_gene = train_negative_gene_index.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='drug_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'drug_name', 'node_type': 'drug_type'}).drop(columns=['node_index'])

train_negative_gene = train_negative_gene.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='disease_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'disease_name', 'node_type': 'disease_type'}).drop(columns=['node_index'])

train_negative_gene = train_negative_gene.merge(
    df_node[['node_index', 'node_name', 'node_type']],
    left_on='gene_index',
    right_on='node_index',
    how='left'
).rename(columns={'node_name': 'gene_name', 'node_type': 'gene_type'}).drop(columns=['node_index'])


train_negative_gene = train_negative_gene.drop_duplicates()
print(train_negative_gene.head(10))

   drug_index  disease_index  gene_index               drug_name drug_type  \
0       14198          36187        4497               Clonidine      drug   
1       15495          37703        1876               Estazolam      drug   
2       15078          36035       10901                Tramadol      drug   
3       14806          83760       10047              Entacapone      drug   
4       14286          27933        2551              Rifampicin      drug   
5       14476          36527         554  Testosterone cypionate      drug   
6       14235          32617        2423               Estradiol      drug   
7       14293          28407        8336               Glipizide      drug   
8       14629          27933        8757              Tryptophan      drug   
9       14990          33605        8805        Chlorpheniramine      drug   

                                disease_name disease_type gene_name  \
0                    coronary artery disease      disease      NOS3   

In [None]:
train_negative_gene.to_csv("./data_analysis/train_negative_gene.csv", index=False)

# Step9: Get all remaining contraindication edges, 

In [329]:
gene_neg_pairs = train_negative_gene_index[['drug_index', 'disease_index']]
phenotype_neg_pairs = train_negative_phenotype_index[['drug_index', 'disease_index']]

merged_neg_train_index = pd.concat([phenotype_neg_pairs, gene_neg_pairs]).drop_duplicates().reset_index(drop=True)
merged_neg_train_index

Unnamed: 0,drug_index,disease_index
0,17064,27626
1,14249,32719
2,14976,31650
3,15116,33091
4,14561,31502
...,...,...
937,18277,83760
938,15783,35641
939,15776,37703
940,14476,35966


In [330]:
remaining_contraindication_index = contraindication_edges_ordered.merge(
    merged_neg_train_index,
    on=['drug_index', 'disease_index'],
    how='left',
    indicator=True
)

remaining_contraindication_index = remaining_contraindication_index[remaining_contraindication_index['_merge'] == 'left_only'].drop(columns=['_merge'])
remaining_contraindication_index

Unnamed: 0,drug_index,disease_index
0,15193,33577
1,15193,36035
2,14483,33577
3,14483,36035
4,16476,33577
...,...,...
30670,14251,35751
30671,20456,35846
30672,20456,35751
30673,17286,27446


In [343]:
def get_related_items(disease_id, edges, target='phenotype'):
    matches_x = edges[edges['x_index'] == disease_id]['y_index'].tolist()
    matches_y = edges[edges['y_index'] == disease_id]['x_index'].tolist()
    return list(set(matches_x + matches_y))

remaining_contraindication_pg = remaining_contraindication_index
disease_ids = remaining_contraindication_index['disease_index'].unique()

disease_to_phenotypes = {
    d: get_related_items(d, phenotype_edges, target='phenotype') for d in disease_ids
}
disease_to_proteins = {
    d: get_related_items(d, protein_edges, target='protein') for d in disease_ids
}

remaining_contraindication_pg['related_phenotypes'] = remaining_contraindication_pg['disease_index'].map(disease_to_phenotypes)
remaining_contraindication_pg['related_proteins'] = remaining_contraindication_pg['disease_index'].map(disease_to_proteins)


filtered_remaining_contraindication = remaining_contraindication_pg[
    remaining_contraindication_pg['related_phenotypes'].apply(lambda x: len(x) > 0) &
    remaining_contraindication_pg['related_proteins'].apply(lambda x: len(x) > 0)
].reset_index(drop=True)

filtered_remaining_contraindication

Unnamed: 0,drug_index,disease_index,related_phenotypes,related_proteins
0,15193,33577,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."
1,14483,33577,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."
2,16476,33577,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."
3,20148,33577,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."
4,15087,33577,"[23513, 94180, 22741, 94390]","[8805, 3912, 11080, 2667, 33776, 4497, 22002, ..."
...,...,...,...,...
10538,16741,33579,"[94337, 94338, 22570, 26061, 24371, 94421, 243...","[2002, 2228, 34124]"
10539,16741,27462,"[22569, 22759, 24389, 33725]","[2002, 2228, 34124]"
10540,14919,27761,"[23172, 26375, 85515, 85131, 86929, 89121, 234...","[7665, 1682, 2103]"
10541,16593,28591,"[23168, 25608, 25865, 25596, 24461, 26127, 262...",[1722]


In [344]:
test_negative = filtered_remaining_contraindication.sample(n=500, random_state=42).reset_index(drop=True)
test_negative.to_csv("./data_analysis/test_negative.csv", index=False)

In [345]:
test_negative

Unnamed: 0,drug_index,disease_index,related_phenotypes,related_proteins
0,20391,28811,"[25699, 26180, 22759, 22472, 85642, 85419, 850...","[7457, 12289, 5829, 7302, 13706, 34890, 10413,..."
1,15294,28547,"[25858, 84357, 24327, 25224, 24331, 23564, 226...","[8192, 22017, 519, 11, 8212, 9237, 1567, 6175,..."
2,14783,33679,"[94408, 23033]","[1760, 8805, 12165, 172, 213, 1494, 509]"
3,20239,29188,"[22124, 22933, 26180]","[6671, 6002, 1908, 8733, 3231]"
4,17243,27326,"[89225, 94223, 26257, 84369, 94225, 84887, 238...","[7177, 9233, 34962, 7198, 1968, 177, 10417, 35..."
...,...,...,...,...
495,15542,28456,"[84363, 89132, 85457, 84885, 27094, 23000, 894...",[7937]
496,15584,31790,"[25522, 88486]","[13976, 22077, 13230]"
497,14168,31493,"[24224, 87330, 26180, 22759, 86573, 23950, 861...","[12165, 33925, 33926, 10376, 33929, 2698, 3405..."
498,20427,31918,"[88450, 85125, 22277, 84616, 22156, 24462, 245...","[9905, 10181]"


In [346]:
test_negative_with_drug = test_negative.merge(
    df_node,
    left_on='drug_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'drug_name',
    'node_type': 'drug_type'
}).drop(columns=['node_index'])

test_negative_full = test_negative_with_drug.merge(
    df_node,
    left_on='disease_index',
    right_on='node_index',
    how='left'
).rename(columns={
    'node_name': 'disease_name',
    'node_type': 'disease_type'
}).drop(columns=['node_index'])

columns_to_keep = [
    'drug_index',
    'disease_index',
    'drug_type',
    'drug_name',
    'disease_type',
    'disease_name',
    'related_phenotypes',
    'related_proteins'
]

cleaned_test_negative = test_negative_full[columns_to_keep]
cleaned_test_negative

Unnamed: 0,drug_index,disease_index,drug_type,drug_name,disease_type,disease_name,related_phenotypes,related_proteins
0,20391,28811,drug,Silver sulfadiazine,disease,cystinuria,"[25699, 26180, 22759, 22472, 85642, 85419, 850...","[7457, 12289, 5829, 7302, 13706, 34890, 10413,..."
1,15294,28547,drug,Tacrine,disease,Parkinson disease,"[25858, 84357, 24327, 25224, 24331, 23564, 226...","[8192, 22017, 519, 11, 8212, 9237, 1567, 6175,..."
2,14783,33679,drug,Caffeine,disease,potassium deficiency disease,"[94408, 23033]","[1760, 8805, 12165, 172, 213, 1494, 509]"
3,20239,29188,drug,Penicillamine,disease,autosomal recessive severe congenital neutrope...,"[22124, 22933, 26180]","[6671, 6002, 1908, 8733, 3231]"
4,17243,27326,drug,Alimemazine,disease,long QT syndrome,"[89225, 94223, 26257, 84369, 94225, 84887, 238...","[7177, 9233, 34962, 7198, 1968, 177, 10417, 35..."
...,...,...,...,...,...,...,...,...
495,15542,28456,drug,Milnacipran,disease,nephrogenic syndrome of inappropriate antidiur...,"[84363, 89132, 85457, 84885, 27094, 23000, 894...",[7937]
496,15584,31790,drug,Lanreotide,disease,cholelithiasis,"[25522, 88486]","[13976, 22077, 13230]"
497,14168,31493,drug,Spironolactone,disease,pancreatitis,"[24224, 87330, 26180, 22759, 86573, 23950, 861...","[12165, 33925, 33926, 10376, 33929, 2698, 3405..."
498,20427,31918,drug,Cholestyramine,disease,phenylketonuria,"[88450, 85125, 22277, 84616, 22156, 24462, 245...","[9905, 10181]"


In [347]:
cleaned_test_negative.to_csv("./data_analysis/test_negative_full.csv", index=False)