In [None]:
from typing import Optional

import compyute as cp
from compyute.nn import Container, ParallelConcat, Linear, Sequential, Dropout, ReLU, SkipConnection, Layernorm, Embedding
from compyute.nn.functional import softmax
from compyute._types import _DtypeLike
from compyute._tensor import Tensor
from compyute.nn import Buffer


class Transformer(Container):
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int,
        feedforward_dim: int,
        n_heads: int,
        n_layers: int,
        sequence_length: int,
        mask: Optional[Tensor] = None,
        dropout: Optional[float] = None,
        attention_bias: bool = False,
        feedforward_bias: bool = True,
        layernorm_eps: float = 1e-5,
        dtype: _DtypeLike = "float32",
        label: Optional[str] = None
    ) -> None:
        super().__init__(label=label)
        self.token_embedding = Embedding(vocab_size=vocab_size, embedding_dim=embedding_dim, dtype=dtype, label="TokenEmbedding")
        self.pos_embedding = Embedding(vocab_size=sequence_length, embedding_dim=embedding_dim, dtype=dtype, label="PosEmbedding")
        self.blocks = Sequential(
            *[TransformerBlock(
                embedding_dim=embedding_dim,
                feedforward_dim=feedforward_dim,
                n_heads=n_heads,
                sequence_length=sequence_length,
                mask=mask,
                dropout=dropout,
                attention_bias=attention_bias,
                feedforward_bias=feedforward_bias,
                layernorm_eps=layernorm_eps,
                dtype=dtype
            ) for _ in range(n_layers)]
        )
        self.out_projection = Linear(embedding_dim, vocab_size)
    
    def forward(self, x: Tensor) -> Tensor:
        token_emb = self.token_embedding(x)
        pos_emb = self.pos_embedding(x)
        x = token_emb + pos_emb
        x = self.blocks(x)
        y = self.out_projection(x)

        def backward(dy: Tensor) -> None:
            dy = self.out_projection.backward(dy)
            dy = self.blocks.backward(dy)
            self.token_embedding.backward(dy)
        
        self._backward = backward

        return y


class TransformerBlock(Sequential):
    def __init__(
        self,
        embedding_dim: int,
        feedforward_dim: int,
        n_heads: int,
        sequence_length: int,
        mask: Optional[Tensor] = None,
        dropout: Optional[float] = None,
        attention_bias: bool = False,
        feedforward_bias: bool = True,
        layernorm_eps: float = 1e-5,
        dtype: _DtypeLike = "float32"
    ) -> None:
        attention_block = SkipConnection(
            Sequential(
                Layernorm(
                    normalized_shape=(sequence_length, embedding_dim),
                    eps=layernorm_eps,
                    dtype=dtype
                ),
                MultiHeadAttention(
                    embedding_dim=embedding_dim,
                    n_heads=n_heads,
                    mask=mask,
                    dropout=dropout,
                    bias=attention_bias,
                    dtype=dtype
                ),
                label="MultiHeadAttentionBlock"
            )
        )
        feedforward_block = SkipConnection(
            Sequential(
                Layernorm(
                    normalized_shape=(sequence_length, embedding_dim),
                    eps=layernorm_eps,
                    dtype=dtype
                ),
                FeedForward(
                    embedding_dim=embedding_dim,
                    feedforward_dim=feedforward_dim,
                    bias=feedforward_bias,
                    dropout=dropout,
                    dtype=dtype
                ),
                label="FeedForwardBlock"
            )
        )
        super().__init__(attention_block, feedforward_block)


class FeedForward(Sequential):
    def __init__(
        self,
        embedding_dim: int,
        feedforward_dim: int,
        dropout: Optional[float] = None,
        bias: bool = True,
        dtype: _DtypeLike = "float32"
    ) -> None:
        layers = [
            Linear(embedding_dim, feedforward_dim, bias=bias, dtype=dtype),
            ReLU(),
            Linear(feedforward_dim, embedding_dim, bias=bias, dtype=dtype)
        ]
        layers += [Dropout(p=dropout)] if dropout is not None else []
        super().__init__(*layers, label="FeedForward")


class MultiHeadAttention(Sequential):
    def __init__(
        self,
        embedding_dim: int,
        n_heads: int,
        mask: Optional[Tensor] = None,
        dropout: Optional[float] = None,
        bias: bool = False,
        dtype: _DtypeLike = "float32"
    ) -> None:
        layers = [
            ParallelConcat(*[
                AttentionHead(
                    embedding_dim=embedding_dim,
                    head_size=embedding_dim // n_heads,
                    mask=mask,
                    dropout=dropout,
                    bias=bias,
                    dtype=dtype
                ) for _ in range(n_heads)
            ], label="Heads"),
            Linear(embedding_dim, embedding_dim, bias, dtype, label="OutProjection")
        ]
        layers += [Dropout(p=dropout)] if dropout is not None else []
        super().__init__(*layers)


class AttentionHead(Container):
    def __init__(
        self,
        embedding_dim: int,
        head_size: int,
        mask: Optional[Tensor] = None,
        dropout: Optional[float] = None,
        bias: bool = False,
        dtype: _DtypeLike = "float32"
    ) -> None:
        super().__init__()
        self.q = Linear(embedding_dim, head_size, bias, dtype, label="QueryProjection")
        self.k = Linear(embedding_dim, head_size, bias, dtype, label="KeyProjection")
        self.v = Linear(embedding_dim, head_size, bias, dtype, label="ValueProjection")
        self.dropout = Dropout(p=dropout) if dropout else None
        self.head_size = head_size
        self.mask = Buffer(mask)
        self.dtype = dtype

    def forward(self, x: cp.Tensor) -> cp.Tensor:
        # input projections
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        # attention
        qk = q @ k.T * self.head_size**-0.5
        if self.mask is not None:
            qk += self.mask
        sm, sm_backward = softmax(qk, self.training)
        if self.dropout is not None:
            sm = self.dropout(sm)
        y = sm @ v

        if self.training:

            def _backward(dy: cp.Tensor) -> cp.Tensor:
                dy = dy.astype(self.dtype)

                dsm = dy @ v.T

                if self.dropout is not None:
                    dsm = self.dropout.backward(dsm)

                dqk = sm_backward(dsm) * self.head_size**-0.5
                dq = self.q.backward(dqk @ k)
                dk = self.k.backward(dqk.T @ q)
                dv = self.v.backward(sm.T @ dy)

                return dq + dk + dv
            
            self._backward = _backward

        return y      

In [None]:
B, T = 64, 128
x = cp.random.uniform_int((B, T), 0, 100)

In [None]:
mask = cp.triu(cp.full(shape=(T, T), value=float("-inf")), d=1)
mask_ = mask.cuda()

In [None]:
t = Transformer(
    vocab_size=100,
    embedding_dim=384,
    feedforward_dim=4*384,
    n_heads=6,
    n_layers=6,
    sequence_length=T,
    mask=mask_,
    dropout=0.2
)

t.set_training(True)
t.to_device("cuda")

In [None]:
t.summary(input_shape=(T,), input_dtype="int32")

In [None]:
x_ = x.cuda()
out = t(x_)
out

In [None]:
dy = cp.random.normal(shape=out.shape, device="cuda")
t.backward(dy)