# Retention Network

## LLM Challenges
Retentive Network (RetNet) is a foundational architecture proposed for large language models in the paper [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/abs/2307.08621). This architecture is designed to address key challenges in the realm of large-scale language modeling: training parallelism, low-cost inference, and good performance.

* **Training parallelism**: During the training, all of the input tokens are processed at the same time, utilizing the GPUs parallel processing power. One example of this is the Transformer architecture; where the self-attention inside the decoder allows token generation that does not depend on the previously generated output. 
* **Low-cost inference**: During the inference, the cost does not scale with sequence length. One example of this is the Recurrent Neural Network (RNN) architecture; where it uses simple and cheap operation like matrix multiplication to process one token at each time step.
* **Good performance**: Transformer and RNN are both excellent, but Transformer has high-cost inference and RNN's training is not parallelizable. On the other hand, linear transformers are parallelizable and its inference is made cheap by sequential processing, but it has poor performance.

RetNet addresses all of these challenges thanks to its multi-scale retention mechanism, which will be explained below.

Read more about these explanations [here](https://medium.com/ai-fusion-labs/retentive-networks-retnet-explained-the-much-awaited-transformers-killer-is-here-6c17e3e8add8).


## RetNet Architecture

<center><img src="Retention_imgs/RetNet Architecture.jpg" style="height: 400px; width:auto;"></center>


Retentive Network has a similar architecture to the Transformer's encoder, but with a few differences:
    <ol>
    <li> The encoder precedes the Feed Forward Network (FFN) and token mixer layer (multi-scale retention (MSR) layer).
    <li> The multi-scale retention layer (the proposed method) replaces the multi-head attention layer.
    </ol>
    
The input sequence ${x}^{|x|}_{i=1}$ is transformed to vectors by a word embedding layer and then packed to form a matrix $X^0=[x_1, \cdot\cdot\cdot, x_{|x|}] \in \mathbb{R}^{|x|
\times d_{model}}$. The computation of the output of the $l$-th layer is as follows:

$$
\begin{aligned}
    Y^l &= \text{MSR}(\text{LN}(X^l)) + X^l \\
    X^{l+1} &= \text{FFN}(\text{LN}(Y^l) + Y^l
\end{aligned}
$$

where $\text{LN}$ is LayerNorm and $\text{FFN}$ is the feed forward network computed as $\text{FFN}(X) = \text{gelu}(XW_1)W_2$ where $W_1$ and $W_2$ are learnable parameters. 

Read more:
- [Embedding layer](https://lena-voita.github.io/nlp_course/word_embeddings.html)
- [gelu](https://arxiv.org/pdf/1606.08415v5.pdf)
- [LayerNorm](https://arxiv.org/abs/1607.06450)

## Multi-scale Retention 

<center><img src="Retention_imgs/Multi Scale Retention.jpg" style="height: 800px; width:auto;"></center>

The detail of the computation of the multi-scale retention layer is as follows:

$$
\begin{aligned}
    \gamma &= 1-2^{-5-arange(0,h)} \\
    head_i &= \text{Retention}(X, \gamma_i) \\
    Y &= GroupNorm_h(Concat(head_1, \cdot\cdot\cdot, head_h)) \\
    \text{MSR}(X) &= (swish(XW_G) \odot Y)W_O
\end{aligned}
$$

where $h=\frac{d_{model}}{d}$ is the number of heads, $d$ is the head dimension, $arange(0,h)$ is the range of integers from 0 to $h$, $\gamma$ is the retention scale, $head_i$ is the output of the $i$-th head, $GroupNorm_h$ is the group normalization applied on each head, $\odot$ is the element-wise multiplication, $W_G$ and $W_O$ are learnable parameters, and $swish$ is the swish activation function.

Let's consider a case like in the figure above, where $h=3$, $|x|=2$, and $d_{model}=4$ (the dimension of the head is usually an integer as $d_{model}$ is the multiplication $d$. Here, $d_{model}=4$ is used only for illustration purpose). Since there are three heads, there will be three different value of gamma $\gamma_1=1-2^{-5-0}=0.96875$, $\gamma_2=1-2^{-5-1}=0.9375$, and $\gamma_3=1-2^{-5-2}=0.875$ applied on each head respectively. However, these gammas are fixed and identical among different layers.

Then the input will go through each retention head. Each head's output will be fed to the GroupNorm layer. The GroupNorm layer is applied on each head separately because the heads use multiple $\gamma$, resulting each head has diferent variance statistics that need to be normalized. After that, the results are concatenated and element-wise multiplied with the output of the swish gate to increase the non-linearity of the retention layers. Finally, the output is projected by multipliying it with $W_O$ so that the output has the same dimension as the input, which is $2 \times 4$.

Read more:
- [GroupNorm](https://arxiv.org/abs/1803.08494)
- [swish](https://arxiv.org/pdf/1710.05941v1.pdf?source=post_page)

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
import copy


def rotate_every_two(x):
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)  # in einsum notation: rearrange(x, '... d j -> ... (d j)')\

def theta_shift(x, sin, cos):
    return (x * cos) + (rotate_every_two(x) * sin)

def get_activation_fn(activation):
    if activation == "swish":
        return F.silu
    elif activation == "gelu":
        return F.gelu
    else:
        raise NotImplementedError

def MultiwayWrapper(args, module, dim=1):
    if args.multiway:
        return MultiwayNetwork(module, dim=dim)
    return module

class MultiwayNetwork(nn.Module):
    def __init__(self, module, dim=1):
        super().__init__()
        self.dim = dim
        self.A = module
        self.B = copy.deepcopy(module)
        self.B.reset_parameters()
        self.split_position = -1

    def forward(self, x, **kwargs):
        if self.split_position == -1:
            return self.A(x, **kwargs)
        if self.split_position == 0:
            return self.B(x, **kwargs)
        x1, x2 = torch.split(
            x,
            [self.split_position, x.size(self.dim) - self.split_position],
            dim=self.dim,
        )
        # x1, x2 = x[:self.split_position], x[self.split_position:]
        y1, y2 = self.A(x1, **kwargs), self.B(x2, **kwargs)
        return torch.cat([y1, y2], dim=self.dim)

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
        super().__init__()
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.register_parameter('weight', None)

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        if self.weight is not None:
            output = output * self.weight
        return output

class MultiScaleRetention(nn.Module):
    def __init__(
        self,
        args,
        embed_dim,
        value_dim,
        num_heads,
        gate_fn="swish",
    ):
        super().__init__()
        self.args = args
        self.embed_dim = embed_dim
        self.value_dim = value_dim
        self.num_heads = num_heads
        self.head_dim = self.value_dim // num_heads
        self.key_dim = self.embed_dim // num_heads
        self.scaling = self.key_dim ** -0.5
        
        self.gate_fn = get_activation_fn(activation=str(gate_fn))

        self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
        self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
        self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
        self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
        
        self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False))

        self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.q_proj.weight, gain=2 ** -2.5)
        nn.init.xavier_uniform_(self.k_proj.weight, gain=2 ** -2.5)
        nn.init.xavier_uniform_(self.v_proj.weight, gain=2 ** -2.5)
        nn.init.xavier_uniform_(self.g_proj.weight, gain=2 ** -2.5)
        nn.init.xavier_uniform_(self.out_proj.weight, gain=2 ** -1)

    def parallel_forward(self, qr, kr, v, mask):
        bsz, tgt_len, embed_dim = v.size()

        vr = v.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)

        qk_mat = qr @ kr.transpose(-1, -2) # bsz * m * tgt_len * tgt_len
        qk_mat = qk_mat * mask
        # invariant after normalization
        qk_mat = qk_mat / qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1, max=5e4)
        output = torch.matmul(qk_mat, vr)
        output = output.transpose(1, 2)
        return output

    def recurrent_forward(
        self,
        qr, kr, v,
        decay,
        incremental_state
    ):
        bsz = v.size(0)

        v = v.view(bsz, self.num_heads, self.head_dim, 1)
        kv = kr * v
        if "prev_key_value" in incremental_state:
            prev_kv = incremental_state["prev_key_value"]
            prev_scale = incremental_state["scale"]
            scale = prev_scale * decay + 1
            kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(self.num_heads, 1, 1) + kv / scale.sqrt().view(self.num_heads, 1, 1)
            # kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv
        else:
            scale = torch.ones_like(decay)

        incremental_state["prev_key_value"] = kv
        incremental_state["scale"] = scale

        output = torch.sum(qr * kv, dim=3)
        return output
    
    def chunk_recurrent_forward(
        self,
        qr, kr, v,
        inner_mask
    ):
        mask, cross_decay, query_inner_decay, value_inner_decay = inner_mask
        bsz, tgt_len, embed_dim = v.size()
        chunk_len = mask.size(1)
        num_chunks = tgt_len // chunk_len

        assert tgt_len % chunk_len == 0

        qr = qr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
        kr = kr.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(1, 2)
        v = v.view(bsz, num_chunks, chunk_len, self.num_heads, self.head_dim).transpose(2, 3)

        kr_t = kr.transpose(-1, -2)

        qk_mat = qr @ kr_t # bsz * num_heads * chunk_len * chunk_len
        qk_mat = qk_mat * mask
        inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
        qk_mat = qk_mat / inner_scale
        inner_output = torch.matmul(qk_mat, v) # bsz * num_heads * num_value_heads * chunk_len * head_dim
        
        # reduce kv in one chunk
        kv = kr_t @ (v * value_inner_decay)

        kv_recurrent = []
        cross_scale = []
        kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
        kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
        
        # accumulate kv by loop
        for i in range(num_chunks):
            kv_recurrent.append(kv_state / kv_scale)
            cross_scale.append(kv_scale)
            kv_state = kv_state * cross_decay + kv[:, i]
            kv_scale = kv_state.detach().abs().sum(dim=-2, keepdim=True).max(dim=-1, keepdim=True).values.clamp(min=1)
            
        kv_recurrent = torch.stack(kv_recurrent, dim=1)
        cross_scale = torch.stack(cross_scale, dim=1)
        
        all_scale = torch.maximum(inner_scale, cross_scale)
        align_inner_scale = all_scale / inner_scale
        align_cross_scale = all_scale / cross_scale

        cross_output = (qr * query_inner_decay) @ kv_recurrent
        output = inner_output / align_inner_scale + cross_output / align_cross_scale
        # output = inner_output / cross_scale + cross_output / inner_scale

        output = output.transpose(2, 3)
        return output
    
    def forward(
        self,
        x,
        rel_pos,
        chunkwise_recurrent=False,
        incremental_state=None
    ):
        bsz, tgt_len, _ = x.size()
        (sin, cos), inner_mask = rel_pos

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        g = self.g_proj(x)

        k *= self.scaling
        q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
        k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)

        qr = theta_shift(q, sin, cos)
        kr = theta_shift(k, sin, cos)

        if incremental_state is not None:
            output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
        elif chunkwise_recurrent:
            output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
        else:
            output = self.parallel_forward(qr, kr, v, inner_mask)
        
        output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)

        output = self.gate_fn(g) * output

        output = self.out_proj(output)

        return output


### Going line by line in the `forward` function:
**Variables**
These are important variables that are going to be used along the computation:
1. Batch size (bsz): The batch size of the input.
2. Target len / Sequence len (tgt_len): The length of the sequence.
3. (sin, cos): Angle for positional embedding.
4. inner_mask: Masking matrix; depends on the mode used, it can also contain other constants.
```python
bsz, tgt_len, _ = x.size()
(sin, cos), inner_mask = rel_pos
```

**Key, Query, Value**
Obtain the key, query, and value representation of the input by multiplying the input with learnable matrices.
```python
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
g = self.g_proj(x)
```
They are declared in the `__init__` function
```python
self.q_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=False))
self.v_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
self.g_proj = MultiwayWrapper(args, nn.Linear(embed_dim, value_dim, bias=False))
```

**Multi-Head**
Dividing the query and key matrices into several heads (similar to Multi-Head Attention).
```python
k *= self.scaling
q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
```
The number of heads and key dimension are declared in the `__init__` function.
```python
self.num_heads = num_heads
self.head_dim = self.value_dim // num_heads
self.key_dim = self.embed_dim // num_heads
```
The multi-head operation will also be applied to the value matrix, but it is implemented inside the represenation calculation (parallel, recurrent, chunkwise recurrent).

**Positional Embedding**
Add the Extrapolatable Position Embedding [XPos](https://arxiv.org/abs/2212.10554).
```python
qr = theta_shift(q, sin, cos)
kr = theta_shift(k, sin, cos)
```
where `theta_shift` function is defined outside the class.
```python
def rotate_every_two(x):
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)  # in einsum notation: rearrange(x, '... d j -> ... (d j)')\

def theta_shift(x, sin, cos):
    return (x * cos) + (rotate_every_two(x) * sin)
```

**Representation Calculation**
Multi-Scale Retention performs computation according to the condition.
```python
if incremental_state is not None:
    output = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
elif chunkwise_recurrent:
    output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
else:
    output = self.parallel_forward(qr, kr, v, inner_mask)
```
where each representation computation is declared inside the MultiScaleRetention class (will be explained later).

**Normalization, Gating, and Projection**
Perform normalization, apply gating function, and do projection.
```python
output = self.group_norm(output).reshape(bsz, tgt_len, self.head_dim * self.num_heads)
output = self.gate_fn(g) * output
output = self.out_proj(output)
return output
```
where each function is declared in the `__init__` function
```python
self.gate_fn = get_activation_fn(activation=str(gate_fn))
self.out_proj = MultiwayWrapper(args, nn.Linear(value_dim, embed_dim, bias=False))
self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False))
```
and `get_activation_fn` is declared outside the class
```python
def get_activation_fn(activation):
    if activation == "swish":
        return F.silu
    elif activation == "gelu":
        return F.gelu
    else:
        raise NotImplementedError
```