In [110]:
import torch
from torch.nn import functional as F
import torch.nn as nn

In [124]:
torch.manual_seed(1337)
B,T,C = 4,8,32
vocab_size = 300
block_size = T
n_embd = 32
#B -> Batch Size, T -> Block Size/Time horizon, C -> channels/covariates
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 32])

In [93]:
%%time
#We want x[b,t] = mean_{i<=t} x[b,t]
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]
#         print(xprev)
        xbow[b,t] = torch.mean(xprev, 0)

CPU times: user 19.3 s, sys: 183 ms, total: 19.5 s
Wall time: 19.4 s


In [12]:
torch.mean(torch.tensor([[ 0.1808, -0.0700],[-0.3596, -0.9152]]), 0)

tensor([-0.0894, -0.4926])

In [15]:
(-0.07-0.9152)/2

-0.49260000000000004

In [28]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
# a = a/torch.sum(a,1, keepdim = True)

In [29]:
a

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [36]:
torch.sum(a,1, keepdim = True); #row-wise sum of a, keep_dim retains the shape
torch.sum(a,0, keepdim=True); #col-wise sum of a
torch.sum(a,1, keepdim=True).shape, torch.sum(a,0, keepdim=True).shape,torch.sum(a,1, keepdim=False).shape, torch.sum(a,0, keepdim=False).shape

(torch.Size([3, 1]), torch.Size([1, 3]), torch.Size([3]), torch.Size([3]))

In [47]:
a/torch.sum(a,0, keepdim=True); #col-wise average of a
a/torch.sum(a,1, keepdim=True) #row-wise average of a

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [102]:
%%timeit
wei = torch.tril(torch.ones(T,T))
wei = wei/wei.sum(1, keepdim=True)
xbow2 = wei @ x # (T,T) is converted to (B,T,T) internally by pytorch (broadcasting) @ (B,T,C) ----> Batched matrix multiply

98.6 ms ± 1.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [105]:
torch.allclose(xbow,xbow2)

False

In [103]:
%%timeit
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=1)
xbow3 = wei @ x

99.6 ms ± 2.26 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [109]:
torch.allclose(xbow,xbow3, atol=1e-06)

True

In [107]:
xbow[0][0]

tensor([ 0.1808, -0.0700, -0.3596, -0.9152,  0.6258,  0.0255,  0.9545,  0.0643,
         0.3612,  1.1679, -1.3499, -0.5102,  0.2360, -0.2398, -0.9211,  1.5433,
         1.3488, -0.1396,  0.2858,  0.9651])

In [108]:
xbow2[0][0]

tensor([ 0.1808, -0.0700, -0.3596, -0.9152,  0.6258,  0.0255,  0.9545,  0.0643,
         0.3612,  1.1679, -1.3499, -0.5102,  0.2360, -0.2398, -0.9211,  1.5433,
         1.3488, -0.1396,  0.2858,  0.9651])

In [176]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.pos_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
    
    def forward(self, idx, targets=None):
        B,T = idx.shape
        #idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.pos_embedding_table(torch.arange(T)) #(T,C)
        x = tok_emb+pos_emb # (B,T,C)+(T,C) --> (B,T,C) (broadcasting)
        x = self.sa_head(x)
        logits = self.lm_head(x) # (B,T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        #idx is a (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            #crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            #get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B,C)
            probs = F.softmax(logits, dim=-1) # (B,C)
            idx_next = torch.multinomial(probs, num_samples=1) #(B,1)
            idx = torch.cat((idx, idx_next), dim=1)
        
        return idx

In [172]:
#how does a single head perform self attention
head_size = 16
key = nn.Linear(C,head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B,T,16)
q = query(x) # (B,T,16)
wei = q @ k.transpose(1,2)# (B,T,16) @ (B,16,T) ---> (B,T,T)

tril = torch.tril(torch.ones(T,T))
# wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [175]:
class Head(nn.Module):
    """one head of self-attention"""
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd,head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        
    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        w = q @ k.transpose(-2,-1) * C**-0.5
        w = w.masked_fill(self.tril==0, float('-inf'))
        w = F.softmax(w, dim=-1)
        v = self.value(x)
        out = w@v
        return out
    


In [153]:
k.transpose(-2,-1).shape

torch.Size([4, 16, 8])

In [156]:
torch.allclose(k.transpose(-1,1), k.transpose(1,2))

True

In [160]:
torch.transpose(torch.randn(4,8,10,2), 0,2).shape

torch.Size([10, 8, 4, 2])

In [158]:
torch.transpose(torch.randn(4,8,10,2), -2,-1).shape

torch.Size([4, 8, 2, 10])