In [13]:
# Simplest way to use context is average on all the previous chars
# This is called bow (bag of words)
# We want x[b,t] = mean_{i<=t} x[b,i]
torch.manual_seed(1337)
B,T,C = 1,2,3 # batch, time, channels (vocab size)
x = torch.randn(B,T,C) 
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)
print('x=')
print(x)
print('--')
print('xbow=')
print(xbow)


x=
tensor([[[-2.0260, -2.0655, -1.2054],
         [-0.9122, -1.2502,  0.8032]]])
--
xbow=
tensor([[[-2.0260, -2.0655, -1.2054],
         [-1.4691, -1.6579, -0.2011]]])


In [14]:
# version 2: using matrix multiply for a weighted aggregation, faster and recommended
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
print("wei=")
print(wei)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2)

wei=
tensor([[1.0000, 0.0000],
        [0.5000, 0.5000]])


True

In [17]:
# version 3: use Softmax
import torch.nn.functional as F
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
print("wei=")
print(wei)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

wei=
tensor([[1.0000, 0.0000],
        [0.5000, 0.5000]])


True

In [18]:
# Question: What are the disadvantages and advantages of bow?
# 1. Why should all the tokens in the history be equally important?
# 2. We may want to take a weighted combination for eg: weight closer tokens more than others?
# 3. Consider an example of coreference resolution "Bruno is my pet dog. He is a labrador". 
# In this case "He" should resolve to "Bruno" or "dog" and these components in the history much be weighted(attended to) more than the others
# 4. Why should the weighing be constant across batches?
