(#2510) allow mask conditioning_type to work on edit models that require latent conditioning #2520
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Closes #2510
This pull request refactors how conditioning datasets are handled and routed during training, especially for edit models and loss masking. The main improvements are a clearer separation of reference and mask/segmentation conditioning types, more robust validation for Qwen edit models, and consistent use of a new
loss_mask_typefield throughout the codebase. This enhances flexibility and correctness when working with multiple conditioning datasets.Key changes include:
Conditioning Dataset Handling & Routing:
simpletuner/helpers/training/collate.pyto distinguish between reference conditioning types (reference_strict,reference_loose) and mask/segmentation types (mask,segmentation). This allows separate routing for model input (reference) and loss masking (mask/segmentation), supporting more complex dataset setups. [1] [2] [3] [4] [5]conditioning_typeandloss_mask_typeto the prepared batch output, making the distinction explicit for downstream code. [1] [2]Qwen Edit Model Validation:
Loss Masking Logic:
loss_mask_typeinstead ofconditioning_typewhen applying mask/segmentation-based loss masking, ensuring consistent behavior across models. [1] [2] [3] [4] [5] [6]