diff --git a/src/shared/utils/parsing.py b/src/shared/utils/parsing.py index dcab43d..be63f90 100644 --- a/src/shared/utils/parsing.py +++ b/src/shared/utils/parsing.py @@ -20,6 +20,40 @@ def to_int(value: Any, default: int | None = None) -> int | None: return default +@overload +def safe_int( + value: Any, + default: int, + minimum: int | None = None, + maximum: int | None = None, +) -> int: ... + + +@overload +def safe_int( + value: Any, + default: int | None = None, + minimum: int | None = None, + maximum: int | None = None, +) -> int | None: ... + + +def safe_int( + value: Any, + default: int | None = None, + minimum: int | None = None, + maximum: int | None = None, +) -> int | None: + result = to_int(value, default=default) + if result is None: + return None + if minimum is not None and result < minimum: + return minimum + if maximum is not None and result > maximum: + return maximum + return result + + @overload def to_float(value: Any, default: float) -> float: ... @@ -37,6 +71,40 @@ def to_float(value: Any, default: float | None = None) -> float | None: return default +@overload +def safe_float( + value: Any, + default: float, + minimum: float | None = None, + maximum: float | None = None, +) -> float: ... + + +@overload +def safe_float( + value: Any, + default: float | None = None, + minimum: float | None = None, + maximum: float | None = None, +) -> float | None: ... + + +def safe_float( + value: Any, + default: float | None = None, + minimum: float | None = None, + maximum: float | None = None, +) -> float | None: + result = to_float(value, default=default) + if result is None: + return None + if minimum is not None and result < minimum: + return minimum + if maximum is not None and result > maximum: + return maximum + return result + + @overload def to_bool(value: Any, default: bool) -> bool: ... diff --git a/src/worker/executors/ppo_executor.py b/src/worker/executors/ppo_executor.py index 5cd703f..f2cf222 100644 --- a/src/worker/executors/ppo_executor.py +++ b/src/worker/executors/ppo_executor.py @@ -24,12 +24,14 @@ GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, + Trainer, ) from trl.models.modeling_value_head import AutoModelForCausalLMWithValueHead from trl.trainer.ppo_config import PPOConfig from trl.trainer.ppo_trainer import PPOTrainer from shared.tasks.specs import PPOSpecStrict +from shared.utils.parsing import safe_float, safe_int from worker.config import WorkerConfig from worker.lifecycle import Lifecycle @@ -577,54 +579,29 @@ def wrapped_forward(*args, **kwargs): dataset_size = len(dataset) logger.info("Dataset loaded with %d samples", dataset_size) - def _safe_int( - value: Any, - *, - default: int | None = None, - minimum: int | None = None, - ) -> int | None: - if value is None: - return default - try: - parsed = int(value) - except (TypeError, ValueError): - return default - if minimum is not None: - parsed = max(parsed, minimum) - return parsed - - per_device_batch = _safe_int( - training_config.get("per_device_train_batch_size"), - default=_safe_int( - training_config.get("batch_size"), default=1, minimum=1 - ), - minimum=1, + per_device_batch = safe_int( + training_config.get("per_device_train_batch_size"), minimum=1 ) - if dataset_size and per_device_batch and per_device_batch > dataset_size: - logger.info( - "Clipping per_device_train_batch_size from %d to dataset size %d " - "to avoid empty PPO batches", - per_device_batch, - dataset_size, + if per_device_batch is None: + per_device_batch = safe_int( + training_config.get("batch_size"), default=1, minimum=1 ) - per_device_batch = dataset_size - - grad_acc_steps = _safe_int( + grad_acc_steps: int | None = safe_int( training_config.get("gradient_accumulation_steps"), default=1, minimum=1 ) - num_mini_batches = _safe_int( + num_mini_batches: int | None = safe_int( training_config.get("num_mini_batches"), default=1, minimum=1 ) - - def _safe_float( - value: Any, *, default: float | None = None - ) -> float | None: - if value is None: - return default - try: - return float(value) - except (TypeError, ValueError): - return default + per_device_batch, grad_acc_steps = self._normalize_ppo_batch_settings( + dataset_size, + per_device_batch, + grad_acc_steps, + ) + num_mini_batches = self._normalize_ppo_num_mini_batches( + per_device_batch, + grad_acc_steps, + num_mini_batches, + ) # Some TRL versions expect tokenized inputs in the dataset and will # route through a padding collator. Ensure input_ids/attention_mask exist. @@ -682,106 +659,16 @@ def _simple_collate(features): reward_module.eval() logger.info("Creating PPOConfig...") - # Optional args to control saving behavior and memory - ppo_optional: dict[str, Any] = {} - if "save_safetensors" in training_config: - ppo_optional["save_safetensors"] = bool( - training_config["save_safetensors"] - ) # transformers arg - else: - # Default off to avoid shared-tensor safetensors error when - # embeddings are tied - ppo_optional["save_safetensors"] = False - ppo_optional["remove_unused_columns"] = False - - ppo_config = PPOConfig( - learning_rate=float(training_config.get("learning_rate", 1.41e-5)), - batch_size=int( - training_config.get("batch_size", per_device_batch or 1) - ), - mini_batch_size=int( - training_config.get("mini_batch_size", num_mini_batches or 1) - ), - output_dir=str(checkpoint_dir), - seed=int(training_config.get("seed", 42)), - **ppo_optional, - ) - - if per_device_batch is not None: - ppo_config.per_device_train_batch_size = per_device_batch - if grad_acc_steps is not None: - ppo_config.gradient_accumulation_steps = grad_acc_steps - if num_mini_batches is not None: - ppo_config.num_mini_batches = num_mini_batches - - ppo_epochs = _safe_int(training_config.get("ppo_epochs"), minimum=1) - if ppo_epochs is not None: - ppo_config.num_ppo_epochs = ppo_epochs - - train_epochs = _safe_float( - training_config.get("num_train_epochs"), default=None - ) - if train_epochs is not None: - ppo_config.num_train_epochs = max(train_epochs, 1.0) - else: - # Default to 1 pass over the data unless overridden - ppo_config.num_train_epochs = 1.0 - - kl_coef = _safe_float(training_config.get("kl_coef"), default=None) - if kl_coef is not None and kl_coef > 0: - ppo_config.kl_coef = kl_coef - response_cfg = spec.generation or {} - response_length = _safe_int( - response_cfg.get("max_new_tokens"), default=None, minimum=1 - ) - if response_length is not None: - ppo_config.response_length = response_length - else: - ppo_config.response_length = int( - training_config.get("max_seq_length", 64) - ) - - temperature = _safe_float(response_cfg.get("temperature"), default=None) - if temperature is None: - temperature = _safe_float( - training_config.get("temperature"), default=None - ) - if temperature is not None and temperature > 0: - ppo_config.temperature = temperature - - stop_token = response_cfg.get("stop") - if isinstance(stop_token, str): - ppo_config.stop_token = stop_token # type: ignore[assignment] - - logger.info( - "Final PPO batch parameters: per_device=%s, grad_acc=%s, " - "num_mini_batches=%s", - per_device_batch, - grad_acc_steps, - num_mini_batches, - ) - - steps_requested = _safe_int( - training_config.get("steps"), default=None, minimum=1 + ppo_config = self._build_ppo_config( + training_config, + response_cfg, + checkpoint_dir, + per_device_batch=per_device_batch, + grad_acc_steps=grad_acc_steps, + num_mini_batches=num_mini_batches, + dataset_size=dataset_size, ) - if steps_requested is not None and per_device_batch: - total_episodes = steps_requested * max(1, per_device_batch) - ppo_config.total_episodes = total_episodes - logger.info( - "Configuring PPO to run %d update steps (~%d episodes)", - steps_requested, - total_episodes, - ) - else: - logger.info( - "Using num_train_epochs=%.2f over %d samples " - "(per_device_batch=%s, grad_acc=%s)", - float(getattr(ppo_config, "num_train_epochs", 1.0)), - dataset_size, - per_device_batch, - grad_acc_steps, - ) logger.info("PPOConfig created successfully") # Initialize PPO trainer with correct API @@ -807,7 +694,7 @@ def build_trainer() -> PPOTrainer: "train_dataset": dataset, "eval_dataset": dataset, "dataset": dataset, - "output_dir": str(checkpoint_dir), + "output_dir": checkpoint_dir.as_posix(), "data_collator": _simple_collate, "collate_fn": _simple_collate, } @@ -872,6 +759,7 @@ def build_trainer() -> PPOTrainer: ppo_trainer = build_trainer() self._ppo_trainer = ppo_trainer + self._install_trainer_save_overrides(ppo_trainer) # Ensure eval dataset/dataloader exist for TRL 0.23 `generate_completions`. try: if getattr(ppo_trainer, "eval_dataset", None) is None: @@ -914,7 +802,7 @@ def build_trainer() -> PPOTrainer: logger.info("Saving trained model...") model_save_path = checkpoint_dir / "final_model" # Prefer save_model to avoid safetensors shared-tensor errors - ppo_trainer.save_model(str(model_save_path)) + ppo_trainer.save_model(model_save_path.as_posix()) logger.info("Model saved to: %s", model_save_path) final_model_path = model_save_path destination = get_http_destination(task.spec) @@ -1018,7 +906,7 @@ def _ensure_jsonl_local(self, jsonl_cfg: dict[str, Any]) -> Path: timeout=timeout, logger=logger, ) - jsonl_cfg["path"] = str(resolved) + jsonl_cfg["path"] = resolved.as_posix() return resolved except ExecutionError as exc: last_error = exc @@ -1137,6 +1025,236 @@ def _prepare_reward_model( self._ensure_value_head_score(value_model, ref_model) return reward_adapter, False, nullcontext() + def _build_ppo_config( + self, + training_config: dict[str, Any], + response_cfg: dict[str, Any], + checkpoint_dir: Path, + per_device_batch: int | None, + grad_acc_steps: int | None, + num_mini_batches: int | None, + dataset_size: int, + ) -> PPOConfig: + learning_rate = safe_float( + training_config.get("learning_rate"), default=1.41e-5, minimum=0 + ) + batch_size = safe_int( + training_config.get("batch_size"), default=per_device_batch or 1, minimum=1 + ) + mini_batch_size = safe_int( + training_config.get("mini_batch_size"), + default=num_mini_batches or 1, + minimum=1, + ) + seed = safe_int(training_config.get("seed"), default=42, minimum=0) + ppo_epochs = safe_int(training_config.get("ppo_epochs"), minimum=1) + num_train_epochs = safe_float( + training_config.get("num_train_epochs"), default=1.0, minimum=1.0 + ) + kl_coef = safe_float(training_config.get("kl_coef"), minimum=0) + + max_seq_length = safe_int( + training_config.get("max_seq_length"), default=64, minimum=1 + ) + response_length = safe_int( + response_cfg.get("max_new_tokens"), + default=max_seq_length, + minimum=1, + ) + if response_length is None: + response_length = max_seq_length + + temperature = safe_float(response_cfg.get("temperature"), minimum=0) + if temperature is None: + temperature = safe_float(training_config.get("temperature"), minimum=0) + if temperature == 0: + logger.warning( + "Non-positive temperature is capped to 0, which may cause issues " + "during PPO training; consider using a small positive value instead." + ) + + save_strategy = str(training_config.get("save_strategy", "steps")).lower() + + save_steps = training_config.get("save_steps") + if save_steps is None: + save_steps = training_config.get("save_freq") + save_steps = safe_int(save_steps, minimum=1) + save_total_limit = training_config.get("save_total_limit") + save_total_limit = safe_int(save_total_limit, minimum=1) + save_only_model = training_config.get("save_only_model") + + steps_requested = safe_int(training_config.get("steps"), minimum=1) + total_episodes = None + if steps_requested is not None and per_device_batch: + total_episodes = steps_requested * max(1, per_device_batch) + + ppo_ctor_kwargs: dict[str, Any] = {} + if per_device_batch is not None: + ppo_ctor_kwargs["per_device_train_batch_size"] = per_device_batch + if grad_acc_steps is not None: + ppo_ctor_kwargs["gradient_accumulation_steps"] = grad_acc_steps + if num_mini_batches is not None: + ppo_ctor_kwargs["num_mini_batches"] = num_mini_batches + if total_episodes is not None: + ppo_ctor_kwargs["total_episodes"] = total_episodes + if ppo_epochs is not None: + ppo_ctor_kwargs["num_ppo_epochs"] = ppo_epochs + if kl_coef is not None: + ppo_ctor_kwargs["kl_coef"] = kl_coef + if temperature is not None: + ppo_ctor_kwargs["temperature"] = temperature + if save_steps is not None: + ppo_ctor_kwargs["save_steps"] = save_steps + if save_total_limit is not None: + ppo_ctor_kwargs["save_total_limit"] = save_total_limit + if save_only_model is not None: + ppo_ctor_kwargs["save_only_model"] = bool(save_only_model) + + ppo_config = PPOConfig( + learning_rate=learning_rate, + batch_size=batch_size, + mini_batch_size=mini_batch_size, + output_dir=checkpoint_dir.as_posix(), + seed=seed, + num_train_epochs=num_train_epochs, + response_length=response_length, + save_strategy=save_strategy, + remove_unused_columns=False, + save_safetensors=bool(training_config.get("save_safetensors", False)), + **ppo_ctor_kwargs, + ) + + stop_token = response_cfg.get("stop") + if isinstance(stop_token, str): + ppo_config.stop_token = stop_token # type: ignore[assignment] + + logger.info( + "Final PPO batch parameters: per_device=%s, grad_acc=%s, " + "num_mini_batches=%s", + per_device_batch, + grad_acc_steps, + num_mini_batches, + ) + + if total_episodes is not None: + logger.info( + "Configuring PPO to run %d update steps (~%d episodes)", + steps_requested, + total_episodes, + ) + else: + logger.info( + "Using num_train_epochs=%.2f over %d samples " + "(per_device_batch=%s, grad_acc=%s)", + num_train_epochs, + dataset_size, + per_device_batch, + grad_acc_steps, + ) + + return ppo_config + + @staticmethod + def _ppo_world_size() -> int: + world_size_raw = os.environ.get("WORLD_SIZE") + if world_size_raw: + try: + world_size = int(world_size_raw) + if world_size > 0: + return world_size + except ValueError: + pass + try: + if torch.distributed.is_available() and torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + if world_size > 0: + return world_size + except Exception: + pass + return 1 + + def _normalize_ppo_batch_settings( + self, + dataset_size: int, + per_device_batch: int | None, + grad_acc_steps: int | None, + ) -> tuple[int | None, int | None]: + """Ensure PPO batch settings are compatible with dataset size and world size, + adjusting if necessary.""" + if dataset_size <= 0 or per_device_batch is None or grad_acc_steps is None: + return per_device_batch, grad_acc_steps + + world_size = self._ppo_world_size() + max_local_batch = dataset_size // world_size + if max_local_batch < 1: + raise ExecutionError( + "PPO dataset is too small for distributed training: " + f"{dataset_size} samples for world_size={world_size}. " + "TRL PPO requires at least one full local batch per rank." + ) + + local_batch_size = per_device_batch * grad_acc_steps + if local_batch_size <= max_local_batch: + return per_device_batch, grad_acc_steps + + original_per_device = per_device_batch + original_grad_acc = grad_acc_steps + + max_grad_acc_steps = max_local_batch // per_device_batch + if max_grad_acc_steps >= 1: + grad_acc_steps = max_grad_acc_steps + else: + per_device_batch = max_local_batch + grad_acc_steps = 1 + + logger.warning( + "Clipping PPO batch settings from per_device=%d, grad_acc=%d " + "(local_batch=%d) to per_device=%d, grad_acc=%d " + "for dataset_size=%d and world_size=%d. " + "TRL PPO uses drop_last=True and requires at least one full " + "local batch per rank.", + original_per_device, + original_grad_acc, + local_batch_size, + per_device_batch, + grad_acc_steps, + dataset_size, + world_size, + ) + return per_device_batch, grad_acc_steps + + @staticmethod + def _normalize_ppo_num_mini_batches( + per_device_batch: int | None, + grad_acc_steps: int | None, + num_mini_batches: int | None, + ) -> int | None: + """Ensure num_mini_batches is compatible with local batch size, adjusting if + necessary.""" + if ( + per_device_batch is None + or grad_acc_steps is None + or num_mini_batches is None + or num_mini_batches < 1 + ): + return num_mini_batches + + local_batch_size = per_device_batch * grad_acc_steps + adjusted = min(num_mini_batches, local_batch_size) + while adjusted > 1 and local_batch_size % adjusted != 0: + adjusted -= 1 + if adjusted < 1: + adjusted = 1 + if adjusted != num_mini_batches: + logger.warning( + "Adjusting PPO num_mini_batches from %d to %d so local_batch=%d " + "divides evenly.", + num_mini_batches, + adjusted, + local_batch_size, + ) + return adjusted + @staticmethod def _ensure_value_head_score( value_model: AutoModelForCausalLMWithValueHead, @@ -1202,6 +1320,51 @@ def _detect_gpu_count(training_config: dict[str, Any]) -> int: pass return 0 + @staticmethod + def _resolve_model_for_save(model: Any) -> Any: + """Return the policy model that should be serialized.""" + policy_model = getattr(model, "policy", None) + if policy_model is not None: + return policy_model + + module = getattr(model, "module", None) + if module is not None: + policy_model = getattr(module, "policy", None) + if policy_model is not None: + return policy_model + + return model + + def _install_trainer_save_overrides(self, ppo_trainer: PPOTrainer) -> None: + """Patch PPO trainer saves to avoid TRL's DDP-unsafe checkpoint wrapper. + + TRL's PPO checkpoint path assumes ``self.model`` exposes policy/config + attributes directly. Under DDP, ``self.model`` is wrapped, so that path + can fail on rank 0 while other ranks continue into checkpoint + collectives, hanging the run. + """ + + def _wrapped_save_model( + output_dir: str | None = None, _internal_call: bool = False + ) -> None: + backup_model = ppo_trainer.model + backup_deepspeed = ppo_trainer.deepspeed + ppo_trainer.model = self._resolve_model_for_save(backup_model) + if ppo_trainer.is_deepspeed_enabled: + ppo_trainer.deepspeed = ppo_trainer.model # type: ignore[assignment] + try: + Trainer.save_model(ppo_trainer, output_dir, _internal_call) + finally: + ppo_trainer.model = backup_model + if ppo_trainer.is_deepspeed_enabled: + ppo_trainer.deepspeed = backup_deepspeed + + def _wrapped_save_checkpoint(model: Any, trial: Any) -> None: + Trainer._save_checkpoint(ppo_trainer, model, trial) + + setattr(ppo_trainer, "save_model", _wrapped_save_model) + setattr(ppo_trainer, "_save_checkpoint", _wrapped_save_checkpoint) + def cleanup_after_run(self) -> None: dropped_objects = [] for attr in ( diff --git a/src/worker/executors/utils/huggingface.py b/src/worker/executors/utils/huggingface.py index e46621d..cbe856b 100644 --- a/src/worker/executors/utils/huggingface.py +++ b/src/worker/executors/utils/huggingface.py @@ -32,6 +32,8 @@ def build_hf_load_kwargs( """Build ``(tok_kwargs, model_kwargs)`` for HuggingFace ``from_pretrained``.""" tok_kwargs: dict[str, Any] = {} model_kwargs: dict[str, Any] = {} + if padding_side := training_cfg.get("padding_side"): + tok_kwargs["padding_side"] = padding_side if trust_remote_code: tok_kwargs["trust_remote_code"] = True model_kwargs["trust_remote_code"] = True diff --git a/templates/dpo_training_llama_1b.yaml b/templates/dpo_training_llama_1b.yaml index 2e8392a..c947c8f 100644 --- a/templates/dpo_training_llama_1b.yaml +++ b/templates/dpo_training_llama_1b.yaml @@ -23,9 +23,6 @@ spec: identifier: TinyLlama/TinyLlama-1.1B-Chat-v1.0 revision: main trust_remote_code: false - config: - bf16: true - device_map_auto: true data: # Option 1: Use Hugging Face dataset @@ -62,11 +59,13 @@ spec: num_train_epochs: 1 # Number of epochs save_model: true save_freq: 500 # Save every N steps + fp16: false + bf16: true output: destination: type: local artifacts: - results.json - - trained_model/ - logs + - artifacts/checkpoints diff --git a/templates/dpo_training_llama_1b_multi_gpu.yaml b/templates/dpo_training_llama_1b_multi_gpu.yaml index dd19309..2df74fe 100644 --- a/templates/dpo_training_llama_1b_multi_gpu.yaml +++ b/templates/dpo_training_llama_1b_multi_gpu.yaml @@ -25,9 +25,6 @@ spec: identifier: TinyLlama/TinyLlama-1.1B-Chat-v1.0 revision: main trust_remote_code: false - config: - bf16: true - device_map_auto: false data: preferences: @@ -55,6 +52,8 @@ spec: num_train_epochs: 1 save_model: false save_freq: 500 + fp16: false + bf16: true output: destination: diff --git a/templates/dpo_training_mistral.yaml b/templates/dpo_training_ministral.yaml similarity index 90% rename from templates/dpo_training_mistral.yaml rename to templates/dpo_training_ministral.yaml index 08c19dc..69d36e5 100644 --- a/templates/dpo_training_mistral.yaml +++ b/templates/dpo_training_ministral.yaml @@ -1,10 +1,10 @@ apiVersion: flowmesh/v1 kind: TrainingTask metadata: - name: mistral-7b-dpo-training + name: ministral-3b-dpo-training owner: alice annotations: - description: DPO training for Mistral-7B using direct preference optimization + description: DPO training for Ministral-3B using direct preference optimization spec: taskType: dpo resources: @@ -19,12 +19,9 @@ spec: model: source: type: huggingface - identifier: mistralai/Mistral-7B-Instruct-v0.1 + identifier: "mistralai/Ministral-3-3B-Instruct-2512" revision: main trust_remote_code: false - config: - fp16: true - device_map_auto: true data: # Option 1: Use Hugging Face dataset @@ -61,11 +58,13 @@ spec: num_train_epochs: 1 # Number of epochs save_model: true save_freq: 500 # Save every N steps + fp16: false + bf16: true output: destination: type: "local" artifacts: - results.json - - trained_model/ - - logs \ No newline at end of file + - logs + - artifacts/checkpoints diff --git a/templates/ppo_training_llama_1b.yaml b/templates/ppo_training_llama_1b.yaml index 1bc6857..7e67544 100644 --- a/templates/ppo_training_llama_1b.yaml +++ b/templates/ppo_training_llama_1b.yaml @@ -26,9 +26,6 @@ spec: identifier: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" revision: "main" trust_remote_code: false - config: - bf16: true - device_map_auto: true # Reward model configuration (optional) reward_model: @@ -55,24 +52,28 @@ spec: # PPO training parameters training: - learning_rate: 1.41e-5 - batch_size: 4 + learning_rate: 3.0e-6 + batch_size: 2 mini_batch_size: 1 - gradient_accumulation_steps: 4 + gradient_accumulation_steps: 2 steps: 50 # Reduced for faster demo ppo_epochs: 4 target_kl: 0.1 seed: 42 + padding_side: "left" early_stopping: false optimize_cuda_cache: true + save_strategy: "steps" save_model: true save_freq: 25 # Save checkpoint every 25 steps log_with: null # Options: "tensorboard", "wandb", null tracker_project_name: "tinyllama-ppo-training" + fp16: false + bf16: true # Text generation parameters during training generation: - max_new_tokens: 256 + max_new_tokens: 128 temperature: 0.7 do_sample: true diff --git a/templates/ppo_training_llama_1b_multi_gpu.yaml b/templates/ppo_training_llama_1b_multi_gpu.yaml index 97ce1ab..1d6960a 100644 --- a/templates/ppo_training_llama_1b_multi_gpu.yaml +++ b/templates/ppo_training_llama_1b_multi_gpu.yaml @@ -25,9 +25,6 @@ spec: identifier: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" revision: "main" trust_remote_code: false - config: - bf16: true - device_map_auto: false reward_model: identifier: "cardiffnlp/twitter-roberta-base-sentiment-latest" @@ -46,20 +43,24 @@ spec: training: allow_multi_gpu: true - learning_rate: 1.41e-5 - batch_size: 2 + learning_rate: 3.0e-6 + batch_size: 1 mini_batch_size: 1 - gradient_accumulation_steps: 4 + gradient_accumulation_steps: 1 steps: 16 ppo_epochs: 2 target_kl: 0.1 seed: 42 + padding_side: "left" early_stopping: false optimize_cuda_cache: true + save_strategy: "no" save_model: false save_freq: 25 log_with: null tracker_project_name: "tinyllama-ppo-multi-gpu" + fp16: false + bf16: true generation: max_new_tokens: 128 @@ -68,7 +69,8 @@ spec: output: destination: - type: "http" + type: "local" artifacts: - "results.json" - "logs" + - "artifacts/checkpoints" diff --git a/templates/ppo_training_mistral.yaml b/templates/ppo_training_ministral.yaml similarity index 81% rename from templates/ppo_training_mistral.yaml rename to templates/ppo_training_ministral.yaml index 40f6693..3245fe2 100644 --- a/templates/ppo_training_mistral.yaml +++ b/templates/ppo_training_ministral.yaml @@ -1,10 +1,10 @@ apiVersion: flowmesh/v1 kind: TrainingTask metadata: - name: mistral-7b-ppo-training + name: ministral-3b-ppo-training owner: alice annotations: - description: "PPO training for Mistral-7B using reinforcement learning" + description: "PPO training for Ministral-3B using reinforcement learning" spec: taskType: "ppo" # Worker will execute with the PPO module @@ -23,12 +23,9 @@ spec: model: source: type: "huggingface" - identifier: "mistralai/Mistral-7B-Instruct-v0.1" + identifier: "mistralai/Ministral-3-3B-Instruct-2512" revision: "main" trust_remote_code: false - config: - fp16: true - device_map_auto: true # Reward model configuration (optional) reward_model: @@ -55,24 +52,28 @@ spec: # PPO training parameters training: - learning_rate: 1.41e-5 - batch_size: 4 + learning_rate: 3.0e-6 + batch_size: 2 mini_batch_size: 1 - gradient_accumulation_steps: 4 + gradient_accumulation_steps: 2 steps: 50 # Reduced for faster demo ppo_epochs: 4 target_kl: 0.1 seed: 42 + padding_side: "left" early_stopping: false optimize_cuda_cache: true + save_strategy: "steps" save_model: true save_freq: 25 # Save checkpoint every 25 steps log_with: null # Options: "tensorboard", "wandb", null - tracker_project_name: "mistral-ppo-training" + tracker_project_name: "ministral-ppo-training" + fp16: false + bf16: true # Text generation parameters during training generation: - max_new_tokens: 256 + max_new_tokens: 128 temperature: 0.7 do_sample: true @@ -82,5 +83,5 @@ spec: type: "local" artifacts: - "results.json" - - "trained_model/" # Only saved if save_model: true - - "logs" \ No newline at end of file + - "logs" + - "artifacts/checkpoints"