In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence

from helper import look


In [50]:
class Encoder(nn.Module):
    def __init__(self, n_features, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_features = n_features
        self.hidden = None
        self.basic_rnn = nn.GRU(self.n_features,
                                self.hidden_dim,
                                batch_first=True)

    def forward(self, X):
        rnn_out, self.hidden = self.basic_rnn(X)
        return rnn_out # N, L, F


class Decoder(nn.Module):
    def __init__(self, n_features, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_features = n_features
        self.hidden = None
        self.basic_rnn = nn.GRU(self.n_features,
                                self.hidden_dim,
                                batch_first=True)
        self.regression = nn.Linear(self.hidden_dim,
                                    self.n_features)

    def init_hidden(self, hidden_seq):
        # We only need the final hidden state
        hidden_final = hidden_seq[:, -1:] # N, 1, H
        # But we need to make it sequence-first
        self.hidden = hidden_final.permute(1, 0, 2) # 1, N, H

    def forward(self, X):
        # X is N, 1, F
        batch_first_output, self.hidden = self.basic_rnn(X, self.hidden)
        last_output = batch_first_output[:, -1:]
        out = self.regression(last_output)
        return out.view(-1, 1, self.n_features)        

In [51]:
full_seq = torch.FloatTensor([[-1, -1], [-1, 1], [1, 1], [1, -1]]).view(1, 4, 2)
source_seq = full_seq[:, :2]
target_seq = full_seq[:, 2:]

look("X", source_seq)
look("Y", target_seq)

X $\begin{bmatrix} \begin{bmatrix} -1.0 & -1.0 \\ -1.0 & 1.0\end{bmatrix}\end{bmatrix}$

Y $\begin{bmatrix} \begin{bmatrix} 1.0 & 1.0 \\ 1.0 & -1.0\end{bmatrix}\end{bmatrix}$

In [52]:
torch.manual_seed(21)
encoder = Encoder(n_features=2, hidden_dim=2)
hidden_seq = encoder(source_seq)
values = hidden_seq # N, L, H values
look("values", values)

values $\begin{bmatrix} \begin{bmatrix} 0.0832 & -0.0356 \\ 0.311 & -0.526\end{bmatrix}\end{bmatrix}$

In [53]:
keys = hidden_seq # N, L, H keys
look("keys", keys)

keys $\begin{bmatrix} \begin{bmatrix} 0.0832 & -0.0356 \\ 0.311 & -0.526\end{bmatrix}\end{bmatrix}$

In [54]:
torch.manual_seed(21)
decoder = Decoder(n_features=2, hidden_dim=2)
decoder.init_hidden(hidden_seq)
inputs = source_seq[:, -1:]
out = decoder(inputs)

look("inputs", inputs)
look("out", out)

inputs $\begin{bmatrix} \begin{bmatrix} -1.0 & 1.0\end{bmatrix}\end{bmatrix}$

out $\begin{bmatrix} \begin{bmatrix} -0.234 & 0.47\end{bmatrix}\end{bmatrix}$

In [55]:
query = decoder.hidden.permute(1, 0, 2) # N, 1, H query

look("query", query)

query $\begin{bmatrix} \begin{bmatrix} 0.391 & -0.685\end{bmatrix}\end{bmatrix}$

In [29]:
def calc_alphas(ks, q):
    N, L, H = ks.size()
    alphas = torch.ones(N, 1, L).float() * 1/L 
    return alphas

alphas = calc_alphas(keys, query)
look("alphas", alphas)

alphas $\begin{bmatrix} \begin{bmatrix} 0.5 & 0.5\end{bmatrix}\end{bmatrix}$

In [72]:
nlh = torch.rand(2, 3, 4)
n1l = torch.rand(2, 1, 3)

look(n1l, nlh)
look(torch.bmm(n1l, nlh))

$\begin{bmatrix} \begin{bmatrix} 0.821 & 0.871 & 0.735\end{bmatrix} & \begin{bmatrix} 0.555 & 0.669 & 0.306\end{bmatrix}\end{bmatrix}$ $\begin{bmatrix} \begin{bmatrix} 0.702 & 0.554 & 0.437 & 0.281 \\ 0.205 & 0.885 & 0.917 & 0.711 \\ 0.508 & 0.023 & 0.277 & 0.156\end{bmatrix} & \begin{bmatrix} 0.577 & 0.929 & 0.205 & 0.385 \\ 0.0596 & 0.245 & 0.26 & 0.671 \\ 0.562 & 0.129 & 0.438 & 0.161\end{bmatrix}\end{bmatrix}$

$\begin{bmatrix} \begin{bmatrix} 1.13 & 1.24 & 1.36 & 0.964\end{bmatrix} & \begin{bmatrix} 0.532 & 0.719 & 0.422 & 0.711\end{bmatrix}\end{bmatrix}$

tensor([[1.1279, 1.2421, 1.3610, 0.9643]])

In [31]:
# N, 1, L x N, L, H -> 1, L x L, H -> 1, H
context_vector = torch.bmm(alphas, values)
look("context vector", context_vector)

context vector $\begin{bmatrix} \begin{bmatrix} 0.197 & -0.281\end{bmatrix}\end{bmatrix}$

In [32]:
concatenated = torch.cat([context_vector, query], axis=-1)
look("concatenated", concatenated)

concatenated $\begin{bmatrix} \begin{bmatrix} 0.197 & -0.281 & 0.391 & -0.685\end{bmatrix}\end{bmatrix}$

In [33]:
# N, 1, H x N, H, L -> N, 1, L
products = torch.bmm(query, keys.permute(0, 2, 1))
look("products", products)


products $\begin{bmatrix} \begin{bmatrix} 0.0569 & 0.482\end{bmatrix}\end{bmatrix}$

In [40]:
alphas = F.softmax(products, dim=-1)
look("alphas", alphas)

def calc_alphas(ks, q):
    # N, 1, H x N, H, L -> N, 1, L
    products = torch.bmm(q, ks.permute(0, 2, 1)) 
    alphas = F.softmax(products, dim=-1)
    return alphas


look("calc alphas", calc_alphas(keys, query))


alphas $\begin{bmatrix} \begin{bmatrix} 0.395 & 0.605\end{bmatrix}\end{bmatrix}$

calc alphas $\begin{bmatrix} \begin{bmatrix} 0.395 & 0.605\end{bmatrix}\end{bmatrix}$

In [45]:
def calc_alphas(ks, q):
    dims = q.size(-1)
    print(dims)
    # N, 1, H x N, H, L -> N, 1, L
    products = torch.bmm(q, ks.permute(0, 2, 1)) 
    scaled_products = products / dims ** 0.5
    alphas = F.softmax(scaled_products, dim=-1) 
    return alphas

alphas = calc_alphas(keys, query)
look("alphas", alphas)
context_vector = torch.bmm(alphas, values) 
look("context_vector", context_vector)


2


alphas $\begin{bmatrix} \begin{bmatrix} 0.425 & 0.575\end{bmatrix}\end{bmatrix}$

context_vector $\begin{bmatrix} \begin{bmatrix} 0.214 & -0.318\end{bmatrix}\end{bmatrix}$

In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_dim, input_dim=None, proj_values=False):
        super().__init__()
        self.d_k = hidden_dim
        self.input_dim = hidden_dim \
                         if input_dim is None \
                         else \
                         input_dim

        self.proj_values = proj_values
        # Affine transformations for Q, K, and V
        self.linear_query = torch.nn.Linear(self.input_dim, hidden_dim)
        self.linear_key = torch.nn.Linear(self.input_dim, hidden_dim)
        self.linear_value = torch.nn.Linear(self.input_dim, hidden_dim)
        self.alphas = None

    def init_keys(self, keys):
        self.keys = keys
        self.proj_keys = self.linear_key(self.keys)
        self.values = self.linear_value(self.keys) \
                      if self.proj_values \
                      else \
                      self.keys

    def score_function(self, query):
        proj_query = self.linear_query(query)
        # scaled dot product
        # N, 1, H x N, H, L -> N, 1, L
        dot_products = torch.bmm(proj_query,
                                 self.proj_keys.permute(0, 2, 1))
        scores = dot_products / self.d_k ** 0.5 
        return scores

    def forward(self, query, mask=None):
        # Query is batch-first N, 1, H
        scores = self.score_function(query) # N, 1, L 1 
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9) 
        alphas = F.softmax(scores, dim=-1) # N, 1, L 2 
        self.alphas = alphas.detach()
        # N, 1, L x N, L, H -> N, 1, H
        context = torch.bmm(alphas, self.values)
        return context        

In [7]:
w1 = torch.tensor([1, 0, 0])
w2 = torch.tensor([0, 1, 0])
w3 = torch.tensor([1, 1, 0])
w4 = torch.tensor([0, 0, 1])

Wq = torch.randint(3, (3, 3))
Wk = torch.randint(3, (3, 3))
Wv = torch.randint(3, (3, 3))
look(Wq, Wk, Wv)

$\begin{bmatrix} 2 & 2 & 1 \\ 1 & 0 & 2 \\ 2 & 1 & 1\end{bmatrix}$ $\begin{bmatrix} 1 & 1 & 2 \\ 2 & 0 & 0 \\ 1 & 2 & 2\end{bmatrix}$ $\begin{bmatrix} 0 & 0 & 2 \\ 1 & 0 & 2 \\ 2 & 0 & 0\end{bmatrix}$

In [38]:
lengths = [len(x_i) for x_i in seq]
packed_batch_first = pack_padded_sequence(padded_batch_first, lengths, batch_first=True)
packed_batch_second = pack_padded_sequence(padded_batch_second, lengths, batch_first=False)


look(packed_batch_first.data, packed_batch_first.batch_sizes)
look(packed_batch_second.data, packed_batch_second.batch_sizes)



$\begin{bmatrix} 1 & 5 & 8 & 2 & 6 & 3 & 7 & 4\end{bmatrix}$ $\begin{bmatrix} 3 & 2 & 2 & 1\end{bmatrix}$

$\begin{bmatrix} 1 & 5 & 8 & 2 & 6 & 3 & 7 & 4\end{bmatrix}$ $\begin{bmatrix} 3 & 2 & 2 & 1\end{bmatrix}$

In [43]:
unpacked_first, batch_first_sizes = pad_packed_sequence(packed_batch_first, batch_first=True)

look(unpacked_first, batch_first_sizes)

$\begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 0 \\ 8 & 0 & 0 & 0\end{bmatrix}$ $\begin{bmatrix} 4 & 3 & 1\end{bmatrix}$

In [42]:
unpacked_second, batch_second_sizes = pad_packed_sequence(packed_batch_second, batch_first=False)

look(unpacked_second, batch_second_sizes)

$\begin{bmatrix} 1 & 5 & 8 \\ 2 & 6 & 0 \\ 3 & 7 & 0 \\ 4 & 0 & 0\end{bmatrix}$ $\begin{bmatrix} 4 & 3 & 1\end{bmatrix}$

# Torch Attention

```
torch.nn.MultiheadAttention(embed_dim, 
                            num_heads, 
                            dropout=0.0, 
                            bias=True, 
                            add_bias_kv=False, 
                            add_zero_attn=False, 
                            kdim=None, 
                            vdim=None, 
                            batch_first=False, 
                            device=None, 
                            dtype=None)
```

In [37]:
torch.manual_seed(42)
embed_dim = 2
mhe = torch.nn.MultiheadAttention(embed_dim=embed_dim, num_heads=1)

look("## Model parameters")
in_proj_weight = mhe.in_proj_weight.data.detach().clone()
in_proj_bias = mhe.in_proj_bias.data.detach().clone()
look("$W_{in}=$", in_proj_weight, "$b_{in}=$", in_proj_bias)

out_proj_weight = mhe.out_proj.weight.detach().clone()
out_proj_bias = mhe.out_proj.bias.detach().clone()
look("$W_{out}=$", out_proj_weight, "$b_{out}=$", out_proj_bias)
look("<hr>")

Wq, Wk, Wv = torch.chunk(in_proj_weight, 3, dim=0)
bq, bk, bv = torch.chunk(in_proj_bias, 3, dim=0)
look("$W_q=$", Wq, "$b_q=$", bq)
look("$W_k=$", Wk, "$b_k=$", bk)
look("$W_v=$", Wv, "$b_v=$", bv)


look("## Data")
L = 3
q = torch.rand(L, embed_dim)
k = torch.rand(L, embed_dim)
v = torch.rand(L, embed_dim)
look("$q=$", q, "$k=$", k, "$v=$", v)

look("## Process")
# pq = torch.sum(Wq * q, dim=1) + bq
# pk = torch.sum(Wk * k, dim=1) + bk
# pv = torch.sum(Wv * v, dim=1) + bv
# look(pq, pk, pv)

# torch.bmm(pq, self.proj_keys.permute(0, 2, 1))


## Model parameters

$W_{in}=$ $\begin{bmatrix} -0.422 & 0.509 \\ 0.763 & -0.635 \\ 0.753 & 0.162 \\ 0.64 & 0.117 \\ 0.418 & -0.122 \\ 0.668 & 0.128\end{bmatrix}$ $b_{in}=$ $\begin{bmatrix} 0.0 & 0.0 & 0.0 & 0.0 & 0.0 & 0.0\end{bmatrix}$

$W_{out}=$ $\begin{bmatrix} 0.541 & 0.587 \\ -0.166 & 0.65\end{bmatrix}$ $b_{out}=$ $\begin{bmatrix} 0.0 & 0.0\end{bmatrix}$

<hr>

$W_q=$ $\begin{bmatrix} -0.422 & 0.509 \\ 0.763 & -0.635\end{bmatrix}$ $b_q=$ $\begin{bmatrix} 0.0 & 0.0\end{bmatrix}$

$W_k=$ $\begin{bmatrix} 0.753 & 0.162 \\ 0.64 & 0.117\end{bmatrix}$ $b_k=$ $\begin{bmatrix} 0.0 & 0.0\end{bmatrix}$

$W_v=$ $\begin{bmatrix} 0.418 & -0.122 \\ 0.668 & 0.128\end{bmatrix}$ $b_v=$ $\begin{bmatrix} 0.0 & 0.0\end{bmatrix}$

## Data

$q=$ $\begin{bmatrix} 0.267 & 0.627 \\ 0.27 & 0.441 \\ 0.297 & 0.832\end{bmatrix}$ $k=$ $\begin{bmatrix} 0.105 & 0.269 \\ 0.359 & 0.199 \\ 0.547 & 0.00616\end{bmatrix}$ $v=$ $\begin{bmatrix} 0.952 & 0.0753 \\ 0.886 & 0.583 \\ 0.338 & 0.809\end{bmatrix}$

## Process

In [38]:
mhe.forward(q, k, v)

(tensor([[0.4518, 0.3147],
         [0.4517, 0.3146],
         [0.4518, 0.3147]], grad_fn=<SqueezeBackward1>),
 tensor([[0.3319, 0.3336, 0.3345],
         [0.3316, 0.3336, 0.3348],
         [0.3320, 0.3336, 0.3344]], grad_fn=<SqueezeBackward1>))