Skip to content

Commit

Permalink
[NeuralChat] Added StarCoder, CodeLlama, Falcon and Mistral finetunin…
Browse files Browse the repository at this point in the history
…g example in NeuralChat (#649)
  • Loading branch information
XinyuYe-Intel committed Nov 9, 2023
1 parent db35a30 commit 477018d
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 15 deletions.
51 changes: 42 additions & 9 deletions intel_extension_for_transformers/llm/finetuning/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import datasets
import re
from itertools import chain
from intel_extension_for_transformers.neural_chat.prompts.prompt import PromptTemplate

IGNORE_INDEX = -100

Expand Down Expand Up @@ -47,21 +48,51 @@ def truncate_sequences(sequences, max_length):
return sequences

class CompletionDataPreprocess:
prompt_template = ALPACA_PROMPT_DICT
def __init__(self, dataset_name):
self.dataset_name = dataset_name.lower()
if "alpaca" in self.dataset_name:
self.prompt_template = [
PromptTemplate("alpaca_without_input"),
PromptTemplate("alpaca_with_input")
]
self.key_role_map = [
[('instruction', 0), ('output', 1)],
[('instruction', 0), ('input', 1), ('output', 2)]
]
elif "stack-exchange-instruction" in self.dataset_name:
self.prompt_template = PromptTemplate("question_answer")
self.key_role_map = [('question', 0), ('response', 1)]
else:
raise NotImplementedError(
f"Unsupported dataset {dataset_name}, "
"only supports stack-exchange-instruction and Alpaca liked dataset now."
)


def create_data(self, examples):
prompts = {}
prompts["source"] = []
prompts["target"] = []
for example in examples:
prompt_template = (
self.prompt_template["prompt_with_input"]
if example.get("input") is not None and example.get("input") != ""
else self.prompt_template["prompt_without_input"]
)
source = prompt_template.format_map(example)
prompt_template = self.prompt_template
key_role_map = self.key_role_map
if "alpaca" in self.dataset_name:
if "input" in example and example["input"]:
prompt_template = self.prompt_template[1]
key_role_map = self.key_role_map[1]
else:
prompt_template = self.prompt_template[0]
key_role_map = self.key_role_map[0]

for idx, (key, role) in enumerate(key_role_map):
message = example[key]
if idx == len(key_role_map)-1:
message = ""
prompt_template.append_message(prompt_template.roles[role], message)
source = prompt_template.get_prompt()
prompts["source"].append(source)
prompts["target"].append(example["output"])
prompts["target"].append(example[key_role_map[-1][0]])
prompt_template.clear_messages()
return prompts

@staticmethod
Expand Down Expand Up @@ -306,7 +337,9 @@ def preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args):

elif finetune_args.task == "completion" or finetune_args.task == "code-generation":
# default use alpaca template
preprocess = CompletionDataPreprocess()
preprocess = CompletionDataPreprocess(
data_args.dataset_name if data_args.dataset_name else data_args.train_file
)
for key in raw_datasets:
prompts = preprocess.create_data(raw_datasets[key])
columns_to_be_removed = list(raw_datasets[key].features.keys())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,8 @@ def finetune_clm(self, model_args, data_args, training_args, finetune_args, conf
low_cpu_mem_usage=True,
)
if not (re.search("mpt", model_args.model_name_or_path, re.IGNORECASE) or
re.search("neural-chat-7b-v1", model_args.model_name_or_path, re.IGNORECASE)):
re.search("neural-chat-7b-v1", model_args.model_name_or_path, re.IGNORECASE) or
re.search("starcoder", model_args.model_name_or_path, re.IGNORECASE)):
tokenizer.padding_side = "left" # allow batched inference, while mpt series don't support
else:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Users could follow below commands to get the checkpoints from github repository
git lfs install
git clone https://huggingface.co/meta-llama/Llama-2-7b
```
#### MPT
### MPT
To acquire the checkpoints and tokenizer, the user can get those files from [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b).
Users could follow below commands to get the checkpoints from github repository.
```bash
Expand All @@ -76,6 +76,38 @@ git clone https://huggingface.co/mosaicml/mpt-7b
For missing GPTNeoTokenizer issue, we advise the user to modify the local `tokenizer_config.json` file according to the following recommendation:
1. The `tokenizer_class` in `tokenizer_config.json` should be changed from `GPTNeoXTokenizer` to `GPTNeoXTokenizerFast`;

### Falcon
To acquire the checkpoints and tokenizer, the user can get those files from [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b).
Users could follow below commands to get the checkpoints from github repository.
```bash
git lfs install
git clone https://huggingface.co/tiiuae/falcon-7b
```

### Mistral
To acquire the checkpoints and tokenizer, the user can get those files from [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1).
Users could follow below commands to get the checkpoints from github repository.
```bash
git lfs install
git clone https://huggingface.co/mistralai/Mistral-7B-v0.1
```

### CodeLlama
To acquire the checkpoints and tokenizer, the user can get those files from [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf).
Users could follow below commands to get the checkpoints from github repository.
```bash
git lfs install
git clone https://huggingface.co/codellama/CodeLlama-7b-hf
```

### StarCoder
To acquire the checkpoints and tokenizer, the user can get those files from [bigcode/starcoder](https://huggingface.co/bigcode/starcoder).
Users could follow below commands to get the checkpoints from github repository.
```bash
git lfs install
git clone https://huggingface.co/bigcode/starcoder
```

### FLAN-T5
The user can obtain the [release model](https://huggingface.co/google/flan-t5-xl) from Huggingface.

Expand Down Expand Up @@ -228,7 +260,29 @@ python finetune_clm.py \
# the script also support other models, like mpt.
```

**For [CodeLlama](https://huggingface.co/codellama/CodeLlama-7b-hf)**, use the below command line for finetuning on the [sahil2801/CodeAlpaca-20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) code instruction dataset.

```bash
python finetune_clm.py \
--model_name_or_path "codellama/CodeLlama-7b-hf" \
--bf16 True \
--dataset_name "sahil2801/CodeAlpaca-20k" \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 1 \
--do_train \
--learning_rate 1e-4 \
--num_train_epochs 3 \
--logging_steps 100 \
--save_total_limit 2 \
--overwrite_output_dir \
--log_level info \
--save_strategy epoch \
--output_dir ./codellama_peft_finetuned_model \
--peft lora \
--use_fast_tokenizer True \
--no_cuda
```

**For [MPT](https://huggingface.co/mosaicml/mpt-7b)**, use the below command line for finetuning on the Alpaca dataset. Only LORA supports MPT in PEFT perspective.it uses gpt-neox-20b tokenizer, so you need to define it in command line explicitly.This model also requires that trust_remote_code=True be passed to the from_pretrained method. This is because we use a custom MPT model architecture that is not yet part of the Hugging Face transformers package.

Expand Down Expand Up @@ -257,6 +311,80 @@ python finetune_clm.py \
--no_cuda \
```

**For Falcon**, use the below command line for finetuning on the Alpaca dataset.

```bash
python finetune_clm.py \
--model_name_or_path "tiiuae/falcon-7b" \
--bf16 True \
--train_file "/path/to/alpaca_data.json" \
--dataset_concatenation \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 1 \
--do_train \
--learning_rate 1e-4 \
--num_train_epochs 3 \
--logging_steps 100 \
--save_total_limit 2 \
--overwrite_output_dir \
--log_level info \
--save_strategy epoch \
--output_dir ./falcon_peft_finetuned_model \
--peft lora \
--use_fast_tokenizer True \
--no_cuda \
```

**For Mistral**, use the below command line for finetuning on the Alpaca dataset.

```bash
python finetune_clm.py \
--model_name_or_path "mistralai/Mistral-7B-v0.1" \
--bf16 True \
--train_file "/path/to/alpaca_data.json" \
--dataset_concatenation \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 1 \
--do_train \
--learning_rate 1e-4 \
--num_train_epochs 3 \
--logging_steps 100 \
--save_total_limit 2 \
--overwrite_output_dir \
--log_level info \
--save_strategy epoch \
--output_dir ./mistral_peft_finetuned_model \
--peft lora \
--use_fast_tokenizer True \
--no_cuda \
```

**For StarCoder**, use the below command line for finetuning on the [sahil2801/CodeAlpaca-20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) code instruction dataset.

```bash
python finetune_clm.py \
--model_name_or_path "bigcode/starcoder" \
--bf16 True \
--dataset_name "sahil2801/CodeAlpaca-20k" \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 1 \
--do_train \
--learning_rate 1e-4 \
--num_train_epochs 3 \
--logging_steps 100 \
--save_total_limit 2 \
--overwrite_output_dir \
--log_level info \
--save_strategy epoch \
--output_dir ./starcoder_peft_finetuned_model \
--peft lora \
--use_fast_tokenizer True \
--no_cuda \
```

Where the `--dataset_concatenation` argument is a way to vastly accelerate the fine-tuning process through training samples concatenation. With several tokenized sentences concatenated into a longer and concentrated sentence as the training sample instead of having several training samples with different lengths, this way is more efficient due to the parallelism characteristic provided by the more concentrated training samples.

For finetuning on SPR, add `--bf16` argument will speedup the finetuning process without the loss of model's performance.
Expand Down
14 changes: 13 additions & 1 deletion intel_extension_for_transformers/neural_chat/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@
)
)

# QA template
register_conv_template(
Conversation(
name="question_answer",
roles=("Question: ", "Answer: "),
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="\n\n",
)
)

class PromptTemplate:
def __init__(self, name="one_shot"):
self.conv = get_conv_template(name)
Expand All @@ -170,4 +180,6 @@ def append_message(self, role: str, message: str):

def get_prompt(self) -> str:
return self.conv.get_prompt()


def clear_messages(self) -> str:
self.conv.messages = []
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ rouge_score
openpyxl
numpy==1.23.5
tiktoken==0.4.0
lm_eval
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
spacy
neural-compressor==2.3.1
intel_extension_for_pytorch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ rouge_score
openpyxl
numpy==1.23.5
tiktoken==0.4.0
lm_eval
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.1.0
torchaudio==2.1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
{"instruction": "Provide the word that comes immediately after the.", "input": "He threw the ball over the fence.", "output": "fence."}
]
"""
test_data_file = './test.json'
test_data_file = './alpaca_test.json'

class TestFinetuning(unittest.TestCase):
@classmethod
Expand Down

0 comments on commit 477018d

Please sign in to comment.