<a href="https://colab.research.google.com/github/lizhieffe/llm_knowledge/blob/main/Torch_Dist_Practice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Torch dist practice

Reference
- https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/distributed/torch-distributed/readme.md

In [4]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

In [5]:
# @title Process for collective communication

# 广播（Broadcast）：广播是一种将数据从一个源进程发送到所有其他进程的通信操作。在 torch.distributed 中，通过 broadcast(tensor, src=0) 可以实现该操作，将 rank 为 0 的进程中的数据广播到所有其他进程。广播操作能够确保所有进程拥有相同的数据，适合需要共享模型参数、初始化权重等场景。比如在分布式训练的初始化阶段，用于将主进程的模型参数广播到所有其他进程，保证训练从同样的初始参数开始。
# 规约（Reduce 和 All-Reduce）：规约操作是一种将多个进程的数据进行计算（如求和、求最大值等）的操作。常用的规约操作有两种，reduce()：一个进程（通常是主进程）收集并合并来自所有进程的数据；all_reduce()：所有进程同时得到合并后的数据。比如 all_reduce(tensor, op=ReduceOp.SUM) 会在所有进程中求和，并将结果存放在每个进程的 tensor 中。规约操作能有效减少通信负担，适用于大规模梯度汇总或模型权重更新。譬如在分布式训练中，all_reduce 常用于梯度求和，以确保在多个进程中的梯度保持一致，实现同步更新。
# 收集（Gather 和 All-Gather）：收集操作是将多个进程的数据收集到一个或多个进程的操作：gather()：将多个进程的数据收集到一个进程中。all_gather()：所有进程都收集到全部进程的数据。例如 all_gather(gathered_tensors, tensor) 会将所有进程中的 tensor 收集到每个进程的 gathered_tensors 列表中。收集操作方便对所有进程中的数据进行后续分析和处理。譬如做 evaluation 时，可以使用 all_gather 来汇总各个进程的中间结果。
# 散发（Scatter）：scatter() 操作是将一个进程的数据分散到多个进程中。例如在 rank 为 0 的进程中有一个包含若干子张量的列表，scatter() 可以将列表中的每个子张量分配给其他进程。适用于数据分发，将大型数据集或模型权重在多个进程中分散，以便每个进程可以处理不同的数据块。

def init_process(rank, world_size):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  # Task 1 - all gather
  # It gathers information from all nodes.
  if rank == 0:
    print("\nTask 1 - all gather")
  process_info = (
      f"Process {rank} Information..."
  )
  max_len = 100
  process_info_tensor = torch.zeros(max_len, dtype=torch.int32)
  process_info_bytes = process_info.encode('utf-8')
  process_info_tensor[:len(process_info_bytes)] = torch.tensor([b for b in process_info_bytes], dtype=torch.int32)

  gathered_tensors = [torch.zeros(max_len, dtype=torch.int32) for _ in range(world_size)]

  dist.all_gather(gathered_tensors, process_info_tensor)

  if rank == 0:
    for t in gathered_tensors:
      info_bytes = t.numpy().astype('uint8').tobytes()
      info_str = info_bytes.decode('utf-8', 'ignore').strip('\x00')
      print(info_str)
  dist.barrier()

  # Task 2 - all reduce (sum)
  if rank == 0:
    print("\nTask 2 - all reduce")
  tensor = torch.ones((4,))
  dist.all_reduce(tensor)
  print(f"All reduce for all processes: in rank {rank}, tensor = {tensor}")
  dist.barrier()

  # Task 3 - all reduce (sum) in a sub-group.
  if rank == 0:
    print("\nTask 3 - all reduce for sub-group")
  sub_group_ranks = [1, 3]
  sub_group = dist.new_group(ranks=sub_group_ranks)
  if rank in sub_group_ranks:
    sub_group_tensor = torch.ones((4,))
    dist.all_reduce(sub_group_tensor, group=sub_group)
    print(f"Sub group all reduce: in rank {rank}, tensor = {sub_group_tensor}")
  dist.barrier()

  # Task 4 - all reduce (sum) in a sub-group, then sync results to the entire group.
  if rank == 0:
    print("\nRank 4 - all reduce (sum) in a sub-group, then sync results to the entire group.")
  group_1_sum = torch.tensor([1, 1, 1, 1])
  group_2_sum = torch.tensor([1.5] * 4)
  group_1_ranks = list(range(world_size // 2))
  group_2_ranks = list(range(world_size // 2, world_size))
  group_1 = dist.new_group(ranks=group_1_ranks)
  group_2 = dist.new_group(ranks=group_2_ranks)
  if rank in group_1_ranks:
    dist.all_reduce(group_1_sum, group=group_1)
  else:
    dist.all_reduce(group_2_sum, group=group_2)
  # Communicate the sub-group sums to the entire group.
  dist.all_reduce(group_1_sum, op=dist.ReduceOp.MAX)
  dist.all_reduce(group_2_sum, op=dist.ReduceOp.MAX)
  print(f"In rank {rank}, {group_1_sum=}, {group_2_sum=}")

  # Finish
  print(f"\nFinishing process with {rank=}, {world_size=}")
  dist.destroy_process_group()



# Colab doesn't support mp.spawn(), so we use mp.Process() to create the processes.
# mp.spawn(
#     init_process,
#     args=(1, 1,),
#     nprocs=1,
#     join=True
# )



In [6]:
# @title Process for P2P communication

# 点对点通信是最基础的通信模式，指的是一个进程直接向另一个特定的进程发送或接收数据。这种模式非常灵活，适合需要精确控制通信过程的场景。

# send-receive 模式：在 torch.distributed 中，这种模式可以通过 send() 和 recv() 接口实现。
# 比如 send(tensor, dst=1) 表示进程将数据发送给 rank 为 1 的进程，而 recv(tensor, src=0) 表示接收来自 rank 为 0 的进程的数据。毫无疑问，这是阻塞式的。
# 点对点通信的优点是简单直观，易于理解和控制；缺点是容易导致复杂的代码结构，尤其在需要多进程相互发送数据的情况下，可能会出现死锁或阻塞问题。因此，这种方式更多适用于两个进程之间的信息交换。适合需要精确控制单个进程之间数据交换的场景，通常在系统层通信优化中或模型分片时使用较多。例如在模型并行训练的梯度更新中，点对点通信可以用于梯度的汇总。

def init_process_p2p(rank: int, world_size: int):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  if rank == 0:
    tensor = torch.tensor([100, 200], dtype=torch.float32)
    print(f"Initial data in {rank=} is: {tensor}")
    dist.send(tensor, dst=1)
    print(f"Rank 0 sent data to Rank 1 ...")
  elif rank == 1:
    tensor = torch.zeros(2, dtype=torch.float32)
    print(f"Initial data in {rank=} is: {tensor}")
    dist.recv(tensor, src=0)
    print(f"Rank 1 received data from Rank 0 ...")

    tensor += 100
    print(f"Modified data in {rank=} is: {tensor}")
    dist.send(tensor, dst=2)
    print(f"Rank 1 sent data to Rank 2 ...")
  elif rank == 2:
    tensor = torch.zeros(2, dtype=torch.float32)
    print(f"Initial data in {rank=} is: {tensor}")
    dist.recv(tensor, src=1)
    print(f"Rank 2 received data from Rank 1 ...")
  else:
    tensor = torch.zeros(2, dtype=torch.float32)

  dist.barrier()

  print(f"Data in {rank=} is: {tensor}")

  # Finish
  print(f"\nFinishing process with {rank=}, {world_size=}")
  dist.destroy_process_group()


In [7]:
# @title Process for Async P2P communication

# 如果需要非阻塞通信，可以使用 isend/irecv
# 也可以使用dist.batch_isend_irecv fuse多个P2P通信操作. 该函数会尝试fuse多个NCCL kernel来提高throughput，并re-order通信顺序以减少deadlock概率。
# 在通信完成前不要修改发送缓冲区(buffer)，在通信完成前不要使用接收缓冲区，必须等待 wait() 完成后才能安全操作相关数据
# 每个异步操作都会占用系统资源，应及时调用 wait() 释放资源
# 避免同时发起过多未完成的异步操作
# 异步操作可能在后台失败，wait() 调用会暴露通信过程中的错误，建议使用 try-finally 确保资源正确清理

import random
import time

def do_other_work(rank: int):
  print(f"Rank {rank} is doing other work ...")
  time.sleep(random.uniform(1.0, 3.0))

def init_process_p2p_async(rank: int, world_size: int):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  if rank == 0:
    tensor = torch.tensor([100, 200], dtype=torch.float32)
    print(f"Initial data in {rank=} is: {tensor}")
    send_req = dist.isend(tensor, dst=1)
    print(f"Rank 0 sending data to Rank 1 ...")

    do_other_work(rank)

    send_req.wait()
    print(f"Rank 0 sent data to Rank 1 ...")

  elif rank == 1:
    tensor = torch.zeros(2, dtype=torch.float32)
    print(f"Initial data in {rank=} is: {tensor}")
    recv_req = dist.irecv(tensor, src=0)
    print(f"Rank 1 is receiving data from Rank 0 ...")

    do_other_work(rank)
    recv_req.wait()
    print(f"Rank 1 received data from Rank 0 ...")

    tensor += 100
    print(f"Modified data in {rank=} is: {tensor}")
    send_req = dist.isend(tensor, dst=2)
    print(f"Rank 1 is sending data to Rank 2 ...")

    do_other_work(rank)

    send_req.wait()
    print(f"Rank 1 sent data to Rank 2 ...")

  elif rank == 2:
    tensor = torch.zeros(2, dtype=torch.float32)
    print(f"Initial data in {rank=} is: {tensor}")
    recv_req = dist.irecv(tensor, src=1)
    print(f"Rank 2 is receiving data from Rank 1 ...")

    do_other_work(rank)

    recv_req.wait()
    print(f"Rank 2 received data from Rank 1 ...")
  else:
    tensor = torch.zeros(2, dtype=torch.float32)

  # dist.barrier()

  print(f"Data in {rank=} is: {tensor}")

  # Finish
  print(f"\nFinishing process with {rank=}, {world_size=}")
  dist.destroy_process_group()


In [16]:
# @title all_reduce & all_gather

# 功能定位：
# all_reduce: 对所有进程的数据进行规约（reduction）操作，如求和、取最大值等
# all_gather: 收集所有进程的数据，不进行运算，只是简单合并
# 输出结果：
# all_reduce: 所有进程得到相同的规约结果
# all_gather: 所有进程得到包含所有进程原始数据的完整列表
# 内存使用：
# all_reduce: 输出张量大小与输入相同
# all_gather: 输出张量大小是输入的 world_size 倍
# 适用场景：
# all_reduce：计算分布式损失，梯度同步，计算全局统计信息（如准确率）
# all_gather：获取其他进程的原始数据，分布式评估指标计算，汇总不同进程的中间结果
# 通讯效率：
# all_reduce 通常比 all_gather 更高效，如果只需要得到最终的汇总结果，应优先使用 all_reduce，传输的数据量更小，可以利用树形结构进行规约。

def init_process_all_reduce_n_all_gather(rank, world_size):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  # Test all_gather
  tensor = torch.tensor([rank * 10, rank * 10 + 1], dtype=torch.float32)
  gathered = [torch.zeros(2, dtype=torch.float) for _ in range(world_size)]
  dist.all_gather(gathered, tensor)

  if rank == 0:
    print(f"\n=== all_gather result ===")
    print(f"Original {tensor=}")
    print(f"Gathered tensor:")
    for i, t in enumerate(gathered):
      print(f"rank {rank} data = {t}")
  dist.barrier()

  # Test all_reduce
  reduced_tensor = tensor.clone()
  if rank == 0:
    print(f"Before all_reduce: {reduced_tensor}")
  dist.all_reduce(reduced_tensor, op=dist.ReduceOp.SUM)
  if rank == 0:
    print(f"\n=== all_reduce result ===")
    print(reduced_tensor)

  # Finish
  print(f"\nFinishing process with {rank=}, {world_size=}")
  dist.destroy_process_group()

In [54]:
# @title Softmax

# Implement using all_gather & all_reduce
def init_process_softmax(rank, world_size):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  torch.manual_seed(rank)

  logit = torch.zeros(1, dtype=torch.float32)
  logit.uniform_(-10.0, 10.0)

  max_logit = logit.clone()
  dist.all_reduce(max_logit, op=dist.ReduceOp.MAX)

  prob = (logit - max_logit).exp()
  prob_sum = prob.clone()
  dist.all_reduce(prob_sum, op=dist.ReduceOp.SUM)

  norm_prob = prob / prob_sum

  print(f"In {rank=}, logit={logit.item():.3f}, prob={norm_prob.item():.3f}")

  # Assert the softmax sum to 1.
  norm_prob_sum = norm_prob.clone()
  dist.all_reduce(norm_prob_sum, op=dist.ReduceOp.SUM)
  torch.allclose(norm_prob_sum, torch.ones(1))

  # Finish
  print(f"\nFinishing process with {rank=}, {world_size=}")
  dist.destroy_process_group()

In [69]:
# @title Broadcast

# broadcast 将源进程 src 的张量数据广播到所有其他进程的同名张量
# 接收数据的进程必须预先分配好相同大小的张量空间
# 广播操作是阻塞的，所有进程都需要执行到这行代码才能继续
# 数据会直接在预分配的内存上进行修改，而不是创建新的张量

# Implement using all_gather & all_reduce
def init_process_broadcast(rank, world_size):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  if rank == 0:
    tensor_1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
    tensor_2 = torch.zeros(2, dtype=torch.float32)
  elif rank == 1:
    tensor_1 = torch.zeros(3, dtype=torch.float32)
    tensor_2 = torch.tensor([4.0, 5.0], dtype=torch.float32)
  else:
    tensor_1 = torch.zeros(3, dtype=torch.float32)
    tensor_2 = torch.zeros(2, dtype=torch.float32)

  print(f"{rank=}")

  dist.broadcast(tensor_1, src=0)
  dist.broadcast(tensor_2, src=1)

  print(f"{rank=}, {tensor_1=}, {tensor_2=}")

  # Finish
  print(f"\nFinishing process with {rank=}, {world_size=}")
  dist.destroy_process_group()


In [71]:
# @title Scatter

def init_process_scatter(rank, world_size):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  output_t = torch.zeros(2, dtype=torch.float32)

  if rank == 0:
    scatter_list = [torch.ones(2, dtype=torch.float32) * i for i in range(world_size)]
  else:
    scatter_list = None

  dist.scatter(output_t, scatter_list, src=0)

  print(f"{rank=}, {output_t=}")

In [72]:
# @title Run the distributed processing

# scatter 是一对多的分发操作，只有源进程(这里是 rank 0)需要准备完整数据

# 其他进程的 scatter_list 必须设为 None，这是 PyTorch 的规定

# 数据必须事先按进程数量切分好，每个进程获得一份

# scatter 操作是同步的，所有进程都会在这里等待，直到通信完成

# 必须指定源进程 (src=0)，表明数据从哪个进程分发出去

# scatter_list 中的每个张量大小必须相同

# 总数据量必须能被进程数整除

# scatter 适合将大数据集划分给多个进程处理

# 相比 broadcast，scatter 可以节省其他进程的内存使用

# scatter 适合：

# 数据并行训练时分发不同的数据批次
# 将大规模数据集分片到多个节点进行处理
# 在参数服务器架构中分发模型参数
# 为什么说 scatter 比起 broadcast 节省空间？

# 考虑一共 4 个进程，需要从 rank 0 发 [1000, 250] 维度的数据给 rank 1, 2, 3，那么用 broadcast 则每张卡上都得有 [1000, 250] 大小的的数据块，然后各自切片。使用 scatter 则只有 rank 0 上会有 [1000, 1000]，其他 rank 上是 [1000, 250]。

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12359' # You can choose a different port if 12355 is in use

world_size = 4

processes = []
for rank in range(world_size):
  p = mp.Process(target=init_process_scatter, args=(rank, world_size))
  p.start()
  processes.append(p)

for p in processes:
  p.join()

Starting process with rank=0, world_size=4Starting process with rank=1, world_size=4Starting process with rank=2, world_size=4


Starting process with rank=3, world_size=4
rank=3, output_t=tensor([3., 3.])
rank=2, output_t=tensor([2., 2.])rank=1, output_t=tensor([1., 1.])
rank=0, output_t=tensor([0., 0.])

