# Pytorch Distributed Data parallel(DDP) processing tutorial

1. Import pytorch module for do DDP processing

In [None]:
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel

In [None]:
def main(args):
    n_gpus = torch.cuda.device_count()
    args.world_size = n_gpus

    if args.use_ddp: # a flag for using Pytorch DDP
        args.world_size = n_gpus * args.world_size
        args.num_workers = n_gpus * 4 # it's common to multiply 4 with how many gpu you have
        args.batch_size = n_gpus * args.batch_size # to split batch per each gpu
        args.val_batch_size = n_gpus * args.val_batch_size # same above
        mp.spawn(main_worker, nprocs=n_gpus, args=(n_gpus, args)) # pytorch multiprocessing spawn
    else: # a flag for not using Pytorch DDP
        args.gpu = 0 # set first gpu id
        main_worker(args.gpu, n_gpus, args)

def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    torch.cuda.set_device(args.gpu)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.is_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend='nccl', 
                                init_method='tcp://127.0.0.1:88', #127.0.0.1:23456
                                world_size=args.world_size, 
                                rank=args.rank)

    args.flag = True if args.is_distributed == False or args.gpu == 0 else False

    if args.flag:
        print(args)
    solver = Solver(args)
