Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions helpers/multiaspect/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def __len__(self):
return len(self.bucket_manager)

def __getitem__(self, image_path):
if image_path is False:
logger.debug(f'Received {image_path} instead of image path, we are assuming this is the end of an epoch, and passing it up the chain.')
return image_path
logger.debug(f"Running __getitem__ for {image_path} inside Dataloader.")
example = {"instance_images_path": image_path}
if not StateTracker.status_training():
Expand Down
2 changes: 2 additions & 0 deletions helpers/multiaspect/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ def __iter__(self):
f"All buckets exhausted - since this is happening now, most likely you have chronically-underfilled buckets."
)
self._reset_buckets()
# Exit with a False so that the loop knows we are done this epoch.
return False

def __len__(self):
return sum(
Expand Down
4 changes: 3 additions & 1 deletion train_sd21.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,9 @@ def main(args):
and epoch == first_epoch
and step < resume_step
):
if step + 2 == resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
if step + 1 == resume_step:
# We want to trigger the batch to be properly generated when we start.
if not StateTracker.status_training():
logging.info(
Expand Down
24 changes: 11 additions & 13 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,19 +773,6 @@ def collate_fn(examples):
ema_unet = None
if args.use_ema:
logger.info("Using EMA. Creating EMAModel.")
decay = unet.config["decay"]
del unet.config["decay"]
min_decay = unet.config["min_decay"]
del unet.config["min_decay"]
update_after_step = unet.config["update_after_step"]
del unet.config["update_after_step"]
use_ema_warmup = unet.config["use_ema_warmup"]
del unet.config["use_ema_warmup"]
inv_gamma = unet.config["inv_gamma"]
del unet.config["inv_gamma"]
power = unet.config["power"]
del unet.config["power"]

ema_unet = EMAModel(
unet.parameters(), model_cls=UNet2DConditionModel, model_config=unet.config
)
Expand Down Expand Up @@ -975,13 +962,20 @@ def collate_fn(examples):
unet.train()
train_loss = 0.0
training_luminance_values = []
current_epoch_step = 0
for step, batch in enumerate(train_dataloader):
# If we receive a False from the enumerator, we know we reached the next epoch.
if batch is False:
logger.info(f'Reached the end of epoch {epoch}')
break
# Skip steps until we reach the resumed step
if (
args.resume_from_checkpoint
and epoch == first_epoch
and step < resume_step
):
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
if step + 2 == resume_step:
# We want to trigger the batch to be properly generated when we start.
if not StateTracker.status_training():
Expand Down Expand Up @@ -1165,6 +1159,7 @@ def collate_fn(examples):
ema_unet.step(unet.parameters())
progress_bar.update(1)
global_step += 1
current_epoch_step += 1
# Average out the luminance values of each batch, so that we can store that in this step.
avg_training_data_luminance = sum(training_luminance_values) / len(
training_luminance_values
Expand Down Expand Up @@ -1221,6 +1216,9 @@ def collate_fn(examples):
state_path=os.path.join(save_path, "training_state.json"),
)
logger.info(f"Saved state to {save_path}")
if current_epoch_step > num_update_steps_per_epoch:
logger.info('Epoch {epoch} is now completed, as we have observed {current_epoch_step}/{num_update_steps_per_epoch} steps per epoch.')
break

logs = {
"step_loss": loss.detach().item(),
Expand Down