Skip to content

Commit

Permalink
rl training
Browse files Browse the repository at this point in the history
  • Loading branch information
sanagno committed Feb 20, 2023
1 parent 6e081a0 commit b9df5d6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
4 changes: 2 additions & 2 deletions model/model_training/configs/config_rl.yaml
Expand Up @@ -7,17 +7,17 @@ defaults_rlhf:
epochs: 10
datasets:
- oa_private:
data_path: .cache
split: rl
val_split: 0.0
fraction: 1
file: 2023-02-10_oasst_prod.jsonl
cache_dir: .cache
quantization: false
seq2seqmodel: false
output_dir: output
reward_model_batch_size: 32

debug_rlhf:
model_name: gpt2
rank_model: /local/home/sanagnos/general/Open-Assistant/model/reward/instructor/facebook/galactica-125m-finetuned/checkpoint-500/
sft_model: /local/home/sanagnos/general/Open-Assistant/model/model_training/EleutherAI/pythia-70m-deduped-base-finetuned/checkpoint-20/
batch_size: 2
6 changes: 3 additions & 3 deletions model/model_training/configs/ppo_config.yaml
Expand Up @@ -2,7 +2,7 @@ train:
seq_length: 1024
epochs: 100
total_steps: 10000
batch_size: 128
batch_size: 1

checkpoint_interval: 10000
eval_interval: 100
Expand Down Expand Up @@ -34,8 +34,8 @@ scheduler:

method:
name: "ppoconfig"
num_rollouts: 128
chunk_size: 128
num_rollouts: 16
chunk_size: 16
ppo_epochs: 4
init_kl_coef: 0.05
target: 6
Expand Down
25 changes: 18 additions & 7 deletions model/model_training/trainer_rl.py
@@ -1,9 +1,10 @@
import argparse
import itertools
import random

import torch
import transformers
import trlx
from custom_datasets.formatting import QA_SPECIAL_TOKENS
from models import get_specific_model
from trlx.data.configs import TRLConfig
from utils import _strtobool, get_dataset, read_yamls
Expand Down Expand Up @@ -73,22 +74,32 @@ def rank_model_fn(samples, **kwargs):

train, _ = get_dataset(training_conf, mode="rl")

print([train[i] for i in range(5)])
# trlx requires training data to be a list of prompts
# iteratore prompts due to the randomness in the dataset generation
prompts = [train[i] for i in range(len(train)) for _ in range(training_conf.epochs)][:100]

# trlx requires training data to be a list of prompts?
prompts = list(itertools.chain(*[list(train[i]) for i in range(len(train)) for _ in range(training_conf.epochs)]))
random.shuffle(prompts)

model = get_specific_model(
training_conf.sft_model, training_conf.cache_dir, training_conf.quantization, training_conf.seq2seqmodel
training_conf.sft_model,
cache_dir=training_conf.cache_dir,
quantization=training_conf.quantization,
seq2seqmodel=training_conf.seq2seqmodel,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(training_conf.sft_model)

trlx_config.tokenizer.tokenizer_path = training_conf.model_name
trlx_config.tokenizer.tokenizer_path = training_conf.sft_model
trlx_config.model.model_path = training_conf.sft_model
trlx_config.train.batch_size = training_conf.batch_size

trainer = trlx.train(
training_conf.model_name,
training_conf.sft_model,
reward_fn=rank_model_fn,
prompts=prompts,
config=trlx_config,
stop_sequences=[tokenizer.eos_token, QA_SPECIAL_TOKENS["Question"]],
)

training_conf.output_dir = training_conf.output_dir if training_conf.output_dir else training_conf.model_name

trainer.save_pretrained(training_conf.output_dir)

0 comments on commit b9df5d6

Please sign in to comment.