Implement a Simplified GPT-2-like Text Generation Function
You are tasked with implementing a simplified GPT-2-like text generation function in Python. This function will incorporate the following components of a minimal GPT-2 architecture:

Token Embeddings: Map input tokens to dense vector representations.
Positional Embeddings: Add positional information to token embeddings.
Multi-head Attention: Attend to various parts of the sequence.
Feed-Forward Network: Process attention outputs through a dense layer.
Layer Normalization: Stabilize the training process.
The function must take in the following parameters:

Prompt: The initial text to guide the generation process.
Number of Tokens to Generate: Specify how many tokens to output.
Your function should output the generated text.

Additionally, utilize the helper function load_encoder_hparams_and_params to retrieve:

A dummy encoder.
Model hyperparameters.
Model parameters.
Build your text generation logic around these components. This exercise is designed to help you understand the core concepts behind GPT-2's autoregressive text generation.

Example:
Input:
prompt="hello", n_tokens_to_generate=5
Output:
world <UNK> <UNK> <UNK> <UNK>
Reasoning:
The function encodes the input "hello" into tokens using the dummy encoder, then runs a simplified GPT-2 forward pass to generate 5 tokens. Finally, it decodes the generated tokens back into text.

In [1]:
import numpy as np

def gelu(x):
    return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

def layer_norm(x, g, b, eps=1e-5):
    mean = np.mean(x, axis=-1, keepdims=True)
    variance = np.var(x, axis=-1, keepdims=True)
    return g * (x - mean) / np.sqrt(variance + eps) + b

def linear(x, w, b):
    return x @ w + b

def ffn(x, c_fc, c_proj):
    return linear(gelu(linear(x, **c_fc)), **c_proj)

def attention(q, k, v, mask):
    return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v

def mha(x, c_attn, c_proj, n_head):
    x = linear(x, **c_attn)
    qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), np.split(x, 3, axis=-1)))
    causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10
    out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)]
    x = linear(np.hstack(out_heads), **c_proj)
    return x

def transformer_block(x, mlp, attn, ln_1, ln_2, n_head):
    x = x + mha(layer_norm(x, **ln_1), **attn, n_head=n_head)
    x = x + ffn(layer_norm(x, **ln_2), **mlp)
    return x

def gpt2(inputs, wte, wpe, blocks, ln_f, n_head):
    x = wte[inputs] + wpe[range(len(inputs))]
    for block in blocks:
        x = transformer_block(x, **block, n_head=n_head)
    return layer_norm(x, **ln_f) @ wte.T

def generate(inputs, params, n_head, n_tokens_to_generate):
    for _ in range(n_tokens_to_generate):
        logits = gpt2(inputs, **params, n_head=n_head)
        next_id = np.argmax(logits[-1])
        inputs.append(int(next_id))
    return inputs[len(inputs) - n_tokens_to_generate:]

def gen_text(prompt: str, n_tokens_to_generate: int = 40):
    np.random.seed(42)  # Set the random seed for reproducibility
    encoder, hparams, params = load_encoder_hparams_and_params()
    input_ids = encoder.encode(prompt)
    assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
    output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
    output_text = encoder.decode(output_ids)
    return output_text

In [3]:
# Dummy encoder
class DummyEncoder:
    def __init__(self, vocab):
        self.vocab = vocab
        self.token_to_id = {tok:i for i,tok in enumerate(vocab)}
        self.id_to_token = {i:t for i,t in enumerate(vocab)}
    def encode(self, text):
        return [self.token_to_id.get(tok, self.token_to_id["<UNK>"]) for tok in text.split()]
    def decode(self, ids):
        return " ".join(self.id_to_token.get(i,"<UNK>") for i in ids)

def load_encoder_hparams_and_params():
    vocab = ["hello","world","cat","dog","<UNK>"]
    encoder = DummyEncoder(vocab)

    d_model = 8
    n_head = 2
    n_ctx = 20
    V = len(vocab)

    # Tiny random weights for demo
    rng = np.random.default_rng(42)
    wte = rng.normal(size=(V,d_model))
    wpe = rng.normal(size=(n_ctx,d_model))
    ln_f = {"g": np.ones(d_model), "b": np.zeros(d_model)}

    # One block only, super minimal
    block = {
        "mlp": {
            "c_fc": {"w": rng.normal(size=(d_model,d_model)), "b": np.zeros(d_model)},
            "c_proj": {"w": rng.normal(size=(d_model,d_model)), "b": np.zeros(d_model)}
        },
        "attn": {
            "c_attn": {"w": rng.normal(size=(d_model,3*d_model)), "b": np.zeros(3*d_model)},
            "c_proj": {"w": rng.normal(size=(d_model,d_model)), "b": np.zeros(d_model)}
        },
        "ln_1": {"g": np.ones(d_model), "b": np.zeros(d_model)},
        "ln_2": {"g": np.ones(d_model), "b": np.zeros(d_model)}
    }

    params = {"wte":wte,"wpe":wpe,"blocks":[block],"ln_f":ln_f}
    hparams = {"n_ctx":n_ctx,"n_head":n_head}
    return encoder, hparams, params

In [4]:
print(gen_text("hello", 5))
print(gen_text("cat", 5))
print(gen_text("dog", 5))

dog <UNK> <UNK> cat dog
<UNK> <UNK> cat dog <UNK>
<UNK> <UNK> cat dog world


In [None]:
import numpy as np

# ----- tiny helpers -----
def gelu(x): return 0.5*x*(1+np.tanh(np.sqrt(2/np.pi)*(x+0.044715*x**3)))
def softmax(x): x=x-np.max(x,axis=-1,keepdims=True); e=np.exp(x); return e/np.sum(e,axis=-1,keepdims=True)
def layer_norm(x,g,b,eps=1e-5): m=x.mean(-1,keepdims=True); v=x.var(-1,keepdims=True); return g*(x-m)/np.sqrt(v+eps)+b
def causal_mask(T): m=np.triu(np.ones((T,T)),1); return np.where(m==1,-1e9,0.0)

# ----- dummy encoder -----
class DummyEncoder:
    def __init__(self,vocab): self.vocab=vocab; self.t2i={t:i for i,t in enumerate(vocab)}; self.i2t={i:t for i,t in enumerate(vocab)}
    def encode(self,text): return [self.t2i.get(tok,self.t2i["<UNK>"]) for tok in text.split()]
    def decode(self,ids): return " ".join(self.i2t.get(int(i),"<UNK>") for i in ids)

# ----- tiny GPT block (single head) -----
def linear(x,w,b): return x@w + b

def attn(Q,K,V,mask): scores=(Q@K.T)/np.sqrt(Q.shape[-1]); scores+=mask; A=softmax(scores); return A@V


def transformer_block(x,params):
    # pre-norm + self-attn
    y=layer_norm(x,params["ln1_g"],params["ln1_b"])
    qkv=linear(y,params["attn_w"],params["attn_b"])             # (T,3D)
    q,k,v=np.split(qkv,3,axis=-1)                                # (T,D) each
    out=attn(q,k,v,params["mask"])                               # (T,D)
    x=x+linear(out,params["attn_proj_w"],params["attn_proj_b"])  # residual
    # pre-norm + MLP
    y=layer_norm(x,params["ln2_g"],params["ln2_b"])
    h=gelu(linear(y,params["ff_w"],params["ff_b"]))              # expand->gelu
    h=linear(h,params["ffp_w"],params["ffp_b"])                  # project back
    return x+h                                                   # residual

def logits_from_ids(ids,params):
    T=len(ids); x=params["wte"][ids]+params["wpe"][:T]
    params["mask"]=causal_mask(T)
    x=transformer_block(x,params)
    x=layer_norm(x,params["lnf_g"],params["lnf_b"])
    return x@params["wte"].T                                     # weight tying

def generate(prompt_ids,params,n_tokens):
    ids=list(prompt_ids)
    for _ in range(n_tokens):
        lg=logits_from_ids(ids,params)
        next_id=int(np.argmax(lg[-1]))                           # greedy
        ids.append(next_id)
    return ids[-n_tokens:]

# ----- minimal wiring + sample -----
def build_tiny_gpt(vocab_size=8,d_model=16,ff_mult=4,n_ctx=64,seed=0):
    rng=np.random.default_rng(seed); V=vocab_size; D=d_model; F=D*ff_mult
    wte=rng.normal(0,0.02,(V,D)); wpe=rng.normal(0,0.02,(n_ctx,D))
    attn_w=rng.normal(0,0.02,(D,3*D)); attn_b=np.zeros(3*D)
    attn_proj_w=rng.normal(0,0.02,(D,D)); attn_proj_b=np.zeros(D)
    ff_w=rng.normal(0,0.02,(D,F)); ff_b=np.zeros(F)
    ffp_w=rng.normal(0,0.02,(F,D)); ffp_b=np.zeros(D)
    ln1_g=np.ones(D); ln1_b=np.zeros(D)
    ln2_g=np.ones(D); ln2_b=np.zeros(D)
    lnf_g=np.ones(D); lnf_b=np.zeros(D)
    return {"wte":wte,"wpe":wpe,"attn_w":attn_w,"attn_b":attn_b,"attn_proj_w":attn_proj_w,"attn_proj_b":attn_proj_b,"ff_w":ff_w,"ff_b":ff_b,"ffp_w":ffp_w,"ffp_b":ffp_b,"ln1_g":ln1_g,"ln1_b":ln1_b,"ln2_g":ln2_g,"ln2_b":ln2_b,"lnf_g":lnf_g,"lnf_b":lnf_b}

if __name__ == "__main__":
    # tiny vocab; space-separated tokens only
    vocab=["hello","world","cat","dog","a","b","c","<UNK>"]
    enc=DummyEncoder(vocab)
    params=build_tiny_gpt(vocab_size=len(vocab),d_model=16,ff_mult=4,n_ctx=64,seed=42)
    # run a few samples
    for prompt in ["hello","cat dog","hello world"]:
        inp=enc.encode(prompt)
        out_ids=generate(inp,params,n_tokens=5)
        print(f"prompt='{prompt}' â†’", enc.decode(out_ids))

In [None]:
import math, torch, torch.nn as nn, torch.nn.functional as F

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_head: int, dropout: float = 0.0):
        super().__init__()
        assert d_model % n_head == 0
        self.n_head = n_head
        self.d_head = d_model // n_head
        self.qkv = nn.Linear(d_model, 3*d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)
        self.attn_drop = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)
        self.register_buffer("mask", None, persistent=False)
    def _get_mask(self, T: int, device):
        if self.mask is None or self.mask.size(0) < T:
            m = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
            self.mask = m
        return self.mask[:T, :T]
    def forward(self, x):  # x: (B,T,D)
        B,T,D = x.size()
        qkv = self.qkv(x)                         # (B,T,3D)
        q,k,v = qkv.split(D, dim=-1)              # each (B,T,D)
        q = q.view(B,T,self.n_head,self.d_head).transpose(1,2)  # (B,H,T,dh)
        k = k.view(B,T,self.n_head,self.d_head).transpose(1,2)
        v = v.view(B,T,self.n_head,self.d_head).transpose(1,2)
        att = (q @ k.transpose(-2,-1)) / math.sqrt(self.d_head) # (B,H,T,T)
        att = att.masked_fill(~self._get_mask(T, x.device), float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v                                 # (B,H,T,dh)
        y = y.transpose(1,2).contiguous().view(B,T,D)
        y = self.resid_drop(self.proj(y))
        return y

class MLP(nn.Module):
    def __init__(self, d_model: int, mult: int = 4, dropout: float = 0.0):
        super().__init__()
        self.fc = nn.Linear(d_model, mult*d_model)
        self.proj = nn.Linear(mult*d_model, d_model)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        x = self.fc(x)
        x = F.gelu(x)
        x = self.drop(self.proj(x))
        return x

class Block(nn.Module):
    def __init__(self, d_model: int, n_head: int, dropout: float = 0.0):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_head, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, mult=4, dropout=dropout)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, vocab_size: int, n_ctx: int, d_model: int = 128, n_head: int = 4, n_layer: int = 2, dropout: float = 0.0, tie_weights: bool = True):
        super().__init__()
        self.n_ctx = n_ctx
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(n_ctx, d_model)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([Block(d_model, n_head, dropout) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        if tie_weights:
            self.head.weight = self.tok_emb.weight
        self.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
    def forward(self, idx):  # idx: (B,T) int64
        B,T = idx.size()
        assert T <= self.n_ctx
        pos = torch.arange(0, T, device=idx.device).unsqueeze(0)  # (1,T)
        x = self.tok_emb(idx) + self.pos_emb(pos)
        x = self.drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        logits = self.head(x)  # (B,T,V)
        return logits
    @torch.no_grad()
    def generate(self, idx, n_new_tokens: int, temperature: float = 1.0, top_k: int | None = None):
        self.eval()
        for _ in range(n_new_tokens):
            idx_cond = idx[:, -self.n_ctx:]  # crop to context
            logits = self(idx_cond)[:, -1, :] / max(temperature, 1e-6)
            if top_k is not None:
                v, _ = torch.topk(logits, k=min(top_k, logits.size(-1)))
                thresh = v[:, -1].unsqueeze(-1)
                logits = torch.where(logits < thresh, torch.full_like(logits, float("-inf")), logits)
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)     # sampling; for greedy: torch.argmax(logits, dim=-1, keepdim=True)
            idx = torch.cat([idx, next_id], dim=1)
        return idx

# tiny dummy tokenizer so you can run a sample
class DummyTokenizer:
    def __init__(self, vocab): self.vocab=vocab; self.stoi={t:i for i,t in enumerate(vocab)}; self.itos={i:t for i,t in enumerate(vocab)}
    def encode(self, s): return torch.tensor([[self.stoi.get(t, self.stoi["<unk>"]) for t in s.split()]], dtype=torch.long)
    def decode(self, ids): 
        if isinstance(ids, torch.Tensor): ids = ids.tolist()
        if ids and isinstance(ids[0], list): ids = ids[0]
        return " ".join(self.itos.get(i, "<unk>") for i in ids)

if __name__ == "__main__":
    torch.manual_seed(0)
    vocab = ["hello","world","cat","dog","a","b","c","<unk>"]
    tok = DummyTokenizer(vocab)
    V = len(vocab); n_ctx = 32
    model = GPT(vocab_size=V, n_ctx=n_ctx, d_model=128, n_head=4, n_layer=2, dropout=0.0)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    # prompt and generate
    prompt = "hello"
    idx = tok.encode(prompt).to(device)        # (1,T)
    out_idx = model.generate(idx, n_new_tokens=8, temperature=1.0, top_k=5)
    print("prompt:", prompt)
    print("gen:", tok.decode(out_idx[0, idx.size(1):].cpu()))

In [None]:
import math, torch, torch.nn as nn, torch.nn.functional as F

# dummy whitespace tokenizer
class Tok:
    def __init__(self, vocab): self.v=vocab; self.s2i={t:i for i,t in enumerate(vocab)}; self.i2s={i:t for i,t in enumerate(vocab)}
    def encode(self, s): return torch.tensor([[self.s2i.get(t, self.s2i["<unk>"]) for t in s.split()]], dtype=torch.long)
    def decode(self, ids): 
        if isinstance(ids, torch.Tensor): ids = ids.tolist()
        if ids and isinstance(ids[0], list): ids = ids[0]
        return " ".join(self.i2s.get(i,"<unk>") for i in ids)

class SimpleGPT(nn.Module):
    def __init__(self, vocab_size, n_ctx=32, d=64):
        super().__init__()
        self.n_ctx, self.d = n_ctx, d
        self.tok = nn.Embedding(vocab_size, d)
        self.pos = nn.Embedding(n_ctx, d)
        self.qkv = nn.Linear(d, 3*d, bias=False)      # single head
        self.proj = nn.Linear(d, d, bias=False)
        self.ln1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, 4*d)
        self.ff2 = nn.Linear(4*d, d)
        self.ln2 = nn.LayerNorm(d)
        self.head = nn.Linear(d, vocab_size, bias=False)
        self.head.weight = self.tok.weight            # tie weights
        self.register_buffer("mask", torch.tril(torch.ones(n_ctx, n_ctx)).bool(), persistent=False)
        self.apply(self._init)
    def _init(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias)
    def forward(self, idx):                           # idx: (B,T)
        B,T = idx.shape
        pos = torch.arange(T, device=idx.device).unsqueeze(0)
        x = self.tok(idx) + self.pos(pos)            # (B,T,d)
        # pre-norm + self-attn (single head)
        y = self.ln1(x)
        q,k,v = self.qkv(y).chunk(3, dim=-1)         # (B,T,d) each
        att = (q @ k.transpose(-2,-1)) / math.sqrt(self.d)
        att = att.masked_fill(~self.mask[:T, :T], float("-inf"))
        att = F.softmax(att, dim=-1)
        x = x + (att @ v) @ self.proj.weight.T       # residual; proj is linear with no bias
        # pre-norm + MLP
        y = self.ln2(x)
        x = x + self.ff2(F.gelu(self.ff1(y)))        # residual
        logits = self.head(x)                        # (B,T,V)
        return logits
    @torch.no_grad()
    def generate(self, idx, n_new):
        self.eval()
        for _ in range(n_new):
            idx_cond = idx[:, -self.n_ctx:]
            logits = self(idx_cond)[:, -1, :]
            next_id = torch.argmax(logits, dim=-1, keepdim=True) # greedy
            idx = torch.cat([idx, next_id], dim=1)
        return idx

if __name__ == "__main__":
    torch.manual_seed(0)
    vocab = ["hello","world","cat","dog","a","b","c","<unk>"]
    tok = Tok(vocab)
    model = SimpleGPT(vocab_size=len(vocab), n_ctx=32, d=64)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    for prompt in ["hello", "cat dog", "hello world"]:
        x = tok.encode(prompt).to(device)
        y = model.generate(x, n_new=6)[0, x.size(1):].cpu()
        print(f"{prompt} -> {tok.decode(y)}")