# Self-Attention and Positional Encoding

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

## Attention

In [2]:
class Attention(nn.Module):
    def __init__(self, hidden_dim, input_dim=None, proj_values=False):
        super(Attention, self).__init__()
        self.d_k = hidden_dim
        self.input_dim = hidden_dim if input_dim is None else input_dim
        self.proj_values = proj_values
        # Affine transformation for q, k , v
        self.linear_query = nn.Linear(self.input_dim, hidden_dim)
        self.linear_key = nn.Linear(self.input_dim, hidden_dim)
        self.linear_value = nn.Linear(self.input_dim, hidden_dim)
        self.alphas = None
    
    def init_keys(self, keys):
        self.keys = keys
        self.proj_keys = self.linear_key(self.keys)
        self.values = self.linear_value(self.keys) if self.proj_values else self.keys
        
    # alignment scores
    def score_function(self, query):
        proj_query = self.linear_query(query)
        dot_products = torch.bmm(proj_query, self.proj_keys.permute(0, 2, 1))
        scores = dot_products / np.sqrt(self.d_k)
        return scores
    
    def forward(self, query, mask=None):
        scores = self.score_function(query)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        alphas = F.softmax(scores, dim=-1)
        self.alphas = alphas.detach()
        
        context = torch.bmm(alphas, self.values)
        return context

## Muti-Headed Attention

In [3]:
class MutiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, input_dim=None, proj_values=True):
        super(MutiHeadAttention, self).__init__()
        self.linear_out = nn.Linear(n_heads * d_model, d_model)
        self.attn_heads = nn.ModuleList(
            [Attention(d_model, input_dim=input_dim, proj_values=proj_values) for _ in range(n_heads)]
        )
    
    def init_keys(self, key):
        for attn in self.attn_heads:
            attn.init_keys(key)
    
    @property
    def alphas(self):
        return torch.stack(
            [attn.alphas for attn in self.attn_heads], dim=0
        )
    
    def output_function(self, contexts):
        concatenated = torch.cat(contexts, axis=-1)
        out = self.linear_out(concatenated)
        return out
    
    def forward(self, query, mask=None):
        contexts = [attn(query, mask=mask) for attn in self.attn_heads]
        out = self.output_function(contexts)
        return out

## data

In [4]:
full_seq = torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]]).float().view(1, 4, 2)
full_seq

tensor([[[-1., -1.],
         [-1.,  1.],
         [ 1.,  1.],
         [ 1., -1.]]])

In [5]:
source_seq = full_seq[:, :2, :]
target_seq = full_seq[:, 2:, :]
source_seq, target_seq

(tensor([[[-1., -1.],
          [-1.,  1.]]]),
 tensor([[[ 1.,  1.],
          [ 1., -1.]]]))

## Encoder + Self-Attention

In [6]:
class Encoder(nn.Module):
    def __init__(self, n_heads, d_model, ff_units, n_features=None):
        super(Encoder, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.ff_units = ff_units
        self.n_features = n_features
        self.self_attn_heads = MutiHeadAttention(n_heads, d_model, input_dim=n_features)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_units),
            nn.ReLU(),
            nn.Linear(ff_units, d_model),
        )
    
    def forward(self, query, mask=None):
        self.self_attn_heads.init_keys(query)
        att = self.self_attn_heads(query, mask)
        out = self.ffn(att)
        return out

In [7]:
torch.manual_seed(11)
encoder = Encoder(n_heads=3, d_model=2, ff_units=10, n_features=2)
query = source_seq
encoder_states = encoder(query)
encoder_states

tensor([[[-0.0498,  0.2193],
         [-0.0642,  0.2258]]], grad_fn=<ViewBackward0>)

## Decoder + Self-Attention

In [8]:
class Decoder(nn.Module):
    def __init__(self, n_heads, d_model, ff_units, n_features=None):
        super(Decoder, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.ff_units = ff_units
        self.n_features = d_model if n_features is None else n_features
        self.self_attn_heads = MutiHeadAttention(n_heads, d_model, input_dim=self.n_features)
        self.cross_attn_heads = MutiHeadAttention(n_heads, d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_units),
            nn.ReLU(),
            nn.Linear(ff_units, self.n_features)
        )
        
    def init_keys(self, states):
        self.cross_attn_heads.init_keys(states)
        
    def forward(self, query, source_mask=None, target_mask=None):
        self.self_attn_heads.init_keys(query)
        att1 = self.self_attn_heads(query, target_mask)
        att2 = self.cross_attn_heads(att1, source_mask)
        out = self.ffn(att2)
        return out

## target mask(traning)

In [9]:
shifted_seq = torch.cat([source_seq[:, -1:], target_seq[:, :-1]], dim=1)
shifted_seq

tensor([[[-1.,  1.],
         [ 1.,  1.]]])

In [10]:
def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = (1 - torch.triu(torch.ones(attn_shape), diagonal=1)).bool()
    return subsequent_mask

In [11]:
subsequent_mask(2)

tensor([[[ True, False],
         [ True,  True]]])

In [12]:
torch.manual_seed(13)
decoder = Decoder(n_heads=3, d_model=2, ff_units=10, n_features=2)
decoder.init_keys(encoder_states)

query = shifted_seq
out = decoder(query, target_mask=subsequent_mask(2))

decoder.self_attn_heads.alphas

tensor([[[[1.0000, 0.0000],
          [0.4011, 0.5989]]],


        [[[1.0000, 0.0000],
          [0.4264, 0.5736]]],


        [[[1.0000, 0.0000],
          [0.6304, 0.3696]]]])

## target mask(evaluation/prediction)

In [13]:
inputs = source_seq[:, -1:]
trg_masks = subsequent_mask(1)
out = decoder(inputs, trg_masks)
out

tensor([[[0.4132, 0.3728]]], grad_fn=<ViewBackward0>)

In [14]:
inputs = torch.cat([inputs, out[:, -1:, :]], dim=-2)
inputs

tensor([[[-1.0000,  1.0000],
         [ 0.4132,  0.3728]]], grad_fn=<CatBackward0>)

In [15]:
trg_masks = subsequent_mask(2)
out = decoder(inputs, trg_masks)
out

tensor([[[0.4137, 0.3727],
         [0.4132, 0.3728]]], grad_fn=<ViewBackward0>)

## Encoder + Decoder + Self-Attention

In [16]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, input_len, target_len):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.input_len = input_len
        self.target_len = target_len
        self.trg_mask = self.subsequent_mask(self.target_len)
    
    @staticmethod
    def subsequent_mask(size):
        attn_shape = (1, size, size)
        subsequent_mask = (1 - torch.triu(torch.ones(attn_shape), diagonal=1)).bool()
        return subsequent_mask
    
    def encode(self, source_seq, source_mask):
        # Encodes the source sequence and uses the result
        # to initialize the decoder
        encoder_states = self.encoder(source_seq, source_mask)
        self.decoder.init_keys(encoder_states)
        
    def decode(self, shifted_target_seq, source_mask=None, target_mask=None):
        # Decodes/generates a sequence using the shifted (masked) target sequence - used in TRAIN mode
        outputs = self.decoder(shifted_target_seq, source_mask=source_mask, target_mask=target_mask)
        return outputs
    
    def predict(self, source_seq, source_mask):
        # Decodes/generates a sequence using one input at a time - used in EVAL mode
        inputs = source_seq[:, -1:]
        for i in range(self.target_len):
            out = self.decode(inputs, source_mask, self.trg_mask[:, :i+1, :i+1])
            out = torch.cat([inputs, out[:, -1:, :]], dim=-2)
            inputs = out.detach()
        outputs = inputs[:, 1:, :]
        return outputs
    
    def forward(self, X, source_mask=None):
        # Sends the mask to the same device as the inputs
        self.trg_mask = self.trg_mask.type_as(X).bool()
        source_seq = X[:, :self.input_len, :]
        self.encode(source_seq, source_mask)
        if self.traning:
            shifted_target_seq = X[:, self.input_len-1:-1, :]
            outputs = self.decode(shifted_target_seq, source_mask, self.trg_mask)
        else:
            outputs = self.predict(source_seq, source_mask)
        return outputs

## Positional Encoding

In [17]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).float().unsqueeze(1)
        angular_speed = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * angular_speed) # even dimensions
        pe[:, 1::2] = torch.cos(position * angular_speed) # odd dimensions
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, X):
        scaled_x = x * np.sqrt(self.d_model)
        encoded = scaled_x + self.pe[:, :x.size(1), :]
        return encoded

In [18]:
posenc = PositionalEncoding(2, 2)

In [19]:
list(posenc.parameters()), posenc.state_dict()

([],
 OrderedDict([('pe', tensor([[[0.0000, 1.0000],
                        [0.8415, 0.5403]]]))]))

In [20]:
source_seq

tensor([[[-1., -1.],
         [-1.,  1.]]])

In [21]:
source_seq + posenc.pe

tensor([[[-1.0000,  0.0000],
         [-0.1585,  1.5403]]])

In [22]:
class EncoderPe(nn.Module):
    def __init__(self, n_heads, d_model, ff_units, n_features=None, max_len=100):
        super(Encoder, self).__init__()
        pe_dim = d_model if n_features is None else n_features
        self.pe = PositionalEncoding(max_len, pe_dim)
        self.layer = Encoder(n_heads, d_model, ff_units, n_features)
        
    def forward(self, query, mask=None):
        query_pe = self.pe(query)
        out = self.layer(query_pe, mask)
        return out

In [23]:
class DecoderPe(nn.Module):
    def __init__(self, n_heads, d_model, ff_units, n_features=None, max_len=100):
        super(DecoderPe, self).__init__()
        pe_dim = d_model if n_features is None else n_features
        self.pe = PositionalEncoding(max_len, pe_dim)
        self.layer = Decoder(n_heads, d_model, ff_untis, n_features)
        
    def init_keys(self, states):
        self.layer.init_keys(states)
    
    def forward(self, query, source_mask=None, target_mask=None):
        query_pe = self.pe(query)
        out = self.layer(query_pe, source_mask, target_mask)
        return out