Skip to content

Commit

Permalink
Try approach 4 on StarCoderBase-7b; bring back LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
mhyee committed Oct 30, 2023
1 parent 7e1ed97 commit a6d445d
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 9 deletions.
23 changes: 23 additions & 0 deletions finetune/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ See the instructions in `../README.md` for downloading the

## Running fine-tuning

To fine-tune cheaply and efficiently, we use Hugging Face 🤗's
[PEFT](https://github.com/huggingface/peft) as well as Tim Dettmers'
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes).

The fine-tuning is based on the
[ts-training](https://huggingface.co/datasets/nuprl/ts-training) dataset,
revision `v1.1p1`. **You will need to accept the agreement** on the dataset
Expand Down Expand Up @@ -71,6 +75,25 @@ CUDA_VISIBLE_DEVICES=2,3 torchrun --nproc-per-node 2 run_finetune.py
Use `CUDA_VISIBLE_DEVICES` to select the GPUs for fine-tuning, and
`--nproc-per-node` to specify the number of GPUs to use.

## Merging PEFT adapter layers

If you train a model with PEFT, you'll need to merge the adapter layers with
the base model if you want to run inference / evaluation. To do so, run:

```bash
python merge_peft_adapters.py \
--peft_model_path ../checkpoints/checkpoint-1000 \
--output ../../models/merged_model
```

By default, the base model is assumed to be starcoderbase-1b, and located in
`../../models/starcoderbase-1b`. The model can be specified as a path or model
ID, e.g. `--model_name_or_path bigcode/starcoder`

By default, the merged model is saved to disk, with the name given by
`--output`. Setting the `--push_to_hub` argument will upload the merged model
to the Hugging Face Hub.

## Copying tokenizer files

Checkpoints contain additional files that are not needed for
Expand Down
66 changes: 64 additions & 2 deletions finetune/finetune_lib.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
from accelerate import Accelerator
from collections.abc import Callable
from dataclasses import dataclass
from datasets import Dataset, IterableDataset
from pathlib import Path
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
set_peft_model_state_dict,
)
from torch.utils.data import IterableDataset as TorchIterableDataset
from transformers import (
AutoModelForCausalLM,
PreTrainedTokenizer,
Trainer,
TrainingArguments,
TrainerCallback,
TrainerControl,
TrainerState,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from typing import Optional
import torch

Expand Down Expand Up @@ -41,15 +52,57 @@ class DatasetConfig:
seq_length: int = 2048


class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> TrainerControl:
checkpoint_folder = Path(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
)
kwargs["model"].save_pretrained(checkpoint_folder)
pytorch_model_path = Path(checkpoint_folder, "pytorch_model.bin")
torch.save({}, pytorch_model_path)
return control


class LoadBestPeftModelCallback(TrainerCallback):
def on_train_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
) -> TrainerControl:
print(
f"Loading best peft model from {state.best_model_checkpoint} "
f"(score: {state.best_metric})."
)
best_model_path = Path(state.best_model_checkpoint, "adapter_model.bin")
adapters_weights = torch.load(best_model_path)
model = kwargs["model"]
set_peft_model_state_dict(model, adapters_weights)
return control


def print_trainable_parameters(model) -> None:
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(f"trainable params: {trainable_params}")
print(
f"trainable params: {trainable_params} || "
f"all params: {all_param} || "
f"trainable%: {100 * trainable_params / all_param}"
)


class ConstantLengthDataset(TorchIterableDataset):
Expand Down Expand Up @@ -194,6 +247,7 @@ def create_datasets(
def run_training(
model_path: str,
training_args: TrainingArguments,
lora_config: Optional[LoraConfig],
train_data: TorchIterableDataset,
val_data: TorchIterableDataset,
):
Expand All @@ -202,9 +256,16 @@ def run_training(
model = AutoModelForCausalLM.from_pretrained(
model_path,
use_cache=not training_args.gradient_checkpointing,
device_map=None,
load_in_8bit=lora_config is not None,
device_map={"": Accelerator().process_index} if lora_config else None,
)

callbacks = []
if lora_config:
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
callbacks = [SavePeftModelCallback, LoadBestPeftModelCallback]

print_trainable_parameters(model)

print("Starting main loop")
Expand All @@ -213,6 +274,7 @@ def run_training(
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
callbacks=callbacks,
)

print("Training...")
Expand Down
49 changes: 49 additions & 0 deletions finetune/merge_peft_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from pathlib import Path
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import argparse
import torch

MODEL_PATH = str(
Path(Path(__file__).parent, "..", "..", "models", "starcoderbase-1b").resolve()
)


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--base_model_name_or_path", type=str, default=MODEL_PATH)
parser.add_argument("--peft_model_path", type=str, required=True)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--output", type=str, required=True)

return parser.parse_args()


def main():
args = get_args()

base_model = AutoModelForCausalLM.from_pretrained(
args.base_model_name_or_path, return_dict=True, torch_dtype=torch.float16
)

model = PeftModel.from_pretrained(base_model, args.peft_model_path)
model = model.merge_and_unload()

tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path)

if args.push_to_hub:
print("Saving to hub ...")
model.push_to_hub(
f"{args.base_model_name_or_path}-merged", use_temp_dir=False, private=True
)
tokenizer.push_to_hub(
f"{args.base_model_name_or_path}-merged", use_temp_dir=False, private=True
)
else:
model.save_pretrained(args.output)
tokenizer.save_pretrained(args.output)
print(f"Model saved to {args.output}")


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion finetune/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
accelerate @ git+https://github.com/huggingface/accelerate
accelerate==0.24.0
bitsandbytes==0.41.1
peft==0.5.0
wandb==0.15.5
38 changes: 33 additions & 5 deletions finetune/run_finetune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datasets import Dataset, IterableDataset
from pathlib import Path
from peft import LoraConfig
from transformers import AutoTokenizer, TrainingArguments, logging, set_seed
from typing import Optional
import argparse
import os

Expand All @@ -23,7 +25,7 @@
"""

MODEL_PATH = str(
Path(Path(__file__).parent, "..", "..", "models", "starcoderbase-1b").resolve()
Path(Path(__file__).parent, "..", "..", "models", "starcoderbase-7b").resolve()
)

# We are using a very large dataset, so it's not feasible to download the whole
Expand All @@ -39,14 +41,26 @@
########## StarCoder-1B on an A100/H100
# We pack the tokens into a ConstantLengthDataset,
# where each example has SEQUENCE_LENGTH tokens
# SEQUENCE_LENGTH = 8 * 1024
# EPOCHS = 1
# BATCH_SIZE = 2
# GRADIENT_ACCUMULATION_STEPS = 4

# Roughly 1.7M examples
# Roughly 217K / NUM_GPUS steps
########## StarCoder-1B on an A100/H100

########## StarCoder-7B on an A100 with LoRA
# We pack the tokens into a ConstantLengthDataset,
# where each example has SEQUENCE_LENGTH tokens
SEQUENCE_LENGTH = 8 * 1024
EPOCHS = 1
BATCH_SIZE = 2
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 4

# Roughly 1.7M examples
# Roughly 217K / NUM_GPUS steps
########## StarCoder-1B on an A100/H100
# Roughly 433K / NUM_GPUS steps
########## StarCoder-7B on an A100 with LoRA

NUM_EXAMPLES = TOTAL_TOKENS // SEQUENCE_LENGTH

Expand Down Expand Up @@ -82,13 +96,25 @@
dataloader_drop_last=True,
eval_steps=50, # save_steps must be a multiple of eval_steps
run_name="StarCoder-finetuned",
load_best_model_at_end=True, # needed for LoRA callbacks
optim="adamw_torch",
report_to="wandb",
ddp_find_unused_parameters=False,
resume_from_checkpoint=False, # only set to True if there is an existing checkpoint!
gradient_checkpointing=True,
)

# If not using LoRA, set to None
# LORA_CONFIG: Optional[LoraConfig] = None
LORA_CONFIG: Optional[LoraConfig] = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["c_proj", "c_attn", "q_attn"],
)

DATASET_CONFIG = DatasetConfig(
get_content=get_training_example.default,
streaming=True,
Expand Down Expand Up @@ -124,7 +150,9 @@ def main():
train_dataset, eval_dataset = finetune.create_datasets(
dataset, tokenizer, DATASET_CONFIG, args.seed
)
finetune.run_training(MODEL_PATH, TRAINING_ARGS, train_dataset, eval_dataset)
finetune.run_training(
MODEL_PATH, TRAINING_ARGS, LORA_CONFIG, train_dataset, eval_dataset
)


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ ignore_missing_imports = True
[mypy-transformers]
ignore_missing_imports = True

[mypy-transformers.trainer_utils]
ignore_missing_imports = True

[mypy-transformers.utils]
ignore_missing_imports = True

Expand Down
7 changes: 6 additions & 1 deletion util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def load_dataset(
split: Optional[str] = None,
revision: Optional[str] = None,
num_proc: Optional[int] = None,
streaming: Optional[bool] = None,
) -> Dataset | IterableDataset:
"""
Load a dataset. Tries to interpret dataset as a path and loads a local file
Expand All @@ -43,7 +44,11 @@ def load_dataset(
else:
print(f"Loading dataset {dataset} from the Hugging Face Hub...", flush=True)
return datasets.load_dataset(
dataset, split=split, revision=revision, num_proc=num_proc
dataset,
split=split,
revision=revision,
num_proc=num_proc,
streaming=streaming,
)


Expand Down

0 comments on commit a6d445d

Please sign in to comment.