- https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html
- https://www.cnblogs.com/bytehandler/p/17635933.html

- `export NCCL_P2P_DISABLE=1`：disable direct GPU-to-GPU (P2P) communication.
    - 禁用点对点通信，点对点通信是一种高效的数据传输方式，允许直接在 GPU 之间传输数据，绕过主机内存（CPU 内存）。但在某些情况下，点对点通信可能会导致兼容性或性能问题。
    - Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled.

In [1]:
from IPython.display import Image

## pytorch 分布式相关api

- `torch.distributed.init_process_group()` ，初始化进程组。
- `torch.distributed.get_rank()`，可以获得当前进程的 rank；
- `torch.distributed.get_world_size()`，可以获得进程组的进程数量。
- `torch.distributed.barrier()`，同步进程组内的所有进程，阻塞所有进程直到所有进程都执行到操作。

## 集合通信

NCCL支持集合通信操作（Collective Operations）：

- `Broadcast`，进程组内的一个进程将Tensor广播给其他进程。
- `AllReduce`，进程组内所有进程进行规约操作，最终所有进程得到统一的Tensor。
- `ReduceScatter`，进程组内所有进程先进行reduce操作，再进行scatter操作，每个进程得到Tensor的一部分。
- `AllGather`，进程组内所有进程的Tensor聚合成一个Tensor列表，并且最终所有进程都有一个Tensor列表副本。

```
torchrun --nproc_per_node 2 --nnodes 1 torch_nccl_test.py
# deepspeed --num_gpus 2 --num_nodes 1 torch_nccl_test.py  
```

### scatter

In [10]:
Image(url='https://pytorch.org/tutorials/_images/scatter.png', width=400)

```
def dist_scatter():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    tensor = torch.zeros(world_size)
    before_tensor = tensor.clone()
    if dist.get_rank() == 0:
        # Assumes world_size of 2.
        # Only tensors, all of which must be the same size.
        t_ones = torch.ones(world_size)
        t_fives = torch.ones(world_size) * 5
        scatter_list = [t_ones, t_fives]
    else:
        scatter_list = None
    dist.scatter(tensor, scatter_list, src=0)
    logging.info(f"scatter, rank: {rank}, before scatter: {repr(before_tensor)} after scatter: {repr(tensor)}")
    dist.barrier()
```

### gather

In [3]:
Image(url='https://pytorch.org/tutorials/_images/gather.png', width=400)

```
def dist_gather():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    tensor = torch.tensor([rank*2+1], dtype=torch.float32)
    before_tensor = tensor.clone()
    
    gather_list = [torch.zeros(1) for _ in range(world_size)] if rank == 0 else None

    dist.gather(tensor, gather_list, dst=0)
    
    logging.info(f"gather, rank: {rank}, before gather: {repr(before_tensor)} after gather: {repr(gather_list)}")
    dist.barrier()
```

### broadcast

In [5]:
Image(url='https://pytorch.org/tutorials/_images/broadcast.png', width=400)

```
def dist_broadcast():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    src_rank = 0
    tensor = torch.tensor(world_size) if rank == src_rank else torch.zeros(1, dtype=torch.int64)
    before_tensor = tensor.clone()
    dist.broadcast(tensor, src=src_rank)
    logging.info(f"broadcast, rank: {rank}, before broadcast tensor: {repr(before_tensor)} after broadcast tensor: {repr(tensor)}")
    dist.barrier()
```

### reduce

In [6]:
Image(url='https://pytorch.org/tutorials/_images/reduce.png', width=400)

- `def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):`
    - 默认的 reduce 操作是 SUM

```
def dist_reduce():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    tensor = torch.tensor([rank*2 + 1], dtype=torch.float32)
    before_tensor = tensor.clone()

    dist.reduce(tensor, op=ReduceOp.SUM, dst=0)
    
    logging.info(f"reduce, rank: {rank}, before reduce: {repr(before_tensor)} after reduce: {repr(tensor)}")
    dist.barrier()
```

### all-reduce

In [7]:
Image(url='https://pytorch.org/tutorials/_images/all_reduce.png', width=400)

- reduce + broadcast

```
def dist_allreduce():
    dist.barrier()

    rank = dist.get_rank()
    # world_size = torch.distributed.get_world_size()

    if rank == 0:
        tensor = torch.tensor([1., 2.])
    else:
        tensor = torch.tensor([2., 3.])
    input_tensor = tensor.clone()
    dist.all_reduce(tensor)

    logging.info(f"all_reduce, rank: {rank}, before allreduce tensor: {repr(input_tensor)}, after allreduce tensor: {repr(tensor)}")
    dist.barrier()
```

### all gather

In [8]:
Image(url='https://pytorch.org/tutorials/_images/all_gather.png', width=400)

- gather + broadcast

```
def dist_allgather():
    print_rank_0("allgather:")
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    input_tensor = torch.tensor(rank)
    tensor_list = [torch.zeros(1, dtype=torch.int64) for _ in range(world_size)]
    dist.all_gather(tensor_list, input_tensor)
    logging.info(f"allgather, rank: {rank}, input_tensor: {repr(input_tensor)}, output tensor_list: {tensor_list}")
    dist.barrier()
```

### reduce-scatter

In [9]:
Image(url='https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/_images/reducescatter.png', width=400)

- 先 reduce + 再 scatter

```
def dist_reducescatter():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    output = torch.empty(1, dtype=torch.int64)
    input_list = [torch.tensor(rank*2+1), torch.tensor(rank*2+2)]
    dist.reduce_scatter(output, input_list, op=ReduceOp.SUM)
    dist.barrier()
    logging.info(f"reduce_scatter, rank: {rank}, input_list: {input_list}, tensor: {repr(output)}")
    dist.barrier()
```