In [1]:
import numpy as np
import pandas as pd
from HierarchicalClusterTree import labelmat_to_tree, tree_to_labelmat, ClusterTreeNode, prune_tree

In [2]:
# Example data for label matrix and embedding
# Simulated data for demonstration
np.random.seed(42)
embedding = np.random.rand(8, 50)  # 100 samples, 50-dimensional embedding space

# Example label matrix with hierarchical clustering labels
labelmat = pd.DataFrame({
    'Level 1': ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C'],  # Main clusters
    'Level 2': ['A1', 'A2', 'A2', 'B1', 'B1', 'B2', 'C1', 'C2'],  # Sub-clusters
    'Level 3': ['A1.1', 'A2.1', 'A2.2', 'B1.1', 'B1.2', 'B2.1', 'C1.1', 'C1.2']  # Detailed sub-clusters
})
labelmat = labelmat.values

labelmat

array([['A', 'A1', 'A1.1'],
       ['A', 'A2', 'A2.1'],
       ['A', 'A2', 'A2.2'],
       ['B', 'B1', 'B1.1'],
       ['B', 'B1', 'B1.2'],
       ['B', 'B2', 'B2.1'],
       ['C', 'C1', 'C1.1'],
       ['C', 'C2', 'C1.2']], dtype=object)

In [3]:
# Convert the label matrix to a tree
print("Converting label matrix to cluster tree...")
cluster_tree = labelmat_to_tree(labelmat, embedding=embedding)

# Root node of the tree is returned
print(cluster_tree)

Converting label matrix to cluster tree...
ClusterTreeNode(name='root', level=0, num_members=8, parent='None', children=[A, B, C])


In [4]:
# Define a function to compare two clusters and return a bool. True = keep, False = merge
def check_func(clusterA: ClusterTreeNode, clusterB: ClusterTreeNode) -> bool:
    
    # Example check: if clusterB has less members than clusterA then merge them
    is_valid = len(clusterA.members) > len(clusterB.members)
    
    if not is_valid:
        print(f'Merging {clusterA.name} and {clusterB.name}')
    
    return is_valid

print("Pruning cluster tree based on provided conditions...")
prune_tree(cluster_tree, check_func)

Pruning cluster tree based on provided conditions...
Merging A and B
Merging B1 and A2
Merging B1.2 and B1.1
Merging C2 and C1
Merging C1.1 and C1.2


In [5]:
# Convert the tree back to a label matrix
print("Converting cluster tree back to label matrix...")
reconstructed_labelmat = tree_to_labelmat(cluster_tree)
reconstructed_labelmat


Converting cluster tree back to label matrix...


array([['root', 'A_B', 'A1', 'A1.1'],
       ['root', 'A_B', 'B1_A2', 'A2.1'],
       ['root', 'A_B', 'B1_A2', 'A2.2'],
       ['root', 'A_B', 'B1_A2', 'B1.2_B1.1'],
       ['root', 'A_B', 'B1_A2', 'B1.2_B1.1'],
       ['root', 'A_B', 'B2', 'B2.1'],
       ['root', 'C', 'C2_C1', 'C1.1_C1.2'],
       ['root', 'C', 'C2_C1', 'C1.1_C1.2']], dtype=object)