In [1]:
import torch

# Data preparation

## Load raw text

In [2]:
with open('../data/shakespeare.txt', 'r') as f:
    text = f.read()

In [3]:
print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


## Tokenization

In [4]:
tokens = sorted(list(set(text)))
''.join(tokens)

"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [5]:
stoi = { ch: i for i, ch in enumerate(tokens) }
itos = { i: ch for i, ch in enumerate(tokens) }

In [6]:
def encode(text):
    return torch.tensor([stoi[ch] for ch in text], dtype=torch.long)

def decode(tensor):
    return ''.join([itos[i.item()] for i in tensor])

In [7]:
encode('testi')

tensor([58, 43, 57, 58, 47])

In [8]:
decode(encode('testi'))

'testi'

In [9]:
data = encode(text)
data[:10]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])

In [10]:
decode(data[:10])

'First Citi'

In [11]:
split = int(0.8 * len(data))
train = data[:split]
val = data[split:]

print(train.shape, val.shape)

torch.Size([892315]) torch.Size([223079])


# Dataloader

Getting a single chunk of data:

In [12]:
block_size = 8
batch_size = 4

In [13]:
offset = 10 # arbitrary offset for demonstration

x = train[offset:offset+block_size]
y = train[offset+1:offset+block_size+1]

print(x)
print(y)

tensor([64, 43, 52, 10,  0, 14, 43, 44])
tensor([43, 52, 10,  0, 14, 43, 44, 53])


We generate random offsets into the training data:

In [14]:
offsets = torch.randint(0, split-block_size, (batch_size,))
offsets

tensor([165150,  69139, 497238, 653945])

And then generate a block-size x and a shifted-by-1 block-size y for each offset, stacking those tensor into a single x and y tensor:

In [15]:
print(torch.stack([data[offset : offset+block_size] for offset in offsets]))
print(torch.stack([data[offset+1 : offset+block_size+1] for offset in offsets]))

tensor([[44,  8,  0,  0, 24, 13, 16, 37],
        [ 1, 58, 46, 43, 51,  2,  0,  0],
        [50,  0, 14, 43,  1, 57, 46, 56],
        [ 1, 41, 53, 52, 57, 41, 47, 43]])
tensor([[ 8,  0,  0, 24, 13, 16, 37,  1],
        [58, 46, 43, 51,  2,  0,  0, 15],
        [ 0, 14, 43,  1, 57, 46, 56, 47],
        [41, 53, 52, 57, 41, 47, 43, 52]])


In [16]:
def get_batch(data, block_size=block_size, batch_size=batch_size):
    offsets = torch.randint(0, split-block_size, (batch_size,))

    xb = torch.stack([data[offset : offset+block_size] for offset in offsets])
    yb = torch.stack([data[offset+1 : offset+block_size+1] for offset in offsets])

    return xb, yb

In [17]:
get_batch(train)

(tensor([[13, 10,  0, 21,  1, 39, 51,  1],
         [53, 60, 43,  6,  1, 58, 46, 47],
         [57, 11,  1, 44, 53, 56,  1, 47],
         [47, 51,  6,  1, 21,  1, 44, 43]]),
 tensor([[10,  0, 21,  1, 39, 51,  1, 39],
         [60, 43,  6,  1, 58, 46, 47, 57],
         [11,  1, 44, 53, 56,  1, 47, 44],
         [51,  6,  1, 21,  1, 44, 43, 39]]))

# Model

In [18]:
class BigramModel(torch.nn.Module):
    def __init__(self, vocab_size):
        super().__init__()

        self.embedding = torch.nn.Embedding(vocab_size, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, x, targets=None):
        x = self.embedding(x)

        if targets is None:
            return x, None

        loss = torch.nn.functional.cross_entropy(x.view(-1, self.vocab_size), targets.view(-1))

        return x, loss

    def generate_text(self, x, steps=500):
        for _ in range(steps):
            logits, _ = self(x)
            last_logits = logits[:,-1,:]
            probs = torch.functional.F.softmax(last_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            x = torch.cat([x, next_token], dim=1)

        return x

    @torch.no_grad()
    def generate(self, x, steps=100):
        for _ in range(steps):
            x, _ = self(x)
            x = x[-1].argmax()
            yield x

## Forward pass

In [19]:
vocab_size = len(tokens)
vocab_size

65

In [20]:
model = BigramModel(vocab_size)

In [21]:
xb, yb = get_batch(train)
xb.shape # (batch_size, block_size)

torch.Size([4, 8])

In [22]:
logits, loss = model(xb, yb)
logits.shape # (batch_size, block_size, vocab_size)

torch.Size([4, 8, 65])

In [23]:
loss

tensor(4.6028, grad_fn=<NllLossBackward0>)

## Generation

In [24]:
x = torch.zeros((1,1), dtype=torch.long) # BxT
x = model.generate_text(x)
x

tensor([[ 0, 22,  8, 45, 28,  5, 31, 30, 36, 57, 11, 61,  2, 27,  5, 29, 12, 32,
          9, 16, 15, 29, 58,  6,  4, 64, 36, 46, 49, 16, 30, 36, 18,  9, 35, 44,
         27, 23,  0,  8,  5, 48, 43, 54,  0, 20,  8, 51, 15, 39, 22,  4, 53,  4,
         62, 47, 10, 15, 50, 33, 31, 48,  9, 34, 62, 15, 63, 64, 62, 15, 49, 58,
         30, 47, 50,  9, 39, 53, 16, 57, 32, 21, 56,  3, 58, 33, 33, 16, 48, 15,
         52,  1, 24,  6, 36, 35,  3, 53,  0, 46,  7, 39, 44, 61,  6, 40, 41, 34,
         25, 46, 26,  2,  6, 61, 64, 45,  8, 46, 49, 42,  7, 61, 28, 23, 14, 60,
         41, 22, 29, 44,  7, 26, 33, 27, 16, 25, 37, 20, 39, 43, 30, 21, 40, 41,
          3, 42, 63, 18,  4, 46, 53, 29, 63, 26, 41, 59, 58, 64, 64, 45, 44, 27,
         11, 10, 30, 46, 61,  2, 39, 22,  3,  0, 46, 18, 45, 42, 54, 17, 39, 21,
         53, 15, 49, 13, 17, 43, 30, 64, 46, 26, 49, 18, 19, 53, 10, 24, 52, 45,
         57, 47, 55, 24, 49,  2, 28, 60, 29, 51, 14, 43, 30, 25, 37,  0,  7, 19,
         17,  0,  8, 12, 32,

In [25]:
print(decode(x[0]))


J.gP'SRXs;w!O'Q?T3DCQt,&zXhkDRXF3WfOK
.'jep
H.mCaJ&o&xi:ClUSj3VxCyzxCktRil3aoDsTIr$tUUDjCn L,XW$o
h-afw,bcVMhN!,wzg.hkd-wPKBvcJQf-NUODMYHaeRIbc$dyF&hoQyNcutzzgfO;:Rhw!aJ$
hFgdpEaIoCkAEeRzhNkFGo:LngsiqLk!PvQmBeRMY
-GE
.?T!!CmCBN-b
JNKXcDZow!VQgz?:jg-Xlk$t-HDyL&!f. uDIKc,zqsbMD
qACljVzJDjGigAt,Vdbn GurIQqxvmOgu.i,Z,tfXUMYs;KTKXQHAk'HmRSlvds;S-bIOW:CyVidaeRqeRzKsP3aZNlvQADsPWL';tnkDApN;ZeRlnvGmwGow&!kAcxCqNqMhgqCDIERXO;vQU.'--fTdv
qi,-FNc&XqbT3Bph$N
jGuDwhgwphm;wN'BpbcuMVReR,l
nc'&hOV3GOVL.:hEyIbcd


# Training

In [42]:
n_epochs = 5000

m = BigramModel(vocab_size)
optim = torch.optim.Adam(m.parameters(), lr=0.01)

In [43]:
print_every = int(0.1 * n_epochs)

for epoch in range(n_epochs):
    xb, yb = get_batch(train)

    logits, loss = m(xb, yb)
    if epoch % print_every == 0:
        print(f'Epoch {epoch}, loss: {loss.item()}')

    optim.zero_grad()
    loss.backward()
    optim.step()

Epoch 0, loss: 4.701707363128662
Epoch 500, loss: 3.198807954788208
Epoch 1000, loss: 2.7940256595611572
Epoch 1500, loss: 2.902494430541992
Epoch 2000, loss: 2.6032276153564453
Epoch 2500, loss: 2.472278356552124
Epoch 3000, loss: 2.251333713531494
Epoch 3500, loss: 2.258505344390869
Epoch 4000, loss: 2.4022860527038574
Epoch 4500, loss: 2.4708480834960938


In [44]:
text = m.generate_text(torch.zeros((1,1), dtype=torch.long))
decoded = decode(text[0])

print(decoded)


By ieep the cou ad wimubirdit tat Ingo od lan lloumme ITord t:
I f t ivill, ive GELARO, fr ord, INothe pthe myo bag t thesothe cuthar, swoungis utherat, t thincatarerber; mietaive and I he mou
Bered, divenfathas MESe,
Tovelewif dearo t tear or!CIXFit oantor tusisaico arde thy wirethor s thon:

CHEO:
Wis l tikingd talincareangndertothe
ARD:
HE:
I fove fonty are,
RCENGondio,
NGRDWind har mpr
Withe fonthan-
I IZENERESor, f s  bor yofat
Tirdobe ia

Was Yee heyowin tr I'l bls p e ms wholit f te plld,
