In [2]:
import torch

# Data preparation

## Load raw text

In [3]:
with open('../data/shakespeare.txt', 'r') as f:
    text = f.read()

In [4]:
print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


## Tokenization

In [5]:
tokens = sorted(list(set(text)))
''.join(tokens)

"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [6]:
stoi = { ch: i for i, ch in enumerate(tokens) }
itos = { i: ch for i, ch in enumerate(tokens) }

In [7]:
def encode(text):
    return torch.tensor([stoi[ch] for ch in text], dtype=torch.long)

def decode(tensor):
    return ''.join([itos[i.item()] for i in tensor])

In [8]:
encode('testi')

tensor([58, 43, 57, 58, 47])

In [9]:
decode(encode('testi'))

'testi'

In [10]:
data = encode(text)
data[:10]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])

In [11]:
decode(data[:10])

'First Citi'

In [12]:
split = int(0.8 * len(data))
train = data[:split]
val = data[split:]

print(train.shape, val.shape)

torch.Size([892315]) torch.Size([223079])


# Dataloader

Getting a single chunk of data:

In [13]:
block_size = 8
batch_size = 4

In [14]:
offset = 10 # arbitrary offset for demonstration

x = train[offset:offset+block_size]
y = train[offset+1:offset+block_size+1]

print(x)
print(y)

tensor([64, 43, 52, 10,  0, 14, 43, 44])
tensor([43, 52, 10,  0, 14, 43, 44, 53])


We generate random offsets into the training data:

In [15]:
offsets = torch.randint(0, split-block_size, (batch_size,))
offsets

tensor([793523, 806706,  89742, 807802])

And then generate a block-size x and a shifted-by-1 block-size y for each offset, stacking those tensor into a single x and y tensor:

In [16]:
print(torch.stack([data[offset : offset+block_size] for offset in offsets]))
print(torch.stack([data[offset+1 : offset+block_size+1] for offset in offsets]))

tensor([[53,  1, 39,  1, 60, 43, 56, 63],
        [57,  1, 58, 46, 43, 43,  1, 58],
        [58, 46, 47, 57,  1, 54, 53, 47],
        [43, 10,  1, 40, 43, 57, 47, 42]])
tensor([[ 1, 39,  1, 60, 43, 56, 63,  1],
        [ 1, 58, 46, 43, 43,  1, 58, 46],
        [46, 47, 57,  1, 54, 53, 47, 52],
        [10,  1, 40, 43, 57, 47, 42, 43]])


In [17]:
def get_batch(data, block_size=block_size, batch_size=batch_size):
    offsets = torch.randint(0, split-block_size, (batch_size,))

    xb = torch.stack([data[offset : offset+block_size] for offset in offsets])
    yb = torch.stack([data[offset+1 : offset+block_size+1] for offset in offsets])

    return xb, yb

In [18]:
get_batch(train)

(tensor([[47, 58, 46,  1, 51, 39, 52, 63],
         [52, 42,  1, 51, 43,  1, 58, 53],
         [41, 46, 53, 49, 43,  1, 58, 46],
         [ 1, 53, 52, 41, 43,  8,  1, 35]]),
 tensor([[58, 46,  1, 51, 39, 52, 63,  1],
         [42,  1, 51, 43,  1, 58, 53,  1],
         [46, 53, 49, 43,  1, 58, 46, 43],
         [53, 52, 41, 43,  8,  1, 35, 46]]))

# Model

In [19]:
class BigramModel(torch.nn.Module):
    def __init__(self, vocab_size):
        super().__init__()

        self.embedding = torch.nn.Embedding(vocab_size, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, x, targets=None):
        x = self.embedding(x)

        if targets is None:
            return x, None

        loss = torch.nn.functional.cross_entropy(x.view(-1, self.vocab_size), targets.view(-1))

        return x, loss

    def generate_text(self, x, steps=500):
        for _ in range(steps):
            logits, _ = self(x)
            last_logits = logits[:,-1,:]
            probs = torch.functional.F.softmax(last_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            x = torch.cat([x, next_token], dim=1)

        return x

    @torch.no_grad()
    def generate(self, x, steps=100):
        for _ in range(steps):
            x, _ = self(x)
            x = x[-1].argmax()
            yield x

## Forward pass

In [20]:
vocab_size = len(tokens)
vocab_size

65

In [21]:
model = BigramModel(vocab_size)

In [22]:
xb, yb = get_batch(train)
xb.shape # (batch_size, block_size)

torch.Size([4, 8])

In [23]:
logits, loss = model(xb, yb)
logits.shape # (batch_size, block_size, vocab_size)

torch.Size([4, 8, 65])

In [24]:
loss

tensor(4.6677, grad_fn=<NllLossBackward0>)

## Generation

In [25]:
x = torch.zeros((1,1), dtype=torch.long) # BxT
x = model.generate_text(x)
x

tensor([[ 0, 12, 31, 33,  6,  8, 60, 55, 36, 24, 43, 42, 59, 40, 62, 10, 21, 22,
         40, 13, 16, 58, 16, 24, 40,  5, 15,  7,  6, 40, 62, 18, 37, 46,  9, 37,
         38, 45, 55, 58, 39, 32, 51, 21,  3, 28, 11,  5, 15, 64, 18, 38, 40,  7,
         12,  1, 46,  9, 57,  7, 16, 11, 39, 40, 24, 61, 45,  8, 48, 62, 19, 26,
          4, 11, 55, 17,  0, 29, 50, 10, 35, 10, 43, 34, 48, 38, 30, 22, 27,  6,
          4, 62,  9, 16, 11, 51, 28, 15, 12, 53, 53, 11, 21,  1, 57,  7, 39, 62,
         21,  0, 52, 11,  5,  9, 33, 60, 63, 26,  0,  1,  9, 33, 50, 47,  2, 25,
         44, 44, 64, 24, 12, 31, 45,  5,  9,  0, 38, 34, 28, 15, 26, 14, 19, 19,
         23, 33, 21, 11, 53, 25, 51, 22,  9, 29, 16, 40, 23, 17, 17, 46, 18, 28,
         51, 29, 39, 24, 33,  3, 19,  6, 54, 14, 17, 25, 36, 28, 64, 24, 12, 10,
         26, 26, 37, 48, 10, 60,  5,  0, 46, 22,  7, 23, 51, 33, 49, 24, 50, 35,
         16, 64, 58, 59, 30, 20, 15, 27,  0,  1, 59, 21,  3, 23, 10, 26,  6, 44,
         40, 23, 17, 18, 43,

In [26]:
print(decode(x[0]))


?SU,.vqXLedubx:IJbADtDLb'C-,bxFYh3YZgqtaTmI$P;'CzFZb-? h3s-D;abLwg.jxGN&;qE
Ql:W:eVjZRJO,&x3D;mPC?oo;I s-axI
n;'3UvyN
 3Uli!MffzL?Sg'3
ZVPCNBGGKUI;oMmJ3QDbKEEhFPmQaLU$G,pBEMXPzL?:NNYj:v'
hJ-KmUkLlWDztuRHCO
 uI$K:N,fbKEFeKVUBJvEvq xeGo&A;zRO
mjNoOAzgRdRMz'3&EcK.ntffBNE I.
Qt$HddP;zYzoIdxtfs:eqEYjQZRUlPp:WhvAYmynsrd3ZRP!ulPPOIzuuUEcKE:DV YrpnT?o&FNLlW:e!WEdTr'CScpqI$BJ$eqE:ym'WXXe&.,aW,pC-Qn,&y
YhwbLelfD
h.IJOn'IJ,,ptsMPm XEvVfbAvl nt:-&MuOAWg.??tjJ3Bo&HktfMzGRdaJ3smTnW,TNM3zrp:ADxV:a$PCGwm,pcHG-Q


# Training

In [27]:
n_epochs = 1000

m = BigramModel(vocab_size)
optim = torch.optim.Adam(m.parameters(), lr=0.01)

In [28]:
print_every = int(0.1 * n_epochs)

for epoch in range(n_epochs):
    xb, yb = get_batch(train, block_size=128)
    optim.zero_grad()

    logits, loss = m(xb, yb)
    if epoch % print_every == 0:
        print(f'Epoch {epoch}, loss: {loss.item()}')

    loss.backward()
    optim.step()

Epoch 0, loss: 4.601304054260254
Epoch 100, loss: 3.5026698112487793
Epoch 200, loss: 2.9861037731170654
Epoch 300, loss: 2.694559335708618
Epoch 400, loss: 2.6426682472229004
Epoch 500, loss: 2.5889017581939697
Epoch 600, loss: 2.492767572402954
Epoch 700, loss: 2.450739622116089
Epoch 800, loss: 2.550001382827759
Epoch 900, loss: 2.4272985458374023


In [29]:
text = m.generate_text(torch.zeros((1,1), dtype=torch.long))
decoded = decode(text[0])

print(decoded)


Agnds wicemfakicke, tem ay, wordalturemboot h, hinomicee, Twhe II butheed peind otononde onseagand inn: y,? futonon Eroures,
He s me's, t ba athupot tas tOMxere.
IN: foure twhansHe matou t ill thee qund threeroPERIZAndeacavewe!
HUSSTEdilvishese ch t.

ET:
I d S:
LENol; haty tineliou&PS of:
D woth pit IUyow'sugakesol, bait'edllitheasms, mowrith bus theVJun use be lkwrearosatke aime hyowicay os tofot.
Fit ng!
BXIN s:
QUMatitsedind nes,
MERCly, saiknougs ht t m le w thargrs-maishetheit olo a me w y
