Skip to content

Commit

Permalink
Merge pull request #709 from mv1388/DDP_eval_loss_calculation_over_al…
Browse files Browse the repository at this point in the history
…l_batches

DDP evaluate average loss over all batches instead of only those on current worker
  • Loading branch information
mv1388 committed Jul 20, 2022
2 parents 4ecce26 + 5bd6f9e commit ba6b6f7
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
9 changes: 7 additions & 2 deletions aitoolbox/torchtrain/train_loop/components/ddp_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def build_loader_sampler(data_loader, shuffle, world_size, rank):
data_loader_sampler = DataLoader(**data_loader_args)
return data_loader_sampler, ddp_sampler

def mp_sync(self, data, concat_mp_data=True):
def mp_sync(self, data, concat_mp_data=True, double_precision=False):
"""Multiprocess data sync
Share input data between all the active processes so that every process has all the values from
Expand All @@ -107,6 +107,11 @@ def mp_sync(self, data, concat_mp_data=True):
In case this is torch.Tensor, resulting output the device location will be preserved.
concat_mp_data (bool): should the returned list of collected tensors be concatenated into a single list
of values
double_precision (bool): in case the ``data`` parameter is not already a Tensor, the function wraps given
data into a Tensor. By default, it uses PyTorch default 32 bit precision float tensor. If this parameter
is set to ``True`` however, the double precision 64 bit tensor will be created. This is useful
for example if input data is in 64 bit, and we want to prevent precision reduction when syncing the data
across the workers.
Returns:
torch.Tensor: list of `data` variable values synced across all the active processes
Expand All @@ -118,7 +123,7 @@ def mp_sync(self, data, concat_mp_data=True):
if isinstance(data, torch.Tensor):
input_data_device = data.device.type
else:
data = torch.Tensor(data)
data = torch.tensor(data, dtype=torch.float32 if not double_precision else torch.float64)

data_tensor_wrap = data.to(self.train_loop_obj.device)
mp_data = [torch.zeros_like(data_tensor_wrap) for _ in range(dist.get_world_size())]
Expand Down
29 changes: 17 additions & 12 deletions aitoolbox/torchtrain/train_loop/train_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,16 @@ def parse_loss(self, loss_record):
Primarily useful for parsing between single loss representation and the multi-loss representation.
Args:
loss_record (list): list losses from each processed batch
loss_record (list): list of losses from each processed batch.
If we used single loss than the ``loss_record`` is a list of floats where each element float is
loss for a single batch which has been transformed from torch Tensor into the float via .item()
If we used multiple losses wrapped inside MultiLoss() then the _train() function called its
item() method which converted multi-loss representation into the list of dicts, where each dict
represents a loss for a single batch:
``[{'loss_1': 1., 'loss_2': 33.}, { ... }]``
Returns:
np.array or dict: in the case of single loss numpy array is returned, otherwise the dict of multiple losses
Expand All @@ -420,23 +429,19 @@ def parse_loss(self, loss_record):

if isinstance(self.optimizer, MultiOptimizer):
loss_names = sorted(loss_record[0].keys())
# loss_record is a list of lists with dimensions: [num_batches, num_losses]
loss_record = [[loss_dict[k] for k in loss_names] for loss_dict in loss_record]

loss_batch_accum_avg = np.mean(loss_record, axis=0)
else:
loss_batch_accum_avg = np.mean(loss_record)

if self.ddp_training_mode:
loss_ddp_synced = self.ddp_handler.mp_sync(loss_batch_accum_avg).numpy()
if isinstance(self.optimizer, MultiOptimizer):
loss_batch_accum_avg = np.mean(loss_ddp_synced, axis=0)
else:
loss_batch_accum_avg = np.mean(loss_ddp_synced)
loss_record = self.ddp_handler.mp_sync(loss_record, double_precision=True)
loss_record = loss_record.numpy()

loss_avg = np.mean(loss_record, axis=0)

if loss_names is None:
return loss_batch_accum_avg
return loss_avg
else:
return dict(zip(loss_names, loss_batch_accum_avg))
return dict(zip(loss_names, loss_avg))

def _print_save_loss(self, loss_parsed, loss_type_name, loss_print_description):
"""Helper function which prints information about parsed loss and saves the loss results into the history
Expand Down
Binary file modified dist/aitoolbox-1.6.0-py3-none-any.whl
Binary file not shown.
Binary file modified dist/aitoolbox-1.6.0.tar.gz
Binary file not shown.

0 comments on commit ba6b6f7

Please sign in to comment.