In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [1]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [2]:
print(f'length of dataset: {len(text)}')

length of dataset: 1115394


In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f'{vocab_size=}')

vocab_size=65


In [7]:
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for i, s in enumerate(chars)}
# lamdba = function wrapper
encode = lambda s: [stoi[c] for c in s] # encode into integers
decode = lambda d: ''.join([itos[c] for c in d]) # decode integers
# can use other tokenizers, like tiktoken or sentencepiece

In [9]:
# tokenizing text
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [12]:
# splitting dataset
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [13]:
block_size = 8 # context length
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'input: {context}, target: {target}')

input: tensor([18]), target: 47
input: tensor([18, 47]), target: 56
input: tensor([18, 47, 56]), target: 57
input: tensor([18, 47, 56, 57]), target: 58
input: tensor([18, 47, 56, 57, 58]), target: 1
input: tensor([18, 47, 56, 57, 58,  1]), target: 15
input: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
input: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58


In [22]:
torch.manual_seed(1337)
batch_size = 4 # number of inputs processed in parallel
block_size = 8 # maximum context length

def get_batch(split):
    data = train_data if split=="train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x,y

xb, yb = get_batch('train')


In [28]:
torch.manual_seed(1337)

class BigramLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx) # (B,T,C) tensor (batch, time, channel)

        if targets is None:
            loss = None
        else:
        # converting into 2d
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets) # negative log likelihood, (B,C,T)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx(B,T): current context of characters in a batch
        for _ in range(max_new_tokens):
            logits, loss = self(idx) # getting predictions
            logits = logits[:, -1, :] # focus only on last character (due to its being a bigram model), converting into (B,C)
            probs = F.softmax(logits, dim=-1) # softmaxing probabilities
            idx_next = torch.multinomial(probs, num_samples=1) # sampling from distributin, (B,1)
            idx = torch.cat((idx, idx_next), dim=1) # appending sampled index to running sequence, (B, T+1)
        return idx



print(xb.shape)
m = BigramLM(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape, loss)
# loss is around 4.87, but should be arond -ln(1/65) = 4.17438

# sampling from model
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


torch.Size([4, 8])
torch.Size([32, 65]) tensor(4.8786, grad_fn=<NllLossBackward0>)

SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ


In [29]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [39]:
# training model
batch_size = 32
for steps in range(1000):
    xb, yb = get_batch('train')
    
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

2.5097763538360596


In [40]:
print(decode(m.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


A'geeliny y,

CUSpthe mpu;?
MAnin,
Fisserin t tho bulla dangO: caren'xalasat is ?we t be DYe; norDO:


In [6]:
# self attention
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

head_size = 16
# key = what you have, query = what youre looking for
key = nn.Linear(C, head_size, bias=False) # nn.linear with no bias = matrix multiplication
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B,T,head_size)
q = query(x) # (B,T,head_size)

weights = q @ k.transpose(-2, -1) # transposing last 2 dimensions - (B, T, 16) @ (B, 16, T) -> (B, T, T)

tril = torch.tril(torch.ones(T,T)) # triangular shape on the bottom left corner
# weights = torch.zeros((T,T))
weights = weights.masked_fill(tril == 0, float('-inf')) # replcaing zeros to -inf, nodes cant talk with eachother
weights = F.softmax(weights, dim=1) # avg of past and current token
# out = weights @ x
v = value(x)
out = weights @ v

out.shape

torch.Size([4, 8, 16])

In [None]:
# normalizing
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5