In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) = (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape  # (4, 8, 16)

torch.Size([4, 8, 16])

In [3]:
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) = (B, T, T)

In [4]:
k.var()

tensor(0.3581, grad_fn=<VarBackward0>)

In [5]:
q.var()

tensor(0.3736, grad_fn=<VarBackward0>)

In [6]:
wei.var()

tensor(1.6728, grad_fn=<VarBackward0>)

In [7]:
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1) * head_size**-0.5

In [8]:
q.var()

tensor(0.3736, grad_fn=<VarBackward0>)

In [9]:
wei.var()

tensor(0.1046, grad_fn=<VarBackward0>)

In [10]:
wei[0][0]

tensor([-0.1851,  0.4056, -0.4520,  0.1029,  0.0062,  0.1204, -0.0272,  0.2714],
       grad_fn=<SelectBackward0>)

# Train model

In [11]:
import torch.nn as nn
from data import FyodorDataset
from model import BigramLanguageModel
from train import get_batch
from utils import get_files_from_folder, open_txt
from tqdm import trange

batch_size = 256
block_size = 8
max_iters = 10000
eval_iters = 200
eval_interval = 500
learning_rate = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"
n_embd = 24
n_head = 6
n_layer = 6
dropout = 0.4

books = get_files_from_folder("../books")
books_string = [open_txt(f"../books/{i}") for i in books]
books = "\n".join(books_string)

train_dataset = FyodorDataset(books[: int(len(books) * 0.8)])
val_dataset = FyodorDataset(books[int(len(books) * 0.8) :])

model = BigramLanguageModel(
    n_embd=n_embd,
    block_size=block_size,
    n_head=n_head,
    n_layer=n_layer,
    dropout=dropout,
    device=device,
)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

train_loss = float("inf")
val_loss = float("inf")
t = trange(max_iters)
for steps in t:
    # sample a batch of data
    xb, yb = get_batch(
        train_dataset.data, block_size=block_size, batch_size=batch_size, device=device
    )

    model.train()
    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % eval_interval == 0:
        train_loss = loss.item()
        # evalute on valditation set
        model.eval()
        val_loss = torch.zeros(eval_iters)
        with torch.no_grad():
            for i in range(eval_iters):
                xb, yb = get_batch(
                    val_dataset.data,
                    block_size=block_size,
                    batch_size=512,
                    device=device,
                )
                logits, loss = model(xb, yb)
                val_loss[i] = loss
            val_loss = val_loss.mean()

        t.set_description(f"train_loss: {train_loss:.4f} | val_loss: {val_loss:.4f}")

print("Done")

train_loss: 2.0146 | val_loss: 1.9644: 100%|██████████| 10000/10000 [02:35<00:00, 64.42it/s]

Done





In [12]:
x = torch.zeros((1, 1), dtype=torch.long)
x[0, 0] = 1
x = x.to(device)
print(model.decode(model.generate(idx=x, max_new_tokens=500)[0].tolist()))

 brught? As a fart talk upacuar, ou have hapack as seight Ge. ”"



Ixcitss, Myssitionowill. Hims of
toemanises is a meary qualioush
in susors the to
tost he is, glan where suddfallood beand and arr?"

“Every toses, bet on repired an doorsed in hyouble her her and I the ponsinly awith to me ince red in
a
butn’t at hild erstipery was ganat in monce excathance hars, bunutelly monmitles, only al havarcigh to be hare” It’nd from nobsy, what
just a all.

It don’t the toued Yearf you oner my, of cancit


In [13]:
# torch.save(
#     {
#         "state_dict": model.state_dict(),
#     },
#     "../model_store/attention.pt",
# )

In [14]:
print(model)

BigramLanguageModel(
  (token_embedding_table): Embedding(104, 24)
  (position_embedding_table): Embedding(8, 24)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-5): 6 x Head(
            (key): Linear(in_features=24, out_features=4, bias=False)
            (query): Linear(in_features=24, out_features=4, bias=False)
            (value): Linear(in_features=24, out_features=4, bias=False)
            (dropout): Dropout(p=0.4, inplace=False)
          )
        )
        (proj): Linear(in_features=24, out_features=24, bias=True)
        (dropout): Dropout(p=0.4, inplace=False)
      )
      (ffwd): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=24, out_features=96, bias=True)
          (1): ReLU()
          (2): Linear(in_features=96, out_features=24, bias=True)
          (3): Dropout(p=0.4, inplace=False)
        )
      )
      (ln1): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
      (l