In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Data

In [2]:
# https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
chars = sorted(list(set(text)))
vocab_dim = len(chars)

In [5]:
# sentencepiece
# tiktoken

In [6]:
# import tiktoken
# enc = tiktoken.get_encoding('gpt2')
# enc.encode('hi my name is linsu!')
# enc.decode([5303, 616, 1438, 318, 300, 1040, 84, 0])

In [7]:
c_i = {c:i for i, c in enumerate(chars)}
i_c = {i:c for i, c in enumerate(chars)}
encode = lambda s: [c_i[c] for c in s]
decode = lambda l: ''.join([i_c[i] for i in l])

In [8]:
data = torch.tensor(encode(text), dtype=torch.int64)
data.shape

torch.Size([1115394])

In [9]:
n = int(.9*len(data))
data_train = data[:n]
data_val = data[n:]

In [10]:
len(data_train)

1003854

# Model

In [11]:
B = 78 # batch size
T = 78 # sequence length
embed_dim = 80 # must be divisible by num_heads
train_steps = 5000
lr = 1e-3 # learning rate
torch.manual_seed(1337)
device = torch.device('cuda')

In [12]:
def get_batch(data, B):
    idx = torch.randint(len(data) - T, (B,))
    x = torch.stack([data[i:i+T] for i in idx])
    y = torch.stack([data[i+1:i+T+1] for i in idx])
    return x, y

In [13]:
x, y = get_batch(data_train, B)

In [14]:
x.shape, y.shape

(torch.Size([78, 78]), torch.Size([78, 78]))

In [15]:
print([decode(l) for l in x.numpy()])
print([decode(l) for l in y.numpy()])

["ade glorious summer by this sun of York;\nAnd all the clouds that lour'd upon o", 'kill were subject to thy curse.\nHere did she fall a tear; here in this place\nI', 'ments and poetry,\nSchoolmasters will I keep within my house,\nFit to instruct h', 'rd, blowing of his nails,\nCan neither call it perfect day nor night.\nNow sways', 'd.\n\nMENENIUS:\nMasters of the people,\nYour multiplying spawn how can he flatter', "be wild, I have dispatch'd in post\nTo sacred Delphos, to Apollo's temple,\nCleo", 'ee heaven!\n\nHENRY BOLINGBROKE:\nHarry of Hereford, Lancaster and Derby\nAm I; wh', "ons notwithstanding,\nBut by the robbing of the banish'd duke.\n\nNORTHUMBERLAND:", 'AURENCE:\nSo smile the heavens upon this holy act,\nThat after hours with sorrow', "eavy womb!\nThou loathed issue of thy father's loins!\nThou rag of honour! thou ", "he slay me,\nHe does fair justice; if he give me way,\nI'll do his country servi", ' must entreat the time alone.\n\nPARIS:\nGod shield I should disturb d

In [16]:
class SelfAttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, head_dim, bias=False)
        self.key = nn.Linear(embed_dim, head_dim, bias=False)
        self.value = nn.Linear(embed_dim, head_dim, bias=False)
        self.dropout = nn.Dropout(.2)
        # self.register_buffer('tril', torch.tril(torch.ones(sequence_dim, sequence_dim)))

    def forward(self, x): # TODO: pass in q k v instead of x
        B, T, E = x.shape
        q = self.query(x) # (B, T, E) -> (B, T, H)
        k = self.key(x) # (B, T, E) -> (B, T, H)
        v = self.value(x) # (B, T, E) -> (B, T, H)
        wei = q @ k.transpose(-2, -1) * E**(-0.5) # (B, T, H) @ (B, H, T) = (B, T, T)
        tril = torch.tril(torch.ones(T, T)).to(device) # how to register buffer without T param
        wei = wei.masked_fill(tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        out = wei @ v # (B, T, T) @ (B, T, H) = (B, T, H)
        return out

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, head_dim=None):
        super().__init__()
        if head_dim is None:
            head_dim = embed_dim//num_heads
        self.heads = nn.ModuleList([SelfAttentionHead(embed_dim, head_dim) for _ in range(num_heads)])
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        return x

In [18]:
class FeedForward(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, 4 * out_features), # TODO: 4x according to paper
            nn.ReLU(),
            nn.Linear(4 * in_features, out_features), # projection layer
            nn.Dropout(.2)
        )

    def forward(self, x):
        return self.net(x)

In [19]:
class Block(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.mha = MultiHeadAttention(embed_dim, num_heads)
        self.ff = FeedForward(embed_dim, embed_dim)
        self.ln1 = nn.LayerNorm(embed_dim) # https://arxiv.org/pdf/2002.04745.pdf
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x + self.mha(self.ln1(x)) # self attend
        x = x + self.ff(self.ln2(x)) # think on data
        return x

In [20]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_dim, embed_dim, sequence_dim):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_dim, embed_dim)
        self.position_embedding = nn.Embedding(sequence_dim, embed_dim)
        self.blocks = nn.Sequential(
            Block(embed_dim, 4),
            Block(embed_dim, 4),
            Block(embed_dim, 4),
            nn.LayerNorm(embed_dim),
        )
        self.linear = nn.Linear(embed_dim, vocab_dim)

    def forward(self, x, y=None):
        # B is batch, T is length of time series, E is embedding dim
        B, T = x.shape
        token_embeddings = self.token_embedding(x) # (B, T, E)
        position_embeddings = self.position_embedding(torch.arange(T).to(device)) # (T, E) # T <= sequence_dim
        x = token_embeddings + position_embeddings # (B, T, E) +  (-, T, E) -> (B, T, E)
        x = self.blocks(x)
        logits = self.linear(x) # pred
        if y is None:
            loss = None
        else:
            B, T, E = logits.shape
            logits = logits.view(B*T, E)
            y = y.view(B*T)
            loss = F.cross_entropy(logits, y)
        return logits, loss

    def generate(self, x, n_tokens):
        for i in range(n_tokens):
            x_cropped = x[:, -T:] # crop s.t. it's <= sequence_dim
            logits, _ = self(x_cropped) # (B, T, E)
            logits = logits[:, -1, :] # (B, E)
            probs = F.softmax(logits, dim=-1) # (B, E)
            y_pred = torch.multinomial(probs, num_samples=1) # (B, 1)
            x = torch.cat((x, y_pred), dim=1) # (B, T) + (B, 1) = (B, T + 1)
        return x

In [21]:
model = BigramLanguageModel(vocab_dim, embed_dim, T).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [22]:
model

BigramLanguageModel(
  (token_embedding): Embedding(65, 80)
  (position_embedding): Embedding(78, 80)
  (blocks): Sequential(
    (0): Block(
      (mha): MultiHeadAttention(
        (heads): ModuleList(
          (0): SelfAttentionHead(
            (query): Linear(in_features=80, out_features=20, bias=False)
            (key): Linear(in_features=80, out_features=20, bias=False)
            (value): Linear(in_features=80, out_features=20, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
          (1): SelfAttentionHead(
            (query): Linear(in_features=80, out_features=20, bias=False)
            (key): Linear(in_features=80, out_features=20, bias=False)
            (value): Linear(in_features=80, out_features=20, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
          (2): SelfAttentionHead(
            (query): Linear(in_features=80, out_features=20, bias=False)
            (key): Linear(in_features=80, out_features=20

In [23]:
# pre training
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.int64).to(device), 100).cpu().numpy()[0]))


wbXaqF,Y!elXe;bwf:a3TKm
v?$dvycRRZejI&EkMe.MMjl!CF'GZFuCPzPbWCBcRLHFDFNvQlWbLA.Tcq?QlMHRedTUKBGslP
D


In [24]:
@torch.no_grad()
def estimate_loss(model, iters, device):
    out = []
    model.to(device)
    model.eval()
    losses = torch.zeros(iters)
    for data in [data_train, data_val]:
        for i in range(iters):
            x, y = get_batch(data, B)
            x = x.to(device)
            y = y.to(device)
            logits, loss = model(x, y)
            losses[i] = loss.item()
        out.append(losses.mean())
    model.train()
    return out

In [25]:
tenth = train_steps//10
for steps in range(train_steps):
    x, y = get_batch(data_train, B)
    x = x.to(device)
    y = y.to(device)
    logits, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps % tenth == 0:
        train_loss, val_loss = estimate_loss(model, 100, device)
        print(train_loss, val_loss)

tensor(4.1863) tensor(4.1894)
tensor(2.1204) tensor(2.1557)
tensor(1.8340) tensor(1.9408)
tensor(1.7039) tensor(1.8463)
tensor(1.6246) tensor(1.8039)
tensor(1.5780) tensor(1.7567)
tensor(1.5456) tensor(1.7379)
tensor(1.5195) tensor(1.7097)
tensor(1.5013) tensor(1.6994)
tensor(1.4880) tensor(1.6777)


In [26]:
# post training
model.eval()
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.int64).to(device), 1000).cpu().numpy()[0]))


Where they is not: a prayer a soal
Warwick of his hubhand to by marriad; thook.
When be was I do them run my brains.

AEdid Apon, the brack:
I me madam: I warrantle your bosolute,
Phast will majesty, Callord; now I will be receive--
Is do I'll theough doth stranget, and thy hand:
Cremy, and so so come that take him:
'Tis a stand; virtue is a more coant!
No purness, who had go: had cheep their bury heard?

PARIS:
Bear you, let him strengdom from undence
Ir warl the Claudio, let speective was deny.

HENRY BOLINGS:
Ay, did you?

TABELLET:
The city to to the worlds speak.'

LADY ANNE:
Nay, my and begether not preat thee: you for kin,
The cwand-charge their news, our sand,
Have billow'd and not and son, suffect,
Contencains so that
To be in the out forbish noble and joice
When to-buy gracious end four Rome to Richard,
But in master him of racks in a lie I said, as in his
bualt, that, and my do rime!  all unfes, and
Is dook upon as since my seeger him was
conser mate put our doops qover aga

In [27]:
B, T, E = 4, 8, 2
x = torch.randn(B, T, E)
xbow = torch.zeros((B, T, E))

In [28]:
# method 1
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t, D)
        xbow[b, t] = torch.mean(xprev, 0) # (D,)

In [29]:
# method 2
wei = torch.tril(torch.ones((T, T)))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (-, T, T) @ (B, T, C) = (B, T, C)

In [30]:
# method 3
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x # (-, T, T) @ (B, T, C) = (B, T, C)

![](diagram.png)

In [31]:
# method 4: Attention Mechanism
head_dim = 16
query = nn.Linear(E, head_dim, bias=False)
key = nn.Linear(E, head_dim, bias=False)
value = nn.Linear(E, head_dim, bias=False)
q = query(x) # q = Wx (B, T, 16)
k = key(x) # k = Wx (B, T, 16)
v = value(x) # v = Wx (B, T, 16)

# dot product between two vectors measures similarity
# this matrix multiplication dots each vector to every other vector
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (8, 16, T) = (B, T, T)
wei = wei * head_dim**(-0.5)

tril = torch.tril(torch.ones(T, T)) # optional (for decoder)
wei = wei.masked_fill(tril == 0, float('-inf')) # optional (for decoder)
wei = F.softmax(wei, dim=-1)
out = wei @ v # (B, T, T) @ (B, T, H) = (B, T, H)

In [32]:
# https://youtu.be/kCc8FmEb1nY?si=AEN-_b8nxN1nxvDf&t=5248