Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distributed training #74

Merged
merged 155 commits into from Feb 6, 2023
Merged

distributed training #74

merged 155 commits into from Feb 6, 2023

Conversation

clementchadebec
Copy link
Owner

@clementchadebec clementchadebec commented Feb 3, 2023

New features

  • Pythae now supports distributed training (built on top of PyTorch DDP). Launching a distributed training can be done using a training script in which all of the distributed environment variables are passed to a BaseTrainerConfig instance as follows:
training_config = BaseTrainerConfig(
     num_epochs=10,
     learning_rate=1e-3,
     per_device_train_batch_size=64,
     per_device_eval_batch_size=64,
     dist_backend="nccl", # distributed backend
     world_size=8 # number of gpus to use (n_nodes x n_gpus_per_node),
     rank=0 # process/gpu id,
     local_rank=1 # node id,
     master_addr="localhost" # master address,
     master_port="12345" # master port,
 )

The script can then be launched using a launcher such a srun. This module was tested in both mono-node-multi-gpu and multi-node-multi-gpu settings.

Major Changes

  • Selection and definition of custom optimizers and schedulers changed. It is no longer needed to build the optimizer (resp. scheduler) and pass them to the Trainer. As of v0.1.0, the choice and parameters of the optimizers and schedulers can be passed directly to the TrainerConfig. See changes below:

As of v0.1.0

my_model = VAE(model_config=model_config)
# Specify instances and params directly in Trainer config
training_config = BaseTrainerConfig(
    ...,
    optimizer_cls="AdamW",
    optimizer_params={"betas": (0.91, 0.995)}
    scheduler_cls="MultiStepLR",
    scheduler_params={"milestones": [10, 20, 30], "gamma": 10**(-1/5)}
)
trainer = BaseTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    training_config=training_config
)
# Launch training
trainer.train()

Before v0.1.0

my_model = VAE(model_config=model_config)
training_config = BaseTrainerConfig(...)
### Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=training_config.learning_rate, betas=(0.91, 0.995))
### Scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=10**(-1/5))
# Pass instances to Trainer
trainer = BaseTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    training_config=training_config,
    optimizer=optimizer,
    scheduler=scheduler
)
# Launch training
trainer.train()
  • batch_size key no longer available in the Trainer configurations. It is replaced by the keys per_device_train_batch_size and per_device_eval_batch_size where the batch size per device is specified. Please note that if you are in a distributed setting with for instance 4 GPUs and specify a per_device_eval_batch_size=64, this is equivalent to training on a single GPU using a batch_size of 4*64.

Minor changes

  • Added the ability to specify the desired number of workers for data_loading in the Trainer configuration under the keys train_dataloader_num_workers and eval_dataloader_num_workers
  • Cleaned up __init__ of Trainers and moved sanity checks from train method to __init__
  • Moved checks on optimizers and schedulers in TrainerConfing __post_init_post_parse__

@clementchadebec clementchadebec added the enhancement New feature or request label Feb 4, 2023
@clementchadebec clementchadebec linked an issue Feb 6, 2023 that may be closed by this pull request
@clementchadebec clementchadebec marked this pull request as ready for review February 6, 2023 16:40
@clementchadebec clementchadebec merged commit 08f805e into main Feb 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Allow distributed training
1 participant