Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 33 additions & 41 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,64 +82,68 @@ def create_finetune_request(
hf_api_token: str | None = None,
hf_output_repo_name: str | None = None,
) -> FinetuneRequest:

# Error validation block (grouped to minimize conditional eval and duplicate attribute lookups)
if model is not None and from_checkpoint is not None:
raise ValueError(
"You must specify either a model or a checkpoint to start a job from, not both"
)

if model is None and from_checkpoint is None:
raise ValueError("You must specify either a model or a checkpoint")

if from_checkpoint is not None and from_hf_model is not None:
raise ValueError(
"You must specify either a Hugging Face Hub model or a previous checkpoint from "
"Together to start a job from, not both"
)

if from_hf_model is not None and model is None:
raise ValueError(
"You must specify the base model to fine-tune a model from the Hugging Face Hub"
)

model_or_checkpoint = model or from_checkpoint
# batch attributes for lora/full training computed then used in block,
# saving repeated lookups

if warmup_ratio is None:
warmup_ratio = 0.0
# Set defaults early
warmup_ratio = 0.0 if warmup_ratio is None else warmup_ratio

training_type: TrainingType = FullTrainingType()
# Set training_type, batch limits block
# Fast-path branch on lora, now refs to non-None attr only once for batch size values
if lora:
if model_limits.lora_training is None:
if (lora_cfg := model_limits.lora_training) is None:
raise ValueError(
f"LoRA adapters are not supported for the selected model ({model_or_checkpoint})."
)
if lora_dropout is not None and not 0 <= lora_dropout < 1.0:
raise ValueError("LoRA dropout must be in [0, 1) range.")

if lora_dropout is not None:
if not 0 <= lora_dropout < 1.0:
raise ValueError("LoRA dropout must be in [0, 1) range.")

lora_r = lora_r if lora_r is not None else model_limits.lora_training.max_rank
lora_r = lora_r if lora_r is not None else lora_cfg.max_rank
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
training_type = LoRATrainingType(
training_type: TrainingType = LoRATrainingType(
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
lora_trainable_modules=lora_trainable_modules,
)

max_batch_size = model_limits.lora_training.max_batch_size
min_batch_size = model_limits.lora_training.min_batch_size
max_batch_size_dpo = model_limits.lora_training.max_batch_size_dpo
max_batch_size = lora_cfg.max_batch_size
min_batch_size = lora_cfg.min_batch_size
max_batch_size_dpo = lora_cfg.max_batch_size_dpo
else:
if model_limits.full_training is None:
if (full_cfg := model_limits.full_training) is None:
raise ValueError(
f"Full training is not supported for the selected model ({model_or_checkpoint})."
)
training_type: TrainingType = FullTrainingType()
max_batch_size = full_cfg.max_batch_size
min_batch_size = full_cfg.min_batch_size
max_batch_size_dpo = full_cfg.max_batch_size_dpo

max_batch_size = model_limits.full_training.max_batch_size
min_batch_size = model_limits.full_training.min_batch_size
max_batch_size_dpo = model_limits.full_training.max_batch_size_dpo

# All validations in one pass, common settings accessed via precomputed variables
if batch_size != "max":
if batch_size < min_batch_size:
raise ValueError(
f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}."
)
if training_method == "sft":
if batch_size > max_batch_size:
raise ValueError(
Expand All @@ -151,41 +155,29 @@ def create_finetune_request(
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
)

if batch_size < min_batch_size:
raise ValueError(
f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}."
)

if warmup_ratio > 1 or warmup_ratio < 0:
if not (0 <= warmup_ratio <= 1):
raise ValueError(f"Warmup ratio should be between 0 and 1 (got {warmup_ratio})")

if min_lr_ratio is not None and (min_lr_ratio > 1 or min_lr_ratio < 0):
if min_lr_ratio is not None and not (0 <= min_lr_ratio <= 1):
raise ValueError(
f"Min learning rate ratio should be between 0 and 1 (got {min_lr_ratio})"
)

if max_grad_norm < 0:
raise ValueError(
f"Max gradient norm should be non-negative (got {max_grad_norm})"
)

if weight_decay is not None and (weight_decay < 0):
raise ValueError(f"Weight decay should be non-negative (got {weight_decay})")

if training_method not in AVAILABLE_TRAINING_METHODS:
raise ValueError(
f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}"
)

if train_on_inputs is not None and training_method != "sft":
raise ValueError("train_on_inputs is only supported for SFT training")

if train_on_inputs is None and training_method == "sft":
log_warn_once(
"train_on_inputs is not set for SFT training, it will be set to 'auto'"
)
train_on_inputs = "auto"

if dpo_beta is not None and training_method != "dpo":
raise ValueError("dpo_beta is only supported for DPO training")
if dpo_normalize_logratios_by_length and training_method != "dpo":
Expand All @@ -195,22 +187,21 @@ def create_finetune_request(
if rpo_alpha is not None:
if training_method != "dpo":
raise ValueError("rpo_alpha is only supported for DPO training")
if not rpo_alpha >= 0.0:
if rpo_alpha < 0.0:
raise ValueError(f"rpo_alpha should be non-negative (got {rpo_alpha})")

if simpo_gamma is not None:
if training_method != "dpo":
raise ValueError("simpo_gamma is only supported for DPO training")
if not simpo_gamma >= 0.0:
if simpo_gamma < 0.0:
raise ValueError(f"simpo_gamma should be non-negative (got {simpo_gamma})")

# Scheduler branch, assignments only called once per type, optimizer args grouped
lr_scheduler: FinetuneLRScheduler
if lr_scheduler_type == "cosine":
if scheduler_num_cycles <= 0.0:
raise ValueError(
f"Number of cycles should be greater than 0 (got {scheduler_num_cycles})"
)

lr_scheduler = CosineLRScheduler(
lr_scheduler_args=CosineLRSchedulerArgs(
min_lr_ratio=min_lr_ratio, num_cycles=scheduler_num_cycles
Expand All @@ -221,6 +212,7 @@ def create_finetune_request(
lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
)

# Training method type switch, grouped assignment
training_method_cls: TrainingMethodSFT | TrainingMethodDPO
if training_method == "sft":
training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs)
Expand All @@ -244,6 +236,7 @@ def create_finetune_request(
simpo_gamma=simpo_gamma,
)

# Assignment block, fast pass through no logic just mapping
finetune_request = FinetuneRequest(
model=model,
training_file=training_file,
Expand All @@ -270,7 +263,6 @@ def create_finetune_request(
hf_api_token=hf_api_token,
hf_output_repo_name=hf_output_repo_name,
)

return finetune_request


Expand Down
14 changes: 13 additions & 1 deletion src/together/utils/_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,20 @@ def log_warn(message: str | Any, **params: Any) -> None:


def log_warn_once(message: str | Any, **params: Any) -> None:
# Optimize: Only format/log if message is new
# Fast-path avoids logfmt/regex unless the warning is actually new
dummy_msg = dict(message=message, **params)
# Use a simple stable repr to check membership first
# This loses logfmt fidelity, but WARNING_MESSAGES_ONCE stores formatted strings
# To avoid full formatting, check the unformatted warning here
# However, membership is determined by the formatted string, so need to precompute key
# Instead, convert to string key before calling logfmt
# But cost is dominated by logfmt, so instead we check the set membership first via the candidate message
msg_candidate = f"{message}|{sorted(params.items())}" if params else str(message)
if msg_candidate in WARNING_MESSAGES_ONCE:
return
msg = logfmt(dict(message=message, **params))
if msg not in WARNING_MESSAGES_ONCE:
print(msg, file=sys.stderr)
logger.warn(msg)
WARNING_MESSAGES_ONCE.add(msg)
WARNING_MESSAGES_ONCE.add(msg_candidate)