In [7]:
from itertools import combinations

import numpy as np
import scipy as sp

def _validate_parameters(dm, num_prototypes, seedset=None):
    '''Validate the paramters for each algorithm.
    Parameters
    ----------
    dm: skbio.stats.distance.DistanceMatrix
        Pairwise distances for all elements in the full set S.
    num_prototypes: int
        Number of prototypes to select for distance matrix.
        Must be >= 2, since a single prototype is useless.
        Must be smaller than the number of elements in the distance matrix,
        otherwise no reduction is necessary.
    seedset: iterable of str
        A set of element IDs that are pre-selected as prototypes. Remaining
        prototypes are then recruited with the prototype selection algorithm.
        Warning: It will most likely violate the global objective function.
    Raises
    ------
    ValueError
        The number of prototypes to be found should be at least 2 and at most
        one element smaller than elements in the distance matrix. Otherwise, a
        ValueError is raised.
        The IDs in the seed set must be unique, and must be present in the
        distance matrix. Otherwise, a ValueError is raised.
        The size of the seed set must be smaller than the number of prototypes
        to be found. Otherwise, a ValueError is raised.
    '''
    if num_prototypes < 2:
        raise ValueError("'num_prototypes' must be >= 2, since a single "
                         "prototype is useless.")
    if num_prototypes >= dm.shape[0]:
        raise ValueError("'num_prototypes' must be smaller than the number of "
                         "elements in the distance matrix, otherwise no "
                         "reduction is necessary.")
    if seedset is not None:
        seeds = set(seedset)
        if len(seeds) < len(seedset):
            raise ValueError("There are duplicated IDs in 'seedset'.")
        if not seeds < set(dm.ids):  # test if set A is a subset of set B
            raise ValueError("'seedset' is not a subset of the element IDs in "
                             "the distance matrix.")
        if len(seeds) >= num_prototypes:
            raise ValueError("Size of 'seedset' must be smaller than the "
                             "number of prototypes to select.")

def prototype_selection_constructive_maxdist(dm, num_prototypes, seedset=None):
    '''Heuristically select k prototypes for given distance matrix.
       Prototype selection is NP-hard. This is an implementation of a greedy
       correctness heuristic: Greedily grow the set of prototypes by adding the
       element with the largest sum of distances to the non-prototype elements.
       Start with the two elements that are globally most distant from each
       other. The set of prototypes is then constructively grown by adding the
       element showing largest sum of distances to all non-prototype elements
       in the distance matrix in each iteration.
    Parameters
    ----------
    dm: skbio.stats.distance.DistanceMatrix
        Pairwise distances for all elements in the full set S.
    num_prototypes: int
        Number of prototypes to select for distance matrix.
        Must be >= 2, since a single prototype is useless.
        Must be smaller than the number of elements in the distance matrix,
        otherwise no reduction is necessary.
    seedset: iterable of str
        A set of element IDs that are pre-selected as prototypes. Remaining
        prototypes are then recruited with the prototype selection algorithm.
        Warning: It will most likely violate the global objective function.
    Returns
    -------
    list of str
        A sequence holding selected prototypes, i.e. a sub-set of the
        IDs of the elements in the distance matrix.
    Raises
    ------
    ValueError
        The number of prototypes to be found should be at least 2 and at most
        one element smaller than elements in the distance matrix. Otherwise, a
        ValueError is raised.
    Notes
    -----
    Timing: %timeit -n 100 prototype_selection_constructive_maxdist(dm, 100)
            100 loops, best of 3: 1.43 s per loop
            where the dm holds 27,398 elements
    function signature with type annotation for future use with python >= 3.5:
    def prototype_selection_constructive_maxdist(dm: DistanceMatrix,
    num_prototypes: int, seedset: List[str]) -> List[str]:
    '''
    _validate_parameters(dm, num_prototypes, seedset)

    # initially mark all elements as uncovered, i.e. as not being a prototype
    uncovered = np.asarray([np.True_] * dm.shape[0])
    res_set, num_found_prototypes = [], 0

    if seedset is not None:
        # mark elements in the seedset as found
        seedset = set(seedset)
        for idx, id_ in enumerate(dm.ids):
            if id_ in seedset:
                uncovered[idx] = np.False_
                res_set.append(idx)
    else:
        # the first two prototypes are those elements that have the globally
        # maximal distance in the distance matrix. Mark those two elements as
        # being covered, i.e. prototypes
        res_set = list(np.unravel_index(dm.data.argmax(), dm.data.shape))
        uncovered[res_set] = np.False_

    # counts the number of already found prototypes
    num_found_prototypes = len(res_set)

    # repeat until enough prototypes have been selected:
    # the new prototype is the element that has maximal distance sum to all
    # non-prototype elements in the distance matrix.
    while num_found_prototypes < num_prototypes:
        max_elm_idx = (dm.data[res_set, :].sum(axis=0) * uncovered).argmax()
        uncovered[max_elm_idx] = np.False_
        num_found_prototypes += 1
        res_set.append(max_elm_idx)

    # return the ids of the selected prototype elements
    return [dm.ids[idx] for idx, x in enumerate(uncovered) if not x]


In [50]:
from skbio.stats.distance import DistanceMatrix
from skbio.util import get_data_path
dm = DistanceMatrix.read(get_data_path('/mnt/home/djin/ceph/snakemake/data/Wang2020/beta/data/distance-matrix.tsv'))

In [51]:
prototype_selection_constructive_maxdist(dm, 30, seedset=None)

['ERR2608608',
 'ERR2608657',
 'ERR2608613',
 'ERR2608671',
 'ERR2608662',
 'ERR2608655',
 'ERR2608663',
 'ERR2608620',
 'ERR2608609',
 'ERR2608665',
 'ERR2608628',
 'ERR2608632',
 'ERR2608672',
 'ERR2608658',
 'ERR2608659',
 'ERR2608637',
 'ERR2608634',
 'ERR2608984',
 'ERR2608656',
 'ERR2608986',
 'ERR2608646',
 'ERR2608645',
 'ERR2608635',
 'ERR2608661',
 'ERR2608644',
 'ERR2608648',
 'ERR2608618',
 'ERR2608611',
 'ERR2608615',
 'ERR2608654']