#### a hierarchical n-ary tree

recursively clusters input ndarray of embeddings using faiss
embeddings must be numpy float32

resulting 'nodes' is list of cluster data :
{
   'parent':parent_node_index,  # index into list of nodes - e.g. the root id is 0
   'id':parent_node_index + j,  # index into list of nodes 
   'labels':rootlabels          # labels are integer row offsets into input ndarray
   'kids':[]                    # list of node_ids
}

It is assumed that there is an associated list of content identifiers of same length as number of rows in input
This for each cluster we can fetch associated content

This is also suffucient data to build a tree structure - e.g. network digraph
Noe that leaf nodes have empty kid lists


In [80]:
def kmeans_raw(X,K):
    """ cluster X"""
    import faiss
    c = faiss.Kmeans(X.shape[1],K,spherical = True,verbose = True)
    c.train(X)
    sims,labels = c.assign(X)
    return sims,labels

def split(X,x,xlabels = None,nodes = [],arity = 2,depth = 0,parent_node = {}):
    """
    X is root ndarray
    x is current ndarray to cluster
    offsets is list of offsets into root ndarray
    nodes is a list of data for items in tree
    arity is number of clusters
    """
    # clusters - cluster(X, k = arity)
    parent_node_index = parent_node['id']
    
    if depth > 4:
        return
    try:
        sims,labels = kmeans_raw(x,arity)      # labels is list of indices of clusters
    except Exception as error:
        print(error)
        return
    
    kids = []
        
    for j in range(arity):
        rootlabels = [xlabels[i] for i,k in enumerate(labels) if k == j] # map cluster labels to root labels
        if len(rootlabels) < (arity * 2):
            continue
        this_node_index = parent_node_index  + j + 1   
        child_node = {
            'parent':parent_node_index,
            'id':this_node_index,
            'labels':rootlabels
        }
        

        kids.append(this_node_index)
        nodes.append(child_node)
        _x = X[rootlabels]             # this is a view into source ndarray   
        split(X,_x,xlabels,nodes,arity, depth + 1,child_node)
    parent_node['kids'] = kids
        
        
        
def build_tree(X,arity):
    labels = list(range(X.shape[0]))
    root_node = {
        'parent':None,
        'id':0,
        'labels':labels
    }
    nodes = [root_node]
    split(X,X,labels,nodes,arity,depth = 0,parent_node = root_node)
    return nodes

    

In [81]:
from _utils_.tau_IO import *
ts = get_tauset('tauset_wiki_1300_centroids')
X = ts[:100].m
nodes  = build_tree(X,2)
for node in nodes:
    print(node)


{'parent': None, 'id': 0, 'labels': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], 'kids': [1, 2]}
{'parent': 0, 'id': 1, 'labels': [6, 7, 10, 22, 23, 43, 45, 46, 49, 55, 63, 69, 83, 86, 90], 'kids': [2]}
{'parent': 1, 'id': 2, 'labels': [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13], 'kids': [3, 4]}
{'parent': 2, 'id': 3, 'labels': [0, 1, 3, 4, 5, 7], 'kids': [5]}
{'parent': 3, 'id': 5, 'labels': [0, 2, 3, 4, 5], 'kids': []}
{'parent': 2, 'id': 4, 'labels': [2, 6, 8, 9, 10, 11], 'kids': [6]}
{'parent': 4, 'id': 6, 'labels': [0, 2, 3, 4, 5], 'kids': []}
{'parent': 0, 'id': 2, 'labels': [0, 1, 2, 3, 4, 5, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 