Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,6 @@ def parse_args(input_args=None):
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
Expand Down Expand Up @@ -376,10 +370,8 @@ def main(args):
else:
colossalai.launch_from_torch(config={}, seed=args.seed)

colossalai.launch_from_torch(config={})

if args.seed is not None:
gpc.set_seed(args.seed)
local_rank = gpc.get_local_rank(ParallelMode.DATA)
world_size = gpc.get_world_size(ParallelMode.DATA)

if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
Expand Down Expand Up @@ -408,7 +400,7 @@ def main(args):
for example in tqdm(
sample_dataloader,
desc="Generating class images",
disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
disable=not local_rank == 0,
):
images = pipeline(example["prompt"]).images

Expand All @@ -420,7 +412,7 @@ def main(args):
del pipeline

# Handle the repository creation
if gpc.get_local_rank(ParallelMode.DATA) == 0:
if local_rank == 0:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
Expand Down Expand Up @@ -486,12 +478,7 @@ def main(args):
unet.enable_gradient_checkpointing()

if args.scale_lr:
args.learning_rate = (
args.learning_rate
* args.gradient_accumulation_steps
* args.train_batch_size
* gpc.get_world_size(ParallelMode.DATA)
)
args.learning_rate = args.learning_rate * args.train_batch_size * world_size

unet = gemini_zero_dpp(unet, args.placement)

Expand Down Expand Up @@ -547,16 +534,16 @@ def collate_fn(examples):

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
Expand All @@ -571,26 +558,25 @@ def collate_fn(examples):
text_encoder.to(get_current_device(), dtype=weight_dtype)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# Train!
total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps
total_batch_size = args.train_batch_size * world_size

logger.info("***** Running training *****", ranks=[0])
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0])
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0)
progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
progress_bar.set_description("Steps")
global_step = 0

Expand All @@ -607,7 +593,7 @@ def collate_fn(examples):
optimizer.zero_grad()

latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
latents = latents * 0.18215

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
Expand Down Expand Up @@ -667,7 +653,7 @@ def collate_fn(examples):
if global_step % args.save_steps == 0:
torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0:
if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=torch_unet,
Expand All @@ -682,7 +668,7 @@ def collate_fn(examples):
torch.cuda.synchronize()
unet = get_static_torch_model(unet)

if gpc.get_local_rank(ParallelMode.DATA) == 0:
if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
Expand Down