In [3]:
import itertools
import torch

In [None]:

def enumerate_paths(n, n_branches):
    """
    Enumerate all paths (from the root to the n-th decision node) in a perfect 
    balanced k-ary tree with a fixed number of decisions (n) and fixed number 
    of branches (n_branches).

    The tree is assumed to be stored in "heap order":
       - The root is at index 0.
       - For a node at index i, its b-th child (with b in {0, ..., n_branches-1}) 
         is at index: n_branches * i + (b + 1).

    Parameters:
      n (int): Number of decision layers (i.e., the path length).
      n_branches (int): Number of branches (children) per node (>=2).

    Returns:
      node_paths (torch.LongTensor): Tensor of shape (n_all_paths, n) with the indices 
                                     of nodes where decisions occur.
      branch_paths (torch.LongTensor): Tensor of shape (n_all_paths, n) with the branch 
                                       choices made at each corresponding node.
    """
    # Generate all possible sequences of branch choices, where each sequence is of length n.
    branch_sequences = list(itertools.product(range(n_branches), repeat=n))
    node_paths = []
    branch_paths = []
    
    # For each branch choice sequence, compute the corresponding node indices.
    for seq in branch_sequences:
        nodes = [0]  # The first decision is always at the root (index 0).
        # For each decision (except the last one, because we only need the node where the decision is made)
        # we compute the next node index.
        for i in range(1, n):
            current_node = nodes[i-1]
            branch_choice = seq[i-1]
            next_node = n_branches * current_node + (branch_choice + 1)
            nodes.append(next_node)
        # Append the list of node indices and branch choices for this path.
        node_paths.append(nodes)
        branch_paths.append(list(seq))
    
    # Convert the lists to PyTorch tensors.
    node_paths = torch.tensor(node_paths, dtype=torch.long)    # shape: (n_all_paths, n)
    branch_paths = torch.tensor(branch_paths, dtype=torch.long)  # shape: (n_all_paths, n)
    
    return node_paths, branch_paths

# Example usage:
if __name__ == "__main__":
    n = 3           # number of decision layers
    n_branches = 3  # for a ternary tree (you can set this to any integer >= 2)

    node_ids, branch_ids = enumerate_paths(n, n_branches)
    print("Node indices:\n", node_ids)
    print("Branch choices:\n", branch_ids)


In [10]:
def estimate_memory(n_branch, n_layers, dtype=torch.int8):
    elements = (n_branch ** n_layers) * n_layers
    bytes_per_element = torch.tensor(0, dtype=dtype).element_size()
    print(f'number of experts: {elements}')
    return f"{(elements * bytes_per_element)/(1024**2):.2f} MB"
estimate_memory(3,12)

number of experts: 6377292


'6.08 MB'

In [10]:
def select_probabilities(prob_tensor, node_paths, branch_paths):
    selected_probs = prob_tensor[:, node_paths, branch_paths]
    return selected_probs

In [None]:
if __name__ == '__main__':
    # Parameters for the tree.
    n_layers = 3        # number of decision layers (path length)
    n_branches = 2      # number of branches per node (e.g., a ternary tree)
    
    # Get the index tensors from enumerate_paths.
    node_paths, branch_paths = enumerate_paths(n_layers, n_branches)
    print("node_paths shape:", node_paths.shape)   # Expected: (n_branches**n_layers, n_layers)
    print("branch_paths shape:", branch_paths.shape)  # Expected: (n_branches**n_layers, n_layers)
    
    # Create an example probability tensor.
    # For instance, assume we have n_tokens tokens.
    n_tokens = 1
    # Ensure that the number of nodes is at least (max(node_paths) + 1)
    n_nodes = node_paths.max().item() + 1
    # Construct a probability tensor of shape (n_tokens, n_nodes, n_branches)
    prob_tensor = torch.rand(n_tokens, n_nodes, n_branches)
    
    expert_indices = node_paths * n_branches + branch_paths
    
    # Use the select_probabilities function to get the tensor of selected probabilities.
    selected_probs = select_probabilities(prob_tensor, node_paths, branch_paths)
    print("selected_probs shape:", selected_probs.shape)  # Expected: (n_tokens, n_all_paths, n_layers)
    print(expert_indices)

In [None]:
import torch
import torch.nn.functional as F

def compute_token_expert_weights(expert_indices, rel_weights, n_experts):
    """
    Compute per-token weights for each expert.

    Parameters:
      expert_indices (torch.LongTensor): shape (n_tokens, n_paths, n_layers)
          Each element is the index (in [0, n_experts)) of the expert chosen.
      rel_weights (torch.Tensor): shape (n_tokens, n_paths, n_layers)
          The relative weight corresponding to each routing decision.
      n_experts (int): Total number of experts.

    Returns:
      token_expert_weights (torch.Tensor): shape (n_tokens, n_experts)
          For each token and each expert, this contains the sum of weights from
          all routing decisions that selected that expert.
    """
    # One-hot encode the expert indices.
    # Resulting shape: (n_tokens, n_paths, n_layers, n_experts)
    one_hot = F.one_hot(expert_indices, num_classes=n_experts)

    # Multiply by the relative weights.
    # rel_weights.unsqueeze(-1) has shape (n_tokens, n_paths, n_layers, 1)
    weighted_one_hot = one_hot * rel_weights.unsqueeze(-1)

    # Sum over the n_paths and n_layers dimensions to obtain per-token weights for each expert.
    # Resulting shape: (n_tokens, n_experts)
    token_expert_weights = weighted_one_hot.sum(dim=(1, 2))
    
    return token_expert_weights

# =======================
# Example usage:
# =======================
if __name__ == "__main__":
    # Example dimensions.
    n_tokens = 2
    n_paths  = 2
    n_layers = 2
    n_experts = 2

    # Create random expert indices (integers in [0, n_experts))
    expert_indices = torch.randint(0, n_experts, (n_tokens, n_paths, n_layers))

    # Create random relative weights (for example, between 0 and 1)
    rel_weights = torch.rand(n_tokens, n_paths, n_layers)

    # Compute the token-expert weights
    token_expert_weights = compute_token_expert_weights(expert_indices, rel_weights, n_experts)
    print("Token-Expert Weights Shape:", token_expert_weights.shape)
    print(token_expert_weights)


In [None]:
expert_indices

In [None]:
rel_weights

In [None]:
import math

def get_expert_layer(index, branch_size):
    """
    Given the expert's index (starting at 0) and the branch size (n),
    returns the layer number on which the expert is located.
    
    For branch_size n > 1, layer L contains n^L experts, and
    the cumulative number of experts up to layer L is:
    
        T(L) = n + n^2 + ... + n^L = n * (n^L - 1) / (n - 1)
    
    The layer L is the smallest integer satisfying:
    
        index < T(L)
    
    This function uses the formula:
    
        L = floor( log_n( 1 + ((n - 1) * index) / n ) ) + 1
    
    If branch_size == 1, the tree is a chain, and expert at index i
    is simply in layer i + 1.
    """
    if branch_size == 1:
        return index + 1  # Each layer has 1 expert in this special case.

    # Calculate the expression inside the logarithm.
    value = 1 + ((branch_size - 1) * index) / branch_size
    # Use math.log with base branch_size.
    layer = math.floor(math.log(value, branch_size)) + 1
    return layer

# Example usage:
if __name__ == "__main__":
    branch_size = 1  # For example, each node has 3 branches
    test_indices = [0, 1, 2, 3, 4, 5, 6,7,8,9,10,11,12,13,14,15,16]

    for i in test_indices:
        layer = get_expert_layer(i, branch_size)
        print(f"Expert with index {i} is in layer {layer}.")