- `export NCCL_P2P_DISABLE=1`:disable direct GPU-to-GPU(P2P) communincation
    - 禁用点对点通信，点对点通信是一种高效的数据传输方式，允许直接在 GPU 之间传输数据，绕过主机内存（CPU 内存）。但在某些情况下，点对点通信可能会导致兼容性或性能问题
    - 在 RTX 40 系显卡上，NCCL 默认的 GPU 直连（P2P）通信可能不被支持或会出错，因此主动关掉它，让数据走传统路径（经 CPU 内存），避免崩溃或性能反降


``` python
"""pytorch分布式相关api"""
import torch.distributed as dist
dist.init_process_group() #初始化进程组
dist.get_rank() #获得当前进程rank
dist.get_world_size() #获取进程组的进程总数
dist.barrier() #同步进程组内的所有进程，阻塞所有进程直到所有进程都执行到操作
```

- `p.join()` 是 Python 进程级同步——“主进程等子进程 生命结束”。
- `torch.distributed.barrier()` 是 分布式集合通信——“所有 逻辑 rank 互相等，直到 全部到达才继续往下执行”。

---

### scatter

``` python
def dist_scatter():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    tensor = torch.zeros(world_size)
    before_tensor = tensor.clone()
    
    if dist.get_rank() == 0:
        t_ones = torch.ones(world_size)
        t_fives = torch.ones(world_size) * 5
        scatter_list = [t_ones, t_fives]
    else:
        scatter_list = None
    dist.scatter(tensor, scatter_list, src)
    logging.info(f"scatter, rank: {rank}, before scatter: {repr(before_tensor)} after scatter: {repr(tensor)}")
    dist.barrier()
```

---

### gather

``` python

def dist_gather():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    tensor = torch.tensor([rank*2+1], dtype=torch.float32)
    before_tensor = tensor.clone()

    gather_list = [torch.zeros(1) for _ in range(world_size)] if rank == 0 else None

    dist.gather(tensor, gather_list, dst=0)
    
    logging.info(f"gather, rank: {rank}, before gather: {repr(before_tensor)} after gather: {repr(gather_list)}")
    dist.barrier()

```

---
### broadcast

``` python
def dist_broadcast():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    src_rank = 0
    tensor = torch.tensor(world_size) if rank == src_rank else torch.zeros(1, dtype=torch.int64)
    before_tensor = tensor.clone()
    dist.broadcast(tensor, src=src_rank)
    logging.info(f"broadcast, rank: {rank}, before broadcast tensor: {repr(before_tensor)} after broadcast tensor: {repr(tensor)}")
    dist.barrier()
```

---
### reduce

``` python
def dist_reduce():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    tensor = torch.tensor([rank*2 + 1], dtype=torch.float32)
    before_tensor = tensor.clone()

    dist.reduce(tensor, op=ReduceOp.SUM, dst=0)
    
    logging.info(f"reduce, rank: {rank}, before reduce: {repr(before_tensor)} after reduce: {repr(tensor)}")
    dist.barrier()
```

---

### all-reduce

``` python
def dist_allreduce():
    dist.barrier()

    rank = dist.get_rank()
    if rank == 0:
        tensor = torch.tensor([1., 2.])
    else:
        tensor = torch.tensor([2., 3.])
    input_tensor = tensor.clone()
    dist.all_reduce(tensor)

    logging.info(f"all_reduce, rank: {rank}, before allreduce tensor: {repr(input_tensor)}, after allreduce tensor: {repr(tensor)}")
    dist.barrier()
```

---

### all gather

- gather + broadcast

``` python
def dist_allgather():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    input_tensor = torch.tensor(rank)
    tensor_list = [torch.zeros(1, dtype=torch.int64) for _ in range(world_size)]
    dist.all_gather(tensor_list, input_tensor)
    logging.info(f"allgather, rank: {rank}, input_tensor: {repr(input_tensor)}, output tensor_list: {tensor_list}")
    dist.barrier()
```

reduce + scatter

``` python
def dist_reducescatter():
    dist.barrier()

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    output = torch.empty(1, dtype=torch.int64)
    input_list = [torch.tensor(rank*2+1), torch.tensor(rank*2+2)]
    dist.reduce_scatter(output, input_list, op=ReduceOp.SUM)
    logging.info(f"reduce_scatter, rank: {rank}, input_list: {input_list}, tensor: {repr(output)}")
    dist.barrier()
```

In [1]:
import torch
import torch.distributed as dist

In [3]:
torch.zeros(1),torch.tensor(1)

(tensor([0.]), tensor(1))