# Benchmark analysis

This notebook compares the different benchmark according to the two KG.

# Imports

In [1]:
import os
import math
from collections import Counter

import pandas as pd
import numpy as np
from tqdm import tqdm

import networkx as nx
from networkx import DiGraph, connected_components

# Generate graph

In [2]:
KG_DATA_PATH = '../data/kg'

In [3]:
def create_graph_from_df(
    graph_df
) -> DiGraph:
    """Create fully connected graph from dataframe."""
    graph = DiGraph()

    for sub_name, obj_name, relation in graph_df.values:
        # Store edge in the graph
        graph.add_edge(
            sub_name,
            obj_name,
            polarity=relation,
        )

    print(f"Report on the number of relations: {dict(Counter(graph_df.edge_type))}")

    connected_components_subgraph = [
        component
        for component in sorted(
            connected_components(
                graph.to_undirected()
            ),
            key=len,
            reverse=True
        )
    ]

    final_subgraph = graph.subgraph(connected_components_subgraph[0])

    return final_subgraph

In [4]:
dataset_pairs = []

for graph_type in tqdm(['medium', 'small', 'xsmall', 'splits']):

    graph_df = pd.read_csv(
        os.path.join(KG_DATA_PATH, graph_type, 'kg_with_train_smpls.tsv'),
        sep='\t',
        usecols=['source', 'target', 'edge_type']
    )
    graph = create_graph_from_df(graph_df)

    gold_standard = pd.read_csv(
        os.path.join(KG_DATA_PATH, graph_type, 'test.tsv'),
        sep='\t',
        usecols=['source', 'target']
    )
    gold_standard['pairs'] = gold_standard['source'] + '_' + gold_standard['target']
    dataset_pairs.append([
        graph_type,
        graph,
        gold_standard['pairs'].tolist()
    ])



  0%|          | 0/4 [00:00<?, ?it/s]

 25%|██▌       | 1/4 [00:00<00:01,  1.98it/s]

Report on the number of relations: {'interacts': 20000, 'downregulates': 2196, 'upregulates': 1627, 'participates': 2605, 'induces': 792}


 50%|█████     | 2/4 [00:00<00:00,  3.10it/s]

Report on the number of relations: {'interacts': 10000, 'induces': 707, 'participates': 2603, 'upregulates': 1624, 'downregulates': 2196}
Report on the number of relations: {'downregulates': 844, 'upregulates': 720, 'interacts': 5264, 'participates': 1325, 'induces': 562}


 75%|███████▌  | 3/4 [00:00<00:00,  4.52it/s]

Report on the number of relations: {'interacts': 86786, 'participates': 4351, 'downregulates': 2205, 'upregulates': 1631, 'induces': 1111}


100%|██████████| 4/4 [00:02<00:00,  2.00it/s]


# Value by change for each KG

What is the change of getting our gold standard pair from all possible pair combinations.

In [5]:
val_by_chance_dict = {}

for graph_type, graph, gold_standard in tqdm(dataset_pairs):
    drugs = set()
    bps = set()

    for drug_bp_pair in gold_standard:
        drug, bp = drug_bp_pair.split('_')
        drugs.add(drug)
        bps.add(bp)

    total = len(drugs) * len(bps)
    prob = len(gold_standard) / total
    val_by_chance_dict[graph_type] = round(prob * 100, 3)

val_by_chance_dict 

100%|██████████| 4/4 [00:00<00:00, 2437.49it/s]


{'medium': 3.249, 'small': 3.409, 'xsmall': 5.269, 'splits': 3.379}

In [6]:
for graph_type, graph, gold_standard in tqdm(dataset_pairs):
    drugs = set()
    bps = set()

    for drug_bp_pair in gold_standard:
        drug, bp = drug_bp_pair.split('_')
        if drug in graph.nodes():
            drugs.add(drug)
        
        if bp in graph.nodes():
            bps.add(bp)

    print(f"{graph_type}: {len(drugs)} drugs, {len(bps)} bps")

100%|██████████| 4/4 [00:00<00:00, 577.79it/s]

medium: 125 drugs, 65 bps
small: 113 drugs, 61 bps
xsmall: 91 drugs, 39 bps
splits: 150 drugs, 73 bps





# Helper Functions

In [7]:
score_actual = {}
kg_dfs = {}

In [8]:
def khop(
    nodeA: str, 
    nodeB: str, 
    graph: nx.Graph, 
    total: bool
) -> tuple:
    
    """Find nodes within the distance limit """
    
    khop_A = {u for u in graph.neighbors(nodeA)}
    khop_B = {u for u in graph.neighbors(nodeB)}
    
    if total:
        return list(khop_A | khop_B), khop_A, khop_B
    else:
        return list(khop_A & khop_B), khop_A, khop_B

In [9]:
def get_dict_df(
    bps, 
    drugs, 
    undirected_kg_graph, 
    di_kg_graph,
    similarity_type,
    similarity_name
):

    t = []
    
    for bp in bps:
        
        cn = []
        
        # for each disease, find the similarity score with for each drug and append to list
        for drug in drugs:
                        
            shared_nodes, nodeA_neighbor, nodeB_neighbor = khop(
                nodeA=drug,
                nodeB=bp,
                graph=undirected_kg_graph, 
                total=False,
            )
            
            if similarity_type == 'cn':
                similarity = len(shared_nodes)
            
            elif similarity_type == 'sp':
                # try to see if path is between two nodes
                try:
                    similarity = len(nx.shortest_path(di_kg_graph, source=drug, target=bp))
                except nx.NetworkXNoPath:
                    similarity = 1000

            cn.append(similarity)
        
        if not similarity_type == 'sp':
            index = np.where(cn == np.amax(cn))
        else:
            index = np.where(cn == np.amin(cn))

        # if list is full of 0's (i.e sum == 0), then there are no shared neighbors 
        if np.sum(cn) == 0:
            continue 
        
        for val in index:
            for j in val:
                t.append(
                    {
                        'source': list(drugs)[j], 
                        'target': bp, 
                        similarity_name: cn[j]
                    }
                )

    return pd.DataFrame(t)

In [10]:
def get_precision(
    gold_standard_pairs: list, 
    predicted: list,
)-> tuple: 
    
    total = len(predicted)
    pos = 0
    
    for pair in predicted:
        if pair in gold_standard_pairs:
            pos += 1
    
    return round(((pos/total) * 100), 3), pos, total


# Different benchmark methods

In [11]:
sim_scores = {
    'cn': 'Common Neighbors',
    'sp': 'Shortest Path'
}

In [12]:
score_df = []

In [13]:
for graph_type, graph, gold_standard in dataset_pairs:
    drugs = set()
    bps = set()

    for drug_bp_pair in gold_standard:
        drug, bp = drug_bp_pair.split('_')
        if drug in graph.nodes():
            drugs.add(drug)
        
        if bp in graph.nodes():
            bps.add(bp)

    undirected_kg_graph = graph.to_undirected()

    for algo in tqdm(sim_scores, desc=f'Calculating scores for algorithms - {graph_type}'):
        algo_name = sim_scores[algo]
        
        full_df = get_dict_df(
            bps=list(bps),
            drugs=list(drugs), 
            undirected_kg_graph=undirected_kg_graph,
            di_kg_graph=graph,
            similarity_type=algo,
            similarity_name=algo_name
        )

        if full_df.empty:
            print(f'No results for {algo_name}')
            continue

            
        full_df['pair'] = full_df['source'] + '_' + full_df['target']
                
        precision, pos, total = get_precision(
            gold_standard_pairs=gold_standard,
            predicted=list(full_df['pair'].unique()),
        )

        score_df.append({
            'graph_type': graph_type,
            'algo_name': algo_name,
            'precision': precision,
            'val_by_chance': val_by_chance_dict[graph_type],
            '# pairs': f'{pos}/{total}',
        })

Calculating scores for algorithms - medium: 100%|██████████| 2/2 [00:02<00:00,  1.03s/it]
Calculating scores for algorithms - small: 100%|██████████| 2/2 [00:02<00:00,  1.26s/it]
Calculating scores for algorithms - xsmall: 100%|██████████| 2/2 [00:01<00:00,  1.78it/s]
Calculating scores for algorithms - splits: 100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


In [14]:
scores = pd.DataFrame(score_df)

In [15]:
scores

Unnamed: 0,graph_type,algo_name,precision,val_by_chance,# pairs
0,medium,Common Neighbors,14.957,3.249,35/234
1,medium,Shortest Path,3.936,3.249,32/813
2,small,Common Neighbors,16.827,3.409,35/208
3,small,Shortest Path,4.216,3.409,32/759
4,xsmall,Common Neighbors,13.274,5.269,15/113
5,xsmall,Shortest Path,1.61,5.269,9/559
6,splits,Common Neighbors,16.949,3.379,40/236
7,splits,Shortest Path,3.139,3.379,35/1115
