Nanogpt - transformer based character language model

In [1]:
with open('./input.txt','r',encoding='utf-8') as f:
    text = f.read()

In [2]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


Develop a strategy to tokenize the input text, means **converting the raw text as string to some sequence of integers according to some vocabulary of possible elements**

# mapping from characters to integers
stoi = {s:i for i,s in enumerate(chars)}
itos = {i:s for i,s in enumerate(chars)}
encode = lambda l: [stoi[c] for c in l] # takes a string, returns a list of integers
decode = lambda d: ''.join([itos[c] for c in d]) # takes an array as input, returns a string

In [7]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


Splitting the dataset into training and validation

In [6]:
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
print(train_data.shape, val_data.shape)

torch.Size([1003854]) torch.Size([111540])


You do not feed the entire text into the transformer, **not efficient**.

Instead you sample random little chunks out of training set, each chunk has a **maximum length**

In [11]:
# each chunk has many examples feeding into the transformer
# you train the examples with contexts between 1 all up to block size
# to make the transformer be used to seeting anything in between
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

Training the transformer with inputs with many contexts from size 1 all up to block size is good for inference, the model gets used to seeing anything in between, while you sampling so it can predict all up to block size, then truncates cause the transformer will never receive more than block size as input

In [12]:
# inputs for the transformer
x = train_data[:block_size]
# targets for each position in the input, offset by one
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'when input is {context} the target it {target}')

when input is tensor([18]) the target it 47
when input is tensor([18, 47]) the target it 56
when input is tensor([18, 47, 56]) the target it 57
when input is tensor([18, 47, 56, 57]) the target it 58
when input is tensor([18, 47, 56, 57, 58]) the target it 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target it 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target it 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target it 58


Second dimension: **BATCH DIMENSION**

As you sampling the chunks of text, every time you feed it into the transformer, you gonna have many chunks of multiple chunks of text stacked up in a single tensor - efficiency, each one is independent

In [21]:
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y 

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')
for b in range(batch_size): # 
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f'when input is {context.tolist()} the target it {target}')

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----
when input is [24] the target it 43
when input is [24, 43] the target it 58
when input is [24, 43, 58] the target it 5
when input is [24, 43, 58, 5] the target it 57
when input is [24, 43, 58, 5, 57] the target it 1
when input is [24, 43, 58, 5, 57, 1] the target it 46
when input is [24, 43, 58, 5, 57, 1, 46] the target it 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target it 39
when input is [44] the target it 53
when input is [44, 53] the target it 56
when input is [44, 53, 56] the target it 1
when input is [44, 53, 56, 1] the target it 58
when input is [44, 53, 56, 1, 58] the target i

In [23]:
print(xb) # input to the transformer

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


Feeding into neural networks.

Most basic: bigram language model

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        if target is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(-1)
            # loss function 
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:,-1,:] # becomes (B,C)
            # apply softmax
            probs = F.softmax(logits,dim=-1) # (B,C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # append sampled index to running sequence
            idx = torch.cat((idx,idx_next), dim=1) # (B,T+1)
        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb,yb)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)


Training the model

In [35]:
# create a pytorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [39]:
batch_size = 32
for steps in range(10000):
    # sample batch from the data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.581860065460205


### Mathematical trick in self-attention

Building first self attention block to process the tokens

In [40]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])