Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions F2LLM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,61 @@ where N_NODE is the number of machines; N_PROCESSES is N_NODE\*8; MASTER_IP is t

On worker nodes, also run the above commmand but modify `machine_rank` accordingly.

### Train with LoRA

For efficient fine-tuning with reduced computational costs, we support **LoRA (Low-Rank Adaptation)** via PEFT (Parameter-Efficient Fine-Tuning). LoRA allows you to adapt base models with minimal parameter updates, making it ideal for resource-constrained environments.

#### LoRA Configuration

Add the following parameters to `configs/config.json` to enable LoRA training:

```json
{
"use_lora": true,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
}
```

#### LoRA Parameters Explanation

- `use_lora` (bool): Enable LoRA fine-tuning. Default: `false`
- `lora_r` (int): LoRA rank (lower values = more efficient, typically 8-32). Default: `16`
- `lora_alpha` (int): LoRA scaling factor. Typically set to 2× `lora_r`. Default: `32`
- `lora_dropout` (float): Dropout probability for LoRA layers. Default: `0.05`
- `lora_target_modules` (list): Transformer modules to apply LoRA to. Default targets query, key, value, output projections and feed-forward gates.

#### LoRA Training Example

```bash
# Start LoRA training with the same command
accelerate launch --config_file configs/accelerate_config.yaml run.py --config configs/config.json
```

#### LoRA Training Benefits

- **Parameter Efficiency**: Only ~1-5% of original model parameters are trainable
- **Reduced Memory**: Significantly lower GPU memory requirements
- **Faster Training**: Quicker convergence due to fewer parameters
- **Portable Adapters**: Save only LoRA weights (~10-100MB) instead of full models
- **Composability**: Combine multiple LoRA adapters for different tasks

#### Loading LoRA Fine-tuned Models

```python
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

# Load the base model and LoRA adapters
model = AutoPeftModelForCausalLM.from_pretrained("path/to/lora/checkpoint")
tokenizer = AutoTokenizer.from_pretrained("path/to/lora/checkpoint")

# For inference, convert to single model file (optional)
model = model.merge_and_unload()
```

### Citation

If you use the F2LLM models, data, or code, please cite the following technical report.
Expand Down
8 changes: 7 additions & 1 deletion F2LLM/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@ class Args:
log_interval: int = 20
checkpointing_steps: int = 100
validation_steps: int = 100
# LoRA settings
use_lora: bool = False
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
lora_target_modules: list = None
# just placeholder, for logging purpose
num_processes: int=0
num_processes: int = 0

def dict(self):
return asdict(self)
Expand Down
2 changes: 1 addition & 1 deletion F2LLM/configs/config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"model_path": "models/qwen3-4b",
"model_path": "Qwen/Qwen2.5-3B",
"experiment_id": "4b+lr.8e-6+bs.16x32+context.1024+2epochs",
"train_data_path": "training_data/data_tokenized_qwen",
"output_dir": "output",
Expand Down
66 changes: 66 additions & 0 deletions F2LLM/lora_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""LoRA configuration and utilities for efficient model adaptation."""

from peft import LoraConfig, TaskType, get_peft_model
from dataclasses import dataclass
from typing import Optional


@dataclass
class LoRAConfig:
"""Configuration class for LoRA (Low-Rank Adaptation) parameters."""

# LoRA settings
r: int = 16 # LoRA rank
lora_alpha: int = 32 # LoRA alpha (scaling factor)
target_modules: list = None # Target modules for LoRA adaptation
lora_dropout: float = 0.05 # Dropout probability for LoRA layers
bias: str = "none" # Bias configuration ("none", "all", "lora_only")

# Training strategy
modules_to_save: Optional[list] = None # Modules to save in addition to LoRA weights

def __post_init__(self):
"""Set default target modules for common LLM architectures."""
if self.target_modules is None:
# Common target modules for Transformer models
self.target_modules = [
"q_proj", # Query projection
"v_proj", # Value projection
"k_proj", # Key projection
"o_proj", # Output projection
"gate_proj", # Gate projection (for gating mechanisms)
"up_proj", # Up projection
"down_proj", # Down projection
]

def get_peft_config(self):
"""Get PEFT LoRA configuration object."""
return LoraConfig(
r=self.r,
lora_alpha=self.lora_alpha,
target_modules=self.target_modules,
lora_dropout=self.lora_dropout,
bias=self.bias,
task_type=TaskType.FEATURE_EXTRACTION, # For embedding models
modules_to_save=self.modules_to_save,
)


def apply_lora_to_model(model, lora_config: LoRAConfig):
"""
Apply LoRA to a model.

Args:
model: The base model to apply LoRA to
lora_config: LoRA configuration object

Returns:
Model with LoRA applied
"""
peft_config = lora_config.get_peft_config()
model = get_peft_model(model, peft_config)

# Print LoRA configuration and trainable parameters
model.print_trainable_parameters()

return model
31 changes: 29 additions & 2 deletions F2LLM/model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,46 @@
import torch
from transformers import AutoModel, AutoTokenizer
from lora_config import LoRAConfig, apply_lora_to_model


class F2LLM:
def __init__(self,
model_path,
max_seq_length=512,
args=None
args=None,
use_lora=False,
lora_config=None
):

self.args = args
self.dtype = torch.bfloat16
self.device = None # set after accelerator.prepare
self.lm = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=self.dtype, attn_implementation='flash_attention_2')
self.use_lora = use_lora

# Try flash_attention_2 first, fall back to eager if not available
try:
self.lm = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=self.dtype,
attn_implementation='flash_attention_2'
)
except (ImportError, ValueError):
# Flash attention not available, use default
self.lm = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=self.dtype
)

self.lm.config.use_cache = False

# Apply LoRA if enabled
if self.use_lora:
if lora_config is None:
lora_config = LoRAConfig()
self.lm = apply_lora_to_model(self.lm, lora_config)

self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.max_seq_length = max_seq_length

Expand Down
1 change: 1 addition & 0 deletions F2LLM/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ flash-attn
torch
transformers
tensorboard
peft
19 changes: 18 additions & 1 deletion F2LLM/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
set_seed,
get_scheduler
)
from lora_config import LoRAConfig
import os, json, random
from datasets import load_dataset
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -119,7 +120,23 @@ def __iter__(self):
override_train_step = True

accelerator.print(f"******************************** Training step before prepare: {args.train_steps} ********************************")
model = F2LLM(args.model_path, args.max_seq_length, args=args)

# Prepare LoRA configuration if enabled
lora_config = None
if args.use_lora:
lora_config = LoRAConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.lora_target_modules,
)
accelerator.print("LoRA enabled with configuration:")
accelerator.print(f" - Rank (r): {args.lora_r}")
accelerator.print(f" - Alpha: {args.lora_alpha}")
accelerator.print(f" - Dropout: {args.lora_dropout}")
accelerator.print(f" - Target modules: {args.lora_target_modules}")

model = F2LLM(args.model_path, args.max_seq_length, args=args, use_lora=args.use_lora, lora_config=lora_config)
model.lm.gradient_checkpointing_enable()
# set seed again to make sure that different models share the same seed
set_seed(0)
Expand Down
Loading