In [4]:
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
import torch.nn as nn
from torch.nn import functional as F

MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method

rwkv5又叫eagal

In [5]:
import torch

In [6]:
class RWKV_TOKENIZER():
    table: list[list[list[bytes]]]
    good: list[set[int]]
    wlen: list[int]
    def __init__(self, file_name):
        self.idx2token = {}
        sorted = [] # must be already sorted
        lines = open(file_name, "r", encoding="utf-8").readlines()
        for l in lines:
            idx = int(l[:l.index(' ')])
            x = eval(l[l.index(' '):l.rindex(' ')])
            x = x.encode("utf-8") if isinstance(x, str) else x
            assert isinstance(x, bytes)
            assert len(x) == int(l[l.rindex(' '):])
            sorted += [x]
            self.idx2token[idx] = x

        self.token2idx = {}
        for k, v in self.idx2token.items():
            self.token2idx[v] = int(k)

        # precompute some tables for fast matching
        self.table = [[[] for j in range(256)] for i in range(256)]
        self.good = [set() for i in range(256)]
        self.wlen = [0 for i in range(256)]

        for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
            s = sorted[i]
            if len(s) >= 2:
                s0 = int(s[0])
                s1 = int(s[1])
                self.table[s0][s1] += [s]
                self.wlen[s0] = max(self.wlen[s0], len(s))
                self.good[s0].add(s1)

    def encodeBytes(self, src: bytes) -> list[int]:
        src_len: int = len(src)
        tokens: list[int] = []
        i: int = 0
        while i < src_len:
            s: bytes = src[i : i + 1]

            if i < src_len - 1:
                s1: int = int(src[i + 1])
                s0: int = int(src[i])
                if s1 in self.good[s0]:
                    sss: bytes = src[i : i + self.wlen[s0]]
                    try:
                        s = next(filter(sss.startswith, self.table[s0][s1]))
                    except:
                        pass
            tokens.append(self.token2idx[s])
            i += len(s)

        return tokens

    def decodeBytes(self, tokens):
        return b''.join(map(lambda i: self.idx2token[i], tokens))

    def encode(self, src: str):
        return self.encodeBytes(src.encode("utf-8"))

    def decode(self, tokens):
        return self.decodeBytes(tokens).decode('utf-8')

    def printTokens(self, tokens):
        for i in tokens:
            s = self.idx2token[i]
            try:
                s = s.decode('utf-8')
            except:
                pass
            print(f'{repr(s)}{i}', end=' ')
            # print(repr(s), i)
        print()

########################################################################################################

In [7]:
def sample_logits(out, temperature=1.0, top_p=0.8):
    probs = F.softmax(out, dim=-1).numpy()
    sorted_probs = np.sort(probs)[::-1]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
    probs[probs < cutoff] = 0
    if temperature != 1.0:
        probs = probs.pow(1.0 / temperature)
    probs = probs / np.sum(probs)
    out = np.random.choice(a=len(probs), p=probs)
    return out

########################################################################################################

In [8]:
tokenizer = RWKV_TOKENIZER("./rwkv_vocab_v20230424.txt")

# THIS IS NOW UPDATED TO SUPPORT LATEST RWKV-5 WORLD v2 MODELS

args = types.SimpleNamespace()
args.MODEL_NAME = '/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096' #这里不用有后缀.pth
args.n_layer = 24
args.n_embd = 1024
args.vocab_size = 65536

In [9]:
# N_LAYER="12"
# N_EMBD="768"
N_LAYER="24"
N_EMBD="1024"

In [10]:
# context = "\nElon Musk has"
# context = "\n我们发现"
context = "Q:Do you know datawhalechina?\nA:"
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
LENGTH_PER_TRIAL = 4096
TEMPERATURE = 1.0
TOP_P = 0.7

Eagle (RWKV-5) 和 Finch (RWKV-6) 相较于基础的RWKV-4架构在建模上的改进：

1. **改进步骤**：
   - **Eagle的改进**：Eagle模型在RWKV-4的基础上进行了多项改进，包括引入矩阵值的注意力状态（matrix-valued attention states）、在注意力头上应用LayerNorm（层归一化）、使用SiLU（Sigmoid-Weighted Linear Unit）进行注意力门控、并改进了初始化方法。此外，Eagle移除了接受度（receptance）函数中的Sigmoid激活函数。
   - **Finch的改进**：Finch模型进一步引入了对衰减计划（decay schedule）和令牌移位（token-shift）的数据依赖性（data-dependence），使模型在处理时间和令牌数据时更加灵活和精确。

2. **核心架构**：
   - 这些模型的核心架构依然类似于RWKV-4，由一系列堆叠的残差块组成，形状类似于传统的Transformer架构。
   - 每个块包含一个预LayerNorm时间混合子层（Pre-LayerNorm Time-Mixing sub-layer）和一个预LayerNorm通道混合子层（Pre-LayerNorm Channel-Mixing sub-layer），对应于Transformer中的注意力子层和前馈网络子层。


这个是RWKV 5的Channel Mixing的代码实现，可以对比一下RWKV 4的实现。


```python
@MyFunction
    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        i0 = (2+self.head_size)*i+0
        xk = x * time_mix_k + state[i0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[i0] * (1 - time_mix_r)
        state[i0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)
```

RWKV 4的Channel Mixing的代码实现为：


```python
@torch.jit.script_method
    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
        state[5*i+0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)
```

这里的`i`表示的是RWKV有多少层，在RWKV4的每一层中Channel Mixing记录一个状态，而每一个Time Mixing则记录4个状态，所以一共是5个状态。而RWKV 5中每一层现在记录了`2+self.head_size`个状态，Channel Mixing记录的状态以及计算过程和RWKV 4是完全一样的。

![](./img/01.png)

图1：RWKV架构概述。左侧：时间混合和通道混合块；右上角：作为RNN单元的RWKV时间混合块；中下部：前馈模块中的令牌移位模块和Eagle时间混合；右下角：Finch时间混合中的令牌移位模块。所有形状注释为简洁起见假设为单头。虚线箭头（左侧，右上角）表示在Finch中有连接，但在Eagle中没有。

Eagle模型中采用的Token Shift技术：

1. **Token Shift**：
   - Eagle模型从之前的RWKV模型中采用了Token Shift技术，这类似于大小为2的一维因果卷积（1D causal convolution）。
   - 在图1的中心底部可以看到该技术的示意图。

2. **线性插值定义**：
   - 为了更好地介绍Token Shift技术，定义了一些符号。
   - 线性插值（lerp）在时间步$t$和$t-1$之间用于RWKV-4和Eagle Token Shift，定义如下：
     \begin{align*}
     \text{lerp}_{\Box}(a, b) = a + (b - a) \odot \mu_{\Box}
     \end{align*}
   - 其中，每个$\mu_{\Box} \in \mathbb{R}^D$是一个可学习的向量。

3. **Token Shift的功能**：
   - Token Shift允许模型学习在每个时间步中分配新信息和旧信息的比例，适用于接受度（receptance）、键（key）、值（value）和门控向量（gate vectors）中的每个通道（$r, k, v, g$），且每个头部（head）独立且唯一地应用这些向量。
   - 这使得即使在单层内，一个单独的头部也可以直接将过去和当前的令牌数据累积到这些向量的不同子空间中，从而形成感应头（induction heads）。


在Eagle和Finch模型中，通道混合模块（Channel Mixing module）的设置及其与RWKV-4架构的异同如下：

1. **模块一致性**：
   - 在Eagle和Finch模型中，通道混合模块与之前的RWKV-4架构基本相同。
   - 唯一的区别在于Eagle模型中，通道混合模块的隐藏维度（hidden dimension）从原来的4D减少到了3.5D。

2. **减少维度的原因**：
   - 这个隐藏维度的减少是为了在Eagle时间混合（Eagle Time Mixing）中引入新的门控权重（gating weights）并确保与之前模型（在相同层数和嵌入维度下）的参数数量相等。

3. **Finch模型中的处理**：
   - 尽管Finch模型中增加了一些新的LoRA权重参数，但并没有进一步减少隐藏维度。

4. **公式一致性**：
   - 通道混合的公式与RWKV-4模型相同，为了符号一致性（notational consistency），再次列出这些公式：

\begin{align*}
r'_t &= \text{lerp}_{r'}(x'_t, x'_{t-1}) W_{r'} \in \mathbb{R}^D \quad \text{(公式10)} \\
k'_t &= \text{lerp}_{k'}(x'_t, x'_{t-1}) W_{k'} \in \mathbb{R}^{3.5D} \quad \text{(公式11)} \\
v'_t &= \text{ReLU}(k'_t)^2 W_{v'} \in \mathbb{R}^D \quad \text{(公式12)} \\
o'_t &= \sigma(r'_t) \odot v'_t \in \mathbb{R}^D \quad \text{(公式13)}
\end{align*}

这些公式描述了在时间步 \( t \) 的通道混合操作：
- 使用线性插值（lerp）计算 \( r'_t \) 和 \( k'_t \)。
- \( v'_t \) 通过 \( k'_t \) 的ReLU平方值乘以权重矩阵 \( W_{v'} \) 得到。
- \( o'_t \) 是 \( r'_t \) 的激活函数 \( \sigma \) 的输出与 \( v'_t \) 的逐元素乘积。

其中，3.5D 指的是一种表示维度的方式。在深度学习模型中，D 通常代表模型的隐藏层维度（即嵌入维度或特征空间的维度）。例如，如果模型的隐藏维度是256，那么4D表示这个维度被扩展为4倍，也就是1024。

然而，3.5D 是一个不常见的表示方法，通常情况下，我们会看到整数倍的表示（如2D, 4D等）。在这里，3.5D代表的是隐藏维度的3.5倍。

具体来说，如果模型的基础维度是D，那么3.5D就表示：
\begin{align*} 3.5D = 3.5 \times D \end{align*}

假设D是256，那么3.5D就是：
\begin{align*} 3.5 \times 256 = 896 \end{align*}

所以，3.5D就是指模型在特定层中使用的特征维度是基础维度的3.5倍。在这个文档中，作者提到从4D减少到3.5D，意味着他们减少了某个层或模块的特征维度，以便引入新的门控权重并保持参数数量的一致性。

Eagle时间混合（Eagle Time Mixing）的公式及其操作方法如下：

### 公式部分

Eagle时间混合的公式如下：

\begin{align*}
\Box_t &= \text{lerp}_{\Box}(x_t, x_{t-1}) W_{\Box}, \quad \Box \in \{r, k, v, g\} \tag{4} \\
w &= \exp(-\exp(\omega)) \tag{5} \\
\text{wk} \mathbf{v}_t &= \text{diag}(u) \cdot k_t^\top \cdot v_t + \sum_{i=1}^{t-1} \text{diag}(w)^{t-1-i} \cdot k_i^\top \cdot v_i \in \mathbb{R}^{(D/h) \times (D/h)} \tag{6} \\
o_t &= \text{concat} \left( \text{SiLU}(g_t) \odot \text{LayerNorm}(r_t \cdot \text{wk} \mathbf{v}_t) \right) W_o \in \mathbb{R}^D \tag{7}
\end{align*}

### 解释部分

- **LayerNorm的操作**：LayerNorm在每个头部（head）上独立操作，这相当于在h个组上执行GroupNorm（Wu & He，2018）。值得注意的是，$w$ 是由 $\omega \in \mathbb{R}^{D/h}$ 通过公式 $w = \exp(-\exp(\omega))$ 计算得到的，$\omega$ 是实际的头部可训练参数。这确保了 $w$ 在区间 (0,1) 内，从而保证 $\text{diag}(w)$ 是一个收缩矩阵。

- **wkv_t 计算**：wkv_t 的注意力计算可以用递归形式写为：
  \begin{align*}
  \text{wk} \mathbf{v}' &= s + \text{diag}(u) \cdot k^\top \cdot v \tag{8} \\
  s' &= \text{diag}(w) \cdot s + k^\top \cdot v \tag{9}
  \end{align*}

- **解释RWKV的 wkv_t 项**：RWKV的 wk\mathbf{v}_t 项可以被认为是归一化 $k^\top v$ 项的基于衰减的等价物。值得注意的是，对于给定的头部 $j$，递归状态 $s$ 是 $k^\top v$ 的和，其中 $s$ 的每个通道在每个时间步通过相应的 $w$ 通道单独衰减。在应用接受度向量、门控和输出权重之前，当前令牌的 $k^\top v$ 被乘以一个每通道的学习提升 $u$ 并与状态相加，见图1右上角。这给当前令牌相对于包含在衰减状态历史中的过去令牌的和一个特殊的处理。接受度乘以这个和，类似于线性注意力中的查询项。


这里的最大的改进应该是现在的计算是分成了`H = self.n_head`个头，然后每个头的计算结果都被存到了state里。相比于RWKV-4，这种改进可以类比于Transformer的单头自注意力机制改到多头注意力机制。
```python
    @MyFunction
    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
        H = self.n_head
        S = self.head_size

        i1 = (2+S)*i+1
        xk = x * time_mix_k + state[i1] * (1 - time_mix_k)
        xv = x * time_mix_v + state[i1] * (1 - time_mix_v)
        xr = x * time_mix_r + state[i1] * (1 - time_mix_r)
        xg = x * time_mix_g + state[i1] * (1 - time_mix_g)
        state[i1] = x

        r = (rw @ xr).view(H, 1, S)
        k = (kw @ xk).view(H, S, 1)
        v = (vw @ xv).view(H, 1, S)
        g = F.silu(gw @ xg)

        s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)

        x = torch.zeros(H, S)
        a = k @ v
        x = r @ (time_first * a + s)
        s = a + time_decay * s
    
        state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
        x = x.flatten()

        x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
        return ow @ x
```

In [11]:
class RWKV_RNN(MyModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.eval() # set torch to inference mode
        
        w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
        for k in w.keys():
            w[k] = w[k].float() # convert to f32 type
            if      '.time_' in k: w[k] = w[k].squeeze()
            if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
            if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)

        self.n_head = w['blocks.0.att.time_decay'].shape[0]
        self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
        
        self.w = types.SimpleNamespace() # set self.w from w
        self.w.blocks = {}
        for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
            parts = k.split('.')
            last = parts.pop()
            here = self.w
            for p in parts:
                if p.isdigit():
                    p = int(p)
                    if p not in here: here[p] = types.SimpleNamespace()
                    here = here[p]
                else:
                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                    here = getattr(here, p)
            setattr(here, last, w[k])

    def layer_norm(self, x, w):
        return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)

    @MyFunction
    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        i0 = (2+self.head_size)*i+0
        xk = x * time_mix_k + state[i0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[i0] * (1 - time_mix_r)
        state[i0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

    @MyFunction
    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
        H = self.n_head
        S = self.head_size

        i1 = (2+S)*i+1
        xk = x * time_mix_k + state[i1] * (1 - time_mix_k)
        xv = x * time_mix_v + state[i1] * (1 - time_mix_v)
        xr = x * time_mix_r + state[i1] * (1 - time_mix_r)
        xg = x * time_mix_g + state[i1] * (1 - time_mix_g)
        state[i1] = x

        r = (rw @ xr).view(H, 1, S)
        k = (kw @ xk).view(H, S, 1)
        v = (vw @ xv).view(H, 1, S)
        g = F.silu(gw @ xg)

        s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)

        x = torch.zeros(H, S)
        a = k @ v
        x = r @ (time_first * a + s)
        s = a + time_decay * s
    
        state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
        x = x.flatten()

        x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
        return ow @ x

    def forward(self, token, state):
        with torch.no_grad():
            if state == None:
                state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)
            
            x = self.w.emb.weight[token]
            x = self.layer_norm(x, self.w.blocks[0].ln0)
            for i in range(self.args.n_layer):
                # print(i)
                att = self.w.blocks[i].att
                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, 
                    att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay, 
                    att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
                    att.ln_x.weight, att.ln_x.bias)
                ffn = self.w.blocks[i].ffn
                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, 
                    ffn.time_mix_k, ffn.time_mix_r, 
                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
            
            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
            return x.float(), state

In [32]:
# context = "Q:Do you know datawhalechina?\nA:"
context = '\nQ:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?'

In [33]:
args.MODEL_NAME

'/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096'

In [34]:
args.n_layer,args.n_embd

(24, 1024)

In [35]:
# args.n_layer = 24
# args.n_embd = 1024

In [36]:
# args.n_layer = 12
# args.n_embd = 768

In [37]:
# args.MODEL_NAME='../models/rwkv-5-world-1b5'

In [38]:
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)

print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None


Using CPU. Loading /data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096 ...

Preprocessing context (slow version. see v2/rwkv/model.py for fast version)


In [39]:
init_state = None

In [40]:
LENGTH_PER_TRIAL=1024

In [41]:
for token in tokenizer.encode(context):
    init_out, init_state = model.forward(token, init_state)

for TRIAL in range(NUM_TRIALS):
    print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
    all_tokens = []
    out_last = 0
    out, state = init_out.clone(), init_state.clone()
    for i in range(LENGTH_PER_TRIAL):
        token = sample_logits(out, TEMPERATURE, TOP_P)
        all_tokens += [token]
        try:
            tmp = tokenizer.decode(all_tokens[out_last:])
            if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
                print(tmp, end="", flush=True)
                out_last = i + 1
        except:
            pass
        out, state = model.forward(token, state)       
print('\n')



--[ Trial 0 ]----------------- 
Q:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?
QI: I think that the group of students is actually the whole AI community.
Q: In the first episode, how do you think you, a student, can use AI to solve a problem?
QI: It's a great opportunity to help develop and build knowledge, so that if we see AI problems, we can help solve them.
Q: How do you think that students can also participate in the teaching of AI?
QI: It is very important to let the students to think that there is an AI problem, and we can solve it by teaching AI.
Q: How do you think the research that we did on AI can be used to develop AI technologies?
QI: The research is interesting and it can be used to develop AI technologies.
Q: Do you think that students can learn from your research?
QI: I think so.
Q: You also talk about the use of AI in real-life applications. What do you think of t