# Echo-A 1B5 (Memory model from scratch - stage)

Fine tune 2 process, continues off with masking everything but the output.
This forces all the loss learning only to be on the output token, and not on the instruction set / etc. (which it should have already learnt)

> This project assumes you have the rwkv-infctx conda env setup, and you are executing in that environment - see the main README.md for the conda env setup steps

## Optional: Download the pretrained model
(if you want to skip the the basemodel train + instruct tune)


In [1]:
# Init required dirs
!mkdir -p ../../../model/
!mkdir -p ../../../datapath/
!mkdir -p ../../../checkpoint/

# Download the Stage2.pth file
!rm -rf ../../../model/Echo-A-1B5-Scratch-Stage-1.pth
!cd ../../../model/ && wget https://huggingface.co/picocreator/memory-size-experiment-for-rwkv/resolve/main/Echo-A-1B5-Scratch-Stage-1.pth
!ls -alh ../../../model/Echo-A-1B5-Scratch-Stage-1.pth

--2023-07-13 04:22:49--  https://huggingface.co/picocreator/memory-size-experiment-for-rwkv/resolve/main/Echo-A-1B5-Scratch-Stage-1.pth
Resolving huggingface.co (huggingface.co)... 52.85.242.84, 52.85.242.16, 52.85.242.8, ...
Connecting to huggingface.co (huggingface.co)|52.85.242.84|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/cb/ef/cbef09abb2634a3375b28868bffa285226dfeabedec89b28c2fb302221164d66/f032099c1cd32937c5fc33c6a61b7beee18a93eeddb89a17d3cff571763337e7?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27Echo-A-1B5-Scratch-Stage-1.pth%3B+filename%3D%22Echo-A-1B5-Scratch-Stage-1.pth%22%3B&Expires=1689481369&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY4OTQ4MTM2OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9jYi9lZi9jYmVmMDlhYmIyNjM0YTMzNzViMjg4NjhiZmZhMjg1MjI2ZGZlYWJlZGVjODliMjhjMmZiMzAyMjIxMTY0ZDY2L2YwMzIwOTljMWNkMzI5MzdjNWZjMz

## Prepare the dataset

Prepare and preload the finetuning process dataset

In [2]:
%%script bash
# Reset the dataset dir
mkdir -p ./dataset
rm -rf ./dataset/*.jsonl

# Generate the various datasets
echo "## Generating word reptition dataset ##"

# Prompt completion pairs, are fully masked instruction and input, with unmasked outputs
# This is required to actually teach the model how to memorize the input, but on its own, 
# its unable to actually teach the model how to trigger this behavior (as the instruct is masked)
python ./memory_script/gen_limited_masked_jsonl.py ./dataset/limited-masked-word-2-count.jsonl  2  25000 &
python ./memory_script/gen_limited_masked_jsonl.py ./dataset/limited-masked-word-5-count.jsonl  5  25000 &
python ./memory_script/gen_limited_masked_jsonl.py ./dataset/limited-masked-word-10-count.jsonl 10 30000 &
python ./memory_script/gen_limited_masked_jsonl.py ./dataset/limited-masked-word-15-count.jsonl 15 30000 &
python ./memory_script/gen_limited_masked_jsonl.py ./dataset/limited-masked-word-20-count.jsonl 20 30000 &
python ./memory_script/gen_limited_masked_jsonl.py ./dataset/limited-masked-word-40-count.jsonl 40 30000 &
python ./memory_script/gen_limited_masked_jsonl.py ./dataset/limited-masked-word-60-count.jsonl 60 25000 &
python ./memory_script/gen_limited_masked_jsonl.py ./dataset/limited-masked-word-80-count.jsonl 80 20000 &
python ./memory_script/gen_limited_masked_jsonl.py ./dataset/limited-masked-word-100-count.jsonl 100 20000 &
python ./memory_script/gen_limited_masked_jsonl.py ./dataset/limited-masked-word-200-count.jsonl 200 10000 &

# Prompt completion pairs, with the full word list. Due to the size of the full word list, it 
# was possible to be stuck training the model just to recognize new words / tokens, and not perform the memorization task
# this greatly slowed down the memorization learning process. As the model was constantly learning new words. 
# With 400k+ words total, even after 100k worth of document samples, new words can appear (due to how RNG works)
#
# We still include a mix of the data, in an attempt to reduce overtraining the model to only a fixed token set.
# which was one of the weakness faced in the original training / benchmark (but technically not an issue for measuring memory)
python ./memory_script/gen_full_masked_jsonl.py ./dataset/full-masked-word-2-count.jsonl  2  15000 &
python ./memory_script/gen_full_masked_jsonl.py ./dataset/full-masked-word-5-count.jsonl  5  15000 &
python ./memory_script/gen_full_masked_jsonl.py ./dataset/full-masked-word-10-count.jsonl 10 15000 &
python ./memory_script/gen_full_masked_jsonl.py ./dataset/full-masked-word-15-count.jsonl 15 15000 &
python ./memory_script/gen_full_masked_jsonl.py ./dataset/full-masked-word-20-count.jsonl 20 15000 &
python ./memory_script/gen_full_masked_jsonl.py ./dataset/full-masked-word-40-count.jsonl 40 15000 &
python ./memory_script/gen_full_masked_jsonl.py ./dataset/full-masked-word-60-count.jsonl 60 10000 &
python ./memory_script/gen_full_masked_jsonl.py ./dataset/full-masked-word-80-count.jsonl 80 10000 &
python ./memory_script/gen_full_masked_jsonl.py ./dataset/full-masked-word-100-count.jsonl 100 5000 &
python ./memory_script/gen_full_masked_jsonl.py ./dataset/full-masked-word-200-count.jsonl 200 5000 &

wait
echo "## Done ##"

## Generating word reptition dataset ##
Generated JSONL file with - 2 max words, 15000 samples - at ./dataset/full-masked-word-2-count.jsonl
Generated JSONL file with - 5 max words, 15000 samples - at ./dataset/full-masked-word-5-count.jsonl
Generated JSONL file with - 10 max words, 15000 samples - at ./dataset/full-masked-word-10-count.jsonl
Generated JSONL file with - 2 max words, 25000 samples - at ./dataset/limited-masked-word-2-count.jsonl
Generated JSONL file with - 100 max words, 5000 samples - at ./dataset/full-masked-word-100-count.jsonl
Generated JSONL file with - 15 max words, 15000 samples - at ./dataset/full-masked-word-15-count.jsonl
Generated JSONL file with - 20 max words, 15000 samples - at ./dataset/full-masked-word-20-count.jsonl
Generated JSONL file with - 5 max words, 25000 samples - at ./dataset/limited-masked-word-5-count.jsonl
Generated JSONL file with - 10 max words, 30000 samples - at ./dataset/limited-masked-word-10-count.jsonl
Generated JSONL file with - 60 

## Configure your environment settings
(!Important: you will need to rerun the below cell, if you restart your kernel)

In [3]:
DEEPSPEED_STRAT="deepspeed_stage_1"
GPU_DEVICES="auto"
ENABLE_WANDB=True
WANDB_PREFIX="Echo-A-1B5"

print("DEEPSPEED_STRAT:", DEEPSPEED_STRAT)
print("ENABLE_WANDB:", ENABLE_WANDB)
print("GPU_DEVICES:", GPU_DEVICES)

if ENABLE_WANDB:
    WANDB_MODE="online"
else:
    WANDB_MODE="disabled"

# Computing the notebook, and various paths
import os
NOTEBOOK_DIR=os.path.dirname(os.path.abspath("__file__"))
PROJECT_DIR=os.path.abspath(os.path.join(NOTEBOOK_DIR, "../../../"))
TRAINER_DIR=os.path.abspath(os.path.join(PROJECT_DIR, "./RWKV-v4neo/"))

print("NOTEBOOK_DIR:", NOTEBOOK_DIR)
print("TRAINER_DIR:", TRAINER_DIR)
print("PROJECT_DIR:", PROJECT_DIR)

DEEPSPEED_STRAT: deepspeed_stage_1
ENABLE_WANDB: True
GPU_DEVICES: auto
NOTEBOOK_DIR: /root/picocreator-memory-experiment/notebook/experiment/memory-scratch
TRAINER_DIR: /root/picocreator-memory-experiment/RWKV-v4neo
PROJECT_DIR: /root/picocreator-memory-experiment


## Stage 3 : Simple Memory finetuning

In [4]:
# Lets preload the requried dataset (enwiki_100k)
!cd "{TRAINER_DIR}" && \
    python3 preload_datapath.py "{NOTEBOOK_DIR}/Echo-A-1B5-scratch-stage-2.yaml"

# Ensure the checkpoint directory exists
!cd "{TRAINER_DIR}" && mkdir -p "../checkpoint/Echo-A-1B5-scratch-stage-2/"

Resolving data files: 100%|██████████████████| 20/20 [00:00<00:00, 12323.50it/s]
Downloading and preparing dataset json/default to /root/.cache/huggingface/datasets/json/default-93d149df4877176d/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96...
Downloading data files: 100%|███████████████████| 1/1 [00:00<00:00, 2709.50it/s]
Extracting data files: 100%|█████████████████████| 1/1 [00:00<00:00, 115.74it/s]
Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-93d149df4877176d/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96. Subsequent calls will reuse this data.
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 89.66it/s]
                                                                                

In [5]:
# Start the foundation model training
!cd "{TRAINER_DIR}" && \
    export WANDB_MODE="{WANDB_MODE}" && \
    python lightning_trainer.py fit \
        -c "{NOTEBOOK_DIR}/Echo-A-1B5-scratch-stage-2.yaml" \
        --trainer.logger.init_args.name="{WANDB_PREFIX} - Scratch-Stage-2 (bs=256, train-ctx=512, {DEEPSPEED_STRAT})" \
        --trainer.strategy="{DEEPSPEED_STRAT}" \
        --trainer.devices="{GPU_DEVICES}"  \
        --model.ctx_len=512

Setting ds_accelerator to cuda (auto detect)
[RWKV.model] Running RWKV model using 'torch-jit' with torch '2.0.1+cu118'
  rank_zero_warn(f"No seed found, seed set to {seed}")
Global seed set to 2509290067
[34m[1mwandb[0m: Currently logged in as: [33mpicocreator[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.15.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.15.4
[34m[1mwandb[0m: Run data is saved locally in [35m[1m./wandb/run-20230713_042546-j8ngi79y[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33m(8x3090) Echo-A-1B5 - Scratch-Stage-2 (bs=256, train-ctx=512, deepspeed_stage_1)[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/picocreator/RWKV-Memory-Experiment[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/picocreator/RWKV-Memory-Experime

In [6]:
# Lets export the model from the checkpoint
!cd "{TRAINER_DIR}" && \
    python export_checkpoint.py \
        "../checkpoint/Echo-A-1B5-scratch-stage-2/last.ckpt" \
        "../model/Echo-A-1B5-Scratch-Stage-2.pth"
!cd "{TRAINER_DIR}" && ls -alh ../model/Echo-A-1B5-Tune2.pth

Setting ds_accelerator to cuda (auto detect)
Processing zero checkpoint '../checkpoint/Echo-A-1B5-scratch-stage-2/last.ckpt/checkpoint'
Detected checkpoint of type zero stage ZeroStageEnum.optimizer_states, world_size: 8
Parsing checkpoint created by deepspeed==0.9.3
Reconstructed fp32 state dict with 438 params 1515106304 elements
Saving fp32 state dict to ../model/Echo-A-1B5-Scratch-Stage-2.pth
ls: cannot access '../model/Echo-A-1B5-Tune2.pth': No such file or directory


In [7]:
# Lets do a quick dragon prompt validation
!python3 ./memory_script/eval_memory_guided.py "{PROJECT_DIR}/model/Echo-A-1B5-Scratch-Stage-2.pth"

Using /root/.cache/torch_extensions/py311_cu118 as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/py311_cu118/wkv_cuda...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py311_cu118/wkv_cuda/build.ninja...
Building extension module wkv_cuda...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] c++ -MMD -MF wrapper.o.d -DTORCH_EXTENSION_NAME=wkv_cuda -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /usr/local/lib/python3.11/dist-packages/torch/include -isystem /usr/local/lib/python3.11/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.11/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.11/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.11 -D_GLIBCXX