In [None]:
from init_notebook import *

In [None]:
class TextTransformer(nn.Module):
    def __init__(
            self,
            vocab_size: int,
            num_layers: int,
            num_channels: int,
            num_channels_mlp: int,
            num_heads: int,
            activation: Union[None, str, Callable] = "relu",
            diagonal_embedding: bool = True,
            symmetric_embedding: bool = True,
            dropout: float = 0.0,
    ):
        super().__init__()
        self.embedding = DiagonalEmbedding(
            channels_in=vocab_size,
            channels_out=num_channels,
            diagonal=diagonal_embedding,
            symmetric=symmetric_embedding,
        )
        self.transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=num_channels,
                nhead=num_heads,
                dim_feedforward=num_channels_mlp,
                dropout=dropout,
                activation=activation,
                batch_first=True,
            ),
            num_layers=num_layers,
        )

    def forward(self, logits: torch.LongTensor) -> torch.Tensor:
        x = self.embedding(logits)
        x = x.permute(0, 2, 1)

        x = self.transformer(x, x)

        x = x.permute(0, 2, 1)
        return self.embedding(x, reverse=True)


model = TextTransformer(
    vocab_size=256,
    num_layers=2,
    num_channels=256,
    num_channels_mlp=256,
    num_heads=8,
)# .to(torch.float16)
print(f"params: {num_module_parameters(model):,}")
inp = torch.randint(0, 255, (1, 10))
outp = model(inp)
print(outp.shape)
model

In [None]:
nn.TransformerDecoderLayer?

In [None]:
nn.TransformerDecoder?