In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GPTStyleDecoderOnlyModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, max_seq_length=512, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_embedding = nn.Parameter(torch.zeros(1, max_seq_length, d_model))

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_layers
        )
        
        self.output_projection = nn.Linear(d_model, vocab_size)
        self.max_seq_length = max_seq_length
        self.d_model = d_model

    def generate_causal_mask(self, size):
        # Obere Dreiecksmatrix mit -inf oberhalb der Diagonalen, 0 auf und unterhalb
        mask = torch.triu(torch.ones(size, size) * float('-inf'), diagonal=1)
        return mask

    def forward(self, input_ids):
        batch_size, seq_length = input_ids.size()
        assert seq_length <= self.max_seq_length, "Input sequence too long"

        token_embeddings = self.token_embedding(input_ids)
        position_embeddings = self.positional_embedding[:, :seq_length, :]
        x = token_embeddings + position_embeddings

        # Autoregressive Mask (Causal Mask)
        tgt_mask = self.generate_causal_mask(seq_length).to(input_ids.device)

        # Memory is None since this is decoder-only
        x = self.decoder(tgt=x, memory=None, tgt_mask=tgt_mask)
        logits = self.output_projection(x)
        return logits
