diff --git a/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/examples/research_projects/colossalai/train_dreambooth_colossalai.py index e5039f593d22..17212e84f8d6 100644 --- a/examples/research_projects/colossalai/train_dreambooth_colossalai.py +++ b/examples/research_projects/colossalai/train_dreambooth_colossalai.py @@ -161,12 +161,6 @@ def parse_args(input_args=None): help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") - parser.add_argument( - "--gradient_accumulation_steps", - type=int, - default=1, - help="Number of updates steps to accumulate before performing a backward/update pass.", - ) parser.add_argument( "--gradient_checkpointing", action="store_true", @@ -376,10 +370,8 @@ def main(args): else: colossalai.launch_from_torch(config={}, seed=args.seed) - colossalai.launch_from_torch(config={}) - - if args.seed is not None: - gpc.set_seed(args.seed) + local_rank = gpc.get_local_rank(ParallelMode.DATA) + world_size = gpc.get_world_size(ParallelMode.DATA) if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) @@ -408,7 +400,7 @@ def main(args): for example in tqdm( sample_dataloader, desc="Generating class images", - disable=not gpc.get_local_rank(ParallelMode.DATA) == 0, + disable=not local_rank == 0, ): images = pipeline(example["prompt"]).images @@ -420,7 +412,7 @@ def main(args): del pipeline # Handle the repository creation - if gpc.get_local_rank(ParallelMode.DATA) == 0: + if local_rank == 0: if args.push_to_hub: if args.hub_model_id is None: repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) @@ -486,12 +478,7 @@ def main(args): unet.enable_gradient_checkpointing() if args.scale_lr: - args.learning_rate = ( - args.learning_rate - * args.gradient_accumulation_steps - * args.train_batch_size - * gpc.get_world_size(ParallelMode.DATA) - ) + args.learning_rate = args.learning_rate * args.train_batch_size * world_size unet = gemini_zero_dpp(unet, args.placement) @@ -547,7 +534,7 @@ def collate_fn(examples): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -555,8 +542,8 @@ def collate_fn(examples): lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, ) weight_dtype = torch.float32 if args.mixed_precision == "fp16": @@ -571,14 +558,14 @@ def collate_fn(examples): text_encoder.to(get_current_device(), dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil(len(train_dataloader)) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # Train! - total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps + total_batch_size = args.train_batch_size * world_size logger.info("***** Running training *****", ranks=[0]) logger.info(f" Num examples = {len(train_dataset)}", ranks=[0]) @@ -586,11 +573,10 @@ def collate_fn(examples): logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0]) logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0]) logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0]) - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0]) logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0]) # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0) + progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0) progress_bar.set_description("Steps") global_step = 0 @@ -607,7 +593,7 @@ def collate_fn(examples): optimizer.zero_grad() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * vae.config.scaling_factor + latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -667,7 +653,7 @@ def collate_fn(examples): if global_step % args.save_steps == 0: torch.cuda.synchronize() torch_unet = get_static_torch_model(unet) - if gpc.get_local_rank(ParallelMode.DATA) == 0: + if local_rank == 0: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=torch_unet, @@ -682,7 +668,7 @@ def collate_fn(examples): torch.cuda.synchronize() unet = get_static_torch_model(unet) - if gpc.get_local_rank(ParallelMode.DATA) == 0: + if local_rank == 0: pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unet,