Skip to content
This repository was archived by the owner on Sep 23, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions docs/finetune_parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,25 @@ The following are the parameters supported in the finetuning workflow.


## Dataset Parameters
|Configuration Name| Default|Meaning|
|-|-|-|
|train_file|examples/data/sample_finetune_data.jsonl|A json file containing the training data.|
|validation_file|None|A json file containing the validation data.|
|validation_split_percentage|5|The percentage of the train set used as validation set in case there's no validation split|
|preprocessing_num_workers|None|The number of processes to use for the preprocessing.|
|max_length|512|Padding sequential data to max length of a batch|
|group|True|Whether to concatenate the sentence for more efficient training|
|block_size|512|The block size of concatenated sentence|
|shuffle|False|Whether shuffle the data at every epoch|
| Configuration Name | Default| Meaning |
|-----------------------------|-|------------------------------------------------------------------------------------------------------------------------------------------|
| train_file |examples/data/sample_finetune_data.jsonl| A json file containing the training data. |
| validation_file |None| A json file containing the validation data. |
| validation_split_percentage |5| The percentage of the train set used as validation set in case there's no validation split |
| preprocessing_num_workers |None| The number of processes to use for the preprocessing. |
| max_length |512| Padding sequential data to max length of a batch |
| group |True| Whether to concatenate the sentence for more efficient training |
| block_size |512| The block size of concatenated sentence |
| shuffle |False| Whether shuffle the data at every epoch |
| max_source_length |384| The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. |
| padding_side |right| The side on which the model should have padding applied. Should be selected between ['right', 'left']. |
| truncation_side |right| The side on which the model should have truncation applied. Should be selected between ['right', 'left']. |
| max_seq_length |max_length| The maximum total input sequence length after tokenization. |
| truncation |True| truncation strategy. Should be selected between ['only_first', 'only_second', 'longest_first/True', 'do_not_truncate/False']. |
| padding |True| padding strategy. Should be selected between ['longest/True', 'do_not_pad/False', 'max_length']
| mask_input |True| mask the input part in lables |
| mask_response |True| mask the response part in lables |
| data_preprocess_type |neural_chat| The type of the encode input |


## Training Parameters
Expand Down
220 changes: 220 additions & 0 deletions llm_on_ray/finetune/data_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
#
# Copyright 2023 The LLM-on-Ray Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import copy
import re
from itertools import chain

import torch

IGNORE_INDEX = -100


class DataProcessor:
# We used the following prompts for fine-tuning the Alpaca model. You can find reference doc form this URL(https://github.com/tatsu-lab/stanford_alpaca/blob/main/README.md#data-release)
def __init__(self, config, tokenizer):
self.tokenizer = tokenizer
self.end = tokenizer.eos_token
self.intro = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
self.instruction = "### Instruction:\n"
self.input = "### Input:\n"
self.response = "### Response:\n"
self.padding_side = config["Dataset"].get("padding_side", "right")
self.truncation_side = config["Dataset"].get("truncation_side", "right")
self.max_length = self.max_seq_length = config["Dataset"].get("max_length", 512)
self.max_source_length = config["Dataset"].get("max_source_length", 384)
self.truncation = config["Dataset"].get("truncation", True)
self.padding = config["Dataset"].get("padding", True)
self.mask_input = config["Dataset"].get("mask_input", True)
self.mask_response = config["Dataset"].get("mask_response", True)

def make_prompt(self, examples):
prompts = {}
prompts["prompt_sources"] = []
prompts["prompt_targets"] = []
for rec in examples:
instruction = rec["instruction"]
response = rec["response"]
context = rec.get("context")
if not instruction:
raise ValueError(f"Expected an instruction in: {rec}")
if not response:
raise ValueError(f"Expected a response in: {rec}")
if context:
prompt = (
self.intro
+ self.end
+ "\n"
+ self.instruction
+ instruction
+ self.input
+ context
+ self.end
+ "\n"
+ self.response
)
prompts["prompt_sources"].append(prompt)
else:
prompt = (
self.intro
+ self.end
+ "\n"
+ self.instruction
+ instruction
+ self.end
+ "\n"
+ self.response
)
prompts["prompt_sources"].append(prompt)
prompt_response = response + self.end
prompts["prompt_targets"].append(prompt_response)
return prompts

def __truncate_sequences(self, sequences, max_length):
"""
Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L40
"""
words_to_cut = sum(list(map(len, sequences))) - max_length
if words_to_cut <= 0:
return sequences

while words_to_cut > 0 and len(sequences) > 0:
words_to_cut -= len(sequences[0])
sequences = sequences[1:]
return sequences

def tokenize_by_neural_chat(self, examples):
"""
Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L225
The only differences are:
- using our own prompt style
- add left or right padding and truncation
- add mask_input and mask_response
"""
keys = list(examples.data.keys())
if len(keys) != 2:
raise ValueError("Unsupported dataset format")
assistant_tokens = self.tokenizer.tokenize(self.response)
header = self.intro + self.end + "\n"

examples["input_ids"] = []
examples["labels"] = []
examples["attention_mask"] = []
for instruction, response in zip(examples[keys[0]], examples[keys[1]]):
convs = re.findall(
r"{0}.*?{2}|{1}.*?{2}".format(self.instruction, self.response, self.end),
instruction,
re.DOTALL,
)
convs_tokens = [
self.tokenizer.tokenize(conv) + self.tokenizer.tokenize("\n") for conv in convs
]
header_tokens = self.tokenizer.tokenize(header) + self.tokenizer.tokenize("\n")
max_input = self.max_source_length - len(header_tokens) - len(assistant_tokens)
truncated_convs = self.__truncate_sequences(convs_tokens, max_input)
if len(truncated_convs) == 0:
truncated_convs = [convs_tokens[-1][: max_input - 3] + convs_tokens[-1][-3:]]

prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens]
prompt_ids = [
self.tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens
]
prompt_ids = list(chain(*prompt_ids))

resp_ids = self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(response.strip())
)
# keep last and eos_id
max_resp = self.max_seq_length - len(prompt_ids) - 1

# truncating response
if len(resp_ids) > max_resp:
if self.truncation_side == "right":
resp_ids = resp_ids[: max_resp - 1] + resp_ids[-1:]
else:
resp_ids = resp_ids[-max_resp:]

# masking
input_ids = prompt_ids + resp_ids + [self.tokenizer.eos_token_id]
if self.mask_input:
labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + [self.tokenizer.eos_token_id]
elif self.mask_response:
labels = prompt_ids + [IGNORE_INDEX] * len(resp_ids) + [self.tokenizer.eos_token_id]
else:
labels = input_ids

# padding
input_len = len(input_ids)
pad_len = self.max_seq_length - input_len
if self.padding_side == "right":
input_ids = input_ids + [self.tokenizer.eos_token_id] * pad_len
labels = labels + [IGNORE_INDEX] * pad_len
attention_mask = [1] * input_len + [0] * pad_len
else:
input_ids = [self.tokenizer.eos_token_id] * pad_len + input_ids
labels = [IGNORE_INDEX] * pad_len + labels
attention_mask = [0] * pad_len + [1] * input_len

assert len(input_ids) == self.max_seq_length
assert len(prompt_ids) <= self.max_source_length
assert len(labels) == len(input_ids) == len(attention_mask)

examples["input_ids"].append(torch.tensor(input_ids))
examples["labels"].append(labels)
examples["attention_mask"].append(attention_mask)

return examples

def tokenize(self, examples):
keys = list(examples.data.keys())
if len(keys) != 2:
raise ValueError("Unsupported dataset format")

examples["input_ids"] = []
examples["labels"] = []
examples["attention_mask"] = []
for s, t in zip(examples[keys[0]], examples[keys[1]]):
results = self.tokenizer(
s + t,
padding=self.padding,
truncation=self.truncation,
return_tensors=None,
max_length=self.max_length,
)

input_ids = results["input_ids"]
input_len = len(input_ids)
labels = copy.deepcopy(input_ids)
if self.mask_input or self.mask_response:
sources_tokenized = self.tokenizer(
s,
padding=False,
truncation=True,
return_tensors=None,
max_length=self.max_length,
)
input_id_len = len(sources_tokenized["input_ids"])
# mask input
if self.mask_input:
labels[:input_id_len] = [IGNORE_INDEX] * input_id_len
# mask response
if self.mask_response:
labels[input_id_len:input_len] = [IGNORE_INDEX] * (input_len - input_id_len)

examples["input_ids"].append(results["input_ids"])
examples["labels"].append(labels)
examples["attention_mask"].append(results["attention_mask"])
return examples
62 changes: 23 additions & 39 deletions llm_on_ray/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

import os
import argparse
import re
import sys
import copy

from typing import Any, Dict, Union, Optional

from itertools import chain
Expand All @@ -37,9 +40,8 @@
from pydantic_yaml import parse_yaml_raw_as

from llm_on_ray import common
from llm_on_ray.finetune import template
from llm_on_ray.finetune.data_process import DataProcessor
from llm_on_ray.finetune.finetune_config import FinetuneConfig
from importlib import util


def adapt_transformers_to_device(config: Dict):
Expand Down Expand Up @@ -140,7 +142,13 @@ def load_tokenizer(config: Dict):
else:
tokenizer_name = config["General"]["base_model"]
load_config = config["General"].get("config", {})
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, **load_config)
# default padding side is right
padding_side = config["Dataset"].get("padding_side", "right")
# default truncation side is right
truncation_side = config["Dataset"].get("truncation_side", "right")
tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer_name, padding_side=padding_side, truncation_side=truncation_side, **load_config
)
return tokenizer


Expand Down Expand Up @@ -195,50 +203,27 @@ def local_load(name, **load_config):


def tokenize_dataset(config: Dict, tokenizer, dataset):
max_length = config["Dataset"].get("max_length", 512)
group = config["Dataset"].get("group", True)
block_size = config["Dataset"].get("block_size", 512)
tokenizer.pad_token = tokenizer.eos_token

if isinstance(dataset, datasets.Dataset):
column_names = dataset.column_names

if isinstance(dataset, datasets.DatasetDict):
column_names = dataset["train"].column_names

if column_names and template.TEXT_COLUMN_NAME not in column_names:

def prompt(rec):
instruction = rec["instruction"]
response = rec["response"]
context = rec.get("context")
if not instruction:
raise ValueError(f"Expected an instruction in: {rec}")
if not response:
raise ValueError(f"Expected a response in: {rec}")
if context:
rec["text"] = template.PROMPT_WITH_INPUT_FORMAT.format(
instruction=instruction, response=response, input=context
)
else:
rec["text"] = template.PROMPT_NO_INPUT_FORMAT.format(
instruction=instruction, response=response
)
return rec
processor = DataProcessor(config, tokenizer)

dataset = dataset.map(
prompt,
load_from_cache_file=False,
desc="Prompt",
)
column_names += [template.TEXT_COLUMN_NAME]
for key in dataset:
prompts = processor.make_prompt(dataset[key])
dataset[key] = datasets.Dataset.from_dict(prompts)

def tokenize_function(examples):
return tokenizer(examples[template.TEXT_COLUMN_NAME], max_length=max_length)
column_names = list(dataset["train"].features)
tokenize_fn = (
processor.tokenize_by_neural_chat
if config["Dataset"].get("data_preprocess_type", "neural_chat") == "neural_chat"
else processor.tokenize
)

tokenized_dataset = dataset.map(
tokenize_function,
tokenize_fn,
remove_columns=column_names,
batched=True,
load_from_cache_file=False,
desc="Tokenize dataset",
)
Expand All @@ -258,7 +243,6 @@ def group_texts(examples):
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result

tokenized_dataset = tokenized_dataset.map(
Expand Down
9 changes: 9 additions & 0 deletions llm_on_ray/finetune/finetune_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ class Dataset(BaseModel):
group: bool = True
block_size: int = 512
shuffle: bool = False
max_source_length: int = 384
padding_side: str = "right"
truncation_side: str = "right"
max_seq_length: int = 512
truncation: bool = True
padding: bool = True
mask_input: bool = True
mask_response: bool = True
data_preprocess_type: str = "neural_chat"


class RayResourceConfig(BaseModel):
Expand Down
Loading