In [15]:
import torch

In [16]:
batch_size = 2
n = 3
m = 4
r = 5

In [17]:
a = torch.randn(batch_size, n, r)
b = torch.randn(batch_size, r, m)

In [18]:
# ensure positive
a = torch.abs(a)
b = torch.abs(b)

In [19]:
b[0][0][2] = 0

In [20]:
expected = torch.bmm(a, b)

In [21]:
def log_bmm2(log_A, log_B):
    """
    Performs a batch matrix multiplication in log space.

    Args:
        log_A: A tensor of shape (b, m, n) representing log(A).
        log_B: A tensor of shape (b, n, p) representing log(B).

    Returns:
        A tensor of shape (b, m, p) representing log(A @ B).
    """
    b, m, n = log_A.shape
    _, _, p = log_B.shape

    # 1. Expand dimensions to align for element-wise addition (broadcast)
    log_A_expanded = log_A.unsqueeze(3)  # Shape (b, m, n, 1)
    log_B_expanded = log_B.unsqueeze(1)  # Shape (b, 1, n, p)

    # 2. Perform addition in log-space for equivalent to product in linear space
    log_product = log_A_expanded + log_B_expanded  # Shape (b, m, n, p)

    # 3. LogSumExp over the `n` dimension (matrix multiplication reduction)
    log_C = torch.logsumexp(log_product, dim=2)  # Shape (b, m, p)

    return log_C

In [22]:
got = log_bmm2(torch.log(a), torch.log(b)).exp()

In [23]:
got.shape

torch.Size([2, 3, 4])

In [24]:
expected.shape

torch.Size([2, 3, 4])

In [25]:
assert torch.allclose(got, expected)