# RWKV Token Shift Experiment A
This model is a custom model containing
- 12 layers
- 2560 embedding size

See `./notes.md` for how the init model was initilaized.

**Note:** This project assumes you have the rwkv-infctx conda env setup

---

```bash
# ninja-build is required for the new trainer
sudo apt-get install ninja-build

# Update conda & its package listings
conda update conda

# Virtual env, with python 3.10
# python 3.11 have issues with torch.compile / h100s
# and if you want to use 3.11, you will need to do a nightly build install
conda create -n rwkv-infctx python=3.11 pip
conda activate rwkv-infctx

# Install pytorch (>=2.0.1)
conda install -y pytorch==2.0.1 torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

# Verify your pytorch version 
python -c "import torch; print(torch.__version__)"

# We use python -m pip, instead of pip directly, as it resolve issues with venv not loading the right pip
python -m pip install datasets transformers 
python -m pip install lightning==2.0.4 deepspeed==0.9.5
python -m pip install ninja numexpr jsonargparse 'jsonargparse[signatures]'
python -m pip install lm-dataformat ftfy sentencepiece tokenizers wandb
```
---

# Basic Setup

In [None]:
# First lets setup the various directories, and get the blank init model, these init model was generated
# using the original RWKV-LM repo (as at this point of writing, this repo cannot init a model)
# As such I have preinitialized these blank models and uploaded them to HF for convinence
!mkdir -p ../../../../model/
!mkdir -p ../../../../datapath/
!mkdir -p ../../../../checkpoint/
!cd ../../../../model/ && wget -nc https://huggingface.co/picocreator/memory-size-experiment-for-rwkv/resolve/main/L12-D2560-init.pth
!ls -alh ../../../../model/L12-D2560-init.pth

# The various other stages, if you want to skip stuff
# !cd ../../../../model/ && wget -nc https://huggingface.co/picocreator/memory-size-experiment-for-rwkv/resolve/main/TokenShift-A-Stage1.pth
# !cd ../../../../model/ && wget -nc https://huggingface.co/picocreator/memory-size-experiment-for-rwkv/resolve/main/TokenShift-A-Stage2.pth

In [1]:
DEEPSPEED_STRAT="deepspeed_stage_1"
GPU_DEVICES="auto"
ENABLE_WANDB=True
WANDB_PREFIX="TokenShift-Exp-A"

print("DEEPSPEED_STRAT:", DEEPSPEED_STRAT)
print("ENABLE_WANDB:", ENABLE_WANDB)
print("GPU_DEVICES:", GPU_DEVICES)

if ENABLE_WANDB:
    WANDB_MODE="online"
else:
    WANDB_MODE="disabled"

# Computing the notebook, and various paths
import os
NOTEBOOK_DIR=os.path.dirname(os.path.abspath("__file__"))
PROJECT_DIR=os.path.abspath(os.path.join(NOTEBOOK_DIR, "../../../../"))
TRAINER_DIR=os.path.abspath(os.path.join(PROJECT_DIR, "./RWKV-v4wavenet/"))
INFERENCE_DIR=os.path.abspath(os.path.join(PROJECT_DIR, "./RWKV-v4wavenet/"))

print("NOTEBOOK_DIR:", NOTEBOOK_DIR)
print("INFERENCE_DIR:", INFERENCE_DIR)
print("TRAINER_DIR:", TRAINER_DIR)
print("PROJECT_DIR:", PROJECT_DIR)

DEEPSPEED_STRAT: deepspeed_stage_1
ENABLE_WANDB: True
GPU_DEVICES: auto
NOTEBOOK_DIR: /home/ubuntu/rwkv5x-tokenshift-exp-A/notebook/experiment/tokenshift-exp
INFERENCE_DIR: /home/ubuntu/rwkv5x-tokenshift-exp-A/RWKV-v5x
TRAINER_DIR: /home/ubuntu/rwkv5x-tokenshift-exp-A/RWKV-v4wavenet
PROJECT_DIR: /home/ubuntu/rwkv5x-tokenshift-exp-A


## Stage 1 : Foundation model training

In [3]:
# Lets preload the requried dataset (enwiki_100k)
!cd "{TRAINER_DIR}" && \
    python3 preload_datapath.py "{NOTEBOOK_DIR}/TokenShift-A-enwiki.yaml"

Found cached dataset parquet (/root/.cache/huggingface/datasets/teven___parquet/teven--enwiki_100k-1359e81b212c2dd6/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 69.36it/s]
Loading cached processed dataset at /root/.cache/huggingface/datasets/teven___parquet/teven--enwiki_100k-1359e81b212c2dd6/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-1ae92a8fcbab5f66_*_of_00064.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/teven___parquet/teven--enwiki_100k-1359e81b212c2dd6/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-67576cf8a6607f3d_*_of_00064.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/teven___parquet/teven--enwiki_100k-1359e81b212c2dd6/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-f69b6b1c107274f3_*_of_00064.arrow
                                

In [4]:
# Start the foundation model training
!cd "{TRAINER_DIR}" && \
    export WANDB_MODE="{WANDB_MODE}" && \
    python lightning_trainer.py fit \
        -c "{NOTEBOOK_DIR}/TokenShift-A-enwiki.yaml" \
        --trainer.logger.init_args.name="{WANDB_PREFIX} - Enwiki Foundation (ctx=4096, {DEEPSPEED_STRAT})" \
        --trainer.strategy="{DEEPSPEED_STRAT}" \
        --trainer.devices="{GPU_DEVICES}" 

Setting ds_accelerator to cuda (auto detect)
[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.0.1+cu118'
  rank_zero_warn(f"No seed found, seed set to {seed}")
Global seed set to 2259997942
[34m[1mwandb[0m: Currently logged in as: [33mpicocreator[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.15.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.15.4
[34m[1mwandb[0m: Run data is saved locally in [35m[1m./wandb/run-20230715_083122-hlfv4t2c[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33m(8x3090, rechunked) TokenShift-Exp-A - Enwiki Foundation (ctx=4096, deepspeed_stage_1)[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/picocreator/RWKV-5X-Experiments[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/picocreator/RWKV-5X-Experimen

In [5]:
# Lets export the model from the checkpoint
!cd "{TRAINER_DIR}" && \
    python export_checkpoint.py "../checkpoint/TokenShift-A-enwiki/last.ckpt" "../model/TokenShift-A-Stage1.pth"
!cd "{TRAINER_DIR}" && ls -alh "../model/TokenShift-A-Stage1.pth"

Setting ds_accelerator to cuda (auto detect)
Processing zero checkpoint '../checkpoint/TokenShift-A-enwiki/last.ckpt/checkpoint'
Detected checkpoint of type zero stage ZeroStageEnum.optimizer_states, world_size: 8
Parsing checkpoint created by deepspeed==0.9.3
Reconstructed fp32 state dict with 222 params 1280128000 elements
Saving fp32 state dict to ../model/TokenShift-A-Stage1.pth
-rw-r--r-- 1 root root 4.8G Jul 15 13:42 ../model/TokenShift-A-Stage1.pth


In [3]:
# # Lets do a quick dragon prompt validation
!cd "{INFERENCE_DIR}" && python3 dragon_test.py ../model/TokenShift-A-Stage1.pth "cuda fp32"


RWKV_HEAD_QK_DIM 0 RWKV_JIT_ON 1

blocks.0.att.key.weight                  float32    cuda:0
blocks.0.att.output.weight               float32    cuda:0
blocks.0.att.receptance.weight           float32    cuda:0
blocks.0.att.time_mix_k                  float32    cuda:0
blocks.0.att.time_mix_r                  float32    cuda:0
blocks.0.att.time_mix_v                  float32    cuda:0
blocks.0.att.value.weight                float32    cuda:0
blocks.0.ffn.key.weight                  float32    cuda:0
blocks.0.ffn.receptance.weight           float32    cuda:0
blocks.0.ffn.time_mix_k                  float32    cuda:0
blocks.0.ffn.time_mix_r                  float32    cuda:0
blocks.0.ffn.value.weight                float32    cuda:0
blocks.0.ln0.bias                        float32    cuda:0
blocks.0.ln0.weight                      float32    cuda:0
blocks.0.ln1.bias                        float32    cuda:0
blocks.0.ln1.weight                      float32    cuda:0
blocks.0.ln2.bias    

In [10]:
# Lets do a quick memory test
# (We dun expect this to work, as we have not finetune for memory recall, but its a baseline)
!python3 ../memory_script/eval_memory_guided.py "{PROJECT_DIR}/model/TokenShift-A-Stage1.pth"


RWKV_HEAD_QK_DIM 0 RWKV_JIT_ON 1

blocks.0.att.key.weight                  float32    cuda:0
blocks.0.att.output.weight               float32    cuda:0
blocks.0.att.receptance.weight           float32    cuda:0
blocks.0.att.time_mix_k                  float32    cuda:0
blocks.0.att.time_mix_r                  float32    cuda:0
blocks.0.att.time_mix_v                  float32    cuda:0
blocks.0.att.value.weight                float32    cuda:0
blocks.0.ffn.key.weight                  float32    cuda:0
blocks.0.ffn.receptance.weight           float32    cuda:0
blocks.0.ffn.time_mix_k                  float32    cuda:0
blocks.0.ffn.time_mix_r                  float32    cuda:0
blocks.0.ffn.value.weight                float32    cuda:0
blocks.0.ln0.bias                        float32    cuda:0
blocks.0.ln0.weight                      float32    cuda:0
blocks.0.ln1.bias                        float32    cuda:0
blocks.0.ln1.weight                      float32    cuda:0
blocks.0.ln2.bias    

# Stage 2 : Instruct Tuning

In [6]:
# Lets preload the requried dataset
!cd "{TRAINER_DIR}" && \
    python3 preload_datapath.py "{NOTEBOOK_DIR}/TokenShift-A-instruct.yaml"

Found cached dataset parquet (/root/.cache/huggingface/datasets/c-s-ale___parquet/c-s-ale--dolly-15k-instruction-alpaca-format-9dfbb23260d63d9d/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 370.62it/s]
Loading cached processed dataset at /root/.cache/huggingface/datasets/c-s-ale___parquet/c-s-ale--dolly-15k-instruction-alpaca-format-9dfbb23260d63d9d/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-f7f96ced2d76e1c7_*_of_00064.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/c-s-ale___parquet/c-s-ale--dolly-15k-instruction-alpaca-format-9dfbb23260d63d9d/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-3b2912dc14935a9a_*_of_00064.arrow
                                                                                

In [7]:
# Start the instruct finetuning
!cd "{TRAINER_DIR}" && \
    export WANDB_MODE="{WANDB_MODE}" && \
    python lightning_trainer.py fit \
        -c "{NOTEBOOK_DIR}/TokenShift-A-instruct.yaml" \
        --trainer.logger.init_args.name="{WANDB_PREFIX} - Instruct (train-ctx=4096, {DEEPSPEED_STRAT})" \
        --trainer.strategy="{DEEPSPEED_STRAT}" \
        --trainer.devices="{GPU_DEVICES}"

Setting ds_accelerator to cuda (auto detect)
[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.0.1+cu118'
  rank_zero_warn(f"No seed found, seed set to {seed}")
Global seed set to 1727192453
[34m[1mwandb[0m: Currently logged in as: [33mpicocreator[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.15.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.15.4
[34m[1mwandb[0m: Run data is saved locally in [35m[1m./wandb/run-20230715_135325-9wh9tja6[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33m(8x3090, rechunked) TokenShift-Exp-A - Instruct (train-ctx=4096, deepspeed_stage_1)[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/picocreator/RWKV-5X-Experiments[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/picocreator/RWKV-5X-Experiments/

In [9]:
# Lets export the model from the checkpoint
!cd "{TRAINER_DIR}" && \
    python export_checkpoint.py "../checkpoint/TokenShift-A-instruct/last.ckpt" "../model/TokenShift-A-Stage2.pth"
!cd "{TRAINER_DIR}" && ls -alh "../model/TokenShift-A-Stage2.pth"

Setting ds_accelerator to cuda (auto detect)
Processing zero checkpoint '../checkpoint/TokenShift-A-instruct/last.ckpt/checkpoint'
Detected checkpoint of type zero stage ZeroStageEnum.optimizer_states, world_size: 8
Parsing checkpoint created by deepspeed==0.9.3
Reconstructed fp32 state dict with 222 params 1280128000 elements
Saving fp32 state dict to ../model/TokenShift-A-Stage2.pth
-rw-r--r-- 1 root root 4.8G Jul 15 14:32 ../model/TokenShift-A-Stage2.pth


In [11]:
# Lets do a quick dragon prompt validation
!cd "{INFERENCE_DIR}" && python3 dragon_test.py "../model/TokenShift-A-Stage2.pth" "cuda fp32"


RWKV_HEAD_QK_DIM 0 RWKV_JIT_ON 1

blocks.0.att.key.weight                  float32    cuda:0
blocks.0.att.output.weight               float32    cuda:0
blocks.0.att.receptance.weight           float32    cuda:0
blocks.0.att.time_mix_k                  float32    cuda:0
blocks.0.att.time_mix_r                  float32    cuda:0
blocks.0.att.time_mix_v                  float32    cuda:0
blocks.0.att.value.weight                float32    cuda:0
blocks.0.ffn.key.weight                  float32    cuda:0
blocks.0.ffn.receptance.weight           float32    cuda:0
blocks.0.ffn.time_mix_k                  float32    cuda:0
blocks.0.ffn.time_mix_r                  float32    cuda:0
blocks.0.ffn.value.weight                float32    cuda:0
blocks.0.ln0.bias                        float32    cuda:0
blocks.0.ln0.weight                      float32    cuda:0
blocks.0.ln1.bias                        float32    cuda:0
blocks.0.ln1.weight                      float32    cuda:0
blocks.0.ln2.bias    

In [12]:
# Lets do a quick memory test
# (We dun expect this to work, as we have not finetune for memory recall, but its a baseline)
!python3 ../memory_script/eval_memory_guided.py "{PROJECT_DIR}/model/TokenShift-A-Stage2.pth"


RWKV_HEAD_QK_DIM 0 RWKV_JIT_ON 1

blocks.0.att.key.weight                  float32    cuda:0
blocks.0.att.output.weight               float32    cuda:0
blocks.0.att.receptance.weight           float32    cuda:0
blocks.0.att.time_mix_k                  float32    cuda:0
blocks.0.att.time_mix_r                  float32    cuda:0
blocks.0.att.time_mix_v                  float32    cuda:0
blocks.0.att.value.weight                float32    cuda:0
blocks.0.ffn.key.weight                  float32    cuda:0
blocks.0.ffn.receptance.weight           float32    cuda:0
blocks.0.ffn.time_mix_k                  float32    cuda:0
blocks.0.ffn.time_mix_r                  float32    cuda:0
blocks.0.ffn.value.weight                float32    cuda:0
blocks.0.ln0.bias                        float32    cuda:0
blocks.0.ln0.weight                      float32    cuda:0
blocks.0.ln1.bias                        float32    cuda:0
blocks.0.ln1.weight                      float32    cuda:0
blocks.0.ln2.bias    