# Attention-based RNN for Seq2Seq

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

## Encoder

In [2]:
class Encoder(nn.Module):
    def __init__(self, hidden_dim, n_features, num_layers=1):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_features = n_features
        self.num_layers = num_layers
        self.hidden = None
        self.rnn = nn.GRU(self.n_features, self.hidden_dim, self.num_layers, batch_first=True)
        
    def forward(self, X):
        output, self.hidden = self.rnn(X)
        return output

## Decoder

In [3]:
class Decoder(nn.Module):
    def __init__(self, hidden_dim, n_features, num_layers=1):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_features = n_features
        self.num_layers = num_layers
        self.hidden = None
        self.rnn = nn.GRU(self.n_features, self.hidden_dim, self.num_layers, batch_first=True)
        self.regression = nn.Linear(self.hidden_dim, self.n_features)
        
    def init_hidden(self, hidden_seq):
        hidden_final = hidden_seq[:, -1:]
        self.hidden = hidden_final.permute(1, 0, 2)
        
    def forward(self, X):
        output, self.hidden = self.rnn(X, self.hidden)
        last_output = output[:, -1:]
        out = self.regression(last_output)
        
        return out.view(-1, 1, self.n_features)

## EncoderDecoder

In [41]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, input_len, target_len, teacher_forcing_prob):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.input_len = input_len
        self.target_len = target_len
        self.teacher_forcing_prob = teacher_forcing_prob
        self.outputs = None
    
    def init_outputs(self, batch_size):
        self.outputs = torch.zeros(batch_size, self.target_len, self.encoder.n_features)
    
    def store_output(self, i, out):
        self.outputs[:, i:i+1, :] = out
        
    def forward(self, X):
        source_seq = X[:, :self.input_len, :]
        target_seq = X[:, self.input_len:, :]
        self.init_outputs(X.shape[0])
        
        hidden_seq = self.encoder(source_seq)
        self.decoder.init_hidden(hidden_seq)
        
        des_inputs = source_seq[:, -1:]
        
        for i in range(self.target_len):
            out = self.decoder(des_inputs)
            self.store_output(i, out)
            
            prob = self.teacher_forcing_prob
            
            if not self.training:
                prob = 0
            
            if torch.rand(1) <= prob:
                des_inputs = target_seq[:, i:i+1, :]
            else:
                des_inputs = out
        return self.outputs

## Attention

In [4]:
full_seq = torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]]).float().view(1, 4, 2)
full_seq

tensor([[[-1., -1.],
         [-1.,  1.],
         [ 1.,  1.],
         [ 1., -1.]]])

In [5]:
source_seq = full_seq[:, :2, :]
target_seq = full_seq[:, 2:, :]
source_seq, target_seq

(tensor([[[-1., -1.],
          [-1.,  1.]]]),
 tensor([[[ 1.,  1.],
          [ 1., -1.]]]))

In [6]:
torch.manual_seed(21)
encoder = Encoder(n_features=2, hidden_dim=2)
hidden_seq = encoder(source_seq)
values = hidden_seq
values

tensor([[[ 0.0832, -0.0356],
         [ 0.3105, -0.5263]]], grad_fn=<TransposeBackward1>)

In [7]:
encoder.hidden

tensor([[[ 0.3105, -0.5263]]], grad_fn=<StackBackward0>)

In [8]:
keys = hidden_seq
keys

tensor([[[ 0.0832, -0.0356],
         [ 0.3105, -0.5263]]], grad_fn=<TransposeBackward1>)

In [9]:
torch.manual_seed(21)
decoder = Decoder(n_features=2, hidden_dim=2)
decoder.init_hidden(hidden_seq)

inputs = source_seq[:, -1:, :]
out = decoder(inputs)

In [10]:
decoder.hidden

tensor([[[ 0.3913, -0.6853]]], grad_fn=<StackBackward0>)

In [11]:
query = decoder.hidden.permute(1, 0, 2)
query

tensor([[[ 0.3913, -0.6853]]], grad_fn=<PermuteBackward0>)

In [12]:
def calc_alphas(ks, q):
    N, L, H = ks.size()
    alphas = torch.ones(N, 1, L).float() * 1/L
    return alphas

alphas = calc_alphas(keys, query)
alphas

tensor([[[0.5000, 0.5000]]])

In [13]:
context_vector = torch.bmm(alphas, values)
context_vector

tensor([[[ 0.1968, -0.2809]]], grad_fn=<BmmBackward0>)

In [15]:
concatenated = torch.cat([context_vector, query], axis=-1)
concatenated

tensor([[[ 0.1968, -0.2809,  0.3913, -0.6853]]], grad_fn=<CatBackward0>)

In [18]:
products = torch.bmm(query, keys.permute(0, 2, 1)) #alignment scores
products

tensor([[[0.0569, 0.4821]]], grad_fn=<BmmBackward0>)

In [22]:
alphas = F.softmax(products, dim=-1) # attention scores
alphas

tensor([[[0.3953, 0.6047]]], grad_fn=<SoftmaxBackward0>)

In [23]:
def calc_alphas(ks, q):
    products = torch.bmm(q, ks.permute(0, 2, 1))
    alphas = F.softmax(products, dim=-1)
    return alphas

In [26]:
dim = query.size(-1)
scaled_products = products / np.sqrt(dim)
scaled_products

tensor([[[0.0403, 0.3409]]], grad_fn=<DivBackward0>)

In [27]:
def calc_alphas(ks, q):
    dims = q.size(-1)
    products = torch.bmm(q, ks.permute(0, 2, 1))
    scaled_products = products / np.sqrt(dim)
    alphas = F.softmax(scaled_products, dim=-1)
    return alphas

In [28]:
alphas = calc_alphas(keys, query)
context_vector = torch.bmm(alphas, values)
context_vector

tensor([[[ 0.2138, -0.3175]]], grad_fn=<BmmBackward0>)

In [29]:
class Attention(nn.Module):
    def __init__(self, hidden_dim, input_dim=None, proj_values=False):
        super(Attention, self).__init__()
        self.d_k = hidden_dim
        self.input_dim = hidden_dim if input_dim is None else input_dim
        self.proj_values = proj_values
        # Affine transformation for q, k , v
        self.linear_query = nn.Linear(self.input_dim, hidden_dim)
        self.linear_key = nn.Linear(self.input_dim, hidden_dim)
        self.linear_value = nn.Linear(self.input_dim, hidden_dim)
        self.alphas = None
    
    def init_keys(self, keys):
        self.keys = keys
        self.proj_keys = self.linear_key(self.keys)
        self.values = self.linear_value(self.keys) if self.proj_values else self.keys
        
    # alignment scores
    def score_function(self, query):
        proj_query = self.linear_query(query)
        dot_products = torch.bmm(proj_query, self.proj_keys.permute(0, 2, 1))
        scores = dot_products / np.sqrt(self.d_k)
        return scores
    
    def forward(self, query, mask=None):
        scores = self.score_function(query)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        alphas = F.softmax(scores, dim=-1)
        self.alphas = alphas.detach()
        
        context = torch.bmm(alphas, self.values)
        return context

In [30]:
source_seq = torch.tensor([[[-1., 1.], [0., 0.]]])
# pretend there's an encoder here...
keys = torch.tensor([[[-.38, .44], [.85, -.05]]])
query = torch.tensor([[[-1., 1.]]])

In [31]:
source_mask = (source_seq != 0).all(axis=2).unsqueeze(1)
source_mask # N, 1, L

tensor([[[ True, False]]])

In [32]:
torch.manual_seed(11)
attnh = Attention(2)
attnh.init_keys(keys)

context = attnh(query, mask=source_mask)
attnh.alphas

tensor([[[1., 0.]]])

## Decoder + Attention

In [39]:
class DecoderAttn(nn.Module):
    def __init__(self, n_features, hidden_dim):
        super(DecoderAttn, self).__init__()
        self.hidden_dim = hidden_dim
        self.n_features = n_features
        self.hidden = None
        self.rnn = nn.GRU(self.n_features, self.hidden_dim, batch_first=True)
        self.attn = Attention(self.hidden_dim)
        self.regression = nn.Linear(2 * self.hidden_dim, self.n_features)
    
    def init_hidden(self, hidden_seq):
        self.attn.init_keys(hidden_seq)
        hidden_final = hidden_seq[:, -1:]
        self.hidden = hidden_final.permute(1, 0, 2)
        
    def forward(self, X, mask=None):
        output, self.hidden = self.rnn(X, self.hidden)
        query = output[:, -1:]
        context = self.attn(query, mask=mask)
        concatenated = torch.cat([context, query], axis=-1)
        out = self.regression(concatenated)
        
        return out.view(-1, 1, self.n_features)

## Encoder + Decoder + Attention

In [34]:
full_seq = torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]]).float().view(1, 4, 2)
full_seq

tensor([[[-1., -1.],
         [-1.,  1.],
         [ 1.,  1.],
         [ 1., -1.]]])

In [35]:
source_seq = full_seq[:, :2, :]
target_seq = full_seq[:, 2:, :]
source_seq, target_seq

(tensor([[[-1., -1.],
          [-1.,  1.]]]),
 tensor([[[ 1.,  1.],
          [ 1., -1.]]]))

In [40]:
torch.manual_seed(21)
encoder = Encoder(n_features=2, hidden_dim=2)
decoder_attn = DecoderAttn(n_features=2, hidden_dim=2)

hidden_seq = encoder(source_seq)
decoder_attn.init_hidden(hidden_seq)

input = source_seq[:, -1:]
target_len = 2
for i in range(target_len):
    out = decoder_attn(inputs)
    print(f'Output {out}')
    inputs = out

Output tensor([[[-0.3555, -0.1220]]], grad_fn=<ViewBackward0>)
Output tensor([[[-0.2641, -0.2521]]], grad_fn=<ViewBackward0>)


In [42]:
encdec = EncoderDecoder(encoder, decoder_attn, input_len=2, target_len=2, teacher_forcing_prob=0.0)
encdec(full_seq)

tensor([[[-0.3555, -0.1220],
         [-0.2641, -0.2521]]], grad_fn=<CopySlices>)

## Muti-Headed Attention

In [44]:
class MutiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, input_dim=None, proj_values=True):
        super(MutiHeadAttention, self).__init__()
        self.linear_out = nn.Linear(n_heads * d_model, d_model)
        self.attn_heads = nn.ModuleList(
            [Attention(d_model, input_dim=input_dim, proj_values=proj_values) for _ in range(n_heads)]
        )
    
    def init_keys(self, key):
        for attn in self.attn_heads:
            attn.init_keys(key)
    
    @property
    def alphas(self):
        return torch.stack(
            [attn.alphas for attn in self.attn_heads], dim=0
        )
    
    def output_function(self, contexts):
        concatenated = torch.cat(contexts, axis=-1)
        out = self.linear_out(concatenated)
        return out
    
    def forward(self, query, mask=None):
        contexts = [attn(query, mask=mask) for attn in self.attn_heads]
        out = self.output_function(contexts)
        return out