In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Data

In [2]:
# https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
chars = sorted(list(set(text)))
vocab_dim = len(chars)

In [5]:
# sentencepiece
# tiktoken

In [6]:
# import tiktoken
# enc = tiktoken.get_encoding('gpt2')
# enc.encode('hi my name is linsu!')
# enc.decode([5303, 616, 1438, 318, 300, 1040, 84, 0])

In [7]:
c_i = {c:i for i, c in enumerate(chars)}
i_c = {i:c for i, c in enumerate(chars)}
encode = lambda s: [c_i[c] for c in s]
decode = lambda l: ''.join([i_c[i] for i in l])

In [8]:
data = torch.tensor(encode(text), dtype=torch.int64)
data.shape

torch.Size([1115394])

In [9]:
n = int(.9*len(data))
data_train = data[:n]
data_val = data[n:]

In [10]:
len(data_train)

1003854

# Model

In [11]:
B = 32 # batch size
T = 8 # sequence length
embed_dim = 32
train_steps = 4500
lr = 1e-3 # learning rate
torch.manual_seed(1337)
device = torch.device('cpu')

In [12]:
def get_batch(data, B):
    idx = torch.randint(len(data) - T, (B,))
    x = torch.stack([data[i:i+T] for i in idx])
    y = torch.stack([data[i+1:i+T+1] for i in idx])
    return x, y

In [13]:
x, y = get_batch(data_train, B)

In [14]:
x.shape, y.shape

(torch.Size([32, 8]), torch.Size([32, 8]))

In [15]:
print([decode(l) for l in x.numpy()])
print([decode(l) for l in y.numpy()])

["Let's he", 'for that', 'nt that ', 'MEO:\nI p', 'seven ye', 'ved.\nMy ', "rd's sak", 'esty, th', 'e may be', 'the ears', 'ation? Y', 'ore I ca', 'lany its', 'roy did ', 'am afrai', 'ELIZABET', ' and gel', ' that do', ' would I', 'usband b', 'nd.\n\nKIN', 'gods\nKee', 'n was mo', 'ak? O tr', 'of speec', 'sons.\n\nM', ' defect ', 'I wander', "eem'd bu", 'gly now?', 'not.\nMy ', 'ot, my l']
["et's hea", 'or that ', 't that h', 'EO:\nI pa', 'even yea', 'ed.\nMy g', "d's sake", 'sty, thi', ' may be ', 'he ears:', 'tion? Yo', 're I cam', 'any itse', 'oy did s', 'm afraid', 'LIZABETH', 'and geld', 'that do ', 'would I ', 'sband bi', 'd.\n\nKING', 'ods\nKeep', ' was mor', 'k? O tra', 'f speech', 'ons.\n\nME', 'defect o', ' wander,', "em'd bur", 'ly now?\n', 'ot.\nMy w', 't, my lo']


In [16]:
class SelfAttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.query = nn.Linear(embed_dim, head_dim, bias=False)
        self.key = nn.Linear(embed_dim, head_dim, bias=False)
        self.value = nn.Linear(embed_dim, head_dim, bias=False)
        # self.register_buffer('tril', torch.tril(torch.ones(sequence_dim, sequence_dim)))

    def forward(self, x):
        B, T, E = x.shape
        q = self.query(x) # (B, T, E) -> (B, T, H)
        k = self.key(x) # (B, T, E) -> (B, T, H)
        v = self.value(x) # (B, T, E) -> (B, T, H)
        wei = q @ k.transpose(-2, -1) * E**(-0.5) # (B, T, H) @ (B, H, T) = (B, T, T)
        tril = torch.tril(torch.ones(T, T)) # how to register buffer without T param
        wei = wei.masked_fill(tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        out = wei @ v # (B, T, T) @ (B, T, H) = (B, T, H)
        return out

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, head_dim=None):
        super().__init__()
        if head_dim is None:
            head_dim = embed_dim//num_heads
        self.heads = nn.ModuleList([SelfAttentionHead(embed_dim, head_dim) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

In [18]:
class FeedForward(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.ReLU()
        )
    def forward(self, x):
        return self.net(x)

In [19]:
# class MHABlock(nn.Module):
#     def __init__(self):
#         self.mha = (n_head, head_dim)

In [20]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_dim, embed_dim, sequence_dim):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_dim, embed_dim)
        self.position_embedding = nn.Embedding(sequence_dim, embed_dim)
        self.sah = MultiHeadAttention(embed_dim, 4)
        self.ff = FeedForward(embed_dim, embed_dim)
        self.linear = nn.Linear(embed_dim, vocab_dim)

    def forward(self, x, y=None):
        # B is batch, T is length of time series, E is embedding dim
        B, T = x.shape
        token_embeddings = self.token_embedding(x) # (B, T, E)
        position_embeddings = self.position_embedding(torch.arange(T).to(device)) # (T, E) # T <= sequence_dim
        x = token_embeddings + position_embeddings # (B, T, E) +  (-, T, E) -> (B, T, E)
        x = self.sah(x)
        x = self.ff(x)
        logits = self.linear(x)
        if y is None:
            loss = None
        else:
            B, T, E = logits.shape
            logits = logits.view(B*T, E)
            y = y.view(B*T)
            loss = F.cross_entropy(logits, y)
        return logits, loss

    def generate(self, x, n_tokens):
        for i in range(n_tokens):
            x_cropped = x[:, -T:] # crop s.t. it's <= sequence_dim
            logits, _ = self(x_cropped) # (B, T, E)
            logits = logits[:, -1, :] # (B, E)
            probs = F.softmax(logits, dim=-1) # (B, E)
            y_pred = torch.multinomial(probs, num_samples=1) # (B, 1)
            x = torch.cat((x, y_pred), dim=1) # (B, T) + (B, 1) = (B, T + 1)
        return x

In [21]:
model = BigramLanguageModel(vocab_dim, embed_dim, T).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [22]:
model

BigramLanguageModel(
  (token_embedding): Embedding(65, 32)
  (position_embedding): Embedding(8, 32)
  (sah): MultiHeadAttention(
    (heads): ModuleList(
      (0): SelfAttentionHead(
        (query): Linear(in_features=32, out_features=8, bias=False)
        (key): Linear(in_features=32, out_features=8, bias=False)
        (value): Linear(in_features=32, out_features=8, bias=False)
      )
      (1): SelfAttentionHead(
        (query): Linear(in_features=32, out_features=8, bias=False)
        (key): Linear(in_features=32, out_features=8, bias=False)
        (value): Linear(in_features=32, out_features=8, bias=False)
      )
      (2): SelfAttentionHead(
        (query): Linear(in_features=32, out_features=8, bias=False)
        (key): Linear(in_features=32, out_features=8, bias=False)
        (value): Linear(in_features=32, out_features=8, bias=False)
      )
      (3): SelfAttentionHead(
        (query): Linear(in_features=32, out_features=8, bias=False)
        (key): Linear(in_fe

In [23]:
# pre training
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.int64).to(device), 100).cpu().numpy()[0]))


U
icIHTZZqPsVu&tqWdUlORt&GlV&VcohJNYT;eQz:mxixwYg,gAo3E!N hOY$!VqTpEHLzMBXsvzXRvzYG:wiB&y,iVED$-xKn;


In [24]:
@torch.no_grad()
def estimate_loss(model, iters, device):
    out = []
    model.to(device)
    model.eval()
    losses = torch.zeros(iters)
    for data in [data_train, data_val]:
        for i in range(iters):
            x, y = get_batch(data, B)
            x = x.to(device)
            y = y.to(device)
            logits, loss = model(x, y)
            losses[i] = loss.item()
        out.append(losses.mean())
    model.train()
    return out

In [25]:
tenth = train_steps//10
for steps in range(train_steps):
    x, y = get_batch(data_train, B)
    x = x.to(device)
    y = y.to(device)
    logits, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps % tenth == 0:
        train_loss, val_loss = estimate_loss(model, 100, device)
        print(train_loss, val_loss)

tensor(4.2014) tensor(4.2007)
tensor(2.6325) tensor(2.6614)
tensor(2.4881) tensor(2.5101)
tensor(2.4194) tensor(2.4270)
tensor(2.3769) tensor(2.4007)
tensor(2.3177) tensor(2.3485)
tensor(2.3080) tensor(2.3188)
tensor(2.2774) tensor(2.2941)
tensor(2.2613) tensor(2.2824)
tensor(2.2552) tensor(2.2706)


In [26]:
# post training
print(decode(model.generate(torch.zeros((1, 1), dtype=torch.int64).to(device), 1000).cpu().numpy()[0]))


Fo wernereg
Shounds hobll his whiloll us ale,
I theat coouf mat mpir bow biven?
Mowy's of ing kmesallir:
I andmun, blome nomsthee am wice; sotheyearsteves peewibth; gon,
Mor ane gor all, sher to sas,
thas ompevos to boos of whor,
Gothe ap aplave bule onceivou
Ren sor hen meyto che the miss ass khan; hor. Youst she no thee; evesters and in hant old; yound
Id: peignilmlot-nen's swor past, I'd werms youp.
Nis gou Vofe, of as his a pour dinow of gave stoo plablad;
Me's lrme.

JOMER:
BIF V, is, tre con, fe; wear itht,'dsees
Towire, to rait thuntwend whoce'simust, houng le this you, che ory of our to far: I
Shond she thamy minesse whes youn:
How peanto-reapranted, toprchs,
Wherit pre hall thllookse urig lour go ward Meiseasesme and yorun. Eram,
Dor on:

thee datthat fen, sis lirurencerm athew
JI
Thato wirdes with sich!y sound mome knowhis, Whut perd'cenquechisaurtit,
Hef by, eak wed hou thante fourntre woilt:
Sesintact, tou worest wis rer's bomote se beespheair sse! Yatthe ste, hou serem.



In [27]:
B, T, E = 4, 8, 2
x = torch.randn(B, T, E)
xbow = torch.zeros((B, T, E))

In [28]:
# method 1
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t, D)
        xbow[b, t] = torch.mean(xprev, 0) # (D,)

In [29]:
# method 2
wei = torch.tril(torch.ones((T, T)))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (-, T, T) @ (B, T, C) = (B, T, C)

In [30]:
# method 3
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x # (-, T, T) @ (B, T, C) = (B, T, C)

![](diagram.png)

In [31]:
# method 4: Attention Mechanism
head_dim = 16
query = nn.Linear(E, head_dim, bias=False)
key = nn.Linear(E, head_dim, bias=False)
value = nn.Linear(E, head_dim, bias=False)
q = query(x) # q = Wx (B, T, 16)
k = key(x) # k = Wx (B, T, 16)
v = value(x) # v = Wx (B, T, 16)

# dot product between two vectors measures similarity
# this matrix multiplication dots each vector to every other vector
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (8, 16, T) = (B, T, T)
wei = wei * head_dim**(-0.5)

tril = torch.tril(torch.ones(T, T)) # optional (for decoder)
wei = wei.masked_fill(tril == 0, float('-inf')) # optional (for decoder)
wei = F.softmax(wei, dim=-1)
out = wei @ v # (B, T, T) @ (B, T, H) = (B, T, H)

In [None]:
# https://youtu.be/kCc8FmEb1nY?si=AEN-_b8nxN1nxvDf&t=5248