Skip to content

Commit

Permalink
stackllama
Browse files Browse the repository at this point in the history
  • Loading branch information
mnoukhov committed Oct 26, 2023
1 parent e39a272 commit 814e0a9
Show file tree
Hide file tree
Showing 28 changed files with 2,826 additions and 0 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## Language Model Alignment with Elastic Reset

Code and experiments for ...

See corresponding folder for code, experiments


## Citation

```
@inproceedings{noukhovitch_language_2023,
title = {Language Model Alignment with Elastic Reset},
author = {Noukhovitch, Michael and Lavoie, Samuel and Strub, Florian and Courville, Aaron},
booktitle = {Neural Information Processing Systems},
year = {2023},
}
```

23 changes: 23 additions & 0 deletions stackllama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model.
There were three main steps to the training process:
1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se:
2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm:
3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model:

For all methods use `python run.py -e configs/` and choose the corresponding config


## Pretrained Models

My LoRA layers for the vanilla StackLLaMA are publicly available on huggingface as
- `mnoukhov/llama-7b-se-peft`
- `mnoukhov/llama-7b-se-rm-peft`
- `mnoukhov/llama-7b-se-rl-peft`

LoRA layers were using at all stages to reduce memory requirements.
At each stage the peft adapter layers were merged with the base model, using:
```shell
python examples/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ
```

I used `huggyllama/llama-7b` as the base model
8 changes: 8 additions & 0 deletions stackllama/configs/humaneval_llama_7b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
tasks: humaneval
model: "huggyllama/llama-7b"
max_length_generation: 512
temperature: 0.2
top_p: 0.95
n_samples: 20
batch_size: 10
allow_code_execution: True
8 changes: 8 additions & 0 deletions stackllama/configs/humaneval_llama_7b_rlhf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model_ckpt: "llama-7b-rlhf"
tokenizer: "huggyllama/llama-7b"
load_8bit: True
do_sample: True
temperature: 0.2
top_p: 0.95
n_samples: 200
HF_ALLOW_CODE_EVAL: "1"
8 changes: 8 additions & 0 deletions stackllama/configs/humaneval_llama_7b_sft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model_ckpt: "llama-7b-sft"
tokenizer: "huggyllama/llama-7b"
load_8bit: True
do_sample: True
temperature: 0.2
top_p: 0.95
n_samples: 200
HF_ALLOW_CODE_EVAL: "1"
11 changes: 11 additions & 0 deletions stackllama/configs/ppl_llama_7b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model: "huggyllama/llama-7b"
tokenizer: "huggyllama/llama-7b"
dataset_name: "lvwerra/stack-exchange-paired"
subset: "data/evaluation"
split: "train"
data_size: 4000
seq_length: 1024
stride: 512
seed: 0
batch_size: 1
bit8: True
11 changes: 11 additions & 0 deletions stackllama/configs/ppl_llama_7b_sft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model: "llama-7b-sft"
tokenizer: "hf-internal-testing/llama-tokenizer"
dataset_name: "lvwerra/stack-exchange-paired"
subset: "data/evaluation"
split: "train"
data_size: 4000
seq_length: 1024
stride: 512
seed: 0
batch_size: 1
bit8: True
11 changes: 11 additions & 0 deletions stackllama/configs/ppl_llama_7b_sft_elasticreset.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model: llama-7b-se-elasticreset
tokenizer: huggyllama/llama-7b
dataset_name: "lvwerra/stack-exchange-paired"
subset: "data/evaluation"
split: "train"
data_size: 4000
seq_length: 1024
stride: 512
seed: 0
batch_size: 1
bit8: True
11 changes: 11 additions & 0 deletions stackllama/configs/ppl_llama_7b_sft_rlhf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model: "llama-7b-rlhf"
tokenizer: "huggyllama/llama-7b"
dataset_name: "lvwerra/stack-exchange-paired"
subset: "data/evaluation"
split: "train"
data_size: 4000
seq_length: 1024
stride: 512
seed: 0
batch_size: 1
bit8: True
10 changes: 10 additions & 0 deletions stackllama/configs/rew_elasticreset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
model_name: "llama-7b-rlhf-er-redo-600"
reward_model_name: kashif/llama-7b_stack-exchange_RM_peft-adapter-merged
save_freq: 100
batch_size: 8
gradient_accumulation_steps: 8
batched_gen: True
output_dir: results/
early_stopping: True
seed: 0
steps: 8
10 changes: 10 additions & 0 deletions stackllama/configs/rew_ppo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
model_name: "llama-7b-rlhf-ppo-600"
reward_model_name: kashif/llama-7b_stack-exchange_RM_peft-adapter-merged
save_freq: 100
batch_size: 8
gradient_accumulation_steps: 8
batched_gen: True
output_dir: results/
early_stopping: True
seed: 0
steps: 8
10 changes: 10 additions & 0 deletions stackllama/configs/rew_sft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
model_name: "llama-7b-sft"
reward_model_name: kashif/llama-7b_stack-exchange_RM_peft-adapter-merged
save_freq: 100
batch_size: 8
gradient_accumulation_steps: 8
batched_gen: True
output_dir: results/
early_stopping: True
seed: 0
steps: 8
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model_name: "llama-7b-se"
tokenizer_name: "huggyllama/llama-7b"
reward_model_name: "llama-7b-se-rm"
log_with: "wandb"
save_freq: 100
batch_size: 16
gradient_accumulation_steps: 8
batched_gen: True
output_dir: results/
early_stopping: True
seed: 0
reset_freq: 260
ema_decay: 0.995
init_kl_coef: 0.02
12 changes: 12 additions & 0 deletions stackllama/configs/rlhf_llama_7b_sft_mine.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model_name: "llama-7b-se"
tokenizer_name: "hf-internal-testing/llama-tokenizer"
reward_model_name: "llama-7b-se-rm"
log_with: "wandb"
save_freq: 100
batch_size: 8
gradient_accumulation_steps: 8
batched_gen: True
output_dir: results/
early_stopping: True
seed: 0
init_kl_coef: 0.2
2 changes: 2 additions & 0 deletions stackllama/configs/rm_llama_7b_sft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: "llama-7b-sft"
tokenizer_name: hf-internal-testing/llama-tokenizer
7 changes: 7 additions & 0 deletions stackllama/configs/sft_llama_7b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
model_path: huggyllama/llama-7b
streaming: True
no_gradient_checkpointing: True
learning_rate: 1e-5
max_steps: 5000
output_dir: llama-sft
gradient_accumulation_steps: 2
Loading

0 comments on commit 814e0a9

Please sign in to comment.