In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
import numpy as np
import torch
from pathlib import Path
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
with Path("input.txt").open("r", encoding='utf-8') as f:
    text = f.read()

In [11]:
vocab = sorted(list(set(text)))
vocab_size = len(vocab)

stoi = { ch: i for i, ch in enumerate(vocab)}
itos = { i: ch for i, ch in enumerate(vocab)}
encode = lambda x: [stoi[s] for s in x]
decode = lambda x: [itos[s] for s in x]

enc = encode("hii there")
print("".join(decode(enc)))

hii there


In [12]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [13]:
batch_size = 4 # B
block_size = 8 # T
n_embd = 16 # C

In [14]:
len(data)

1115394

In [15]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    print(len(data))
    # Select the starting index of a block of size block_size from data
    ix = torch.randint(len(data) - block_size, (batch_size,)) 
    # print(ix[0])
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y
    
xb, yb = get_batch('train')

1003854


In [22]:
class BLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets):
        # idx.shape = (B,T)
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        logits = self.lm_head(tok_emb)
        B,T,C = logits.shape
        logits = logits.view(B*T,C)
        targets = targets.view(B*T)
        
        # cross_entropy expecting size (minibatch, C) = (minibatch, 16)
        # cross_entropy turns all targets into OHEs
        # behind the scenes so we don't need to explicitly
        # transform them
        loss = F.cross_entropy(logits, targets)
        
        return logits, loss

In [26]:
model = BLM()
m = model.to(device)
logits, loss = m(xb, yb)