In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
# # data loading
# def get_batch(split):
#     # generate a small batch of data of inputs x and targets y
#     data = train_data if split == 'train' else val_data
#     ix = torch.randint(len(data) - block_size, (batch_size,))
#     x = torch.stack([data[i:i+block_size] for i in ix])
#     y = torch.stack([data[i+1:i+block_size+1] for i in ix])
#     x, y = x.to(device), y.to(device)
#     return x, y

# @torch.no_grad()
# def estimate_loss():
#     out = {}
#     model.eval()
#     for split in ['train', 'val']:
#         losses = torch.zeros(eval_iters)
#         for k in range(eval_iters):
#             X, Y = get_batch(split)
#             logits, loss = model(X, Y)
#             losses[k] = loss.item()
#         out[split] = losses.mean()
#     model.train()
#     return out

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
class Head(nn.Module):
    def __init__(self, head_size, n_embd, block_size, drop_out):
        super().__init__()
        self.q = nn.Linear(n_embd, head_size, bias=False)
        self.k = nn.Linear(n_embd, head_size, bias=False)
        self.v = nn.Linear(n_embd, head_size, bias=False)
        self.n_embd = n_embd
        self.head_size = head_size
        self.dropout = nn.Dropout(drop_out)
        self.register_buffer('tril', torch.tril(
            torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        Q = self.q(x) #B,T,head_size
        K = self.k(x) #B,T,head_size
        V = self.v(x)
        x = Q@K.transpose(-2, -1) * C**-0.5
        x = x.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        x = F.softmax(x, dim=-1)
        x = self.dropout(x)
        x = x @ V
        return x
    

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, n_embd, n_head, drop_out):
        super().__init__()
        head_size = n_embd // n_head
        self.heads = nn.ModuleList(
            [Head(head_size, n_embd, block_size, drop_out) for i in range(n_head)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(drop_out)

    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        x = self.dropout(x)
        return x

In [6]:
class FeedForward(nn.Module):
    def __init__(self, n_embd, drop_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd),
            nn.Dropout(drop_out),
        )

    def forward(self, x):
        return self.net(x)

In [7]:
class Block(nn.Module):
    def __init__(self, block_size, n_embd, n_head, drop_out) -> None:
        super().__init__()
        self.sa = MultiHeadAttention(block_size, n_embd, n_head, drop_out)
        self.ffwd = FeedForward(n_embd, drop_out)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [8]:
class BigramLanguageModel(nn.Module):
    def __init__(self, n_vocab, block_size, n_embd, n_head, n_layer, drop_out) -> None:
        super().__init__()
        self.token_embedding_table = nn.Embedding(n_vocab, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[Block(block_size, n_embd, n_head, drop_out) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embd)  # final layer norm
        self.lm_head = nn.Linear(n_embd, n_vocab)
        self.block_size = block_size

    def forward(self, x, targets=None):
        B, T = x.shape
        t = self.token_embedding_table(x)  # B,T,C
        p = self.position_embedding_table(
            torch.arange(T, device=device))  # T,C
        x = t+p  # B,T,C
        x = self.blocks(x)  # B,T,C
        x = self.ln_f(x)
        logits = self.lm_head(x) #B,T,n_vocab

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

In [9]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-07-30 21:48:35--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Connecting to 127.0.0.1:12639... connected.


Proxy request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2023-07-30 21:48:37 (1.49 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [10]:
torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [11]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
# encoder: take a string, output a list of integers
def encode(s): return [stoi[c] for c in s]
# decoder: take a list of integers, output a string
def decode(l): return ''.join([itos[i] for i in l])


# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))  # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split, block_size, batch_size):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


@torch.no_grad()
def estimate_loss(model, eval_iters, block_size, batch_size):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, block_size, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [19]:
drop_out = 0.1
block_size = 16
n_embd = 64
n_head = 4
n_layer = 12
eval_iters = 200

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BigramLanguageModel(vocab_size, block_size,
                            n_embd, n_head, n_layer, drop_out)
model.to(device)
model

BigramLanguageModel(
  (token_embedding_table): Embedding(65, 64)
  (position_embedding_table): Embedding(16, 64)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-3): 4 x Head(
            (q): Linear(in_features=64, out_features=16, bias=False)
            (k): Linear(in_features=64, out_features=16, bias=False)
            (v): Linear(in_features=64, out_features=16, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=64, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ffwd): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=64, out_features=256, bias=True)
          (1): ReLU()
          (2): Linear(in_features=256, out_features=64, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): 

In [20]:
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')


0.607041 M parameters


In [21]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)


In [22]:
max_iters = 5000
batch_size = 16

for i in range(max_iters):
    if i % eval_iters == 0 or i == max_iters-1:
        losses = estimate_loss(model, eval_iters, block_size, batch_size)
        print(
            f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    x, y = get_batch('train', block_size, batch_size)
    logits, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step <built-in function iter>: train loss 4.3907, val loss 4.3916
step <built-in function iter>: train loss 2.4960, val loss 2.4963
step <built-in function iter>: train loss 2.3382, val loss 2.3579
step <built-in function iter>: train loss 2.2595, val loss 2.2878
step <built-in function iter>: train loss 2.1813, val loss 2.2041
step <built-in function iter>: train loss 2.1234, val loss 2.1884
step <built-in function iter>: train loss 2.0673, val loss 2.1247
step <built-in function iter>: train loss 2.0453, val loss 2.1008
step <built-in function iter>: train loss 1.9983, val loss 2.0575
step <built-in function iter>: train loss 1.9978, val loss 2.0748
step <built-in function iter>: train loss 1.9478, val loss 2.0285
step <built-in function iter>: train loss 1.9232, val loss 2.0003
step <built-in function iter>: train loss 1.8946, val loss 1.9913
step <built-in function iter>: train loss 1.8867, val loss 1.9758
step <built-in function iter>: train loss 1.8627, val loss 1.9650
step <buil

In [18]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
decode(model.generate(context, max_new_tokens=2000)[0].tolist())

"\n\nBUSCENTIUS:\nAband the Curteness a with be of his you is tour men: word shart toorts for cons's ghath your mord mooves? For in thou sprar.\n\nAgain rewell not be strocledob. My such.\n\nCOMIOLISS:\nI 'ladvis a amporn the go.\n\nELWOe, Tyroviouss a when, I go, He cbonady, costroth it too.\n\nHESSNOSS IS MARET:\nRive,\nWhat pepentenge, of the\nHadlam then thing,\nButuy tite thefurlous my feaeclest.\n\nShe Caradam\nAnd?\n'AUTINGBRO:\nTreples our good;\nOr mytaixistard,\nAway, 'Tyserven. \nCay Collacitess ap 'Tis whire;\nO faurs? you proble un ilour your eaven:\nLoubly the shing Lice, they motten\nmain ancerce?\n\nServoln, my dave, he rime\nand notly vaine,\nTo I presing if their, go my your reegr, you.\n\nROTUSA:\nHeartle tumpes here them Lose wirtusem? bose, and deep us.\n\nClove fanlal, nor in best time this.\n\nKive, my nore, fain his, of my in Eour Whochlaid; If to bung oftiettly, them: for the death it.\n\nCLord S Aup my bestery in thee every;\nAnow'\nremenn thou loss, be laven 