In [None]:
from typing import Optional

import numpy as np

import compyute as cp
from compyute.nn import Container, ParallelConcat, Linear, Sequential, Dropout, ReLU, SkipConnection, Layernorm, Embedding
from compyute.nn.functional import softmax
from compyute.dtypes import _DtypeLike
from compyute.base_tensor import Tensor


class MultiHeadAttention(Sequential):
    def __init__(
        self,
        embedding_dim: int,
        n_heads: int,
        mask: Optional[Tensor] = None,
        dropout: Optional[float] = None,
        bias: bool = False,
        dtype: _DtypeLike = "float32"
    ) -> None:
        layers = [
            ParallelConcat(*[
                AttentionHead(
                    embedding_dim=embedding_dim,
                    head_size=embedding_dim // n_heads,
                    mask=mask,
                    dropout=dropout,
                    bias=bias,
                    dtype=dtype
                ) for _ in range(n_heads)
            ], label="Heads"),
            Linear(embedding_dim, embedding_dim, bias, dtype, label="OutProjection")
        ]
        layers += [Dropout(p=dropout)] if dropout is not None else []
        super().__init__(*layers)


class AttentionHead(Container):
    def __init__(
        self,
        embedding_dim: int,
        head_size: int,
        mask: Optional[Tensor] = None,
        dropout: Optional[float] = None,
        bias: bool = False,
        dtype: _DtypeLike = "float32"
    ) -> None:
        super().__init__()
        self.q = Linear(embedding_dim, head_size, bias, dtype, label="QueryProjection")
        self.k = Linear(embedding_dim, head_size, bias, dtype, label="KeyProjection")
        self.v = Linear(embedding_dim, head_size, bias, dtype, label="ValueProjection")
        self.dropout = Dropout(p=dropout) if dropout else None
        self.head_size = head_size
        self.mask = mask
        self.dtype = dtype

    def forward(self, x: cp.Tensor) -> cp.Tensor:
        # input projections
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        # attention
        qk = q @ k.T * self.head_size**-0.5
        if self.mask is not None:
            qk += self.mask
        sm, sm_backward = softmax(qk, self._training)
        if self.dropout is not None:
            sm = self.dropout(sm)
        y = sm @ v

        if self.training:

            def _backward(dy: cp.Tensor) -> cp.Tensor:
                dy = dy.as_type(self.dtype)

                dsm = dy @ v.T

                if self.dropout is not None:
                    dsm = self.dropout.backward(dsm)

                dqk = sm_backward(dsm) * self.head_size**-0.5
                dq = self.q.backward(dqk @ k)
                dk = self.k.backward(dqk.T @ q)
                dv = self.v.backward(sm.T @ dy)

                return dq + dk + dv
            
            self._backward = _backward

        return y      

In [None]:
B, T, C, H = 1, 5, 8, 4
x = cp.random.uniform((B, T, C), dtype="float32")

In [None]:
mha = MultiHeadAttention(embedding_dim=C, n_heads=H)
mha.set_training(True)

In [None]:
import torch

mha_torch = torch.nn.MultiheadAttention(C, H, bias=False, batch_first=True)

in_proj_weights = [
    mha.modules[0].modules[0].q.w,
    mha.modules[0].modules[1].q.w,
    mha.modules[0].modules[2].q.w,
    mha.modules[0].modules[3].q.w,
    
    mha.modules[0].modules[0].k.w,
    mha.modules[0].modules[1].k.w,
    mha.modules[0].modules[2].k.w,
    mha.modules[0].modules[3].k.w,

    mha.modules[0].modules[0].v.w,
    mha.modules[0].modules[1].v.w,
    mha.modules[0].modules[2].v.w,
    mha.modules[0].modules[3].v.w,
]
in_proj_weights = cp.concatenate(in_proj_weights, axis=0)
out_proj_weights = mha.modules[1].w

mha_torch.in_proj_weight = torch.nn.Parameter(torch.tensor(in_proj_weights.to_numpy(), requires_grad=True))
mha_torch.out_proj.weight = torch.nn.Parameter(torch.tensor(out_proj_weights.to_numpy(), requires_grad=True))

In [None]:
print(in_proj_weights.shape)
print(out_proj_weights.shape)

In [None]:
print(mha_torch.in_proj_weight.shape)
print(mha_torch.out_proj.weight.shape)

In [None]:
out = mha(x)
x_torch = torch.tensor(x.to_numpy(), requires_grad=True)
out_torch, _ = mha_torch(x_torch, x_torch, x_torch, need_weights=False)

In [None]:
tol = 1e-6

In [None]:
np.allclose(out.to_numpy(), out_torch.detach().numpy(), atol=tol, rtol=tol)

In [None]:
dy = cp.random.normal(out.shape)
dx = mha.backward(dy)
out_torch.backward(torch.tensor(dy.to_numpy()))

np.allclose(dx.to_numpy(), x_torch.grad.detach().numpy(), atol=tol, rtol=tol)

In [None]:
in_proj_weight_grads = [
    mha.modules[0].modules[0].q.w.grad,
    mha.modules[0].modules[1].q.w.grad,
    mha.modules[0].modules[2].q.w.grad,
    mha.modules[0].modules[3].q.w.grad,

    mha.modules[0].modules[0].k.w.grad,
    mha.modules[0].modules[1].k.w.grad,
    mha.modules[0].modules[2].k.w.grad,
    mha.modules[0].modules[3].k.w.grad,

    mha.modules[0].modules[0].v.w.grad,
    mha.modules[0].modules[1].v.w.grad,
    mha.modules[0].modules[2].v.w.grad,
    mha.modules[0].modules[3].v.w.grad,
]
in_proj_weight_grads = cp.concatenate(in_proj_weight_grads, axis=0)
out_proj_weight_grads = mha.modules[1].w.grad

In [None]:
np.allclose(in_proj_weight_grads.to_numpy(), mha_torch.in_proj_weight.grad.detach().numpy(), atol=tol, rtol=tol)

In [None]:
np.allclose(out_proj_weight_grads.to_numpy(), mha_torch.out_proj.weight.grad.detach().numpy(), atol=tol, rtol=tol)