In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x1a8b242c4f0>

In [2]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

print("Vocab size:", vocab_size)
print("Number of characters:", len(text))

Vocab size: 65
Number of characters: 1115394


In [3]:
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [4]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [5]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
v = value(x)

tril = torch.tril(torch.ones(T,T))
wei = k @ q.transpose(-2, -1)
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)
out = wei @ v

In [6]:
# self attention head computation:
n_embd = 32

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
    
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
    
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        out = wei @ v 
        return out

In [12]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.positional_embedding = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B,T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.positional_embedding(torch.arange(T, device=device))
        x = tok_emb + pos_emb # B, T, C
        x = self.sa_head(x) # B, T, C
        logits = self.lm_head(x) # B, T, vocab_size

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_token):
        for _ in range(max_new_token):
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [13]:
batch_size = 256 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 2500
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0

model = BigramLanguageModel()
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.1541, val loss 4.1553
step 100: train loss 2.8980, val loss 2.9141
step 200: train loss 2.6617, val loss 2.6645
step 300: train loss 2.5322, val loss 2.5353
step 400: train loss 2.4546, val loss 2.4633
step 500: train loss 2.4160, val loss 2.4255
step 600: train loss 2.3902, val loss 2.4065
step 700: train loss 2.3741, val loss 2.3912
step 800: train loss 2.3607, val loss 2.3812
step 900: train loss 2.3458, val loss 2.3685
step 1000: train loss 2.3366, val loss 2.3616
step 1100: train loss 2.3243, val loss 2.3530
step 1200: train loss 2.3181, val loss 2.3448
step 1300: train loss 2.3101, val loss 2.3414
step 1400: train loss 2.3058, val loss 2.3374
step 1500: train loss 2.2995, val loss 2.3308
step 1600: train loss 2.2950, val loss 2.3329
step 1700: train loss 2.2913, val loss 2.3271
step 1800: train loss 2.2876, val loss 2.3276
step 1900: train loss 2.2871, val loss 2.3290
step 2000: train loss 2.2829, val loss 2.3255
step 2100: train loss 2.2813, val loss 2.3228


KeyboardInterrupt: 

In [15]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_token=2000)[0].tolist()))


WANGRHDI bridcowr,
This sowr Kibadset bobe doeravegr-ans mealilanss:
Want he usquet vet?

MEXENO:
Aen wice my.

HDYUSYom onoug
Yowno, tof isth bot milf dill, at miree sen cie lat Het drovets, and Win ngan ilerabous lelind meall liser onchiry:
Asprinesspll, yo wllingu normopetelaves
Momy yu, demet akleo Winso wher eiinge wisti dourive wees ime st sot owrif thure kind thrupirf sor; igre! mef thie male onto, af Prred my om.

HETY:
E'ss,
Sbus, wardave aces art my din cme amy aney Iry ts I fr yo voucken pand, bemary.

HARIWOF RIORD oben anghse.

And fout senet.

Thy showne, ins win llety ome.

Thuco frepy tshintchigl.

Andias wetlal wave.

LAWIOPRAUNTE:
Rour imd assche, os coknovet Hose st ums histe fe'd tass:
Whit Clof; chun hes, nd dud ton, moxcharcheanto ankes agh whein
'As mes sleve bumlon mod the wllo no'ld id, morsed
Formy?
TI idurd porvenand, do thieyr ivethe of tiund the nof the sut nexch you on whandeng itth ougle Et llollke, on sothan thean, delwat do ived:
Ther, foru;
yo knogin.