In [1]:
import torch
import numpy as np

In [None]:
def forward(QK, M, V):
    tmp = QK**2 * torch.exp(M)
    A = tmp / tmp.sum(-1, keepdim=True)
    return A @ V

def manual_grad(QK, M, V):
    QK_ = QK.clone().detach().requires_grad_()
    M_ = M.clone().detach().requires_grad_()
    V_ = V.clone().detach().requires_grad_()
    out = forward(QK_, M_, V_)
    out.sum().backward()
    return QK_.grad, M_.grad, V_.grad

class Squaremax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, QK, M, V):
        out = forward(QK, M, V)
        ctx.save_for_backward(QK, M, V, out)
        return out

    @staticmethod
    def backward(ctx, prev_grad):
        QK, M, V, out = ctx.saved_tensors
        M_ = torch.exp(M)
        tmp = QK**2 * M_
        out = tmp / tmp.sum(-1, keepdim=True)
        
        do = prev_grad @ V.mT
        
        # Z = sum_j x_j^2 along the last dimension
        Z = (tmp).sum(dim=-1, keepdim=True)

        # Using J^T v = (2/Z) * x * ( v - (vÂ·y) ), where y = out
        v_dot_y = (do * out).sum(dim=-1, keepdim=True)
        grad_tmp = (1 / Z) * M_ * (do - v_dot_y)
        grad_QK = 2 * QK * grad_tmp
        
        # Due to the exponential, we keep M_ in grad_tmp. Without
        # it, we would remvoe it for the M grad.
        grad_M = grad_tmp * QK**2

        return grad_QK, grad_M, None

In [10]:
dim1 = 256
dim2 = 256
QK = torch.randn(dim1, dim2).cuda().double().detach().requires_grad_()
mask = torch.tril(torch.ones(dim1, dim2)).bool().cuda().detach()
M = (torch.randn(dim1, dim2).cuda().double() * mask).detach().requires_grad_()
V = torch.randn(dim2, 128).cuda().double().detach().requires_grad_()

In [11]:
QK_grad, M_grad, V_grad = manual_grad(QK, M, V)
QK_grad

tensor([[-0.0547,  0.0172, -0.2347,  ..., -0.1243,  0.0523, -0.0247],
        [ 0.0791, -0.1117, -0.0240,  ...,  0.0339, -0.0589,  0.0626],
        [ 0.0352,  0.0638,  0.0155,  ...,  0.2923,  0.0679,  0.0813],
        ...,
        [-0.0088, -0.0155,  0.0156,  ...,  0.0343,  0.0423,  0.0812],
        [ 0.0101, -0.0179,  0.0101,  ..., -0.1450,  0.0500,  0.0641],
        [-0.0125,  0.0020,  0.0180,  ...,  0.1526, -0.0952, -0.0195]],
       device='cuda:0', dtype=torch.float64)

In [12]:
out = Squaremax.apply(QK, M, V)

In [13]:
out.sum().backward()

In [14]:
QK_grad_ = QK.grad
QK_grad_

tensor([[-0.0547,  0.0172, -0.2347,  ..., -0.1243,  0.0523, -0.0247],
        [ 0.0791, -0.1117, -0.0240,  ...,  0.0339, -0.0589,  0.0626],
        [ 0.0352,  0.0638,  0.0155,  ...,  0.2923,  0.0679,  0.0813],
        ...,
        [-0.0088, -0.0155,  0.0156,  ...,  0.0343,  0.0423,  0.0812],
        [ 0.0101, -0.0179,  0.0101,  ..., -0.1450,  0.0500,  0.0641],
        [-0.0125,  0.0020,  0.0180,  ...,  0.1526, -0.0952, -0.0195]],
       device='cuda:0', dtype=torch.float64)

In [15]:
(QK_grad - QK_grad_).abs().max()

tensor(8.8818e-16, device='cuda:0', dtype=torch.float64)

In [9]:
M_grad

tensor([[-1.3695e-02, -2.4378e-02, -1.1753e-02,  ..., -5.2944e-03,
         -1.6745e-01, -5.7915e-02],
        [-1.8094e-02, -9.2249e-02, -1.4489e-03,  ...,  4.9065e-04,
         -5.6104e-01, -7.7263e-02],
        [-1.8147e-03, -2.4071e-03, -1.2420e-02,  ..., -3.7256e-04,
         -6.1703e-03, -9.8866e-03],
        ...,
        [-5.5734e-03, -4.5801e-04, -2.6770e-04,  ...,  6.0660e-04,
         -1.6904e-02, -3.3013e-02],
        [-4.3928e-06, -6.2730e-03, -1.3781e-02,  ..., -3.0220e-02,
         -2.1008e-03, -8.8513e-03],
        [-3.2837e-02, -4.8665e-03, -2.8812e-04,  ..., -1.0095e-03,
         -5.0367e-02, -2.9068e-04]], device='cuda:0', dtype=torch.float64)

In [10]:
M_grad_ = M.grad
M_grad_

tensor([[-1.3695e-02, -2.4378e-02, -1.1753e-02,  ..., -5.2944e-03,
         -1.6745e-01, -5.7915e-02],
        [-1.8094e-02, -9.2249e-02, -1.4489e-03,  ...,  4.9065e-04,
         -5.6104e-01, -7.7263e-02],
        [-1.8147e-03, -2.4071e-03, -1.2420e-02,  ..., -3.7256e-04,
         -6.1703e-03, -9.8866e-03],
        ...,
        [-5.5734e-03, -4.5801e-04, -2.6770e-04,  ...,  6.0660e-04,
         -1.6904e-02, -3.3013e-02],
        [-4.3928e-06, -6.2730e-03, -1.3781e-02,  ..., -3.0220e-02,
         -2.1008e-03, -8.8513e-03],
        [-3.2837e-02, -4.8665e-03, -2.8812e-04,  ..., -1.0095e-03,
         -5.0367e-02, -2.9068e-04]], device='cuda:0', dtype=torch.float64)

In [11]:
(M_grad - M_grad_).abs().max()

tensor(8.8818e-16, device='cuda:0', dtype=torch.float64)

In [75]:
torch.autograd.gradcheck(Function.apply, (X, A, V, mask), eps=1e-4)



OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 GiB. GPU 0 has a total capacity of 79.25 GiB of which 13.55 GiB is free. Process 1737049 has 988.00 MiB memory in use. Including non-PyTorch memory, this process has 64.72 GiB memory in use. Of the allocated memory 64.16 GiB is allocated by PyTorch, and 66.44 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)