This repository contains the official implementation of SafeMoE: Safe Fine-Tuning for MoE LLMs by Aligning Harmful Input Routing, published as a conference paper at ICLR 2026.
Built upon Safe-RLHF
- GPU: 1x NVIDIA A100 80GB
- CUDA: 12.4
git clone https://github.com/jaehanwork/SafeMoE.git
cd SafeMoEInstrall conda env and customized vllm (for LoRA support):
conda env create -f environments.yml
conda activate SafeMoE
export VLLM_PRECOMPILED_WHEEL_LOCATION=https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-0.10.0-cp38-abi3-manylinux1_x86_64.whl
VLLM_USE_PRECOMPILED=1 pip install -e ./vllm-0.10.0Configure Hugging Face token
export HF_TOKEN="your_hf_api_key_here"Harmful Fine-Tuning Attack (HFT) setting:
| Type | Dataset | Size |
|---|---|---|
| Task-specific | SAMSum or SQL | 5,000 samples |
| Harmful | BeaverTails | 500 samples |
Step 1: Extract Safety Routing Weights
Extract routing weights from the initial safety-aligned model:
scripts/extract_routing_logits.sh \
--model_name_or_path allenai/OLMoE-1B-7B-0125-Instruct \
--sample_size 100 \
--output_dir results/router_logits/OLMoEStep 2: Safe Fine-Tuning
Run SafeMoE:
scripts/run_safemoe.sh \
--model_name_or_path allenai/OLMoE-1B-7B-0125-Instruct \
--train_datasets Samsum/train_5k BeaverTails/train_500 \
--routing_logits_safe results/router_logits/OLMoE/router_logits.json \
--temp 0.1 \
--output_dir models/OLMoE_Samsum_safemoeEvaluation
Harmfulness score:
./scripts/eval_JBB.sh \
--model_name_or_path models/OLMoE_Samsum_safemoe \
--output_dir results/JBB/OLMoE_Samsum_safemoeFine-tuning accuracy:
./scripts/eval_Samsum.sh \
--model_name_or_path models/OLMoE_Samsum_safemoe \
--output_dir results/Samsum/OLMoE_Samsum_safemoeRun vanilla fine-tuning:
./scripts/run_sft.sh \
--model_name_or_path allenai/OLMoE-1B-7B-0125-Instruct \
--train_datasets Samsum/train_5k BeaverTails/train_500 \
--output_dir models/OLMoE_Samsum_ftEvaluation
Harmfulness score:
./scripts/eval_JBB.sh \
--model_name_or_path models/OLMoE_Samsum_ft \
--output_dir results/JBB/OLMoE_Samsum_ftFine-tuning accuracy:
./scripts/eval_Samsum.sh \
--model_name_or_path models/OLMoE_Samsum_ft \
--output_dir results/Samsum/OLMoE_Samsum_ft