diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5d71c46..22450a9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -7,12 +7,14 @@ on: paths: - "src/**" - "test/**" + - "examples/**" pull_request: branches: - main paths: - "src/**" - "test/**" + - "examples/**" schedule: # Runs at 00:00 UTC daily - cron: '0 0 * * *' diff --git a/examples/pruning/README.md b/examples/pruning/README.md new file mode 100644 index 0000000..9f725c3 --- /dev/null +++ b/examples/pruning/README.md @@ -0,0 +1,86 @@ +# Pruning LLMs with FMCHISEL + +This example shows how to **prune a Large Language Model (LLM) into a sparse, memory- and compute-efficient variant** using FMCHISEL. It demonstrates data loading, model preparation, unstructured & N:M pruning, and exporting the pruned model back to the Hugging Face Hub format. + +## 1. Background + +* **Unstructured sparsity** – Any weight can be zero. Offers maximum flexibility but requires special kernels for runtime speed-ups. +* **Semi-structured (N:M) sparsity** – Exactly **N** non-zeros in every **M** consecutive weights (e.g. 2:4 = 50 % sparsity). Readily accelerated by NVIDIA A100/H100 tensor cores. + +FMCHISEL implements **ALPS** (ADMM-based Layerwise Pruning with Saliency) and wraps two post-training methods from *llmcompressor*: **SparseGPT** and **Wanda**. See the [ALPS paper](https://arxiv.org/abs/2406.07831) for technical details. + +--- + +## 2. Getting Started + +```bash +# 1. (Optional) login to HF if models / datasets are gated +huggingface-cli login + +# 2a. Prune via direct CLI arguments +bash run.sh + +# 2b. Prune via a YAML recipe +bash run_recipe.sh +``` + +The run.sh script will: + +1. Download the **base model** (default: `Qwen/Qwen3-0.6B`). +2. Load **C4-en** calibration samples (default: `1024`). +3. Apply **ALPS 2:4** pruning on all linear MLP layers while keeping attention layers dense. +4. Save the pruned model to `out/pruning-/` (HF-compatible). +5. Optionally store compressed tensor formats if `--save_compressed True` is used. + +### Output Artifacts + +``` +out/pruning-/ +└── [HF model files] # weights + config.json + tokenizer +``` + +If `--save_compressed True` is enabled, an additional **compressed** directory containing CSR tensors & metadata will be created alongside the standard HF files. + +## 3. Customizations + +All flags from `PruningConfig` and `CalibrationDataConfig` can be overridden on the CLI or via a **YAML recipe**. + +| Flag | Description | Example | +|------|-------------|---------| +| `--model` | Path / HF-hub id of the model to prune. | `meta-llama/Llama-3.1-8B` | +| `--output_dir` | Where to write the pruned model. | `output_model/pruning/my_llama` | +| `--dataset` | HF dataset used for calibration. | `allenai/c4` | +| `--data_field` | Text column inside the dataset. | `text` | +| `--num_calibration_samples` | #Samples used for calibration. | `2048` | +| `--pruning_strategy` | `ALPS`, `SparseGPT`, or `wanda`. | `ALPS` | +| `--sparsity` | Global sparsity ratio for unstructured pruning. | `0.5` | +| `--prunen / --prunem` | N and M for N:M sparsity. Use `0 0` to disable. | `2 4` | +| `--pruning_yaml_recipe` | Path to a YAML pruning recipe (overrides other pruning flags). | `examples/pruning/alps_24_ignore_attn.yaml` | +| `--save_compressed` | Store compressed tensors (CSR & metadata). | `True` | +| `--model_max_length` | Max sequence length during calibration. | `4096` | + +### YAML Recipes + +Complex sparsity patterns are easiest to express via a YAML file. + +```yaml +# alps_24_ignore_attn.yaml +sparsity_stage: + sparsity_modifiers: + ALPSModifier: + sparsity: 0.5 # 50 % overall + mask_structure: "2:4" # N:M pattern + targets: ["Linear"] # prune all Linear layers + ignore: [ # keep attention dense + "re:.*q_proj", "re:.*k_proj", "re:.*v_proj", "re:.*o_proj", "re:.*lm_head" + ] +``` + +Pass it via `--pruning_yaml_recipe path/to/file.yaml` (see `run_recipe.sh`). + + +## 4. Tips & Tricks + +1. **Speed** – Use N:M sparsity (e.g. 2:4) on Ampere/Hopper GPUs for actual inference acceleration. +2. **Layer dropping** – Excluding attention heads (via `ignore` regexes) often preserves accuracy. +3. **Model size** – Enable `--save_compressed True` to store the model in a compact CSR format. diff --git a/examples/pruning/alps_24_ignore_attn.yaml b/examples/pruning/alps_24_ignore_attn.yaml new file mode 100644 index 0000000..ca1b92e --- /dev/null +++ b/examples/pruning/alps_24_ignore_attn.yaml @@ -0,0 +1,7 @@ +sparsity_stage: + sparsity_modifiers: + ALPSModifier: + sparsity: 0.5 + mask_structure: "2:4" + targets: ["Linear"] + ignore: ["re:.*lm_head", "re:.*q_proj", "re:.*k_proj", "re:.*v_proj", "re:.*o_proj"] \ No newline at end of file diff --git a/examples/pruning/main_pruning.py b/examples/pruning/main_pruning.py new file mode 100644 index 0000000..36647fb --- /dev/null +++ b/examples/pruning/main_pruning.py @@ -0,0 +1,18 @@ +import logging + +from pruning_utils import prune +from transformers import HfArgumentParser + +from fmchisel.config import CalibrationDataConfig, PruningConfig + +logger = logging.getLogger(__name__) + + +if __name__ == "__main__": + + parser = HfArgumentParser((PruningConfig, CalibrationDataConfig)) + (pruning_config, data_config) = parser.parse_args_into_dataclasses() + logger.info(f"pruning_config = {pruning_config}") + logger.info(f"data_config = {data_config}") + + prune(pruning_config, data_config) diff --git a/examples/pruning/pruning_utils.py b/examples/pruning/pruning_utils.py new file mode 100644 index 0000000..117f6b1 --- /dev/null +++ b/examples/pruning/pruning_utils.py @@ -0,0 +1,96 @@ +import logging + +from llmcompressor import oneshot +from transformers import AutoTokenizer + +from fmchisel.config import CalibrationDataConfig, PruningConfig +from fmchisel.data.calibration_datautil import HFCalibrationDataLoader +from fmchisel.pruning.osscar.utils.helpers import cleanup_after_prune, pack + +SPARSE_GPT = "SparseGPT" +WANDA = "wanda" +ALPS = "ALPS" + +logger = logging.getLogger(__name__) + + +def get_pruning_modifier( + pruning_strategy: str, + sparsity: float, + mask_structure: str, +): + + common_kwargs = { + "sparsity": sparsity, + "mask_structure": mask_structure, + "targets": "__ALL_PRUNABLE__", + } + if pruning_strategy == SPARSE_GPT: + from llmcompressor.modifiers.obcq import SparseGPTModifier + + recipe = SparseGPTModifier(**common_kwargs) + return recipe + elif pruning_strategy == WANDA: + from llmcompressor.modifiers.pruning import WandaPruningModifier + + recipe = WandaPruningModifier(**common_kwargs) + return recipe + elif pruning_strategy == ALPS: + from fmchisel.pruning.alps.base import ALPSModifier + + recipe = ALPSModifier(**common_kwargs) + return recipe + else: + raise ValueError(f"Unsupported pruning strategy: {pruning_strategy}") + + +def prune(pruning_config: PruningConfig, data_config: CalibrationDataConfig): + + tokenizer = AutoTokenizer.from_pretrained(pruning_config.model) + max_seq_length = pruning_config.model_max_length or tokenizer.model_max_length + + tokenized_dataset = HFCalibrationDataLoader( + nsamples=data_config.num_calibration_samples, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + dataset=data_config.dataset, + data_field=data_config.data_field, + data_dir=data_config.data_dir, + data_split=data_config.data_split, + ).get_tokenized_calibration() + + if pruning_config.pruning_yaml_recipe and "yaml" in pruning_config.pruning_yaml_recipe: + logger.info("Found a yaml recipe, ignoring pruning_strategy, sparsity, prunen and prunem.") + recipe = pruning_config.pruning_yaml_recipe + else: + logger.info( + "No yaml recipe provided, creating the recipe based on pruning_strategy, sparsity, prunen and prunem." + ) + recipe = get_pruning_modifier( + pruning_strategy=pruning_config.pruning_strategy, + sparsity=pruning_config.sparsity, + mask_structure=f"{pruning_config.prunen}:{pruning_config.prunem}", + ) + + oneshot( + model=pruning_config.model, + dataset=tokenized_dataset, + recipe=recipe, + save_compressed=pruning_config.save_compressed, + output_dir=pruning_config.output_dir, + max_seq_length=max_seq_length, + num_calibration_samples=data_config.num_calibration_samples, + ) + # The lm_head is always saved with the model checkpoint in llmcompressor, + # even if the model has tied word embeddings. This leads to a bug where + # models with tied word embeddings get a random lm_head. + # As a workaround, after pruning, we load the model and copy it to a new + # one with correct lm_head settings. + if not pruning_config.save_compressed: + pack( + pruning_config.output_dir, + 0, + 0, + pruning_config.model, + ) + cleanup_after_prune(pruning_config.output_dir) diff --git a/examples/pruning/run.sh b/examples/pruning/run.sh new file mode 100644 index 0000000..6b14646 --- /dev/null +++ b/examples/pruning/run.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -ex + +###################### +# Task Parameters +###################### + +# Base model (HF hub id or local path) +MODEL="Qwen/Qwen3-0.6B" + +# Calibration dataset +DATASET="allenai/c4" +DATA_FIELD="text" +DATA_SPLIT="train" +DATA_DIR="en" +NUM_CAL_SAMPLES=1024 + +# Pruning settings +PRUNING_METHOD="ALPS" # Choose from: ALPS, wanda, SparseGPT +SPARSITY=0.5 # Unstructured sparsity ratio +PRUNE_N=2 # N for N:M pattern +PRUNE_M=4 # M for N:M pattern +SAVE_COMPRESSED="False" # Store compressed tensors (True/False) + +# Output directory (timestamped) +OUTPUT_DIR="out/pruning-$(date +%Y%m%d-%H%M%S)" +mkdir -p "$OUTPUT_DIR" + +###################### +# Run pruning script +###################### +python main_pruning.py \ + --model "$MODEL" \ + --output_dir "$OUTPUT_DIR" \ + --pruning_strategy "$PRUNING_METHOD" \ + --dataset "$DATASET" \ + --data_field "$DATA_FIELD" \ + --data_split "$DATA_SPLIT" \ + --data_dir "$DATA_DIR" \ + --num_calibration_samples "$NUM_CAL_SAMPLES" \ + --sparsity "$SPARSITY" \ + --prunen "$PRUNE_N" \ + --prunem "$PRUNE_M" \ + --save_compressed "$SAVE_COMPRESSED" diff --git a/examples/pruning/run_recipe.sh b/examples/pruning/run_recipe.sh new file mode 100644 index 0000000..7558223 --- /dev/null +++ b/examples/pruning/run_recipe.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +set -ex + +#-------------------------------------------------- +# Prune a model using a YAML recipe +# Results will be written to: out/pruning-YYYYMMDD-HHMMSS +#-------------------------------------------------- + +###################### +# Task Parameters +###################### + +# Base model (HF hub id or local path) +MODEL="Qwen/Qwen3-0.6B" + +# Calibration dataset +DATASET="Salesforce/wikitext" +DATA_FIELD="text" +DATA_SPLIT="train" +DATA_DIR="wikitext-103-raw-v1" +NUM_CAL_SAMPLES=1024 + +# YAML pruning recipe +RECIPE="./alps_24_ignore_attn.yaml" + +# Whether to store compressed tensors (True/False) +SAVE_COMPRESSED="False" + +# Output directory (timestamped) +OUTPUT_DIR="out/pruning-$(date +%Y%m%d-%H%M%S)" +mkdir -p "$OUTPUT_DIR" + +###################### +# Run pruning script +###################### +python main_pruning.py \ + --model "$MODEL" \ + --output_dir "$OUTPUT_DIR" \ + --dataset "$DATASET" \ + --data_field "$DATA_FIELD" \ + --data_split "$DATA_SPLIT" \ + --data_dir "$DATA_DIR" \ + --num_calibration_samples "$NUM_CAL_SAMPLES" \ + --save_compressed "$SAVE_COMPRESSED" \ + --pruning_yaml_recipe "$RECIPE" diff --git a/examples/quantization/README.md b/examples/quantization/README.md new file mode 100644 index 0000000..3bcb7bc --- /dev/null +++ b/examples/quantization/README.md @@ -0,0 +1,84 @@ +# Quantizing LLMs with FMCHISEL + +This example shows how to **quantize a Large Language Model (LLM) into a low precision, memory- and compute-efficient variant** using FMCHISEL. It demonstrates data loading, model preparation and quantization, and exporting the quantizing the model into the `compressed-tensors` format. + +## 1. Background + +* **Weight-only quantization** – Only model weights are quantized to low precision (e.g., 4 bits), while activations are kept in 16 bits.. +* **Weight and activation quantization** – Both weights and activations are quantized to lower precisions. + +FMCHISEL implements **QuantEase**. See the [QuantEase paper](https://arxiv.org/abs/2309.01885) for technical details. + +--- + +## 2. Getting Started + +```bash +# 1. (Optional) login to HF if models / datasets are gated +huggingface-cli login + +# 2. Quantize via passing YAML recipes +bash run_quantization.sh +``` + +The run.sh script will: + +1. Download the **base model** (default: `Qwen/Qwen3-0.6B`). +2. Load **C4-en** calibration samples (default: `1024`). +3. Apply **QuantEase** quantization (4 bits weight-only quantization). +4. Save the quantized model to `out/quantization-/` (compressed-tensors-compatible). + +### Output Artifacts + +``` +out/quantization-/ +└── [model files] # weights + config.json + tokenizer +``` + + +## 3. Customizations + + +| Flag | Description | Example | +|------|-------------|---------| +| `--model` | Path / HF-hub id of the model to quantize. | `meta-llama/Llama-3.1-8B` | +| `--output_dir` | Where to write the quantized model. | `output_model/quantization/my_llama` | +| `--dataset` | HF dataset used for calibration. | `allenai/c4` | +| `--data_field` | Text column inside the dataset. | `text` | +| `--num_calibration_samples` | #Samples used for calibration. | `2048` | +| `--quantization_recipe` | The path to the recipe for quantization | '/my_recipe.yaml' | +| `--model_max_length` | Max sequence length during calibration. | `4096` | + +### YAML Recipes + +Complex sparsity patterns are easiest to express via a YAML file. We follow the same recipe patterns as `llmcompressor`. These recipes allow for customization of the quantization scheme (number of bits, grouping, activation ordering, etc). + +```yaml +# w4a16_int.yaml +quantization_stage: + run_type: oneshot + quantization_modifiers: + QuantEaseModifier: + dampening_frac: 0.01 + ignore: ["re:.*lm_head"] + num_iter: 5 # Number of QuantEase iterations + config_groups: # Quantization config + group_0: + targets: + - "Linear" + input_activations: null # No activation quantization + output_activations: null + weights: + num_bits: 4 # 4-bit weight quantization + type: "int" # int or float + symmetric: true + strategy: "group" # group-level quantization + group_size: 128 # group size +``` + + + +## 4. Tips & Tricks + +1. **Speed** – Use serving engines such as vLLM for end-to-end speed ups of quantized models. +2. **Layer dropping** – Excluding attention heads (via `ignore` regexes) often preserves accuracy. \ No newline at end of file diff --git a/examples/quantization/main_quantization.py b/examples/quantization/main_quantization.py new file mode 100644 index 0000000..4ce8ecd --- /dev/null +++ b/examples/quantization/main_quantization.py @@ -0,0 +1,18 @@ +import logging + +from quantization_utils import quantize +from transformers import HfArgumentParser + +from fmchisel.config import CalibrationDataConfig, QuantizationConfig + +logger = logging.getLogger(__name__) + + +if __name__ == "__main__": + + parser = HfArgumentParser((QuantizationConfig, CalibrationDataConfig)) + (quantization_config, data_config) = parser.parse_args_into_dataclasses() + logger.info(f"quantization_config = {quantization_config}") + logger.info(f"data_config = {data_config}") + + quantize(quantization_config, data_config) diff --git a/examples/quantization/quantization_utils.py b/examples/quantization/quantization_utils.py new file mode 100644 index 0000000..12a5d27 --- /dev/null +++ b/examples/quantization/quantization_utils.py @@ -0,0 +1,36 @@ +from llmcompressor import oneshot +from transformers import AutoTokenizer + +from fmchisel.config import CalibrationDataConfig, QuantizationConfig +from fmchisel.data.calibration_datautil import HFCalibrationDataLoader + + +def quantize(quantization_config: QuantizationConfig, data_config: CalibrationDataConfig): + + tokenizer = AutoTokenizer.from_pretrained(quantization_config.model) + max_seq_length = quantization_config.model_max_length or tokenizer.model_max_length + + tokenized_dataset = HFCalibrationDataLoader( + nsamples=data_config.num_calibration_samples, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + dataset=data_config.dataset, + data_field=data_config.data_field, + data_dir=data_config.data_dir, + data_split=data_config.data_split, + ).get_tokenized_calibration() + + if quantization_config.quantization_recipe and "yaml" in quantization_config.quantization_recipe: + recipe = quantization_config.quantization_recipe + else: + raise ValueError("No valid quantization recipe was provided.") + + oneshot( + model=quantization_config.model, + dataset=tokenized_dataset, + recipe=recipe, + save_compressed=True, + output_dir=quantization_config.output_dir, + max_seq_length=max_seq_length, + num_calibration_samples=data_config.num_calibration_samples, + ) diff --git a/examples/quantization/recipes/w4a16_int.yaml b/examples/quantization/recipes/w4a16_int.yaml new file mode 100644 index 0000000..38f6ca2 --- /dev/null +++ b/examples/quantization/recipes/w4a16_int.yaml @@ -0,0 +1,20 @@ + +quantization_stage: + run_type: oneshot + quantization_modifiers: + QuantEaseModifier: + dampening_frac: 0.01 + ignore: ["re:.*lm_head"] + num_iter: 5 + config_groups: + group_0: + targets: + - "Linear" + input_activations: null + output_activations: null + weights: + num_bits: 4 + type: "int" + symmetric: true + strategy: "group" + group_size: 128 diff --git a/examples/quantization/recipes/w8a8_int.yaml b/examples/quantization/recipes/w8a8_int.yaml new file mode 100644 index 0000000..f089897 --- /dev/null +++ b/examples/quantization/recipes/w8a8_int.yaml @@ -0,0 +1,26 @@ +quantization_stage: + run_type: oneshot + quantization_modifiers: + SmoothQuantModifier: + smoothing_strength: 0.8 + ignore: null + QuantEaseModifier: + dampening_frac: 0.01 + num_iter: 10 + ignore: ["lm_head"] + config_groups: + group_0: + targets: + - "Linear" + input_activations: + num_bits: 8 + type: "int" + symmetric: true + strategy: 'token' + dynamic: true + output_activations: null + weights: + num_bits: 8 + type: "int" + symmetric: true + strategy: "channel" \ No newline at end of file diff --git a/examples/quantization/run_quantization.sh b/examples/quantization/run_quantization.sh new file mode 100755 index 0000000..56673a7 --- /dev/null +++ b/examples/quantization/run_quantization.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +set -ex + + +###################### +# Task Parameters +###################### + +# Base model (HF hub id or local path) +MODEL="Qwen/Qwen3-0.6B" + +# Calibration dataset +DATASET="allenai/c4" +DATA_FIELD="text" +DATA_SPLIT="train" +DATA_DIR="en" +NUM_CAL_SAMPLES=1024 + +# Output directory (timestamped) +OUTPUT_DIR="out/quantization-$(date +%Y%m%d-%H%M%S)" +mkdir -p "$OUTPUT_DIR" + +# Quantization recipe +RECIPE="./recipes/w4a16_int.yaml" + +python main_quantization.py \ + --model $MODEL \ + --output_dir $OUTPUT_DIR \ + --dataset $DATASET \ + --data_field $DATA_FIELD \ + --data_split $DATA_SPLIT \ + --data_dir $DATA_DIR \ + --num_calibration_samples $NUM_CAL_SAMPLES \ + --quantization_recipe $RECIPE diff --git a/examples/structured_pruning/README.md b/examples/structured_pruning/README.md new file mode 100644 index 0000000..296bf02 --- /dev/null +++ b/examples/structured_pruning/README.md @@ -0,0 +1,60 @@ +# Structured Pruning LLMs with FMCHISEL + +This example shows how to **compress a Large Language Model (LLM) by removing entire neural structures**—hidden MLP neurons and attention-head groups—using FMCHISEL’s structured-pruning algorithm **OSSCAR**. The workflow covers data loading, model preparation, pruning, and exporting the compressed model back to the Hugging Face format. + +## 1. Background + +Structured pruning differs from unstructured sparsity in that whole blocks of computation are removed: + +* **MLP-neuron removal** – shrinks the feed-forward hidden dimension (`intermediate_size`). +* **Attention-head removal** – drops groups of key/value heads and their matching query heads, reducing `num_key_value_heads` and `num_attention_heads`. + +Because full structures are deleted, the pruned model runs with **standard Transformer kernels**—no custom sparse ops required. See the [OSSCAR paper](https://arxiv.org/pdf/2403.12983) for algorithmic details. + + +## 2. Getting Started + +```bash +# 1. (Optional) login to HF if models / datasets are gated +huggingface-cli login + +# 2. Run structured pruning with default hyper-parameters +bash run.sh +``` + +The run.sh script will: + +1. Download the **base model** (default: `Qwen/Qwen3-0.6B`). +2. Load **C4-en** calibration samples (default: `1024`). +3. Remove **128 MLP neurons** and **1 KV-head group** from *each* Transformer block (uniform pruning). +4. Save the compressed model to `out/structured_pruning-/` (HF-compatible). +5. Optionally save an even smaller on-disk representation if `--save_compressed True` is used. + +### Output Artifacts + +``` +out/structured_pruning-/ +└── [HF model files] # weights + config.json + tokenizer +``` + +If `--save_compressed True` is enabled, an additional **compressed** directory containing dense weights with pruned dimensions removed is stored alongside the standard HF files. + + +## 3. Customizations + +All flags from `StructuredPruningConfig` and `CalibrationDataConfig` can be overridden on the CLI. + +| Flag | Description | Example | +|------|-------------|---------| +| `--model` | Path / HF-hub id of the model to prune. | `meta-llama/Llama-3.1-8B` | +| `--output_dir` | Directory for the pruned model. | `out/structured_pruning-myllama` | +| `--dataset` | HF dataset for calibration. | `allenai/c4` | +| `--data_field` | Text column in the dataset. | `text` | +| `--num_calibration_samples` | #Samples for calibration. | `2048` | +| `--num_drop_mlp_neuron` | Hidden neurons removed **per block**. | `256` | +| `--num_drop_attn_group` | KV-head groups removed **per block**. | `2` | +| `--save_compressed` | Store stripped-down tensors. | `True` | +| `--model_max_length` | Max sequence length during calibration. | `4096` | + + +**Accuracy vs. size** – Balance `--num_drop_mlp_neuron` and `--num_drop_attn_group` to reach your target footprint. diff --git a/examples/structured_pruning/main_structured_pruning.py b/examples/structured_pruning/main_structured_pruning.py new file mode 100644 index 0000000..e3bfa3c --- /dev/null +++ b/examples/structured_pruning/main_structured_pruning.py @@ -0,0 +1,18 @@ +import logging + +from structured_pruning_utils import prune +from transformers import HfArgumentParser + +from fmchisel.config import CalibrationDataConfig, StructuredPruningConfig + +logger = logging.getLogger(__name__) + + +if __name__ == "__main__": + + parser = HfArgumentParser((StructuredPruningConfig, CalibrationDataConfig)) + (pruning_config, data_config) = parser.parse_args_into_dataclasses() + logger.info(f"pruning_config = {pruning_config}") + logger.info(f"data_config = {data_config}") + + prune(pruning_config, data_config) diff --git a/examples/structured_pruning/run.sh b/examples/structured_pruning/run.sh new file mode 100644 index 0000000..9748b0e --- /dev/null +++ b/examples/structured_pruning/run.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +set -ex + +# ------------------------------------------------- +# Structured pruning with explicit CLI arguments +# Results saved to: out/structured_pruning-/ +# ------------------------------------------------- + +###################### +# Task Parameters +###################### + +# Base model (HF hub id or local path) +MODEL="Qwen/Qwen3-0.6B" + +# Calibration dataset +DATASET="allenai/c4" +DATA_FIELD="text" +DATA_SPLIT="train" +DATA_DIR="en" +NUM_CAL_SAMPLES=1024 + +# Structured-pruning settings (OSSCAR) +NUM_DROP_MLP_NEURON=128 # neurons removed per transformer block +NUM_DROP_ATTN_GROUP=1 # KV-head groups removed per block +SAVE_COMPRESSED="True" # store compressed model (True/False) + +# Output directory (timestamped like distillation/pruning examples) +OUTPUT_DIR="out/structured_pruning-$(date +%Y%m%d-%H%M%S)" +mkdir -p "$OUTPUT_DIR" + +###################### +# Run pruning script +###################### +python main_structured_pruning.py \ + --model "$MODEL" \ + --output_dir "$OUTPUT_DIR" \ + --dataset "$DATASET" \ + --data_field "$DATA_FIELD" \ + --data_split "$DATA_SPLIT" \ + --data_dir "$DATA_DIR" \ + --num_calibration_samples "$NUM_CAL_SAMPLES" \ + --num_drop_mlp_neuron "$NUM_DROP_MLP_NEURON" \ + --num_drop_attn_group "$NUM_DROP_ATTN_GROUP" \ + --save_compressed "$SAVE_COMPRESSED" diff --git a/examples/structured_pruning/structured_pruning_utils.py b/examples/structured_pruning/structured_pruning_utils.py new file mode 100644 index 0000000..bed5409 --- /dev/null +++ b/examples/structured_pruning/structured_pruning_utils.py @@ -0,0 +1,61 @@ +from llmcompressor import oneshot +from transformers import AutoTokenizer + +from fmchisel.config import CalibrationDataConfig, PruningConfig +from fmchisel.data.calibration_datautil import HFCalibrationDataLoader +from fmchisel.pruning.osscar.base import OSSCARModifier +from fmchisel.pruning.osscar.utils.helpers import cleanup_after_prune, pack + + +def prune(pruning_config: PruningConfig, data_config: CalibrationDataConfig): + + tokenizer = AutoTokenizer.from_pretrained(pruning_config.model) + max_seq_length = pruning_config.model_max_length or tokenizer.model_max_length + + tokenized_dataset = HFCalibrationDataLoader( + nsamples=data_config.num_calibration_samples, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + dataset=data_config.dataset, + data_field=data_config.data_field, + data_dir=data_config.data_dir, + data_split=data_config.data_split, + ).get_tokenized_calibration() + + recipe = OSSCARModifier( + num_drop_mlp_neuron=pruning_config.num_drop_mlp_neuron, + num_drop_attn_group=pruning_config.num_drop_attn_group, + ) + + oneshot( + model=pruning_config.model, + dataset=tokenized_dataset, + recipe=recipe, + save_compressed=pruning_config.save_compressed, # We have custom packing functions for this type of pruning. + output_dir=pruning_config.output_dir, + max_seq_length=max_seq_length, + num_calibration_samples=data_config.num_calibration_samples, + ) + + if not pruning_config.save_compressed: + # The lm_head is always saved with the model checkpoint in llmcompressor, + # even if the model has tied word embeddings. This leads to a bug where + # models with tied word embeddings get a random lm_head. + # As a workaround, after pruning, we load the model and copy it to a new + # one with correct lm_head settings. + pack( + pruning_config.output_dir, + 0, + 0, + pruning_config.model, + ) + + else: + pack( + pruning_config.output_dir, + pruning_config.num_drop_mlp_neuron, + pruning_config.num_drop_attn_group, + pruning_config.model, + ) + + cleanup_after_prune(pruning_config.output_dir) diff --git a/src/fmchisel/config.py b/src/fmchisel/config.py new file mode 100644 index 0000000..cb3084b --- /dev/null +++ b/src/fmchisel/config.py @@ -0,0 +1,160 @@ +from dataclasses import dataclass, field +from typing import List, Literal, Union + + +@dataclass +class TrainingArgs: + model_path: str + output_dir: str + lr: float = field(default=5e-6) + num_epoch: int = field(default=None) + warmup_ratio: float = field(default=0.1) + weight_decay: float = field(default=0.1) + val_check_interval: int = field(default=10) + keep_sparse: bool = field(default=False) + optimizer: str = field(default="adamw") + enable_gradient_checkpointing: bool = field(default=True) + gradient_accumulation_steps: int = field(default=1) + save_on_best_validation: bool = field(default=True) + cpu_offload: bool = field(default=False) + use_liger: bool = field( + default=False, + metadata={ + "help": "Whether to use `liger-kernel` for distillation. With this flag set, we support " + "liger chunked losses for computing distillation loss and liger flce for computing the hard loss. " + "Currently we support FKL, RKL and JSD. Defaults to False." + }, + ) + # LoRA Args + lora_rank: int = field(default=None) + use_lora: bool = field(default=False) + lora_target_modules: Union[List[str], str] = field(default_factory=lambda: ["q_proj", "v_proj"]) + lora_alpha_to_rank_ratio: float = field(default=2.0) + verify_lora_saving_correctness: bool = field( + default=False, + metadata={"help": "Check if the LoRA saved model is properly merged and saved? Only use for testing."}, + ) + + def __post_init__(self): + if self.optimizer not in {"adamw", "adamw_schedulefree"}: + raise ValueError( + f"Optimizer {self.optimizer} is not supported. Please use `adamw` or `adamw_schedulefree`." + ) + if self.use_lora: + assert ( + not self.keep_sparse + ), "LoRA does not update the base weights, so they remain sparse. But the merged weights will not be sparse." + + +@dataclass +class DataLoadingConfig: + data_path: str + dataset: str = field(default="cnn_dailymail") + max_length: int = field(default=4096) + batch_size: int = field(default=8) + n_train: int = field(default=16000) + n_val: int = field(default=5000) + return_prompt_input_ids: bool = field(default=False) + + +@dataclass +class CalibrationDataConfig: + dataset: str = field( + metadata={"help": "Dataset name from HuggingFace (e.g., allenai/c4)."}, + ) + data_split: str = field( + metadata={"help": "What split of data to use (e.g., train, validation, etc)."}, + ) + data_field: str = field( + metadata={"help": "What field of the data to use (e.g., text, question, etc)."}, + ) + data_dir: str = field( + default=None, + metadata={"help": "If applicable, the data directory from Huggingface."}, + ) + num_calibration_samples: int = 1024 + + +@dataclass +class QuantizationConfig: + + model: str + output_dir: str + quantization_recipe: str = ( + field( + metadata={ + "help": "Use W4A16, W8A8, or enter a path to a yaml recipe. Example recipes can be found at flows/inference/quantization/src/recipes." + }, + ), + ) + model_max_length: int = field( + default=2048, + ) + + +@dataclass +class StructuredPruningConfig: + model: str + output_dir: str + num_drop_mlp_neuron: int = field( + default=0, + metadata={"help": "Number of hidden MLP neurons to be pruned."}, + ) + num_drop_attn_group: int = field( + default=0, + metadata={"help": "Number of attention KV groups to be pruned."}, + ) + model_max_length: int = field( + default=2048, + ) + save_compressed: bool = field( + default=True, + metadata={ + "help": "Save the compressed smaller model on disk. If set to False, the saved model will have occupy same disk space with zero paddings in the MLP/attention layers for pruned weights." + }, + ) + + def __post_init__(self): + if self.num_drop_attn_group < 0 or self.num_drop_mlp_neuron < 0: + raise ValueError("num_drop_attn_group and num_drop_mlp_neuron must be non-negative integers.") + if self.num_drop_attn_group + self.num_drop_mlp_neuron == 0: + raise ValueError( + "At least one mlp neuron or attn group has to be removed. got num_drop_attn_group + num_drop_mlp_neuron = 0." + ) + + +@dataclass +class PruningConfig: + model: str + output_dir: str + pruning_yaml_recipe: str = field( + default=None, + metadata={ + "help": "The yaml recipe that can be used for pruning. If a valid yaml file is passed, the values of pruning_strategy, sparsity, prunen and prunem WILL BE IGNORED. Alternatively, leave this field empty and pass in pruning_strategy, sparsity, prunen and prunem." + }, + ) + pruning_strategy: Literal["ALPS", "SparseGPT", "wanda"] = field( + default=None, + metadata={"help": "Method to be used for pruning. WILL BE IGNORED if pruning_yaml_recipe is passed."}, + ) + model_max_length: int = field( + default=2048, + ) + sparsity: float = field( + default=0.5, + metadata={ + "help": "The unstructured sparsity ratio. WILL BE IGNORED if pruning_yaml_recipe is passed. WILL BE IGNORED if prunen is not set to zero." + }, + ) + prunen: int = field( + default=2, + metadata={"help": "The value of N in N:M sparsity. WILL BE IGNORED if pruning_yaml_recipe is passed."}, + ) + prunem: int = field( + default=4, + metadata={"help": "The value of M in N:M sparsity. WILL BE IGNORED if pruning_yaml_recipe is passed."}, + ) + save_compressed: bool = field( + default=False, + metadata={"help": "save the pruned model in the compressed format? It is recommended to be set to False."}, + ) diff --git a/src/fmchisel/data/calibration_datautil.py b/src/fmchisel/data/calibration_datautil.py new file mode 100644 index 0000000..8a6c0dd --- /dev/null +++ b/src/fmchisel/data/calibration_datautil.py @@ -0,0 +1,102 @@ +from abc import ABC, abstractmethod +from typing import List + +from datasets import Dataset, load_dataset +from transformers import AutoTokenizer + +# The following can be used as a reference for a few common datasets. +C4_DATA_PATH = "allenai/c4" +CNN_MAIL_DATA_PATH = "abisee/cnn_dailymail" +WIKITEXT_DATA_PATH = "Salesforce/wikitext" + +DATASETS_DICT = { + # (split, field, dataset, data_dir) + "c4": {"split": "train", "field": "text", "dataset": C4_DATA_PATH, "dir": "en"}, + "cnn_dailymail": {"split": "train", "field": "article", "dataset": CNN_MAIL_DATA_PATH, "dir": "1.0.0"}, + "wikitext": {"split": "train", "field": "text", "dataset": WIKITEXT_DATA_PATH, "dir": "wikitext-103-raw-v1"}, +} +# + + +class CalibrationDataLoader(ABC): + + def __init__( + self, + nsamples: int, + tokenizer: AutoTokenizer, + max_seq_length: int, + padding: bool = False, + truncation: bool = True, + add_special_tokens: bool = False, + **kwargs, + ): + + self.nsamples = nsamples + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.padding = padding + self.truncation = truncation + self.add_special_tokens = add_special_tokens + + @abstractmethod + def _get_calibration_data(self) -> List[str]: + pass + + def get_tokenized_calibration(self): + + calibration_dataset = self._get_calibration_data() + + assert isinstance(calibration_dataset, List), "Calibration dataset must be a list of strings." + + assert len(calibration_dataset) == self.nsamples, "Length of calibration data should be the same as nsamples." + + tokenized_ids = self.tokenizer.batch_encode_plus( + calibration_dataset, + padding=self.padding, + truncation=self.truncation, + max_length=self.max_seq_length, + add_special_tokens=self.add_special_tokens, + ) # {"input_ids": [[..], ..], "attention_mask": [[..], ..]} + return Dataset.from_dict(tokenized_ids) + + +class HFCalibrationDataLoader(CalibrationDataLoader): + + def __init__( + self, + nsamples: int, + tokenizer: AutoTokenizer, + max_seq_length: int, + padding: bool = False, + truncation: bool = True, + add_special_tokens: bool = False, + **kwargs, + ): + + super().__init__( + nsamples, + tokenizer, + max_seq_length, + padding, + truncation, + add_special_tokens, + **kwargs, + ) + self.dataset = kwargs.get("dataset") + self.data_split = kwargs.get("data_split") + self.data_field = kwargs.get("data_field") + self.data_dir = kwargs.get("data_dir", None) + + def _get_calibration_data(self) -> List[str]: + + if self.data_dir is not None: + ds = load_dataset(self.dataset, self.data_dir, streaming=True, split=self.data_split)[self.data_field] + else: + ds = load_dataset(self.dataset, streaming=True, split=self.data_split)[self.data_field] + text_data = [] + for item in ds: + if len(item) > 0: + text_data.append(item) + if len(text_data) == self.nsamples: + break + return text_data diff --git a/src/fmchisel/data/collator.py b/src/fmchisel/data/collator.py new file mode 100644 index 0000000..0a6fa48 --- /dev/null +++ b/src/fmchisel/data/collator.py @@ -0,0 +1,94 @@ +import warnings +from typing import Any, Dict, List, Union + +import numpy as np +import torch +from transformers.data.data_collator import DataCollatorForLanguageModeling + + +# Code taken from yudai_linkedin's e2e flow +# copied from OSS trl repo https://github.com/huggingface/trl/blob/a7dc892717a1503d5f68f94af870b523fe14bc94/trl/trainer/utils.py#L75 +# for avoiding additional dependency +class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): + """ + Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' + when they do not come from the assistant. This ensure that the loss is only + calculated on the completion made by the assistant. + + Args: + response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like + '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response + differently if it does not have proper context. + mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying + `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present + for flexibility and backwards-compatibility. + ignore_index (`int`, *optional*, defaults to `-100`): + The index to use to ignore the initial tokens with + return_prompt_input_ids (`bool`, *optional*, defaults to `False`): Whether or not to return the prompt only input ids and corresponding prompt attention mask + """ + + def __init__( + self, + response_template: Union[str, List[int]], + *args, + mlm: bool = False, + ignore_index: int = -100, + return_prompt_input_ids=False, + **kwargs, + ): + super().__init__(*args, mlm=mlm, **kwargs) + self.response_template = response_template + if isinstance(response_template, str): + # The user provides a string, must tokenize + self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) + else: + # The user already provides the token ids + self.response_token_ids = response_template + self.ignore_index = ignore_index + self.return_prompt_input_ids = return_prompt_input_ids + + def torch_call(self, examples: List[Union[List, Any, Dict]]) -> Dict[str, Any]: + batch = super().torch_call(examples) + + if self.return_prompt_input_ids: + # create empty container + batch["prompt_input_ids"] = torch.full( + batch["input_ids"].shape, + self.tokenizer.pad_token_id, + dtype=batch["input_ids"].dtype, + device=batch["input_ids"].device, + ) + batch["prompt_attention_mask"] = torch.zeros_like(batch["attention_mask"]) + + for i in range(len(examples)): + response_token_ids_start_idx = None + + for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: + # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match + if ( + self.response_token_ids == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist() + ): # noqa: E203 + response_token_ids_start_idx = idx + + if response_token_ids_start_idx is None: + warnings.warn( + f"Could not find response key `{self.response_template}` in the instance. " + f"This instance will be ignored in loss calculation. " + f"Note, if this happens often, consider increasing the `max_seq_length`." + ) + batch["labels"][i, :] = self.ignore_index + if self.return_prompt_input_ids: + # no response token found, all ids in this row are prompt ids + batch["prompt_input_ids"][i, :] = batch["input_ids"][i, :] + batch["prompt_attention_mask"][i, :] = 1 + else: + response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids) + + # Make pytorch loss function ignore all tokens up through the end of the response key + batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index + if self.return_prompt_input_ids: + batch["prompt_input_ids"][i, -response_token_ids_end_idx:] = batch["input_ids"][ + i, :response_token_ids_end_idx + ] + batch["prompt_attention_mask"][i, -response_token_ids_end_idx:] = 1 + return batch diff --git a/src/fmchisel/data/datasets.py b/src/fmchisel/data/datasets.py new file mode 100644 index 0000000..35ca226 --- /dev/null +++ b/src/fmchisel/data/datasets.py @@ -0,0 +1,102 @@ +from abc import ABC, abstractmethod + +import datasets +import lightning.pytorch as pl +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +from fmchisel.config import DataLoadingConfig +from fmchisel.data.collator import DataCollatorForCompletionOnlyLM + +_RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"} + + +CNN_RESPONSE_TEMPLATE = " " + + +class DataModule(pl.LightningDataModule, ABC): + def __init__(self, tokenizer: AutoTokenizer, data_load_config: DataLoadingConfig): + super().__init__() + self.data_name = data_load_config.dataset + self.tokenizer = tokenizer + self.data_path = data_load_config.data_path + self.max_length = data_load_config.max_length + self.batch_size = data_load_config.batch_size + self.n_train = data_load_config.n_train + self.n_val = data_load_config.n_val + self.return_prompt_input_ids = data_load_config.return_prompt_input_ids + + @abstractmethod + def formatting_func(self, example): + pass + + def tokenize(self, example): + outputs = self.tokenizer( + self.formatting_func(example), + truncation=True, + padding=False, + max_length=self.max_length, + ) + return { + "input_ids": outputs["input_ids"], + "attention_mask": outputs["attention_mask"], + } + + @abstractmethod + def setup(self, stage) -> None: + self.train_dataset = self.dataset["train"].map( + self.tokenize, + remove_columns=list(set(self.dataset["train"].column_names) - _RETAIN_COLUMNS), + batched=True, + batch_size=1, + ) + self.val_dataset = self.dataset["test"].map( + self.tokenize, + remove_columns=list(set(self.dataset["test"].column_names) - _RETAIN_COLUMNS), + batched=True, + batch_size=1, + ) + self.dataset = None + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + collate_fn=self.collator, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + collate_fn=self.collator, + ) + + +class CNNModule(DataModule): + def __init__(self, tokenizer: AutoTokenizer, data_load_config: DataLoadingConfig): + super().__init__(tokenizer, data_load_config) + response_prompt = tokenizer.encode(CNN_RESPONSE_TEMPLATE, add_special_tokens=False) + self.collator = DataCollatorForCompletionOnlyLM( + tokenizer=tokenizer, + response_template=response_prompt, + pad_to_multiple_of=16, + return_prompt_input_ids=self.return_prompt_input_ids, + ) + + def formatting_func(self, example): + output = "Given a text, please give highlights.\n\n" + output += f"TEXT: {example['article']}\n" + output += f" {CNN_RESPONSE_TEMPLATE} " + output += f"{example['highlights']} " + return [output] + + def setup(self, stage) -> None: + self.dataset = ( + datasets.load_dataset(path=self.data_path) + if self.data_path + else datasets.load_dataset("cnn_dailymail", "3.0.0") + ) + self.dataset["train"] = self.dataset["train"].select(range(self.n_train)) + self.dataset["test"] = self.dataset["test"].select(range(self.n_val)) + super().setup(stage) diff --git a/src/fmchisel/modifiers.py b/src/fmchisel/modifiers.py new file mode 100644 index 0000000..93f4347 --- /dev/null +++ b/src/fmchisel/modifiers.py @@ -0,0 +1,9 @@ +from fmchisel.pruning.alps.base import ALPSModifier +from fmchisel.pruning.osscar.base import OSSCARModifier +from fmchisel.quantization.quantease.base import QuantEaseModifier + +__all__ = [ + "ALPSModifier", + "OSSCARModifier", + "QuantEaseModifier", +] diff --git a/test/data/test_calibration_data.py b/test/data/test_calibration_data.py new file mode 100644 index 0000000..33aa5ac --- /dev/null +++ b/test/data/test_calibration_data.py @@ -0,0 +1,55 @@ +import unittest + +from transformers import AutoTokenizer + +from fmchisel.data.calibration_datautil import DATASETS_DICT, HFCalibrationDataLoader + + +class TestCalibrationDataLoader(unittest.TestCase): + + def setUp(self): + model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + self.tokenizer = AutoTokenizer.from_pretrained(model) + self.nsamples = 50 + self.length = 128 + + self.common_kwargs = { + "tokenizer": self.tokenizer, + "nsamples": self.nsamples, + "max_seq_length": self.length, + } + + def single_dataset_test(self, kwargs): + loader = HFCalibrationDataLoader(**kwargs) + + ids = loader.get_tokenized_calibration() + self.assertEqual(len(ids), self.nsamples) + for id in ids: + assert len(id["input_ids"]) <= self.length + return True + + def form_kwargs(self, dataset_name): + data_info = DATASETS_DICT[dataset_name] + data_split = data_info["split"] + data_field = data_info["field"] + dataset = data_info["dataset"] + data_dir = data_info["dir"] + kwargs = {"dataset": dataset, "data_field": data_field, "data_dir": data_dir, "data_split": data_split} + return {**self.common_kwargs, **kwargs} + + def test_c4(self): + + kwargs = self.form_kwargs("c4") + assert self.single_dataset_test(kwargs) + + def test_cnn_dailymail(self): + kwargs = self.form_kwargs("cnn_dailymail") + assert self.single_dataset_test(kwargs) + + def test_wikitext(self): + kwargs = self.form_kwargs("wikitext") + assert self.single_dataset_test(kwargs) + + +if __name__ == "__main__": + unittest.main()