In [2]:
with open("input.txt", "r") as f:
    text = f.read()

In [5]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.


In [6]:
len(text)

1115394

In [9]:
chars = sorted(list(set(text)))
s_to_i = {s: i for i, s in enumerate(chars)}
i_to_s = {i: s for i, s in enumerate(chars)}
encode = lambda s: [s_to_i[c] for c in s]
decode = lambda l: "".join([i_to_s[i] for i in l])

In [11]:
print(encode("Hello, world!"))
print(decode(encode("Hello, world!")))

[20, 43, 50, 50, 53, 6, 1, 61, 53, 56, 50, 42, 2]
Hello, world!


In [12]:
import torch

data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape)

torch.Size([1115394])


In [13]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [14]:
block_size = 8
x = train_data[:block_size]
y = train_data[1:block_size + 1]
for t in range(block_size):
    context = x[:t + 1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [18]:
torch.manual_seed(1337)
batch_size = 4  # how many independent sequences will we process in parallel?
block_size = 8  # what is the maximum context length for predictions?


def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    print(ix)
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
    return x, y


xb, yb = get_batch("train")
print(xb.shape)
print(yb.shape)
print(xb)
print(yb)
print("----------------------")
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t + 1]
        target = yb[b, t]
        print(f"when input is {context} the target: {target}")

tensor([ 76049, 234249, 934904, 560986])
torch.Size([4, 8])
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----------------------
when input is tensor([24]) the target: 43
when input is tensor([24, 43]) the target: 58
when input is tensor([24, 43, 58]) the target: 5
when input is tensor([24, 43, 58,  5]) the target: 57
when input is tensor([24, 43, 58,  5, 57]) the target: 1
when input is tensor([24, 43, 58,  5, 57,  1]) the target: 46
when input is tensor([24, 43, 58,  5, 57,  1, 46]) the target: 43
when input is tensor([24, 43, 58,  5, 57,  1, 46, 43]) the target: 39
when input is tensor([44]) the target: 53
when input is tensor([44, 53]) the target: 56
when input is tensor([44, 

In [40]:
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)


class BiggramLM(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are both (B, T) tensor of integers
        # logits is (B, T, C), B is batch size, T is sequence length, C is vocab size
        logits = self.token_embedding_table(idx)
        if targets is None:
            return logits, None
        B, T, C = logits.shape
        logits = logits.view(B * T, C)
        targets = targets.view(B * T)
        loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :]  # (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B,1)
            # append sampled index to the running sequence
            idx = torch.cat([idx, idx_next], dim=1)  # (B, T+1)
        return idx


m = BiggramLM(vocab_size=len(chars))
logits, loss = m(xb, yb)
print(logits.shape, loss)
print(decode(m.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 65]) tensor(4.8786, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


In [41]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [42]:
batch_size = 32
for steps in range(10000):
    xb, yb = get_batch("train")
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

tensor([880792, 620608, 758123, 761955, 483803,  61808, 476059, 726263, 656633,
        674426, 241399, 395150, 298543, 938016, 254737, 952644,  24901, 772494,
        628396, 556307, 319255,  38208, 849969, 394140, 919551, 316361, 359631,
        121698, 567966, 743553, 884616, 849510])
tensor([793375, 924645,  86369, 919544, 299140, 227344, 860300, 216165, 551383,
        555439, 641236, 435602, 759848, 700082, 918994, 468585, 273700,  95491,
        765171, 359861, 736914, 144022, 105516, 477200, 784451, 896134, 934681,
        376322, 502277, 204528, 202156, 543499])
tensor([ 800823,  930616,  153375,  834582,  860265,  296227,  993134,  806983,
          98006,  172503,  194830,  353142,  954475,  301736,  781978, 1002223,
         143971,  716105,  174172,  148687,  362485,  354704,  925089,  290609,
         108944,  926805,  989259,  608356,  416002,  157300,  833640,  657117])
tensor([680430, 266960, 331912, 406615, 677171,  61696, 506942, 766551, 981292,
        147889, 26677

In [44]:
print(decode(m.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))



Wengerofo'dsssit ey
KIN d pe wither vouprrouthercc.
hathe; d!
My hind tt hinig t ouchos tes; st yo hind wotte grotonear 'so it t jod weancotha:
h hay.JUCle n prids, r loncave w hollular s O:
HIs; ht anjx?

DUThinqunt.

LaZAnde.
athave l.
KEONH:
ARThanco be y,-hedarwnoddy scace, tridesar, wnl'shenous s ls, theresseys
PlorseelapinghiybHen yof GLUCEN t l-t E:
I hisgothers je are!-e!
QLYotouciullle'z,
Thitertho s?
NDan'spererfo cist ripl chys er orlese;
Yo jehof h hecere ek? wferommot mowo soaf yoi


In [61]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [62]:
# 先用简单的加权平均的方式，计算sequence T中第n个token和它之前的token的权重，比如第5个token的值是它前面4个token和自己加在一起后的均值
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t + 1]  # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

In [48]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [49]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [52]:
# 另一个高效的方式是用矩阵乘法
torch.manual_seed(42)
torch.tril(torch.ones(3,3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [53]:
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
a

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [56]:
b = torch.randint(1,10,(3,2)).float()
print(b)
c = a @ b
print(c)

tensor([[9., 1.],
        [1., 5.],
        [3., 5.]])
tensor([[9.0000, 1.0000],
        [5.0000, 3.0000],
        [4.3333, 3.6667]])


In [69]:
weight = torch.tril(torch.ones(T, T))
weight = weight / weight.sum(1, keepdim=True)
xbow2 = weight @ x  # (T, T) @ (B, T, C) ----> (B, T, C)
# xbow2 == xbow
torch.allclose(xbow, xbow2)
print(xbow[0] ,xbow2[0])

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]]) tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


In [67]:
# 还可以用softmax
tril = torch.tril(torch.ones(T, T))
weight = torch.zeros((T, T))
weight = weight.masked_fill(tril == 0, float('-inf'))
weight = F.softmax(weight, dim=-1)
xbow3 = weight @ x
torch.allclose(xbow2, xbow3)

True

In [72]:
# softmax的方式更能直观的表达， token不能看到未来的信息
# 然而对于一些场景，例如语义分析，token需要看到全部的信息，就会把这行代码删掉，这也被称为decoder，意为decode这句话
weight.masked_fill(tril == 0, float('-inf'))

tensor([[1.0000,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.5000, 0.5000,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.3333, 0.3333, 0.3333,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.2500, 0.2500, 0.2500, 0.2500,   -inf,   -inf,   -inf,   -inf],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000,   -inf,   -inf,   -inf],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667,   -inf,   -inf],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429,   -inf],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [80]:
# self-attention
# each token has two vector Q, K
# the Q means what do I looking for? the K means what do I contain?
# so my Q @ others K  become the weight
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

# let`s see a single head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
# communication occurs
weight = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)
# 根据公式，该weight还需要 * head_size**-0.5，因为之后要进行softmax，对于数值范围过大的值，softmax效果不理想，为了有一个更好的initialization

tril = torch.tril(torch.ones(T, T))
# weight = torch.zeros((T, T))
weight = weight.masked_fill(tril == 0, float('-inf'))
weight = F.softmax(weight, dim=-1)

# we usuarlly not @x, instead @v
# out = weight @ x
v = value(x)
out = weight @ v
print(out[0])
out.shape
# 这被称为self attention，是因为q,k,v都是从x自己来的
# 而cross attention，q来自x，k,v可能来自外部，例如上下文或者其他的encoder，x only making queries, and get information from outside

tensor([[-0.1571,  0.8801,  0.1615, -0.7824, -0.1429,  0.7468,  0.1007, -0.5239,
         -0.8873,  0.1907,  0.1762, -0.5943, -0.4812, -0.4860,  0.2862,  0.5710],
        [ 0.6764, -0.5477, -0.2478,  0.3143, -0.1280, -0.2952, -0.4296, -0.1089,
         -0.0493,  0.7268,  0.7130, -0.1164,  0.3266,  0.3431, -0.0710,  1.2716],
        [ 0.4823, -0.1069, -0.4055,  0.1770,  0.1581, -0.1697,  0.0162,  0.0215,
         -0.2490, -0.3773,  0.2787,  0.1629, -0.2895, -0.0676, -0.1416,  1.2194],
        [ 0.1971,  0.2856, -0.1303, -0.2655,  0.0668,  0.1954,  0.0281, -0.2451,
         -0.4647,  0.0693,  0.1528, -0.2032, -0.2479, -0.1621,  0.1947,  0.7678],
        [ 0.2510,  0.7346,  0.5939,  0.2516,  0.2606,  0.7582,  0.5595,  0.3539,
         -0.5934, -1.0807, -0.3111, -0.2781, -0.9054,  0.1318, -0.1382,  0.6371],
        [ 0.3428,  0.4960,  0.4725,  0.3028,  0.1844,  0.5814,  0.3824,  0.2952,
         -0.4897, -0.7705, -0.1172, -0.2541, -0.6892,  0.1979, -0.1513,  0.7666],
        [ 0.1866, -0.0

torch.Size([4, 8, 16])