Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- made metrics friendly for multi-gpu model parallel training
Browse files Browse the repository at this point in the history
  • Loading branch information
nasimrahaman committed Sep 16, 2017
1 parent 8d05e2e commit dfd3598
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
13 changes: 11 additions & 2 deletions inferno/extensions/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,14 @@ class Metric(object):
def forward(self, *args, **kwargs):
raise NotImplementedError

def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def __call__(self, prediction, target, **kwargs):
# Make sure prediction and target live on the same device.
# If they don't, move target to the right device.
if not prediction.is_cuda:
# Move to CPU
target = target.cpu()
else:
# Find device to move to
device_ordinal = prediction.get_device()
target = target.cuda(device_ordinal)
return self.forward(prediction, target, **kwargs)
3 changes: 2 additions & 1 deletion inferno/extensions/metrics/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def forward(self, prediction, target):
dont_ignore_class = list(range(num_classes))
dont_ignore_class.pop(ignore_class)
if classwise_iou.is_cuda:
dont_ignore_class = torch.cuda.LongTensor(dont_ignore_class)
dont_ignore_class = \
torch.LongTensor(dont_ignore_class).cuda(classwise_iou.get_device())
else:
dont_ignore_class = torch.LongTensor(dont_ignore_class)
iou = classwise_iou[dont_ignore_class].mean()
Expand Down

0 comments on commit dfd3598

Please sign in to comment.