Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ disable=raw-checker-failed,
broad-exception-caught,
super-init-not-called,
duplicate-code,
too-many-positional-arguments
too-many-positional-arguments,
too-many-lines
Comment thread
nathan-weinberg marked this conversation as resolved.

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
51 changes: 11 additions & 40 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
check_flash_attn_enabled,
check_valid_train_args,
convert_loss_to_reduce_sum,
create_lora_config,
ensure_loadable_dolomite_checkpoint,
get_projection_layer_names,
load_latest_full_state,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
Expand Down Expand Up @@ -113,13 +113,16 @@ def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled):
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
base_model_args["use_padding_free_transformer"] = True
model = GPTDolomiteForCausalLM.from_pretrained(
**base_model_args,
use_padding_free_transformer=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(**base_model_args)

# store the base model args so we can recall them later if saving a LoRA model
args.base_model_args = base_model_args

if len(tokenizer) > model.config.vocab_size:
print(
f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
Expand Down Expand Up @@ -175,46 +178,14 @@ def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled):
# - with the exception of granite, which handles it
# in the later stanza
if args.lora_r > 0:
# if lora
# Third Party
from peft import LoraConfig

# ensure we select only the modules that exist in the model
proj_layers = get_projection_layer_names(model)
if not args.lora_target_modules:
print(
f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
)
if not proj_layers:
raise RuntimeError("could not find any projection layers in the model")
args.__dict__["lora_target_modules"] = proj_layers
else:
# when the user specifies the module, we should verify that they align with what's in the model
lora_target_modules_set = set(args.lora_target_modules)
diff = lora_target_modules_set - set(proj_layers)
layers_to_target = lora_target_modules_set - diff
if len(diff) == len(args.lora_target_modules):
raise ValueError(
f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
)
if diff:
print(
f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
)
args.__dict__["lora_target_modules"] = list(layers_to_target)

peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules,
)
lora_config = create_lora_config(model, args)
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.use_dolomite
model,
lora_config,
args.distributed_training_framework,
gradient_checkpointing=not args.use_dolomite,
)

args.lora_config = lora_config
elif not args.use_dolomite:
model.gradient_checkpointing_enable()

Expand Down
38 changes: 27 additions & 11 deletions src/instructlab/training/setup_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

# Third Party
from accelerate import Accelerator
from torch.distributed.fsdp import ( # FullyShardedDataParallel as FSDP,
BackwardPrefetch,
MixedPrecision,
ShardingStrategy,
)
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import PreTrainedModel
import torch

# First Party
Expand Down Expand Up @@ -51,34 +49,52 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption
return ds_plugin


def get_fsdp_config(args, model):
def get_fsdp_config(args, model: PreTrainedModel):
# Third Party
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload

is_lora = args.lora_r > 0
block_name = model._no_split_modules[0]

fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=partial(
wrap_policy = None
if is_lora > 0:
wrap_policy = fsdp_auto_wrap_policy(model)
else:
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
get_module_class_from_name(model, block_name),
},
),
)

# TODO(osilkin): BACKWARD_POST trades memory utilization for processing time, which is important for systems utilizing LoRA
# We should have this be configurable in the future.
prefetch_policy = (
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrap_policy,
limit_all_gathers=True,
mixed_precision_policy=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
backward_prefetch=prefetch_policy,
sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy],
cpu_offload=CPUOffload(args.cpu_offload_params_fsdp),
)

# `use_orig_params` must be disabled when using LoRA and FSDP together
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
if args.lora_r > 0:
fsdp_plugin.use_orig_params = False
Comment thread
JamesKunstle marked this conversation as resolved.

return fsdp_plugin


def setup_accelerator(args, model, grad_accum):
def setup_accelerator(args, model: PreTrainedModel, grad_accum):
if args.distributed_training_framework == "deepspeed":
# Third Party
from deepspeed import DeepSpeedEngine
Expand Down
Loading