Skip to content

Commit

Permalink
use mask loss PR kohya-ss#589 and save mask to npz
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Feb 10, 2024
1 parent 5f6c5ff commit 041bd93
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 13 deletions.
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask
Expand Down
19 changes: 12 additions & 7 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,8 +1131,8 @@ def __getitem__(self, index):

image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz)
mask = load_mask(image_info.absolute_path, image_info.resized_size) / 255
latents, original_size, crop_ltrb, flipped_latents, mask = load_latents_from_disk(image_info.latents_npz)
mask = mask / 255
if flipped:
latents = flipped_latents
mask = np.flip(mask, axis=1)
Expand Down Expand Up @@ -2001,7 +2001,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[np.ndarray]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
Expand All @@ -2010,14 +2010,19 @@ def load_latents_from_disk(
original_size = npz["original_size"].tolist()
crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
return latents, original_size, crop_ltrb, flipped_latents
mask = npz["mask"] if "mask" in npz else None
if mask is not None:
mask = mask.astype(np.float32)
return latents, original_size, crop_ltrb, flipped_latents, mask


def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None):
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, mask=None):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
np.savez(
if mask is not None:
kwargs["mask"] = np.array(mask, dtype=np.uint8)
np.savez_compressed(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size),
Expand Down Expand Up @@ -2322,7 +2327,7 @@ def cache_batch_latents(
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")

if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, mask)
else:
info.latents = latent
info.mask = mask
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

target = noise

if args.masked_loss and batch['masks'] is not None:
if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask
Expand Down
2 changes: 1 addition & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def train(args):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask
Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask
Expand Down

0 comments on commit 041bd93

Please sign in to comment.