In [4]:
with open('input.txt', 'r', encoding='utf8') as f:
    text = f.read()   

print(f'{len(text)=}')
print(text[:300])

len(text)=1115393
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us


In [5]:
characters = sorted(list(set(text)))
V = len(characters)
print(''.join(characters))
print(f'{V=}')


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
V=65


In [19]:
stoi = {ch: i for i, ch in enumerate(characters)}
itos = {i: ch for ch, i in stoi.items()}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
r = encode('Brad')
decode(r)

'Brad'

In [22]:
import torch

data = torch.tensor(encode(text), dtype=torch.long)
print(f'{data.dtype=}, {data.shape=}')
data[:10]

data.dtype=torch.int64, data.shape=torch.Size([1115393])


tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])

In [25]:
n_split = int(len(data) *.9)
train = data[:n_split]
val = data[n_split:]

In [28]:
context_length = 8
train[:context_length+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [30]:
x = train[:context_length]
y = train[1:context_length+1]

for t in range(context_length):
    context = x[:t+1]
    target = y[t]
    print(f'when the input is {context}, the target is {target}')

print(f'{x=}, {y=}')

when the input is tensor([18]), the target is 47
when the input is tensor([18, 47]), the target is 56
when the input is tensor([18, 47, 56]), the target is 57
when the input is tensor([18, 47, 56, 57]), the target is 58
when the input is tensor([18, 47, 56, 57, 58]), the target is 1
when the input is tensor([18, 47, 56, 57, 58,  1]), the target is 15
when the input is tensor([18, 47, 56, 57, 58,  1, 15]), the target is 47
when the input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is 58
x=tensor([18, 47, 56, 57, 58,  1, 15, 47]), y=tensor([47, 56, 57, 58,  1, 15, 47, 58])


In [35]:
torch.manual_seed(1337)
batch_size = 4
context_length = 8

def get_batch(split:str = 'train') -> tuple[torch.Tensor, torch.Tensor]:
    data = train if split == 'train' else val
    idx = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i+context_length] for i in idx])
    y = torch.stack([data[i+1:i+context_length+1] for i in idx])
    return x, y

X_batch, Y_batch = get_batch()
print(f'{X_batch.shape=}\n{X_batch[:10]}')
print(f'{Y_batch.shape=}\n{Y_batch[:10]}')

for b in range(batch_size):
    for t in range(context_length):
        context = X_batch[b, :t+1]
        target = Y_batch[b, t]
        print(f'When the input is {context} the target is {target}: {b=}')

X_batch.shape=torch.Size([4, 8])
tensor([[53, 59,  6,  1, 58, 56, 47, 40],
        [49, 43, 43, 54,  1, 47, 58,  1],
        [13, 52, 45, 43, 50, 53,  8,  0],
        [ 1, 39,  1, 46, 53, 59, 57, 43]])
Y_batch.shape=torch.Size([4, 8])
tensor([[59,  6,  1, 58, 56, 47, 40, 59],
        [43, 43, 54,  1, 47, 58,  1, 58],
        [52, 45, 43, 50, 53,  8,  0, 26],
        [39,  1, 46, 53, 59, 57, 43,  0]])
When the input is tensor([53]) the target is 59: b=0
When the input is tensor([53, 59]) the target is 6: b=0
When the input is tensor([53, 59,  6]) the target is 1: b=0
When the input is tensor([53, 59,  6,  1]) the target is 58: b=0
When the input is tensor([53, 59,  6,  1, 58]) the target is 56: b=0
When the input is tensor([53, 59,  6,  1, 58, 56]) the target is 47: b=0
When the input is tensor([53, 59,  6,  1, 58, 56, 47]) the target is 40: b=0
When the input is tensor([53, 59,  6,  1, 58, 56, 47, 40]) the target is 59: b=0
When the input is tensor([49]) the target is 43: b=1
When the 

In [87]:
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size: int) -> None:
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
        logits = self.token_embedding_table(idx)
        
        if targets is None:
            loss = None
        else:
            batches, time_steps, channels = logits.shape
            logits = logits.view(batches * time_steps, channels)
            targets = targets.view(batches* time_steps)
            loss = F.cross_entropy(logits,targets)
        return logits, loss

    
    def generate(self, indicies: torch.Tensor, max_new_tokens: int):
        for i in range(max_new_tokens):
            logits, loss = self(indicies)
            # extract last time step
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=1)
            next_index = torch.multinomial(probs,num_samples=1)
            indicies = torch.cat((indicies, next_index), dim=1)
            
        return indicies
    
m = BigramLanguageModel(V)
logits, loss = m(X_batch, Y_batch)
logits.shape, loss

(torch.Size([32, 65]), tensor(4.8948, grad_fn=<NllLossBackward0>))

In [98]:
decode(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=400)[0].tolist())
# m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)

"\n\n\nBALLOUCoord'rer H:\nICAd t gs trtin s: Twathy sthimbe wine k u h f h ff s t, s ayo be.\nANos tcroll tovaspthis ar w mis y his Himayo Gotesith sownorthe:\nFRDin fore po whand.\nCin metis o, thend s t s prthinthyofan:\nWef, dreadek w.\nAnoyer'ditoby; thy geak awit t. My brito orthy httagerp ansensthart cery by,\nLENThas whelllkes the wat w'sthe thie douswe t ser e ba ort ppe Har unden\nI at, swhos.\nGLABel"

In [91]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [96]:
batch_size = 32
for steps in range(10000):
    X_batch, Y_batch = get_batch('train')
    
    logits, loss = m(X_batch, Y_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    if steps % 1000 == 0:
        print(f'{steps=} {loss=}')

steps=0 loss=tensor(2.4163, grad_fn=<NllLossBackward0>)
steps=1000 loss=tensor(2.4161, grad_fn=<NllLossBackward0>)
steps=2000 loss=tensor(2.5657, grad_fn=<NllLossBackward0>)
steps=3000 loss=tensor(2.4316, grad_fn=<NllLossBackward0>)
steps=4000 loss=tensor(2.5484, grad_fn=<NllLossBackward0>)
steps=5000 loss=tensor(2.4627, grad_fn=<NllLossBackward0>)
steps=6000 loss=tensor(2.4249, grad_fn=<NllLossBackward0>)
steps=7000 loss=tensor(2.3524, grad_fn=<NllLossBackward0>)
steps=8000 loss=tensor(2.4860, grad_fn=<NllLossBackward0>)
steps=9000 loss=tensor(2.4564, grad_fn=<NllLossBackward0>)
