In [143]:
""" transformer: https://arxiv.org/abs/1706.03762

Dimension key:

B: batch size
T: sequence length
V: vocabulary size
"""

# data
import torch
torch.manual_seed(1337)

B, T = 4, 8
def gen_dataset(data):
    ix_rand = torch.randint(len(data) - T, (B,))
    x = torch.stack([data[i:i+T] for i in ix_rand])
    y = torch.stack([data[i+1:i+T+1] for i in ix_rand])
    return x, y

with open('./data/shakespeare.txt', 'r') as f:
    text = f.read()
vocab = sorted(list(set(''.join(text))))
V = len(vocab)

# tokenize
encode = { c:i for i,c in enumerate(vocab) }
# encode['.'] = 0
decode = { i:c for i,c in enumerate(vocab) }
data = torch.tensor([encode[c] for c in text], dtype=torch.long)

# dataload
n1, n2 = int(0.8*len(data)), int(0.9*len(data))
Xtr_BT, Ytr_BT = gen_dataset(data[:n1])
Xdev_BT, Ydev_BT = gen_dataset(data[n1:n2])
Xte_BT, Yte_BT = gen_dataset(data[n2:])

print(Xtr)
print(Ytr)

for b in range(B):
    print('batch', b)
    for t in range(T):
        context = Xtr[b, :t+1]
        target = Ytr[b, t]
        print('x:', context, '->', 'y:', target)

tensor([[58, 63,  0,  0,  0, 19, 24, 27],
        [39, 59, 45, 46, 58,  1, 46, 43],
        [49, 43, 57,  1, 53, 50, 42,  1],
        [52, 41, 47, 43, 52, 58,  1, 56]])
tensor([[63,  0,  0,  0, 19, 24, 27, 33],
        [59, 45, 46, 58,  1, 46, 43,  1],
        [43, 57,  1, 53, 50, 42,  1, 46],
        [41, 47, 43, 52, 58,  1, 56, 47]])
batch 0
x: tensor([58]) -> y: tensor(63)
x: tensor([58, 63]) -> y: tensor(0)
x: tensor([58, 63,  0]) -> y: tensor(0)
x: tensor([58, 63,  0,  0]) -> y: tensor(0)
x: tensor([58, 63,  0,  0,  0]) -> y: tensor(19)
x: tensor([58, 63,  0,  0,  0, 19]) -> y: tensor(24)
x: tensor([58, 63,  0,  0,  0, 19, 24]) -> y: tensor(27)
x: tensor([58, 63,  0,  0,  0, 19, 24, 27]) -> y: tensor(33)
batch 1
x: tensor([39]) -> y: tensor(59)
x: tensor([39, 59]) -> y: tensor(45)
x: tensor([39, 59, 45]) -> y: tensor(46)
x: tensor([39, 59, 45, 46]) -> y: tensor(58)
x: tensor([39, 59, 45, 46, 58]) -> y: tensor(1)
x: tensor([39, 59, 45, 46, 58,  1]) -> y: tensor(46)
x: tensor([39, 5

In [144]:
# model
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)

class Bigram(nn.Module):
    def __init__(self, V):
        super().__init__()
        self.token_embedding_table = nn.Embedding(V, V)

    # Y_BT is optional for inference
    def forward(self, X_BT, Y_BT=None):
        logits_BTV = self.token_embedding_table(X_BT)
        B, T, V = logits_BTV.shape # support inference
        if Y_BT is None:
            loss = None
        else:
            loss = F.cross_entropy(logits_BTV.view(B*T, V), Y_BT.view(B*T)) # reshape for .cross_entropy()

        return logits_BTV, loss

    # generate: X_BT -> X_B(T+N)
    def generate(self, X_BT, N):
        X_BTi = X_BT
        for _ in range(N):
            logits_BTV, _ = self(X_BTi)
            logits_BV = logits_BTV[:, -1, :] # pluck out last prediction (t=T)
            probs = F.softmax(logits_BV, dim=1)
            y_hat_B1 = torch.multinomial(probs, num_samples=1)
            X_BTi = torch.cat((X_BTi, y_hat_B1), dim=1)
        
        return X_BTi

m = Bigram(V)
logits_BTV, loss = m(Xtr_BT, Ytr_BT)
print(logits_BTV.shape)
print(loss)

B_inf, T_inf = 1, 1
X_BinfTinf = torch.zeros((B_inf,T_inf), dtype=torch.long)
Y_hat_BTplusN = m.generate(X_BinfTinf, N=100)
Y_hat_BTplusNdecoded = ''.join([decode[i] for i in Y_hat_BTplusN[0].tolist()]) # 0 since B_inf = 1 
print(Y_hat_BTplusNdecoded)

torch.Size([4, 8, 65])
tensor(5.0493, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [145]:
# training
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

Btr = 32
for steps in range(10000):
    X_BT, Y_BT = gen_dataset(data[:n1])
    logits_BTV, loss = m(X_BT, Y_BT)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.577054500579834


In [146]:
B_inf, T_inf = 1, 1
X_BinfTinf = torch.zeros((B_inf,T_inf), dtype=torch.long)
Y_hat_BTplusN = m.generate(X_BinfTinf, N=500)
Y_hat_BTplusNdecoded = ''.join([decode[i] for i in Y_hat_BTplusN[0].tolist()]) # 0 since B_inf = 1 
print(Y_hat_BTplusNdecoded)


LUng:
j?3o f hereren rdiglsunsomm d fethal m,
KTINNAl tthane Soug ls th un,
AD3XEr tidsey EThe or akGLGHAky!
G&paveelt pofotes dend tou:

AEVuing
SPKI' thickerenlaig, at wg wohay mandstobouThee CH:P! t DI fuf O my me he n:
Biset,Uinl, ke ig tor, micuais s.
blt ind' cegandisins m.
W-ves: anyoroup meartold n ay,
3SSSeral igr

CE:way; ke me En s HEQmy alaf tryo'se, y h d hd t t.
Uhe dizkindlalesed wr't tho ainck, or bblligf whoul dico.
I tly d n, kshiseline MPy LIEN FJUpUgathaite, n:wo o-gusatofthi
