Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 44 additions & 57 deletions library/lumina_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,7 @@ def sample_image_inference(


def time_shift(mu: float, sigma: float, t: torch.Tensor):
# the following implementation was original for t=0: clean / t=1: noise
# Since we adopt the reverse, the 1-t operations are needed
t = 1 - t
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
t = 1 - t
return t


Expand Down Expand Up @@ -802,61 +798,42 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor
weighting = torch.ones_like(sigmas)
return weighting


# mainly copied from flux_train_utils.get_noisy_model_input_and_timesteps
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.

Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
latents (Tensor): Latents.
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type

Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps
sigmas
"""
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None

assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)

timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = (
logits_norm * args.sigmoid_scale
) # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)

t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * noise + t * latents
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "nextdit_shift":
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
t = time_shift(mu, 1.0, t)

timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
sigmas = time_shift(mu, 1.0, sigmas)

timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "flux_shift":
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
Expand All @@ -867,14 +844,24 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)

# Add noise according to flow matching.
sigmas = get_sigmas(
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
)
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise

return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas

Expand Down Expand Up @@ -1049,10 +1036,10 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):

parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift", "flux_shift"],
default="shift",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid, Flux.1 and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、Flux.1、NextDIT.1のシフト。デフォルトは'shift'です。",
)
parser.add_argument(
"--sigmoid_scale",
Expand Down
2 changes: 1 addition & 1 deletion lumina_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def grad_hook(parameter: torch.Tensor):
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = nextdit(
x=noisy_model_input, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(
dtype=torch.int32
Expand Down
2 changes: 1 addition & 1 deletion lumina_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps):
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = dit(
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t=1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
)
Expand Down