Skip to content

Commit

Permalink
Merge pull request ultralytics#8 from MagicFrogSJTU/NanoCode012-patch-1
Browse files Browse the repository at this point in the history
Modify number of dataloaders' workers
  • Loading branch information
MagicFrogSJTU committed Jul 14, 2020
2 parents e742dd9 + 787582f commit c8357ad
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions train.py
Expand Up @@ -294,7 +294,7 @@ def train(hyp, tb_writer, opt, device):
loss, loss_items = compute_loss(pred, targets.to(device), model)
# loss is scaled with batch size in func compute_loss. But in DDP mode, gradient is averaged between devices.
if local_rank != -1:
loss *= dist.get_world_size()
loss *= opt.world_size
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss_items)
return results
Expand Down Expand Up @@ -449,6 +449,7 @@ def train(hyp, tb_writer, opt, device):
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
device = torch_utils.select_device(opt.device, apex=mixed_precision, batch_size=opt.batch_size)
opt.total_batch_size = opt.batch_size
opt.world_size = 1
if device.type == 'cpu':
mixed_precision = False
elif opt.local_rank != -1:
Expand All @@ -457,9 +458,10 @@ def train(hyp, tb_writer, opt, device):
torch.cuda.set_device(opt.local_rank)
device = torch.device("cuda", opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend

assert opt.batch_size % dist.get_world_size() == 0
opt.batch_size = opt.total_batch_size // dist.get_world_size()

opt.world_size = dist.get_world_size()
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
opt.batch_size = opt.total_batch_size // opt.world_size
print(opt)

# Train
Expand Down
2 changes: 1 addition & 1 deletion utils/datasets.py
Expand Up @@ -59,7 +59,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
pad=pad)

batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
nw = min([os.cpu_count()//opt.world_size, batch_size if batch_size > 1 else 0, 8]) # number of workers
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if local_rank != -1 else None
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
Expand Down

0 comments on commit c8357ad

Please sign in to comment.