diff --git a/examples/open_llama/open_llama_qlora_tinycodes.json b/examples/open_llama/open_llama_qlora_tinycodes.json index b41c742b7f..db3dcaf199 100644 --- a/examples/open_llama/open_llama_qlora_tinycodes.json +++ b/examples/open_llama/open_llama_qlora_tinycodes.json @@ -19,7 +19,7 @@ } }, "params_config": { - "dataset_name": "nampdn-ai/tiny-codes", + "data_name": "nampdn-ai/tiny-codes", "split": "train", "component_kwargs": { "load_dataset": { diff --git a/examples/open_llama/qlora_user_script.py b/examples/open_llama/qlora_user_script.py index 7cdae58229..0243d026c7 100644 --- a/examples/open_llama/qlora_user_script.py +++ b/examples/open_llama/qlora_user_script.py @@ -11,6 +11,6 @@ # TODO(jambayk): remove custom dataset component once default dataset component supports filter, tokens and split @Registry.register_dataset() -def load_tiny_code_dataset(dataset_name: str, split: str, language: str, token: Union[bool, str] = True): - dataset = load_dataset(dataset_name, split=split, token=token) +def load_tiny_code_dataset(data_name: str, split: str, language: str, token: Union[bool, str] = True): + dataset = load_dataset(data_name, split=split, token=token) return dataset.filter(lambda x: x["programming_language"] == language) diff --git a/olive/data/component/text_generation.py b/olive/data/component/text_generation.py index f5cb1e03c2..e9d102f89f 100644 --- a/olive/data/component/text_generation.py +++ b/olive/data/component/text_generation.py @@ -76,6 +76,7 @@ class TextGenParams(ConfigBase): pad_to_max_len: bool = True # pad sequences to max_len, ignored for JOIN corpus strategy drop_short_sequences: bool = False # drop sequences shorter than max_len. Mutually exclusive with pad_to_max_len add_special_tokens: bool = True # add bos and eos tokens to each sequence + use_attention_mask: bool = True # add attention mask to each example @validator("drop_short_sequences", always=True) def _check_padding(cls, v, values): @@ -142,6 +143,23 @@ def _check_max_samples(cls, v, values): raise ValueError("max_samples must be specified when corpus_strategy is random") return v + @validator("corpus_strategy", always=True) + def _check_use_attention_mask(cls, v, values): + if "use_attention_mask" not in values: + raise ValueError("Invalid use_attention_mask") + if "pad_to_max_len" not in values: + raise ValueError("Invalid pad_to_max_len") + use_attention_mask = values["use_attention_mask"] + pad_to_max_len = values["pad_to_max_len"] + if "join" in v: + # both True and False are valid since attention_mask is all 1s + return v + if not use_attention_mask and pad_to_max_len: + raise ValueError( + "pad_to_max_len is True but use_attention_mask is False. Attention mask is required for padding!" + ) + return v + @validator("random_seed", always=True) def _check_random(cls, v, values): if "corpus_strategy" not in values: @@ -180,6 +198,16 @@ def _check_custom(cls, v, values, field): v = validate_text_source(v, values, field.name) return v + @validator("use_attention_mask", always=True) + def _check_use_attention_mask(cls, v, values): + if "pad_to_max_len" not in values: + raise ValueError("Invalid pad_to_max_len") + if not v and values["pad_to_max_len"]: + raise ValueError( + "pad_to_max_len is True but use_attention_mask is False. Attention mask is required for padding!" + ) + return v + def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs): """Pre-process data for text generation task with 'corpus' dataset type. @@ -207,9 +235,10 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs): tokenized_inputs = { "input_ids": [], - "attention_mask": [], "labels": [], } + if args.use_attention_mask: + tokenized_inputs["attention_mask"] = [] if "join" in args.corpus_strategy: joiner_tokens = tokenizer.encode(args.joiner, add_special_tokens=False) if args.joiner else [] @@ -256,7 +285,13 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs): end_loc = begin_loc + args.source_max_len # get the input sequence input_ids = torch.tensor(joined_input_ids[begin_loc:end_loc]) - append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer, context=context) + append_text_gen_input_ids( + tokenized_inputs, + input_ids, + tokenizer, + context=context, + use_attention_mask=args.use_attention_mask, + ) num_samples += 1 if args.max_samples is not None and num_samples >= args.max_samples: # we have reached max_samples @@ -294,7 +329,9 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs): if len(joined_input_ids) >= args.source_max_len: # found a good example input_ids = torch.tensor(joined_input_ids[: args.source_max_len]) - append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer) + append_text_gen_input_ids( + tokenized_inputs, input_ids, tokenizer, use_attention_mask=args.use_attention_mask + ) break resamples += 1 else: @@ -305,7 +342,9 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs): batched_input_ids = batch_tokenize_text(text_list, tokenizer, args) for native_input_ids in batched_input_ids: input_ids = torch.tensor(native_input_ids) - append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer) + append_text_gen_input_ids( + tokenized_inputs, input_ids, tokenizer, use_attention_mask=args.use_attention_mask + ) else: example_idx = 0 # index of the first example in the current batch num_samples = 0 @@ -321,7 +360,9 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs): ) for native_input_ids in batched_input_ids: input_ids = torch.tensor(native_input_ids) - append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer) + append_text_gen_input_ids( + tokenized_inputs, input_ids, tokenizer, use_attention_mask=args.use_attention_mask + ) # update counters num_samples += len(batched_input_ids) example_idx += examples_to_get @@ -355,7 +396,9 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs): if not encodings: # could not find a good sample after resampling continue - append_text_gen_input_ids(tokenized_inputs, encodings.input_ids[0], tokenizer) + append_text_gen_input_ids( + tokenized_inputs, encodings.input_ids[0], tokenizer, use_attention_mask=args.use_attention_mask + ) # convert to HFDataset hf_dataset = HFDataset.from_dict(tokenized_inputs) @@ -403,9 +446,10 @@ def text_gen_pair_pre_process(dataset, tokenizer, all_kwargs): # build tokenized_inputs tokenized_inputs = { "input_ids": [], - "attention_mask": [], "labels": [], } + if args.use_attention_mask: + tokenized_inputs["attention_mask"] = [] # max_len is the max length of the concatenated input and output # if pad_to_max_len is True, max_len is the max length of the concatenated input and output max_len = args.source_max_len + args.target_max_len @@ -416,7 +460,7 @@ def text_gen_pair_pre_process(dataset, tokenizer, all_kwargs): # skip short sequences if drop_short_sequences is True continue if args.pad_to_max_len: - if not tokenizer.pad_token_id: + if tokenizer.pad_token_id is None: raise ValueError("Tokenizer does not have a pad token") # add padding to max_len input_ids = torch.nn.functional.pad( @@ -424,7 +468,9 @@ def text_gen_pair_pre_process(dataset, tokenizer, all_kwargs): ) # if ignore_source_in_labels is True, the source tokens are treated as context and set to ignore_index in labels context = len(tokenized_source) if args.ignore_source_in_labels else None - append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer, context=context) + append_text_gen_input_ids( + tokenized_inputs, input_ids, tokenizer, context=context, use_attention_mask=args.use_attention_mask + ) # convert to HFDataset hf_dataset = HFDataset.from_dict(tokenized_inputs) @@ -486,15 +532,18 @@ def batch_tokenize_text(text_list, tokenizer, args): return batched_input_ids -def append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer, context: int = None, ignore_index=IGNORE_INDEX): +def append_text_gen_input_ids( + tokenized_inputs, input_ids, tokenizer, context: int = None, ignore_index=IGNORE_INDEX, use_attention_mask=True +): """Convert input_ids to inputs dict and append to tokenized_inputs.""" inputs = {"input_ids": input_ids} # create attention_mask - attention_mask = ( - torch.ones_like(input_ids) if tokenizer.pad_token_id is None else input_ids.ne(tokenizer.pad_token_id) - ) - inputs["attention_mask"] = attention_mask + if use_attention_mask: + attention_mask = ( + torch.ones_like(input_ids) if tokenizer.pad_token_id is None else input_ids.ne(tokenizer.pad_token_id) + ) + inputs["attention_mask"] = attention_mask # create labels # target is not shifted by 1 since causal lm models shifts internally when computing loss @@ -503,7 +552,8 @@ def append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer, context: i if context is not None: labels[:context] = ignore_index # set padding to ignore_index - labels[attention_mask != 1] = ignore_index + if use_attention_mask: + labels[attention_mask != 1] = ignore_index inputs["labels"] = labels # add to list diff --git a/olive/model/hf_utils.py b/olive/model/hf_utils.py index 7991c01b54..8883f86815 100644 --- a/olive/model/hf_utils.py +++ b/olive/model/hf_utils.py @@ -226,7 +226,8 @@ def load_model(self, model_path: str = None): def load_model_config(self, model_path: str = None): """Load model config from model_path or model_name.""" model_name_or_path = model_path or self.model_name - return get_hf_model_config(model_name_or_path) + loading_args = self.model_loading_args.get_loading_args() if self.model_loading_args else {} + return get_hf_model_config(model_name_or_path, **loading_args) def load_huggingface_model_from_task(task: str, name: str, **kwargs): @@ -275,9 +276,9 @@ def huggingface_model_loader(model_loader): return model_loader.from_pretrained -def get_hf_model_config(model_name: str): +def get_hf_model_config(model_name: str, **kwargs): """Get HF Config for the given model name.""" - return AutoConfig.from_pretrained(model_name) + return AutoConfig.from_pretrained(model_name, **kwargs) def load_huggingface_model_from_model_class(model_class: str, name: str, **kwargs): diff --git a/olive/passes/pytorch/qlora.py b/olive/passes/pytorch/qlora.py index 62bb554514..abc0487f98 100644 --- a/olive/passes/pytorch/qlora.py +++ b/olive/passes/pytorch/qlora.py @@ -12,7 +12,7 @@ from copy import deepcopy from functools import partial from pathlib import Path -from typing import Any, Dict, List, Tuple, Union +from typing import Any, ClassVar, Dict, List, Tuple, Union import torch import transformers @@ -122,6 +122,10 @@ class QLoRA(Pass): This pass only supports PyTorchModel with hf_config. """ + # these are the attributes of the model (in hf_config) that will be overwritten by the pass + # values from the input model will be ignored and new values will be set based on the pass config + model_overwrites: ClassVar[tuple] = ("torch_dtype", "device_map", "quantization_method", "quantization_config") + @staticmethod def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]: return { @@ -267,6 +271,11 @@ def _run_for_config( new_model.model = None # remove the device map since we don't want "auto" device map new_model.hf_config.model_loading_args.device_map = None + # remove model_overwrites from model_attributes + if new_model.model_attributes: + for k in QLoRA.model_overwrites: + new_model.model_attributes.pop(k, None) + # set adapter_path new_model.set_resource("adapter_path", adapter_path) @@ -308,24 +317,31 @@ def get_model_tokenizer( compute_dtype = supported_dtypes[config.compute_dtype] # load model, reset model_loading_args and adapter_path + model_loading_args = {} if new_model.hf_config.model_loading_args: - logger.warning( - "Input model has model_loading_args. Ignoring. QLoRA will use its own model_loading_args based on the" - " pass config." - ) - new_model.hf_config.model_loading_args = HFModelLoadingArgs( - torch_dtype=compute_dtype, - # TODO(jambayk): Worry about `use_multi_gpu` and distributed training later - # this uses all available GPUs, model parallel - device_map="auto", - quantization_method="bitsandbytes", - quantization_config={ - "load_in_4bit": True, - "bnb_4bit_compute_dtype": compute_dtype, - "bnb_4bit_use_double_quant": config.double_quant, - "bnb_4bit_quant_type": config.quant_type, - }, + model_loading_args = new_model.hf_config.model_loading_args.dict() + for k in QLoRA.model_overwrites: + if model_loading_args.get(k) is not None: + logger.warning( + f"Input model has model_loading_args.{k}. Ignoring. QLoRA will overwrite it based on the pass" + " config." + ) + model_loading_args.update( + { + "torch_dtype": compute_dtype, + # TODO(jambayk): Worry about `use_multi_gpu` and distributed training later + # this uses all available GPUs, model parallel + "device_map": "auto", + "quantization_method": "bitsandbytes", + "quantization_config": { + "load_in_4bit": True, + "bnb_4bit_compute_dtype": compute_dtype, + "bnb_4bit_use_double_quant": config.double_quant, + "bnb_4bit_quant_type": config.quant_type, + }, + } ) + new_model.hf_config.model_loading_args = HFModelLoadingArgs(**model_loading_args) if new_model.get_resource("adapter_path"): logger.warning( "Input model has adapter_path. Ignoring. QLoRA will save the adapter weights to its own adapter_path." @@ -347,7 +363,7 @@ def get_model_tokenizer( # TODO(jambayk): Do this in a better way since the embedding size might become unoptimal # (not a multiple of 64, etc) perhaps use eos_token as pad_token, but need to ensure the actual eos_token # at the end of the sequence is not masked (both in attention mask and loss calculation) - if not tokenizer.pad_token_id: + if tokenizer.pad_token_id is None: cls.smart_tokenizer_and_embedding_resize( special_tokens_dict={"pad_token": DEFAULT_PAD_TOKEN}, tokenizer=tokenizer, model=pytorch_model ) diff --git a/test/unit_test/workflows/test_run_config.py b/test/unit_test/workflows/test_run_config.py index 4be6d98e04..027bbd2533 100644 --- a/test/unit_test/workflows/test_run_config.py +++ b/test/unit_test/workflows/test_run_config.py @@ -152,7 +152,7 @@ def setup(self): "params_config": { "model_name": "dummy_model2", "task": "dummy_task2", - "dataset_name": "dummy_dataset2", + "data_name": "dummy_dataset2", }, } }, @@ -174,7 +174,7 @@ def test_auto_insert_model_name_and_task(self, model_name, task, expected_model_ config_dict["data_configs"]["dummy_data_config2"]["params_config"] = { "model_name": model_name, "task": task, - "dataset_name": "dummy_dataset2", + "data_name": "dummy_dataset2", } run_config = RunConfig.parse_obj(config_dict)