# Dataset repacking implementation

Advance dataset operations, of sorting, offset, and length support

In [1]:
GPU_DEVICES="auto"
ENABLE_WANDB=True
WANDB_PREFIX="infctx-v5-datapack"

print("ENABLE_WANDB:", ENABLE_WANDB)
print("GPU_DEVICES:", GPU_DEVICES)

if ENABLE_WANDB:
    WANDB_MODE="online"
else:
    WANDB_MODE="disabled"

# Computing the notebook, and various paths
import os
NOTEBOOK_DIR=os.path.dirname(os.path.abspath("__file__"))
PROJECT_DIR=os.path.abspath(os.path.join(NOTEBOOK_DIR, "../../"))
TRAINER_DIR=os.path.abspath(os.path.join(PROJECT_DIR, "./RWKV-v5/"))

print("NOTEBOOK_DIR:", NOTEBOOK_DIR)
print("TRAINER_DIR:", TRAINER_DIR)
print("PROJECT_DIR:", PROJECT_DIR)

ENABLE_WANDB: True
GPU_DEVICES: auto
NOTEBOOK_DIR: /home/recursal/RWKV-infctx-trainer/notebook/trainer-v5-validation
TRAINER_DIR: /home/recursal/RWKV-infctx-trainer/RWKV-v5
PROJECT_DIR: /home/recursal/RWKV-infctx-trainer


In [2]:
# Init the model
!cd "{TRAINER_DIR}" && \
    python3 ./init_model.py \
        --n_layer 6 --n_embd 1024 \
        --vocab_size world --skip-if-exists \
        "../model/L6-D1024-world-v5base-init.pth"

[2024-02-17 01:53:06,566] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[RWKV.model] Running RWKV infctx using 'torch-jit' with torch '2.1.2'
---- Initializing model ----
No of layers: 6
Embedding size: 1024
Output model path: ../model/L6-D1024-world-v5base-init.pth
Vocab size: 65536
Emb scale: 0.0001
Note: this process takes a significant time (and ram) for large models
---- ----- ----
Model exists, skipping init_model


# Build the datapack

In [3]:
# Lets preload the requried dataset 
!cd "{TRAINER_DIR}" && \
    python3 datapack_build.py "{NOTEBOOK_DIR}/config/datapack-build.yaml"

>> Starting datapack build process for: /home/recursal/RWKV-infctx-trainer/notebook/trainer-v5-validation/config/datapack-build.yaml
>> Preparing dataset - index:  0  - name:  enwiki_10k
Map (num_proc=160): 100%|███████| 10000/10000 [00:00<00:00, 11337.88 examples/s]
Filter (num_proc=160): 100%|████| 10000/10000 [00:00<00:00, 16064.01 examples/s]
Map (num_proc=160): 100%|██████████| 9992/9992 [00:01<00:00, 5399.75 examples/s]
Map (num_proc=160): 100%|█████████████| 586/586 [00:01<00:00, 403.36 examples/s]
Saving the dataset (1/1 shards): 100%|█| 586/586 [00:00<00:00, 7825.99 examples/
Saving the dataset (1/1 shards): 100%|████| 6/6 [00:00<00:00, 351.26 examples/s]
>> Preparing dataset - index:  1  - name:  openhermes
Map (num_proc=160): 100%|████| 242831/242831 [00:02<00:00, 118253.41 examples/s]
Filter (num_proc=160): 100%|█| 242831/242831 [00:01<00:00, 169667.47 examples/s]
Map (num_proc=160): 100%|█████| 240402/240402 [00:04<00:00, 49760.35 examples/s]
Map (num_proc=160): 100%|████|

In [8]:
# Lets load the datapath at {PROJECT_DIR}/datapath/world/Eagle-x-multipack/ via HF dataset
# and iterate the first 10 documents
import datasets
datapath = f"{PROJECT_DIR}/datapath/v5-validation/example-datapack/"
print(f"Loading the dataset... {datapath}")

# Load the dataset
full_dataset = datasets.load_from_disk(datapath)
print(f"Dataset loaded, {len(full_dataset)} documents")

# Train dataset
train_dataset = full_dataset["train"]

# Lets iterate the first 10 documents
iterate_limit = 4
for idx, doc in enumerate(train_dataset):
    print(f"Document {idx+1}:")
    # Print the keys
    print("- Keys:", doc.keys())
    print("- input_ids length:", len(doc["input_ids"]))
    print("- attention_mask length:", len(doc["attention_mask"]))
    print("- token_type_ids length:", len(doc["token_type_ids"]))
    # print("- sample_length:", doc["sample_length"])

    if idx >= (iterate_limit-1):
        break

Loading the dataset... /home/recursal/RWKV-infctx-trainer/datapath/v5-validation/example-datapack/
Dataset loaded, 2 documents
Document 1:
- Keys: dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
- input_ids length: 8167
- attention_mask length: 8167
- token_type_ids length: 8167
Document 2:
- Keys: dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
- input_ids length: 8192
- attention_mask length: 8192
- token_type_ids length: 8192
Document 3:
- Keys: dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
- input_ids length: 8182
- attention_mask length: 8182
- token_type_ids length: 8182
Document 4:
- Keys: dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
- input_ids length: 8182
- attention_mask length: 8182
- token_type_ids length: 8182


# Short train

In [9]:
!cd "{TRAINER_DIR}" && \
    export WANDB_MODE="disabled" && \
    python3 lightning_trainer.py fit \
        -c "{NOTEBOOK_DIR}/config/datapack-train.yaml" \
        --model.load_model="../model/L6-D1024-world-v5base-init.pth" \
        --trainer.callbacks.init_args.dirpath="../checkpoint/datapack-validaiton-train/" \
        --trainer.logger.init_args.name="{WANDB_PREFIX} - Multi Datapack Validation - (deepspeed_stage_1)" \
        --trainer.strategy="deepspeed_stage_1" \
        --trainer.microbatch_size=8 \
        --trainer.fast_dev_run=2 \
        --trainer.devices="{GPU_DEVICES}"

[2024-02-17 05:36:58,441] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[RWKV.model] Running RWKV infctx using 'torch-jit' with torch '2.1.2'
/home/recursal/miniconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/cli.py:518: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['fit', '-c', '/home/recursal/RWKV-infctx-trainer/notebook/trainer-v5-validation/config/datapack-train.yaml', '--model.load_model=../model/L6-D1024-world-v5base-init.pth', '--trainer.callbacks.init_args.dirpath=../checkpoint/datapack-validaiton-train/', '--trainer.logger.init_args.name=infctx-v5-datapack - Multi Datapack Validation - (deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.microbatch_size=8', '--trainer.fast_dev_run=2', '--trainer.devices=auto'], args=['f

# Partial training run

In [110]:
!cd "{TRAINER_DIR}" && \
    export WANDB_MODE="{WANDB_MODE}" && \
    python3 lightning_trainer.py fit \
        -c "{NOTEBOOK_DIR}/config/datapack-train.yaml" \
        --model.load_model="../model/L6-D1024-world-v5base-init.pth" \
        --trainer.callbacks.init_args.dirpath="../checkpoint/datapack-validaiton-train/" \
        --trainer.logger.init_args.name="{WANDB_PREFIX} - Multi Datapack Validation - (deepspeed_stage_1)" \
        --trainer.strategy="deepspeed_stage_1" \
        --trainer.microbatch_size=8 \
        --trainer.max_steps=50 \
        --trainer.devices="{GPU_DEVICES}"

[2024-01-28 09:05:15,732] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[RWKV.model] Running RWKV infctx using 'torch-jit' with torch '2.1.2'
/home/recursal/miniconda3/envs/rwkv-infctx/lib/python3.11/site-packages/lightning/pytorch/cli.py:518: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['fit', '-c', '/home/recursal/RWKV-infctx-trainer/notebook/trainer-v5-validation/config/datapack-train.yaml', '--model.load_model=../model/L6-D1024-world-v5base-init.pth', '--trainer.callbacks.init_args.dirpath=../checkpoint/datapack-validaiton-train/', '--trainer.logger.init_args.name=infctx-v5-datapack - Multi Datapack Validation - (deepspeed_stage_1)', '--trainer.strategy=deepspeed_stage_1', '--trainer.microbatch_size=8', '--trainer.max_steps=50', '--trainer.devices=auto'], args=['fit