In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [10]:
shakespeare = open('input.txt', 'r').read()

In [11]:
vocab = sorted(list(set(''.join(shakespeare))))
vocab_size = len(vocab)
vocab_size

65

In [14]:
''.join(vocab)

"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [24]:
stoi = {ch:i for i,ch in enumerate(vocab)}
itos = {i:ch for ch,i in stoi.items()}

encode = lambda text: [stoi[ch] for ch in text]
decode = lambda idx: ''.join([itos[i] for i in idx])

decode(encode('Hello there')) # test encode-decode functionality

'Hello there'

In [None]:
print(len(shakespeare))
encoded_text = encode(shakespeare)
print(len(encoded_text))

n = int(len(encoded_text) * 0.9)
train_data = encoded_text[:n]
val_data = encoded_text[n:]

1115394
1115394


In [83]:
block_size = 8

In [84]:
train_data[:block_size + 1]

[18, 47, 56, 57, 58, 1, 15, 47, 58]

In [85]:
# look at how inputs - output pairs look like with a given block_size

x = train_data[:block_size + 1]
for i in range(1, block_size+1):
    inp = x[:i]
    output = x[i]
    print(f'{inp} --> {output}')

[18] --> 47
[18, 47] --> 56
[18, 47, 56] --> 57
[18, 47, 56, 57] --> 58
[18, 47, 56, 57, 58] --> 1
[18, 47, 56, 57, 58, 1] --> 15
[18, 47, 56, 57, 58, 1, 15] --> 47
[18, 47, 56, 57, 58, 1, 15, 47] --> 58


In [87]:
torch.manual_seed(2)

# making a batch of data
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(train_data) - block_size, (batch_size,))

    x = torch.tensor([train_data[i : i + block_size] for i in ix])
    y = torch.tensor([train_data[i+1 : i + 1 + block_size] for i in ix])

    return x, y

In [101]:
xb, yb = get_batch('train')

print(xb)
print(yb)

tensor([[47, 51,  6,  0, 13, 57,  1, 46],
        [60, 43,  1, 51, 39, 42, 43,  1],
        [ 1, 42, 53,  1, 51, 47, 57, 58],
        [53, 50, 42,  1, 40, 50, 53, 53]])
tensor([[51,  6,  0, 13, 57,  1, 46, 43],
        [43,  1, 51, 39, 42, 43,  1, 45],
        [42, 53,  1, 51, 47, 57, 58, 39],
        [50, 42,  1, 40, 50, 53, 53, 42]])


In [102]:
for b in range(batch_size):
    for t in range(block_size):
        inp = xb[b, : t+1]
        out = yb[b, t]
        print(f'{inp} --> {out}')
    break

tensor([47]) --> 51
tensor([47, 51]) --> 6
tensor([47, 51,  6]) --> 0
tensor([47, 51,  6,  0]) --> 13
tensor([47, 51,  6,  0, 13]) --> 57
tensor([47, 51,  6,  0, 13, 57]) --> 1
tensor([47, 51,  6,  0, 13, 57,  1]) --> 46
tensor([47, 51,  6,  0, 13, 57,  1, 46]) --> 43


In [233]:
# torch.ones(2,4,8)[:,-1:,:].shape

In [247]:
class BigramLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_encoding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, x, targets = None):
        # x shape       - (b, t)
        # targets shape - (b, t)
        logits = self.token_encoding_table(x) # (b, t, vocab_size)
        
        if targets is None:
            loss = None
        else:
            # B,T,C = logits.shape
            loss = F.cross_entropy(logits.transpose(-1,-2), targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (b, t)
        for _ in range(max_new_tokens):
            logits, loss = self(idx) # (b, t, vocab_size)
            logits = logits[:, -1, :] # (b, vocab_size)
            probs = F.softmax(logits, dim=-1) # (b, vocab_size)
            idx_next = torch.multinomial(probs, num_samples=1) # (b, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (b, t+1)
        return idx


In [279]:
bigram_model = BigramLM(vocab_size)

xb, yb = get_batch('train')

# logits, loss = bigram_model(xb, yb)
# logits.shape, loss.item()

In [280]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(bigram_model.generate(idx, max_new_tokens=100)[0].tolist()))


:!UYJOro,:MAcBrF!fsi3a:wnveh'FN-c YBFwC VOrZ3lvBk$!jmv&v;yt'.IIy?F$Fy'?h'VwC$Sd$ ;a;RNPROEzPRESTQ;A3


In [281]:
# optimizer
optimizer = torch.optim.AdamW(bigram_model.parameters(), lr=1e-3)

In [282]:
# training
batch_size = 32
for _ in range(10000):
    xb, yb = get_batch('train')

    logits, loss = bigram_model(xb, yb)
    
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.5395166873931885


In [283]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(bigram_model.generate(idx, max_new_tokens=500)[0].tolist()))


LBEDeame ghy, cthemaiandsth wesinomy ind;
Derg th tanistowat broshin se ate.
BRDY: d haseche are,
' sha theale ay ndrdiopou checreeo sthe ifeaus, bre.
Gofo cous tr nerim s

ARETofovil, asu he?
Whoft?
uant quthind t,
Th th letourthanchousirs f ANowousgre d t?
Hay wind omayo, impon my che?
HENCapangeef s IClelise fe chisthise asthicoulis y a IZorenoou.
Beanos, Fower llveriz.
LIAy d:
As maleun:
Hacethatheetrusthieend t thauery! cee ndaze cho y thelgal t hen r akim he, owom!
GRI bo,
Sapin mon thayor
