Skip to content

Commit

Permalink
Merge pull request #738 from mv1388/print_only_on_main_ddp_gpu
Browse files Browse the repository at this point in the history
More descriptive tqdm bar and less tqdm clutter for DDP
  • Loading branch information
mv1388 committed Aug 18, 2022
2 parents 61ac989 + 6e8c298 commit b3cff2e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
45 changes: 35 additions & 10 deletions aitoolbox/torchtrain/train_loop/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(self, model,

self.ddp_training_mode = False
self.ddp_handler: Optional[DDPHandler] = None
self.ddp_rank = None

self.callbacks = []
self.callbacks_handler = CallbacksHandler(self)
Expand Down Expand Up @@ -230,7 +231,8 @@ def _train(self, num_epochs, num_iterations, callbacks=None, grad_accumulation=1
print(f'Epoch: {self.epoch}')
self.callbacks_handler.execute_epoch_begin()

for self.iteration, batch_data in enumerate(tqdm(self.train_loader)):
for self.iteration, batch_data in enumerate(tqdm(self.train_loader,
desc='Training', disable=not self.is_main_process())):
self.total_iteration_idx += 1
self.callbacks_handler.execute_batch_begin()

Expand Down Expand Up @@ -488,7 +490,7 @@ def evaluate_loss_on_train_set(self, force_prediction=False):
float or dict: loss, in the case of multi loss, the dict gets returned
"""
if not self.prediction_store.has_train_loss(self.total_iteration_idx) or force_prediction:
loss = self.evaluate_model_loss(self.train_loader)
loss = self.evaluate_model_loss(self.train_loader, dataset_info={'type': 'train'})
self.prediction_store.insert_train_loss(loss, self.total_iteration_idx, force_prediction)
else:
loss = self.prediction_store.get_train_loss(self.total_iteration_idx)
Expand All @@ -506,7 +508,7 @@ def evaluate_loss_on_validation_set(self, force_prediction=False):
float or dict: loss, in the case of multi loss, the dict gets returned
"""
if not self.prediction_store.has_val_loss(self.total_iteration_idx) or force_prediction:
loss = self.evaluate_model_loss(self.validation_loader)
loss = self.evaluate_model_loss(self.validation_loader, dataset_info={'type': 'validation'})
self.prediction_store.insert_val_loss(loss, self.total_iteration_idx, force_prediction)
else:
loss = self.prediction_store.get_val_loss(self.total_iteration_idx)
Expand All @@ -524,22 +526,30 @@ def evaluate_loss_on_test_set(self, force_prediction=False):
float or dict: loss, in the case of multi loss, the dict gets returned
"""
if not self.prediction_store.has_test_loss(self.total_iteration_idx) or force_prediction:
loss = self.evaluate_model_loss(self.test_loader)
loss = self.evaluate_model_loss(self.test_loader, dataset_info={'type': 'test'})
self.prediction_store.insert_test_loss(loss, self.total_iteration_idx, force_prediction)
else:
loss = self.prediction_store.get_test_loss(self.total_iteration_idx)

return loss

def evaluate_model_loss(self, data_loader):
def evaluate_model_loss(self, data_loader, dataset_info=None):
"""Run given dataset through the network without updating the weights and return the loss
Args:
data_loader (torch.utils.data.DataLoader): dataloader containing the data on which the loss is calculated
dataset_info (dict or None): additional information describing the dataset inside the provided dataloader.
One such dataset info is the dataset ``type`` (``"train"``, ``"validation"``, or ``"test"``) set by
``evaluate_loss_on_train_set()``, ``evaluate_loss_on_validation_set()`` and
``evaluate_loss_on_test_set()`` methods.
Returns:
float or dict: loss, in the case of multi loss, the dict gets returned
"""
desc = "Loss evaluation"
if isinstance(dataset_info, dict) and 'type' in dataset_info:
desc = f"{desc} on {dataset_info['type']}"

self.model = self.model.to(self.device)
if self.criterion is not None:
self.criterion = self.criterion.to(self.device)
Expand All @@ -548,7 +558,7 @@ def evaluate_model_loss(self, data_loader):
loss_avg = []

with torch.no_grad():
for batch_data in tqdm(data_loader):
for batch_data in tqdm(data_loader, desc=desc, disable=not self.is_main_process()):
with amp.autocast(enabled=self.use_amp):
if self.batch_model_feed_def is None:
loss_batch = self.model.get_loss_eval(batch_data, self.criterion, self.device)
Expand Down Expand Up @@ -646,13 +656,17 @@ def predict_with_model(self, data_loader, execute_callbacks=False, dataset_info=
(torch.Tensor, torch.Tensor, dict): y_pred, y_true, metadata
in the form of dict of lists/torch.Tensors/np.arrays
"""
desc = "Making predictions"
if isinstance(dataset_info, dict) and 'type' in dataset_info:
desc = f"{desc} on {dataset_info['type']}"

self.model = self.model.to(self.device)

self.model.eval()
y_pred, y_test, metadata_list = [], [], []

with torch.no_grad():
for batch_data in tqdm(data_loader):
for batch_data in tqdm(data_loader, desc=desc, disable=not self.is_main_process()):
with amp.autocast(enabled=self.use_amp):
if self.batch_model_feed_def is None:
y_pred_batch, y_test_batch, metadata_batch = self.model.get_predictions(batch_data, self.device)
Expand Down Expand Up @@ -723,6 +737,17 @@ def get_num_training_steps(self):
else:
return int(len(self.train_loader) // self.grad_accumulation * self.num_epochs)

def is_main_process(self):
"""Is current process the main training process
In case of single GPU/CPU we have single process so this function is always True. However, for DDP training
main process is treated as that which is at rank 0.
Returns:
bool: if current process is the main training process. In case of DDP it is process at rank 0
"""
return not self.ddp_training_mode or self.ddp_rank == 0

def _train_dp(self, num_epochs, num_iterations, callbacks=None, grad_accumulation=1, dp_model_args=None):
"""Train the model on multi-GPU with DataParallel auto wrapping
Expand Down Expand Up @@ -824,11 +849,11 @@ def _spawn_fit(self, gpu, ddp_args, num_epochs, num_iterations, callbacks, grad_
When using this data loading option bear in mind that loaded dataset will be replicated in memory for
every spawned training process. This can in turn in cause extensive overall memory consumption.
"""
rank = ddp_args['node_rank'] * ddp_args['num_gpus'] + gpu
self.ddp_rank = ddp_args['node_rank'] * ddp_args['num_gpus'] + gpu

dist.init_process_group(
backend=ddp_args['backend'], init_method=ddp_args['init_method'],
world_size=ddp_args['world_size'], rank=rank
world_size=ddp_args['world_size'], rank=self.ddp_rank
)

torch.manual_seed(0)
Expand All @@ -849,7 +874,7 @@ def _spawn_fit(self, gpu, ddp_args, num_epochs, num_iterations, callbacks, grad_
self.callbacks_handler.execute_multiprocess_start()
# Add DistributedSampler to the data loaders
self.ddp_handler = DDPHandler(self)
self.ddp_handler.add_distributed_samplers(ddp_args['world_size'], rank)
self.ddp_handler.add_distributed_samplers(ddp_args['world_size'], self.ddp_rank)

# Move to the GPU belonging to the process
self.model = self.model.to(self.device)
Expand Down
Binary file modified dist/aitoolbox-1.6.2-py3-none-any.whl
Binary file not shown.
Binary file modified dist/aitoolbox-1.6.2.tar.gz
Binary file not shown.

0 comments on commit b3cff2e

Please sign in to comment.