In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [15]:
B, T, C = 2, 3, 2  # Batch size = 2, Time steps = 3, Feature size = 2

# Example tensor `x`
x = torch.tensor([
    [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],    # Batch 1
    [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]  # Batch 2
])

print("Input tensor `x`:")
print(x)

Input tensor `x`:
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]]])


In [None]:
# bag of words
xbow = torch.zeros((B,T,C))
print("xbow:")
print(xbow)
for b in range(B):
    for t in range(T):
            xprev = x[b,:t+1]
            print("xprev:")
            print(xprev)
            xbow[b,t] = torch.mean(xprev,0)
            print("xbow:")
            print(xbow)

In [None]:
cumulative_sum = torch.cumsum(x, dim=1)
print("cumulative_sum:")
print(cumulative_sum)
time_indices = torch.arange(1, T+1, device=x.device).view(1, T, 1)
print("time_indices:")
print(time_indices)
xbow2 = cumulative_sum / time_indices
print("xbow2:")
print(xbow2)


In [34]:
T, C = 8, 32
x = torch.randn(T, C)

head_size = 8
query = nn.Linear(C, head_size)
key = nn.Linear(C, head_size)
value = nn.Linear(C, head_size)

q = query(x)
k = key(x)

wei = q @ k.transpose(-2, -1)

tril = torch.tril(torch.ones(T, T), diagonal=0)

wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v





In [None]:
print("wei:")
print(wei)

In [None]:
print("tril:")
print(tril)

In [35]:
print("out:")
print(out)

out:
tensor([[ 0.0649,  0.1602, -0.0311,  0.6902, -0.7329, -0.5998, -0.5210, -0.2716],
        [-0.0886, -0.0320, -0.0311,  0.7142, -0.2816,  0.0311, -0.1482, -0.1849],
        [-0.2431,  0.2085, -0.0192,  0.6746, -0.3607,  0.3686, -0.2317, -0.3600],
        [-0.2869, -0.1572, -0.0683,  0.7051,  0.1820,  0.8295,  0.2040, -0.0929],
        [ 0.1469,  0.0730, -0.2744,  0.5075, -0.3719, -0.0090, -0.4512,  0.1519],
        [-0.0554,  0.3178, -0.1247,  0.4949, -0.1455,  0.5526, -0.1992, -0.0362],
        [-0.0352,  0.1734, -0.1119,  0.5989, -0.1932,  0.3782, -0.2868,  0.0249],
        [-0.3431,  0.4761, -0.1086,  0.6183, -0.0982,  0.5465,  0.0169, -0.0715]],
       grad_fn=<MmBackward0>)


In [47]:
# scaled attention score

# softmax saturates to the max value
x = torch.tensor([0.1, -0.2, 0.3, 0.4, 0.5])
v1 = torch.softmax(x, dim=0)
print("v1:", v1)

x = x * 8
print("scaled up x:", x)
v2 = torch.softmax(x, dim=0)

# softmax sharps to the max value
print("v2:", v2)

# keeps the variance stable
T, C = 8, 32
x = torch.randn(T, C)

head_size = 64
q = torch.randn(C, head_size)
k = torch.randn(C, head_size)

# q and v have the similar variance which is close to 1
print("q.var:", q.var()) 
print("k.var:", k.var())

wei = q @ k.transpose(-2, -1) 
print("wei.var:", wei.var())

# scaled attention score's variance is close to 1 also
wei = wei / (head_size ** 0.5)
print("scaled wei.var:", wei.var())

v1: tensor([0.1723, 0.1276, 0.2104, 0.2326, 0.2570])
scaled up x: tensor([ 0.8000, -1.6000,  2.4000,  3.2000,  4.0000])
v2: tensor([0.0240, 0.0022, 0.1191, 0.2650, 0.5897])
q.var: tensor(1.0746)
k.var: tensor(0.9732)
wei.var: tensor(72.8887)
scaled wei.var: tensor(1.1389)


In [54]:
class Head(nn.Module):
    def __init__(self, block_size, n_embd, head_size):
        print(f"block_size:{block_size}, n_embd: {n_embd}, head_size: {head_size}")
        
        self.block_size = block_size
        self.n_embd = n_embd
        self.head_size = head_size

        super().__init__()
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size), diagonal=0))

    def forward(self, x):
        B, T, C = x.shape
        assert T == self.block_size
        assert C == self.n_embd

        q = self.query(x) # (B, block_size, head_size)
        k = self.key(x) # (B, block_size, head_size)

        wei = q @ k.transpose(-2, -1) # (B, block_size, block_size)
        wei = wei / (self.head_size ** 0.5)
        wei = wei.masked_fill(self.tril == 0, float('-inf'))
        #print("masked wei:")
        #print(wei)
        wei = F.softmax(wei, dim=-1)
        #print("softmax wei:")
        #print(wei)

        v = self.value(x) # (B, block_size, head_size)
        out = wei @ v # (B, block_size, head_size)
        #print("out.shape:", out.shape)
        return out

In [None]:
B, T, C= 4, 8, 32
block_size = 8
n_embd = 32
head_size = 8

x = torch.randn(B, T, C)
head = Head(block_size, n_embd, head_size)
out = head(x)
print("out:")
print(out)

In [59]:
class MultiHeadAttention(nn.Module):
    def __init__(self, block_size, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.heads = nn.ModuleList([Head(block_size, n_embd, head_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        head_out = [head(x) for head in self.heads]
        print(type(head_out))
        out = torch.cat(head_out, dim=-1)
        print("out.shape:", out.shape)
        out = self.proj(out)
        return out

In [60]:
B, T, C= 4, 8, 32
block_size = 8
n_embd = 32
n_head = 4

x = torch.randn(B, T, C)
ma = MultiHeadAttention(block_size, n_embd, n_head)
out = ma(x)


block_size:8, n_embd: 32, head_size: 8
block_size:8, n_embd: 32, head_size: 8
block_size:8, n_embd: 32, head_size: 8
block_size:8, n_embd: 32, head_size: 8
<class 'list'>
out.shape: torch.Size([4, 8, 32])


In [66]:
# cat function sample
tensors = [torch.tensor([[1, 2]]), torch.tensor([[3, 4]]), torch.tensor([[5, 6]])]
result = torch.cat(tensors, dim=-1)
print(result)
# Output: 
# tensor([[1, 2, 3, 4, 5, 6]])
result = torch.cat(tensors, dim=0)
print(result)
# Output:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])



# cat function sample
tensors = [torch.tensor([[1, 2], [1, 2]]), torch.tensor([[3, 4], [3, 4]]), torch.tensor([[5, 6], [5, 6]])]
result = torch.cat(tensors, dim=-1)
print(result)
# Output: 
# tensor([[1, 2, 3, 4, 5, 6],
#        [1, 2, 3, 4, 5, 6]])
result = torch.cat(tensors, dim=0)
print(result)
# Output:
# tensor([[1, 2],
#        [1, 2],
#        [3, 4],
#        [3, 4],
#        [5, 6],
#        [5, 6]])



tensor([[1, 2, 3, 4, 5, 6]])
tensor([[1, 2],
        [3, 4],
        [5, 6]])
tensor([[1, 2, 3, 4, 5, 6],
        [1, 2, 3, 4, 5, 6]])
tensor([[1, 2],
        [1, 2],
        [3, 4],
        [3, 4],
        [5, 6],
        [5, 6]])


In [88]:
# simulate generation from a single input token
context_len = 8
tril = torch.tril(torch.ones(context_len, context_len), diagonal=0)
#print(tril)


query = nn.Linear(8, 8, bias=False)
key = nn.Linear(8, 8, bias=False)
value = nn.Linear(8, 8, bias=False)

# two input tokens
x = torch.randn(2,8)
q = query(x)
k = key(x)

print("q:", q)
print("k:", k)

wei = q @ k.transpose(-2, -1)
# print("wei:", wei)

mask = tril[:2, :2]
wei = wei.masked_fill(mask == 0, float('-inf'))
# print("masked wei:", wei)
wei = F.softmax(wei, dim=-1)
print("wei:", wei)

v = value(x)
print("v:", v)
out = wei @ v
print("out:", out)

# three input tokens
x = torch.randn(3,8)
q = query(x)
k = key(x)

print("q:", q)
print("k:", k)

wei = q @ k.transpose(-2, -1)
# print("wei:", wei)
mask = tril[:3, :3]
wei = wei.masked_fill(mask == 0, float('-inf'))

# print("masked wei:", wei)
wei = F.softmax(wei, dim=-1)
print("wei:", wei)


v = value(x)
print("v:", v)
out = wei @ v
print("out:", out)










q: tensor([[-0.7740, -0.0146,  0.2315, -0.1058, -0.4416, -0.5640, -0.3614,  0.1565],
        [ 0.3142,  0.6628,  1.0521, -0.4043, -0.0742,  0.8171, -0.4285, -0.2207]],
       grad_fn=<MmBackward0>)
k: tensor([[ 1.1275,  0.0791,  0.0754,  0.1742,  0.3658, -0.5434,  0.2000, -0.2962],
        [-1.0723, -1.2245,  0.2277,  0.2622, -1.3923, -0.3449,  0.3051,  0.1239]],
       grad_fn=<MmBackward0>)
wei: tensor([[1.0000, 0.0000],
        [0.7817, 0.2183]], grad_fn=<SoftmaxBackward0>)
v: tensor([[ 0.0613,  0.5328,  0.4237,  0.5995, -0.0622,  0.1780, -0.6394,  0.2454],
        [ 0.3002, -0.4567,  0.1420,  0.2841,  1.2245,  0.1776, -0.3756, -0.1101]],
       grad_fn=<MmBackward0>)
out: tensor([[ 0.0613,  0.5328,  0.4237,  0.5995, -0.0622,  0.1780, -0.6394,  0.2454],
        [ 0.1135,  0.3168,  0.3622,  0.5307,  0.2187,  0.1779, -0.5818,  0.1678]],
       grad_fn=<MmBackward0>)
q: tensor([[ 0.8806,  0.5689,  0.3013, -0.4565,  1.0761,  0.6579,  0.2991,  0.5608],
        [ 0.4317,  0.2750,  0.3457,