In [194]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

## Sublayers
Note the usage of Conv1d here for multiple attention heads. This is supposed to replicate the linear transformation represented for each head in one structure, but make sure it's correct.

In [232]:
class AttentionLayer(nn.Module):
    def __init__(self, embedding_dim, heads=1):
        super(AttentionLayer, self).__init__()
        self.w_query = nn.Conv1d(embedding_dim, heads*embedding_dim, 1)
        self.w_key = nn.Conv1d(embedding_dim, heads*embedding_dim, 1)
        self.w_value = nn.Conv1d(embedding_dim, heads*embedding_dim, 1)
        nn.init.normal_(self.w_query.weight, mean=0, std=np.sqrt(2.0 / (embedding_dim + (heads*embedding_dim))))
        nn.init.normal_(self.w_key.weight, mean=0, std=np.sqrt(2.0 / (embedding_dim + (heads*embedding_dim))))
        nn.init.normal_(self.w_value.weight, mean=0, std=np.sqrt(2.0 / (embedding_dim + (heads*embedding_dim))))
        self.scale_factor = 1. / np.sqrt(embedding_dim)
        self.w_out = nn.Linear(heads*embedding_dim, embedding_dim)
        self.layer_norm = nn.LayerNorm(embedding_dim)
    
    def forward(self, input, mask=None):
        if mask is None:
            masked_input = input
        else:
            masked_input = mask * input
        q = self.w_query(masked_input.transpose(1, 2)) 
        k = self.w_key(masked_input.transpose(1, 2)) 
        v = self.w_value(masked_input.transpose(1, 2))
        print(torch.bmm(q, k.transpose(1, 2)).shape)
        attention_weights = F.softmax(torch.bmm(q, k.transpose(1, 2)), dim=-1) / self.scale_factor
        output = torch.bmm(attention_weights, v).transpose(1, 2)
        output = self.w_out(output)
        output = self.layer_norm(input + output)
        return output, attention_weights


class PositionwiseFeedForwardLayer(nn.Module):
    def __init__(self, embedding_dim, hidden_dim):
        super(PositionwiseFeedForwardLayer, self).__init__()
        self.w_1 = nn.Conv1d(embedding_dim, hidden_dim, 1) 
        self.w_2 = nn.Conv1d(hidden_dim, embedding_dim, 1) 
        self.layer_norm = nn.LayerNorm(embedding_dim)

    def forward(self, input):
        output = self.w_2(F.relu(self.w_1(input.transpose(1, 2))))
        output = self.layer_norm(input + output.transpose(1, 2))
        return output


## Top-level Layers

In [233]:
''' Taken from http://nlp.seas.harvard.edu/2018/04/03/attention.html#positional-encoding'''
class PositionalEncoding(nn.Module):
    def __init__(self, seq_len, embedding_dim):
        super(PositionalEncoding, self).__init__()        

        # Compute the positional encodings once in log space.
        pe = torch.zeros(seq_len, embedding_dim)
        position = torch.arange(0., seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., embedding_dim, 2) *
                             -(math.log(10000.0) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return x

class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, seq_len, embedding_dim):
        super(EmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoding = PositionalEncoding(seq_len, embedding_dim)
    
    def forward(self, input):
        return self.pos_encoding(self.embedding(input) * np.sqrt(embedding_dim))


class EncoderLayer(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, heads=1):
        super(EncoderLayer, self).__init__()
        self.attention = AttentionLayer(embedding_dim, heads=heads)
        self.pwff = PositionwiseFeedForwardLayer(embedding_dim, hidden_dim)

    def forward(self, input, detailed=False):
        output, attention_weights = self.attention(input)
        output = self.pwff(output)
        if detailed:
            return output, attention_weights
        else:
            return output
        
class DecoderLayer(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, heads=1):
        super(DecoderLayer, self).__init__()
        self.attention = AttentionLayer(embedding_dim, heads=heads)
        self.pwff = PositionwiseFeedForwardLayer(embedding_dim, hidden_dim)

    def forward(self, input, detailed=False):
        output, attention_weights = self.attention(input)
        output = self.pwff(output)
        if detailed:
            return output, attention_weights
        else:
            return output

In [234]:
sequence = torch.LongTensor([[0, 1, 2], [3, 4, 5]])
label = torch.LongTensor([[2, 1, 0, 6], [5, 4, 3, 6]])

To facilitate LayerNorm and residual connections, all vector dimensions will be the same (embedding dimension and output dimensions, i.e. anything that needs to be added together, hence why ```EncoderLayer``` dimensions don't require a hidden dimension for the ```AttentionLayer``` (all dimensions are 10).

In [235]:
vocab_size = 6
seq_len = 3
embedding_dim = 10
embedding_layer = EmbeddingLayer(vocab_size, vocab_size, embedding_dim)
encoder_layer = EncoderLayer(embedding_dim, 5, heads=2)

In [236]:
encoder_layer(embedding_layer(sequence))

torch.Size([2, 20, 20])


tensor([[[ 1.1137, -0.3880, -1.5422, -0.4232,  0.6086,  1.4954, -1.0851,
          -1.1498,  0.7424,  0.6282],
         [ 1.5668, -0.4738, -0.0204, -0.1533, -0.0912,  1.6796, -0.7200,
          -1.9743,  0.2231, -0.0365],
         [-0.7582,  0.1845,  0.8691, -1.0817,  1.0172,  0.2928, -1.9719,
          -0.5819,  0.8459,  1.1843]],

        [[ 0.6240, -1.1185, -0.9015,  0.9725, -0.9600,  1.7241, -1.1003,
          -0.4100,  0.0177,  1.1522],
         [ 1.7745, -0.4261, -0.8353,  1.4753, -0.3191,  0.1992, -1.5761,
          -0.9651,  0.2581,  0.4144],
         [-0.8022,  0.9486,  0.3475, -0.2479, -0.1137,  2.1452, -0.9754,
          -0.5459, -1.4086,  0.6525]]], grad_fn=<AddcmulBackward>)

## TODO

How to deal with masking, how to get sequential outputs from decoder, how to feed inputs to decoder (like the memory from encoder, the source mask, the target mask, etc.)?