Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9637ce0
Preference optimization with MaPO and Diffusion-DPO
rockerBOO Jan 28, 2025
2c646cc
Fix image size batch for SDXL
rockerBOO Jun 19, 2024
4df97b3
Remove orpo weight
rockerBOO Jun 24, 2024
c213b56
Fix preference caption prefix/suffix
rockerBOO Jul 5, 2024
5bc6d01
Fix preference in controlnet subset
rockerBOO Jul 14, 2024
0da9e10
Add dataset changes for ControlNet support
rockerBOO Jan 28, 2025
46414bb
Fix parameters
rockerBOO Jan 28, 2025
7452974
Merge branch 'sd3' into po
rockerBOO Apr 14, 2025
d22c827
Update PO cached latents, move out functions, update calls
rockerBOO Apr 27, 2025
78a2946
Merge branch 'sd3' into po
rockerBOO Apr 27, 2025
61e3083
Typo
rockerBOO Apr 28, 2025
10ce29f
Fix timestep/timestep refactor
rockerBOO Apr 28, 2025
d23e15a
Fix remaining test
rockerBOO Apr 28, 2025
8e8243a
Add DDO preference optimization
rockerBOO Apr 29, 2025
9a2101a
Add DDO loss
rockerBOO Apr 30, 2025
22447eb
Use mean, use ddo_loss
rockerBOO Apr 30, 2025
e61dd14
Formatting
rockerBOO Apr 30, 2025
d8716a9
Rework DDO loss
rockerBOO May 2, 2025
e4bdffd
Update diffusion_dpo, MaPO tests. Fix diffusion_dpo/MaPO
rockerBOO May 5, 2025
fe49729
Fix names
rockerBOO May 5, 2025
971387e
Fix DDO arguments
rockerBOO May 5, 2025
4f27c6a
Add BPO, CPO, DDO, SDPO, SimPO
rockerBOO Jun 3, 2025
429b2ab
Merge branch 'sd3' into po
rockerBOO Jun 3, 2025
4152339
Spelling
rockerBOO Jun 3, 2025
db05136
Fix sigmas/timesteps
rockerBOO Jun 4, 2025
e9e9871
Fix ImageInfo iterator
rockerBOO Jun 6, 2025
bf86471
Fix sigmas
rockerBOO Aug 3, 2025
fb4eb3b
Merge branch 'sd3' into po
rockerBOO Aug 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,27 +311,34 @@ def shift_scale_latents(self, args, latents):

def get_noise_pred_and_target(
self,
args,
accelerator,
args: argparse.Namespace,
accelerator: Accelerator,
noise_scheduler,
latents,
batch,
latents: torch.FloatTensor,
batch: dict[str, torch.Tensor],
text_encoder_conds,
unet: flux_models.Flux,
unet,
network,
weight_dtype,
train_unet,
weight_dtype: torch.dtype,
train_unet: bool,
is_train=True,
):
timesteps: torch.FloatTensor | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]

# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
noisy_model_input, rand_timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)

if timesteps is None:
timesteps = rand_timesteps
else:
# Convert timesteps into sigmas
sigmas: torch.FloatTensor = timesteps - noise_scheduler.config.num_train_timesteps

# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
Expand Down Expand Up @@ -364,6 +371,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
model_pred = unet(
img=img,
img_ids=img_ids,
Expand Down Expand Up @@ -431,7 +439,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)

return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
Expand Down
10 changes: 10 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ class BaseSubsetParams:
validation_seed: int = 0
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
preference: bool = False
preference_caption_prefix: Optional[str] = None
preference_caption_suffix: Optional[str] = None
non_preference_caption_prefix: Optional[str] = None
non_preference_caption_suffix: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -198,6 +203,11 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"caption_suffix": str,
"custom_attributes": dict,
"resize_interpolation": str,
"preference": bool,
"preference_caption_prefix": str,
"preference_caption_suffix": str,
"non_preference_caption_prefix": str,
"non_preference_caption_suffix": str
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
Expand Down
Loading
Loading