# A Comprehensive Tutorial to Pytorch DistributedDataParallel

The limited computation resource might discourage distibuted training across multiple gpus. It’s basically an easy job to wrap the model with DDP (short for DistributedDataParallel). 

In this blog, I want to share my code, my insighs with all beginners in DDP. I’m not going to include detailed explanation of how DDP works, instead, I provide minimum knowledge needed to make the model run in multiple gpus. Note that I only introduce DDP on one machine with multiple gpus, which is the most general case (Otherwise, we should use model parallel as stated in the official [blog](https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html). 

This tutorial is organized as:
- Overview of DDP
- Implementation of DDP workflow (Steps 1–6)
- Issues about dist.barrier()

(https://medium.com/codex/a-comprehensive-tutorial-to-pytorch-distributeddataparallel-1f4b42bb1b51)

## Overview of DDP

Terms used in distributed training:

- **master node**: the main GPU responsible for synchronizations, making copies, loading models, writing checpoints and logs;
- **process group**: if you want to train/test the model over K GPUs, then the K process forms a group, which is supported by a backend (pytorch managed that for you, according to the [documentation](https://pytorch.org/docs/1.9.0/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=distributeddataparallel#torch.nn.parallel.DistributedDataParallel), nccl is the most recommended backend);
- **rank**: within the process group, each process is identified by its rank, from 0 to K-1;
- **world size**: the number of processes in the group.

Pytorch provides two settings for distributed training: `torch.nn.DataParallel` (DP) and `torch.nn.parallel.DistributedDataParallel` (DDP), where the latter is officially recommended. In short, DDP is faster, more flexible than DP. The fundamental thing DDP does is to copy the model to multiple gpus, gather the gradients from them, average the gradients to update the model, then synchronize the model over all K processes. 

We can also gather/scatter tensors/objects other than gradients by torch.distributed.gather/scatter/reduce.

In case the model can fit on one GPU (it can be trained on one GPU with batch_size = 1) and we want to train/test it on K GPUs, the best practice of DDP is to copy the model onto the K GPUs (the DDP class automatically does this for you) and split the dataloader to K non-overlapping groups to feed into K models respectively.

We have to do the following things:

1. setup the process group, which is three lines of code and needs no modification;
2. split the dataloader to each process in the group, which can be easily achieved by torch.utils.data.DistributedSampler or any customized sampler;
3. wrap our model with DDP, which is one line of code and barely needs modification;
4. train/test our model, which is the same as is on 1 GPU;
5. clean up the process groups, which is one line of code;
6. optional: gather extra data among processes (possibly needed for distributed testing), which is basically one line of code.

## Setup the process group

In [26]:
import torch.distributed as dist

def setup(rank, world_size):    
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'    
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

## Split the dataloader

We can easily split our dataloader by `torch.utils.data.distributed.DistributedSampler`. The sampler returns an iterator over indices, which are fed into dataloader to bachify.

The DistributedSampler split the total indices of the dataset into `world_size` parts, and evenly distributes them to the dataloader in each process without duplication.

In [27]:
from torch.utils.data.distributed import DistributedSampler

def prepare(rank, world_size, batch_size=32, pin_memory=False, num_workers=0):
    dataset = Your_Dataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)
    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, 
                            drop_last=False, shuffle=False, sampler=sampler)
    return dataloader

Suppose K=3, and the length of dataset is 10. We must understand that DistributedSampler imposes even partition of indices.

If we set `drop_last=False` when defining `DistributedSampler`, it will automatically pad. For example, it splits indices `[0,1,2,3,4,5,6,7,8,9]` to `[0,3,6,9]` when `rank=1`, `[0,4,7,0]` when `rank=2`, and `[2,5,8,0]` when `rank=3`. As you can see, such padding may cause issues because the padded 0 is a data record. Otherwise, it will strip off the trailing elements. For example, it splits the indices to `[0,3,6]` at `rank=1`, `[1,4,7]` at `rank=2`, and `[2,5,8]` at `rank=3`. In this case, it tailored 9 to make the indice number divisible by `world_size`.

It is very simple to customize our `Sampler`. We only need to create a class, then define its `__iter__()` and `__len__()` function. Refer to the [official documentation](https://pytorch.org/docs/stable/data.html?highlight=distributedsampler#torch.utils.data.distributed.DistributedSampler) for more details.

BTW, you’d better set the `num_workers=0` when distributed training, because creating extra threads in the children processes may be problemistic. I also found `pin_memory=False` avoids many horrible bugs, maybe such things are machine-specific.

## Wrap the model with DDP

We should first move our model to the specific GPU (recall that one model replica resides in one GPU), then we wrap it with DDP class. The following function takes in an argument rank, which we will introduce soon. For now, we just keep in mind rank equals the GPU id.

In [28]:
from torch.nn.parallel import DistributedDataParallel as DDP

def main(rank, world_size):
    # setup the process groups
    setup(rank, world_size)    
    
    # prepare the dataloader
    dataloader = prepare(rank, world_size)
    
    # instantiate the model(it's your own model) and move it to the right device
    model = Model().to(rank)
    
    # wrap the model with DDP
    # device_ids tell DDP where is your model
    # output_device tells DDP where to output, in our case, it is rank
    # find_unused_parameters=True instructs DDP to find unused output of the forward() function of any module in the model    
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)

There are a few tricky things here:

- When we want to access some customized attributes of the DDP wrapped model, we must reference `model.module`. That is to say, our model instance is saved as a module attribute of the DDP model. If we assign some attributes `xxx` other than built-in properties or functions, we must access them by `model.module.xxx`.
- When we save the DDP model, our state_dict would add a module prefix to all parameters. 
- Consequently, if we want to load a DDP saved model to a non-DDP model, we have to manually strip the extra prefix. I provide my code below:

In [31]:
# in case we load a DDP model checkpoint to a non-DDP modelmodel_dict = OrderedDict()
'''
import re
pattern = re.compile('module.')

for k,v in state_dict.items():
    if re.search("module", k):
        model_dict[re.sub(pattern, '', k)] = v
    else:
        model_dict = state_dict
        
model.load_state_dict(model_dict)
'''

'\nimport re\npattern = re.compile(\'module.\')\n\nfor k,v in state_dict.items():\n    if re.search("module", k):\n        model_dict[re.sub(pattern, \'\', k)] = v\n    else:\n        model_dict = state_dict\n        \nmodel.load_state_dict(model_dict)\n'

## Train/test our model

This part is the key to implementing DDP. First we need to know the basis of multi-processing: all children processes together with the parent process run the same code.

In PyTorch, `torch.multiprocessing` provides convenient ways to create parallel processes. As the official documentation says,

> The spawn function below addresses these concerns and takes care of error propagation, out of order termination, and will actively terminate processes upon detecting an error in one of them.

So, using `spawn` is a good choice.

In our script, we should define a train/test function before spawning it to parallel processes:

In [35]:
def main(rank, world_size):
    # setup the process groups
    setup(rank, world_size)    # prepare the dataloader
    dataloader = prepare(rank, world_size)
    
    # instantiate the model(it's your own model) and move it to the right device
    model = Your_Model().to(rank)
    
    # wrap the model with DDP
    # device_ids tell DDP where is your model
    # output_device tells DDP where to output, in our case, it is rank
    # find_unused_parameters=True instructs DDP to find unused output of the forward() function of any module in the model    
    
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)    
    
    #################### The above is defined previously
   
    optimizer = Your_Optimizer()
    loss_fn = Your_Loss()    
    for epoch in epochs:
        # if we are using DistributedSampler, we have to tell it which epoch this is
        dataloader.sampler.set_epoch(epoch)       
        
        for step, x in enumerate(dataloader):
            optimizer.zero_grad(set_to_none=True)
            
            pred = model(x)
            label = x['label']
            
            loss = loss_fn(pred, label)
            loss.backward()
            optimizer.step()    
    cleanup()

This `main` function is run in every parallel process. We now need to call it by `spawn` method. In our `.py` script, we write:

In [36]:
import torch.multiprocessing as mp

if __name__ == '__main__':
    # suppose we have 3 gpus
    world_size = 2
    mp.spawn(
        main,
        args=(world_size),
        nprocs=world_size
    )

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/conda/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'main' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/conda/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/conda/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'main' on <module '__main__' (built-in)>


ProcessExitedException: process 1 terminated with exit code 1

Remember the first argument of `main` is `rank`? It is automatically passed to each process by `mp.spawn`, we don’t need to pass it explicitly. `rank=0` is the master node by default. The `rank` ranges from `0` to `K-1` (2 in our case).

## Clean up the process groups

The last line of main function is the clean up function, which is:

In [17]:
def cleanup():
    dist.destroy_process_group()

## Optional: Gather extra data among processes

Sometimes we need to collect some data from all processes, such as the testing result. We can easily gather tensors by `dist.all_gather` and objects by `dist.all_gather_object`.

Without loss of generality, I assume we want to collect python objects. The only constraint of the object is it must be serializable, which is basically everything in python. One should always assign `torch.cuda.set_device(rank)` before using `all_gather_xxx`. And, if we want to store a tensor in the object, it must locate at the `output_device`.

In [25]:
def main(rank, world_size):
    torch.cuda.set_device(rank)
    data = {
        'tensor': torch.ones(3,device=rank) + rank,
        'list': [1,2,3] + rank,
        'dict': {'rank':rank}   
    }
    
    # we have to create enough room to store the collected objects
    outputs = [None for _ in range(world_size)]
    
    # the first argument is the collected lists, the second argument is the data unique in each process
    dist.all_gather_object(outputs, data)    
    
    # we only want to operate on the collected objects at master node
    if rank == 0:
        print(outputs)

## Issues about dist.barrier()

The most confusing thing to me is when to use `dist.barrier()`. As the documentation says, it synchronizes processes. In other words, it blocks processes until all of them reaches the same line of code: `dist.barrier()`. I summarize its usage as follows:

- we do not need it when training, since DDP automatically does it for us (in `loss.backward()`);
- we do not need it when gathering data, since `dist.all_gather_object` does it for us;
- we need it when enforcing execution order of codes, [say one process loads the model that another process saves](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) (I can hardly imagine this scenario is needed).