# Train - Test - Validation Splits

In [1]:
import pandas as pd
import os.path as osp
import os
import networkx as nx
from copy import deepcopy

## Load in the final KGs

In [2]:
KG_DIR = '../data/kg'

In [3]:
kg = pd.read_csv(osp.join(KG_DIR, 'final_kg.tsv'), sep='\t')
kg.drop_duplicates(inplace=True)

kg_protclass = pd.read_csv(osp.join(KG_DIR, 'final_kg_subclassed.tsv'), sep='\t')
kg_protclass.drop_duplicates(inplace=True)

small_kg = pd.read_csv(osp.join(KG_DIR, 'small_kg.tsv'), sep='\t')
small_kg.drop_duplicates(inplace=True)

xsmall_kg = pd.read_csv(osp.join(KG_DIR, 'xsmall_kg.tsv'), sep='\t')
xsmall_kg.drop_duplicates(inplace=True)

In [4]:
drug_bp_pairs = kg.loc[kg['edge_type'] == 'induces']

drug_bp_pairs_protclass = kg_protclass.loc[kg_protclass['edge_type'] == 'induces']

drug_bp_pairs_small = small_kg.loc[small_kg['edge_type'] == 'induces']

drug_bp_pairs_xsmall = xsmall_kg.loc[xsmall_kg['edge_type'] == 'induces']

In [5]:
print(f"There are {len(drug_bp_pairs)} positive drug-BP pairs in the final KG")

print(f"There are {len(drug_bp_pairs_protclass)} positive drug-BP pairs in the protein-classed KG")

print(f"There are {len(drug_bp_pairs_small)} positive drug-BP pairs in the small KG")

print(f"There are {len(drug_bp_pairs_xsmall)} positive drug-BP pairs in the xsmall KG")

There are 1622 positive drug-BP pairs in the final KG
There are 1622 positive drug-BP pairs in the protein-classed KG
There are 457 positive drug-BP pairs in the small KG
There are 291 positive drug-BP pairs in the xsmall KG


## Load in the DrugMechDB pairs that go in the test set

In [6]:
dm_db_pairs = pd.read_csv(osp.join(KG_DIR, 'drugmechdb_triples.tsv'), sep='\t')
dm_db_pairs.drop_duplicates(inplace=True)

print(f"{len(dm_db_pairs)} additional drug-BP pairs come from DrugMechDB, constituting {len(dm_db_pairs)/len(drug_bp_pairs)*100}%")

48 additional drug-BP pairs come from DrugMechDB, constituting 2.9593094944512948%


In [7]:
dm_db_pairs_small = pd.read_csv(osp.join(KG_DIR, 'drugmechdb_triples_small.tsv'), sep='\t')
dm_db_pairs_small.drop_duplicates(inplace=True)

print(f"{len(dm_db_pairs_small)} additional drug-BP pairs for MoA-net-Small come from DrugMechDB, constituting {len(dm_db_pairs_small)/len(drug_bp_pairs_small)*100}%")

48 additional drug-BP pairs for MoA-net-Small come from DrugMechDB, constituting 10.50328227571116%


In [8]:
dm_db_pairs_xsmall = pd.read_csv(osp.join(KG_DIR, 'drugmechdb_triples_xsmall.tsv'), sep='\t')
dm_db_pairs_xsmall.drop_duplicates(inplace=True)

print(f"{len(dm_db_pairs_xsmall)} additional drug-BP pairs for MoA-net-XSmall come from DrugMechDB, constituting {len(dm_db_pairs_xsmall)/len(drug_bp_pairs_xsmall)*100}%")

28 additional drug-BP pairs for MoA-net-XSmall come from DrugMechDB, constituting 9.621993127147768%


## Specify which pairs are able to be matched:

MARS / MARS can match pairs as long as they are 4 or less hops apart (our hyperparameter setting) and not connected via an inverse _CtBP edge.

In [9]:
G = nx.DiGraph()

for i, row in kg.iterrows():
    if row['edge_type'] == 'induces':
        continue
    src_id = row['source']
    trgt_id = row['target']
    if src_id not in G.nodes:
        G.add_node(src_id, type=row['source_node_type'])
    if trgt_id not in G.nodes:
        G.add_node(trgt_id, type=row['target_node_type'])
    G.add_edge(src_id, trgt_id, type=row['edge_type'])
    # if row['source_node_type'] != 'Compound':
    G.add_edge(trgt_id, src_id, type=f"_{row['edge_type']}")

Gp = nx.DiGraph()

for i, row in kg_protclass.iterrows():
    if row['edge_type'] == 'induces':
        continue
    src_id = row['source']
    trgt_id = row['target']
    if src_id not in Gp.nodes:
        Gp.add_node(src_id, type=row['source_node_type'])
    if trgt_id not in Gp.nodes:
        Gp.add_node(trgt_id, type=row['target_node_type'])
    Gp.add_edge(src_id, trgt_id, type=row['edge_type'])
    # if row['source_node_type'] != 'Compound':
    Gp.add_edge(trgt_id, src_id, type=f"_{row['edge_type']}")

G_small = nx.DiGraph()

for i, row in small_kg.iterrows():
    if row['edge_type'] == 'induces':
        continue
    src_id = row['source']
    trgt_id = row['target']
    if src_id not in G_small.nodes:
        G_small.add_node(src_id, type=row['source_node_type'])
    if trgt_id not in G_small.nodes:
        G_small.add_node(trgt_id, type=row['target_node_type'])
    G_small.add_edge(src_id, trgt_id, type=row['edge_type'])
    # if row['source_node_type'] != 'Compound':
    G_small.add_edge(trgt_id, src_id, type=f"_{row['edge_type']}")

G_xsmall = nx.DiGraph()

for i, row in xsmall_kg.iterrows():
    if row['edge_type'] == 'induces':
        continue
    src_id = row['source']
    trgt_id = row['target']
    if src_id not in G_xsmall.nodes:
        G_xsmall.add_node(src_id, type=row['source_node_type'])
    if trgt_id not in G_xsmall.nodes:
        G_xsmall.add_node(trgt_id, type=row['target_node_type'])
    G_xsmall.add_edge(src_id, trgt_id, type=row['edge_type'])
    # if row['source_node_type'] != 'Compound':
    G_xsmall.add_edge(trgt_id, src_id, type=f"_{row['edge_type']}")


Check the unmatched pairs for the regular, full KG:

In [10]:
unmatched_pairs = set()
unmatched_dm_db_pairs = set()

for i, row in drug_bp_pairs.iterrows():
    if not nx.has_path(G, row['source'], row['target']):
        unmatched_pairs.add(i)

for i, row in dm_db_pairs.iterrows():
    if not nx.has_path(G, row['source'], row['target']):
        unmatched_dm_db_pairs.add(i)

In [11]:
print(f"{len(unmatched_pairs)} drug-BP pairs in the final KG are not connected via 4 hops or less")
print(f"{len(unmatched_dm_db_pairs)} drug-BP pairs from DrugMechDB are not connected via 4 hops or less")

0 drug-BP pairs in the final KG are not connected via 4 hops or less
0 drug-BP pairs from DrugMechDB are not connected via 4 hops or less


Check the unmatched pairs for the protein-subclassed KG:

In [12]:
unmatched_pairs_p = set()
unmatched_dm_db_pairs_p = set()

for i, row in drug_bp_pairs_protclass.iterrows():
    if not nx.has_path(Gp, row['source'], row['target']):
        unmatched_pairs_p.add(i)

for i, row in dm_db_pairs.iterrows():
    if not nx.has_path(Gp, row['source'], row['target']):
        unmatched_dm_db_pairs_p.add(i)

In [13]:
print(f"{len(unmatched_pairs_p)} drug-BP pairs in the protein-classed KG are not connected via 4 hops or less")
print(f"{len(unmatched_dm_db_pairs_p)} drug-BP pairs from DrugMechDB are not connected via 4 hops or less")

0 drug-BP pairs in the protein-classed KG are not connected via 4 hops or less
0 drug-BP pairs from DrugMechDB are not connected via 4 hops or less


Check the unmatched pairs for the small KG:

In [14]:
unmatched_pairs_small = set()
unmatched_dm_db_pairs_small = set()

for i, row in drug_bp_pairs_small.iterrows():
    if not nx.has_path(G_small, row['source'], row['target']):
        unmatched_pairs_small.add(i)

for i, row in dm_db_pairs_small.iterrows():
    if not nx.has_path(G_small, row['source'], row['target']):
        unmatched_dm_db_pairs_small.add(i)

In [15]:
print(f"{len(unmatched_pairs_small)} drug-BP pairs in the MoA-net-Small are not connected via 4 hops or less")
print(f"{len(unmatched_dm_db_pairs_small)} drug-BP pairs from DrugMechDB are not connected via 4 hops or less")

0 drug-BP pairs in the MoA-net-Small are not connected via 4 hops or less
0 drug-BP pairs from DrugMechDB are not connected via 4 hops or less


Check the unmatched pairs for the x-small KG:

In [16]:
unmatched_pairs_xsmall = set()
unmatched_dm_db_pairs_xsmall = set()

for i, row in drug_bp_pairs_xsmall.iterrows():
    if not nx.has_path(G_xsmall, row['source'], row['target']):
        unmatched_pairs_xsmall.add(i)

for i, row in dm_db_pairs_xsmall.iterrows():
    if not nx.has_path(G_xsmall, row['source'], row['target']):
        unmatched_dm_db_pairs_xsmall.add(i)

In [17]:
print(f"{len(unmatched_pairs_xsmall)} drug-BP pairs in the MoA-net-XSmall are not connected via 4 hops or less")
print(f"{len(unmatched_dm_db_pairs_xsmall)} drug-BP pairs from DrugMechDB are not connected via 4 hops or less")

0 drug-BP pairs in the MoA-net-XSmall are not connected via 4 hops or less
0 drug-BP pairs from DrugMechDB are not connected via 4 hops or less


In [18]:
# Remove unmatched pairs from the final KG
drug_bp_pairs.drop(list(unmatched_pairs), inplace=True)
drug_bp_pairs.reset_index(drop=True, inplace=True)

# Remove unmatched pairs from the protein-classed KG
drug_bp_pairs_protclass.drop(list(unmatched_pairs_p), inplace=True)
drug_bp_pairs_protclass.reset_index(drop=True, inplace=True)

# Remove unmatched pairs from MoA-net-Small
drug_bp_pairs_small.drop(list(unmatched_pairs_small), inplace=True)
drug_bp_pairs_small.reset_index(drop=True, inplace=True)

# Remove unmatched pairs from MoA-net-XSmall
drug_bp_pairs_xsmall.drop(list(unmatched_pairs_xsmall), inplace=True)
drug_bp_pairs_xsmall.reset_index(drop=True, inplace=True)

# Remove unmatched pairs from DrugMechDB
# make copies of it for multiple KG variants:
dm_db_pairs_protclass = deepcopy(dm_db_pairs)

dm_db_pairs.drop(list(unmatched_dm_db_pairs), inplace=True)
dm_db_pairs.reset_index(drop=True, inplace=True)

dm_db_pairs_protclass.drop(list(unmatched_dm_db_pairs_p), inplace=True)
dm_db_pairs_protclass.reset_index(drop=True, inplace=True)

dm_db_pairs_small.drop(list(unmatched_dm_db_pairs_small), inplace=True)
dm_db_pairs_small.reset_index(drop=True, inplace=True)

dm_db_pairs_xsmall.drop(list(unmatched_dm_db_pairs_xsmall), inplace=True)
dm_db_pairs_xsmall.reset_index(drop=True, inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  drug_bp_pairs.drop(list(unmatched_pairs), inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  drug_bp_pairs_protclass.drop(list(unmatched_pairs_p), inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  drug_bp_pairs_small.drop(list(unmatched_pairs_small), inplace=True)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning

## Create the Splits

In the original PoLo example, they use different proportions for the splits, but let's go with something most similar to their Hetionet example, in which they do an approximate 60/20/20% split.

Note that the DrugMechDB examples need to be in the test set.

First, we'll exclude the subset of DrugMechDB examples that are in the test set.

Then, we'll split the remaining examples into train, validation, and test sets, with the test set accumulating to 20% with the DrugMechDB examples.

In [19]:
total_positives = len(drug_bp_pairs) + len(dm_db_pairs)

total_positives_pc = len(drug_bp_pairs_protclass) + len(dm_db_pairs_protclass)

total_positives_small = len(drug_bp_pairs_small) + len(dm_db_pairs_small)

total_positives_xsmall = len(drug_bp_pairs_xsmall) + len(dm_db_pairs_xsmall)

So we need to get the following numbers from the KG positive examples for train, validation, and test sets:

In [20]:
proportions = round(0.6 * total_positives), round(0.2 * total_positives), round(0.2 * total_positives) - len(dm_db_pairs)
print(proportions)

proportions_pc = round(0.6 * total_positives_pc), round(0.2 * total_positives_pc), round(0.2 * total_positives_pc) - len(dm_db_pairs_protclass)
print(proportions_pc)

proportions_small = round(0.6 * total_positives_small), round(0.2 * total_positives_small), round(0.2 * total_positives_small) - len(dm_db_pairs_small)
print(proportions_small)

proportions_xsmall = round(0.6 * total_positives_xsmall), round(0.2 * total_positives_xsmall), round(0.2 * total_positives_xsmall) - len(dm_db_pairs_xsmall)
print(proportions_xsmall)

(1002, 334, 286)
(1002, 334, 286)
(303, 101, 53)
(191, 64, 36)


In [21]:
# write a function which separates the dataframe into train, val and test sets of defined sizes
def train_test_split(df, train_size, val_size, test_size):
    df = df.sample(frac=1, random_state=7).reset_index(drop=True)
    train = df[:train_size]
    val = df[train_size:train_size+val_size]
    test = df[train_size+val_size:train_size+val_size+test_size]
    return train, val, test

In [22]:
train, val, test = train_test_split(drug_bp_pairs, proportions[0], proportions[1], proportions[2])

train_pc, val_pc, test_pc = train_test_split(drug_bp_pairs_protclass, proportions_pc[0], proportions_pc[1], proportions_pc[2])

train_small, val_small, test_small = train_test_split(drug_bp_pairs_small, proportions_small[0], proportions_small[1], proportions_small[2])

train_xsmall, val_xsmall, test_xsmall = train_test_split(drug_bp_pairs_xsmall, proportions_xsmall[0], proportions_xsmall[1], proportions_xsmall[2])

Check it did what we want:

In [23]:
len(train), len(val), len(test)

(1002, 334, 286)

In [24]:
len(train_pc), len(val_pc), len(test_pc)

(1002, 334, 286)

In [25]:
len(train_small), len(val_small), len(test_small)

(303, 101, 53)

In [26]:
len(train_xsmall), len(val_xsmall), len(test_xsmall)

(191, 64, 36)

No overlap?

In [27]:
train_pairs = {(row['source'], row['target']) for i, row in train.iterrows()}
test_pairs = {(row['source'], row['target']) for i, row in test.iterrows()}
val_pairs = {(row['source'], row['target']) for i, row in val.iterrows()}

In [28]:
train_pairs & test_pairs

set()

In [29]:
train_pairs & val_pairs

set()

In [30]:
test_pairs & val_pairs

set()

Overlap in the protein-classed one?

In [31]:
train_pairs = {(row['source'], row['target']) for i, row in train_pc.iterrows()}
test_pairs = {(row['source'], row['target']) for i, row in test_pc.iterrows()}
val_pairs = {(row['source'], row['target']) for i, row in val_pc.iterrows()}

In [32]:
train_pairs & test_pairs

set()

In [33]:
train_pairs & val_pairs

set()

In [34]:
test_pairs & val_pairs

set()

Overlap in the small one?

In [35]:
train_pairs = {(row['source'], row['target']) for i, row in train_small.iterrows()}
test_pairs = {(row['source'], row['target']) for i, row in test_small.iterrows()}
val_pairs = {(row['source'], row['target']) for i, row in val_small.iterrows()}

In [36]:
train_pairs & test_pairs

set()

In [37]:
train_pairs & val_pairs

set()

In [38]:
test_pairs & val_pairs

set()

Overlap in the x-small one?

In [39]:
train_pairs = {(row['source'], row['target']) for i, row in train_xsmall.iterrows()}
test_pairs = {(row['source'], row['target']) for i, row in test_xsmall.iterrows()}
val_pairs = {(row['source'], row['target']) for i, row in val_xsmall.iterrows()}

In [40]:
train_pairs & test_pairs

set()

In [41]:
train_pairs & val_pairs

set()

In [42]:
test_pairs & val_pairs

set()

Good, no overlap. Add the DrugMechDB examples to the test set:

In [43]:
test = pd.concat([test, dm_db_pairs]).sample(frac=1, random_state=7).reset_index(drop=True)

test_pc = pd.concat([test_pc, dm_db_pairs_protclass]).sample(frac=1, random_state=7).reset_index(drop=True)

test_small = pd.concat([test_small, dm_db_pairs_small]).sample(frac=1, random_state=7).reset_index(drop=True)

test_xsmall = pd.concat([test_xsmall, dm_db_pairs_xsmall]).sample(frac=1, random_state=7).reset_index(drop=True)

In [44]:
print(len(test))
print(len(test_pc))
print(len(test_small))
print(len(test_xsmall))

334
334
101
64


Take the test and validation sets out the KG:

In [45]:
kg = kg.loc[kg['edge_type'] != 'induces']
kg_mars = pd.concat([kg, train]).sample(frac=1, random_state=7).reset_index(drop=True)

kg_protclass = kg_protclass.loc[kg_protclass['edge_type'] != 'induces']
kg_mars_pc = pd.concat([kg_protclass, train_pc]).sample(frac=1, random_state=7).reset_index(drop=True)

small_kg = small_kg.loc[small_kg['edge_type'] != 'induces']
kg_mars_small = pd.concat([small_kg, train_small]).sample(frac=1, random_state=7).reset_index(drop=True)

xsmall_kg = xsmall_kg.loc[xsmall_kg['edge_type'] != 'induces']
kg_mars_xsmall = pd.concat([xsmall_kg, train_xsmall]).sample(frac=1, random_state=7).reset_index(drop=True)

In [46]:
len(kg_mars.loc[kg_mars['edge_type'] == 'induces']) == len(train)

True

In [47]:
len(kg_mars_pc.loc[kg_mars_pc['edge_type'] == 'induces']) == len(train_pc)

True

In [48]:
len(kg_mars_small.loc[kg_mars_small['edge_type'] == 'induces']) == len(train_small)

True

In [49]:
len(kg_mars_small.loc[kg_mars_small['edge_type'] == 'induces']) == len(train_small)

True

Write everything to files:

In [50]:
SPLITS_DIR = osp.join(KG_DIR, 'splits')
MOA_NET = osp.join(SPLITS_DIR, 'MoA-net')
MOA_NET_PROTCLASSED = osp.join(SPLITS_DIR, 'MoA-net-protclass')
MOA_NET_SMALL = osp.join(SPLITS_DIR, 'MoA-net-Small')
MOA_NET_XSMALL = osp.join(SPLITS_DIR, 'MoA-net-XSmall')

# create the above directories if they do not exist:
for dir in [SPLITS_DIR, MOA_NET, MOA_NET_PROTCLASSED, MOA_NET_SMALL, MOA_NET_XSMALL]:
    if not osp.exists(dir):
        os.makedirs(dir)

In [51]:
kg.to_csv(osp.join(MOA_NET, 'kg_no_cmp_bp.tsv'), sep='\t', index=False)
kg_mars.to_csv(osp.join(MOA_NET, 'kg_with_train_smpls.tsv'), sep='\t', index=False)

train.to_csv(osp.join(MOA_NET, 'train.tsv'), sep='\t', index=False)
val.to_csv(osp.join(MOA_NET, 'dev.tsv'), sep='\t', index=False)
test.to_csv(osp.join(MOA_NET, 'test.tsv'), sep='\t', index=False)

In [52]:
kg_protclass.to_csv(osp.join(MOA_NET_PROTCLASSED, 'kg_no_cmp_bp.tsv'), sep='\t', index=False)
kg_mars_pc.to_csv(osp.join(MOA_NET_PROTCLASSED, 'kg_with_train_smpls.tsv'), sep='\t', index=False)

train_pc.to_csv(osp.join(MOA_NET_PROTCLASSED, 'train.tsv'), sep='\t', index=False)
val_pc.to_csv(osp.join(MOA_NET_PROTCLASSED, 'dev.tsv'), sep='\t', index=False)
test_pc.to_csv(osp.join(MOA_NET_PROTCLASSED, 'test.tsv'), sep='\t', index=False)

In [53]:
small_kg.to_csv(osp.join(MOA_NET_SMALL, 'kg_no_cmp_bp.tsv'), sep='\t', index=False)
kg_mars_small.to_csv(osp.join(MOA_NET_SMALL, 'kg_with_train_smpls.tsv'), sep='\t', index=False)

train_small.to_csv(osp.join(MOA_NET_SMALL, 'train.tsv'), sep='\t', index=False)
val_small.to_csv(osp.join(MOA_NET_SMALL, 'dev.tsv'), sep='\t', index=False)
test_small.to_csv(osp.join(MOA_NET_SMALL, 'test.tsv'), sep='\t', index=False)

In [54]:
xsmall_kg.to_csv(osp.join(MOA_NET_XSMALL, 'kg_no_cmp_bp.tsv'), sep='\t', index=False)
kg_mars_xsmall.to_csv(osp.join(MOA_NET_XSMALL, 'kg_with_train_smpls.tsv'), sep='\t', index=False)

train_xsmall.to_csv(osp.join(MOA_NET_XSMALL, 'train.tsv'), sep='\t', index=False)
val_xsmall.to_csv(osp.join(MOA_NET_XSMALL, 'dev.tsv'), sep='\t', index=False)
test_xsmall.to_csv(osp.join(MOA_NET_XSMALL, 'test.tsv'), sep='\t', index=False)