In this notebook, we will investigate how a transformer layer reacts on permuted input, i.e. changing the order of words in a sentence. We keep the model dimensions small to be able to visually inspect the output.

In [1]:
import torch
torch.manual_seed(123)

<torch._C.Generator at 0x7f9832302930>

In [2]:
# Model dimension
D = 8
# Length of encoder input
L = 4
#
# Create random input
#
X = torch.randn(L, D)
#
# and feed through an encoder block
#
block = torch.nn.TransformerEncoderLayer(d_model = D, nhead = 4, dropout = 0)
Y = block(X).detach()

In [3]:
#
# Now permute the input, recompute
#
Xp = X.flip(dims = [0]).detach()
Yp = block(Xp).detach()
#
# Verify that Yp is simply the permutation of Y
#
print(f"{torch.allclose(Yp, Y.flip(dims = [0]))}")

True


So we have confirmed that permuting the input of a decoder layer simply results in the same permutation being applied to the output. Now let us simulate that this encoder output is fed as input into the attention layer of a decoder. Thus the keys and values are the encoder output Y respectively Yp, while the queries come from the decoder input and are unpermuted.

In [4]:
# length of target sequence, i.e. decoder input
T = 3 
queries = torch.randn(T, D)

In [5]:
attn = torch.nn.MultiheadAttention(embed_dim = D, num_heads = 4)
#
# Put into eval mode to avoid dropout
#
attn.eval()

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
)

In [6]:
#
# Values and keys are both the encoder output
#
out = attn(queries, Y, Y)[0].detach()
outp = attn(queries, Yp, Yp)[0].detach()
#
# Compare
#
print(f"{torch.allclose(out, outp)}")

True


Thus the attention layer sitting between the encoder and the decoder layers is invariant under permutations, i.e. permuting the input does not change the output. Taking these two observations together implies that when permuting the inputs to an encoder-decoder combination,  i.e. the source sentence, the output of the model does not change. Thus the model is insensitive towards permutations. We can also verify this with a full transformer directly.

In [17]:
transformer = torch.nn.Transformer(d_model = D, nhead = 1)
transformer.eval()
tgt = torch.randn(T, D)
src = torch.randn(L, D)
src_permuted = src.flip(dims = [0])
out = transformer(src, tgt)
_out = transformer(src_permuted, tgt)
print(out)
print(_out)
print(f"{torch.allclose(out, _out)}")

tensor([[ 1.3313,  0.3166, -1.6677, -0.7317,  0.4747, -1.0081,  0.0325,  1.2522],
        [ 1.2092,  0.2136, -2.0010,  0.2608,  0.4414, -0.7903, -0.4971,  1.1634],
        [ 1.0899, -0.1775, -2.1671, -0.3654,  0.9013, -0.4634,  0.2556,  0.9265]],
       grad_fn=<NativeLayerNormBackward0>)
tensor([[ 1.3313,  0.3166, -1.6677, -0.7317,  0.4747, -1.0081,  0.0325,  1.2522],
        [ 1.2092,  0.2136, -2.0010,  0.2608,  0.4414, -0.7903, -0.4971,  1.1634],
        [ 1.0899, -0.1775, -2.1671, -0.3654,  0.9013, -0.4634,  0.2556,  0.9265]],
       grad_fn=<NativeLayerNormBackward0>)
False
