Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention Implementation & Fuller Config Options #139

Merged
merged 9 commits into from
Apr 9, 2024
39 changes: 32 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@ LLM Finetuning toolkit is a config-based CLI tool for launching a series of LLM
</p>

## Installation

### pipx (recommended)

pipx installs the package and depdencies in a seperate virtual environment

```shell
pipx install llm-toolkit
```

### pip

```shell
pip install llm-toolkit
```


## Quick Start

This guide contains 3 stages that will enable you to get the most out of this toolkit!
Expand All @@ -45,6 +48,30 @@ This command initiates the fine-tuning process using the settings specified in t

The configuration file is the central piece that defines the behavior of the toolkit. It is written in YAML format and consists of several sections that control different aspects of the process, such as data ingestion, model definition, training, inference, and quality assurance. We highlight some of the critical sections.

#### Flash Attention 2

To enable Flash-attention for [supported models](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2). First install `flash-attn`:

**pipx**

```shell
pipx inject llm-toolkit flash-attn --pip-args=--no-build-isolation
```

**pip**

```
pip install flash-attn --no-build-isolation
```
Comment on lines +61 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to install it with llm-toolkit, so users do not have to do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not possible as the package doesn't support PEP 517. See python-poetry/poetry#8427, Dao-AILab/flash-attention#493 (comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^^ Also that's the first thing I tried 🙈


Then, add to config file.

```yaml
model:
torch_dtype: "bfloat16" # or "float16" if using older GPU
attn_implementation: "flash_attention_2"
```

#### Data Ingestion

An example of what the data ingestion may look like:
Expand Down Expand Up @@ -247,6 +274,7 @@ NOTE: Be sure to merge the latest from "upstream" before making a pull request!
# GPU
docker run -it --gpus all llm-toolkit
```

</details>

<details>
Expand All @@ -257,6 +285,7 @@ See poetry documentation page for poetry [installation instructions](https://pyt
```shell
poetry install
```

</details>
<details>
<summary>pip</summary>
Expand All @@ -265,27 +294,23 @@ We recommend using a virtual environment like `venv` or `conda` for installation
```shell
pip install -e .
```

</details>
</details>



### Checklist Before Pull Request (Optional)

1. Use `ruff check --fix` to check and fix lint errors
2. Use `ruff format` to apply formatting

NOTE: Ruff linting and formatting checks are done when PR is raised via Git Action. Before raising a PR, it is a good practice to check and fix lint errors, as well as apply formatting.


### Releasing


To manually release a PyPI package, please run:
To manually release a PyPI package, please run:

```shell
make build-release
```

Note: Make sure you have pypi token for this [PyPI repo](https://pypi.org/project/llm-toolkit/).

2 changes: 1 addition & 1 deletion llmtune/cli/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run_one_experiment(config: Config, config_path: str) -> None:
results_file_path = join(dir_helper.save_paths.results, "results.csv")
if not exists(results_path) or exists(results_file_path):
inference_runner = LoRAInference(test, test_column, config, dir_helper)
inference_runner.infer_all()
inference_runner.infer_test_set()
RichUI.after_inference(results_path)
else:
RichUI.inference_found(results_path)
Expand Down
2 changes: 2 additions & 0 deletions llmtune/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def _get_model(self):
),
use_cache=False,
device_map=self.device_map,
torch_dtype=self._model_config.torch_dtype,
attn_implementation=self._model_config.attn_implementation,
)

model.config.pretraining_tp = 1
Expand Down
2 changes: 1 addition & 1 deletion llmtune/inference/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ def infer_one(self, prompt: str):
pass

@abstractmethod
def infer_all(self):
def infer_test_set(self):
pass
2 changes: 1 addition & 1 deletion llmtune/inference/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _get_merged_model(self, weights_path: str):

return model, tokenizer

def infer_all(self):
def infer_test_set(self):
benjaminye marked this conversation as resolved.
Show resolved Hide resolved
results = []
prompts = self.test_dataset["formatted_prompt"]
labels = self.test_dataset[self.label_column]
Expand Down
80 changes: 69 additions & 11 deletions llmtune/pydantic_models/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ class ModelConfig(BaseModel):
description="Path to the model (huggingface repo or local path)",
)
device_map: Optional[str] = Field("auto", description="device onto which to load the model")
torch_dtype: Optional[str] = Field("auto", description="torch dtype to use for model weights")
attn_implementation: Optional[str] = Field(
None,
description="set desired attention implementation; leave None for default. E.g. `flash_attention_2` (please ensure `torch_dtype` is either float16 or bfloat16).",
)

# Quantization Config
quantize: Optional[bool] = Field(False, description="Flag to enable quantization")
bitsandbytes: BitsAndBytesConfig = Field(None, description="Bits and Bytes configuration")

Expand All @@ -99,6 +105,16 @@ def set_device_map_to_none(cls, v, values, **kwargs):
return None
return v

@validator("torch_dtype", pre=True, allow_reuse=True)
def convert_str_to_torch_dtype(cls, v):
try:
# Attempt to retrieve the corresponding PyTorch data type
torch_dtype = getattr(torch, v)
except AttributeError:
# Handle the case where the string does not match any PyTorch data type
raise ValueError(f"{v} is not a valid torch data type")
return torch_dtype


class LoraConfig(BaseModel):
r: Optional[int] = Field(8, description="Lora rank")
Expand Down Expand Up @@ -126,7 +142,6 @@ class LoraConfig(BaseModel):
# )


# TODO: Get comprehensive Args!
class TrainingArgs(BaseModel):
num_train_epochs: Optional[int] = Field(1, description="Number of training epochs")
per_device_train_batch_size: Optional[int] = Field(1, description="Batch size per training device")
Expand All @@ -141,9 +156,12 @@ class TrainingArgs(BaseModel):
max_grad_norm: Optional[float] = Field(0.3, description="Maximum gradient norm")
warmup_ratio: Optional[float] = Field(0.03, description="Warmup ratio")
lr_scheduler_type: Optional[str] = Field("constant", description="Learning rate scheduler type")
save_steps: Optional[Union[int, float]] = Field(
500,
description="Number of updates steps before checkpoint saves. Should be an integer or a float in range [0,1). If smaller than 1, will be interpreted as ratio of total training steps.",
)


# TODO: Get comprehensive Args!
class SftArgs(BaseModel):
max_seq_length: Optional[int] = Field(None, description="Maximum sequence length")
neftune_noise_alpha: Optional[float] = Field(
Expand All @@ -157,16 +175,56 @@ class TrainingConfig(BaseModel):
sft_args: SftArgs


# TODO: Get comprehensive Args!
class InferenceConfig(BaseModel):
max_new_tokens: Optional[int] = Field(None, description="Maximum new tokens")
use_cache: Optional[bool] = Field(True, description="Flag to enable cache usage")
do_sample: Optional[bool] = Field(True, description="Flag to enable sampling")
top_p: Optional[float] = Field(1.0, description="Top p value")
temperature: Optional[float] = Field(0.1, description="Temperature value")
epsilon_cutoff: Optional[float] = Field(0.0, description="epsilon cutoff value")
eta_cutoff: Optional[float] = Field(0.0, description="eta cutoff value")
top_k: Optional[int] = Field(50, description="top-k sampling")
# Length
max_length: Optional[int] = Field(None, description="The maximum length the generated tokens can have.")
max_new_tokens: Optional[int] = Field(None, description="The maximum numbers of tokens to generate.")
min_length: Optional[int] = Field(0, description="The minimum length of the sequence to be generated.")
min_new_tokens: Optional[int] = Field(None, description="The minimum numbers of tokens to generate.")
early_stopping: Optional[Union[bool, str]] = Field(
False, description="Controls the stopping condition for beam search."
)
max_time: Optional[float] = Field(None, description="The maximum amount of time for the computation in seconds.")

# Generation Strategy
do_sample: Optional[bool] = Field(False, description="Whether or not to use sampling.")
num_beams: Optional[int] = Field(1, description="Number of beams for beam search.")
num_beam_groups: Optional[int] = Field(1, description="Number of groups for diversity among beams.")
penalty_alpha: Optional[float] = Field(None, description="Balances model confidence and degeneration penalty.")
use_cache: Optional[bool] = Field(
True,
description="Whether to use past key/values attentions to speed up decoding.",
)

# Manipulation of Model Output Logits
temperature: Optional[float] = Field(1.0, description="Modulates the next token probabilities.")
top_k: Optional[int] = Field(
50,
description="Number of highest probability tokens to keep for top-k-filtering.",
)
top_p: Optional[float] = Field(
1.0,
description="Keeps the smallest set of most probable tokens summing up to top_p.",
)
typical_p: Optional[float] = Field(1.0, description="Local typicality measure.")
epsilon_cutoff: Optional[float] = Field(0.0, description="Minimum conditional probability for token sampling.")
eta_cutoff: Optional[float] = Field(0.0, description="Hybrid of locally typical sampling and epsilon sampling.")
diversity_penalty: Optional[float] = Field(
0.0, description="Penalty for token repetition across different beam groups."
)
repetition_penalty: Optional[float] = Field(1.0, description="Penalty for token repetition.")
encoder_repetition_penalty: Optional[float] = Field(
1.0, description="Penalty on sequences not in the original input."
)
length_penalty: Optional[float] = Field(1.0, description="Exponential penalty to the length for beam search.")
no_repeat_ngram_size: Optional[int] = Field(0, description="Size of ngrams that cannot occur more than once.")
bad_words_ids: Optional[List[List[int]]] = Field(None, description="Tokens that are not allowed to be generated.")
force_words_ids: Optional[List[Union[List[int], List[List[int]]]]] = Field(
None, description="Tokens that must be generated."
)
renormalize_logits: Optional[bool] = Field(
False, description="Whether to renormalize logits after all processors."
)


class AblationConfig(BaseModel):
Expand Down