# Train - Test - Validation Splits

In [1]:
import pandas as pd
import os.path as osp
import networkx as nx

## Load in the final KG

In [2]:
KG_DIR = '../data/kg'

In [3]:
kg = pd.read_csv(osp.join(KG_DIR, 'final_kg_subclassed.tsv'), sep='\t')
kg.drop_duplicates(inplace=True)

In [4]:
drug_bp_pairs = kg.loc[kg['edge_type'] == 'induces']

In [5]:
print(f"There are {len(drug_bp_pairs)} positive drug-BP pairs in the final KG")

There are 1622 positive drug-BP pairs in the final KG


## Load in the DrugMechDB pairs that go in the test set

In [6]:
dm_db_pairs = pd.read_csv(osp.join(KG_DIR, 'test.tsv'), sep='\t')
dm_db_pairs.drop_duplicates(inplace=True)

In [7]:
print(f"{len(dm_db_pairs)} additional drug-BP pairs come from DrugMechDB, constituting {len(dm_db_pairs)/len(drug_bp_pairs)*100}%")

48 additional drug-BP pairs come from DrugMechDB, constituting 2.9593094944512948%


## Specify which pairs are able to be matched:

PoLo / MARS can match pairs as long as they are 4 or less hops apart (our hyperparameter setting) and not connected via an inverse _CtBP edge.

In [8]:
G = nx.DiGraph()

for i, row in kg.iterrows():
    if row['edge_type'] == 'induces':
        continue
    src_id = row['source']
    trgt_id = row['target']
    if src_id not in G.nodes:
        G.add_node(src_id, type=row['source_node_type'])
    if trgt_id not in G.nodes:
        G.add_node(trgt_id, type=row['target_node_type'])
    G.add_edge(src_id, trgt_id, type=row['edge_type'])
    # if row['source_node_type'] != 'Compound':
    G.add_edge(trgt_id, src_id, type=f"_{row['edge_type']}")


In [9]:
unmatched_pairs = set()
unmatched_dm_db_pairs = set()

for i, row in drug_bp_pairs.iterrows():
    try:
        spath = nx.shortest_path(G, row['source'], row['target'])
        if len(spath) > 5: # 4 hops, 5 nodes
            unmatched_pairs.add(i)
    except nx.NetworkXNoPath:
        unmatched_pairs.add(i)

for i, row in dm_db_pairs.iterrows():
    try:
        spath = nx.shortest_path(G, row['source'], row['target'])
        if len(spath) > 5: # 4 hops, 5 nodes
            unmatched_dm_db_pairs.add(i)
    except nx.NetworkXNoPath:
        unmatched_dm_db_pairs.add(i)

In [10]:
print(f"{len(unmatched_pairs)} drug-BP pairs in the final KG are not connected via 4 hops or less")
print(f"{len(unmatched_dm_db_pairs)} drug-BP pairs from DrugMechDB are not connected via 4 hops or less")

69 drug-BP pairs in the final KG are not connected via 4 hops or less
0 drug-BP pairs from DrugMechDB are not connected via 4 hops or less


In [51]:
# Remove unmatched pairs from the final KG
drug_bp_pairs.drop(list(unmatched_pairs), inplace=True)
drug_bp_pairs.reset_index(drop=True, inplace=True)

# Remove unmatched pairs from DrugMechDB
dm_db_pairs.drop(list(unmatched_dm_db_pairs), inplace=True)
dm_db_pairs.reset_index(drop=True, inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  errors=errors,


In [55]:
shortest_path = nx.shortest_path(G, source='pubchem.compound:5717', target='GO:0002376')
print("Shortest path:", shortest_path)

Shortest path: ['pubchem.compound:5717', 'ncbigene:10800', 'ncbigene:1906', 'ncbigene:4772', 'ncbigene:159', 'GO:0002376']


## Create the Splits

In the original PoLo example, they use different proportions for the splits, but let's go with something most similar to their Hetionet example, in which they do an approximate 60/20/20% split.

Note that the DrugMechDB examples need to be in the test set.

First, we'll exclude the subset of DrugMechDB examples that are in the test set.

Then, we'll split the remaining examples into train, validation, and test sets, with the test set accumulating to 20% with the DrugMechDB examples.

In [13]:
total_positives = len(drug_bp_pairs) + len(dm_db_pairs)

So we need to get the following numbers from the KG positive examples for train, validation, and test sets:

In [14]:
proportions = round(0.6 * total_positives), round(0.2 * total_positives), round(0.2 * total_positives) - len(dm_db_pairs)
proportions

(1000, 333, 289)

In [15]:
# write a function which separates the dataframe into train, val and test sets of defined sizes
def train_test_split(df, train_size, val_size, test_size):
    df = df.sample(frac=1, random_state=7).reset_index(drop=True)
    train = df[:train_size]
    val = df[train_size:train_size+val_size]
    test = df[train_size+val_size:train_size+val_size+test_size]
    return train, val, test

In [16]:
train, val, test = train_test_split(drug_bp_pairs, proportions[0], proportions[1], proportions[2])

Check it did what we want:

In [17]:
len(train), len(val), len(test)

(1000, 333, 289)

No overlap?

In [18]:
train_pairs = {(row['source'], row['target']) for i, row in train.iterrows()}
test_pairs = {(row['source'], row['target']) for i, row in test.iterrows()}
val_pairs = {(row['source'], row['target']) for i, row in val.iterrows()}

In [19]:
train_pairs & test_pairs

set()

In [20]:
train_pairs & val_pairs

set()

In [21]:
test_pairs & val_pairs

set()

Good, no overlap. Add the drugmechDB examples to the test set:

In [22]:
test = pd.concat([test, dm_db_pairs]).sample(frac=1, random_state=7).reset_index(drop=True)

In [23]:
len(test)

333

Take the test and validation sets out the KG:

In [24]:
kg = kg.loc[kg['edge_type'] != 'induces']
kg_polo = pd.concat([kg, train]).sample(frac=1, random_state=7).reset_index(drop=True)

In [25]:
len(kg_polo.loc[kg_polo['edge_type'] == 'induces']) == len(train)

True

Write everything to files:

In [26]:
SPLITS_DIR = osp.join(KG_DIR, 'splits')

In [27]:
kg.to_csv(osp.join(SPLITS_DIR, 'kg_no_cmp_bp.tsv'), sep='\t', index=False)
kg_polo.to_csv(osp.join(SPLITS_DIR, 'kg_with_train_smpls.tsv'), sep='\t', index=False)

train.to_csv(osp.join(SPLITS_DIR, 'train.tsv'), sep='\t', index=False)
val.to_csv(osp.join(SPLITS_DIR, 'dev.tsv'), sep='\t', index=False)
test.to_csv(osp.join(SPLITS_DIR, 'test.tsv'), sep='\t', index=False)

For the PoLo files, let's get it into a format suitable for PoLo:

In [28]:
POLO_DIR = osp.join(SPLITS_DIR, 'PoLo')