In [1]:
import scanpy as sc
adata = sc.read_h5ad("/home/jhaberbe/Projects/spatial-indian-buffet-process/data/new_annotations.h5ad")

In [3]:
import torch
counts = adata.X.todense()
counts = torch.tensor(counts)

## Reading Blei's Paper first.

In [None]:
import pyro.distributions as dist

In [None]:
# Dirichlet Distributions
# Simplex of nonnegative real numbers that sum to one
dist.Dirichlet(torch.ones(10)).sample(), dist.Dirichlet(torch.ones(10)).sample().sum()

# for a random vector U, distributed as a dirichlet random variable on the K-simplex, with alpha > 0
# the mean of u is proportional to the parameters
torch.ones(10) / torch.ones(10).sum()

tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000])

### Chinese Restaurant Process

Just a demonstration of the generative process that is assumed to be underlying it.

Part of our inferential process is the idea that each of these tables has some beta parameter underlying it, drawn from a distribution G_0.

In [None]:
gamma = 1
tables = []

for i in tqdm(range(1_000_000)):
    if i == 0:
        tables.append([i])
    else:
        totals = np.array([len(x) for x in tables])
        probabilities = totals / (gamma + i)
        table_draw = np.random.multinomial(1, np.array(probabilities.tolist() + [(gamma / (gamma + i))])).argmax()
        if (table_draw + 1) > len(tables):
            tables.append([i])
        else:
            tables[table_draw].append(i)

100%|██████████| 1000000/1000000 [00:03<00:00, 294598.37it/s]


### Stick Breaking Construction

Its defined as breaking sticks along the [0, 1] domain recursively through draws of a Beta distribution.
$$
V_1 \sim \text{Beta}(\alpha, \beta)
$$


Which broadly define the following stick breaking procedure.
$$
\theta_i = V_i \prod_{j=1}^{i-1}(1-V_j)
$$





In [177]:
gamma = 2.
stick_breaking = dist.Beta(1., gamma).sample_n(100)
theta = stick_breaking * torch.cat((torch.ones(1), (1-stick_breaking).cumprod(dim=0)[:-1]))
theta

  stick_breaking = dist.Beta(1., gamma).sample_n(100)


tensor([3.0311e-01, 2.4515e-01, 1.5978e-01, 5.6117e-02, 1.2349e-02, 9.6796e-02,
        8.9983e-02, 5.3761e-03, 9.3138e-03, 1.8915e-02, 3.4388e-05, 1.4186e-04,
        4.7850e-04, 1.6749e-03, 1.5183e-04, 2.4349e-04, 1.3128e-04, 1.8304e-05,
        1.6316e-04, 3.5574e-05, 2.3442e-06, 1.9440e-05, 1.1627e-05, 4.1088e-06,
        2.0119e-07, 7.3713e-07, 9.1857e-09, 6.4655e-07, 4.8955e-08, 3.3182e-08,
        2.2598e-08, 1.4204e-08, 1.8048e-08, 1.7206e-08, 5.3102e-09, 2.7904e-10,
        5.0013e-10, 5.5598e-10, 1.2019e-09, 4.2880e-10, 1.2739e-10, 2.5140e-11,
        1.5899e-11, 3.4160e-11, 2.2752e-11, 6.1126e-11, 6.7262e-11, 1.6722e-11,
        1.3781e-11, 5.3673e-12, 1.0974e-12, 9.1572e-14, 4.7706e-13, 8.1917e-14,
        2.7165e-13, 1.2272e-13, 8.3815e-14, 3.5911e-14, 6.4449e-14, 3.2909e-14,
        8.2922e-14, 8.8392e-15, 2.7986e-14, 4.2185e-15, 7.5553e-16, 9.1296e-16,
        1.3687e-15, 1.2551e-15, 2.5918e-16, 4.3121e-16, 1.3772e-16, 1.1296e-16,
        1.1877e-17, 4.6307e-18, 7.1279e-

### 2.4 Connections

The major connection is that a paper by Pitman in 2002 shows that the way CRP and GEM partition distribution is the same.

We can sort of get at what G is by doing the following:

$$
G = \sum_{i=1}^{\infty}\theta_i\delta_{\beta_i}
$$

Our $\beta_i$ is an atom that is at location $\beta_i$

They define a special case of the GEM, out of convenience, which gives control over the mean and variance of the stick lengths

In [188]:
def adjusted_gem_distribution(m, pi, n_samples = 1):
    assert 0 < m < 1
    assert 0 < pi
    return dist.Beta(m*pi, (1-m)*pi).sample_n(n_samples)

adjusted_gem_distribution(.5, .5)

  return dist.Beta(m*pi, (1-m)*pi).sample_n(n_samples)


tensor([0.2242])

### 3. Nested Chinese Restaurant Process

- for each of the tables in the infinite tree:
    - draw a topic b_k ~ Diriclet(eta)

In [191]:
eta = torch.ones(10)
dist.Dirichlet(eta).sample_n(100)

  dist.Dirichlet(eta).sample_n(100)


tensor([[1.0161e-01, 1.4334e-01, 1.2585e-01, 3.3448e-02, 1.0561e-01, 3.5629e-02,
         4.7144e-02, 1.7794e-01, 2.2239e-01, 7.0455e-03],
        [7.7685e-02, 5.0154e-02, 1.4711e-01, 2.9614e-02, 2.0832e-01, 4.5049e-02,
         2.0489e-02, 1.8380e-01, 2.0290e-02, 2.1748e-01],
        [1.5905e-01, 7.2145e-02, 1.0704e-01, 1.7331e-01, 1.9187e-02, 3.2230e-02,
         7.9466e-02, 1.7980e-01, 6.5757e-02, 1.1201e-01],
        [4.6743e-02, 5.1934e-02, 3.9134e-02, 2.0828e-03, 4.7203e-01, 8.5971e-02,
         1.2111e-01, 3.6458e-02, 8.6001e-02, 5.8542e-02],
        [1.8867e-03, 2.9350e-03, 4.4882e-02, 4.8009e-02, 8.1955e-02, 5.5452e-01,
         4.4684e-02, 3.7808e-02, 4.0455e-02, 1.4286e-01],
        [1.4991e-01, 1.5454e-02, 6.0992e-03, 2.7044e-01, 5.6654e-03, 1.4033e-01,
         8.8273e-02, 1.3041e-02, 1.4376e-01, 1.6703e-01],
        [3.6220e-02, 1.2263e-03, 2.8848e-02, 1.9247e-01, 6.9255e-03, 5.7170e-02,
         8.9803e-02, 2.8273e-01, 5.8561e-02, 2.4605e-01],
        [1.1648e-01, 1.1843

In [238]:
import pyro
import pyro.nn as pnn
import torch.nn as nn

eta = torch.ones(5)

class TreeNode:
    def __init__(self, name=None, parent=None):
        super().__init__()
        self.name = name
        self.parent = parent
        self.children = []
        self.topic = None
        self.draw_topic(eta)
        self.documents = []

    def add_child(self, name):
        self.children.append(TreeNode(name="name", parent=self))

    def forward(self, *args, **kwargs):
        raise NotImplementedError("Override `forward` in subclasses.")
    
    def draw_topic(self, eta):
        self.topic = dist.Dirichlet(eta).sample()

## Setting up the tree.

In [184]:
import torch
import pyro
import pyro.distributions as dist
from collections import defaultdict
import random

class GibbsSampler:
    def __init__(self, root, alpha=1.0, max_depth=3):
        self.root = root
        self.alpha = alpha
        self.max_depth = max_depth
        self.path_registry = defaultdict(list)  # {index: [node1, node2, ...]}

    def remove_index(self, idx):
        for node in self.path_registry[idx]:
            node.remove_index(idx, self.path_registry)

            # If we have an empty tree, we now remove the sucker.
            if len(node.indices) == 0:
                del(node)
        del self.path_registry[idx]

    def sample_path(self, idx, counts=counts):
        current_node = self.root
        self.path_registry[idx] = []

        for depth in range(self.max_depth):
            children = list(current_node.children.values())
            num_children = len(children)

            # Gather CRP prior weights
            crp_log_weights = torch.tensor([
                torch.log(torch.tensor(len(child.indices), dtype=torch.float32) + 1e-8)  # log(n_k)
                for child in children
            ] + [torch.log(torch.tensor(self.alpha))])  # log(γ) for new node

            # Compute log-likelihoods
            log_liks = torch.tensor([
                dist.Categorical(child.params).log_prob(counts[idx]).sum()
                for child in children
            ] + [dist.Categorical(current_node.params).log_prob(counts[idx]).sum()])

            # Combine prior + likelihood
            log_probs = crp_log_weights + log_liks

            # Normalize and sample
            probs = torch.softmax(log_probs, dim=0)
            choice = dist.Categorical(probs).sample()

            if choice == num_children:
                name = f"node_{depth}_{len(children)}"
                child = current_node.add_child(name)
            else:
                child = children[choice]

            child.add_index(idx, self.path_registry)
            current_node = child


    def resample_all(self, indices):
        for idx in tqdm(indices):
            if idx in self.path_registry:
                self.remove_index(idx)
            self.sample_path(idx)

import pyro
import pyro.nn as pnn
import torch.nn as nn
from collections import defaultdict

# Giving counts is BAD, STINKY
class TreeNode:
    def __init__(self, name=None, parent=None, counts = counts):
        self.name = name
        self.parent = parent
        self.children = dict()
        self.indices = set()
        # TODO: Figure out this.
        self.params = None
        self.estimate_dirichlet_alpha(counts)

    def add_child(self, name):
        if name in self.children:
            raise ValueError(f"Child '{name}' already exists.")
        child = TreeNode(name=name, parent=self)
        self.children[name] = child
        return child

    def get_or_create_child(self, name):
        return self.children.get(name) or self.add_child(name)

    # counts here is also BAD BAD BAD
    def add_index(self, idx, path_registry, counts = counts):
        self.indices.add(idx)
        path_registry[idx].append(self)
        self.estimate_dirichlet_alpha(counts)

    def remove_index(self, idx, path_registry):
        self.indices.discard(idx)
        if self in path_registry[idx]:
            path_registry[idx].remove(self)

    def is_leaf(self):
        return len(self.children) == 0

    def estimate_dirichlet_alpha(self, counts):
        if len(self.indices) == 0:
            self.params = torch.ones(counts.shape[1])
        else:
            self.params = fit_dirichlet_multinomial(counts[list(self.indices)])

    def __repr__(self):
        return f"TreeNode(name={self.name}, num_indices={len(self.indices)}, num_children={len(self.children)})"


In [180]:
root = TreeNode("root", counts=counts)
sampler = GibbsSampler(root, alpha=1.0, max_depth=3)
root.params = alpha

In [181]:
from tqdm import tqdm
indices = list(range(1000))
for idx in tqdm(indices):
    sampler.sample_path(idx)

  0%|          | 0/1000 [00:00<?, ?it/s]

{0}
tensor([0.])
tensor([0.])





ValueError: not enough values to unpack (expected 2, got 1)

In [168]:
sampler.resample_all(indices)

  0%|          | 0/1000 [00:00<?, ?it/s]

tensor([0.])





ValueError: not enough values to unpack (expected 2, got 1)

In [169]:
root.children

{'node_0_0': TreeNode(name=node_0_0, num_indices=0, num_children=0),
 'node_0_1': TreeNode(name=node_0_1, num_indices=1, num_children=0)}

In [85]:
def get_max_depth(node):
    if not node.children:
        return 1
    return 1 + max(get_max_depth(child) for child in node.children.values())

get_max_depth(root)

4

In [86]:
root.children

{'node_0_0': TreeNode(name=node_0_0, num_indices=13, num_children=2),
 'node_0_1': TreeNode(name=node_0_1, num_indices=9, num_children=2),
 'node_0_2': TreeNode(name=node_0_2, num_indices=1, num_children=1),
 'node_0_3': TreeNode(name=node_0_3, num_indices=112, num_children=3),
 'node_0_4': TreeNode(name=node_0_4, num_indices=1, num_children=1),
 'node_0_5': TreeNode(name=node_0_5, num_indices=7, num_children=2),
 'node_0_6': TreeNode(name=node_0_6, num_indices=1, num_children=1),
 'node_0_7': TreeNode(name=node_0_7, num_indices=1, num_children=1),
 'node_0_8': TreeNode(name=node_0_8, num_indices=1, num_children=1),
 'node_0_9': TreeNode(name=node_0_9, num_indices=4, num_children=1),
 'node_0_10': TreeNode(name=node_0_10, num_indices=1, num_children=1),
 'node_0_11': TreeNode(name=node_0_11, num_indices=5, num_children=1),
 'node_0_12': TreeNode(name=node_0_12, num_indices=1, num_children=1),
 'node_0_13': TreeNode(name=node_0_13, num_indices=1, num_children=1),
 'node_0_14': TreeNode(