In [39]:
# data
import torch
torch.manual_seed(1337)

B, T = 4, 8
def gen_dataset(data):
    ix_rand = torch.randint(len(data) - T, (B,))
    x = torch.stack([data[i:i+T] for i in ix_rand])
    y = torch.stack([data[i+1:i+T+1] for i in ix_rand])
    return x, y

with open('./data/shakespeare.txt', 'r') as f:
    text = f.read()
vocab = sorted(list(set(''.join(text))))
V = len(vocab)

# tokenize
encode = { c:i for i,c in enumerate(vocab) }
decode = { i:c for i,c in enumerate(vocab) }
data = torch.tensor([encode[c] for c in text], dtype=torch.long)

# dataload
n1, n2 = int(0.8*len(data)), int(0.9*len(data))
Xtr_BT, Ytr_BT = gen_dataset(data[:n1])
Xdev_BT, Ydev_BT = gen_dataset(data[n1:n2])
Xte_BT, Yte_BT = gen_dataset(data[n2:])

print(Xtr_BT)
print(Ytr_BT)

for b in range(B):
    print('batch', b)
    for t in range(T):
        context = Xtr_BT[b, :t+1]
        target = Ytr_BT[b, t]
        print('x:', context, '->', 'y:', target)

tensor([[58, 63,  8,  0,  0, 19, 24, 27],
        [39, 59, 45, 46, 58,  1, 46, 43],
        [49, 43, 57,  1, 53, 50, 42,  1],
        [52, 41, 47, 43, 52, 58,  1, 56]])
tensor([[63,  8,  0,  0, 19, 24, 27, 33],
        [59, 45, 46, 58,  1, 46, 43,  1],
        [43, 57,  1, 53, 50, 42,  1, 46],
        [41, 47, 43, 52, 58,  1, 56, 47]])
batch 0
x: tensor([58]) -> y: tensor(63)
x: tensor([58, 63]) -> y: tensor(8)
x: tensor([58, 63,  8]) -> y: tensor(0)
x: tensor([58, 63,  8,  0]) -> y: tensor(0)
x: tensor([58, 63,  8,  0,  0]) -> y: tensor(19)
x: tensor([58, 63,  8,  0,  0, 19]) -> y: tensor(24)
x: tensor([58, 63,  8,  0,  0, 19, 24]) -> y: tensor(27)
x: tensor([58, 63,  8,  0,  0, 19, 24, 27]) -> y: tensor(33)
batch 1
x: tensor([39]) -> y: tensor(59)
x: tensor([39, 59]) -> y: tensor(45)
x: tensor([39, 59, 45]) -> y: tensor(46)
x: tensor([39, 59, 45, 46]) -> y: tensor(58)
x: tensor([39, 59, 45, 46, 58]) -> y: tensor(1)
x: tensor([39, 59, 45, 46, 58,  1]) -> y: tensor(46)
x: tensor([39, 5

In [40]:
""" model: transformer (Vaswani et al. 2017 https://arxiv.org/abs/1706.03762)

Dimension key:

B: batch size
T: sequence length
V: vocabulary size
D: model dimension (d_model/embedding_dim)
H: number of attention heads in a layer
K: size of each attention key or value (d_kv)
F: feed-forward subnetwork hidden size
"""
# model
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)
D, K = 64, 16 # why is C redefined here? not V??

class SA(nn.Module):
    def __init__(self, K):
        super().__init__()
        self.Wq_DK = nn.Linear(D,K,bias=False)
        self.Wk_DK = nn.Linear(D,K,bias=False)
        self.Wv_DK = nn.Linear(D,K,bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(T, T)))

    def forward(self, X_BTD):
        B,T,D = X_BTD.shape
        Q_BTK, K_BTK, V_BTK = self.Wq_DK(X_BTD), self.Wk_DK(X_BTD), self.Wv_DK(X_BTD)

        A_BTT = Q_BTK @ K_BTK.transpose(-2, -1) * D**-0.5 # todo, pytorch transpose api
        A_BTT = A_BTT.masked_fill(self.tril[:T, :T]==0, float('-inf'))
        A_BTT = F.softmax(A_BTT, dim=-1) # todo, when dim=-1?

        H_BTK = A_BTT @ V_BTK

        return H_BTK

class MHA(nn.Module):
    def __init__(self, H, K):
        super().__init__()
        self.heads = nn.ModuleList([SA(K) for _ in range(H)])

    def forward(self, X_BTD):
        return torch.cat([h(X_BTD) for h in self.heads], dim=-1)

class FFN(nn.Module):
    def __init__(self, D):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(D, D),
            nn.ReLU()
        )

    def forward(self, X_BTK):
        return self.net(X_BTK)

class Bigram(nn.Module):
    def __init__(self, V):
        super().__init__()
        self.token_embedding_table = nn.Embedding(V, D)
        self.position_embedding_table = nn.Embedding(T, D)
        self.sa_heads = MHA(4, D//4)
        self.ffn = FFN(D)
        self.lm_head = nn.Linear(D, V)

    def forward(self, X_BT, Y_BT=None): # Y_BT is optional for inference

        B, T = X_BT.shape

        Xtok_BTD = self.token_embedding_table(X_BT)
        Xpos_TD = self.position_embedding_table(torch.arange(T))
        X_BTD = Xtok_BTD + Xpos_TD

        X_BTK = self.sa_heads(X_BTD)
        X_BTK = self.ffn(X_BTK)
        logits_BTV = self.lm_head(X_BTK)
 
        if Y_BT is None:
            loss = None
        else:
            B, T, V = logits_BTV.shape # support inference
            loss = F.cross_entropy(logits_BTV.view(B*T, V), Y_BT.view(B*T)) # reshape for .cross_entropy()

        return logits_BTV, loss

    # generate: X_BT -> X_B(T+N)
    def generate(self, X_BT, N):
        X_BTi = X_BT

        for _ in range(N):
            X_BTi = X_BTi[: -T:] # crop to last T tokens
            logits_BTV, _ = self(X_BTi)
            logits_BV = logits_BTV[:, -1, :] # pluck out last prediction (t=T)
            probs = F.softmax(logits_BV, dim=-1)
            y_hat_B1 = torch.multinomial(probs, num_samples=1)
            X_BTi = torch.cat((X_BTi, y_hat_B1), dim=1)
        
        return X_BTi

# initialization
m = Bigram(V)
logits_BTV, loss = m(Xtr_BT, Ytr_BT)
print(logits_BTV.shape)
print(loss)

# B_inf, T_inf = 1, 1
# X_BinfTinf = torch.zeros((B_inf,T_inf), dtype=torch.long)
# Y_hat_BTplusN = m.generate(X_BinfTinf, N=100)
# Y_hat_BTplusNdecoded = ''.join([decode[i] for i in Y_hat_BTplusN[0].tolist()]) # 0 since B_inf = 1 
# print(Y_hat_BTplusNdecoded)

torch.Size([4, 8, 65])
tensor(4.1785, grad_fn=<NllLossBackward0>)


In [41]:
# training loop
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

Btr = 32
for steps in range(10000):
    X_BT, Y_BT = gen_dataset(data[:n1])
    logits_BTV, loss = m(X_BT, Y_BT)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

# B_inf, T_inf = 1, 1
# X_BinfTinf = torch.zeros((B_inf,T_inf), dtype=torch.long)
# Y_hat_BTplusN = m.generate(X_BinfTinf, N=500)
# Y_hat_BTplusNdecoded = ''.join([decode[i] for i in Y_hat_BTplusN[0].tolist()]) # 0 since B_inf = 1 
# print(Y_hat_BTplusNdecoded)

2.0604941844940186
