In [1]:
import numpy as np
import torch
from torchviz import make_dot, make_dot_from_trace
from IPython.core.debugger import set_trace

In [2]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f8f043bb0d0>

In [3]:
def self_attention_before(n, w_1, w_2, x):  # (h, 3h), (h, h), (s, b, h)
    s, b, h = x.shape
    x1 = np.matmul(x, w_1)  # (s, b, 3h)

    x2 = np.reshape(x1, (s, b, n, 3 * h // n))
    q, k, v = np.split(x2, 3, axis=3)  # (s, b, n, h/n)

    q = np.reshape(q, (s, b * n, h // n))
    q = np.transpose(q, (1, 0, 2))  # (bn, s, h/n)
    k = np.reshape(k, (s, b * n, h // n))
    k = np.transpose(k, (1, 2, 0))  # (bn, h/n, s)
    v = np.reshape(v, (s, b * n, h // n))
    v = np.transpose(v, (1, 0, 2))  # (bn, s, h/n)

    y1 = np.matmul(q, k)  # (bn, s, s)

    # Ignoring a scale mask, softmax, and dropout here

    y2 = np.matmul(y1, v)  # (bn, s, h/n)
    # No idea why megatron goes through these layout changes:
    y2 = np.reshape(y2, (b, n, s, h // n))
    y2 = np.transpose(y2, (2, 0, 1, 3))  # (s, b, n, h/n)
    y2 = np.reshape(y2, (s, b, h))

    # RowParallelLinear
    z = np.matmul(y2, w_2)  # (s, b, h)

    return z

In [4]:
def attention_pytorch():
    s = 3  # sequence length
    b = 4  # batch size
    n = 2  # num attention heads
    h = 6  # hidden size
    p = 2  # num partitions

    def attention(x, w_1, w_2):
        x1 = torch.matmul(x, w_1)  # (s, b, 3h)

        x2 = torch.reshape(x1, (s, b, n, 3 * h // n))
        q, k, v = torch.split(x2, 3, dim=3)  # (s, b, n, h/n)

        q = torch.reshape(q, (s, b * n, h // n))
        q = q.permute(1, 0, 2)  # (bn, s, h/n)
        k = torch.reshape(k, (s, b * n, h // n))
        k = k.permute(1, 2, 0)  # (bn, h/n, s)
        v = torch.reshape(v, (s, b * n, h // n))
        v = v.permute(1, 0, 2)  # (bn, s, h/n)

        y1 = torch.matmul(q, k)  # (bn, s, s)

        # Ignoring a scale mask, softmax, and dropout here

        y2 = torch.matmul(y1, v)  # (bn, s, h/n)
        # No idea why megatron goes through these layout changes:
        y2 = torch.reshape(y2, (b, n, s, h // n))
        y2 = y2.permute(2, 0, 1, 3)  # (s, b, n, h/n)
        y2 = torch.reshape(y2, (s, b, h))

        # RowParallelLinear
        z = torch.matmul(y2, w_2)  # (s, b, h)

        np.testing.assert_array_almost_equal(
            z.detach().numpy(),
            self_attention_before(
                n, w_1.detach().numpy(), w_2.detach().numpy(), x.numpy()
            ),
            decimal=4,
        )
        
    x = torch.randn((s, b, h))  # model input/output of previous layer
    w_1 = torch.autograd.Variable(torch.randn((h, 3 * h)), requires_grad=True)
    w_2 = torch.autograd.Variable(torch.randn(h, h), requires_grad=True)
    attention(x, w_1, w_2)
#     traced_attention = torch.jit.trace(attention, (x, w_1, w_2))
#     torch.jit.save(traced_attention, "attention.pt")

In [5]:
attention_pytorch()