In [None]:
import numpy as np

In [None]:
N_CTX = 5
N_VOCAB = 2
N_EMBED = 8

In [None]:
Lg = 1024

MODEL = {
    "wte": np.array(
        [
            [0, 0, 0, 0, 0, 1, 0, 0],  # token `a` (id 0)
            [0, 0, 0, 0, 0, 0, 1, 0],  # token `b` (id 1)
        ]
    ),
    "wpe": np.array(
        [
            [1, 0, 0, 0, 0, 0, 0, 0],  # position 0
            [0, 1, 0, 0, 0, 0, 0, 0],  # position 1
            [0, 0, 1, 0, 0, 0, 0, 0],  # position 2
            [0, 0, 0, 1, 0, 0, 0, 0],  # position 3
            [0, 0, 0, 0, 1, 0, 0, 0],  # position 4
        ]
    ),
    "blocks": [
        {
            "attn": {
                "c_attn": {  # generates qkv matrix
                    "b": np.zeros(N_EMBED * 3),
                    "w": np.array(
                        # this is where the magic happens
                        # fmt: off
                        [
                            [
                                Lg, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # q
                                1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # k
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],  # v
                            [
                                Lg, Lg, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # q
                                0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # k
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],  # v
                            [
                                0.0, Lg, Lg, 0.0, 0.0, 0.0, 0.0, 0.0,  # q
                                0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # k
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],  # v
                            [
                                0.0, 0.0, Lg, Lg, 0.0, 0.0, 0.0, 0.0,  # q
                                0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0,  # k
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],  # v
                            [
                                0.0, 0.0, 0.0, Lg, Lg, 0.0, 0.0, 0.0,  # q
                                0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,  # k
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],  # v
                            [
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # q
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # k
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, ],  # v
                            [
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # q
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # k
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1, ],  # v
                            [
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # q
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,  # k
                                0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],  # v
                        ]
                        # fmt: on
                    ),
                },
                "c_proj": {  # weights to project attn result back to embedding space
                    "b": [0, 0, 0, 0, 0, Lg, 0, 0],
                    "w": np.array(
                        [
                            [0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, 0, 0, 0],
                            [0, 0, 0, 0, 0, -Lg, Lg, 0],
                        ]
                    ),
                },
            }
        }
    ],
}

In [None]:
def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)


def linear(x, w, b):
    return x @ w + b


def attention(q, k, v, mask):
    return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v


def causal_self_attention(x, c_attn, c_proj):
    print(x)
    x = linear(x, **c_attn)
    print(x)

    q, k, v = np.split(x, 3, axis=-1)

    causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10

    x = attention(q, k, v, causal_mask)

    x = linear(x, **c_proj)

    return x


def transformer_block(x, attn):
    x = x + causal_self_attention(x, **attn)
    return x


def gpt(inputs, wte, wpe, blocks):
    # loop through each input, get embedding by index (0, 1)
    token_embeddings = wte[inputs]
    # loop through input length, get embedding by (0, len(input))
    position_embeddings = wpe[range(len(inputs))]
    x = token_embeddings + position_embeddings  # [n_seq] -> [n_seq, n_embd]

    for block in blocks:
        x = transformer_block(x, **block)

    return x @ wte.T

In [None]:
CHARS = ["a", "b"]


def tokenise(s):
    return [CHARS.index(c) for c in s]


def untok(token):
    return CHARS[token]

In [None]:
def predict(s):
    tokens = tokenise(s)
    logits = gpt(np.array(tokens), **MODEL)
    print(logits)
    probs = softmax(logits)

    for i, tok in enumerate(tokens):
        pred = np.argmax(probs[i])
        print(
            f"{untok(tok)} ({tok}): next={untok(pred)} ({pred}) probs={probs[i]} logits={logits[i]}"
        )

    return np.argmax(probs[-1])


print(untok(predict("aabaa")))