<a href="https://colab.research.google.com/github/mrsteyk/RWKV-LM-deepspeed/blob/master/RWKV_v4neo_Trainer_STK.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **RWKV-v4neo finetuner**

**THIS IS A COLAB NOTEBOOK USE LOCAL SCRIPTS IF YOU RUN LOCALLY**

Certified `Your session crashed after using all available RAM.` moment

This colab is a port of [https://github.com/mrsteyk/RWKV-LM-deepspeed](https://github.com/mrsteyk/RWKV-LM-deepspeed)

In [None]:
#@title Prereqs
%cd /content
!git clone --depth 1 https://github.com/mrsteyk/RWKV-LM-deepspeed.git
%cd RWKV-LM-deepspeed
!git pull
%cd RWKV-v4neo

!pip install deepspeed pytorch_lightning transformers psutil

In [None]:
#@title Tokenize your stuff
import numpy as np

from transformers import GPTNeoXTokenizerFast
tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")

input_file = 'train.txt'
output_file = 'train.npy'

print(f'Tokenizing {input_file} (VERY slow. please wait)')

data_raw = open(input_file, encoding="utf-8").read()
print(f'Raw length = {len(data_raw)}')

data_code = tokenizer.encode(data_raw)
print(f'Tokenized length = {len(data_code)}')

out = np.array(data_code, dtype='uint16')
np.save(output_file, out, allow_pickle=False)

In [None]:
!wget https://huggingface.co/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth

In [None]:
#@title Initial Imports
import types

import deepspeed
import os
import torch
import torch.utils.data
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.callbacks import DeviceStatsMonitor, ModelCheckpoint

from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint

import dataset
import lr_warmup
from soft_embedding_hotswap import SoftEmbedding

In [None]:
#@title Arguments

args = types.SimpleNamespace()

args.vocab_size_delta = 1
args.allgather_bucket_size = 200
args.reduce_bucket_size = 200

args.data_file = "train.npy"

args.batch_size = 2

args.load_model_cont = '' # ZeRO checkpoint

args.soft_emb_tune = False
args.soft_emb_tokens = 50

args.load_model_init = './RWKV-4-Pile-430M-20220808-8066.pth' # Initialise weights with this
args.layerwise_lr = True
args.ctx_len = 1024

# For Soft Embeddings try using larger lr and epsilon (0.1 and 1e-6)
args.lr_init = 1e-5 # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048
args.lr_final = 1e-5
args.warmup_steps = 0 # try 50 if you load a model
args.beta1 = 0.9
args.beta2 = 0.999 # def: 0.99; use 0.999 when your model is close to convergence
args.adam_eps = 1e-8


# Hyperparameters
args.vocab_size = 50277
args.n_layer = 24
args.n_embd = 1024
args.pre_ffn = False
args.head_qk = 0

args.tiny_att_dim = 0
args.tiny_att_layer = -999

# Trainer stuff
args.accelerator = "gpu"
args.devices = 1
args.precision = 16 #"bf16" # Do T4's support bf16?
args.strategy = 'single_device'

In [None]:
#@title Trainer prologue (if you change context len you need to restart)
args.betas = (args.beta1, args.beta2)
rank_zero_info(args)

assert args.precision in [32, 16, "bf16"]
if args.precision == 16:
    os.environ["RWKV_FLOAT_MODE"] = "fp16"
else:
    os.environ["RWKV_FLOAT_MODE"] = str(args.precision)
os.environ["RWKV_T_MAX"] = str(args.ctx_len + (args.soft_emb_tokens if args.soft_emb_tune else 0))

# Now we can import the model after setting that stupid T max envvar
import model as M
model = M.RWKV(args)
# model = None

if args.load_model_cont != '' and not args.soft_emb_tune:
    # load_state_dict_from_zero_checkpoint(model, args.load_model_cont)
    pass
elif args.load_model_init != '':
    if os.path.isdir(args.load_model_init):
        load_state_dict_from_zero_checkpoint(model, args.load_model_init)
        model.cpu()
        if args.precision == 16:
            model.half()
        elif args.precision == "bf16":
            model.bfloat16()
    else:
        d = torch.load(args.load_model_init, map_location='cpu')
        if list(d.keys())[0].startswith("_forward_module."):
            d = {n[len("_forward_module."):]: d[n] for n in d.keys()}
        model.load_state_dict(d)
    # model = M.RWKV(args).load_from_checkpoint(args.load_model_init)
else:
    # TODO?
    # model = M.RWKV(args)
    model.generate_init_weight()

if args.vocab_size_delta > 0:
    new_vocab_size = args.vocab_size + args.vocab_size_delta
    model.resize_emb(new_vocab_size)
    args.vocab_size = new_vocab_size

if args.soft_emb_tune:
    # meme hard, die young
    print("### буду погибать молодым/малоДЫМ(а)")
    args.layerwise_lr = False
    for p in model.parameters():
        p.requires_grad = False
    model.emb_hotswap = True
    assert args.soft_emb_tokens < args.vocab_size, "Soft Embedding can't eat more than the `emb`"
    model.emb = SoftEmbedding(model.emb, n_tokens=args.soft_emb_tokens, initialize_from_vocab=True)

lr_meme = lr_warmup.LearningWarmUpCallback(args)
device_stats = DeviceStatsMonitor(cpu_stats=True)
val_loss_checkpointing = ModelCheckpoint(
    filename="epoch-{epoch:02d}-val_loss-{val_loss:.2f}",
    # save_on_train_epoch_end=True,
    # save_weights_only=True,
    save_top_k=3,
    mode='min',
    monitor="val_loss",
    auto_insert_metric_name=False,
)
epoch_checkpointing = ModelCheckpoint(
    filename="epoch-{epoch:02d}",
    save_on_train_epoch_end=True,
    save_top_k=1,
    mode='max',
    monitor="epoch",
    auto_insert_metric_name=False,
)

trainer = Trainer.from_argparse_args(
    args,
    callbacks=[lr_meme, device_stats, val_loss_checkpointing, epoch_checkpointing],
)
if "deepspeed" in args.strategy:
    trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.allgather_bucket_size * 1e6
    trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.reduce_bucket_size * 1e6
    rank_zero_info(trainer.strategy.config)

if "single_device" == args.strategy:
    trainer.strategy._root_device = torch.device('cuda:0')

train_data = dataset.MyDataSet(args)

# TODO(mrsteyk): Allow different validation files
# use 20% of training data for validation
train_set_size = int(len(train_data) * 0.8)
valid_set_size = len(train_data) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_data, valid_data = torch.utils.data.random_split(train_data, [train_set_size, valid_set_size], generator=seed)

# data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.batch_size, num_workers=1, persistent_workers=False, drop_last=True)
train_loader = DataLoader(train_data, shuffle=True, pin_memory=True, batch_size=args.batch_size)
valid_loader = DataLoader(valid_data, shuffle=False, pin_memory=True, batch_size=args.batch_size)

In [None]:
#@title Train

%load_ext tensorboard
%tensorboard --logdir /content/RWKV-LM-deepspeed/RWKV-v4neo/lightning_logs

model.cuda()
trainer.fit(model, train_loader, valid_loader, ckpt_path=args.load_model_cont if args.load_model_cont != ''  else None)