# Which DWWC to use?

`dwwc_recursive` with caching appears to be the fastest method for bulk DWWC computation. However, `dwpc` makes many `dwwc` calls, often on very short segments, for which it is not clear whether `dwwc_chain` or `dwwc_recursive` will be faster. To find out, we compute a number of DWPCs, using either `dwwc_chain` or `dwwc_recursive` as the DWWC function.

In [1]:
from hetmech.degree_weight import (
    _dwpc_disjoint, _dwpc_short_repeat, _dwpc_general_case, _dwpc_baab, _dwpc_baba,
    path_count_cache, categorize, get_segments, get_all_segments, dwwc, dwwc_chain, dwwc_recursive,
    _dwpc_repeat_around

)
import hetmech.hetmat
from hetmech.matrix import (
    copy_array,
    metaedge_to_adjacency_matrix,
    normalize,
)
import functools
import numpy as np
import numpy 
import tqdm
import time
import pandas as pd
from scipy import sparse

In [2]:
hetmat = hetmech.hetmat.HetMat('../data/hetionet-v1.0.hetmat/')

In [3]:
metapaths = list()
for metapath in hetmat.metagraph.extract_metapaths('Compound', 'Disease', max_length=4):
    if hetmech.degree_weight.categorize(metapath) in {'long_repeat', 'other'}:
        continue
    metapaths.append(metapath)
len(metapaths)

1172

## Original DWWC computation times

In [4]:
def category_to_function(category):
    function_dictionary = {'no_repeats': dwwc,
                           'disjoint': _dwpc_disjoint,
                           'disjoint_groups': _dwpc_disjoint,
                           'short_repeat': _dwpc_short_repeat,
                           'long_repeat': _dwpc_general_case,
                           'BAAB': _dwpc_baab,
                           'BABA': _dwpc_baba,
                           'repeat_around': _dwpc_repeat_around,
                           'interior_complete_group': _dwpc_baba,
                           'other': _dwpc_general_case}
    return function_dictionary[category]

@path_count_cache(metric='dwpc')
def dwpc(graph, metapath, damping=0.5, dense_threshold=0, use_general=False,
         dtype=numpy.float64):
    category = categorize(metapath)
    dwpc_function = category_to_function(category)
    if category in ('long_repeat', 'other'):
        if use_general:
            row_names, col_names, dwpc_matrix = _dwpc_general_case(
                graph, metapath, damping)
        else:
            raise NotImplementedError(
                'Metapath category will use _dwpc_general_case')
    else:
        row_names, col_names, dwpc_matrix = dwpc_function(
            graph, metapath, damping, dense_threshold=dense_threshold,
            dtype=dtype)

    return row_names, col_names, dwpc_matrix

In [5]:
hetmat = hetmech.hetmat.HetMat('../data/hetionet-v1.0.hetmat/')
hetmat.path_counts_cache = hetmech.hetmat.PathCountPriorityCache(hetmat, allocate_GB=5)

times = []
for metapath in tqdm.tqdm(metapaths):
    time1 = time.time()
    dwpc(hetmat, metapath, dense_threshold=1)
    time2 = time.time()
    times.append([metapath, time2 - time1, 'recursive'])

100%|██████████| 1172/1172 [29:53<00:00,  1.53s/it]


## Recursive DWWC computation times

In [6]:
def category_to_function(category):
    function_dictionary = {'no_repeats': dwwc_recursive,
                           'disjoint': _dwpc_disjoint,
                           'disjoint_groups': _dwpc_disjoint,
                           'short_repeat': _dwpc_short_repeat,
                           'long_repeat': _dwpc_general_case,
                           'BAAB': _dwpc_baab,
                           'BABA': _dwpc_baba,
                           'repeat_around': _dwpc_repeat_around,
                           'interior_complete_group': _dwpc_baba,
                           'other': _dwpc_general_case}
    return function_dictionary[category]

@path_count_cache(metric='dwpc')
def dwpc(graph, metapath, damping=0.5, dense_threshold=0, use_general=False,
         dtype=numpy.float64):
    category = categorize(metapath)
    dwpc_function = category_to_function(category)
    if category in ('long_repeat', 'other'):
        if use_general:
            row_names, col_names, dwpc_matrix = _dwpc_general_case(
                graph, metapath, damping)
        else:
            raise NotImplementedError(
                'Metapath category will use _dwpc_general_case')
    else:
        row_names, col_names, dwpc_matrix = dwpc_function(
            graph, metapath, damping, dense_threshold=dense_threshold,
            dtype=dtype)

    return row_names, col_names, dwpc_matrix

In [7]:
hetmat_rec = hetmech.hetmat.HetMat('../data/hetionet-v1.0.hetmat/')
hetmat_rec.path_counts_cache = hetmech.hetmat.PathCountPriorityCache(hetmat_rec, allocate_GB=5)

times = []
for metapath in tqdm.tqdm(metapaths):
    time1 = time.time()
    dwpc(hetmat_rec, metapath, dense_threshold=1)
    time2 = time.time()
    times.append([metapath, time2 - time1, 'recursive'])

100%|██████████| 1172/1172 [28:40<00:00,  1.47s/it]


## Chain DWWC computation times

In [8]:
def category_to_function(category):
    function_dictionary = {'no_repeats': dwwc_chain,
                           'disjoint': _dwpc_disjoint,
                           'disjoint_groups': _dwpc_disjoint,
                           'short_repeat': _dwpc_short_repeat,
                           'long_repeat': _dwpc_general_case,
                           'BAAB': _dwpc_baab,
                           'BABA': _dwpc_baba,
                           'repeat_around': _dwpc_repeat_around,
                           'interior_complete_group': _dwpc_baba,
                           'other': _dwpc_general_case}
    return function_dictionary[category]

@path_count_cache(metric='dwpc')
def dwpc(graph, metapath, damping=0.5, dense_threshold=0, use_general=False,
         dtype=numpy.float64):
    category = categorize(metapath)
    dwpc_function = category_to_function(category)
    if category in ('long_repeat', 'other'):
        if use_general:
            row_names, col_names, dwpc_matrix = _dwpc_general_case(
                graph, metapath, damping)
        else:
            raise NotImplementedError(
                'Metapath category will use _dwpc_general_case')
    else:
        row_names, col_names, dwpc_matrix = dwpc_function(
            graph, metapath, damping, dense_threshold=dense_threshold,
            dtype=dtype)

    return row_names, col_names, dwpc_matrix

In [9]:
hetmat_chain = hetmech.hetmat.HetMat('../data/hetionet-v1.0.hetmat/')
hetmat_chain.path_counts_cache = hetmech.hetmat.PathCountPriorityCache(hetmat_chain, allocate_GB=5)

for metapath in tqdm.tqdm(metapaths):
    time1 = time.time()
    dwpc(hetmat_chain, metapath, dense_threshold=1)
    time2 = time.time()
    times.append([metapath, time2 - time1, 'chain'])

100%|██████████| 1172/1172 [27:40<00:00,  1.42s/it]


In [None]:
times_df = pd.DataFrame(times, columns=['metapath', 'dwpc-time', 'dwwc-method'])
times_df.head()

## Baseline cache method using functools

In [4]:
def category_to_function(category):
    function_dictionary = {'no_repeats': dwwc_chain,
                           'disjoint': _dwpc_disjoint,
                           'disjoint_groups': _dwpc_disjoint,
                           'short_repeat': _dwpc_short_repeat,
                           'long_repeat': _dwpc_general_case,
                           'BAAB': _dwpc_baab,
                           'BABA': _dwpc_baba,
                           'repeat_around': _dwpc_repeat_around,
                           'interior_complete_group': _dwpc_baba,
                           'other': _dwpc_general_case}
    return function_dictionary[category]

@functools.lru_cache(maxsize=128)
def dwpc(graph, metapath, damping=0.5, dense_threshold=0, use_general=False,
         dtype=numpy.float64):
    category = categorize(metapath)
    dwpc_function = category_to_function(category)
    if category in ('long_repeat', 'other'):
        if use_general:
            row_names, col_names, dwpc_matrix = _dwpc_general_case(
                graph, metapath, damping)
        else:
            raise NotImplementedError(
                'Metapath category will use _dwpc_general_case')
    else:
        row_names, col_names, dwpc_matrix = dwpc_function(
            graph, metapath, damping, dense_threshold=dense_threshold,
            dtype=dtype)

    return row_names, col_names, dwpc_matrix

In [5]:
hetmat = hetmech.hetmat.HetMat('../data/hetionet-v1.0.hetmat/')

times=[]
for metapath in tqdm.tqdm(metapaths):
    time1 = time.time()
    dwpc(hetmat, metapath, dense_threshold=1)
    time2 = time.time()
    times.append([metapath, time2 - time1, 'chain'])

100%|██████████| 1172/1172 [30:07<00:00,  1.54s/it]
