In [1]:
import torch
def randomize(*l,scale=10): # randomize the content of each tensor in l
  for x in l: x[...] = scale*torch.rand(1)*torch.rand(x.shape)
def div(u,v): return torch.amax(abs(v-u)).item() # divergence between two tensors

## Overparametrisation of torch.nn.MultiheadAttention

The purpose of this snippet is to show that the pytorch implementation of [MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) is overparametrized in its management of biases.

The bias for the query, key and value are held in attribute `in_proj_bias` (present by default, can be disabled by passing `bias=False` in the invocation).

The snippet shows that modifying the key bias does not change the result. Changing the other two biases does change the result.

This can be shown from the formula (although it is not explicitly given in the doc, one has to guess it) using the fact that softmax is invariant to an additive constant.

The practical consequence is that it spends some time at each iteration both
* forward, involving a parameter which does not change the result and
* backward, propagating null gradients to update it.

Probably negligible in the tsunami of propagations which take place, still could be removed.

In [2]:
# Create a torch.nn.MultiheadAttention instance and randomize all its parameters
a = torch.nn.MultiheadAttention(embed_dim=12,num_heads=4,batch_first=True)
#a = torch.nn.MultiheadAttention(embed_dim=12,num_heads=4,batch_first=True,kdim=17,vdim=19) # variant with key and value dimensions
for p in a.parameters(): randomize(p.data)
# Sample some input
B = 64; M = 100; N = 80
yʹ = torch.rand(B,N,a.embed_dim) # query input
xʹ = torch.rand(B,M,a.kdim) # key input
x =  torch.rand(B,M,a.vdim) # value input
outs = (B,N,a.embed_dim) # shape of the output
print(f'Output shape: {outs}')
with torch.no_grad():
  # iterate over the three biases (query,key,value)
  for bias_name,bias in zip(('query','key','value'),torch.chunk(a.in_proj_bias.data,3)):
    # First randomize the selected bias and compute the output.
    randomize(bias); y1,_ = a(yʹ,xʹ,x)
    # Now randomize the same bias again and compute the output.
    randomize(bias); y2,_ = a(yʹ,xʹ,x)
    assert y1.shape == outs and y2.shape == outs # sanity check
    # compare the two outputs
    print(bias_name.ljust(5),':',f'{div(y1,y2):.2g}')

Output shape: (64, 80, 12)
query : 20
key   : 0.00092
value : 6.3


The result for the key bias should be null, non null value is due to numerical instability.

## Alternative implementation of multi-head attention

The [PYTOOLS implementation](https://github.com/jmandreoli/PYTOOLS/blob/master/src/torch.py) of multi-head attention does not have a key bias. Furthermore, it is open to extensions of the attention mechanism. One extension (Mixed attention) is provided.

This snippet shows that the implementation is equivalent (up to numerical instability) to the original.

In [3]:
# Create a torch.nn.MultiheadAttention instance and randomize all its parameters
#a = torch.nn.MultiheadAttention(embed_dim=12,num_heads=4,bias=True,batch_first=True)
for p in a.parameters(): randomize(p.data)
# Create the corresponding PYTOOLS implementation
from myutil.torch import MultiHeadAttention
a_ = MultiHeadAttention.torch_convert(a)
# Compare the two implementation on some random sample
fwd,bwd = a_.torch_compare(a,B=64,M=100,N=80)
# display result
print('forward'.ljust(8),''.ljust(5),':',f'{div(*fwd):.2g}')
for head,(p,(u_,u)) in zip(('backward',*100*('',)),bwd.items()): print(head.ljust(8),p.ljust(5),':',f'{div(u_,u):.2g}')

forward        : 0.023
backward Λ[1]  : 1.8e-05
         Λ[0]  : 4.1e-05
         ϴ[0]  : 3.1e-06
         ϴ[1]  : 1.7e-06
         Λₒ    : 3e-05
         Θₒ[0] : 4.5e-06
         Θₒ[1] : 0


All results should be null, non null values are due to numerical instability.