# Trees!

Oftentimes, we don't want to just settle for flat representations of our dataset, we would also like to learn heirarchical structure as part of the inc 

In [38]:
import numpy as np
from crp.table import ChineseRestaurantTable, DirichletMultinomialTable, NegativeBinomialTable

In [None]:
import numpy as np
import scanpy as sc

adata = sc.read_h5ad("/home/jhaberbe/Data/choroid-plexus/new_annotations.h5ad")
adata = adata[adata.obs["Cell.Subtype"].ne("Doublet")]
adata = adata[adata.X.sum(axis=1) > 300]
sc.pp.highly_variable_genes(adata, flavor="seurat_v3")

counts = adata.X.todense()
size_factors = np.log(counts.sum(axis=1) / counts.sum(axis=1).mean())

  adata.uns["hvg"] = {"flavor": flavor}


### Transforming our Tables into Nodes

In [None]:
X = np.array(counts[::100, adata.var.highly_variable])

In [338]:
import numpy as np

class CRPNode:
    def __init__(self, data, depth=0, parent=None, table_class=None, expected_number_of_classes: int = 1):
        self.data = data
        self.depth = depth
        self.parent = parent
        self.children = {}
        self.members = set()
        self.table = table_class(data)
        self.table_class = table_class
        self.expected_number_of_classes = expected_number_of_classes
        self.alpha = expected_number_of_classes / np.log(self.data.shape[0])

    def add_child(self, data):
        child = CRPNode(data, depth=self.depth + 1, parent=self, table_class=self.table_class, expected_number_of_classes = self.expected_number_of_classes)
        i = 0
        while i in self.children:
            i+=1
        self.children[i] = child
        return child

    def add_member(self, index):
        self.members.add(index)
        self.table.add_member(index)

    def remove_member(self, index):
        self.members.discard(index)
        self.table.remove_member(index)
    
    def has_member(self, index):
        return index in self.members
    
    @staticmethod
    def sample_path(node, index, depth=0, max_depth=4):
        node.add_member(index)
        existing_children = list(node.children.items())
        log_posteriors = []

        # Score existing children
        for child_key, child_node in existing_children:
            ll = child_node.table.log_likelihood(index, posterior=True)
            prior = np.log1p(len(child_node.members))  # prior favors larger children
            log_posteriors.append(ll + prior)

        # Score new child
        new_child = CRPNode(node.table.data, depth=node.depth + 1, parent=node, table_class=node.table_class, expected_number_of_classes = node.expected_number_of_classes)
        ll_new = new_child.table.log_likelihood(index, posterior=True)
        prior_new = np.log(node.alpha if hasattr(node, 'alpha') else 1.0)  # Use alpha if set, else 1.0
        log_posteriors.append(ll_new + prior_new)

        # Normalize and sample
        log_posteriors = np.array(log_posteriors)
        max_log = np.max(log_posteriors)
        probs = np.exp(log_posteriors - max_log)
        probs /= probs.sum()

        choice = np.random.choice(len(probs), p=probs)

        if choice == len(existing_children):
            # Create and add new child
            new_key = 0
            while new_key in node.children:
                new_key += 1
            node.children[new_key] = new_child
            return [new_key]  # return path as list

        else:
            child_key = existing_children[choice][0]
            if depth + 1 < max_depth:
                # Recurse down the chosen child node
                path = CRPNode.sample_path(node.children[child_key], index, depth=depth + 1, max_depth=max_depth)
                return [child_key] + path
            else:
                # At max depth, add member to this node and return path
                node.children[child_key].add_member(index)
                return [child_key]


In [None]:
root = CRPNode(X, table_class=NegativeBinomialTable)
for idx in range(X.shape[0]):
    # Optionally remove idx from current assignments if doing iterative inference
    path = CRPNode.sample_path(root, idx, max_depth=3)

Sample 0 assigned to path [0]
Sample 1 assigned to path [0, 0]
Sample 2 assigned to path [0, 0, 0]
Sample 3 assigned to path [0, 0, 1]
Sample 4 assigned to path [0, 0, 0]
Sample 5 assigned to path [0, 0, 0]
Sample 6 assigned to path [0, 0, 1]
Sample 7 assigned to path [0, 0, 0]
Sample 8 assigned to path [0, 0, 0]
Sample 9 assigned to path [0, 0, 0]
Sample 10 assigned to path [0, 0, 0]
Sample 11 assigned to path [0, 0, 1]
Sample 12 assigned to path [0, 0, 0]
Sample 13 assigned to path [0, 0, 0]
Sample 14 assigned to path [0, 0, 0]
Sample 15 assigned to path [0, 0, 0]
Sample 16 assigned to path [0, 0, 0]
Sample 17 assigned to path [0, 0, 0]
Sample 18 assigned to path [0, 0, 0]
Sample 19 assigned to path [0, 0, 0]
Sample 20 assigned to path [0, 0, 0]
Sample 21 assigned to path [0, 0, 0]
Sample 22 assigned to path [0, 0, 0]
Sample 23 assigned to path [0, 0, 0]
Sample 24 assigned to path [0, 0, 0]
Sample 25 assigned to path [0, 0, 0]
Sample 26 assigned to path [0, 0, 0]
Sample 27 assigned t