In [None]:
#export
import os
import argparse

In [None]:
# https://github.com/pytorch/examples/tree/master/distributed/ddp
# https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
# https://pytorch.org/tutorials/intermediate/dist_tuto.html
#
#

# Single node, multi-gpu
```
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=4 --use_env <main_python_script>
python /path/to/launch.py --nnode=1 --node_rank=0 --nproc_per_node=8 example.py --local_world_size=8
```

# Multi-node, multi-gpu

    1. Choose a node as master and find an available high port (here, in range 49000-65535) on it for communication with worker nodes (https://unix.stackexchange.com/a/423052):
```bash
    MASTER_PORT=`comm -23 <(seq 49000 65535 | sort) <(ss -tan | awk '{print $4}' | cut -d':' -f2 | grep '[0-9]{1,5}' | sort -u)| shuf | head -n 1`
```
    2. Set MASTER_ADDR and MASTER_PORT on all nodes for launch utility:
    
```bash
    export MASTER_ADDR=<MASTER_ADDR> MASTER_PORT=$MASTER_PORT
```
    
    3. Launch master node process:
```bash
    python -m torch.distributed.launch --nnodes= --node_rank=0 --nproc_per_node=<num_gpus_per_node> --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --use_env <main_python_script> --distributed true
```
    4. Launch worker nodes' processes (run on each node, setting appropriate node_rank):
```bash
    python -m torch.distributed.launch --nnodes=<num_nodes> --node_rank= --nproc_per_node=4 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --use_env <main_python_script> --distributed true
```


# Code

In [None]:
#export 

def spmd_main(local_rank, ddp_func):
    env_dict = {key: os.environ[key] for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE") }
    
    print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
    dist.init_process_group(backend="nccl")
    print( f"[{os.getpid()}] world_size = {dist.get_world_size()}, "
        + f"rank = {dist.get_rank()}, backend={dist.get_backend()}" ) 
    
    ddp_func(local_rank)

    dist.destroy_process_group()
    
def parse_mp_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", default=0, type=int)
    args = parser.parse_args()
    return args
