In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#default_exp dl.utils.relative_positional
from nbdev.showdoc import show_doc

This notebook is about implementing relative positional encoding introduced in [Shaw et al (2018)](https://arxiv.org/pdf/1803.02155.pdf) and refined by [Huang et al (2018)](https://arxiv.org/pdf/1809.04281.pdf)

Shaw et al. originally introduce two learming parameters ($L^2 D$ size, where $L$ is seq len and $D$ is hidden dimension) to added in both keys and values when computing attentions.

$$
e_{ij} = \frac{x_i W^Q(x_jW^K + a_{ij}^K)^T}{\sqrt{d}} \tag{1}
$$

$$
\alpha_{ij} = \frac{e^{e_{ij}}}{\sum_{k=1}^{k=n}e^{e_{ik}}}
$$

And also to the values:
$
z_i = \sum_{j=1}^{j=n}\alpha_{ij}(x_j W^V + a_{ij}^V) \tag{2}
$

The equation 1 is equivalent to 

$$
\text{RelativeAttention} = \text{Softmax} \left( \frac{Q K^\top + S_{rel}}{\sqrt{D_h}} \right) V \tag{3}
$$

Where $R^T$ is the same as $A^K$ in equation 1
$$
Srel = Q R^T
$$

To address the memory concerns, Huang et al. proposed a skew algorithm by directly computing $S_{rel}$ not using intermediate $R$ at all, thus cut space to only position embedding space. While the idea of not direct computing $Srel$ is, I think the paper made a mistake in how to computing. I don't think it can be computed by adding one column, then shifting.

![Skew algorithm ](img/skew_alg.png)


Here is an illustration how the paper works.


In [6]:
import numpy as np

In [131]:
# L is seq len
L = 5

# Hidden dimsion
D = 1

Q = np.ones((L, D))
# Er is relative positional embeding
Er = np.array(range(-L+1, L))
Er = Er[:, np.newaxis]

In [148]:
Q,Er

(array([[1.],
        [1.],
        [1.],
        [1.],
        [1.]]),
 array([[-4],
        [-3],
        [-2],
        [-1],
        [ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4]]))

In [149]:
QEr = Q @ Er.T

In [150]:
QEr

array([[-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.],
       [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.],
       [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.],
       [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.],
       [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.]])

In [151]:
def naive_skew(QEr):
    rows, cols = QEr.shape
    seq_len = rows
    res = []
    for q in range(rows):
        rel_col_0 = 0 - q + seq_len - 1
        res.append(QEr[q, rel_col_0:rel_col_0+seq_len])
    return np.array(res)

In [152]:
ideal_sel = naive_skew(QEr)
ideal_sel

array([[ 0.,  1.,  2.,  3.,  4.],
       [-1.,  0.,  1.,  2.,  3.],
       [-2., -1.,  0.,  1.,  2.],
       [-3., -2., -1.,  0.,  1.],
       [-4., -3., -2., -1.,  0.]])

In [167]:
# here is illustration how the paper works
# pad a dummy column vector of length L before the leftmost column
QEr_pad = np.pad(QEr, ((0, 0), (0, 1)), constant_values=(-8, -8))
QEr_pad

array([[-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4., -8.],
       [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4., -8.],
       [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4., -8.],
       [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4., -8.],
       [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4., -8.]])

In [168]:
m, n = QEr_pad.shape
m, n

(5, 10)

In [169]:

QEr_shaped = QEr_pad.reshape(n, -1)
Sel = QEr_shaped[:-1].reshape(m, -1)

In [173]:
Sel = Sel[:, -m:]
Sel

array([[ 0.,  1.,  2.,  3.,  4.],
       [-1.,  0.,  1.,  2.,  3.],
       [-2., -1.,  0.,  1.,  2.],
       [-3., -2., -1.,  0.,  1.],
       [-4., -3., -2., -1.,  0.]])

In [174]:
ideal_sel

array([[ 0.,  1.,  2.,  3.,  4.],
       [-1.,  0.,  1.,  2.,  3.],
       [-2., -1.,  0.,  1.,  2.],
       [-3., -2., -1.,  0.,  1.],
       [-4., -3., -2., -1.,  0.]])

In [175]:
np.allclose(ideal_sel, Sel)

True

In [107]:
QEr = Q @ Er.T

def skew(QEr):
    seq_len = QEr.shape[0]
    q_ind = np.arange(seq_len)[:, None]
    k_ind = np.arange(seq_len)[None, :]
    col_ind = k_ind - q_ind + seq_len - 1
    return QEr[q_ind, col_ind]


In [108]:
sel = skew(QEr)
sel

array([[ 0.,  1.,  2.,  3.,  4.],
       [-1.,  0.,  1.,  2.,  3.],
       [-2., -1.,  0.,  1.,  2.],
       [-3., -2., -1.,  0.,  1.],
       [-4., -3., -2., -1.,  0.]])

In [111]:
np.allclose(ideal_sel, sel)

True

In [178]:
# torch version
import torch

In [179]:
# L is seq len
L = 5

# Hidden dimsion
D = 1

Q = torch.ones((L, D), dtype=torch.float32)
# Er is relative positional embeding
Er = torch.arange(-L+1, L, dtype=torch.float32).reshape((-1, D))
Q, Er

(tensor([[1.],
         [1.],
         [1.],
         [1.],
         [1.]]),
 tensor([[-4.],
         [-3.],
         [-2.],
         [-1.],
         [ 0.],
         [ 1.],
         [ 2.],
         [ 3.],
         [ 4.]]))

In [180]:
def skew(QEr):
    seq_len = QEr.shape[0]
    q_ind = torch.arange(seq_len)[:, None]
    k_ind = torch.arange(seq_len)[None, :]
    col_ind = k_ind - q_ind + seq_len - 1
    return QEr[q_ind, col_ind]

In [181]:
QEr = Q @ Er.T
QEr

tensor([[-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.],
        [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.],
        [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.],
        [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.],
        [-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.]])

In [182]:
skew(QEr)

tensor([[ 0.,  1.,  2.,  3.,  4.],
        [-1.,  0.,  1.,  2.,  3.],
        [-2., -1.,  0.,  1.,  2.],
        [-3., -2., -1.,  0.,  1.],
        [-4., -3., -2., -1.,  0.]])

In [190]:
import torch.nn.functional as F

def shift_right2(QEr: torch.Tensor):
    """
    This method shifts $i^{th}$ row of a matrix by $i$ columns.
    If the input is `[[1, 2 ,3], [4, 5 ,6], [7, 8, 9]]`, the shifted
    result would be `[[1, 2 ,3], [0, 4, 5], [6, 0, 7]]`.
    *Ideally we should mask out the lower triangle but it's ok for our purpose*.
    """

    seq_len, e_len = QEr.shape
    padded = F.pad(QEr, (0, 1))
    
    padded = padded.view(e_len + 1, seq_len)
    # Reshape and remove excess elements from the end
    return padded[:-1].view_as(QEr)[:,-seq_len:]

In [191]:
shift_right2(QEr)

tensor([[ 0.,  1.,  2.,  3.,  4.],
        [-1.,  0.,  1.,  2.,  3.],
        [-2., -1.,  0.,  1.,  2.],
        [-3., -2., -1.,  0.,  1.],
        [-4., -3., -2., -1.,  0.]])

In [192]:
def shift_right(x: torch.Tensor):
    """
    This method shifts $i^{th}$ row of a matrix by $i$ columns.
    If the input is `[[1, 2 ,3], [4, 5 ,6], [7, 8, 9]]`, the shifted
    result would be `[[1, 2 ,3], [0, 4, 5], [6, 0, 7]]`.
    *Ideally we should mask out the lower triangle but it's ok for our purpose*.
    """

    # Concatenate a column of zeros
    zero_pad = x.new_zeros(x.shape[0], 1, *x.shape[2:])
    x_padded = torch.cat([x, zero_pad], dim=1)

    # Reshape and remove excess elements from the end
    x_padded = x_padded.view(x.shape[1] + 1, x.shape[0], *x.shape[2:])
    x = x_padded[:-1].view_as(x)
    return x[:, -x.shape[0]:]

In [193]:
shift_right(QEr)

tensor([[ 0.,  1.,  2.,  3.,  4.],
        [-1.,  0.,  1.,  2.,  3.],
        [-2., -1.,  0.,  1.,  2.],
        [-3., -2., -1.,  0.,  1.],
        [-4., -3., -2., -1.,  0.]])

In [185]:
x_= torch.tensor([[1, 2 ,3], [4, 5 ,6], [7, 8, 9]])
x_


tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

In [186]:
shift_right(x_)

tensor([[1, 2, 3],
        [0, 4, 5],
        [6, 0, 7]])

In [203]:
%%time
for i in range(10000):
    shift_right2(QEr)

CPU times: total: 297 ms
Wall time: 325 ms


In [204]:
%%time
for i in range(10000):
    shift_right(QEr)

CPU times: total: 328 ms
Wall time: 306 ms


In [205]:
%%time
for i in range(10000):
    skew(QEr)

CPU times: total: 578 ms
Wall time: 569 ms
