In [1]:
import torch
from torch.distributions.multivariate_normal import _batch_trtrs_lower

In [2]:
def _batch_mahalanobis(bL, bx, fixed=False):
    n = bx.size(-1)
    if fixed:
        # assume bL.shape = (i, j, n, n), bx.shape = (K, I, J, n)
        # we are going to reshape bx into the shape (K, I/i, i, J/j, j, n)
        # then permute the shape into (K, I/i, J/j, i, j, n)
        bx_batch_shape = bx.shape[:-1]
        bL_batch_dims = bL.dim() - 2
        outer_batch_dims = len(bx_batch_shape) - bL_batch_dims
        old_batch_dims = outer_batch_dims + bL_batch_dims
        new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
        bx_new_shape = bx.shape[:outer_batch_dims]
        for (s, S) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
            bx_new_shape += (S // s, s)
        bx_new_shape += (n,)
        bx = bx.reshape(bx_new_shape)
        permute_dims = (list(range(outer_batch_dims)) +
                        list(range(outer_batch_dims, new_batch_dims, 2)) +
                        list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
                        [new_batch_dims])
        bx = bx.permute(permute_dims)
    else:
        bL = bL.expand(bx.shape[bx.dim() - bL.dim() + 1:] + (n,))
    flat_L = bL.reshape(-1, n, n)  # shape = b x n x n
    flat_x = bx.reshape(-1, flat_L.size(0), n)  # shape = c x b x n
    flat_x_swap = flat_x.permute(1, 2, 0)  # shape = b x n x c
    M_swap = _batch_trtrs_lower(flat_x_swap, flat_L).pow(2).sum(-2)  # shape = b x c
    M = M_swap.t()  # shape = c x b
    if fixed:
        permuted_M = M.reshape(bx.shape[:-1])
        permute_inv_dims = list(range(outer_batch_dims))
        for i in range(bL_batch_dims):
            permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
        reshaped_M = permuted_M.permute(permute_inv_dims)
        return reshaped_M.reshape(bx_batch_shape)
    else:
        return M.reshape(bx.shape[:-1])

In [3]:
bL = torch.eye(8).expand(6, 1, 8, 8).contiguous().requires_grad_()
bx = torch.randn(6, 8000, 8)

### compare forward speed

In [4]:
%%time
a = _batch_mahalanobis(bL, bx)

CPU times: user 2.72 s, sys: 47.7 ms, total: 2.77 s
Wall time: 586 ms


In [5]:
%%time
b = _batch_mahalanobis(bL, bx, fixed=True)

CPU times: user 46.5 ms, sys: 0 ns, total: 46.5 ms
Wall time: 2.61 ms


### compare forward+backward speed

In [6]:
%%time
c = torch.autograd.grad(_batch_mahalanobis(bL, bx).sum(), (bL,))[0]

CPU times: user 4min 28s, sys: 11min 39s, total: 16min 7s
Wall time: 1min 20s


In [7]:
%%time
d = torch.autograd.grad(_batch_mahalanobis(bL, bx, fixed=True).sum(), (bL,))[0]

CPU times: user 91.5 ms, sys: 220 ms, total: 312 ms
Wall time: 15.6 ms


### verify consistence

In [8]:
(a == b).all()

tensor(1, dtype=torch.uint8)

In [9]:
(c == d).all()

tensor(1, dtype=torch.uint8)

### 0.4 version

In [10]:
def _legacy_batch_mahalanobis(L, x):
    flat_L = L.unsqueeze(0).reshape((-1,) + L.shape[-2:])
    L_inv = torch.stack([torch.inverse(Li.t()) for Li in flat_L]).view(L.shape)
    return (x.unsqueeze(-1) * L_inv).sum(-2).pow(2.0).sum(-1)

In [11]:
%%time
e = _legacy_batch_mahalanobis(bL, bx)

CPU times: user 26.3 ms, sys: 109 ms, total: 135 ms
Wall time: 6.3 ms


In [12]:
%%timeit
f = torch.autograd.grad(_legacy_batch_mahalanobis(bL, bx).sum(), (bL,))[0]

7.73 ms ± 219 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### inverse vs trtrs with event_dim=2

In [13]:
bL = torch.eye(2).expand(6, 1, 2, 2).contiguous().requires_grad_()
bx = torch.randn(6, 8000, 2)

In [14]:
%%timeit
g = torch.autograd.grad(_batch_mahalanobis(bL, bx, fixed=True).sum(), (bL,))[0]

5.92 ms ± 24.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%%timeit
h = torch.autograd.grad(_legacy_batch_mahalanobis(bL, bx).sum(), (bL,))[0]

2.46 ms ± 12.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### inverse vs trtrs with event_dim=64

In [28]:
N = 8
bL = torch.eye(N).expand(6, 1, N, N).contiguous().requires_grad_()
bx = torch.randn(6, 8000, N)

In [29]:
%%timeit
g = torch.autograd.grad(_batch_mahalanobis(bL, bx, fixed=True).sum(), (bL,))[0]

6.34 ms ± 69.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [30]:
%%timeit
h = torch.autograd.grad(_legacy_batch_mahalanobis(bL, bx).sum(), (bL,))[0]

6.03 ms ± 96 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
