Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/open_llama/open_llama_qlora_tinycodes.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
}
},
"params_config": {
"dataset_name": "nampdn-ai/tiny-codes",
"data_name": "nampdn-ai/tiny-codes",
"split": "train",
"component_kwargs": {
"load_dataset": {
Expand Down
4 changes: 2 additions & 2 deletions examples/open_llama/qlora_user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@

# TODO(jambayk): remove custom dataset component once default dataset component supports filter, tokens and split
@Registry.register_dataset()
def load_tiny_code_dataset(dataset_name: str, split: str, language: str, token: Union[bool, str] = True):
dataset = load_dataset(dataset_name, split=split, token=token)
def load_tiny_code_dataset(data_name: str, split: str, language: str, token: Union[bool, str] = True):
dataset = load_dataset(data_name, split=split, token=token)
return dataset.filter(lambda x: x["programming_language"] == language)
80 changes: 65 additions & 15 deletions olive/data/component/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class TextGenParams(ConfigBase):
pad_to_max_len: bool = True # pad sequences to max_len, ignored for JOIN corpus strategy
drop_short_sequences: bool = False # drop sequences shorter than max_len. Mutually exclusive with pad_to_max_len
add_special_tokens: bool = True # add bos and eos tokens to each sequence
use_attention_mask: bool = True # add attention mask to each example

@validator("drop_short_sequences", always=True)
def _check_padding(cls, v, values):
Expand Down Expand Up @@ -142,6 +143,23 @@ def _check_max_samples(cls, v, values):
raise ValueError("max_samples must be specified when corpus_strategy is random")
return v

@validator("corpus_strategy", always=True)
def _check_use_attention_mask(cls, v, values):
if "use_attention_mask" not in values:
raise ValueError("Invalid use_attention_mask")
if "pad_to_max_len" not in values:
raise ValueError("Invalid pad_to_max_len")
use_attention_mask = values["use_attention_mask"]
pad_to_max_len = values["pad_to_max_len"]
if "join" in v:
# both True and False are valid since attention_mask is all 1s
return v
if not use_attention_mask and pad_to_max_len:
raise ValueError(
"pad_to_max_len is True but use_attention_mask is False. Attention mask is required for padding!"
)
return v

@validator("random_seed", always=True)
def _check_random(cls, v, values):
if "corpus_strategy" not in values:
Expand Down Expand Up @@ -180,6 +198,16 @@ def _check_custom(cls, v, values, field):
v = validate_text_source(v, values, field.name)
return v

@validator("use_attention_mask", always=True)
def _check_use_attention_mask(cls, v, values):
if "pad_to_max_len" not in values:
raise ValueError("Invalid pad_to_max_len")
if not v and values["pad_to_max_len"]:
raise ValueError(
"pad_to_max_len is True but use_attention_mask is False. Attention mask is required for padding!"
)
return v


def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs):
"""Pre-process data for text generation task with 'corpus' dataset type.
Expand Down Expand Up @@ -207,9 +235,10 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs):

tokenized_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
if args.use_attention_mask:
tokenized_inputs["attention_mask"] = []
if "join" in args.corpus_strategy:
joiner_tokens = tokenizer.encode(args.joiner, add_special_tokens=False) if args.joiner else []

Expand Down Expand Up @@ -256,7 +285,13 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs):
end_loc = begin_loc + args.source_max_len
# get the input sequence
input_ids = torch.tensor(joined_input_ids[begin_loc:end_loc])
append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer, context=context)
append_text_gen_input_ids(
tokenized_inputs,
input_ids,
tokenizer,
context=context,
use_attention_mask=args.use_attention_mask,
)
num_samples += 1
if args.max_samples is not None and num_samples >= args.max_samples:
# we have reached max_samples
Expand Down Expand Up @@ -294,7 +329,9 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs):
if len(joined_input_ids) >= args.source_max_len:
# found a good example
input_ids = torch.tensor(joined_input_ids[: args.source_max_len])
append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer)
append_text_gen_input_ids(
tokenized_inputs, input_ids, tokenizer, use_attention_mask=args.use_attention_mask
)
break
resamples += 1
else:
Expand All @@ -305,7 +342,9 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs):
batched_input_ids = batch_tokenize_text(text_list, tokenizer, args)
for native_input_ids in batched_input_ids:
input_ids = torch.tensor(native_input_ids)
append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer)
append_text_gen_input_ids(
tokenized_inputs, input_ids, tokenizer, use_attention_mask=args.use_attention_mask
)
else:
example_idx = 0 # index of the first example in the current batch
num_samples = 0
Expand All @@ -321,7 +360,9 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs):
)
for native_input_ids in batched_input_ids:
input_ids = torch.tensor(native_input_ids)
append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer)
append_text_gen_input_ids(
tokenized_inputs, input_ids, tokenizer, use_attention_mask=args.use_attention_mask
)
# update counters
num_samples += len(batched_input_ids)
example_idx += examples_to_get
Expand Down Expand Up @@ -355,7 +396,9 @@ def text_gen_corpus_pre_process(dataset, tokenizer, all_kwargs):
if not encodings:
# could not find a good sample after resampling
continue
append_text_gen_input_ids(tokenized_inputs, encodings.input_ids[0], tokenizer)
append_text_gen_input_ids(
tokenized_inputs, encodings.input_ids[0], tokenizer, use_attention_mask=args.use_attention_mask
)

# convert to HFDataset
hf_dataset = HFDataset.from_dict(tokenized_inputs)
Expand Down Expand Up @@ -403,9 +446,10 @@ def text_gen_pair_pre_process(dataset, tokenizer, all_kwargs):
# build tokenized_inputs
tokenized_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
}
if args.use_attention_mask:
tokenized_inputs["attention_mask"] = []
# max_len is the max length of the concatenated input and output
# if pad_to_max_len is True, max_len is the max length of the concatenated input and output
max_len = args.source_max_len + args.target_max_len
Expand All @@ -416,15 +460,17 @@ def text_gen_pair_pre_process(dataset, tokenizer, all_kwargs):
# skip short sequences if drop_short_sequences is True
continue
if args.pad_to_max_len:
if not tokenizer.pad_token_id:
if tokenizer.pad_token_id is None:
raise ValueError("Tokenizer does not have a pad token")
# add padding to max_len
input_ids = torch.nn.functional.pad(
input_ids, (0, max_len - input_ids.shape[0]), value=tokenizer.pad_token_id
)
# if ignore_source_in_labels is True, the source tokens are treated as context and set to ignore_index in labels
context = len(tokenized_source) if args.ignore_source_in_labels else None
append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer, context=context)
append_text_gen_input_ids(
tokenized_inputs, input_ids, tokenizer, context=context, use_attention_mask=args.use_attention_mask
)

# convert to HFDataset
hf_dataset = HFDataset.from_dict(tokenized_inputs)
Expand Down Expand Up @@ -486,15 +532,18 @@ def batch_tokenize_text(text_list, tokenizer, args):
return batched_input_ids


def append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer, context: int = None, ignore_index=IGNORE_INDEX):
def append_text_gen_input_ids(
tokenized_inputs, input_ids, tokenizer, context: int = None, ignore_index=IGNORE_INDEX, use_attention_mask=True
):
"""Convert input_ids to inputs dict and append to tokenized_inputs."""
inputs = {"input_ids": input_ids}

# create attention_mask
attention_mask = (
torch.ones_like(input_ids) if tokenizer.pad_token_id is None else input_ids.ne(tokenizer.pad_token_id)
)
inputs["attention_mask"] = attention_mask
if use_attention_mask:
attention_mask = (
torch.ones_like(input_ids) if tokenizer.pad_token_id is None else input_ids.ne(tokenizer.pad_token_id)
)
inputs["attention_mask"] = attention_mask

# create labels
# target is not shifted by 1 since causal lm models shifts internally when computing loss
Expand All @@ -503,7 +552,8 @@ def append_text_gen_input_ids(tokenized_inputs, input_ids, tokenizer, context: i
if context is not None:
labels[:context] = ignore_index
# set padding to ignore_index
labels[attention_mask != 1] = ignore_index
if use_attention_mask:
labels[attention_mask != 1] = ignore_index
inputs["labels"] = labels

# add to list
Expand Down
7 changes: 4 additions & 3 deletions olive/model/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def load_model(self, model_path: str = None):
def load_model_config(self, model_path: str = None):
"""Load model config from model_path or model_name."""
model_name_or_path = model_path or self.model_name
return get_hf_model_config(model_name_or_path)
loading_args = self.model_loading_args.get_loading_args() if self.model_loading_args else {}
return get_hf_model_config(model_name_or_path, **loading_args)


def load_huggingface_model_from_task(task: str, name: str, **kwargs):
Expand Down Expand Up @@ -275,9 +276,9 @@ def huggingface_model_loader(model_loader):
return model_loader.from_pretrained


def get_hf_model_config(model_name: str):
def get_hf_model_config(model_name: str, **kwargs):
"""Get HF Config for the given model name."""
return AutoConfig.from_pretrained(model_name)
return AutoConfig.from_pretrained(model_name, **kwargs)


def load_huggingface_model_from_model_class(model_class: str, name: str, **kwargs):
Expand Down
52 changes: 34 additions & 18 deletions olive/passes/pytorch/qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from copy import deepcopy
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from typing import Any, ClassVar, Dict, List, Tuple, Union

import torch
import transformers
Expand Down Expand Up @@ -122,6 +122,10 @@ class QLoRA(Pass):
This pass only supports PyTorchModel with hf_config.
"""

# these are the attributes of the model (in hf_config) that will be overwritten by the pass
# values from the input model will be ignored and new values will be set based on the pass config
model_overwrites: ClassVar[tuple] = ("torch_dtype", "device_map", "quantization_method", "quantization_config")
Comment thread
trajepl marked this conversation as resolved.

@staticmethod
def _default_config(accelerator_spec: AcceleratorSpec) -> Dict[str, PassConfigParam]:
return {
Expand Down Expand Up @@ -267,6 +271,11 @@ def _run_for_config(
new_model.model = None
# remove the device map since we don't want "auto" device map
new_model.hf_config.model_loading_args.device_map = None
# remove model_overwrites from model_attributes
if new_model.model_attributes:
for k in QLoRA.model_overwrites:
new_model.model_attributes.pop(k, None)

# set adapter_path
new_model.set_resource("adapter_path", adapter_path)

Expand Down Expand Up @@ -308,24 +317,31 @@ def get_model_tokenizer(
compute_dtype = supported_dtypes[config.compute_dtype]

# load model, reset model_loading_args and adapter_path
model_loading_args = {}
if new_model.hf_config.model_loading_args:
logger.warning(
"Input model has model_loading_args. Ignoring. QLoRA will use its own model_loading_args based on the"
" pass config."
)
new_model.hf_config.model_loading_args = HFModelLoadingArgs(
torch_dtype=compute_dtype,
# TODO(jambayk): Worry about `use_multi_gpu` and distributed training later
# this uses all available GPUs, model parallel
device_map="auto",
quantization_method="bitsandbytes",
quantization_config={
"load_in_4bit": True,
"bnb_4bit_compute_dtype": compute_dtype,
"bnb_4bit_use_double_quant": config.double_quant,
"bnb_4bit_quant_type": config.quant_type,
},
model_loading_args = new_model.hf_config.model_loading_args.dict()
for k in QLoRA.model_overwrites:
if model_loading_args.get(k) is not None:
logger.warning(
f"Input model has model_loading_args.{k}. Ignoring. QLoRA will overwrite it based on the pass"
" config."
)
model_loading_args.update(
{
"torch_dtype": compute_dtype,
# TODO(jambayk): Worry about `use_multi_gpu` and distributed training later
# this uses all available GPUs, model parallel
"device_map": "auto",
"quantization_method": "bitsandbytes",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_compute_dtype": compute_dtype,
"bnb_4bit_use_double_quant": config.double_quant,
"bnb_4bit_quant_type": config.quant_type,
},
}
)
new_model.hf_config.model_loading_args = HFModelLoadingArgs(**model_loading_args)
if new_model.get_resource("adapter_path"):
logger.warning(
"Input model has adapter_path. Ignoring. QLoRA will save the adapter weights to its own adapter_path."
Expand All @@ -347,7 +363,7 @@ def get_model_tokenizer(
# TODO(jambayk): Do this in a better way since the embedding size might become unoptimal
# (not a multiple of 64, etc) perhaps use eos_token as pad_token, but need to ensure the actual eos_token
# at the end of the sequence is not masked (both in attention mask and loss calculation)
if not tokenizer.pad_token_id:
if tokenizer.pad_token_id is None:
cls.smart_tokenizer_and_embedding_resize(
special_tokens_dict={"pad_token": DEFAULT_PAD_TOKEN}, tokenizer=tokenizer, model=pytorch_model
)
Expand Down
4 changes: 2 additions & 2 deletions test/unit_test/workflows/test_run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def setup(self):
"params_config": {
"model_name": "dummy_model2",
"task": "dummy_task2",
"dataset_name": "dummy_dataset2",
"data_name": "dummy_dataset2",
},
}
},
Expand All @@ -174,7 +174,7 @@ def test_auto_insert_model_name_and_task(self, model_name, task, expected_model_
config_dict["data_configs"]["dummy_data_config2"]["params_config"] = {
"model_name": model_name,
"task": task,
"dataset_name": "dummy_dataset2",
"data_name": "dummy_dataset2",
}

run_config = RunConfig.parse_obj(config_dict)
Expand Down