In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) = (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape  # (4, 8, 16)

torch.Size([4, 8, 16])

In [3]:
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) = (B, T, T)

In [4]:
k.var()

tensor(0.3071, grad_fn=<VarBackward0>)

In [5]:
q.var()

tensor(0.3514, grad_fn=<VarBackward0>)

In [6]:
wei.var()

tensor(1.5710, grad_fn=<VarBackward0>)

In [7]:
k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1) * head_size**-0.5

In [8]:
q.var()

tensor(0.3514, grad_fn=<VarBackward0>)

In [9]:
wei.var()

tensor(0.0982, grad_fn=<VarBackward0>)

In [10]:
wei[0][0]

tensor([ 0.1080,  0.2272, -0.0115, -0.3524, -0.1320,  0.0708,  0.0865,  0.1894],
       grad_fn=<SelectBackward0>)

# Train model

In [27]:
import torch.nn as nn
from data import FyodorDataset
from model import BigramLanguageModel
from train import get_batch
from utils import get_files_from_folder, open_txt
from tqdm import trange

batch_size = 128
block_size = 32
max_iters = 10000
eval_iters = 100
eval_interval = 100
learning_rate = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"
n_embd = 32

books = get_files_from_folder("../books")
books_string = [open_txt(f"../books/{i}") for i in books]
books = "\n".join(books_string)

train_dataset = FyodorDataset(books[: int(len(books) * 0.8)])
val_dataset = FyodorDataset(books[int(len(books) * 0.8) :])

model = BigramLanguageModel(n_embd=n_embd, block_size=block_size, device=device)
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

train_loss = float("inf")
val_loss = float("inf")
t = trange(max_iters)
for steps in t:
    # sample a batch of data
    xb, yb = get_batch(
        train_dataset.data, block_size=block_size, batch_size=batch_size, device=device
    )

    model.train()
    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % eval_interval == 0:
        train_loss = loss.item()
        # evalute on valditation set
        model.eval()
        val_loss = torch.zeros(eval_iters)
        with torch.no_grad():
            for i in range(eval_iters):
                xb, yb = get_batch(
                    val_dataset.data,
                    block_size=block_size,
                    batch_size=512,
                    device=device,
                )
                logits, loss = model(xb, yb)
                val_loss[i] = loss
            val_loss = val_loss.mean()

        t.set_description(f"train_loss: {train_loss:.4f} | val_loss: {val_loss:.4f}")

print("Done")

train_loss: 2.4255 | val_loss: 2.3884: 100%|██████████| 10000/10000 [00:58<00:00, 169.71it/s]

Done





In [33]:
x = torch.zeros((1, 1), dtype=torch.long)
x = x.to(device)
print(model.decode(model.generate(idx=x, max_new_tokens=1000)[0].tolist()))


tooitlat hot nishe sa shapti.
Thore I welitm my
wite duritnly
ther Loixpicedfusevastolly gqeirus, dimocud hald ar. "Arr edyout to spo wito Vad_
dry se mens.

“"Itlld fd ito, win mbarn, thady; tooutllls ad
of wimy ige erm ori. Lben hitou term tche tove, hars ils, sSter antios toesa’sin hely Roust ld toh wimed itou gh ut tethery simedingth tch wn noffu st to that to Sthe abid
ast wtis its skist brlow poporke hanraploke scof he, _I'
Co to tha yt aye In sta gh onle yond oron so dll sowusinger the, alu nsmou we irof omy, Lobla ours setnol, on ouine matzediondeles, hanid aboutt immeldere Tro cerapr of aride ye—f ase erekeer! "ly dis Dom. I ks iroun a, yon tiong heave at—ille, serw..
.. “On ow gay dsre ellitle fo righing,

annst ourderd
thingh wofu heattho No ug hgt he re velff..



“Miay a cat rshistthim bieleled‐mie ind ow thy.. Embe he drut the. “Rove of whotors led ckesw
timongat ay yofor anndiad her, ne fis ‘st we alyisasce ies whe ared dw ove, acly
thoulve sclo dit in. Whe saltitololer