In [1]:
with open("austen.txt", 'r', encoding='utf-8') as f1:
    text = f1.read()

In [2]:
print(f"length of the dataset in characters: {len(text)}")

length of the dataset in characters: 5292297


In [3]:
print(text[:1000])

No one who had ever seen Catherine Morland in her infancy would have supposed her born to be an heroine. Her situation in life, the character of her father and mother, her own person and disposition, were all equally against her. Her father was a clergyman, without being neglected, or poor, and a very respectable man, though his name was Richard--and he had never been handsome. He had a considerable independence besides two good livings--and he was not in the least addicted to locking up his daughters. Her mother was a woman of useful plain sense, with a good temper, and, what is more remarkable, with a good constitution. She had three sons before Catherine was born; and instead of dying in bringing the latter into the world, as anybody might expect, she still lived on--lived to have six children more--to see them growing up around her, and to enjoy excellent health herself. A family of ten children will be always called a fine family, where there are heads and arms and legs enough for

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("unique characters in the text: ", "".join(chars))
print("length of the vocabulary: ", vocab_size)

unique characters in the text:  
 !"&'()*,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]^_abcdefghijklmnopqrstuvwxyz{}£½àáæèéê“”
length of the vocabulary:  94


In [5]:
# create a mapping from string to integers
stoi = {chars[i]:i for i in range(vocab_size)}
itos = {i:chars[i] for i in range(vocab_size)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda i: "".join([itos[j] for j in i])

print(encode("hii there"))
print(decode(encode("hii there")))

[63, 64, 64, 1, 75, 63, 60, 73, 60]
hii there


In [6]:
# encode the entire dataset and store it in a tensor
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([5292297]) torch.int64
tensor([39, 70,  1, 70, 69, 60,  1, 78, 63, 70,  1, 63, 56, 59,  1, 60, 77, 60,
        73,  1, 74, 60, 60, 69,  1, 28, 56, 75, 63, 60, 73, 64, 69, 60,  1, 38,
        70, 73, 67, 56, 69, 59,  1, 64, 69,  1, 63, 60, 73,  1, 64, 69, 61, 56,
        69, 58, 80,  1, 78, 70, 76, 67, 59,  1, 63, 56, 77, 60,  1, 74, 76, 71,
        71, 70, 74, 60, 59,  1, 63, 60, 73,  1, 57, 70, 73, 69,  1, 75, 70,  1,
        57, 60,  1, 56, 69,  1, 63, 60, 73, 70, 64, 69, 60, 11,  1, 33, 60, 73,
         1, 74, 64, 75, 76, 56, 75, 64, 70, 69,  1, 64, 69,  1, 67, 64, 61, 60,
         9,  1, 75, 63, 60,  1, 58, 63, 56, 73, 56, 58, 75, 60, 73,  1, 70, 61,
         1, 63, 60, 73,  1, 61, 56, 75, 63, 60, 73,  1, 56, 69, 59,  1, 68, 70,
        75, 63, 60, 73,  9,  1, 63, 60, 73,  1, 70, 78, 69,  1, 71, 60, 73, 74,
        70, 69,  1, 56, 69, 59,  1, 59, 64, 74, 71, 70, 74, 64, 75, 64, 70, 69,
         9,  1, 78, 60, 73, 60,  1, 56, 67, 67,  1, 60, 72, 76, 56, 67, 67, 80,
      

In [7]:
# split up data into train and validation sets
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [8]:
block_size = 8   # the maximum size of the block that is fed to the transformer at once

train_data[:block_size+1]

tensor([39, 70,  1, 70, 69, 60,  1, 78, 63])

In [9]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")


when input is tensor([39]) the target: 70
when input is tensor([39, 70]) the target: 1
when input is tensor([39, 70,  1]) the target: 70
when input is tensor([39, 70,  1, 70]) the target: 69
when input is tensor([39, 70,  1, 70, 69]) the target: 60
when input is tensor([39, 70,  1, 70, 69, 60]) the target: 1
when input is tensor([39, 70,  1, 70, 69, 60,  1]) the target: 78
when input is tensor([39, 70,  1, 70, 69, 60,  1, 78]) the target: 63


In [10]:
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    # generate a small batch of data of inputs x and target y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print("-------")

for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[ 1, 57, 60, 64, 69, 62,  1, 75],
        [67, 67,  1, 37, 56, 59, 80,  1],
        [28, 56, 75, 63, 60, 73, 64, 69],
        [56,  1, 58, 67, 60, 73, 66,  1]])
targets:
torch.Size([4, 8])
tensor([[57, 60, 64, 69, 62,  1, 75, 63],
        [67,  1, 37, 56, 59, 80,  1, 43],
        [56, 75, 63, 60, 73, 64, 69, 60],
        [ 1, 58, 67, 60, 73, 66,  1, 64]])
-------
when input is [1] the target: 57
when input is [1, 57] the target: 60
when input is [1, 57, 60] the target: 64
when input is [1, 57, 60, 64] the target: 69
when input is [1, 57, 60, 64, 69] the target: 62
when input is [1, 57, 60, 64, 69, 62] the target: 1
when input is [1, 57, 60, 64, 69, 62, 1] the target: 75
when input is [1, 57, 60, 64, 69, 62, 1, 75] the target: 63
when input is [67] the target: 67
when input is [67, 67] the target: 1
when input is [67, 67, 1] the target: 37
when input is [67, 67, 1, 37] the target: 56
when input is [67, 67, 1, 37, 56] the target: 59
when input is [67, 6

In [11]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx) # (B, T, C) Batch, Time, Channel (vocab size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the prediction
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

# "\n" is encoded into 0. So, we start by feeding a zero to the generate function
print(decode(m.generate(idx=torch.zeros([1,1], dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 94])
tensor(4.8272, grad_fn=<NllLossBackward>)

trNLphJ“s'æXxL/H£}Kry3Fkrg8!sm_3_FuL5H5EfPvSWbuziv8Rrè[1jfær))
fQbMDPSWdpYé25TW{/HSejá
qm50RA^U1a}Kc


In [12]:
# create a Pytorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [13]:
batch_size = 32
for steps in range(10000):
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(loss.item())

2.4198219776153564


In [14]:
print(decode(m.generate(idx=torch.zeros([1,1], dtype=torch.long), max_new_tokens=100)[0].tolist()))


SàLifrth pe rat bjeatthes, rlily sisuthithincur s atod Hveshag 
Dutonde S, oton ag matipanor mare n;
