In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.nn import functional as F

In [2]:
# Hyperparameters
batch_size = 16
block_size = 256
max_iters = 50000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layers = 6
dropout = 0.2

In [3]:
with open('input.txt', 'r', encoding = 'utf-8') as f:
    text = f.read()

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
vocab_size

65

In [5]:
string_to_int = {ch:i for i, ch in enumerate(chars)}
int_to_string = {i:ch for i, ch in enumerate(chars)}

In [6]:
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: "".join([int_to_string[i] for i in l])

In [7]:
data = torch.tensor(encode(text), dtype = torch.long)

In [8]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [9]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x,y

In [10]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = m(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [11]:
class Head(nn.Module):
    """ One head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias = False)
        self.query = nn.Linear(n_embd, head_size, bias = False)
        self.value = nn.Linear(n_embd, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)  #(B, T, C)
        q = self.query(x) #(B, T, C)

        wei = q @ k.transpose(-2,-1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf'))     
        wei = F.softmax(wei, dim = -1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v
        return out
        

In [12]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim = -1)
        out = self.proj(out)
        return out

In [13]:
class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

In [14]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """
    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension
        # n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)


    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [15]:
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            Block(n_embd, n_head = 4),
            Block(n_embd, n_head = 4),
            Block(n_embd, n_head = 4),
            nn.LayerNorm(n_embd),
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)
    def forward(self, idx, targets = None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device = device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [16]:
model = BigramLanguageModel()
m = model.to(device)

In [17]:
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

5.468993 M parameters


In [18]:
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
for iter in tqdm(range(max_iters)):

    if iter % eval_interval == 0 or iter == max_iters -1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        xb, yb = get_batch('train')
        logits, loss = m(xb, yb)
        optimizer.zero_grad(set_to_none = True)
        loss.backward()
        optimizer.step()

  0%|                                                                                        | 0/50000 [00:00<?, ?it/s]

step 0: train loss 4.3022, val loss 4.2976


  0%|                                                                           | 1/50000 [01:18<1096:06:04, 78.92s/it]

step 500: train loss 3.7776, val loss 3.7878


  1%|▊                                                                           | 501/50000 [02:42<3:48:33,  3.61it/s]

step 1000: train loss 3.5143, val loss 3.5308


  2%|█▌                                                                         | 1001/50000 [04:10<2:58:15,  4.58it/s]

step 1500: train loss 3.3661, val loss 3.3968


  3%|██▎                                                                        | 1501/50000 [05:42<2:43:14,  4.95it/s]

step 2000: train loss 3.2632, val loss 3.2963


  4%|███                                                                        | 2001/50000 [07:13<2:35:24,  5.15it/s]

step 2500: train loss 3.2039, val loss 3.2362


  5%|███▊                                                                       | 2501/50000 [08:46<2:31:33,  5.22it/s]

step 3000: train loss 3.1562, val loss 3.1933


  6%|████▌                                                                      | 3001/50000 [10:20<2:28:55,  5.26it/s]

step 3500: train loss 3.1133, val loss 3.1579


  7%|█████▎                                                                     | 3501/50000 [11:54<2:26:39,  5.28it/s]

step 4000: train loss 3.0843, val loss 3.1237


  8%|██████                                                                     | 4001/50000 [13:27<2:24:38,  5.30it/s]

step 4500: train loss 3.0527, val loss 3.0849


  9%|██████▊                                                                    | 4501/50000 [15:03<2:23:46,  5.27it/s]

step 5000: train loss 3.0278, val loss 3.0605


 10%|███████▌                                                                   | 5001/50000 [16:41<2:23:39,  5.22it/s]

step 5500: train loss 2.9893, val loss 3.0280


 11%|████████▎                                                                  | 5501/50000 [18:19<2:23:02,  5.19it/s]

step 6000: train loss 2.9750, val loss 3.0155


 12%|█████████                                                                  | 6001/50000 [19:57<2:22:07,  5.16it/s]

step 6500: train loss 2.9510, val loss 2.9810


 13%|█████████▊                                                                 | 6501/50000 [21:35<2:20:58,  5.14it/s]

step 7000: train loss 2.9298, val loss 2.9514


 14%|██████████▌                                                                | 7001/50000 [23:13<2:19:41,  5.13it/s]

step 7500: train loss 2.9076, val loss 2.9297


 15%|███████████▎                                                               | 7501/50000 [24:51<2:18:18,  5.12it/s]

step 8000: train loss 2.8843, val loss 2.9126


 16%|████████████                                                               | 8001/50000 [26:29<2:16:51,  5.11it/s]

step 8500: train loss 2.8675, val loss 2.8904


 17%|████████████▊                                                              | 8501/50000 [28:07<2:15:23,  5.11it/s]

step 9000: train loss 2.8565, val loss 2.8785


 18%|█████████████▌                                                             | 9001/50000 [29:45<2:13:52,  5.10it/s]

step 9500: train loss 2.8408, val loss 2.8551


 19%|██████████████▎                                                            | 9501/50000 [31:23<2:12:16,  5.10it/s]

step 10000: train loss 2.8305, val loss 2.8434


 20%|██████████████▊                                                           | 10001/50000 [33:02<2:10:41,  5.10it/s]

step 10500: train loss 2.8157, val loss 2.8336


 21%|███████████████▌                                                          | 10501/50000 [34:40<2:09:06,  5.10it/s]

step 11000: train loss 2.8034, val loss 2.8155


 22%|████████████████▎                                                         | 11001/50000 [36:18<2:07:31,  5.10it/s]

step 11500: train loss 2.7918, val loss 2.8013


 23%|█████████████████                                                         | 11501/50000 [37:56<2:05:55,  5.10it/s]

step 12000: train loss 2.7793, val loss 2.7883


 24%|█████████████████▊                                                        | 12001/50000 [39:34<2:04:16,  5.10it/s]

step 12500: train loss 2.7676, val loss 2.7775


 25%|██████████████████▌                                                       | 12501/50000 [41:12<2:02:38,  5.10it/s]

step 13000: train loss 2.7575, val loss 2.7674


 26%|███████████████████▏                                                      | 13001/50000 [42:50<2:00:59,  5.10it/s]

step 13500: train loss 2.7444, val loss 2.7579


 27%|███████████████████▉                                                      | 13501/50000 [44:28<1:59:20,  5.10it/s]

step 14000: train loss 2.7329, val loss 2.7497


 28%|████████████████████▋                                                     | 14001/50000 [46:06<1:57:41,  5.10it/s]

step 14500: train loss 2.7221, val loss 2.7339


 29%|█████████████████████▍                                                    | 14501/50000 [47:45<1:56:03,  5.10it/s]

step 15000: train loss 2.7153, val loss 2.7321


 30%|██████████████████████▏                                                   | 15001/50000 [49:23<1:54:25,  5.10it/s]

step 15500: train loss 2.7042, val loss 2.7223


 31%|██████████████████████▉                                                   | 15501/50000 [51:01<1:52:46,  5.10it/s]

step 16000: train loss 2.6924, val loss 2.7123


 32%|███████████████████████▋                                                  | 16001/50000 [52:39<1:51:09,  5.10it/s]

step 16500: train loss 2.6894, val loss 2.7041


 33%|████████████████████████▍                                                 | 16501/50000 [54:17<1:49:31,  5.10it/s]

step 17000: train loss 2.6793, val loss 2.6941


 34%|█████████████████████████▏                                                | 17001/50000 [55:55<1:47:52,  5.10it/s]

step 17500: train loss 2.6737, val loss 2.6862


 35%|█████████████████████████▉                                                | 17501/50000 [57:33<1:46:13,  5.10it/s]

step 18000: train loss 2.6644, val loss 2.6821


 36%|██████████████████████████▋                                               | 18001/50000 [59:11<1:44:35,  5.10it/s]

step 18500: train loss 2.6547, val loss 2.6744


 37%|██████████████████████████▋                                             | 18501/50000 [1:00:49<1:42:56,  5.10it/s]

step 19000: train loss 2.6491, val loss 2.6657


 38%|███████████████████████████▎                                            | 19001/50000 [1:02:27<1:41:17,  5.10it/s]

step 19500: train loss 2.6446, val loss 2.6585


 39%|████████████████████████████                                            | 19501/50000 [1:03:55<1:36:32,  5.27it/s]

step 20000: train loss 2.6418, val loss 2.6533


 40%|████████████████████████████▊                                           | 20001/50000 [1:05:20<1:31:58,  5.44it/s]

step 20500: train loss 2.6323, val loss 2.6473


 41%|█████████████████████████████▌                                          | 20501/50000 [1:06:45<1:28:30,  5.56it/s]

step 21000: train loss 2.6269, val loss 2.6383


 42%|██████████████████████████████▏                                         | 21001/50000 [1:08:10<1:25:25,  5.66it/s]

step 21500: train loss 2.6217, val loss 2.6304


 43%|██████████████████████████████▉                                         | 21501/50000 [1:09:36<1:23:19,  5.70it/s]

step 22000: train loss 2.6158, val loss 2.6230


 44%|███████████████████████████████▋                                        | 22001/50000 [1:10:59<1:20:38,  5.79it/s]

step 22500: train loss 2.6111, val loss 2.6182


 45%|████████████████████████████████▍                                       | 22501/50000 [1:12:25<1:19:06,  5.79it/s]

step 23000: train loss 2.6029, val loss 2.6093


 46%|█████████████████████████████████                                       | 23001/50000 [1:13:50<1:17:06,  5.84it/s]

step 23500: train loss 2.6025, val loss 2.6012


 47%|█████████████████████████████████▊                                      | 23501/50000 [1:15:16<1:15:49,  5.82it/s]

step 24000: train loss 2.5953, val loss 2.5953


 48%|██████████████████████████████████▌                                     | 24001/50000 [1:16:44<1:15:01,  5.78it/s]

step 24500: train loss 2.5913, val loss 2.5938


 49%|███████████████████████████████████▎                                    | 24501/50000 [1:18:08<1:12:58,  5.82it/s]

step 25000: train loss 2.5849, val loss 2.5899


 50%|████████████████████████████████████                                    | 25001/50000 [1:19:34<1:11:36,  5.82it/s]

step 25500: train loss 2.5821, val loss 2.5849


 51%|████████████████████████████████████▋                                   | 25501/50000 [1:21:04<1:11:02,  5.75it/s]

step 26000: train loss 2.5786, val loss 2.5836


 52%|█████████████████████████████████████▍                                  | 26001/50000 [1:22:35<1:10:35,  5.67it/s]

step 26500: train loss 2.5753, val loss 2.5782


 53%|██████████████████████████████████████▏                                 | 26501/50000 [1:24:06<1:09:48,  5.61it/s]

step 27000: train loss 2.5740, val loss 2.5745


 54%|██████████████████████████████████████▉                                 | 27001/50000 [1:25:37<1:08:47,  5.57it/s]

step 27500: train loss 2.5707, val loss 2.5705


 55%|███████████████████████████████████████▌                                | 27501/50000 [1:27:09<1:07:37,  5.55it/s]

step 28000: train loss 2.5660, val loss 2.5665


 56%|████████████████████████████████████████▎                               | 28001/50000 [1:28:40<1:06:19,  5.53it/s]

step 28500: train loss 2.5653, val loss 2.5673


 57%|█████████████████████████████████████████                               | 28501/50000 [1:30:11<1:04:57,  5.52it/s]

step 29000: train loss 2.5668, val loss 2.5644


 58%|█████████████████████████████████████████▊                              | 29001/50000 [1:31:42<1:03:32,  5.51it/s]

step 29500: train loss 2.5622, val loss 2.5603


 59%|██████████████████████████████████████████▍                             | 29501/50000 [1:33:13<1:02:04,  5.50it/s]

step 30000: train loss 2.5566, val loss 2.5580


 60%|███████████████████████████████████████████▏                            | 30001/50000 [1:34:44<1:00:34,  5.50it/s]

step 30500: train loss 2.5534, val loss 2.5609


 61%|█████████████████████████████████████████████▏                            | 30501/50000 [1:36:15<59:04,  5.50it/s]

step 31000: train loss 2.5525, val loss 2.5628


 62%|█████████████████████████████████████████████▉                            | 31001/50000 [1:37:46<57:33,  5.50it/s]

step 31500: train loss 2.5486, val loss 2.5635


 63%|██████████████████████████████████████████████▌                           | 31501/50000 [1:39:17<56:03,  5.50it/s]

step 32000: train loss 2.5458, val loss 2.5597


 64%|███████████████████████████████████████████████▎                          | 32001/50000 [1:40:47<54:32,  5.50it/s]

step 32500: train loss 2.5442, val loss 2.5550


 65%|████████████████████████████████████████████████                          | 32501/50000 [1:42:18<53:00,  5.50it/s]

step 33000: train loss 2.5443, val loss 2.5524


 66%|████████████████████████████████████████████████▊                         | 33001/50000 [1:43:49<51:30,  5.50it/s]

step 33500: train loss 2.5448, val loss 2.5502


 67%|█████████████████████████████████████████████████▌                        | 33501/50000 [1:45:20<49:59,  5.50it/s]

step 34000: train loss 2.5426, val loss 2.5498


 68%|██████████████████████████████████████████████████▎                       | 34001/50000 [1:46:51<48:28,  5.50it/s]

step 34500: train loss 2.5431, val loss 2.5455


 69%|███████████████████████████████████████████████████                       | 34501/50000 [1:48:22<46:57,  5.50it/s]

step 35000: train loss 2.5428, val loss 2.5458


 70%|███████████████████████████████████████████████████▊                      | 35001/50000 [1:49:53<45:26,  5.50it/s]

step 35500: train loss 2.5420, val loss 2.5424


 71%|████████████████████████████████████████████████████▌                     | 35501/50000 [1:51:24<43:56,  5.50it/s]

step 36000: train loss 2.5381, val loss 2.5449


 72%|█████████████████████████████████████████████████████▎                    | 36001/50000 [1:52:55<42:25,  5.50it/s]

step 36500: train loss 2.5378, val loss 2.5427


 73%|██████████████████████████████████████████████████████                    | 36501/50000 [1:54:26<40:54,  5.50it/s]

step 37000: train loss 2.5338, val loss 2.5419


 74%|██████████████████████████████████████████████████████▊                   | 37001/50000 [1:55:57<39:23,  5.50it/s]

step 37500: train loss 2.5326, val loss 2.5375


 75%|███████████████████████████████████████████████████████▌                  | 37501/50000 [1:57:28<37:53,  5.50it/s]

step 38000: train loss 2.5292, val loss 2.5390


 76%|████████████████████████████████████████████████████████▏                 | 38001/50000 [1:58:58<36:22,  5.50it/s]

step 38500: train loss 2.5300, val loss 2.5355


 77%|████████████████████████████████████████████████████████▉                 | 38501/50000 [2:00:29<34:51,  5.50it/s]

step 39000: train loss 2.5278, val loss 2.5359


 78%|█████████████████████████████████████████████████████████▋                | 39001/50000 [2:02:00<33:20,  5.50it/s]

step 39500: train loss 2.5266, val loss 2.5367


 79%|██████████████████████████████████████████████████████████▍               | 39501/50000 [2:03:31<31:49,  5.50it/s]

step 40000: train loss 2.5299, val loss 2.5360


 80%|███████████████████████████████████████████████████████████▏              | 40001/50000 [2:05:02<30:17,  5.50it/s]

step 40500: train loss 2.5255, val loss 2.5317


 81%|███████████████████████████████████████████████████████████▉              | 40501/50000 [2:06:33<28:46,  5.50it/s]

step 41000: train loss 2.5261, val loss 2.5344


 82%|████████████████████████████████████████████████████████████▋             | 41001/50000 [2:08:04<27:15,  5.50it/s]

step 41500: train loss 2.5257, val loss 2.5360


 83%|█████████████████████████████████████████████████████████████▍            | 41501/50000 [2:09:35<25:44,  5.50it/s]

step 42000: train loss 2.5224, val loss 2.5341


 84%|██████████████████████████████████████████████████████████████▏           | 42001/50000 [2:11:05<24:13,  5.50it/s]

step 42500: train loss 2.5234, val loss 2.5372


 85%|██████████████████████████████████████████████████████████████▉           | 42501/50000 [2:12:36<22:42,  5.50it/s]

step 43000: train loss 2.5223, val loss 2.5332


 86%|███████████████████████████████████████████████████████████████▋          | 43001/50000 [2:14:07<21:11,  5.50it/s]

step 43500: train loss 2.5216, val loss 2.5388


 87%|████████████████████████████████████████████████████████████████▍         | 43501/50000 [2:15:38<19:40,  5.50it/s]

step 44000: train loss 2.5214, val loss 2.5323


 88%|█████████████████████████████████████████████████████████████████         | 44001/50000 [2:17:09<18:09,  5.50it/s]

step 44500: train loss 2.5204, val loss 2.5351


 89%|█████████████████████████████████████████████████████████████████▊        | 44501/50000 [2:18:40<16:39,  5.50it/s]

step 45000: train loss 2.5191, val loss 2.5319


 90%|██████████████████████████████████████████████████████████████████▌       | 45001/50000 [2:20:10<15:08,  5.50it/s]

step 45500: train loss 2.5182, val loss 2.5328


 91%|███████████████████████████████████████████████████████████████████▎      | 45501/50000 [2:21:41<13:37,  5.50it/s]

step 46000: train loss 2.5176, val loss 2.5311


 92%|████████████████████████████████████████████████████████████████████      | 46001/50000 [2:23:12<12:06,  5.50it/s]

step 46500: train loss 2.5129, val loss 2.5328


 93%|████████████████████████████████████████████████████████████████████▊     | 46501/50000 [2:24:43<10:35,  5.50it/s]

step 47000: train loss 2.5161, val loss 2.5282


 94%|█████████████████████████████████████████████████████████████████████▌    | 47001/50000 [2:26:14<09:05,  5.50it/s]

step 47500: train loss 2.5149, val loss 2.5250


 95%|██████████████████████████████████████████████████████████████████████▎   | 47501/50000 [2:27:45<07:34,  5.50it/s]

step 48000: train loss 2.5116, val loss 2.5246


 96%|███████████████████████████████████████████████████████████████████████   | 48001/50000 [2:29:16<06:03,  5.50it/s]

step 48500: train loss 2.5097, val loss 2.5228


 97%|███████████████████████████████████████████████████████████████████████▊  | 48501/50000 [2:30:47<04:32,  5.50it/s]

step 49000: train loss 2.5111, val loss 2.5175


 98%|████████████████████████████████████████████████████████████████████████▌ | 49001/50000 [2:32:17<03:01,  5.50it/s]

step 49500: train loss 2.5097, val loss 2.5197


 99%|█████████████████████████████████████████████████████████████████████████▎| 49501/50000 [2:33:48<01:30,  5.50it/s]

step 49999: train loss 2.5072, val loss 2.5229


100%|██████████████████████████████████████████████████████████████████████████| 50000/50000 [2:35:19<00:00,  5.37it/s]


In [20]:
# Generate from the model
context = torch.zeros((1,1), dtype = torch.long, device = device)
print(decode(m.generate(context, max_new_tokens = 500)[0].tolist()))


Fhis blleasina,
So by, icave cuf hallellond rso dithande?

Uthodomuray En Thort.
Lxeofe RE wigun orsud, in?
WhUit.
Tht 
P:
LDow ein nthamad ot inedmy t gnd youcee kts waveatoshowigrftt t.
NEdan md od, boul
Whisousiou, m, heyo blithmes;
S:
ARIIAr serdoredoushap h be teth, breseinthtl t;

Yomour inthed
MCAncD:
He, nof u.
W; keans.
Bfad su frsong loyf lt muwere
RDth o ureefork, ngsorthrdet've, mu qrdo mese ds gr
Samoly k tean t! foul-uranareind smeand thedprrs nevind gt tryovesaftal ho e inaathef t
