In [4]:


import torch
import torch.nn as nn

class GPTDecoderOnlyModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, max_seq_length=512, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_embedding = nn.Parameter(torch.zeros(1, max_seq_length, d_model))

        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                batch_first=True
            ) for _ in range(num_layers)
        ])

        self.ln_f = nn.LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size)
        self.max_seq_length = max_seq_length

    def generate_causal_mask(self, seq_len, device):
        return torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1).to(device)

    def forward(self, input_ids):
        B, T = input_ids.shape
        assert T <= self.max_seq_length

        tok_emb = self.token_embedding(input_ids)
        pos_emb = self.positional_embedding[:, :T, :]
        x = tok_emb + pos_emb

        mask = self.generate_causal_mask(T, input_ids.device)

        for layer in self.layers:
            x = layer(x, mask)

        x = self.ln_f(x)
        logits = self.output_projection(x)
        return logits



In [None]:

from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
"""
model = GPTDecoderOnlyModel(vocab_size=tokenizer.vocab_size)
model.eval()

# Textgenerierung wie vorher
def generate(model, tokenizer, prompt, max_new_tokens=20):
    model.eval()
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    for _ in range(max_new_tokens):
        logits = model(input_ids)
        next_token_logits = logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
        input_ids = torch.cat([input_ids, next_token], dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

print(generate(model, tokenizer, "Once upon a time", max_new_tokens=30))

"""

Once upon a time Gall external Avery「 Gomez 000000queuevalidcollar DN JavaScript Architectshard haythanwithstandingelokin horrified Education bordering subpoena laws majestic sailedatoon LOW 303 WHITE indications
