diff --git a/flux_train_network.py b/flux_train_network.py index cfc617088..bdff62007 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -311,27 +311,34 @@ def shift_scale_latents(self, args, latents): def get_noise_pred_and_target( self, - args, - accelerator, + args: argparse.Namespace, + accelerator: Accelerator, noise_scheduler, - latents, - batch, + latents: torch.FloatTensor, + batch: dict[str, torch.Tensor], text_encoder_conds, - unet: flux_models.Flux, + unet, network, - weight_dtype, - train_unet, + weight_dtype: torch.dtype, + train_unet: bool, is_train=True, - ): + timesteps: torch.FloatTensor | None = None, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + noisy_model_input, rand_timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep( args, noise_scheduler, latents, noise, accelerator.device, weight_dtype ) + if timesteps is None: + timesteps = rand_timesteps + else: + # Convert timesteps into sigmas + sigmas: torch.FloatTensor = timesteps - noise_scheduler.config.num_train_timesteps + # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 @@ -364,6 +371,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() model_pred = unet( img=img, img_ids=img_ids, @@ -431,7 +439,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, weighting + return model_pred, noisy_model_input, target, sigmas, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/config_util.py b/library/config_util.py index 53727f252..b296db841 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -76,6 +76,11 @@ class BaseSubsetParams: validation_seed: int = 0 validation_split: float = 0.0 resize_interpolation: Optional[str] = None + preference: bool = False + preference_caption_prefix: Optional[str] = None + preference_caption_suffix: Optional[str] = None + non_preference_caption_prefix: Optional[str] = None + non_preference_caption_suffix: Optional[str] = None @dataclass @@ -198,6 +203,11 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "caption_suffix": str, "custom_attributes": dict, "resize_interpolation": str, + "preference": bool, + "preference_caption_prefix": str, + "preference_caption_suffix": str, + "non_preference_caption_prefix": str, + "non_preference_caption_suffix": str } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index ad3e69ffb..bab63f0ca 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,10 +1,15 @@ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F +from typing import Callable, Protocol +import math import argparse import random import re from torch.types import Number -from typing import List, Optional, Union +from typing import List, Optional, Union, Callable from .utils import setup_logging setup_logging() @@ -65,7 +70,9 @@ def enforce_zero_terminal_snr(betas): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): +def apply_snr_weight( + loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False +): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: @@ -91,7 +98,9 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): return scale -def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor): +def add_v_prediction_like_loss( + loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor +): scale = get_snr_scale(timesteps, noise_scheduler) # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss @@ -143,6 +152,75 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意", ) + parser.add_argument( + "--beta_dpo", + type=int, + help="DPO KL Divergence penalty. Recommended values for SD1.5 B=2000, SDXL B=5000 / DPO KL 発散ペナルティ。SD1.5 の推奨値 B=2000、SDXL B=5000", + ) + parser.add_argument( + "--mapo_beta", + type=float, + help="MaPO beta regularization parameter. Recommended values of 0.01 to 0.1 / 相対比損失の MaPO ~ 0.25 です", + ) + parser.add_argument( + "--cpo_beta", + type=float, + help="CPO beta regularization parameter. Recommended value of 0.1", + ) + parser.add_argument( + "--bpo_beta", + type=float, + help="BPO beta regularization parameter. Recommended value of 0.1", + ) + parser.add_argument( + "--bpo_lambda", + type=float, + help="BPO beta regularization parameter. Recommended value of 0.0 to 0.2. -0.5 similar to DPO gradient.", + ) + parser.add_argument( + "--sdpo_beta", + type=float, + help="SDPO beta regularization parameter. Recommended value of 0.02", + ) + parser.add_argument( + "--sdpo_epsilon", + type=float, + default=0.1, + help="SDPO epsilon for clipping importance weighting. Recommended value of 0.1", + ) + parser.add_argument( + "--simpo_gamma_beta_ratio", + type=float, + help="SimPO target reward margin term. Ensure the reward for the chosen exceeds the rejected. Recommended: 0.25-1.75", + ) + parser.add_argument( + "--simpo_beta", + type=float, + help="SDPO beta controls the scaling of the reward difference. Recommended: 2.0-2.5", + ) + parser.add_argument( + "--simpo_smoothing", + type=float, + help="SDPO smoothing of chosen/rejected. Recommended: 0.0", + ) + parser.add_argument( + "--simpo_loss_type", + type=str, + default="sigmoid", + choices=["sigmoid", "hinge"], + help="SDPO loss type. Options: sigmoid, hinge. Default: sigmoid", + ) + parser.add_argument( + "--ddo_alpha", + type=float, + help="Controls weight of the fake samples loss term (range: 0.5-50). Higher values increase penalty on reference model samples. Start with 4.0.", + ) + parser.add_argument( + "--ddo_beta", + type=float, + help="Scaling factor for likelihood ratio (range: 0.01-0.1). Higher values create stronger separation between target and reference distributions. Start with 0.05.", + ) + re_attention = re.compile( r""" @@ -492,7 +570,7 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: # print(f"conditioning_image: {mask_image.shape}") elif "alpha_masks" in batch and batch["alpha_masks"] is not None: # alpha mask is 0 to 1 - mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") else: return loss @@ -503,6 +581,443 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: return loss +def assert_po_variables(args): + if args.ddo_beta is not None or args.ddo_alpha is not None: + assert args.ddo_beta is not None and args.ddo_alpha is not None, "Both ddo_beta and ddo_alpha must be set together" + elif args.bpo_beta is not None or args.bpo_lambda is not None: + assert args.bpo_beta is not None and args.bpo_lambda is not None, "Both bpo_beta and bpo_lambda must be set together" + + +class PreferenceOptimization: + def __init__(self, args): + self.loss_fn = None + self.loss_ref_fn = None + + assert_po_variables(args) + + if args.ddo_beta is not None or args.ddo_alpha is not None: + self.algo = "DDO" + self.loss_ref_fn = ddo_loss + self.args = {"beta": args.ddo_beta, "alpha": args.ddo_alpha} + elif args.bpo_beta is not None or args.bpo_lambda is not None: + self.algo = "BPO" + self.loss_ref_fn = bpo_loss + self.args = {"beta": args.bpo_beta, "lambda_": args.bpo_lambda} + elif args.beta_dpo is not None: + self.algo = "Diffusion DPO" + self.loss_ref_fn = diffusion_dpo_loss + self.args = {"beta": args.beta_dpo} + elif args.sdpo_beta is not None: + self.algo = "SDPO" + self.loss_ref_fn = sdpo_loss + self.args = {"beta": args.sdpo_beta, "epsilon": args.sdpo_epsilon} + + if args.mapo_beta is not None: + self.algo = "MaPO" + self.loss_fn = mapo_loss + self.args = {"beta": args.mapo_beta} + elif args.simpo_beta is not None: + self.algo = "SimPO" + self.loss_fn = simpo_loss + self.args = { + "beta": args.simpo_beta, + "gamma_beta_ratio": args.simpo_gamma_beta_ratio, + "smoothing": args.simpo_smoothing, + "loss_type": args.simpo_loss_type, + } + elif args.cpo_beta is not None: + self.algo = "CPO" + self.loss_fn = cpo_loss + self.args = {"beta": args.cpo_beta} + + def is_po(self): + return self.loss_fn is not None or self.loss_ref_fn is not None + + def is_reference(self): + return self.loss_ref_fn is not None + + def __call__(self, loss: torch.Tensor, ref_loss: torch.Tensor | None = None): + if self.is_reference(): + assert ref_loss is not None, "Reference required for this preference optimization" + assert self.loss_ref_fn is not None, "No reference loss function" + loss, metrics = self.loss_ref_fn(loss, ref_loss, **self.args) + else: + assert self.loss_fn is not None, "No loss function" + loss, metrics = self.loss_fn(loss, **self.args) + + return loss, metrics + + +def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta: float): + """ + Diffusion DPO loss + + Args: + loss: pairs of w, l losses B//2 + ref_loss: ref pairs of w, l losses B//2 + beta_dpo: beta_dpo weight + """ + loss_w, loss_l = loss.chunk(2) + ref_losses_w, ref_losses_l = ref_loss.chunk(2) + + model_diff = loss_w - loss_l + ref_diff = ref_losses_w - ref_losses_l + + scale_term = -0.5 * beta + inside_term = scale_term * (model_diff - ref_diff) + loss = -1 * torch.nn.functional.logsigmoid(inside_term).mean(dim=(1, 2, 3)) + + implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) + + metrics = { + "loss/diffusion_dpo_total_loss": loss.detach().mean().item(), + "loss/diffusion_dpo_ref_loss": ref_loss.detach().mean().item(), + "loss/diffusion_dpo_implicit_acc": implicit_acc.detach().mean().item(), + } + + return loss, metrics + + +def mapo_loss(model_losses: torch.Tensor, beta: float, total_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]: + """ + MaPO loss + + Paper: Margin-aware Preference Optimization for Aligning Diffusion Models without Reference + https://mapo-t2i.github.io/ + + Args: + loss: pairs of w, l losses B//2, C, H, W. We want full distribution of the + loss for numerical stability + mapo_weight: mapo weight + total_timesteps: number of timesteps + """ + loss_w, loss_l = model_losses.chunk(2) + + phi_coefficient = 0.5 + win_score = (phi_coefficient * loss_w) / (torch.exp(phi_coefficient * loss_w) - 1) + lose_score = (phi_coefficient * loss_l) / (torch.exp(phi_coefficient * loss_l) - 1) + + # Score difference loss + score_difference = win_score - lose_score + + # Margin loss. + # By multiplying T in the inner term , we try to maximize the + # margin throughout the overall denoising process. + # T here is the number of training steps from the + # underlying noise scheduler. + margin = F.logsigmoid(score_difference * total_timesteps + 1e-10) + margin_losses = beta * margin + + # Full MaPO loss + loss = loss_w.mean(dim=(1, 2, 3)) - margin_losses.mean(dim=(1, 2, 3)) + + metrics = { + "loss/mapo_total": loss.detach().mean().item(), + "loss/mapo_ratio": -margin_losses.detach().mean().item(), + "loss/mapo_w_loss": loss_w.detach().mean().item(), + "loss/mapo_l_loss": loss_l.detach().mean().item(), + "loss/mapo_score_difference": score_difference.detach().mean().item(), + "loss/mapo_win_score": win_score.detach().mean().item(), + "loss/mapo_lose_score": lose_score.detach().mean().item(), + } + + return loss, metrics + + +def ddo_loss(loss, ref_loss, w_t: float, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): + """ + Implements Direct Discriminative Optimization (DDO) loss. + + DDO bridges likelihood-based generative training with GAN objectives + by parameterizing a discriminator using the likelihood ratio between + a learnable target model and a fixed reference model. + + Args: + loss: Target model loss + ref_loss: Reference model loss (should be detached) + w_t: weight at timestep + ddo_alpha: Weight coefficient for the fake samples loss term. + Controls the balance between real/fake samples in training. + Higher values increase penalty on reference model samples. + ddo_beta: Scaling factor for the likelihood ratio to control gradient magnitude. + Smaller values produce a smoother optimization landscape. + Too large values can lead to numerical instability. + + Returns: + tuple: (total_loss, metrics_dict) + - total_loss: Combined DDO loss for optimization + - metrics_dict: Dictionary containing component losses for monitoring + """ + ref_loss = ref_loss.detach() # Ensure no gradients to reference + + # Log likelihood from weighted loss + target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3)) + ref_logp = -torch.sum(w_t * ref_loss, dim=(1, 2, 3)) + + # ∆xt,t,ε = -w(t) * [||εθ(xt,t) - ε||²₂ - ||εθref(xt,t) - ε||²₂] + delta = target_logp - ref_logp + + # log_ratio = β * log pθ(x)/pθref(x) + log_ratio = ddo_beta * delta + + # E_pdata[log σ(-log_ratio)] + data_loss = -F.logsigmoid(log_ratio) + + # αE_pθref[log(1 - σ(log_ratio))] + ref_loss_term = -ddo_alpha * F.logsigmoid(-log_ratio) + + total_loss = data_loss + ref_loss_term + + metrics = { + "loss/ddo_data": data_loss.detach().mean().item(), + "loss/ddo_ref": ref_loss_term.detach().mean().item(), + "loss/ddo_total": total_loss.detach().mean().item(), + "loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(), + } + + return total_loss, metrics + + +def cpo_loss(loss: torch.Tensor, beta: float = 0.1) -> tuple[torch.Tensor, dict[str, int | float]]: + """ + CPO Loss = L(π_θ; U) - E[log π_θ(y_w|x)] + + Where L(π_θ; U) is the uniform reference DPO loss and the second term + is a behavioral cloning regularizer on preferred data. + + Args: + loss: Losses of w and l B, C, H, W + beta: Weight for log ratio (Similar to Diffusion DPO) + """ + # L(π_θ; U) - DPO loss with uniform reference (no reference model needed) + loss_w, loss_l = loss.chunk(2) + + # Prevent values from being too small, causing large gradients + log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01)) + uniform_dpo_loss = -F.logsigmoid(beta * log_ratio).mean() + + # Behavioral cloning regularizer: -E[log π_θ(y_w|x)] + bc_regularizer = -loss_w.mean() + + # Total CPO loss + cpo_loss = uniform_dpo_loss + bc_regularizer + + metrics = {} + metrics["loss/cpo_reward_margin"] = uniform_dpo_loss.detach().mean().item() + + return cpo_loss, metrics + + +def bpo_loss(loss: Tensor, ref_loss: Tensor, beta: float, lambda_: float) -> tuple[Tensor, dict[str, int | float]]: + """ + Bregman Preference Optimization + + Paper: Preference Optimization by Estimating the + Ratio of the Data Distribution + + Computes the BPO loss + loss: Loss from the training model B + ref_loss: Loss from the reference model B + param beta : Regularization coefficient + param lambda : hyperparameter for SBA + """ + # Compute the model ratio corresponding to Line 4 of Algorithm 1. + loss_w, loss_l = loss.chunk(2) + ref_loss_w, ref_loss_l = ref_loss.chunk(2) + + logits = loss_w - loss_l - ref_loss_w + ref_loss_l + reward_margin = beta * logits + R = torch.exp(-reward_margin) + + # Clip R values to be no smaller than 0.01 for training stability + R = torch.max(R, torch.full_like(R, 0.01)) + + # Compute the loss according to the function h , following Line 5 of Algorithm 1. + if lambda_ == 0.0: + losses = R + torch.log(R) + else: + losses = R ** (lambda_ + 1) - ((lambda_ + 1) / lambda_) * (R ** (-lambda_)) + losses /= 4 * (1 + lambda_) + + metrics = {} + metrics["loss/bpo_reward_margin"] = reward_margin.detach().mean().item() + metrics["loss/bpo_R"] = R.detach().mean().item() + return losses.mean(dim=(1, 2, 3)), metrics + + +def kto_loss(loss: Tensor, ref_loss: Tensor, kl_loss: Tensor, ref_kl_loss: Tensor, w_t=1.0, undesirable_w_t=1.0, beta=0.1): + """ + KTO: Model Alignment as Prospect Theoretic Optimization + https://arxiv.org/abs/2402.01306 + + Compute the Kahneman-Tversky loss for a batch of policy and reference model losses. + If generation y ~ p_desirable, we have the 'desirable' loss: + L(x, y) := 1 - sigmoid(beta * ([log p_policy(y|x) - log p_reference(y|x)] - KL(p_policy || p_reference))) + If generation y ~ p_undesirable, we have the 'undesirable' loss: + L(x, y) := 1 - sigmoid(beta * (KL(p_policy || p_reference) - [log p_policy(y|x) - log p_reference(y|x)])) + The desirable losses are weighed by w_t. + The undesirable losses are weighed by undesirable_w_t. + This should be used to address imbalances in the ratio of desirable:undesirable examples respectively. + The KL term is estimated by matching x with unrelated outputs y', then calculating the average log ratio + log p_policy(y'|x) - log p_reference(y'|x). Doing so avoids the requirement that there be equal numbers of + desirable and undesirable examples in the microbatch. It can be estimated differently: the 'z1' estimate + takes the mean reward clamped to be non-negative; the 'z2' estimate takes the mean over rewards when y|x + is more probable under the policy than the reference. + """ + loss_w, loss_l = loss.chunk(2) + ref_loss_w, ref_loss_l = ref_loss.chunk(2) + + # Convert losses to rewards (negative loss = positive reward) + chosen_rewards = -(loss_w - loss_l) + rejected_rewards = -(ref_loss_w - ref_loss_l) + KL_rewards = -(kl_loss - ref_kl_loss) + + # Estimate KL divergence using unmatched samples + KL_estimate = KL_rewards.mean().clamp(min=0) + + losses = [] + + # Desirable (chosen) samples: we want reward > KL + if chosen_rewards.shape[0] > 0: + chosen_kto_losses = w_t * (1 - F.sigmoid(beta * (chosen_rewards - KL_estimate))) + losses.append(chosen_kto_losses) + + # Undesirable (rejected) samples: we want KL > reward + if rejected_rewards.shape[0] > 0: + rejected_kto_losses = undesirable_w_t * (1 - F.sigmoid(beta * (KL_estimate - rejected_rewards))) + losses.append(rejected_kto_losses) + + if losses: + total_loss = torch.cat(losses, 0).mean() + else: + total_loss = torch.tensor(0.0) + + return total_loss + + +def ipo_loss(loss: Tensor, ref_loss: Tensor, tau=0.1): + """ + IPO: Iterative Preference Optimization for Text-to-Video Generation + https://arxiv.org/abs/2502.02088 + """ + loss_w, loss_l = loss.chunk(2) + ref_loss_w, ref_loss_l = ref_loss.chunk(2) + + chosen_rewards = loss_w - ref_loss_w + rejected_rewards = loss_l - ref_loss_l + + losses = (chosen_rewards - rejected_rewards - (1 / (2 * tau))).pow(2) + + metrics: dict[str, int | float] = {} + metrics["loss/ipo_chosen_rewards"] = chosen_rewards.detach().mean().item() + metrics["loss/ipo_rejected_rewards"] = rejected_rewards.detach().mean().item() + + return losses, metrics + + +def compute_importance_weight(loss: Tensor, ref_loss: Tensor) -> Tensor: + """ + Compute importance weight w(t) = p_θ(x_{t-1}|x_t) / q(x_{t-1}|x_t, x_0) + + Args: + loss: Training model loss B, ... + ref_loss: Reference model loss B, ... + """ + # Approximate importance weight (higher when model prediction is better) + w_t = torch.exp(-loss + ref_loss) # [batch_size] + return w_t + + +def clip_importance_weight(w_t: Tensor, epsilon=0.1) -> Tensor: + """ + Clip importance weights: w̃(t) = clip(w(t), 1-ε, 1+ε) + """ + return torch.clamp(w_t, 1 - epsilon, 1 + epsilon) + + +def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[Tensor, dict[str, int | float]]: + """ + SDPO Loss (Formula 11): + L_SDPO(θ) = -E[log σ(w̃_θ(t) · ψ(x^w_{t-1}|x^w_t) - w̃_θ(t) · ψ(x^l_{t-1}|x^l_t))] + + where ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t)) + """ + + loss_w, loss_l = loss.chunk(2) + ref_loss_w, ref_loss_l = ref_loss.chunk(2) + + # Compute step-wise importance weights for inverse weighting + w_theta_w = compute_importance_weight(loss_w, ref_loss_w) + w_theta_l = compute_importance_weight(loss_l, ref_loss_l) + + # Inverse weighting with clipping (Formula 12) + w_theta_w_inv = clip_importance_weight(1.0 / (w_theta_w + 1e-8), epsilon=epsilon) + w_theta_l_inv = clip_importance_weight(1.0 / (w_theta_l + 1e-8), epsilon=epsilon) + w_theta_max = torch.max(w_theta_w_inv, w_theta_l_inv) # [batch_size] + + # Compute ψ terms: ψ(x_{t-1}|x_t) = β · log(p*_θ(x_{t-1}|x_t) / p_ref(x_{t-1}|x_t)) + # Approximated using negative MSE differences + + # For preferred samples + log_ratio_w = -loss_w + ref_loss_w + psi_w = beta * log_ratio_w # [batch_size] + + # For dispreferred samples + log_ratio_l = -loss_l + ref_loss_l + psi_l = beta * log_ratio_l # [batch_size] + + # Final SDPO loss computation + logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size] + sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size] + + metrics: dict[str, int | float] = {} + metrics["loss/sdpo_log_ratio_w"] = log_ratio_w.detach().mean().item() + metrics["loss/sdpo_log_ratio_l"] = log_ratio_l.detach().mean().item() + metrics["loss/sdpo_w_theta_max"] = w_theta_max.detach().mean().item() + metrics["loss/sdpo_w_theta_w"] = w_theta_w.detach().mean().item() + metrics["loss/sdpo_w_theta_l"] = w_theta_l.detach().mean().item() + + return sigmoid_loss.mean(dim=(1, 2, 3)), metrics + + +def simpo_loss( + loss: torch.Tensor, loss_type: str = "sigmoid", gamma_beta_ratio: float = 0.25, beta: float = 2.0, smoothing: float = 0.0 +) -> tuple[torch.Tensor, dict[str, int | float]]: + """ + Compute the SimPO loss for a batch of policy and reference model + + SimPO: Simple Preference Optimization with a Reference-Free Reward + https://arxiv.org/abs/2405.14734 + """ + loss_w, loss_l = loss.chunk(2) + + pi_logratios = loss_w - loss_l + pi_logratios = pi_logratios + logits = pi_logratios - gamma_beta_ratio + + if loss_type == "sigmoid": + losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing + elif loss_type == "hinge": + losses = torch.relu(1 - beta * logits) + else: + raise ValueError(f"Unknown loss type: {loss_type}. Should be one of ['sigmoid', 'hinge']") + + metrics = {} + metrics["loss/simpo_chosen_rewards"] = (beta * loss_w.detach()).mean().item() + metrics["loss/simpo_rejected_rewards"] = (beta * loss_l.detach()).mean().item() + metrics["loss/simpo_logratio"] = (beta * logits.detach()).mean().item() + + return losses, metrics + + +def normalize_gradients(model): + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters() if p.grad is not None])) + if total_norm > 0: + for p in model.parameters(): + if p.grad is not None: + p.grad.div_(total_norm) + + """ ########################################## # Perlin Noise diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f3eb81992..34a5ad541 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -419,7 +419,7 @@ def denoise( # region train -def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) -> torch.FloatTensor: sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) schedule_timesteps = noise_scheduler.timesteps.to(device) timesteps = timesteps.to(device) @@ -450,7 +450,7 @@ def compute_density_for_timestep_sampling( return u -def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas) -> torch.Tensor: """Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -467,35 +467,43 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def get_noisy_model_input_and_timesteps( +def get_noisy_model_input_and_timestep( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """ + Returns: + tuple[ + noisy_model_input: noisy at sigma applied to latent + timesteps: timesteps between 1.0 and 1000.0 + sigmas: sigmas between 0.0 and 1.0 + ] + """ bsz, _, h, w = latents.shape assert bsz > 0, "Batch size not large enough" - num_timesteps = noise_scheduler.config.num_train_timesteps + num_timesteps: int = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random sigma-based noise sampling if args.timestep_sampling == "sigmoid": # https://github.com/XLabs-AI/x-flux/tree/main - sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + sigma = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: - sigmas = torch.rand((bsz,), device=device) + sigma = torch.rand((bsz,), device=device) - timesteps = sigmas * num_timesteps + timestep = sigma * num_timesteps elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift - sigmas = torch.randn(bsz, device=device) - sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling - sigmas = sigmas.sigmoid() - sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) - timesteps = sigmas * num_timesteps + sigma = torch.randn(bsz, device=device) + sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling + sigma = sigma.sigmoid() + sigma = (sigma * shift) / (1 + (shift - 1) * sigma) + timestep = sigma * num_timesteps elif args.timestep_sampling == "flux_shift": - sigmas = torch.randn(bsz, device=device) - sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling - sigmas = sigmas.sigmoid() + sigma = torch.randn(bsz, device=device) + sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling + sigma = sigma.sigmoid() mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size - sigmas = time_shift(mu, 1.0, sigmas) - timesteps = sigmas * num_timesteps + sigma = time_shift(mu, 1.0, sigma) + timestep = noise_scheduler._sigma_to_t(sigma) else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -507,28 +515,29 @@ def get_noisy_model_input_and_timesteps( mode_scale=args.mode_scale, ) indices = (u * num_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=device) - sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + timestep: torch.Tensor = noise_scheduler.timesteps[indices].to(device=device) + sigma = get_sigmas(noise_scheduler, timestep, device, n_dim=latents.ndim, dtype=dtype) # Broadcast sigmas to latent shape - sigmas = sigmas.view(-1, 1, 1, 1) + sigma = sigma.view(-1, 1, 1, 1) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: + assert isinstance(args.ip_noise_gamma, float) xi = torch.randn_like(latents, device=latents.device, dtype=dtype) if args.ip_noise_gamma_random_strength: ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma else: ip_noise_gamma = args.ip_noise_gamma - noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) + noisy_model_input = (1.0 - sigma) * latents + sigma * (noise + ip_noise_gamma * xi) else: - noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise + noisy_model_input = (1.0 - sigma) * latents + sigma * noise - return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas + return noisy_model_input.to(dtype), timestep.to(dtype), sigma -def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): +def apply_model_prediction_type(args, model_pred: torch.FloatTensor, noisy_model_input, sigmas): weighting = None if args.model_prediction_type == "raw": pass diff --git a/library/flux_utils.py b/library/flux_utils.py index 3f0a0d63e..8c25de477 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -418,7 +418,7 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi return img_ids -def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: +def unpack_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor: """ x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 """ diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index c40798846..fe03e8fc6 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -895,7 +895,7 @@ def compute_density_for_timestep_sampling( return u -def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor: """Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. diff --git a/library/strategy_base.py b/library/strategy_base.py index fad79682f..f41a915d4 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -11,7 +11,7 @@ # TODO remove circular import by moving ImageInfo to a separate file # from library.train_util import ImageInfo - +# from library.train_util import ImageSetInfo from library.utils import setup_logging setup_logging() @@ -539,6 +539,7 @@ def _default_cache_batch_latents( info.latents_flipped = flipped_latent info.alpha_mask = alpha_mask + def load_latents_from_disk( self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: diff --git a/library/train_util.py b/library/train_util.py index c866dec2a..8e0d525c1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -209,6 +209,76 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime self.resize_interpolation: Optional[str] = None + self._current = 0 + + def __iter__(self): + self._current = 0 + return self + + def __next__(self): + if self._current < 1: + self._current += 1 + return self + else: + self.current = 0 + raise StopIteration + + def __len__(self): + return 1 + + def __getitem__(self, item): + if item == 0: + return self + else: + raise IndexError("Index out of range") + + @staticmethod + def _pin_tensor(tensor): + return tensor.pin_memory() if tensor is not None else tensor + + def pin_memory(self): + self.latents = self._pin_tensor(self.latents) + self.latents_flipped = self._pin_tensor(self.latents_flipped) + self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1) + self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2) + self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2) + self.alpha_mask = self._pin_tensor(self.alpha_mask) + return self + + +class ImageSetInfo: + def __init__(self, images: list[ImageInfo] = []) -> None: + super().__init__() + + self.images = images + self.current = 0 + + @property + def image_key(self): + return self.images[0].image_key + + @property + def bucket_reso(self): + return self.images[0].bucket_reso + + def __iter__(self): + return self + + def __next__(self): + if self.current < len(self.images): + result = self.images[self.current] + self.current += 1 + return result + else: + self.current = 0 + raise StopIteration + + def __getitem__(self, item): + return self.images[item] + + def __len__(self): + return len(self.images) + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -433,6 +503,11 @@ def __init__( validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, resize_interpolation: Optional[str] = None, + preference: bool = False, + preference_caption_prefix: Optional[str] = None, + preference_caption_suffix: Optional[str] = None, + non_preference_caption_prefix: Optional[str] = None, + non_preference_caption_suffix: Optional[str] = None, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -457,6 +532,11 @@ def __init__( self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる self.custom_attributes = custom_attributes if custom_attributes is not None else {} + self.preference = preference + self.preference_caption_prefix = preference_caption_prefix + self.preference_caption_suffix = preference_caption_suffix + self.non_preference_caption_prefix = non_preference_caption_prefix + self.non_preference_caption_suffix = non_preference_caption_suffix self.img_count = 0 @@ -497,6 +577,11 @@ def __init__( validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, resize_interpolation: Optional[str] = None, + preference: bool = False, + preference_caption_prefix: Optional[str] = None, + preference_caption_suffix: Optional[str] = None, + non_preference_caption_prefix: Optional[str] = None, + non_preference_caption_suffix: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -525,6 +610,11 @@ def __init__( validation_seed=validation_seed, validation_split=validation_split, resize_interpolation=resize_interpolation, + preference=preference, + preference_caption_prefix=preference_caption_prefix, + preference_caption_suffix=preference_caption_suffix, + non_preference_caption_prefix=non_preference_caption_prefix, + non_preference_caption_suffix=non_preference_caption_suffix, ) self.is_reg = is_reg @@ -635,6 +725,11 @@ def __init__( validation_seed: Optional[int] = None, validation_split: Optional[float] = 0.0, resize_interpolation: Optional[str] = None, + preference: bool = False, + preference_caption_prefix: Optional[str] = None, + preference_caption_suffix: Optional[str] = None, + non_preference_caption_prefix: Optional[str] = None, + non_preference_caption_suffix: Optional[str] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -663,6 +758,11 @@ def __init__( validation_seed=validation_seed, validation_split=validation_split, resize_interpolation=resize_interpolation, + preference=preference, + preference_caption_prefix=preference_caption_prefix, + preference_caption_suffix=preference_caption_suffix, + non_preference_caption_prefix=non_preference_caption_prefix, + non_preference_caption_suffix=non_preference_caption_suffix, ) self.conditioning_data_dir = conditioning_data_dir @@ -683,7 +783,7 @@ def __init__( resolution: Optional[Tuple[int, int]], network_multiplier: float, debug_dataset: bool, - resize_interpolation: Optional[str] = None + resize_interpolation: Optional[str] = None, ) -> None: super().__init__() @@ -719,10 +819,12 @@ def __init__( self.image_transforms = IMAGE_TRANSFORMS if resize_interpolation is not None: - assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation" + assert validate_interpolation_fn( + resize_interpolation + ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' self.resize_interpolation = resize_interpolation - self.image_data: Dict[str, ImageInfo] = {} + self.image_data: Dict[str, ImageInfo | ImageSetInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} self.replacements = {} @@ -975,7 +1077,7 @@ def get_input_ids(self, caption, tokenizer=None): input_ids = torch.stack(iids_list) # 3,77 return input_ids - def register_image(self, info: ImageInfo, subset: BaseSubset): + def register_image(self, info: ImageInfo | ImageSetInfo, subset: BaseSubset): self.image_data[info.image_key] = info self.image_to_subset[info.image_key] = subset @@ -985,9 +1087,10 @@ def make_buckets(self): min_size and max_size are ignored when enable_bucket is False """ logger.info("loading image sizes.") - for info in tqdm(self.image_data.values()): - if info.image_size is None: - info.image_size = self.get_image_size(info.absolute_path) + for infos in tqdm(self.image_data.values()): + for info in infos: + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) # # run in parallel # max_workers = min(os.cpu_count(), len(self.image_data)) # TODO consider multi-gpu (processes) @@ -1029,26 +1132,39 @@ def make_buckets(self): ) img_ar_errors = [] - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( - image_width, image_height - ) + for image_infos in self.image_data.values(): + for image_info in image_infos: + image_width, image_height = image_info.image_size + image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket( + image_width, image_height + ) - # logger.info(image_info.image_key, image_info.bucket_reso) - img_ar_errors.append(abs(ar_error)) + # logger.info(image_info.image_key, image_info.bucket_reso) + img_ar_errors.append(abs(ar_error)) self.bucket_manager.sort() else: self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None) self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ - for image_info in self.image_data.values(): - image_width, image_height = image_info.image_size - image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) + for image_infos in self.image_data.values(): + for info in image_infos: + image_width, image_height = info.image_size + info.bucket_reso, info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height) + + for infos in self.image_data.values(): + bucket_reso = None + for info in infos: + if bucket_reso is None: + bucket_reso = info.bucket_reso + else: + assert ( + bucket_reso == info.bucket_reso + ), f"Image pair not found in same bucket. {info.image_key} {bucket_reso} {info.bucket_reso}" - for image_info in self.image_data.values(): - for _ in range(image_info.num_repeats): - self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key) + assert bucket_reso is not None + + for _ in range(infos[0].num_repeats): + self.bucket_manager.add_image(bucket_reso, infos[0].image_key) # bucket情報を表示、格納する if self.enable_bucket: @@ -1135,7 +1251,7 @@ def __eq__(self, other): and self.random_crop == other.random_crop ) - batch: List[ImageInfo] = [] + batch: list[ImageInfo] = [] current_condition = None # support multiple-gpus @@ -1143,7 +1259,7 @@ def __eq__(self, other): process_index = accelerator.process_index # define a function to submit a batch to cache - def submit_batch(batch, cond): + def submit_batch(batch: list[ImageInfo], cond): for info in batch: if info.image is not None and isinstance(info.image, Future): info.image = info.image.result() # future to image @@ -1162,52 +1278,52 @@ def submit_batch(batch, cond): try: # iterate images logger.info("caching latents...") - for i, info in enumerate(tqdm(image_infos)): - subset = self.image_to_subset[info.image_key] + for i, infos in enumerate(tqdm(image_infos)): + subset = self.image_to_subset[infos[0].image_key] - if info.latents_npz is not None: # fine tuning dataset - continue + for info in infos: + if info.latents_npz is not None: # fine tuning dataset + continue - # check disk cache exists and size of latents - if caching_strategy.cache_to_disk: - # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) + # check disk cache exists and size of latents + if caching_strategy.cache_to_disk: + # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) - # if the modulo of num_processes is not equal to process_index, skip caching - # this makes each process cache different latents - if i % num_processes != process_index: - continue + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different latents + if i % num_processes != process_index: + continue - # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") + # print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}") - cache_available = caching_strategy.is_disk_cached_latents_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask - ) - if cache_available: # do not add to batch - continue + cache_available = caching_strategy.is_disk_cached_latents_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) + if cache_available: # do not add to batch + continue - # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty - condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) - if len(batch) > 0 and current_condition != condition: - submit_batch(batch, current_condition) - batch = [] + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + submit_batch(batch, current_condition) + batch = [] - if info.image is None: - # load image in parallel - info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) + if info.image is None: + # load image in parallel + info.image = executor.submit(load_image, info.absolute_path, condition.alpha_mask) - batch.append(info) - current_condition = condition + batch.append(info) + current_condition = condition - # if number of data in batch is enough, flush the batch - if len(batch) >= caching_strategy.batch_size: - submit_batch(batch, current_condition) - batch = [] - current_condition = None + # if number of data in batch is enough, flush the batch + if len(batch) >= caching_strategy.batch_size: + submit_batch(batch, current_condition) + batch = [] + current_condition = None if len(batch) > 0: submit_batch(batch, current_condition) - finally: executor.shutdown() @@ -1236,44 +1352,44 @@ def __eq__(self, other): and self.random_crop == other.random_crop ) - batches: List[Tuple[Condition, List[ImageInfo]]] = [] - batch: List[ImageInfo] = [] + batches: list[tuple[Condition, list[ImageInfo | ImageSetInfo]]] = [] + batch: list[ImageInfo | ImageSetInfo] = [] current_condition = None logger.info("checking cache validity...") - for info in tqdm(image_infos): - subset = self.image_to_subset[info.image_key] - - if info.latents_npz is not None: # fine tuning dataset - continue - - # check disk cache exists and size of latents - if cache_to_disk: - info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - if not is_main_process: # store to info only + for infos in tqdm(image_infos): + subset = self.image_to_subset[infos[0].image_key] + for info in infos: + if info.latents_npz is not None: # fine tuning dataset continue - cache_available = is_disk_cached_latents_is_expected( - info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask - ) + # check disk cache exists and size of latents + if cache_to_disk: + info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix + if not is_main_process: # store to info only + continue - if cache_available: # do not add to batch - continue + cache_available = is_disk_cached_latents_is_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) - # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty - condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) - if len(batch) > 0 and current_condition != condition: - batches.append((current_condition, batch)) - batch = [] + if cache_available: # do not add to batch + continue - batch.append(info) - current_condition = condition + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) + batch = [] - # if number of data in batch is enough, flush the batch - if len(batch) >= vae_batch_size: - batches.append((current_condition, batch)) - batch = [] - current_condition = None + batch.append(info) + current_condition = condition + + # if number of data in batch is enough, flush the batch + if len(batch) >= vae_batch_size: + batches.append((current_condition, batch)) + batch = [] + current_condition = None if len(batch) > 0: batches.append((current_condition, batch)) @@ -1307,27 +1423,28 @@ def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Acceler process_index = accelerator.process_index logger.info("checking cache validity...") - for i, info in enumerate(tqdm(image_infos)): - # check disk cache exists and size of text encoder outputs - if caching_strategy.cache_to_disk: - te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) - info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability - - # if the modulo of num_processes is not equal to process_index, skip caching - # this makes each process cache different text encoder outputs - if i % num_processes != process_index: - continue + for i, infos in enumerate(tqdm(image_infos)): + for info in infos: + # check disk cache exists and size of text encoder outputs + if caching_strategy.cache_to_disk: + te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path) + info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability - cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) - if cache_available: # do not add to batch - continue + # if the modulo of num_processes is not equal to process_index, skip caching + # this makes each process cache different text encoder outputs + if i % num_processes != process_index: + continue + + cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz) + if cache_available: # do not add to batch + continue - batch.append(info) + batch.append(info) - # if number of data in batch is enough, flush the batch - if len(batch) >= batch_size: - batches.append(batch) - batch = [] + # if number of data in batch is enough, flush the batch + if len(batch) >= batch_size: + batches.append(batch) + batch = [] if len(batch) > 0: batches.append(batch) @@ -1482,6 +1599,71 @@ def get_image_size(self, image_path): image_size = (0, 0) return image_size + def load_and_transform_image(self, subset, image_info, absolute_path, flipped): + # 画像を読み込み、必要ならcropする + + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, absolute_path, subset.alpha_mask) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img, original_size, crop_ltrb = trim_and_resize_if_required( + subset.random_crop, img, image_info.bucket_reso, image_info.resized_size + ) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert ( + subset.random_crop + ), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p : p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p : p + self.width] + + im_h, im_w = img.shape[0:2] + assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {absolute_path}" + + original_size = [im_w, im_h] + crop_ltrb = (0, 0, 0, 0) + + # augmentation + aug = self.aug_helper.get_augmentor(subset.color_aug) + if aug is not None: + # augment RGB channels only + img_rgb = img[:, :, :3] + img_rgb = aug(image=img_rgb)["image"] + img[:, :, :3] = img_rgb + + if flipped: + img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + + if subset.alpha_mask: + if img.shape[2] == 4: + alpha_mask = img[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 + alpha_mask = torch.FloatTensor(alpha_mask) + else: + alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) + else: + alpha_mask = None + + img = img[:, :, :3] # remove alpha channel + + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + + if not flipped: + crop_left_top = (crop_ltrb[0], crop_ltrb[1]) + else: + # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image + crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) + + return image, original_size, crop_left_top, alpha_mask + def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): img = load_image(image_path, alpha_mask) @@ -1569,170 +1751,133 @@ def __getitem__(self, index): custom_attributes = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: - image_info = self.image_data[image_key] + image_infos = self.image_data[image_key] subset = self.image_to_subset[image_key] + for image_info in image_infos: + custom_attributes.append(subset.custom_attributes) - custom_attributes.append(subset.custom_attributes) - - # in case of fine tuning, is_reg is always False - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + # in case of fine tuning, is_reg is always False + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) - flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance - - # image/latentsを処理する - if image_info.latents is not None: # cache_latents=Trueの場合 - original_size = image_info.latents_original_size - crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped - if not flipped: - latents = image_info.latents - alpha_mask = image_info.alpha_mask - else: - latents = image_info.latents_flipped - alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) + flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance - image = None - elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( - self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) - ) - if flipped: - latents = flipped_latents - alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem - del flipped_latents - latents = torch.FloatTensor(latents) - if alpha_mask is not None: - alpha_mask = torch.FloatTensor(alpha_mask) - - image = None - else: - # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( - subset, image_info.absolute_path, subset.alpha_mask - ) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img, original_size, crop_ltrb = trim_and_resize_if_required( - subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation + # image/latentsを処理する + if image_info.latents is not None: # cache_latents=Trueの場合 + original_size = image_info.latents_original_size + crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped + if not flipped: + latents = image_info.latents + alpha_mask = image_info.alpha_mask + else: + latents = image_info.latents_flipped + alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) + + target_size = (latents.shape[2] * 8, latents.shape[1] * 8) + image = None + + images.append(image) + latents_list.append(latents) + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) + elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) ) - else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(subset, img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert ( - subset.random_crop - ), f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p : p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p : p + self.width] - - im_h, im_w = img.shape[0:2] - assert ( - im_h == self.height and im_w == self.width - ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" - - original_size = [im_w, im_h] - crop_ltrb = (0, 0, 0, 0) - - # augmentation - aug = self.aug_helper.get_augmentor(subset.color_aug) - if aug is not None: - # augment RGB channels only - img_rgb = img[:, :, :3] - img_rgb = aug(image=img_rgb)["image"] - img[:, :, :3] = img_rgb - - if flipped: - img = img[:, ::-1, :].copy() # copy to avoid negative stride problem - - if subset.alpha_mask: - if img.shape[2] == 4: - alpha_mask = img[:, :, 3] # [H,W] - alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 + if flipped: + latents = flipped_latents + alpha_mask = ( + None if alpha_mask is None else alpha_mask[:, ::-1].copy() + ) # copy to avoid negative stride problem + del flipped_latents + latents = torch.FloatTensor(latents) + if alpha_mask is not None: alpha_mask = torch.FloatTensor(alpha_mask) - else: - alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) - else: - alpha_mask = None + target_size = (latents.shape[2] * 8, latents.shape[1] * 8) - img = img[:, :, :3] # remove alpha channel + image = None - latents = None - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる - del img - - images.append(image) - latents_list.append(latents) - alpha_mask_list.append(alpha_mask) + images.append(image) + latents_list.append(latents) + alpha_mask_list.append(alpha_mask) + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) + else: + image, original_size, crop_ltrb, alpha_mask = self.load_and_transform_image( + subset, image_info, image_info.absolute_path, flipped + ) + images.append(image) + latents_list.append(None) + alpha_mask_list.append(alpha_mask) - target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + target_size = ( + (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) + ) - if not flipped: - crop_left_top = (crop_ltrb[0], crop_ltrb[1]) - else: - # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image - crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) + if not flipped: + crop_left_top = (crop_ltrb[0], crop_ltrb[1]) + else: + # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image + crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) - original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) - crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) - target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) - flippeds.append(flipped) + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) + flippeds.append(flipped) - # captionとtext encoder outputを処理する - caption = image_info.caption # default + # captionとtext encoder outputを処理する + caption = image_info.caption # default - tokenization_required = ( - self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial - ) - text_encoder_outputs = None - input_ids = None - - if image_info.text_encoder_outputs is not None: - # cached - text_encoder_outputs = image_info.text_encoder_outputs - elif image_info.text_encoder_outputs_npz is not None: - # on disk - text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( - image_info.text_encoder_outputs_npz + tokenization_required = ( + self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial ) - else: - tokenization_required = True - text_encoder_outputs_list.append(text_encoder_outputs) - - if tokenization_required: - caption = self.process_caption(subset, image_info.caption) - input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension - # if self.XTI_layers: - # caption_layer = [] - # for layer in self.XTI_layers: - # token_strings_from = " ".join(self.token_strings) - # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) - # caption_ = caption.replace(token_strings_from, token_strings_to) - # caption_layer.append(caption_) - # captions.append(caption_layer) - # else: - # captions.append(caption) - - # if not self.token_padding_disabled: # this option might be omitted in future - # # TODO get_input_ids must support SD3 - # if self.XTI_layers: - # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) - # else: - # token_caption = self.get_input_ids(caption, self.tokenizers[0]) - # input_ids_list.append(token_caption) - - # if len(self.tokenizers) > 1: - # if self.XTI_layers: - # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) - # else: - # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) - # input_ids2_list.append(token_caption2) - - input_ids_list.append(input_ids) - captions.append(caption) + text_encoder_outputs = None + input_ids = None + + if image_info.text_encoder_outputs is not None: + # cached + text_encoder_outputs = image_info.text_encoder_outputs + elif image_info.text_encoder_outputs_npz is not None: + # on disk + text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz( + image_info.text_encoder_outputs_npz + ) + else: + tokenization_required = True + text_encoder_outputs_list.append(text_encoder_outputs) + + if tokenization_required: + caption = self.process_caption(subset, image_info.caption) + input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension + # if self.XTI_layers: + # caption_layer = [] + # for layer in self.XTI_layers: + # token_strings_from = " ".join(self.token_strings) + # token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings]) + # caption_ = caption.replace(token_strings_from, token_strings_to) + # caption_layer.append(caption_) + # captions.append(caption_layer) + # else: + # captions.append(caption) + + # if not self.token_padding_disabled: # this option might be omitted in future + # # TODO get_input_ids must support SD3 + # if self.XTI_layers: + # token_caption = self.get_input_ids(caption_layer, self.tokenizers[0]) + # else: + # token_caption = self.get_input_ids(caption, self.tokenizers[0]) + # input_ids_list.append(token_caption) + + # if len(self.tokenizers) > 1: + # if self.XTI_layers: + # token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1]) + # else: + # token_caption2 = self.get_input_ids(caption, self.tokenizers[1]) + # input_ids2_list.append(token_caption2) + + input_ids_list.append(input_ids) + captions.append(caption) def none_or_stack_elements(tensors_list, converter): # [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)] @@ -1772,6 +1917,7 @@ def none_or_stack_elements(tensors_list, converter): example["images"] = images example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None + example["captions"] = captions example["original_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in original_sizes_hw]) @@ -1798,41 +1944,42 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): random_crop = None for image_key in bucket[image_index : image_index + bucket_batch_size]: - image_info = self.image_data[image_key] + image_infos = self.image_data[image_key] subset = self.image_to_subset[image_key] - if flip_aug is None: - flip_aug = subset.flip_aug - alpha_mask = subset.alpha_mask - random_crop = subset.random_crop - bucket_reso = image_info.bucket_reso - else: - # TODO そもそも混在してても動くようにしたほうがいい - assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" - assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" - assert random_crop == subset.random_crop, "random_crop must be same in a batch" - assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" + for image_info in image_infos: + if flip_aug is None: + flip_aug = subset.flip_aug + alpha_mask = subset.alpha_mask + random_crop = subset.random_crop + bucket_reso = image_info.bucket_reso + else: + # TODO そもそも混在してても動くようにしたほうがいい + assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" + assert random_crop == subset.random_crop, "random_crop must be same in a batch" + assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" - caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc. + caption = image_info.caption # TODO cache some patterns of dropping, shuffling, etc. - if self.caching_mode == "latents": - image = load_image(image_info.absolute_path) - else: - image = None + if self.caching_mode == "latents": + image = load_image(image_info.absolute_path) + else: + image = None - if self.caching_mode == "text": - input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) - input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) - else: - input_ids1 = None - input_ids2 = None + if self.caching_mode == "text": + input_ids1 = self.get_input_ids(caption, self.tokenizers[0]) + input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) + else: + input_ids1 = None + input_ids2 = None - captions.append(caption) - images.append(image) - input_ids1_list.append(input_ids1) - input_ids2_list.append(input_ids2) - absolute_paths.append(image_info.absolute_path) - resized_sizes.append(image_info.resized_size) + captions.append(caption) + images.append(image) + input_ids1_list.append(input_ids1) + input_ids2_list.append(input_ids2) + absolute_paths.append(image_info.absolute_path) + resized_sizes.append(image_info.resized_size) example = {} @@ -1954,6 +2101,11 @@ def load_dreambooth_dir(subset: DreamBoothSubset): img_paths = list(metas.keys()) sizes: List[Optional[Tuple[int, int]]] = [meta["resolution"] for meta in metas.values()] + elif subset.preference: + # We assume a image_dir path pattern for winner/loser + winner_path = str(pathlib.Path(subset.image_dir) / "w") + img_paths = glob_images(winner_path, "*") + sizes = [None] * len(img_paths) # we may need to check image size and existence of image files, but it takes time, so user should check it before training else: img_paths = glob_images(subset.image_dir, "*") @@ -2100,10 +2252,65 @@ def load_dreambooth_dir(subset: DreamBoothSubset): num_train_images += num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) - info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation - if size is not None: - info.image_size = size + if subset.preference: + + def get_non_preferred_pair_info(img_path, subset): + head, file = os.path.split(img_path) + head, tail = os.path.split(head) + new_tail = tail.replace("w", "l") + loser_img_path = os.path.join(head, new_tail, file) + + def check_extension(path: str): + from pathlib import Path + + test_path = Path(path) + if not test_path.exists(): + for ext in [".webp", ".png", ".jpg", ".jpeg", ".png"]: + test_path = test_path.with_suffix(ext) + if test_path.exists(): + return str(test_path) + + return str(test_path) + + loser_img_path = check_extension(loser_img_path) + + caption = read_caption(img_path, subset.caption_extension, subset.enable_wildcard) + + if subset.non_preference_caption_prefix: + caption = subset.non_preference_caption_prefix + " " + caption + if subset.non_preference_caption_suffix: + caption = caption + " " + subset.non_preference_caption_suffix + + image_size = self.get_image_size(img_path) if size is not None else None + + return loser_img_path, caption, image_size + + if subset.preference_caption_prefix: + caption = subset.preference_caption_prefix + " " + caption + if subset.preference_caption_suffix: + caption = caption + " " + subset.preference_caption_suffix + + resize_interpolation = ( + subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + ) + + chosen_image_info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path) + chosen_image_info.resize_interpolation = resize_interpolation + rejected_img_path, rejected_caption, rejected_image_size = get_non_preferred_pair_info(img_path, subset) + rejected_image_info = ImageInfo( + rejected_img_path, subset.num_repeats, caption, subset.is_reg, rejected_img_path + ) + rejected_image_info.resize_interpolation = resize_interpolation + + info = ImageSetInfo([chosen_image_info, rejected_image_info]) + else: + info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) + info.resize_interpolation = ( + subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation + ) + if size is not None: + info.image_size = size + if subset.is_reg: reg_infos.append((info, subset)) else: @@ -2385,7 +2592,7 @@ def __init__( bucket_no_upscale: bool, debug_dataset: bool, validation_split: float, - validation_seed: Optional[int], + validation_seed: Optional[int], resize_interpolation: Optional[str] = None, ) -> None: super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) @@ -2421,6 +2628,11 @@ def __init__( subset.token_warmup_min, subset.token_warmup_step, resize_interpolation=subset.resize_interpolation, + preference=subset.preference, + preference_caption_prefix=subset.preference_caption_prefix, + preference_caption_suffix=subset.preference_caption_suffix, + non_preference_caption_prefix=subset.non_preference_caption_prefix, + non_preference_caption_suffix=subset.non_preference_caption_suffix, ) db_subsets.append(db_subset) @@ -2448,7 +2660,7 @@ def __init__( self.num_train_images = self.dreambooth_dataset_delegate.num_train_images self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images self.validation_split = validation_split - self.validation_seed = validation_seed + self.validation_seed = validation_seed self.resize_interpolation = resize_interpolation # assert all conditioning data exists @@ -2538,7 +2750,14 @@ def __getitem__(self, index): cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation) + cond_img = resize_image( + cond_img, + original_size_hw[1], + original_size_hw[0], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) # TODO support random crop # 現在サポートしているcropはrandomではなく中央のみ @@ -2552,7 +2771,14 @@ def __getitem__(self, index): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation) + cond_img = resize_image( + cond_img, + cond_img.shape[0], + cond_img.shape[1], + target_size_hw[1], + target_size_hw[0], + self.resize_interpolation, + ) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2982,7 +3208,7 @@ def trim_and_resize_if_required( # for new_cache_latents def load_images_and_masks_for_caching( image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool -) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: +) -> Tuple[torch.Tensor, list[torch.Tensor | None], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]: r""" requires image_infos to have: [absolute_path or image], bucket_reso, resized_size @@ -2994,38 +3220,47 @@ def load_images_and_masks_for_caching( crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...] """ images: List[torch.Tensor] = [] - alpha_masks: List[np.ndarray] = [] + alpha_masks: list[torch.Tensor | None] = [] original_sizes: List[Tuple[int, int]] = [] crop_ltrbs: List[Tuple[int, int, int, int]] = [] - for info in image_infos: - image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) - # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) - - original_sizes.append(original_size) - crop_ltrbs.append(crop_ltrb) - - if use_alpha_mask: - if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [H,W] - alpha_mask = alpha_mask.astype(np.float32) / 255.0 - alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + for infos in image_infos: + for info in infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) + + original_sizes.append(original_size) + crop_ltrbs.append(crop_ltrb) + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(torch.from_numpy(image[:, :, 0]), dtype=torch.float32) # [H,W] else: - alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] - else: - alpha_mask = None - alpha_masks.append(alpha_mask) + alpha_mask = None + alpha_masks.append(alpha_mask) - image = image[:, :, :3] # remove alpha channel if exists - image = IMAGE_TRANSFORMS(image) - images.append(image) + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + assert isinstance(image, torch.Tensor) + images.append(image) img_tensor = torch.stack(images, dim=0) return img_tensor, alpha_masks, original_sizes, crop_ltrbs def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool + vae: AutoencoderKL, + cache_to_disk: bool, + image_infos: list[ImageInfo | ImageSetInfo], + flip_aug: bool, + use_alpha_mask: bool, + random_crop: bool, ) -> None: r""" requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz @@ -3037,29 +3272,32 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] - alpha_masks: List[np.ndarray] = [] - for info in image_infos: - image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) - # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 - image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation) - - info.latents_original_size = original_size - info.latents_crop_ltrb = crop_ltrb - - if use_alpha_mask: - if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [H,W] - alpha_mask = alpha_mask.astype(np.float32) / 255.0 - alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + alpha_masks: List[torch.Tensor | None] = [] + for infos in image_infos: + for info in infos: + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 + image, original_size, crop_ltrb = trim_and_resize_if_required( + random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation + ) + + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] + else: + alpha_mask = torch.ones_like(torch.from_numpy(image[:, :, 0]), dtype=torch.float32) # [H,W] else: - alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] - else: - alpha_mask = None - alpha_masks.append(alpha_mask) + alpha_mask = None + alpha_masks.append(alpha_mask) - image = image[:, :, :3] # remove alpha channel if exists - image = IMAGE_TRANSFORMS(image) - images.append(image) + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) img_tensors = torch.stack(images, dim=0) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) @@ -4135,7 +4373,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", ) - if support_dreambooth: # DreamBooth training parser.add_argument( @@ -4444,6 +4681,30 @@ def add_dataset_arguments( default=None, help="suffix for caption text / captionのテキストの末尾に付ける文字列", ) + parser.add_argument( + "--preference_caption_prefix", + type=str, + default=None, + help="prefix for preference caption text / captionのテキストの先頭に付ける文字列", + ) + parser.add_argument( + "--preference_caption_suffix", + type=str, + default=None, + help="suffix for preference caption text / captionのテキストの末尾に付ける文字列", + ) + parser.add_argument( + "--non_preference_caption_prefix", + type=str, + default=None, + help="prefix for non-preference caption text / captionのテキストの先頭に付ける文字列", + ) + parser.add_argument( + "--non_preference_caption_suffix", + type=str, + default=None, + help="suffix for non-preference caption text / captionのテキストの末尾に付ける文字列", + ) parser.add_argument( "--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする" ) @@ -5994,6 +6255,7 @@ def get_noise_noisy_latents_and_timesteps( # Sample a random timestep for each image b_size = latents.shape[0] + min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep @@ -6027,7 +6289,8 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler elif args.huber_schedule == "snr": if not hasattr(noise_scheduler, "alphas_cumprod"): raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + device = noise_scheduler.alphas_cumprod.device + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.to(device)) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c result = result.to(timesteps.device) @@ -6589,4 +6852,3 @@ def moving_average(self) -> float: if losses == 0: return 0 return self.loss_total / losses - diff --git a/library/utils.py b/library/utils.py index d0586b84a..6742e8533 100644 --- a/library/utils.py +++ b/library/utils.py @@ -16,6 +16,7 @@ import numpy as np from safetensors.torch import load_file + def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -88,6 +89,7 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) + setup_logging() logger = logging.getLogger(__name__) @@ -398,7 +400,9 @@ def pil_resize(image, size, interpolation): return resized_cv2 -def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None): +def resize_image( + image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None +): """ Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS. @@ -449,29 +453,30 @@ def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]: https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121 """ if interpolation is None: - return None + return None if interpolation == "lanczos" or interpolation == "lanczos4": - # Lanczos interpolation over 8x8 neighborhood + # Lanczos interpolation over 8x8 neighborhood return cv2.INTER_LANCZOS4 elif interpolation == "nearest": - # Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab. + # Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab. return cv2.INTER_NEAREST_EXACT elif interpolation == "bilinear" or interpolation == "linear": # bilinear interpolation return cv2.INTER_LINEAR elif interpolation == "bicubic" or interpolation == "cubic": - # bicubic interpolation + # bicubic interpolation return cv2.INTER_CUBIC elif interpolation == "area": - # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. return cv2.INTER_AREA elif interpolation == "box": - # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. + # resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method. return cv2.INTER_AREA else: return None + def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]: """ Convert interpolation value to PIL interpolation @@ -479,7 +484,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters """ if interpolation is None: - return None + return None if interpolation == "lanczos": return Image.Resampling.LANCZOS @@ -493,7 +498,7 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp # For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used. return Image.Resampling.BICUBIC elif interpolation == "area": - # Image.Resampling.BOX may be more appropriate if upscaling + # Image.Resampling.BOX may be more appropriate if upscaling # Area interpolation is related to cv2.INTER_AREA # Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX. return Image.Resampling.HAMMING @@ -503,12 +508,37 @@ def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resamp else: return None + def validate_interpolation_fn(interpolation_str: str) -> bool: """ Check if a interpolation function is supported """ return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"] + +# For debugging +def save_latent_as_img(vae, latent_to, output_name): + """Save latent as image using VAE""" + from PIL import Image + + with torch.no_grad(): + image = vae.decode(latent_to.to(vae.dtype)).float() + # VAE outputs are typically in the range [-1, 1], so rescale to [0, 255] + image = (image / 2 + 0.5).clamp(0, 1) + + # Convert to numpy array with values in range [0, 255] + image = (image * 255).cpu().numpy().astype(np.uint8) + + # Rearrange dimensions from [batch_size, channels, height, width] to [batch_size, height, width, channels] + image = image.transpose(0, 2, 3, 1) + + # Take the first image if you have a batch + pil_image = Image.fromarray(image[0]) + + # Save the image + pil_image.save(output_name) + + # endregion # TODO make inf_utils.py diff --git a/requirements.txt b/requirements.txt index 448af323c..828f366a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ voluptuous==0.13.1 huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 -numpy<=2.0 +numpy<2.0 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 diff --git a/sd3_train_network.py b/sd3_train_network.py index cdb7aa4e3..a3c8c2930 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -323,7 +323,7 @@ def get_noise_pred_and_target( weight_dtype, train_unet, is_train=True, - ): + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -389,7 +389,7 @@ def get_noise_pred_and_target( target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, weighting + return model_pred, noisy_model_input, target, sigmas, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/tests/library/test_custom_train_functions_bpo.py b/tests/library/test_custom_train_functions_bpo.py new file mode 100644 index 000000000..387b44c4c --- /dev/null +++ b/tests/library/test_custom_train_functions_bpo.py @@ -0,0 +1,358 @@ +import pytest +import torch + +from library.custom_train_functions import bpo_loss + + +class TestBPOLoss: + """Test suite for BPO loss function""" + + @pytest.fixture + def sample_tensors(self): + """Create sample tensors for testing image latent tensors""" + # Image latent tensor dimensions + batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs + channels = 4 # Latent channels (e.g., VAE latent space) + height = 32 # Latent height + width = 32 # Latent width + + # Create tensors with shape [2*batch_size, channels, height, width] + # First half represents preferred (w), second half dispreferred (l) + loss = torch.randn(2 * batch_size, channels, height, width) + ref_loss = torch.randn(2 * batch_size, channels, height, width) + + return loss, ref_loss + + @pytest.fixture + def simple_tensors(self): + """Create simple tensors for basic testing""" + # Create tensors with shape (2, 4, 32, 32) + # First tensor (batch 0) + batch_0 = torch.full((4, 32, 32), 1.0) + batch_0[1] = 2.0 # Second channel + batch_0[2] = 2.0 # Third channel + batch_0[3] = 3.0 # Fourth channel + + # Second tensor (batch 1) + batch_1 = torch.full((4, 32, 32), 3.0) + batch_1[1] = 4.0 + batch_1[2] = 5.0 + batch_1[3] = 2.0 + + loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32) + + # Reference loss tensor + ref_batch_0 = torch.full((4, 32, 32), 0.5) + ref_batch_0[1] = 1.5 + ref_batch_0[2] = 3.5 + ref_batch_0[3] = 9.5 + + ref_batch_1 = torch.full((4, 32, 32), 2.5) + ref_batch_1[1] = 3.5 + ref_batch_1[2] = 4.5 + ref_batch_1[3] = 3.5 + + ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32) + + return loss, ref_loss + + @torch.no_grad() + def test_basic_functionality(self, simple_tensors): + """Test basic functionality with simple inputs""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # Check return types + assert isinstance(result_loss, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape (should be scalar after mean reduction) + assert result_loss.shape == torch.Size([1]) + + # Check that loss is finite + assert torch.isfinite(result_loss) + + @torch.no_grad() + def test_metrics_keys(self, simple_tensors): + """Test that all expected metrics are returned""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + _, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + expected_keys = ["loss/bpo_reward_margin", "loss/bpo_R"] + + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], (int, float)) + assert torch.isfinite(torch.tensor(metrics[key])) + + @torch.no_grad() + def test_lambda_zero_case(self, simple_tensors): + """Test the special case when lambda = 0.0""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.0 + + result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # Should handle lambda=0 case (R + log(R)) + assert torch.isfinite(result_loss) + assert "loss/bpo_reward_margin" in metrics + assert "loss/bpo_R" in metrics + + @torch.no_grad() + def test_different_beta_values(self, simple_tensors): + """Test with different beta values""" + loss, ref_loss = simple_tensors + lambda_ = 0.5 + + beta_values = [0.01, 0.1, 0.5, 1.0] + results = [] + + for beta in beta_values: + result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_) + results.append(result_loss.item()) + + # Results should be different for different beta values + assert len(set(results)) == len(beta_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + @torch.no_grad() + def test_different_lambda_values(self, simple_tensors): + """Test with different lambda values""" + loss, ref_loss = simple_tensors + beta = 0.1 + + lambda_values = [0.0, 0.1, 0.5, 1.0, 2.0] + results = [] + + for lambda_ in lambda_values: + result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_) + results.append(result_loss.item()) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + @torch.no_grad() + def test_r_clipping(self, simple_tensors): + """Test that R values are properly clipped to minimum 0.01""" + loss, ref_loss = simple_tensors + beta = 10.0 # Large beta to potentially create very small R values + lambda_ = 0.5 + + result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # R should be >= 0.01 due to clipping + assert metrics["loss/bpo_R"] >= 0.01 + assert torch.isfinite(result_loss) + + @torch.no_grad() + def test_tensor_chunking(self, sample_tensors): + """Test that tensor chunking works correctly""" + loss, ref_loss = sample_tensors + beta = 0.1 + lambda_ = 0.5 + + result_loss, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # The function should handle chunking internally + assert torch.isfinite(result_loss) + assert len(metrics) == 2 + + def test_gradient_flow(self, simple_tensors): + """Test that gradients can flow through the loss""" + loss, ref_loss = simple_tensors + loss.requires_grad_(True) + ref_loss.requires_grad_(True) + beta = 0.1 + lambda_ = 0.5 + + result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_) + result_loss.backward() + + # Check that gradients exist + assert loss.grad is not None + assert ref_loss.grad is not None + assert not torch.isnan(loss.grad).any() + assert not torch.isnan(ref_loss.grad).any() + + @torch.no_grad() + def test_numerical_stability_extreme_values(self): + """Test numerical stability with extreme values""" + # Test with very large values + large_loss = torch.full((2, 4, 32, 32), 100.0) + large_ref_loss = torch.full((2, 4, 32, 32), 50.0) + + result_loss, _ = bpo_loss(large_loss, large_ref_loss, beta=0.1, lambda_=0.5) + assert torch.isfinite(result_loss) + + # Test with very small values + small_loss = torch.full((2, 4, 32, 32), 1e-6) + small_ref_loss = torch.full((2, 4, 32, 32), 1e-7) + + result_loss, _ = bpo_loss(small_loss, small_ref_loss, beta=0.1, lambda_=0.5) + assert torch.isfinite(result_loss) + + @torch.no_grad() + def test_negative_lambda_values(self, simple_tensors): + """Test with negative lambda values""" + loss, ref_loss = simple_tensors + beta = 0.1 + + # Test some negative lambda values + lambda_values = [-0.5, -0.1, -0.9] + + for lambda_ in lambda_values: + # Skip lambda = -1 as it causes division by zero + if lambda_ != -1.0: + result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_) + assert torch.isfinite(result_loss) + + @torch.no_grad() + def test_edge_case_lambda_near_negative_one(self, simple_tensors): + """Test edge case near lambda = -1""" + loss, ref_loss = simple_tensors + beta = 0.1 + + # Test values close to -1 but not exactly -1 + lambda_values = [-0.99, -0.999] + + for lambda_ in lambda_values: + result_loss, _ = bpo_loss(loss, ref_loss, beta, lambda_) + # Should still be finite even though close to the problematic value + assert torch.isfinite(result_loss) + + @torch.no_grad() + def test_asymmetric_preference_structure(self): + """Test that the function properly handles preferred vs dispreferred samples""" + # Create scenario where preferred samples have lower loss + loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss) + loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss) + loss = torch.cat([loss_w, loss_l], dim=0) + + ref_loss_w = torch.full((1, 4, 32, 32), 2.0) + ref_loss_l = torch.full((1, 4, 32, 32), 2.0) + ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0) + + result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5) + + # The loss should be finite and reflect the preference structure + assert torch.isfinite(result_loss) + + # The reward margin should reflect the preference (preferred - dispreferred) + # In this case: (1-3) - (2-2) = -2, so reward_margin should be negative + assert metrics["loss/bpo_reward_margin"] < 0 + + @pytest.mark.parametrize( + "batch_size,channels,height,width", + [ + (2, 4, 32, 32), + (2, 4, 16, 16), + (2, 8, 64, 64), + ], + ) + @torch.no_grad() + def test_different_tensor_shapes(self, batch_size, channels, height, width): + """Test with different tensor shapes""" + loss = torch.randn(2 * batch_size, channels, height, width) + ref_loss = torch.randn(2 * batch_size, channels, height, width) + + result_loss, metrics = bpo_loss(loss, ref_loss, beta=0.1, lambda_=0.5) + + assert torch.isfinite(result_loss.mean()) + assert result_loss.shape == torch.Size([2]) + assert len(metrics) == 2 + + def test_device_compatibility(self, simple_tensors): + """Test that function works on different devices""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + # Test on CPU + result_cpu, _ = bpo_loss(loss, ref_loss, beta, lambda_) + assert result_cpu.device.type == "cpu" + + # Test on GPU if available + if torch.cuda.is_available(): + loss_gpu = loss.cuda() + ref_loss_gpu = ref_loss.cuda() + result_gpu, _ = bpo_loss(loss_gpu, ref_loss_gpu, beta, lambda_) + assert result_gpu.device.type == "cuda" + + @torch.no_grad() + def test_reproducibility(self, simple_tensors): + """Test that results are reproducible with same inputs""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + # Run multiple times with same seed + torch.manual_seed(42) + result1, metrics1 = bpo_loss(loss, ref_loss, beta, lambda_) + + torch.manual_seed(42) + result2, metrics2 = bpo_loss(loss, ref_loss, beta, lambda_) + + # Results should be identical + assert torch.allclose(result1, result2) + for key in metrics1: + assert abs(metrics1[key] - metrics2[key]) < 1e-6 + + @torch.no_grad() + def test_zero_inputs(self): + """Test with zero inputs""" + zero_loss = torch.zeros(2, 4, 32, 32) + zero_ref_loss = torch.zeros(2, 4, 32, 32) + + result_loss, metrics = bpo_loss(zero_loss, zero_ref_loss, beta=0.1, lambda_=0.5) + + # Should handle zero inputs gracefully + assert torch.isfinite(result_loss) + for value in metrics.values(): + assert torch.isfinite(torch.tensor(value)) + + @torch.no_grad() + def test_reward_margin_computation(self, simple_tensors): + """Test that reward margin is computed correctly""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + _, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # Manually compute expected reward margin + loss_w, loss_l = loss.chunk(2) + ref_loss_w, ref_loss_l = ref_loss.chunk(2) + expected_logits = loss_w - loss_l - ref_loss_w + ref_loss_l + expected_reward_margin = beta * expected_logits + + # Compare with returned metric (within floating point precision) + assert abs(metrics["loss/bpo_reward_margin"] - expected_reward_margin.mean().item()) < 1e-5 + + @torch.no_grad() + def test_r_value_computation(self, simple_tensors): + """Test that R values are computed correctly""" + loss, ref_loss = simple_tensors + beta = 0.1 + lambda_ = 0.5 + + _, metrics = bpo_loss(loss, ref_loss, beta, lambda_) + + # R should be positive and >= 0.01 due to clipping + assert metrics["loss/bpo_R"] > 0 + assert metrics["loss/bpo_R"] >= 0.01 + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_cpo.py b/tests/library/test_custom_train_functions_cpo.py new file mode 100644 index 000000000..64c3d507b --- /dev/null +++ b/tests/library/test_custom_train_functions_cpo.py @@ -0,0 +1,384 @@ +import pytest +import torch +import torch.nn.functional as F + +from library.custom_train_functions import cpo_loss + + +class TestCPOLoss: + """Test suite for CPO (Contrastive Preference Optimization) loss function""" + + @pytest.fixture + def sample_tensors(self): + """Create sample tensors for testing image latent tensors""" + # Image latent tensor dimensions + batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs + channels = 4 # Latent channels (e.g., VAE latent space) + height = 32 # Latent height + width = 32 # Latent width + + # Create tensors with shape [2*batch_size, channels, height, width] + # First half represents preferred (w), second half dispreferred (l) + loss = torch.randn(2 * batch_size, channels, height, width) + + return loss + + @pytest.fixture + def simple_tensors(self): + """Create simple tensors for basic testing""" + # Create tensors with shape (2, 4, 32, 32) + # First tensor (batch 0) - preferred + batch_0 = torch.full((4, 32, 32), 1.0) + batch_0[1] = 2.0 # Second channel + batch_0[2] = 1.5 # Third channel + batch_0[3] = 1.8 # Fourth channel + + # Second tensor (batch 1) - dispreferred + batch_1 = torch.full((4, 32, 32), 3.0) + batch_1[1] = 4.0 + batch_1[2] = 3.5 + batch_1[3] = 3.8 + + loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32) + + return loss + + def test_basic_functionality(self, simple_tensors): + """Test basic functionality with simple inputs""" + loss = simple_tensors + + result_loss, metrics = cpo_loss(loss) + + # Check return types + assert isinstance(result_loss, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape (should be scalar) + assert result_loss.shape == torch.Size([]) + + # Check that loss is finite + assert torch.isfinite(result_loss) + + def test_metrics_keys(self, simple_tensors): + """Test that all expected metrics are returned""" + loss = simple_tensors + + _, metrics = cpo_loss(loss) + + expected_keys = ["loss/cpo_reward_margin"] + + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], (int, float)) + assert torch.isfinite(torch.tensor(metrics[key])) + + def test_tensor_chunking(self, sample_tensors): + """Test that tensor chunking works correctly""" + loss = sample_tensors + + result_loss, metrics = cpo_loss(loss) + + # The function should handle chunking internally + assert torch.isfinite(result_loss) + assert len(metrics) == 1 + + # Verify chunking produces correct shapes + loss_w, loss_l = loss.chunk(2) + assert loss_w.shape == loss_l.shape + assert loss_w.shape[0] == loss.shape[0] // 2 + + def test_different_beta_values(self, simple_tensors): + """Test with different beta values""" + loss = simple_tensors + + beta_values = [0.01, 0.05, 0.1, 0.5, 1.0] + results = [] + + for beta in beta_values: + result_loss, _ = cpo_loss(loss, beta=beta) + results.append(result_loss.item()) + + # Results should be different for different beta values + assert len(set(results)) == len(beta_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_log_ratio_clipping(self, simple_tensors): + """Test that log ratio is properly clipped to minimum 0.01""" + loss = simple_tensors + + # Manually verify clipping behavior + loss_w, loss_l = loss.chunk(2) + raw_log_ratio = loss_w - loss_l + + result_loss, _ = cpo_loss(loss) + + # The function should clip values to minimum 0.01 + expected_log_ratio = torch.max(raw_log_ratio, torch.full_like(raw_log_ratio, 0.01)) + + # All clipped values should be >= 0.01 + assert (expected_log_ratio >= 0.01).all() + assert torch.isfinite(result_loss) + + def test_uniform_dpo_component(self, simple_tensors): + """Test the uniform DPO loss component""" + loss = simple_tensors + beta = 0.1 + + _, metrics = cpo_loss(loss, beta=beta) + + # Manually compute uniform DPO loss + loss_w, loss_l = loss.chunk(2) + log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01)) + expected_uniform_dpo = -F.logsigmoid(beta * log_ratio).mean() + + # The metric should match our manual computation + assert abs(metrics["loss/cpo_reward_margin"] - expected_uniform_dpo.item()) < 1e-5 + + def test_behavioral_cloning_component(self, simple_tensors): + """Test the behavioral cloning regularizer component""" + loss = simple_tensors + + result_loss, metrics = cpo_loss(loss) + + # Manually compute BC regularizer + loss_w, _ = loss.chunk(2) + expected_bc_regularizer = -loss_w.mean() + + # The total loss should include this component + # Total = uniform_dpo + bc_regularizer + expected_total = metrics["loss/cpo_reward_margin"] + expected_bc_regularizer.item() + + # Should match within floating point precision + assert abs(result_loss.item() - expected_total) < 1e-5 + + def test_gradient_flow(self, simple_tensors): + """Test that gradients flow properly through the loss""" + loss = simple_tensors + loss.requires_grad_(True) + + result_loss, _ = cpo_loss(loss) + result_loss.backward() + + # Check that gradients exist + assert loss.grad is not None + assert not torch.isnan(loss.grad).any() + assert torch.isfinite(loss.grad).all() + + def test_preferred_vs_dispreferred_structure(self): + """Test that the function properly handles preferred vs dispreferred samples""" + # Create scenario where preferred samples have lower loss (better) + loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss) + loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss) + loss = torch.cat([loss_w, loss_l], dim=0) + + result_loss, _ = cpo_loss(loss) + + # The loss should be finite and reflect the preference structure + assert torch.isfinite(result_loss) + + # With preferred having lower loss, log_ratio should be negative + # This should lead to specific behavior in the logsigmoid term + log_ratio = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0) + clipped_log_ratio = torch.max(log_ratio, torch.full_like(log_ratio, 0.01)) + + # After clipping, should be 0.01 (the minimum) + assert torch.allclose(clipped_log_ratio, torch.full_like(clipped_log_ratio, 0.01)) + + def test_equal_losses_case(self): + """Test behavior when preferred and dispreferred losses are equal""" + # Create scenario where preferred and dispreferred have same loss + loss_w = torch.full((1, 4, 32, 32), 2.0) + loss_l = torch.full((1, 4, 32, 32), 2.0) + loss = torch.cat([loss_w, loss_l], dim=0) + + result_loss, metrics = cpo_loss(loss) + + # Log ratio should be zero, but clipped to 0.01 + assert torch.isfinite(result_loss) + + # The reward margin should reflect the clipped behavior + assert metrics["loss/cpo_reward_margin"] > 0 + + def test_numerical_stability_extreme_values(self): + """Test numerical stability with extreme values""" + # Test with very large values + large_loss = torch.full((2, 4, 32, 32), 100.0) + result_loss, _ = cpo_loss(large_loss) + assert torch.isfinite(result_loss) + + # Test with very small values + small_loss = torch.full((2, 4, 32, 32), 1e-6) + result_loss, _ = cpo_loss(small_loss) + assert torch.isfinite(result_loss) + + # Test with negative values + negative_loss = torch.full((2, 4, 32, 32), -1.0) + result_loss, _ = cpo_loss(negative_loss) + assert torch.isfinite(result_loss) + + def test_zero_beta_case(self, simple_tensors): + """Test the case when beta = 0""" + loss = simple_tensors + beta = 0.0 + + result_loss, metrics = cpo_loss(loss, beta=beta) + + # With beta=0, the uniform DPO term should behave differently + # logsigmoid(0 * log_ratio) = logsigmoid(0) = log(0.5) ≈ -0.693 + assert torch.isfinite(result_loss) + assert metrics["loss/cpo_reward_margin"] > 0 # Should be approximately 0.693 + + def test_large_beta_case(self, simple_tensors): + """Test the case with very large beta""" + loss = simple_tensors + beta = 100.0 + + result_loss, metrics = cpo_loss(loss, beta=beta) + + # Even with large beta, should remain stable due to clipping + assert torch.isfinite(result_loss) + assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"])) + + @pytest.mark.parametrize( + "batch_size,channels,height,width", + [ + (1, 4, 32, 32), + (2, 4, 16, 16), + (4, 8, 64, 64), + (8, 4, 8, 8), + ], + ) + def test_different_tensor_shapes(self, batch_size, channels, height, width): + """Test with different tensor shapes""" + # Note: batch_size will be doubled for preferred/dispreferred pairs + loss = torch.randn(2 * batch_size, channels, height, width) + + result_loss, metrics = cpo_loss(loss) + + assert torch.isfinite(result_loss) + assert result_loss.shape == torch.Size([]) # Scalar + assert len(metrics) == 1 + + def test_device_compatibility(self, simple_tensors): + """Test that function works on different devices""" + loss = simple_tensors + + # Test on CPU + result_cpu, _ = cpo_loss(loss) + assert result_cpu.device.type == "cpu" + + # Test on GPU if available + if torch.cuda.is_available(): + loss_gpu = loss.cuda() + result_gpu, _ = cpo_loss(loss_gpu) + assert result_gpu.device.type == "cuda" + + def test_reproducibility(self, simple_tensors): + """Test that results are reproducible with same inputs""" + loss = simple_tensors + + # Run multiple times + result1, metrics1 = cpo_loss(loss) + result2, metrics2 = cpo_loss(loss) + + # Results should be identical (deterministic computation) + assert torch.allclose(result1, result2) + for key in metrics1: + assert abs(metrics1[key] - metrics2[key]) < 1e-6 + + def test_no_reference_model_needed(self, simple_tensors): + """Test that CPO works without reference model (key feature)""" + loss = simple_tensors + + # CPO should work with just the loss tensor, no reference needed + result_loss, metrics = cpo_loss(loss) + + # Should produce meaningful results without reference model + assert torch.isfinite(result_loss) + assert len(metrics) == 1 + assert "loss/cpo_reward_margin" in metrics + + def test_loss_components_are_additive(self, simple_tensors): + """Test that the total loss is sum of uniform DPO and BC regularizer""" + loss = simple_tensors + beta = 0.1 + + result_loss, metrics = cpo_loss(loss, beta=beta) + + # Manually compute components + loss_w, loss_l = loss.chunk(2) + + # Uniform DPO component + log_ratio = torch.max(loss_w - loss_l, torch.full_like(loss_w, 0.01)) + uniform_dpo = -F.logsigmoid(beta * log_ratio).mean() + + # BC regularizer component + bc_regularizer = -loss_w.mean() + + # Total should be sum of components + expected_total = uniform_dpo + bc_regularizer + + assert abs(result_loss.item() - expected_total.item()) < 1e-5 + assert abs(metrics["loss/cpo_reward_margin"] - uniform_dpo.item()) < 1e-5 + + def test_clipping_prevents_large_gradients(self): + """Test that clipping prevents very large gradients from small differences""" + # Create case where loss_w - loss_l would be very small without clipping + loss_w = torch.full((1, 4, 32, 32), 2.000001) + loss_l = torch.full((1, 4, 32, 32), 2.000000) + loss = torch.cat([loss_w, loss_l], dim=0) + loss.requires_grad_(True) + + result_loss, _ = cpo_loss(loss) + result_loss.backward() + + assert loss.grad is not None + + # Gradients should be finite and not extremely large due to clipping + assert torch.isfinite(loss.grad).all() + assert not torch.any(torch.abs(loss.grad) > 0.001) # Reasonable gradient magnitude + + def test_behavioral_cloning_effect(self): + """Test that behavioral cloning regularizer has expected effect""" + # Create two scenarios: one with low preferred loss, one with high + + # Scenario 1: Low preferred loss + loss_w_low = torch.full((1, 4, 32, 32), 0.5) + loss_l_low = torch.full((1, 4, 32, 32), 2.0) + loss_low = torch.cat([loss_w_low, loss_l_low], dim=0) + + # Scenario 2: High preferred loss + loss_w_high = torch.full((1, 4, 32, 32), 2.0) + loss_l_high = torch.full((1, 4, 32, 32), 2.0) + loss_high = torch.cat([loss_w_high, loss_l_high], dim=0) + + result_low, _ = cpo_loss(loss_low) + result_high, _ = cpo_loss(loss_high) + + # The BC regularizer should make the total loss lower when preferred loss is lower + # BC regularizer = -loss_w.mean(), so lower loss_w leads to higher (less negative) regularizer + # But the overall effect depends on the relative magnitudes + assert torch.isfinite(result_low) + assert torch.isfinite(result_high) + + def test_edge_case_all_zeros(self): + """Test edge case with all zero losses""" + loss = torch.zeros(2, 4, 32, 32) + + result_loss, metrics = cpo_loss(loss) + + # Should handle all zeros gracefully + assert torch.isfinite(result_loss) + assert torch.isfinite(torch.tensor(metrics["loss/cpo_reward_margin"])) + + # With all zeros: loss_w - loss_l = 0, clipped to 0.01 + # BC regularizer = -0 = 0 + # So total should be just the uniform DPO term + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_ddo.py b/tests/library/test_custom_train_functions_ddo.py new file mode 100644 index 000000000..0b173c743 --- /dev/null +++ b/tests/library/test_custom_train_functions_ddo.py @@ -0,0 +1,376 @@ +import pytest +import torch +import torch.nn.functional as F + +from library.custom_train_functions import ddo_loss + + +class TestDDOLoss: + """Test suite for DDO (Direct Discriminative Optimization) loss function""" + + @pytest.fixture + def sample_tensors(self): + """Create sample tensors for testing image latent tensors""" + # Image latent tensor dimensions + batch_size = 2 + channels = 4 # Latent channels (e.g., VAE latent space) + height = 32 # Latent height + width = 32 # Latent width + + # Create tensors with shape [batch_size, channels, height, width] + loss = torch.randn(batch_size, channels, height, width) + ref_loss = torch.randn(batch_size, channels, height, width) + + return loss, ref_loss + + @pytest.fixture + def simple_tensors(self): + """Create simple tensors for basic testing""" + # Create tensors with shape (2, 4, 32, 32) + batch_0 = torch.full((4, 32, 32), 1.0) + batch_0[1] = 2.0 # Second channel + batch_0[2] = 1.5 # Third channel + batch_0[3] = 1.8 # Fourth channel + + batch_1 = torch.full((4, 32, 32), 2.0) + batch_1[1] = 3.0 + batch_1[2] = 2.5 + batch_1[3] = 2.8 + + loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32) + + # Reference loss tensor (different from target) + ref_batch_0 = torch.full((4, 32, 32), 1.2) + ref_batch_0[1] = 2.2 + ref_batch_0[2] = 1.7 + ref_batch_0[3] = 2.0 + + ref_batch_1 = torch.full((4, 32, 32), 2.3) + ref_batch_1[1] = 3.3 + ref_batch_1[2] = 2.8 + ref_batch_1[3] = 3.1 + + ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32) + + return loss, ref_loss + + def test_basic_functionality(self, simple_tensors): + """Test basic functionality with simple inputs""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # Check return types + assert isinstance(result_loss, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape (should be 1D with batch dimension) + assert result_loss.shape == torch.Size([2]) # batch_size = 2 + + # Check that loss is finite + assert torch.isfinite(result_loss).all() + + def test_metrics_keys(self, simple_tensors): + """Test that all expected metrics are returned""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + _, metrics = ddo_loss(loss, ref_loss, w_t) + + expected_keys = ["loss/ddo_data", "loss/ddo_ref", "loss/ddo_total", "loss/ddo_sigmoid_log_ratio"] + + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], (int, float)) + assert torch.isfinite(torch.tensor(metrics[key])) + + def test_ref_loss_detached(self, simple_tensors): + """Test that reference loss gradients are properly detached""" + loss, ref_loss = simple_tensors + loss.requires_grad_(True) + ref_loss.requires_grad_(True) + w_t = 1.0 + + result_loss, _ = ddo_loss(loss, ref_loss, w_t) + result_loss.sum().backward() + + # Target loss should have gradients + assert loss.grad is not None + assert not torch.isnan(loss.grad).any() + + # Reference loss should NOT have gradients due to detach() + assert ref_loss.grad is None or torch.allclose(ref_loss.grad, torch.zeros_like(ref_loss.grad)) + + def test_different_w_t_values(self, simple_tensors): + """Test with different timestep weights""" + loss, ref_loss = simple_tensors + + w_t_values = [0.1, 0.5, 1.0, 2.0, 5.0] + results = [] + + for w_t in w_t_values: + result_loss, _ = ddo_loss(loss, ref_loss, w_t) + results.append(result_loss.mean().item()) + + # Results should be different for different w_t values + assert len(set(results)) == len(w_t_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_different_ddo_alpha_values(self, simple_tensors): + """Test with different alpha values""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + alpha_values = [1.0, 2.0, 4.0, 8.0, 16.0] + results = [] + + for alpha in alpha_values: + result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_alpha=alpha) + results.append(result_loss.mean().item()) + + # Results should be different for different alpha values + assert len(set(results)) == len(alpha_values) + + # Higher alpha should generally increase the total loss due to increased ref penalty + # (though this depends on the specific values) + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_different_ddo_beta_values(self, simple_tensors): + """Test with different beta values""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + beta_values = [0.01, 0.05, 0.1, 0.2, 0.5] + results = [] + + for beta in beta_values: + result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta) + results.append(result_loss.mean().item()) + + # Results should be different for different beta values + assert len(set(results)) == len(beta_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_log_likelihood_computation(self, simple_tensors): + """Test that log likelihood computation is correct""" + loss, ref_loss = simple_tensors + w_t = 2.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # Manually compute expected log likelihoods + expected_target_logp = -torch.sum(w_t * loss, dim=(1, 2, 3)) + expected_ref_logp = -torch.sum(w_t * ref_loss.detach(), dim=(1, 2, 3)) + expected_delta = expected_target_logp - expected_ref_logp + + # The function should produce finite results + assert torch.isfinite(result_loss).all() + assert torch.isfinite(expected_delta).all() + + def test_sigmoid_log_ratio_bounds(self, simple_tensors): + """Test that sigmoid log ratio is properly bounded""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # Sigmoid output should be between 0 and 1 + sigmoid_ratio = metrics["loss/ddo_sigmoid_log_ratio"] + assert 0 <= sigmoid_ratio <= 1 + + def test_component_losses_relationship(self, simple_tensors): + """Test relationship between component losses and total loss""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # Total loss should equal data loss + ref loss (approximately) + expected_total = metrics["loss/ddo_data"] + metrics["loss/ddo_ref"] + actual_total = metrics["loss/ddo_total"] + + # Should be close within floating point precision + assert abs(expected_total - actual_total) < 1e-5 + + def test_numerical_stability_extreme_values(self): + """Test numerical stability with extreme values""" + # Test with very large values + large_loss = torch.full((2, 4, 32, 32), 100.0) + large_ref_loss = torch.full((2, 4, 32, 32), 50.0) + + result_loss, metrics = ddo_loss(large_loss, large_ref_loss, w_t=1.0) + assert torch.isfinite(result_loss).all() + + # Test with very small values + small_loss = torch.full((2, 4, 32, 32), 1e-6) + small_ref_loss = torch.full((2, 4, 32, 32), 1e-7) + + result_loss, metrics = ddo_loss(small_loss, small_ref_loss, w_t=1.0) + assert torch.isfinite(result_loss).all() + + def test_zero_w_t(self, simple_tensors): + """Test with zero timestep weight""" + loss, ref_loss = simple_tensors + w_t = 0.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # With w_t=0, log likelihoods should be zero, leading to specific behavior + assert torch.isfinite(result_loss).all() + + # When w_t=0, target_logp = ref_logp = 0, so delta = 0, log_ratio = 0 + # sigmoid(0) = 0.5, so sigmoid_log_ratio should be 0.5 + assert abs(metrics["loss/ddo_sigmoid_log_ratio"] - 0.5) < 1e-5 + + def test_negative_w_t(self, simple_tensors): + """Test with negative timestep weight""" + loss, ref_loss = simple_tensors + w_t = -1.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + # Should handle negative weights gracefully + assert torch.isfinite(result_loss).all() + for key, value in metrics.items(): + assert torch.isfinite(torch.tensor(value)) + + def test_gradient_flow(self, simple_tensors): + """Test that gradients flow properly through target loss only""" + loss, ref_loss = simple_tensors + loss.requires_grad_(True) + ref_loss.requires_grad_(True) + w_t = 1.0 + + result_loss, _ = ddo_loss(loss, ref_loss, w_t) + result_loss.sum().backward() + + # Check that gradients exist for target loss + assert loss.grad is not None + assert not torch.isnan(loss.grad).any() + + # Reference loss should not have gradients + assert ref_loss.grad is None or torch.allclose(ref_loss.grad, torch.zeros_like(ref_loss.grad)) + + @pytest.mark.parametrize( + "batch_size,channels,height,width", + [ + (1, 4, 32, 32), + (4, 4, 16, 16), + (2, 8, 64, 64), + (8, 4, 8, 8), + ], + ) + def test_different_tensor_shapes(self, batch_size, channels, height, width): + """Test with different tensor shapes""" + loss = torch.randn(batch_size, channels, height, width) + ref_loss = torch.randn(batch_size, channels, height, width) + w_t = 1.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t) + + assert torch.isfinite(result_loss).all() + assert result_loss.shape == torch.Size([batch_size]) + assert len(metrics) == 4 + + def test_device_compatibility(self, simple_tensors): + """Test that function works on different devices""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + # Test on CPU + result_cpu, metrics_cpu = ddo_loss(loss, ref_loss, w_t) + assert result_cpu.device.type == "cpu" + + # Test on GPU if available + if torch.cuda.is_available(): + loss_gpu = loss.cuda() + ref_loss_gpu = ref_loss.cuda() + result_gpu, metrics_gpu = ddo_loss(loss_gpu, ref_loss_gpu, w_t) + assert result_gpu.device.type == "cuda" + + def test_reproducibility(self, simple_tensors): + """Test that results are reproducible with same inputs""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + # Run multiple times + result1, metrics1 = ddo_loss(loss, ref_loss, w_t) + result2, metrics2 = ddo_loss(loss, ref_loss, w_t) + + # Results should be identical (deterministic computation) + assert torch.allclose(result1, result2) + for key in metrics1: + assert abs(metrics1[key] - metrics2[key]) < 1e-6 + + def test_logsigmoid_stability(self, simple_tensors): + """Test that logsigmoid operations are numerically stable""" + loss, ref_loss = simple_tensors + w_t = 1.0 + + # Test with extreme beta that could cause numerical issues + extreme_beta_values = [0.001, 100.0] + + for beta in extreme_beta_values: + result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta) + + # All components should be finite + assert torch.isfinite(result_loss).all() + assert torch.isfinite(torch.tensor(metrics["loss/ddo_data"])) + assert torch.isfinite(torch.tensor(metrics["loss/ddo_ref"])) + + def test_alpha_zero_case(self, simple_tensors): + """Test the case when alpha = 0 (no reference loss term)""" + loss, ref_loss = simple_tensors + w_t = 1.0 + alpha = 0.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_alpha=alpha) + + # With alpha=0, ref loss term should be zero + assert abs(metrics["loss/ddo_ref"]) < 1e-6 + + # Total loss should equal data loss + assert abs(metrics["loss/ddo_total"] - metrics["loss/ddo_data"]) < 1e-5 + + def test_beta_zero_case(self, simple_tensors): + """Test the case when beta = 0 (no scaling of log ratio)""" + loss, ref_loss = simple_tensors + w_t = 1.0 + beta = 0.0 + + result_loss, metrics = ddo_loss(loss, ref_loss, w_t, ddo_beta=beta) + + # With beta=0, log_ratio=0, so sigmoid should be 0.5 + assert abs(metrics["loss/ddo_sigmoid_log_ratio"] - 0.5) < 1e-5 + + # All losses should be finite + assert torch.isfinite(result_loss).all() + + def test_discriminative_behavior(self): + """Test that DDO behaves as expected for discriminative training""" + # Create scenario where target model is better than reference + target_loss = torch.full((2, 4, 32, 32), 1.0) # Lower loss (better) + ref_loss = torch.full((2, 4, 32, 32), 2.0) # Higher loss (worse) + w_t = 1.0 + + result_loss, metrics = ddo_loss(target_loss, ref_loss, w_t) + + # When target is better, we expect specific behavior in the discriminator + assert torch.isfinite(result_loss).all() + + # The sigmoid ratio should reflect that target model is preferred + # (exact value depends on beta, but should be meaningful) + assert 0 <= metrics["loss/ddo_sigmoid_log_ratio"] <= 1 + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_diffusion_dpo.py b/tests/library/test_custom_train_functions_diffusion_dpo.py new file mode 100644 index 000000000..4c5cf6245 --- /dev/null +++ b/tests/library/test_custom_train_functions_diffusion_dpo.py @@ -0,0 +1,149 @@ +import pytest +import torch + +from library.custom_train_functions import diffusion_dpo_loss + + +def test_diffusion_dpo_loss_basic(): + # Test basic functionality with simple inputs + batch_size = 4 + channels = 3 + height, width = 8, 8 + + # Create dummy loss tensors + loss = torch.rand(batch_size, channels, height, width) + ref_loss = torch.rand(batch_size, channels, height, width) + beta_dpo = 0.1 + + result, metrics = diffusion_dpo_loss(loss, ref_loss, beta_dpo) + + # Check return types + assert isinstance(result, torch.Tensor) + assert isinstance(metrics, dict) + + # Check shape of result + assert result.shape == torch.Size([batch_size // 2]) + + # Check metrics + expected_keys = [ + "loss/diffusion_dpo_total_loss", + "loss/diffusion_dpo_ref_loss", + "loss/diffusion_dpo_implicit_acc", + ] + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], float) + + +def test_diffusion_dpo_loss_different_shapes(): + # Test with different tensor shapes + shapes = [ + (2, 3, 8, 8), # Small tensor + (4, 6, 16, 16), # Medium tensor + (6, 9, 32, 32), # Larger tensor + ] + + for shape in shapes: + loss = torch.rand(*shape) + ref_loss = torch.rand(*shape) + + result, metrics = diffusion_dpo_loss(loss, ref_loss, 0.1) + + # Result should have batch dimension halved + assert result.shape == torch.Size([shape[0] // 2]) + + # All metrics should be scalars + for val in metrics.values(): + assert isinstance(val, float) + + +def test_diffusion_dpo_loss_beta_values(): + # Test with different beta values + batch_size = 4 + channels = 3 + height, width = 8, 8 + + loss = torch.rand(batch_size, channels, height, width) + ref_loss = torch.rand(batch_size, channels, height, width) + + # Test with different beta values + beta_values = [0.0, 0.5, 1.0, 10.0] + results = [] + + for beta in beta_values: + result, _ = diffusion_dpo_loss(loss, ref_loss, beta) + results.append(result.mean().item()) + + # With different betas, results should vary + assert len(set(results)) > 1, "Different beta values should produce different results" + + +def test_diffusion_dpo_loss_implicit_acc(): + # Test implicit accuracy calculation + batch_size = 4 + channels = 3 + height, width = 8, 8 + + # Create controlled test data where winners have lower loss + loss_w = torch.ones(batch_size // 2, channels, height, width) * 0.2 + loss_l = torch.ones(batch_size // 2, channels, height, width) * 0.8 + loss = torch.cat([loss_w, loss_l], dim=0) + + # Make reference losses with opposite preference + ref_w = torch.ones(batch_size // 2, channels, height, width) * 0.8 + ref_l = torch.ones(batch_size // 2, channels, height, width) * 0.2 + ref_loss = torch.cat([ref_w, ref_l], dim=0) + + # With beta=1.0, model_diff and ref_diff are opposite, should give low accuracy + _, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0) + assert metrics["loss/diffusion_dpo_implicit_acc"] > 0.5 + + # With beta=-1.0, the sign is flipped, should give high accuracy + _, metrics = diffusion_dpo_loss(loss, ref_loss, -1.0) + assert metrics["loss/diffusion_dpo_implicit_acc"] < 0.5 + + +def test_diffusion_dpo_gradient_flow(): + # Test that gradients flow properly + batch_size = 4 + channels = 3 + height, width = 8, 8 + + # Create tensors that require gradients + loss = torch.rand(batch_size, channels, height, width, requires_grad=True) + ref_loss = torch.rand(batch_size, channels, height, width, requires_grad=False) + + # Compute loss + result, _ = diffusion_dpo_loss(loss, ref_loss, 0.1) + + # Backpropagate + result.mean().backward() + + # Verify gradients flowed through loss but not ref_loss + assert loss.grad is not None + assert ref_loss.grad is None # Reference loss should be detached + + +def test_diffusion_dpo_loss_chunking(): + # Test chunking functionality + batch_size = 4 + channels = 3 + height, width = 8, 8 + + # Create controlled inputs where first half is clearly different from second half + first_half = torch.zeros(batch_size // 2, channels, height, width) + second_half = torch.ones(batch_size // 2, channels, height, width) + + # Test that the function correctly chunks inputs + loss = torch.cat([first_half, second_half], dim=0) + ref_loss = torch.cat([first_half, second_half], dim=0) + + _result, metrics = diffusion_dpo_loss(loss, ref_loss, 1.0) + + # Since model_diff and ref_diff are identical, implicit acc should be 0.0 + assert abs(metrics["loss/diffusion_dpo_implicit_acc"]) < 1e-5 + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_mapo.py b/tests/library/test_custom_train_functions_mapo.py new file mode 100644 index 000000000..25a256c05 --- /dev/null +++ b/tests/library/test_custom_train_functions_mapo.py @@ -0,0 +1,121 @@ +import pytest +import torch +import numpy as np + +from library.custom_train_functions import mapo_loss + + +def test_mapo_loss_basic(): + batch_size = 8 # Must be even for chunking + channels = 4 + height, width = 64, 64 + + # Create dummy loss tensor with shape [B, C, H, W] + loss = torch.rand(batch_size, channels, height, width) + mapo_weight = 0.5 + result, metrics = mapo_loss(loss, mapo_weight) + + # Check return types + assert isinstance(result, torch.Tensor) + assert isinstance(metrics, dict) + + # Check required metrics are present + expected_keys = [ + "loss/mapo_total", + "loss/mapo_ratio", + "loss/mapo_w_loss", + "loss/mapo_l_loss", + "loss/mapo_win_score", + "loss/mapo_lose_score", + ] + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], float) + + +def test_mapo_loss_different_shapes(): + # Test with different tensor shapes + shapes = [ + (4, 4, 32, 32), # Small tensor + (8, 16, 64, 64), # Medium tensor + (12, 32, 128, 128), # Larger tensor + ] + for shape in shapes: + loss = torch.rand(*shape) + result, metrics = mapo_loss(loss, 0.5) + # The result should have dimension batch_size//2 + assert result.shape == torch.Size([shape[0] // 2]) + # All metrics should be scalars + for val in metrics.values(): + assert np.isscalar(val) + + +def test_mapo_loss_with_zero_weight(): + loss = torch.rand(8, 3, 64, 64) # Batch size must be even + result, metrics = mapo_loss(loss, 0.0) + + # With zero mapo_weight, ratio_loss should be zero + assert metrics["loss/mapo_ratio"] == 0.0 + + # result should be equal to loss_w (first half of the batch) + loss_w = loss[: loss.shape[0] // 2] + assert torch.allclose(result.mean(), loss_w.mean()) + + +def test_mapo_loss_with_different_timesteps(): + loss = torch.rand(8, 4, 32, 32) # Batch size must be even + # Test with different timestep values + timesteps = [1, 10, 100, 1000] + results = [] + for ts in timesteps: + result, metrics = mapo_loss(loss, 0.5, ts) + results.append(metrics["loss/mapo_ratio"]) + + # Check that the results are different for different timesteps + for i in range(1, len(results)): + assert results[i] != results[i - 1] + + +def test_mapo_loss_win_loss_scores(): + batch_size = 8 # Must be even + channels = 4 + height, width = 64, 64 + + # Create losses where winning examples have lower loss + w_loss = torch.ones(batch_size // 2, channels, height, width) * 0.1 + l_loss = torch.ones(batch_size // 2, channels, height, width) * 0.9 + + # Concatenate to create the full loss tensor + loss = torch.cat([w_loss, l_loss], dim=0) + + # Run the function + result, metrics = mapo_loss(loss, 0.5) + + # Win score should be higher than lose score (better performance) + assert metrics["loss/mapo_win_score"] > metrics["loss/mapo_lose_score"] + # Model losses for winners should be lower + assert metrics["loss/mapo_w_loss"] < metrics["loss/mapo_l_loss"] + + +def test_mapo_loss_gradient_flow(): + batch_size = 8 # Must be even + channels = 4 + height, width = 64, 64 + + # Create a loss tensor that requires grad + loss = torch.rand(batch_size, channels, height, width, requires_grad=True) + mapo_weight = 0.5 + + # Compute loss + result, _ = mapo_loss(loss, mapo_weight) + + # Compute mean for backprop + result.mean().backward() + + # If gradients flow, loss.grad should not be None + assert loss.grad is not None + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_sdpo.py b/tests/library/test_custom_train_functions_sdpo.py new file mode 100644 index 000000000..731ae1b6f --- /dev/null +++ b/tests/library/test_custom_train_functions_sdpo.py @@ -0,0 +1,254 @@ +import pytest +import torch + +from library.custom_train_functions import sdpo_loss + + +class TestSDPOLoss: + """Test suite for SDPO loss function""" + + @pytest.fixture + def sample_tensors(self): + """Create sample tensors for testing image latent tensors""" + # Image latent tensor dimensions + batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs + channels = 4 # Latent channels (e.g., VAE latent space) + height = 32 # Latent height + width = 32 # Latent width + + # Create tensors with shape [2*batch_size, channels, height, width] + # First half represents preferred (w), second half dispreferred (l) + loss = torch.randn(2 * batch_size, channels, height, width) + ref_loss = torch.randn(2 * batch_size, channels, height, width) + + return loss, ref_loss + + @pytest.fixture + def simple_tensors(self): + """Create simple tensors for basic testing""" + # Create tensors with shape (2, 4, 32, 32) + # First tensor (batch 0) + batch_0 = torch.full((4, 32, 32), 1.0) + batch_0[1] = 2.0 # Second channel + batch_0[2] = 2.0 # Third channel + batch_0[3] = 3.0 # Fourth channel + + # Second tensor (batch 1) + batch_1 = torch.full((4, 32, 32), 3.0) + batch_1[1] = 4.0 + batch_1[2] = 5.0 + batch_1[3] = 2.0 + + loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32) + + # Reference loss tensor + ref_batch_0 = torch.full((4, 32, 32), 0.5) + ref_batch_0[1] = 1.5 + ref_batch_0[2] = 3.5 + ref_batch_0[3] = 9.5 + + ref_batch_1 = torch.full((4, 32, 32), 2.5) + ref_batch_1[1] = 3.5 + ref_batch_1[2] = 4.5 + ref_batch_1[3] = 3.5 + + ref_loss = torch.stack([ref_batch_0, ref_batch_1], dim=0) # Shape: (2, 4, 32, 32) + + return loss, ref_loss + + def test_basic_functionality(self, simple_tensors): + """Test basic functionality with simple inputs""" + loss, ref_loss = simple_tensors + + print(loss.shape, ref_loss.shape) + + result_loss, metrics = sdpo_loss(loss, ref_loss) + + # Check return types + assert isinstance(result_loss, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape (should be scalar after mean reduction) + assert result_loss.shape == torch.Size([1]) + + # Check that loss is finite and positive + assert torch.isfinite(result_loss) + assert result_loss >= 0 + + def test_metrics_keys(self, simple_tensors): + """Test that all expected metrics are returned""" + loss, ref_loss = simple_tensors + + _, metrics = sdpo_loss(loss, ref_loss) + + expected_keys = [ + "loss/sdpo_log_ratio_w", + "loss/sdpo_log_ratio_l", + "loss/sdpo_w_theta_max", + "loss/sdpo_w_theta_w", + "loss/sdpo_w_theta_l", + ] + + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], (int, float)) + assert not torch.isnan(torch.tensor(metrics[key])) + + def test_different_beta_values(self, simple_tensors): + """Test with different beta values""" + loss, ref_loss = simple_tensors + + print(loss.shape, ref_loss.shape) + + beta_values = [0.01, 0.02, 0.05, 0.1] + results = [] + + for beta in beta_values: + result_loss, _ = sdpo_loss(loss, ref_loss, beta=beta) + results.append(result_loss.item()) + + # Results should be different for different beta values + assert len(set(results)) == len(beta_values) + + def test_different_epsilon_values(self, simple_tensors): + """Test with different epsilon values""" + loss, ref_loss = simple_tensors + + epsilon_values = [0.05, 0.1, 0.2, 0.5] + results = [] + + for epsilon in epsilon_values: + result_loss, _ = sdpo_loss(loss, ref_loss, epsilon=epsilon) + results.append(result_loss.item()) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_tensor_chunking(self, sample_tensors): + """Test that tensor chunking works correctly""" + loss, ref_loss = sample_tensors + + result_loss, metrics = sdpo_loss(loss, ref_loss) + + # The function should handle chunking internally + assert torch.isfinite(result_loss) + assert len(metrics) == 5 + + def test_gradient_flow(self, simple_tensors): + """Test that gradients can flow through the loss""" + loss, ref_loss = simple_tensors + loss.requires_grad_(True) + ref_loss.requires_grad_(True) + + result_loss, _ = sdpo_loss(loss, ref_loss) + result_loss.backward() + + # Check that gradients exist + assert loss.grad is not None + assert ref_loss.grad is not None + assert not torch.isnan(loss.grad).any() + assert not torch.isnan(ref_loss.grad).any() + + def test_numerical_stability(self): + """Test numerical stability with extreme values""" + # Test with very large values + large_loss = torch.full((4, 2, 32, 32), 100.0) + large_ref_loss = torch.full((4, 2, 32, 32), 50.0) + + result_loss, metrics = sdpo_loss(large_loss, large_ref_loss) + assert torch.isfinite(result_loss.mean()) + + # Test with very small values + small_loss = torch.full((4, 2, 32, 32), 1e-6) + small_ref_loss = torch.full((4, 2, 32, 32), 1e-7) + + result_loss, metrics = sdpo_loss(small_loss, small_ref_loss) + assert torch.isfinite(result_loss.mean()) + + def test_zero_inputs(self): + """Test with zero inputs""" + zero_loss = torch.zeros(4, 2, 32, 32) + zero_ref_loss = torch.zeros(4, 2, 32, 32) + + result_loss, metrics = sdpo_loss(zero_loss, zero_ref_loss) + + # Should handle zero inputs gracefully + assert torch.isfinite(result_loss.mean()) + for key, value in metrics.items(): + assert torch.isfinite(torch.tensor(value)) + + def test_asymmetric_preference(self): + """Test that the function properly handles preferred vs dispreferred samples""" + # Create scenario where preferred samples have lower loss + loss_w = torch.tensor([[[[1.0, 1.0]]]]) # preferred (lower loss) + loss_l = torch.tensor([[[[2.0, 3.0]]]]) # dispreferred (higher loss) + loss = torch.cat([loss_w, loss_l], dim=0) + + ref_loss_w = torch.tensor([[[[2.0, 2.0]]]]) + ref_loss_l = torch.tensor([[[[2.0, 2.0]]]]) + ref_loss = torch.cat([ref_loss_w, ref_loss_l], dim=0) + + result_loss, metrics = sdpo_loss(loss, ref_loss) + + # The loss should be finite and reflect the preference structure + assert torch.isfinite(result_loss) + assert result_loss >= 0 + + # Log ratios should reflect the preference structure + assert metrics["loss/sdpo_log_ratio_w"] > metrics["loss/sdpo_log_ratio_l"] + + @pytest.mark.parametrize( + "batch_size,channel,height,width", + [ + (2, 4, 16, 16), + (8, 16, 32, 32), + (4, 4, 16, 16), + ], + ) + def test_different_tensor_shapes(self, batch_size, channel, height, width): + """Test with different tensor shapes""" + loss = torch.randn(2 * batch_size, channel, height, width) + ref_loss = torch.randn(2 * batch_size, channel, height, width) + + result_loss, metrics = sdpo_loss(loss, ref_loss) + + assert torch.isfinite(result_loss.mean()) + assert result_loss.shape == torch.Size([batch_size]) + assert len(metrics) == 5 + + def test_device_compatibility(self, simple_tensors): + """Test that function works on different devices""" + loss, ref_loss = simple_tensors + + # Test on CPU + result_cpu, metrics_cpu = sdpo_loss(loss, ref_loss) + assert result_cpu.device.type == "cpu" + + # Test on GPU if available + if torch.cuda.is_available(): + loss_gpu = loss.cuda() + ref_loss_gpu = ref_loss.cuda() + result_gpu, metrics_gpu = sdpo_loss(loss_gpu, ref_loss_gpu) + assert result_gpu.device.type == "cuda" + + def test_reproducibility(self, simple_tensors): + """Test that results are reproducible with same inputs""" + loss, ref_loss = simple_tensors + + # Run multiple times with same seed + torch.manual_seed(42) + result1, metrics1 = sdpo_loss(loss, ref_loss) + + torch.manual_seed(42) + result2, metrics2 = sdpo_loss(loss, ref_loss) + + # Results should be identical + assert torch.allclose(result1, result2) + for key in metrics1: + assert abs(metrics1[key] - metrics2[key]) < 1e-6 + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_custom_train_functions_simpo.py b/tests/library/test_custom_train_functions_simpo.py new file mode 100644 index 000000000..173142b27 --- /dev/null +++ b/tests/library/test_custom_train_functions_simpo.py @@ -0,0 +1,537 @@ +import pytest +import torch +import torch.nn.functional as F + +from library.custom_train_functions import simpo_loss + + +class TestSimPOLoss: + """Test suite for SimPO (Simple Preference Optimization) loss function""" + + @pytest.fixture + def sample_tensors(self): + """Create sample tensors for testing image latent tensors""" + # Image latent tensor dimensions + batch_size = 1 # Will be doubled to 2 for preferred/dispreferred pairs + channels = 4 # Latent channels (e.g., VAE latent space) + height = 32 # Latent height + width = 32 # Latent width + + # Create tensors with shape [2*batch_size, channels, height, width] + # First half represents preferred (w), second half dispreferred (l) + loss = torch.randn(2 * batch_size, channels, height, width) + + return loss + + @pytest.fixture + def simple_tensors(self): + """Create simple tensors for basic testing""" + # Create tensors with shape (2, 4, 32, 32) + # First tensor (batch 0) - preferred (lower loss is better) + batch_0 = torch.full((4, 32, 32), 1.0) + batch_0[1] = 0.8 + batch_0[2] = 1.2 + batch_0[3] = 0.9 + + # Second tensor (batch 1) - dispreferred (higher loss) + batch_1 = torch.full((4, 32, 32), 2.5) + batch_1[1] = 2.8 + batch_1[2] = 2.2 + batch_1[3] = 2.7 + + loss = torch.stack([batch_0, batch_1], dim=0) # Shape: (2, 4, 32, 32) + + return loss + + def test_basic_functionality_sigmoid(self, simple_tensors): + """Test basic functionality with sigmoid loss type""" + loss = simple_tensors + + result_losses, metrics = simpo_loss(loss, loss_type="sigmoid") + + # Check return types + assert isinstance(result_losses, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape (should match input preferred/dispreferred batch size) + loss_w, _ = loss.chunk(2) + assert result_losses.shape == loss_w.shape + + # Check that losses are finite + assert torch.isfinite(result_losses).all() + + def test_basic_functionality_hinge(self, simple_tensors): + """Test basic functionality with hinge loss type""" + loss = simple_tensors + + result_losses, metrics = simpo_loss(loss, loss_type="hinge") + + # Check return types + assert isinstance(result_losses, torch.Tensor) + assert isinstance(metrics, dict) + + # Check tensor shape + loss_w, _ = loss.chunk(2) + assert result_losses.shape == loss_w.shape + + # Check that losses are finite and non-negative (ReLU property) + assert torch.isfinite(result_losses).all() + assert (result_losses >= 0).all() + + def test_metrics_keys(self, simple_tensors): + """Test that all expected metrics are returned""" + loss = simple_tensors + + _, metrics = simpo_loss(loss) + + expected_keys = ["loss/simpo_chosen_rewards", "loss/simpo_rejected_rewards", "loss/simpo_logratio"] + + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], (int, float)) + assert torch.isfinite(torch.tensor(metrics[key])) + + def test_loss_type_parameter(self, simple_tensors): + """Test different loss types produce different results""" + loss = simple_tensors + + sigmoid_losses, sigmoid_metrics = simpo_loss(loss, loss_type="sigmoid") + hinge_losses, hinge_metrics = simpo_loss(loss, loss_type="hinge") + + # Results should be different + assert not torch.allclose(sigmoid_losses, hinge_losses) + + # But metrics should be the same (they don't depend on loss type) + assert sigmoid_metrics["loss/simpo_chosen_rewards"] == hinge_metrics["loss/simpo_chosen_rewards"] + assert sigmoid_metrics["loss/simpo_rejected_rewards"] == hinge_metrics["loss/simpo_rejected_rewards"] + assert sigmoid_metrics["loss/simpo_logratio"] == hinge_metrics["loss/simpo_logratio"] + + def test_invalid_loss_type(self, simple_tensors): + """Test that invalid loss type raises ValueError""" + loss = simple_tensors + + with pytest.raises(ValueError, match="Unknown loss type: invalid"): + simpo_loss(loss, loss_type="invalid") + + def test_gamma_beta_ratio_effect(self, simple_tensors): + """Test that gamma_beta_ratio parameter affects results""" + loss = simple_tensors + + results = [] + gamma_ratios = [0.0, 0.25, 0.5, 1.0] + + for gamma_ratio in gamma_ratios: + result_losses, _ = simpo_loss(loss, gamma_beta_ratio=gamma_ratio) + results.append(result_losses.mean().item()) + + # Results should be different for different gamma_beta_ratio values + assert len(set(results)) == len(gamma_ratios) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_beta_parameter_effect(self, simple_tensors): + """Test that beta parameter affects results""" + loss = simple_tensors + + results = [] + beta_values = [0.1, 0.5, 1.0, 2.0, 5.0] + + for beta in beta_values: + result_losses, _ = simpo_loss(loss, beta=beta) + results.append(result_losses.mean().item()) + + # Results should be different for different beta values + assert len(set(results)) == len(beta_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_smoothing_parameter_sigmoid(self, simple_tensors): + """Test smoothing parameter with sigmoid loss""" + loss = simple_tensors + + # Test different smoothing values + smoothing_values = [0.0, 0.1, 0.3, 0.5] + results = [] + + for smoothing in smoothing_values: + result_losses, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=smoothing) + results.append(result_losses.mean().item()) + + # Results should be different for different smoothing values + assert len(set(results)) == len(smoothing_values) + + # All results should be finite + for result in results: + assert torch.isfinite(torch.tensor(result)) + + def test_smoothing_parameter_hinge(self, simple_tensors): + """Test that smoothing parameter doesn't affect hinge loss""" + loss = simple_tensors + + # Smoothing should not affect hinge loss + result_no_smooth, _ = simpo_loss(loss, loss_type="hinge", smoothing=0.0) + result_with_smooth, _ = simpo_loss(loss, loss_type="hinge", smoothing=0.5) + + # Results should be identical for hinge loss regardless of smoothing + assert torch.allclose(result_no_smooth, result_with_smooth) + + def test_tensor_chunking(self, sample_tensors): + """Test that tensor chunking works correctly""" + loss = sample_tensors + + result_losses, metrics = simpo_loss(loss) + + # The function should handle chunking internally + assert torch.isfinite(result_losses).all() + assert len(metrics) == 3 + + # Verify chunking produces correct shapes + loss_w, loss_l = loss.chunk(2) + assert loss_w.shape == loss_l.shape + assert loss_w.shape[0] == loss.shape[0] // 2 + assert result_losses.shape == loss_w.shape + + def test_logits_computation(self, simple_tensors): + """Test the logits computation (pi_logratios - gamma_beta_ratio)""" + loss = simple_tensors + gamma_beta_ratio = 0.25 + + _, metrics = simpo_loss(loss, gamma_beta_ratio=gamma_beta_ratio) + + # Manually compute logits + loss_w, loss_l = loss.chunk(2) + pi_logratios = loss_w - loss_l + expected_logits = pi_logratios - gamma_beta_ratio + + # The logratio metric should match our manual pi_logratios computation + # (Note: metric includes beta scaling) + beta = 2.0 # default beta + expected_logratio_metric = (beta * expected_logits).mean().item() + + assert abs(metrics["loss/simpo_logratio"] - expected_logratio_metric) < 1e-5 + + def test_sigmoid_loss_manual_computation(self, simple_tensors): + """Test sigmoid loss computation matches manual calculation""" + loss = simple_tensors + beta = 2.0 + gamma_beta_ratio = 0.25 + smoothing = 0.1 + + result_losses, _ = simpo_loss(loss, loss_type="sigmoid", beta=beta, gamma_beta_ratio=gamma_beta_ratio, smoothing=smoothing) + + # Manual computation + loss_w, loss_l = loss.chunk(2) + pi_logratios = loss_w - loss_l + logits = pi_logratios - gamma_beta_ratio + expected_losses = -F.logsigmoid(beta * logits) * (1 - smoothing) - F.logsigmoid(-beta * logits) * smoothing + + assert torch.allclose(result_losses, expected_losses, atol=1e-6) + + def test_hinge_loss_manual_computation(self, simple_tensors): + """Test hinge loss computation matches manual calculation""" + loss = simple_tensors + beta = 2.0 + gamma_beta_ratio = 0.25 + + result_losses, _ = simpo_loss(loss, loss_type="hinge", beta=beta, gamma_beta_ratio=gamma_beta_ratio) + + # Manual computation + loss_w, loss_l = loss.chunk(2) + pi_logratios = loss_w - loss_l + logits = pi_logratios - gamma_beta_ratio + expected_losses = torch.relu(1 - beta * logits) + + assert torch.allclose(result_losses, expected_losses, atol=1e-6) + + def test_reward_metrics_computation(self, simple_tensors): + """Test that reward metrics are computed correctly""" + loss = simple_tensors + beta = 2.0 + + _, metrics = simpo_loss(loss, beta=beta) + + # Manual computation of rewards + loss_w, loss_l = loss.chunk(2) + expected_chosen_rewards = (beta * loss_w.detach()).mean().item() + expected_rejected_rewards = (beta * loss_l.detach()).mean().item() + + assert abs(metrics["loss/simpo_chosen_rewards"] - expected_chosen_rewards) < 1e-6 + assert abs(metrics["loss/simpo_rejected_rewards"] - expected_rejected_rewards) < 1e-6 + + def test_gradient_flow(self, simple_tensors): + """Test that gradients flow properly through the loss""" + loss = simple_tensors + loss.requires_grad_(True) + + result_losses, _ = simpo_loss(loss) + + # Sum losses to get scalar for backward pass + total_loss = result_losses.sum() + total_loss.backward() + + # Check that gradients exist + assert loss.grad is not None + assert not torch.isnan(loss.grad).any() + assert torch.isfinite(loss.grad).all() + + def test_preferred_vs_dispreferred_structure(self): + """Test that the function properly handles preferred vs dispreferred samples""" + # Create scenario where preferred samples have lower loss (better) + loss_w = torch.full((1, 4, 32, 32), 1.0) # preferred (lower loss) + loss_l = torch.full((1, 4, 32, 32), 3.0) # dispreferred (higher loss) + loss = torch.cat([loss_w, loss_l], dim=0) + + result_losses, metrics = simpo_loss(loss) + + # The losses should be finite + assert torch.isfinite(result_losses).all() + + # With preferred having lower loss, pi_logratios should be negative + # This should lead to specific behavior in the loss computation + pi_logratios = loss_w - loss_l # Should be negative (1.0 - 3.0 = -2.0) + + assert pi_logratios.mean() == -2.0 + + # Chosen rewards should be lower than rejected rewards (since loss_w < loss_l) + assert metrics["loss/simpo_chosen_rewards"] < metrics["loss/simpo_rejected_rewards"] + + def test_equal_losses_case(self): + """Test behavior when preferred and dispreferred losses are equal""" + # Create scenario where preferred and dispreferred have same loss + loss_w = torch.full((1, 4, 32, 32), 2.0) + loss_l = torch.full((1, 4, 32, 32), 2.0) + loss = torch.cat([loss_w, loss_l], dim=0) + + result_losses, metrics = simpo_loss(loss) + + # pi_logratios should be zero + assert torch.isfinite(result_losses).all() + + # Chosen and rejected rewards should be equal + assert abs(metrics["loss/simpo_chosen_rewards"] - metrics["loss/simpo_rejected_rewards"]) < 1e-6 + + # Logratio should reflect the gamma_beta_ratio offset + gamma_beta_ratio = 0.25 # default + beta = 2.0 # default + expected_logratio = -beta * gamma_beta_ratio # Since pi_logratios = 0 + assert abs(metrics["loss/simpo_logratio"] - expected_logratio) < 1e-6 + + def test_numerical_stability_extreme_values(self): + """Test numerical stability with extreme values""" + # Test with very large values + large_loss = torch.full((2, 4, 32, 32), 100.0) + result_losses, _ = simpo_loss(large_loss) + assert torch.isfinite(result_losses).all() + + # Test with very small values + small_loss = torch.full((2, 4, 32, 32), 1e-6) + result_losses, _ = simpo_loss(small_loss) + assert torch.isfinite(result_losses).all() + + # Test with negative values + negative_loss = torch.full((2, 4, 32, 32), -10.0) + result_losses, _ = simpo_loss(negative_loss) + assert torch.isfinite(result_losses).all() + + def test_zero_beta_case(self, simple_tensors): + """Test the case when beta = 0""" + loss = simple_tensors + beta = 0.0 + + result_losses, metrics = simpo_loss(loss, beta=beta) + + # With beta=0, both loss types should give specific results + assert torch.isfinite(result_losses).all() + + # For sigmoid: logsigmoid(0) = log(0.5) ≈ -0.693 + # For hinge: relu(1 - 0) = 1 + + # Rewards should be zero + assert abs(metrics["loss/simpo_chosen_rewards"]) < 1e-6 + assert abs(metrics["loss/simpo_rejected_rewards"]) < 1e-6 + assert abs(metrics["loss/simpo_logratio"]) < 1e-6 + + def test_large_beta_case(self, simple_tensors): + """Test the case with very large beta""" + loss = simple_tensors + beta = 1000.0 + + result_losses, metrics = simpo_loss(loss, beta=beta) + + # Even with large beta, should remain stable + assert torch.isfinite(result_losses).all() + assert torch.isfinite(torch.tensor(metrics["loss/simpo_chosen_rewards"])) + assert torch.isfinite(torch.tensor(metrics["loss/simpo_rejected_rewards"])) + assert torch.isfinite(torch.tensor(metrics["loss/simpo_logratio"])) + + @pytest.mark.parametrize( + "batch_size,channels,height,width", + [ + (1, 4, 32, 32), + (2, 4, 16, 16), + (4, 8, 64, 64), + (8, 4, 8, 8), + ], + ) + def test_different_tensor_shapes(self, batch_size, channels, height, width): + """Test with different tensor shapes""" + # Note: batch_size will be doubled for preferred/dispreferred pairs + loss = torch.randn(2 * batch_size, channels, height, width) + + result_losses, metrics = simpo_loss(loss) + + assert torch.isfinite(result_losses).all() + assert result_losses.shape == (batch_size, channels, height, width) + assert len(metrics) == 3 + + def test_device_compatibility(self, simple_tensors): + """Test that function works on different devices""" + loss = simple_tensors + + # Test on CPU + result_cpu, _ = simpo_loss(loss) + assert result_cpu.device.type == "cpu" + + # Test on GPU if available + if torch.cuda.is_available(): + loss_gpu = loss.cuda() + result_gpu, _ = simpo_loss(loss_gpu) + assert result_gpu.device.type == "cuda" + + def test_reproducibility(self, simple_tensors): + """Test that results are reproducible with same inputs""" + loss = simple_tensors + + # Run multiple times + result1, metrics1 = simpo_loss(loss) + result2, metrics2 = simpo_loss(loss) + + # Results should be identical (deterministic computation) + assert torch.allclose(result1, result2) + for key in metrics1: + assert abs(metrics1[key] - metrics2[key]) < 1e-6 + + def test_no_reference_model_needed(self, simple_tensors): + """Test that SimPO works without reference model (key feature)""" + loss = simple_tensors + + # SimPO should work with just the loss tensor, no reference needed + result_losses, metrics = simpo_loss(loss) + + # Should produce meaningful results without reference model + assert torch.isfinite(result_losses).all() + assert len(metrics) == 3 + assert all(key in metrics for key in ["loss/simpo_chosen_rewards", "loss/simpo_rejected_rewards", "loss/simpo_logratio"]) + + def test_smoothing_interpolation_sigmoid(self): + """Test that smoothing interpolates between positive and negative logsigmoid""" + loss_w = torch.full((1, 4, 32, 32), 1.0) + loss_l = torch.full((1, 4, 32, 32), 2.0) + loss = torch.cat([loss_w, loss_l], dim=0) + + # Test extreme smoothing values + no_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=0.0) + full_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=1.0) + half_smooth, _ = simpo_loss(loss, loss_type="sigmoid", smoothing=0.5) + + # With smoothing=0.5, result should be between the extremes + assert torch.isfinite(no_smooth).all() + assert torch.isfinite(full_smooth).all() + assert torch.isfinite(half_smooth).all() + + # The smoothed version should be different from both extremes + assert not torch.allclose(no_smooth, full_smooth) + assert not torch.allclose(half_smooth, no_smooth) + assert not torch.allclose(half_smooth, full_smooth) + + def test_hinge_loss_properties(self): + """Test specific properties of hinge loss""" + # Create scenario where logits > 1/beta (should give zero loss) + loss_w = torch.full((1, 4, 32, 32), -2.0) # Very low preferred loss + loss_l = torch.full((1, 4, 32, 32), 2.0) # High dispreferred loss + loss = torch.cat([loss_w, loss_l], dim=0) + + beta = 0.5 # Small beta + gamma_beta_ratio = 0.25 + + result_losses, _ = simpo_loss(loss, loss_type="hinge", beta=beta, gamma_beta_ratio=gamma_beta_ratio) + + # Calculate expected behavior + pi_logratios = loss_w - loss_l # -2 - 2 = -4 + logits = pi_logratios - gamma_beta_ratio # -4 - 0.25 = -4.25 + # relu(1 - 0.5 * (-4.25)) = relu(1 + 2.125) = relu(3.125) = 3.125 + + expected_value = 1 - beta * logits # 1 - 0.5 * (-4.25) = 3.125 + assert torch.allclose(result_losses, expected_value) + + def test_edge_case_all_zeros(self): + """Test edge case with all zero losses""" + loss = torch.zeros(2, 4, 32, 32) + + result_losses, metrics = simpo_loss(loss) + + # Should handle all zeros gracefully + assert torch.isfinite(result_losses).all() + assert torch.isfinite(torch.tensor(metrics["loss/simpo_chosen_rewards"])) + assert torch.isfinite(torch.tensor(metrics["loss/simpo_rejected_rewards"])) + assert torch.isfinite(torch.tensor(metrics["loss/simpo_logratio"])) + + # With all zeros: chosen and rejected rewards should be zero + assert abs(metrics["loss/simpo_chosen_rewards"]) < 1e-6 + assert abs(metrics["loss/simpo_rejected_rewards"]) < 1e-6 + + def test_gamma_beta_ratio_as_margin(self): + """Test that gamma_beta_ratio acts as a margin in the logits""" + loss_w = torch.full((1, 4, 32, 32), 1.0) + loss_l = torch.full((1, 4, 32, 32), 1.0) # Equal losses + loss = torch.cat([loss_w, loss_l], dim=0) + + # With equal losses, pi_logratios = 0, so logits = -gamma_beta_ratio + gamma_ratios = [0.0, 0.5, 1.0] + + for gamma_ratio in gamma_ratios: + _, metrics = simpo_loss(loss, gamma_beta_ratio=gamma_ratio) + + # logratio should be -beta * gamma_ratio + beta = 2.0 # default + expected_logratio = -beta * gamma_ratio + assert abs(metrics["loss/simpo_logratio"] - expected_logratio) < 1e-6 + + def test_return_tensor_vs_scalar_difference_from_cpo(self): + """Test that SimPO returns tensor losses (not scalar like some other methods)""" + loss = torch.randn(2, 4, 32, 32) + + result_losses, _ = simpo_loss(loss) + + # SimPO should return tensor with same shape as preferred batch + loss_w, _ = loss.chunk(2) + assert result_losses.shape == loss_w.shape + assert result_losses.dim() > 0 # Not a scalar + + @pytest.mark.parametrize("loss_type", ["sigmoid", "hinge"]) + def test_parameter_combinations(self, simple_tensors, loss_type): + """Test various parameter combinations work correctly""" + loss = simple_tensors + + # Test different parameter combinations + param_combinations = [ + {"beta": 0.5, "gamma_beta_ratio": 0.1, "smoothing": 0.0}, + {"beta": 2.0, "gamma_beta_ratio": 0.5, "smoothing": 0.1}, + {"beta": 5.0, "gamma_beta_ratio": 1.0, "smoothing": 0.3}, + ] + + for params in param_combinations: + result_losses, metrics = simpo_loss(loss, loss_type=loss_type, **params) + + assert torch.isfinite(result_losses).all() + assert len(metrics) == 3 + assert all(torch.isfinite(torch.tensor(v)) for v in metrics.values()) + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index 2ad7ce4ee..372293969 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -2,9 +2,10 @@ import torch from unittest.mock import MagicMock, patch from library.flux_train_utils import ( - get_noisy_model_input_and_timesteps, + get_noisy_model_input_and_timestep, ) + # Mock classes and functions class MockNoiseScheduler: def __init__(self, num_train_timesteps=1000): @@ -12,6 +13,9 @@ def __init__(self, num_train_timesteps=1000): self.config.num_train_timesteps = num_train_timesteps self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long) + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + # Create fixtures for commonly used objects @pytest.fixture @@ -66,13 +70,13 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device): args.timestep_sampling = "uniform" dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) assert noisy_input.dtype == dtype - assert timesteps.dtype == dtype + assert timestep.dtype == dtype def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): @@ -80,11 +84,11 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): args.sigmoid_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) def test_shift_sampling(args, noise_scheduler, latents, noise, device): @@ -93,11 +97,11 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device): args.discrete_flow_shift = 3.1582 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): @@ -105,34 +109,34 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): args.sigmoid_scale = 1.0 dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) def test_weighting_scheme(args, noise_scheduler, latents, noise, device): # Mock the necessary functions for this specific test - with patch("library.flux_train_utils.compute_density_for_timestep_sampling", - return_value=torch.tensor([0.3, 0.7], device=device)), \ - patch("library.flux_train_utils.get_sigmas", - return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)): - + with ( + patch( + "library.flux_train_utils.compute_density_for_timestep_sampling", return_value=torch.tensor([0.3, 0.7], device=device) + ), + patch("library.flux_train_utils.get_sigmas", return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)), + ): + args.timestep_sampling = "other" # Will trigger the weighting scheme path args.weighting_scheme = "uniform" args.logit_mean = 0.0 args.logit_std = 1.0 args.mode_scale = 1.0 dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype - ) - + + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) + assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) # Test IP noise options @@ -141,11 +145,11 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device): args.ip_noise_gamma_random_strength = False dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): @@ -153,21 +157,21 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): args.ip_noise_gamma_random_strength = True dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) + assert timestep.shape == (latents.shape[0],) + assert sigma.shape == (latents.shape[0], 1, 1, 1) # Test different data types def test_float16_dtype(args, noise_scheduler, latents, noise, device): dtype = torch.float16 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.dtype == dtype - assert timesteps.dtype == dtype + assert timestep.dtype == dtype # Test different batch sizes @@ -176,11 +180,11 @@ def test_different_batch_size(args, noise_scheduler, device): noise = torch.randn(5, 4, 8, 8) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (5,) - assert sigmas.shape == (5, 1, 1, 1) + assert timestep.shape == (5,) + assert sigma.shape == (5, 1, 1, 1) # Test different image sizes @@ -189,11 +193,11 @@ def test_different_image_size(args, noise_scheduler, device): noise = torch.randn(2, 4, 16, 16) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (2,) - assert sigmas.shape == (2, 1, 1, 1) + assert timestep.shape == (2,) + assert sigma.shape == (2, 1, 1, 1) # Test edge cases @@ -203,7 +207,7 @@ def test_zero_batch_size(args, noise_scheduler, device): noise = torch.randn(0, 4, 8, 8) dtype = torch.float32 - get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) def test_different_timestep_count(args, device): @@ -212,9 +216,9 @@ def test_different_timestep_count(args, device): noise = torch.randn(2, 4, 8, 8) dtype = torch.float32 - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) + noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype) assert noisy_input.shape == latents.shape - assert timesteps.shape == (2,) + assert timestep.shape == (2,) # Check that timesteps are within the proper range - assert torch.all(timesteps < 500) + assert torch.all(timestep < 500) diff --git a/train_network.py b/train_network.py index 7861e7404..e29053db7 100644 --- a/train_network.py +++ b/train_network.py @@ -36,8 +36,10 @@ import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( + PreferenceOptimization, apply_snr_weight, get_weighted_text_embeddings, + normalize_gradients, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, @@ -66,24 +68,9 @@ def generate_step_logs( lr_scheduler, lr_descriptions, optimizer=None, - keys_scaled=None, - mean_norm=None, - maximum_norm=None, - mean_grad_norm=None, - mean_combined_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} - if keys_scaled is not None: - logs["max_norm/keys_scaled"] = keys_scaled - logs["max_norm/max_key_norm"] = maximum_norm - if mean_norm is not None: - logs["norm/avg_key_norm"] = mean_norm - if mean_grad_norm is not None: - logs["norm/avg_grad_norm"] = mean_grad_norm - if mean_combined_norm is not None: - logs["norm/avg_combined_norm"] = mean_combined_norm - lrs = lr_scheduler.get_last_lr() for i, lr in enumerate(lrs): if lr_descriptions is not None: @@ -108,7 +95,11 @@ def generate_step_logs( if ( args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None ): # tracking d*lr value of unet. - logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] + + if "effective_lr" in optimizer.param_groups[i]: + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["effective_lr"] + else: + logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] else: idx = 0 if not args.network_train_unet_only: @@ -122,7 +113,10 @@ def generate_step_logs( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None: - logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + if "effective_lr" in optimizer.param_groups[i]: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["effective_lr"] + else: + logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] return logs @@ -255,21 +249,25 @@ def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTe def get_noise_pred_and_target( self, - args, - accelerator, + args: argparse.Namespace, + accelerator: Accelerator, noise_scheduler, - latents, - batch, + latents: torch.FloatTensor, + batch: dict[str, torch.Tensor], text_encoder_conds, unet, network, - weight_dtype, - train_unet, + weight_dtype: torch.dtype, + train_unet: bool, is_train=True, - ): + timesteps=None, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, rand_timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + if timesteps is None: + timesteps = rand_timesteps # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -320,10 +318,10 @@ def get_noise_pred_and_target( ) network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + sigmas = timesteps / noise_scheduler.config.num_train_timesteps + return noise_pred, noisy_latents, target, sigmas, timesteps, None - return noise_pred, target, timesteps, None - - def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: + def post_process_loss(self, loss: torch.Tensor, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: @@ -380,10 +378,12 @@ def process_batch( is_train=True, train_text_encoder=True, train_unet=True, - ) -> torch.Tensor: + multipliers=1.0, + ) -> tuple[torch.Tensor, dict[str, float | int]]: """ Process a batch for the network """ + metrics: dict[str, float | int] = {} with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -447,7 +447,8 @@ def process_batch( text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + + noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -461,20 +462,60 @@ def process_batch( is_train=is_train, ) + losses: dict[str, torch.Tensor] = {} + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + if self.po.is_po(): + if self.po.is_reference(): + accelerator.unwrap_model(network).set_multiplier(0.0) + ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, ref_weighting = ( + self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=False, + timesteps=timesteps, + ) + ) + + # reset network multipliers + accelerator.unwrap_model(network).set_multiplier(1.0) + + ref_loss = train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c) + + if weighting is not None: + ref_loss = ref_loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + ref_loss = apply_masked_loss(ref_loss, batch) + loss, metrics_po = self.po(loss, ref_loss) + else: + loss, metrics_po = self.po(loss) + + metrics.update(metrics_po) + else: + loss = loss.mean([1, 2, 3]) loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return loss.mean() + for k in losses.keys(): + losses[k] = self.post_process_loss(losses[k], args, timesteps, noise_scheduler, latents) + # if "loss_weights" in batch and len(batch["loss_weights"]) == loss.shape[0]: + # losses[k] *= batch["loss_weights"] # 各sampleごとのweight + + return loss.mean(), losses, metrics def train(self, args): session_id = random.randint(0, 2**32) @@ -1041,6 +1082,14 @@ def load_model_hook(models, input_dir): "ss_validate_every_n_epochs": args.validate_every_n_epochs, "ss_validate_every_n_steps": args.validate_every_n_steps, "ss_resize_interpolation": args.resize_interpolation, + "ss_mapo_beta": args.mapo_beta, + "ss_cpo_beta": args.cpo_beta, + "ss_bpo_beta": args.bpo_beta, + "ss_bpo_lambda": args.bpo_lambda, + "ss_sdpo_beta": args.sdpo_beta, + "ss_ddo_beta": args.ddo_beta, + "ss_ddo_alpha": args.ddo_alpha, + "ss_dpo_beta": args.beta_dpo, } self.update_metadata(metadata, args) # architecture specific metadata @@ -1261,6 +1310,11 @@ def load_model_hook(models, input_dir): val_step_loss_recorder = train_util.LossRecorder() val_epoch_loss_recorder = train_util.LossRecorder() + self.po = PreferenceOptimization(args) + + if self.po.is_po(): + logger.info(f"Preference optimization activated: {self.po.algo}") + del train_dataset_group if val_dataset_group is not None: del val_dataset_group @@ -1401,7 +1455,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) - loss = self.process_batch( + loss, losses, metrics = self.process_batch( batch, text_encoders, unet, @@ -1420,8 +1474,14 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen ) accelerator.backward(loss) + + if args.norm_gradient: + normalize_gradients(network) + + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: self.all_reduce_network(accelerator, network) # sync DDP grad manually + if args.max_grad_norm != 0.0: params_to_clip = accelerator.unwrap_model(network).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) @@ -1435,29 +1495,31 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + max_mean_logs = {} if args.scale_weight_norms: keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( args.scale_weight_norms, accelerator.device ) - mean_grad_norm = None - mean_combined_norm = None max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} - else: - if hasattr(network, "weight_norms"): - weight_norms = network.weight_norms() - mean_norm = weight_norms.mean().item() if weight_norms is not None else None - grad_norms = network.grad_norms() - mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None - combined_weight_norms = network.combined_weight_norms() - mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None - maximum_norm = weight_norms.max().item() if weight_norms is not None else None - keys_scaled = None - max_mean_logs = {} - else: - keys_scaled, mean_norm, maximum_norm = None, None, None - mean_grad_norm = None - mean_combined_norm = None - max_mean_logs = {} + metrics["max_norm/avg_key_norm"] = mean_norm + metrics["max_norm/max_key_norm"] = maximum_norm + metrics["max_norm/keys_scaled"] = keys_scaled + + if hasattr(network, "weight_norms"): + weight_norms = network.weight_norms() + if weight_norms is not None: + metrics["norm/avg_key_norm"] = weight_norms.mean().item() + metrics["norm/max_key_norm"] = weight_norms.max().item() + + grad_norms = network.grad_norms() + if grad_norms is not None: + metrics["norm/avg_grad_norm"] = grad_norms.mean().item() + metrics["norm/max_grad_norm"] = grad_norms.max().item() + + combined_weight_norms = network.combined_weight_norms() + if combined_weight_norms is not None: + metrics["norm/avg_combined_norm"] = combined_weight_norms.mean().item() + metrics["norm/max_combined_norm"] = combined_weight_norms.max().item() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -1500,13 +1562,8 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen lr_scheduler, lr_descriptions, optimizer, - keys_scaled, - mean_norm, - maximum_norm, - mean_grad_norm, - mean_combined_norm, ) - self.step_logging(accelerator, logs, global_step, epoch + 1) + self.step_logging(accelerator, {**logs, **metrics}, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ... @@ -1532,7 +1589,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep - loss = self.process_batch( + loss, losses, val_metrics = self.process_batch( batch, text_encoders, unet, @@ -1610,7 +1667,7 @@ def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTen # temporary, for batch processing self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False) - loss = self.process_batch( + loss, losses, val_metrics = self.process_batch( batch, text_encoders, unet, @@ -1875,6 +1932,7 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します", ) + parser.add_argument("--norm_gradient", action="store_true", help="Normalize gradients to 1.0") return parser