In [1]:
from typing import List, NamedTuple, Union

import numpy as np
from sklearn.datasets import make_blobs

In [2]:
X, y = make_blobs(n_samples=1_000_000, n_features=100, centers=15, random_state=42)

In [3]:
class Leaf(NamedTuple):
    bounds: np.ndarray
    centroid: np.ndarray
    count: int = 0

KDTree = Union['Node', Leaf]
        
class Node(NamedTuple):
    pivot_value: float
    pivot_feature: int
    left: KDTree = None
    right: KDTree = None

In [4]:
def make_median_tree(X, depth: int=4) -> KDTree:
    """Make KDTree using median as pivot

    Construct a KDTree out of data using median as a pivoting element.
    Each split makes two segments. The result doesn't contain the original
    data, just the splitting points, bounds of leaves, centroids in each box
    and count of items.

    Parameters
    ==========
    X : array_like, (n_samples, n_features)
        Set of observations to divide into boxes
        
    depth : int, optional (default 4)
        Number of splits done
    
    Returns
    =======
    tree : KDTree
        Lightweight KD-Tree over the data
    """
    X = np.asanyarray(X)
    if depth == 0:
        bounds = np.vstack([X.min(axis=0, keepdims=True),
                            X.max(axis=0, keepdims=True)])
        centroid = X.mean(axis=0, keepdims=True)
        return Leaf(bounds, centroid, X.shape[0])
    most_variant = X.var(axis=0).argmax()
    feature = X[:, most_variant]
    med_feature = np.median(feature)
    left = X[feature < med_feature]
    right = X[feature > med_feature]
    return Node(
        pivot_value=med_feature, pivot_feature=most_variant,
        left=make_median_tree(left, depth-1), right=make_median_tree(right, depth-1),
    )

In [5]:
%timeit make_median_tree(X)

4.12 s ± 13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
def get_leaves(tree: KDTree) -> List[Leaf]:
    """Extract leaves of the KDTree
    
    Parameters
    ==========
    tree : KDTree
        KDTree constructed on the data
        
    Returns
    =======
    leaves : list of Leaf
        All the leaves from the full depth of the tree
    """
    if isinstance(tree, Leaf):
        return [tree]
    return get_leaves(tree.left) + get_leaves(tree.right)

In [7]:
leaves = get_leaves(make_median_tree(X, depth=2))

In [8]:
assert len(leaves) == 4, len(leaves)

In [9]:
assert sum(l.count for l in leaves) == X.shape[0]

In [10]:
lb_approved = (leaves[0].bounds[0] <= X).all(axis=1)
ub_approved = (leaves[0].bounds[1] >= X).all(axis=1)
is_first_leave = np.logical_and(lb_approved, ub_approved)
assert is_first_leave.sum() == leaves[0].count

In [11]:
np.testing.assert_almost_equal(X[is_first_leave].mean(axis=0, keepdims=True),
                               leaves[0].centroid)

In [12]:
from skimage.filters import threshold_otsu

In [13]:
thr = threshold_otsu(X[0])
thr

-1.0802173354994267

In [14]:
%timeit np.median(X[0])

46.6 µs ± 550 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [15]:
%timeit threshold_otsu(X[0])

180 µs ± 4.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [16]:
def make_tree(X, leaf_size: Union[int, float]=0.01, pivot=np.median) -> KDTree:
    """Make KDTree out of the data

    Construct a KDTree out of data using Otsu threshold as a pivoting element.
    Each split makes two segments. The result doesn't contain the original
    data, just the splitting points, bounds of leaves, centroids in each box
    and count of items.

    Parameters
    ==========
    X : array_like, (n_samples, n_features)
        Set of observations to divide into boxes
        
    leaf_size : int or float, optional (default 0.01)
        Desired leaf size. When int, it will be between `leaf_size` and
        `2 * leaf_size`. When float, it will be between
        `leaf_size * n_samples` and `2 * leaf_size * n_samples`
    
    pivot : callable, optional (default np.median)
        Method to find the pivot element
    
    Returns
    =======
    tree : KDTree
        Lightweight KD-Tree over the data
    """
    X = np.asanyarray(X)
    if isinstance(leaf_size, float):
        if 0 <= leaf_size <= 1:
            leaf_size = int(leaf_size * X.shape[0])
        else:
            raise ValueError('leaf_size must be between 0 and 1 when float')
    if X.shape[0] < 2 * leaf_size:
        bounds = np.vstack([X.min(axis=0, keepdims=True),
                            X.max(axis=0, keepdims=True)])
        centroid = X.mean(axis=0, keepdims=True)
        return Leaf(bounds, centroid, X.shape[0])
    most_variant = X.var(axis=0).argmax()
    feature = X[:, most_variant]
    thr = pivot(feature)
    left = X[feature < thr]
    right = X[feature > thr]
    return Node(
        pivot_value=thr, pivot_feature=most_variant,
        left=make_tree(left, leaf_size=leaf_size, pivot=pivot),
        right=make_tree(right, leaf_size=leaf_size, pivot=pivot),
    )

In [17]:
%timeit make_tree(X, leaf_size=25_000, pivot=threshold_otsu)

5.02 s ± 180 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
