# Gated transformer

<img src='Supplementary material/gated_transformer.png'>


In [1]:
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class PositionwiseFeedForward(nn.Module):
    """
    Applies 2 linear layers with ReLU and dropout layers
    only after the first layer.
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

# Original Transformer Block 

This is the original Transformer block, with the NormLayer applied after the summation of the input with the output of the submodule. Notice that my implementation did not have the skip connection also for the second layer!

In [3]:
class AttentionBlock(nn.Module):
    def __init__(self, n_features, n_heads, n_hidden=64, dropout=0.1):
        """
        Args:
          n_features: Number of input and output features. (d_model)
          n_heads: Number of attention heads in the Multi-Head Attention.
          n_hidden: Number of hidden units in the Feedforward (MLP) block. (d_k)
          dropout: Dropout rate after the first layer of the MLP and the two skip connections.
        """
        super(AttentionBlock, self).__init__()
        self.norm = nn.LayerNorm(n_features)
        self.dropout = nn.Dropout(dropout)
        self.attn = nn.MultiheadAttention(n_features, n_heads, dropout)
        self.ff = PositionwiseFeedForward(n_features, n_hidden, dropout)
        
    def forward(self, x, mask=None):
        """
        Args:
          x of shape (n_pixels**2, batch_size, n_features): Input sequences.
          mask of shape (batch_size, max_seq_length): Boolean tensor indicating which elements of the input
              sequences should be ignored.
        
        Returns:
          z of shape (max_seq_length, batch_size, n_features): Encoded input sequence.

        Note: All intermediate signals should be of shape (n_pixels**2, batch_size, n_features).
        """

        attn_output, attn_output_weights = self.attn(x,x,x, key_padding_mask=mask) # MHA step
        x_norm = self.dropout(self.norm(attn_output + x)) # add and norm
        z = self.ff(x_norm) # FF step
        return self.dropout(self.norm(z)) # add and norm


In [35]:
n_features = 32
n_heads = 4

Tr = AttentionBlock(n_features, n_heads)

In [37]:
x = torch.ones(10, 1, n_features)

y = TrI(x)

# Transformer Block with Identity Map Reordering 

Now this one is the one that applies the LayerNorm only to the input of the submodule, but leaving untached the initial signal, so that the identity transformation can be easily implemented.

"Because the layer norm reordering causes a path where 2 linear layers are applied in sequence, we apply a ReLU activation to each sub-module output before the residual connection" (not 100% sure of which is the path and why the addition of just a single ReLU doesn't solve the problem, but 2 are required; not a big deal anyway).

In [10]:
class TransformerIdentityBlock(nn.Module):
    def __init__(self, n_features, n_heads, n_hidden=64, dropout=0.1):
        """
        Args:
          n_features: Number of input and output features. (d_model)
          n_heads: Number of attention heads in the Multi-Head Attention.
          n_hidden: Number of hidden units in the Feedforward (MLP) block. (d_k)
          dropout: Dropout rate after the first layer of the MLP and the two skip connections.
        """
        super(TransformerIdentityBlock, self).__init__()
        self.norm = nn.LayerNorm(n_features)
        self.dropout = nn.Dropout(dropout)
        self.attn = nn.MultiheadAttention(n_features, n_heads, dropout)
        self.ff = PositionwiseFeedForward(n_features, n_hidden, dropout)
        
    def forward(self, x, mask=None):
        """
        Args:
          x of shape (n_pixels**2, batch_size, n_features): Input sequences.
          mask of shape (batch_size, max_seq_length): Boolean tensor indicating which elements of the input
              sequences should be ignored.
        
        Returns:
          z of shape (max_seq_length, batch_size, n_features): Encoded input sequence.

        Note: All intermediate signals should be of shape (n_pixels**2, batch_size, n_features).
        """
        
        # First submodule
        x_norm = self.norm(x) # LayerNorm to the input before entering submodule
        attn_output, attn_output_weights = self.attn(x_norm, x_norm, x_norm, key_padding_mask=mask) # MHA step
        x = self.dropout(F.relu(attn_output) + x) # skip connection added
        
        # Second submodule
        x_norm = self.norm(x) # LayerNorm to the input before entering submodule
        z = F.relu(self.ff(x_norm)) # FF step
        return self.dropout(z+x) # skip connection added

In [39]:
n_features = 32
n_heads = 4

TrI = TransformerIdentityBlock(n_features, n_heads)

In [40]:
x = torch.ones(10, 1, n_features)

y = TrI(x)

# Gated Transformer with IMR 


<img src='Supplementary material/GRU_gating.png'>

In [27]:
debug = True

class GRU_gating(nn.Module):
    def __init__(self, n_features):
        super(GRU_gating, self).__init__()
        self.Wr = nn.Linear(n_features*2, n_features, bias=False)
        self.Wz = nn.Linear(n_features*2, n_features, bias=True)
        self.Wg = nn.Linear(n_features*2, n_features, bias=False)
        
    def forward(self, x, y):
        xy = torch.cat([x, y], axis=-1)
        if debug: print("xy.shape: ", xy.shape)
            
        r = torch.sigmoid(self.Wr(xy))
        if debug: print("r.shape: ", r.shape)
            
        z = torch.sigmoid(self.Wz(xy))
        if debug: print("z.shape: ", z.shape)
            
        rx = r*x
        if debug: print("rx.shape: ", rx.shape)
            
        h = torch.tanh(self.Wg(torch.cat([rx, y], axis=-1)))
        if debug: print("h.shape: ", h.shape)
            
        g = (1-z)*x + z*h
        if debug: print("g.shape: ", g.shape)
            
        return g

In [29]:
n_features = 10
GRU_gate = GRU_gating(n_features)

x = torch.ones(1, n_features)
print("x.shape: ", x.shape)

y = torch.rand(1, n_features)
print("y.shape: ", y.shape)

g = GRU_gate(x,y)
g

x.shape:  torch.Size([1, 10])
y.shape:  torch.Size([1, 10])
xy.shape:  torch.Size([1, 20])
r.shape:  torch.Size([1, 10])
z.shape:  torch.Size([1, 10])
rx.shape:  torch.Size([1, 10])
h.shape:  torch.Size([1, 10])
g.shape:  torch.Size([1, 10])


tensor([[0.6947, 0.6898, 0.5084, 0.3647, 0.5615, 0.4226, 0.2960, 0.5226, 0.7250,
         0.2281]], grad_fn=<AddBackward0>)

In [30]:
class GatedTransformerBlock(nn.Module):
    def __init__(self, n_features, n_heads, n_hidden=64, dropout=0.1):
        """
        Args:
          n_features: Number of input and output features. (d_model)
          n_heads: Number of attention heads in the Multi-Head Attention.
          n_hidden: Number of hidden units in the Feedforward (MLP) block. (d_k)
          dropout: Dropout rate after the first layer of the MLP and the two skip connections.
        """
        super(GatedTransformerBlock, self).__init__()
        self.norm = nn.LayerNorm(n_features)
        self.dropout = nn.Dropout(dropout)
        self.attn = nn.MultiheadAttention(n_features, n_heads, dropout)
        self.GRU_gate1 = GRU_gating(n_features)
        self.ff = PositionwiseFeedForward(n_features, n_hidden, dropout)
        self.GRU_gate2 = GRU_gating(n_features)
        
    def forward(self, x, mask=None):
        """
        Args:
          x of shape (n_pixels**2, batch_size, n_features): Input sequences.
          mask of shape (batch_size, max_seq_length): Boolean tensor indicating which elements of the input
              sequences should be ignored.
        
        Returns:
          z of shape (max_seq_length, batch_size, n_features): Encoded input sequence.

        Note: All intermediate signals should be of shape (n_pixels**2, batch_size, n_features).
        """
        
        # First submodule
        x_norm = self.norm(x) # LayerNorm to the input before entering submodule
        attn_output, attn_output_weights = self.attn(x_norm, x_norm, x_norm, key_padding_mask=mask) # MHA step
        x = self.dropout(self.GRU_gate1(x, attn_output)) # skip connection added
        
        # Second submodule
        x_norm = self.norm(x) # LayerNorm to the input before entering submodule
        z = self.ff(x_norm) # FF step
        return self.dropout(self.GRU_gate2(x, z)) # skip connection added

In [41]:
n_features = 32
n_heads = 4

GTr = GatedTransformerBlock(n_features, n_heads)

In [43]:
x = torch.ones(10, 1, n_features)
y = GTr(x)
y

xy.shape:  torch.Size([10, 1, 64])
r.shape:  torch.Size([10, 1, 32])
z.shape:  torch.Size([10, 1, 32])
rx.shape:  torch.Size([10, 1, 32])
h.shape:  torch.Size([10, 1, 32])
g.shape:  torch.Size([10, 1, 32])
xy.shape:  torch.Size([10, 1, 64])
r.shape:  torch.Size([10, 1, 32])
z.shape:  torch.Size([10, 1, 32])
rx.shape:  torch.Size([10, 1, 32])
h.shape:  torch.Size([10, 1, 32])
g.shape:  torch.Size([10, 1, 32])


tensor([[[ 3.2434e-01,  1.9966e-02,  2.5949e-01,  7.8418e-02,  3.7675e-01,
           7.6138e-02,  4.2284e-01,  0.0000e+00,  4.6148e-01,  3.7017e-01,
           0.0000e+00,  2.4634e-01,  4.0325e-01,  2.3206e-01,  3.2800e-01,
           2.6801e-01,  3.6624e-01,  3.2993e-01,  3.5932e-01,  0.0000e+00,
           2.7202e-01,  2.0279e-01,  2.0578e-01,  2.8484e-01,  3.7115e-02,
           3.7038e-01,  2.2633e-01,  4.1369e-01,  2.7697e-01,  1.1816e-01,
           4.0359e-01,  2.6289e-01]],

        [[ 3.6654e-01,  1.1305e-01,  0.0000e+00,  1.8434e-01,  4.7564e-01,
           8.2172e-02,  4.4056e-01,  3.9164e-01,  0.0000e+00,  3.8793e-01,
           4.6810e-01,  2.1911e-01,  4.4695e-01,  1.3583e-01,  3.7940e-01,
           2.2348e-01,  4.3927e-01,  1.9926e-01,  2.9011e-01,  3.1536e-01,
           1.8956e-01,  2.2109e-01,  1.4870e-01,  3.7345e-01,  3.3986e-01,
           3.4751e-01,  3.5846e-01,  3.6325e-01,  3.1484e-01, -4.1760e-02,
           4.7626e-01,  2.5629e-01]],

        [[ 3.9131e-01,