In [2]:
import torch
import numpy as np

In [68]:
def forward(X, A, V, mask):
    M = (A[..., :, None] - A[..., None, :]).masked_fill(~mask, -torch.inf).exp()
    O = M * X
    return O @ V

def manual_grad(X, A, V, mask):
    X = X.clone().detach()
    A = A.clone().detach().requires_grad_()
    V = V.clone().detach()
    mask = mask.clone().detach()
    out = forward(X, A, V, mask)
    out.sum().backward()
    return A.grad

class Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, A, V, mask):
        ctx.save_for_backward(X, A, V, mask)
        return forward(X, A, V, mask)

    @staticmethod
    def backward(ctx, prev_grad):
        X, A, V, mask = ctx.saved_tensors
        
        M = (A[..., :, None] - A[..., None, :]).masked_fill(~mask, -torch.inf).exp()
        da = prev_grad @ V.mT
        vals = da * M * X
        
        A_grad = vals.sum(-1) - vals.sum(-2)
        
        return None, A_grad, None, None

In [69]:
N = 4
H = 4
S = 1024
d = 64

X = torch.randn(N, H, S, S).cuda().detach()
A = (-torch.nn.functional.softplus(torch.randn(N, H, S) * 0.1).cuda().cumsum(-1)).detach().requires_grad_()
V = torch.randn(N, H, S, d).cuda().detach()
mask = torch.tril(torch.ones(S, S)).bool()[None, None, ...].cuda().detach()

In [70]:
A_grad = manual_grad(X, A, V, mask)
A_grad

tensor([[[  1.1771,  -2.4842,  -0.8229,  ...,  -2.9013,  -8.1079,   5.7931],
         [  0.7770,   0.4071,   6.8530,  ...,   3.4741,  -8.4213,   3.5603],
         [ -4.3483,   1.9367,   2.1912,  ..., -14.8011,  12.6799,  -0.6350],
         [-18.0781,  15.1313,   4.7784,  ...,  10.2108,  -9.6315,  15.0513]],

        [[ -3.0058,   3.5219,  -0.3513,  ...,  -0.0805,  -1.3040,   2.3503],
         [ -2.6475,   2.8694,   0.9606,  ...,  -0.8401,   1.6273,  -0.9460],
         [  1.6462,   0.6906,  -2.9011,  ...,   4.2723,  13.6414, -14.4975],
         [ -1.9470,   1.1767,   0.8330,  ...,  -6.4973,   7.4139,   2.7352]],

        [[  1.9239,   7.2508,  -9.1321,  ...,   6.5867,  -0.0471,   2.4892],
         [  2.3440,  -5.6813,   4.9629,  ...,  -3.3493,  -2.7129,   1.0615],
         [ -2.6036,   0.2488,   0.3786,  ...,   1.3992,  -2.6434,   0.8331],
         [  3.0142,  -0.6237,  -0.4436,  ...,  -0.0514,  -1.9302,  -1.3114]],

        [[  0.5355,  -4.7543,   2.7631,  ...,  -1.5342,  10.6232, -10.

In [71]:
out = Function.apply(X, A, V, mask)

In [72]:
out.sum().backward()

In [73]:
A_grad_ = A.grad
A_grad_

tensor([[[  1.1771,  -2.4842,  -0.8229,  ...,  -2.9013,  -8.1079,   5.7931],
         [  0.7770,   0.4071,   6.8530,  ...,   3.4741,  -8.4213,   3.5603],
         [ -4.3483,   1.9367,   2.1912,  ..., -14.8011,  12.6799,  -0.6350],
         [-18.0781,  15.1313,   4.7784,  ...,  10.2108,  -9.6315,  15.0513]],

        [[ -3.0058,   3.5219,  -0.3513,  ...,  -0.0805,  -1.3040,   2.3503],
         [ -2.6475,   2.8694,   0.9606,  ...,  -0.8401,   1.6273,  -0.9460],
         [  1.6462,   0.6906,  -2.9011,  ...,   4.2723,  13.6414, -14.4975],
         [ -1.9471,   1.1767,   0.8330,  ...,  -6.4973,   7.4139,   2.7352]],

        [[  1.9239,   7.2508,  -9.1321,  ...,   6.5867,  -0.0471,   2.4892],
         [  2.3440,  -5.6813,   4.9629,  ...,  -3.3493,  -2.7129,   1.0615],
         [ -2.6036,   0.2488,   0.3786,  ...,   1.3992,  -2.6434,   0.8331],
         [  3.0142,  -0.6237,  -0.4436,  ...,  -0.0514,  -1.9302,  -1.3114]],

        [[  0.5355,  -4.7543,   2.7631,  ...,  -1.5342,  10.6232, -10.

In [76]:
(A_grad - A_grad_).abs().max()

tensor(7.6294e-06, device='cuda:0')

In [75]:
torch.autograd.gradcheck(Function.apply, (X, A, V, mask), eps=1e-4)



OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 GiB. GPU 0 has a total capacity of 79.25 GiB of which 13.55 GiB is free. Process 1737049 has 988.00 MiB memory in use. Including non-PyTorch memory, this process has 64.72 GiB memory in use. Of the allocated memory 64.16 GiB is allocated by PyTorch, and 66.44 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)