[Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf)

In [None]:
import pytorch
import torch.nn

Transformer consists of an encoder and decoder part:

<img src="transformer_architecture.png" width="500">

Input is embeded into a vector representation and fed into the encoder. Embeding can be learnt or we can use one of the basic embeding options.


### Encoder

Encoder is built from N stacked encoder blocks.

<img src="encoder_block.png">

Each encoder block has the output size of d_model which in paper is set to 512.
Encoder block consists of:
- multi head attention block
- fully connected feed forward block
- layer normalization following each of the sub-layers

Input to each sub-block is added to the output of sub-block with a residual connection (evident with the arrow connection around the element) so output of each layer is $LayerNorm(x+Sublayer(x))$. All the sub-layers and embedding layer produce the same dimension of d_model tensors, to allow for the residual connections.

### Attention

Attention according to the paper is a mapping from a query and a set of key-value pairs to an output all of which are vectors. Mapping is calculated as a weighted sum of the values, where the weights are obtained by computing a compatibility function of the query and corresponding key.

<figure>
    <img src="attention.png">
    <figcaption>A represents the compatibility function between query and key; B is applying the weight to the corresponding value</figcaption>
</figure>

### Scaled Dot-Product Attention

Input is queries and keys of dimension $d_{k}$ and values of dimension $d_{v}$. Formula for a single query $q$ is
$$
Attention(q,K,V) = softmax(\frac{qK^{T}}{\sqrt{d_{k}}})V
$$
where q is a single query, K is all keys in matrix form and V is all values in matrix form. This is a simple dot-product attention that gets scaled with $\sqrt{d_{k}}$ to keep the speed of optimized dot product and save the dot-product from getting too large for softmax. This manoeuvre brings it in line with additive attention while making it much more computationally efficient.

Values, keys and queries are all calculated with their respective weighted matrices eg: $q = W_{q}x$ where q is query vector, x is embeding vector of input sequence and $W_{q}$ is the query learnable weight matrix



In [None]:
class Attention(torch.nn.Module):
    #d_k = d_K/n_heads
    def __init__(self,d_model,d_k):
        super.__init__()
        self.weightQ = torch.nn.Linear(d_model,d_k,bias = False)
        self.weightK = torch.nn.Linear(d_model,d_k,bias = False)
        self.weightV = torch.nn.Linear(d_model,d_k,bias = False)

    def forward(self,input):
        query = self.weightQ(input)
        key = self.weightK(input)
        value = self.weightV(input)
        return torch.bmm(torch.nn.Softmax(torch.bmm(query,torch.transpose(key,1,2))/torch.sqrt(query.size(dim=2))),value)