Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ on:
paths:
- "src/**"
- "test/**"
- "examples/**"
pull_request:
branches:
- main
paths:
- "src/**"
- "test/**"
- "examples/**"
schedule:
# Runs at 00:00 UTC daily
- cron: '0 0 * * *'
Expand Down
86 changes: 86 additions & 0 deletions examples/pruning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Pruning LLMs with FMCHISEL

This example shows how to **prune a Large Language Model (LLM) into a sparse, memory- and compute-efficient variant** using FMCHISEL. It demonstrates data loading, model preparation, unstructured & N:M pruning, and exporting the pruned model back to the Hugging Face Hub format.

## 1. Background

* **Unstructured sparsity** – Any weight can be zero. Offers maximum flexibility but requires special kernels for runtime speed-ups.
* **Semi-structured (N:M) sparsity** – Exactly **N** non-zeros in every **M** consecutive weights (e.g. 2:4 = 50 % sparsity). Readily accelerated by NVIDIA A100/H100 tensor cores.

FMCHISEL implements **ALPS** (ADMM-based Layerwise Pruning with Saliency) and wraps two post-training methods from *llmcompressor*: **SparseGPT** and **Wanda**. See the [ALPS paper](https://arxiv.org/abs/2406.07831) for technical details.

---

## 2. Getting Started

```bash
# 1. (Optional) login to HF if models / datasets are gated
huggingface-cli login

# 2a. Prune via direct CLI arguments
bash run.sh

# 2b. Prune via a YAML recipe
bash run_recipe.sh
```

The run.sh script will:

1. Download the **base model** (default: `Qwen/Qwen3-0.6B`).
2. Load **C4-en** calibration samples (default: `1024`).
3. Apply **ALPS 2:4** pruning on all linear MLP layers while keeping attention layers dense.
4. Save the pruned model to `out/pruning-<ID>/` (HF-compatible).
5. Optionally store compressed tensor formats if `--save_compressed True` is used.

### Output Artifacts

```
out/pruning-<ID>/
└── [HF model files] # weights + config.json + tokenizer
```

If `--save_compressed True` is enabled, an additional **compressed** directory containing CSR tensors & metadata will be created alongside the standard HF files.

## 3. Customizations

All flags from `PruningConfig` and `CalibrationDataConfig` can be overridden on the CLI or via a **YAML recipe**.

| Flag | Description | Example |
|------|-------------|---------|
| `--model` | Path / HF-hub id of the model to prune. | `meta-llama/Llama-3.1-8B` |
| `--output_dir` | Where to write the pruned model. | `output_model/pruning/my_llama` |
| `--dataset` | HF dataset used for calibration. | `allenai/c4` |
| `--data_field` | Text column inside the dataset. | `text` |
| `--num_calibration_samples` | #Samples used for calibration. | `2048` |
| `--pruning_strategy` | `ALPS`, `SparseGPT`, or `wanda`. | `ALPS` |
| `--sparsity` | Global sparsity ratio for unstructured pruning. | `0.5` |
| `--prunen / --prunem` | N and M for N:M sparsity. Use `0 0` to disable. | `2 4` |
| `--pruning_yaml_recipe` | Path to a YAML pruning recipe (overrides other pruning flags). | `examples/pruning/alps_24_ignore_attn.yaml` |
| `--save_compressed` | Store compressed tensors (CSR & metadata). | `True` |
| `--model_max_length` | Max sequence length during calibration. | `4096` |

### YAML Recipes

Complex sparsity patterns are easiest to express via a YAML file.

```yaml
# alps_24_ignore_attn.yaml
sparsity_stage:
sparsity_modifiers:
ALPSModifier:
sparsity: 0.5 # 50 % overall
mask_structure: "2:4" # N:M pattern
targets: ["Linear"] # prune all Linear layers
ignore: [ # keep attention dense
"re:.*q_proj", "re:.*k_proj", "re:.*v_proj", "re:.*o_proj", "re:.*lm_head"
]
```

Pass it via `--pruning_yaml_recipe path/to/file.yaml` (see `run_recipe.sh`).


## 4. Tips & Tricks

1. **Speed** – Use N:M sparsity (e.g. 2:4) on Ampere/Hopper GPUs for actual inference acceleration.
2. **Layer dropping** – Excluding attention heads (via `ignore` regexes) often preserves accuracy.
3. **Model size** – Enable `--save_compressed True` to store the model in a compact CSR format.
7 changes: 7 additions & 0 deletions examples/pruning/alps_24_ignore_attn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
sparsity_stage:
sparsity_modifiers:
ALPSModifier:
sparsity: 0.5
mask_structure: "2:4"
targets: ["Linear"]
ignore: ["re:.*lm_head", "re:.*q_proj", "re:.*k_proj", "re:.*v_proj", "re:.*o_proj"]
18 changes: 18 additions & 0 deletions examples/pruning/main_pruning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import logging

from pruning_utils import prune
from transformers import HfArgumentParser

from fmchisel.config import CalibrationDataConfig, PruningConfig

logger = logging.getLogger(__name__)


if __name__ == "__main__":

parser = HfArgumentParser((PruningConfig, CalibrationDataConfig))
(pruning_config, data_config) = parser.parse_args_into_dataclasses()
logger.info(f"pruning_config = {pruning_config}")
logger.info(f"data_config = {data_config}")

prune(pruning_config, data_config)
96 changes: 96 additions & 0 deletions examples/pruning/pruning_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import logging

from llmcompressor import oneshot
from transformers import AutoTokenizer

from fmchisel.config import CalibrationDataConfig, PruningConfig
from fmchisel.data.calibration_datautil import HFCalibrationDataLoader
from fmchisel.pruning.osscar.utils.helpers import cleanup_after_prune, pack

SPARSE_GPT = "SparseGPT"
WANDA = "wanda"
ALPS = "ALPS"

logger = logging.getLogger(__name__)


def get_pruning_modifier(
pruning_strategy: str,
sparsity: float,
mask_structure: str,
):

common_kwargs = {
"sparsity": sparsity,
"mask_structure": mask_structure,
"targets": "__ALL_PRUNABLE__",
}
if pruning_strategy == SPARSE_GPT:
from llmcompressor.modifiers.obcq import SparseGPTModifier

recipe = SparseGPTModifier(**common_kwargs)
return recipe
elif pruning_strategy == WANDA:
from llmcompressor.modifiers.pruning import WandaPruningModifier

recipe = WandaPruningModifier(**common_kwargs)
return recipe
elif pruning_strategy == ALPS:
from fmchisel.pruning.alps.base import ALPSModifier

recipe = ALPSModifier(**common_kwargs)
return recipe
else:
raise ValueError(f"Unsupported pruning strategy: {pruning_strategy}")


def prune(pruning_config: PruningConfig, data_config: CalibrationDataConfig):

tokenizer = AutoTokenizer.from_pretrained(pruning_config.model)
max_seq_length = pruning_config.model_max_length or tokenizer.model_max_length

tokenized_dataset = HFCalibrationDataLoader(
nsamples=data_config.num_calibration_samples,
tokenizer=tokenizer,
max_seq_length=max_seq_length,
dataset=data_config.dataset,
data_field=data_config.data_field,
data_dir=data_config.data_dir,
data_split=data_config.data_split,
).get_tokenized_calibration()

if pruning_config.pruning_yaml_recipe and "yaml" in pruning_config.pruning_yaml_recipe:
logger.info("Found a yaml recipe, ignoring pruning_strategy, sparsity, prunen and prunem.")
recipe = pruning_config.pruning_yaml_recipe
else:
logger.info(
"No yaml recipe provided, creating the recipe based on pruning_strategy, sparsity, prunen and prunem."
)
recipe = get_pruning_modifier(
pruning_strategy=pruning_config.pruning_strategy,
sparsity=pruning_config.sparsity,
mask_structure=f"{pruning_config.prunen}:{pruning_config.prunem}",
)

oneshot(
model=pruning_config.model,
dataset=tokenized_dataset,
recipe=recipe,
save_compressed=pruning_config.save_compressed,
output_dir=pruning_config.output_dir,
max_seq_length=max_seq_length,
num_calibration_samples=data_config.num_calibration_samples,
)
# The lm_head is always saved with the model checkpoint in llmcompressor,
# even if the model has tied word embeddings. This leads to a bug where
# models with tied word embeddings get a random lm_head.
# As a workaround, after pruning, we load the model and copy it to a new
# one with correct lm_head settings.
if not pruning_config.save_compressed:
pack(
pruning_config.output_dir,
0,
0,
pruning_config.model,
)
cleanup_after_prune(pruning_config.output_dir)
44 changes: 44 additions & 0 deletions examples/pruning/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env bash
set -ex

######################
# Task Parameters
######################

# Base model (HF hub id or local path)
MODEL="Qwen/Qwen3-0.6B"

# Calibration dataset
DATASET="allenai/c4"
DATA_FIELD="text"
DATA_SPLIT="train"
DATA_DIR="en"
NUM_CAL_SAMPLES=1024

# Pruning settings
PRUNING_METHOD="ALPS" # Choose from: ALPS, wanda, SparseGPT
SPARSITY=0.5 # Unstructured sparsity ratio
PRUNE_N=2 # N for N:M pattern
PRUNE_M=4 # M for N:M pattern
SAVE_COMPRESSED="False" # Store compressed tensors (True/False)

# Output directory (timestamped)
OUTPUT_DIR="out/pruning-$(date +%Y%m%d-%H%M%S)"
mkdir -p "$OUTPUT_DIR"

######################
# Run pruning script
######################
python main_pruning.py \
--model "$MODEL" \
--output_dir "$OUTPUT_DIR" \
--pruning_strategy "$PRUNING_METHOD" \
--dataset "$DATASET" \
--data_field "$DATA_FIELD" \
--data_split "$DATA_SPLIT" \
--data_dir "$DATA_DIR" \
--num_calibration_samples "$NUM_CAL_SAMPLES" \
--sparsity "$SPARSITY" \
--prunen "$PRUNE_N" \
--prunem "$PRUNE_M" \
--save_compressed "$SAVE_COMPRESSED"
45 changes: 45 additions & 0 deletions examples/pruning/run_recipe.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env bash
set -ex

#--------------------------------------------------
# Prune a model using a YAML recipe
# Results will be written to: out/pruning-YYYYMMDD-HHMMSS
#--------------------------------------------------

######################
# Task Parameters
######################

# Base model (HF hub id or local path)
MODEL="Qwen/Qwen3-0.6B"

# Calibration dataset
DATASET="Salesforce/wikitext"
DATA_FIELD="text"
DATA_SPLIT="train"
DATA_DIR="wikitext-103-raw-v1"
NUM_CAL_SAMPLES=1024

# YAML pruning recipe
RECIPE="./alps_24_ignore_attn.yaml"

# Whether to store compressed tensors (True/False)
SAVE_COMPRESSED="False"

# Output directory (timestamped)
OUTPUT_DIR="out/pruning-$(date +%Y%m%d-%H%M%S)"
mkdir -p "$OUTPUT_DIR"

######################
# Run pruning script
######################
python main_pruning.py \
--model "$MODEL" \
--output_dir "$OUTPUT_DIR" \
--dataset "$DATASET" \
--data_field "$DATA_FIELD" \
--data_split "$DATA_SPLIT" \
--data_dir "$DATA_DIR" \
--num_calibration_samples "$NUM_CAL_SAMPLES" \
--save_compressed "$SAVE_COMPRESSED" \
--pruning_yaml_recipe "$RECIPE"
Loading