Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates run_lora_clm.py with enhanced dataset support #955

Merged
merged 9 commits into from
Jun 11, 2024
77 changes: 55 additions & 22 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,27 @@ class DataArguments:
save_last_ckpt: bool = field(
default=True, metadata={"help": "Whether to save checkpoint at the end of the training."}
)
instruction_column_name: Optional[str] = field(
default=None,
metadata={
"help": "Name of the column in the dataset that describes the task that the model should perform. By "
"default, the 'instruction' column is used for non-SQL prompts and the 'question' column is used for SQL prompts."
},
)
input_column_name: Optional[str] = field(
default=None,
metadata={
"help": "Name of the column in the dataset that optionally provides context or input for the task. By "
"default, the 'input' column is used for non-SQL prompts and the 'context' column is used for SQL prompts."
},
)
output_column_name: Optional[str] = field(
default=None,
metadata={
"help": "Name of the column in the dataset with the answer to the instruction. By default, the "
"'output' column is used for non-SQL prompts and the 'answer' column is used for SQL prompts."
},
)


@dataclass
Expand Down Expand Up @@ -357,7 +378,7 @@ def create_prompts(examples):
prompts["target"] = []
for example in examples:
prompt_template = (
PROMPT_DICT["prompt_with_input"] if example["input"] != "" else PROMPT_DICT["prompt_without_input"]
PROMPT_DICT["prompt_with_input"] if example.get("input", "") != "" else PROMPT_DICT["prompt_without_input"]
)
source = prompt_template.format_map(example)
prompts["source"].append(source)
Expand Down Expand Up @@ -531,19 +552,7 @@ def main():
**dataset_args,
)

if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt:
# Preprocessing the datasets.
for key in raw_datasets:
prompts = (
create_prompts(raw_datasets[key])
if not data_args.sql_prompt
else create_sql_prompts(raw_datasets[key])
)
columns_to_be_removed = list(raw_datasets[key].features.keys())
raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"])
raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"])
raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed)
elif (
if (
data_args.dataset_name == "timdettmers/openassistant-guanaco"
): # from https://github.com/artidoro/qlora/blob/main/qlora.py#L621
raw_datasets = raw_datasets.map(
Expand All @@ -557,7 +566,33 @@ def main():
[col for col in raw_datasets.column_names["train"] if col not in ["input", "output"]]
)
else:
raise ValueError("Unsupported dataset")
# Preprocessing the datasets.
for key in raw_datasets:
if data_args.instruction_column_name:
raw_datasets[key] = raw_datasets[key].rename_column(
data_args.instruction_column_name, "question" if data_args.sql_prompt else "instruction"
)

if data_args.input_column_name:
raw_datasets[key] = raw_datasets[key].rename_column(
data_args.input_column_name, "context" if data_args.sql_prompt else "input"
)

if data_args.output_column_name:
raw_datasets[key] = raw_datasets[key].rename_column(
data_args.output_column_name, "answer" if data_args.sql_prompt else "output"
)

prompts = (
create_prompts(raw_datasets[key])
if not data_args.sql_prompt
else create_sql_prompts(raw_datasets[key])
)
columns_to_be_removed = list(raw_datasets[key].features.keys())
raw_datasets[key] = raw_datasets[key].add_column("prompt_sources", prompts["source"])
raw_datasets[key] = raw_datasets[key].add_column("prompt_targets", prompts["target"])
raw_datasets[key] = raw_datasets[key].remove_columns(columns_to_be_removed)

# Load model
if model_args.model_name_or_path:
model_dtype = torch.bfloat16 if training_args.bf16 else None
Expand Down Expand Up @@ -661,18 +696,16 @@ def concatenate_data(dataset, max_seq_length):
concatenated_dataset[column] = reshaped_data
return datasets.Dataset.from_dict(concatenated_dataset)

if data_args.dataset_name == "tatsu-lab/alpaca" or data_args.sql_prompt:
if data_args.dataset_name == "timdettmers/openassistant-guanaco":
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"])
if training_args.do_eval:
tokenized_datasets_eval_ = tokenized_datasets["test"].remove_columns(["input", "output"])
else:
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["prompt_sources", "prompt_targets"])
if training_args.do_eval:
tokenized_datasets_eval_ = tokenized_datasets["validation"].remove_columns(
["prompt_sources", "prompt_targets"]
)
elif data_args.dataset_name == "timdettmers/openassistant-guanaco":
tokenized_datasets_ = tokenized_datasets["train"].remove_columns(["input", "output"])
if training_args.do_eval:
tokenized_datasets_eval_ = tokenized_datasets["test"].remove_columns(["input", "output"])
else:
raise ValueError("Unsupported dataset")
tokenized_datasets["train"] = concatenate_data(tokenized_datasets_, data_args.max_seq_length)
if training_args.do_eval:
tokenized_datasets["validation"] = concatenate_data(tokenized_datasets_eval_, data_args.max_seq_length)
Expand Down
36 changes: 36 additions & 0 deletions tests/baselines/llama_7b.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,42 @@
}
},
"gaudi2": {
"databricks/databricks-dolly-15k": {
"num_train_epochs": 1,
"eval_batch_size": 8,
"distribution": {
"single_card": {
"learning_rate": 2e-4,
"train_batch_size": 16,
"perplexity": 3.8436,
"train_runtime": 113.9713,
"train_samples_per_second": 18.428,
"extra_arguments": [
"--bf16",
"--gradient_accumulation_steps 1",
"--evaluation_strategy no",
"--save_strategy no",
"--warmup_ratio 0.03",
"--lr_scheduler_type constant",
"--max_grad_norm 0.3",
"--logging_steps 1",
"--use_hpu_graphs_for_inference",
"--lora_rank 8",
"--lora_alpha 16",
"--lora_dropout 0.1",
"--lora_target_modules q_proj v_proj",
"--dataset_concatenation",
"--low_cpu_mem_usage True",
"--adam_epsilon 1e-08",
"--validation_split_percentage 20",
"--attn_softmax_bf16",
"--max_steps 100",
"--input_column_name context",
"--output_column_name response"
]
}
}
},
"tatsu-lab/alpaca": {
"num_train_epochs": 3,
"eval_batch_size": 4,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,12 @@ class ProteinFoldingExampleTester2(ExampleTesterBase, metaclass=ExampleTestMeta,
pass


class CausalLanguageModelingLORAExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm"
):
TASK_NAME = "databricks/databricks-dolly-15k"


class MultiCardCausalLanguageModelingLORAExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_lora_clm", multi_card=True
):
Expand Down
Loading