In [2]:
import torch
import torch.nn as nn
import torch.functional as F

In [8]:
# Chunk a single fc layer instead of using multiple fc layers for multiple heads
tensor=torch.rand(1,8,9)
fc=nn.Linear(9,9) # Fully connected layer
q=fc(tensor)
chunk1,chunk2,chunk3=torch.chunk(q,3,dim=-1)
print(chunk1)

tensor([[[-0.4768, -0.0741,  0.0344],
         [-0.3951, -0.2308, -0.1386],
         [-0.6783, -0.1651,  0.1059],
         [-0.6053,  0.0069,  0.0992],
         [-0.2338, -0.0089,  0.0371],
         [-0.2259, -0.2235, -0.1448],
         [-0.3738, -0.1831, -0.3939],
         [-0.0658, -0.4086, -0.1901]]], grad_fn=<SplitBackward0>)


In [11]:
# Matrix multiplication in PyTorch
a=torch.rand(1,2,6,4)
b=torch.rand(1,2,4,3)
c=a@b # (1, 2, 6, 3)
print(c.shape)

torch.Size([1, 2, 6, 3])


In [24]:
# Attention masking
rand_attn=torch.rand(1,6,6) # (batch_size, seq_len, seq_len)
attention_mask=torch.tensor([1,1,1,1,0,0]).unsqueeze(0).bool()
print(attention_mask.shape)

# Method 1: auto broadcast by PyTorch
'''attention_mask=attention_mask.unsqueeze(1)
rand_attn.masked_fill_(~attention_mask,-float('inf'))
print(rand_attn)
rand_attn=rand_attn.softmax(axis=-1)
print(rand_attn)'''

# Method 2: Using repeat(). We will need this when we implement Flash Attention
attention_mask=attention_mask.unsqueeze(1)
print(attention_mask.shape)
attention_mask=attention_mask.repeat(1,6,1)
print(attention_mask.shape)
rand_attn=rand_attn.masked_fill(~attention_mask,-float('inf'))
print(rand_attn)
rand_attn=rand_attn.softmax(axis=-1)
print(rand_attn)

torch.Size([1, 6])
torch.Size([1, 1, 6])
torch.Size([1, 6, 6])
tensor([[[0.5335, 0.8167, 0.0448, 0.4094,   -inf,   -inf],
         [0.4415, 0.3949, 0.9173, 0.8123,   -inf,   -inf],
         [0.4609, 0.3791, 0.9816, 0.3266,   -inf,   -inf],
         [0.1590, 0.7717, 0.4184, 0.8646,   -inf,   -inf],
         [0.7205, 0.9850, 0.5376, 0.5326,   -inf,   -inf],
         [0.9981, 0.3274, 0.3654, 0.7623,   -inf,   -inf]]])
tensor([[[0.2615, 0.3471, 0.1604, 0.2310, 0.0000, 0.0000],
         [0.1995, 0.1904, 0.3210, 0.2890, 0.0000, 0.0000],
         [0.2233, 0.2057, 0.3758, 0.1952, 0.0000, 0.0000],
         [0.1622, 0.2993, 0.2102, 0.3284, 0.0000, 0.0000],
         [0.2522, 0.3286, 0.2101, 0.2090, 0.0000, 0.0000],
         [0.3530, 0.1805, 0.1875, 0.2789, 0.0000, 0.0000]]])


In [18]:
# Attention masking with multihead
rand_attn=torch.rand(1,2,6,6)
print(rand_attn.shape)
attention_mask=torch.tensor([1,1,1,1,0,0])
attention_mask=attention_mask.unsqueeze(0)
attention_mask=attention_mask.unsqueeze(1)
attention_mask=attention_mask.unsqueeze(1)
attention_mask=attention_mask.bool()
print(attention_mask.shape)
rand_attn=rand_attn.masked_fill(~attention_mask,float('-inf'))
print(rand_attn)

torch.Size([1, 2, 6, 6])
torch.Size([1, 1, 1, 6])
tensor([[[[0.2767, 0.1329, 0.3028, 0.3019,   -inf,   -inf],
          [0.8389, 0.3400, 0.3342, 0.1790,   -inf,   -inf],
          [0.5019, 0.6945, 0.4680, 0.2863,   -inf,   -inf],
          [0.7435, 0.3661, 0.7698, 0.0816,   -inf,   -inf],
          [0.0601, 0.2670, 0.0176, 0.7797,   -inf,   -inf],
          [0.0925, 0.4264, 0.4410, 0.6303,   -inf,   -inf]],

         [[0.8490, 0.4955, 0.4319, 0.1780,   -inf,   -inf],
          [0.8582, 0.0552, 0.7631, 0.2467,   -inf,   -inf],
          [0.7036, 0.0575, 0.7861, 0.8282,   -inf,   -inf],
          [0.5695, 0.5102, 0.0829, 0.0605,   -inf,   -inf],
          [0.4161, 0.6617, 0.1674, 0.7407,   -inf,   -inf],
          [0.7881, 0.8110, 0.0934, 0.9591,   -inf,   -inf]]]])
