<a href="https://colab.research.google.com/github/lizhieffe/llm_knowledge/blob/main/%5BDist%5D_PyTorch_DDP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- Reference: https://zhuanlan.zhihu.com/p/178402798
- Good read about the mechanism:
  - https://zhuanlan.zhihu.com/p/187610959

# Single GPU

In [20]:
import torch

DEVICE = 'cpu'

torch.manual_seed(123)

# Init
input = torch.randn(20, 10).to(DEVICE) # (20, 10)
labels = torch.randn(20, 10).to(DEVICE)

loss_fn = torch.nn.MSELoss()

model = torch.nn.Linear(10, 10).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=1)

for it in range(1):
  # forward
  optimizer.zero_grad()
  outputs = model(input)

  # backward
  loss_fn(outputs, labels).backward()
  optimizer.step()

  # check model params
  print(f"In epoch {it}")
  for name, param in model.named_parameters():
    if param.requires_grad:
      print(f"{name=}, {param.data=}")

In epoch 0
name='weight', param.data=tensor([[ 0.0944, -0.1266, -0.1251,  0.1248, -0.0659, -0.0834,  0.1643, -0.1748,
          0.1602,  0.2234],
        [-0.1549, -0.2573,  0.0233, -0.1648,  0.1694, -0.1869, -0.2008, -0.0400,
         -0.0904, -0.2047],
        [ 0.1007,  0.1481,  0.0845,  0.1624,  0.0025,  0.2104, -0.2733,  0.2917,
          0.1926, -0.1286],
        [-0.0789, -0.1440,  0.2480,  0.0249,  0.0506,  0.2015, -0.1588,  0.2132,
         -0.2207, -0.1388],
        [-0.0530, -0.2940, -0.2188, -0.2619, -0.1186, -0.1148, -0.0541,  0.0963,
         -0.0057, -0.1375],
        [-0.2125,  0.2258, -0.2172, -0.1571, -0.1958,  0.2219,  0.2494,  0.0186,
         -0.1865,  0.0155],
        [ 0.0577,  0.0483,  0.0066, -0.1421, -0.0218,  0.0573, -0.0655,  0.0892,
         -0.1105,  0.0901],
        [-0.0250,  0.1317, -0.1487, -0.2025, -0.2007,  0.1900,  0.1915,  0.1950,
          0.1023,  0.0138],
        [-0.0253, -0.2263, -0.0107, -0.1052, -0.0309,  0.1138, -0.0093, -0.1307,
          

# DDP

- Here we use CPU instead of GPU to reduce the system requirement.


## Try 1 - Manually dist the training data.

- We expect the trained model weights to be exactly the same as the single CPU scenario.

In [23]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

DEVICE = 'cpu'

def run_single_process(rank: int, world_size: int):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  split_data_size = 20 // 4

  torch.manual_seed(123)

  # Create the train set.
  if rank == 0:
    inputs = torch.randn(20, 10)
    inputs_split_list = torch.split(inputs, split_data_size, dim=0)
    inputs_split_list = list(inputs_split_list)
    assert (20 // split_data_size) == len(inputs_split_list)

    targets = torch.randn(20, 10)
    targets_split_list = torch.split(targets, split_data_size, dim=0)
    targets_split_list = list(targets_split_list)
    assert (20 // split_data_size) == len(targets_split_list)
  else:
    inputs_split_list = None
    targets_split_list = None

  # Split the train set and send to the distributed workers.
  inputs_split = torch.zeros((split_data_size, 10), dtype=torch.float32)
  dist.scatter(inputs_split, inputs_split_list, src=0)
  inputs_split.to(DEVICE)

  targets_split = torch.zeros((split_data_size, 10), dtype=torch.float32)
  dist.scatter(targets_split, targets_split_list, src=0)
  targets_split.to(DEVICE)

  # Init the model
  model = torch.nn.Linear(10, 10).to(DEVICE)
  ddp_model = DDP(model, device_ids=None)
  loss_fn = torch.nn.MSELoss()
  optimizer = torch.optim.SGD(ddp_model.parameters(), lr=1)

  # forward
  optimizer.zero_grad()
  outputs = ddp_model(inputs_split)

  # backward
  loss_fn(outputs, targets_split).backward()

  # check model params
  # if rank == 0:
  #   print("Before backward")
  #   for name, param in ddp_model.named_parameters():
  #     if param.requires_grad:
  #       print(f"{name=}, {param.data=}")

  optimizer.step()

  # check model params
  if rank == 0:
    print("After backward")
    for name, param in ddp_model.named_parameters():
      if param.requires_grad:
        print(f"{name=}, {param.data=}")

  dist.destroy_process_group()

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355' # You can choose a different port if 12355 is in use

world_size = 4

processes = []
for rank in range(world_size):
  p = mp.Process(target=run_single_process, args=(rank, world_size))
  p.start()
  processes.append(p)

for p in processes:
  p.join()

Starting process with rank=0, world_size=4
Starting process with rank=1, world_size=4
Starting process with rank=2, world_size=4
Starting process with rank=3, world_size=4
After backward
name='module.weight', param.data=tensor([[ 0.0944, -0.1266, -0.1251,  0.1248, -0.0659, -0.0834,  0.1643, -0.1748,
          0.1602,  0.2234],
        [-0.1549, -0.2573,  0.0233, -0.1648,  0.1694, -0.1869, -0.2008, -0.0400,
         -0.0904, -0.2047],
        [ 0.1007,  0.1481,  0.0845,  0.1624,  0.0025,  0.2104, -0.2733,  0.2917,
          0.1926, -0.1286],
        [-0.0789, -0.1440,  0.2480,  0.0249,  0.0506,  0.2015, -0.1588,  0.2132,
         -0.2207, -0.1388],
        [-0.0530, -0.2940, -0.2188, -0.2619, -0.1186, -0.1148, -0.0541,  0.0963,
         -0.0057, -0.1375],
        [-0.2125,  0.2258, -0.2172, -0.1571, -0.1958,  0.2219,  0.2494,  0.0186,
         -0.1865,  0.0155],
        [ 0.0577,  0.0483,  0.0066, -0.1421, -0.0218,  0.0573, -0.0655,  0.0892,
         -0.1105,  0.0901],
        [-0.0250,

## Try 2 - use distributed sampler

In [60]:
EPOCHS = 2

# This is the global bs. In DPP, it should guarantee the sum of bs on all devices equal to this #.
BATCH_SIZE = 16

WORLD_SIZE = 4

class ToyModel(nn.Module):
  def __init__(self):
      super(ToyModel, self).__init__()
      self.conv1 = nn.Conv2d(3, 6, 5)
      self.pool = nn.MaxPool2d(2, 2)
      self.conv2 = nn.Conv2d(6, 16, 5)
      self.fc1 = nn.Linear(16 * 5 * 5, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = x.view(-1, 16 * 5 * 5)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x

# from the official doc: https://docs.pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
class ToyModel(nn.Module):
  def __init__(self):
      super().__init__()
      self.conv1 = nn.Conv2d(3, 6, 5)
      self.pool = nn.MaxPool2d(2, 2)
      self.conv2 = nn.Conv2d(6, 16, 5)
      self.fc1 = nn.Linear(16 * 5 * 5, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = torch.flatten(x, 1) # flatten all dimensions except batch
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x

dataset_transform = torchvision.transforms.Compose([
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])

In [61]:
# @title single CPU version

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision

from torch.nn.parallel import DistributedDataParallel as DDP

DEVICE = 'cpu'

# Init the model
model = ToyModel().to(DEVICE)
model.train()
loss_fn = torch.nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Dataset
download_path = "./data"
my_trainset = torchvision.datasets.CIFAR10(root=download_path, train=True, download=True, transform=dataset_transform)
trainloader = torch.utils.data.DataLoader(my_trainset, batch_size=BATCH_SIZE)

for epoch in range(EPOCHS):
  for it, (data, label) in enumerate(trainloader):

    # forward
    optimizer.zero_grad()
    outputs = model(data)

    # backward
    loss = loss_fn(outputs, label)
    loss.backward()

    if it % 200 == 0:
      print(f"{epoch=}, {it=}, loss={loss.item():.3f}")

    optimizer.step()

epoch=0, it=0, loss=2.320
epoch=0, it=200, loss=2.296
epoch=0, it=400, loss=2.289
epoch=0, it=600, loss=2.263
epoch=0, it=800, loss=2.210
epoch=0, it=1000, loss=2.294
epoch=0, it=1200, loss=2.180
epoch=0, it=1400, loss=2.228
epoch=0, it=1600, loss=2.208
epoch=0, it=1800, loss=1.816
epoch=0, it=2000, loss=1.924
epoch=0, it=2200, loss=1.912
epoch=0, it=2400, loss=1.899
epoch=0, it=2600, loss=1.794
epoch=0, it=2800, loss=1.813
epoch=0, it=3000, loss=1.949


In [62]:
# @title dist version

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision

from torch.nn.parallel import DistributedDataParallel as DDP

DEVICE = 'cpu'

def run_single_process(rank: int, world_size: int):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=WORLD_SIZE, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  torch.manual_seed(123)

  # Init the model
  #
  # DDP初始化（也就是model = DDP(model)这一步）
  # 1. 把parameter，buffer从master节点传到其他节点，使所有进程上的状态一致。
  #   注释：DDP通过这一步保证所有进程的初始状态一致。所以，请确保在这一步之后，你的代码不会再修改模型的任何东西了，包括添加、修改、删除parameter和buffer！
  # 2.（可能）如果有每个节点有多卡，则在每张卡上创建模型（类似DP）
  # 3. 把parameter进行分组，每一组称为一个bucket。临近的parameter在同一个bucket。
  #   注释：这是为了加速，在梯度通讯时，先计算、得到梯度的bucket会马上进行通讯，不必等到所有梯度计算结束才进行通讯。后面会详细介绍。
  # 4. 创建管理器reducer，给每个parameter注册梯度平均的hook。
  #   注释：这一步的具体实现是在C++代码里面的，即reducer.h文件。
  # 5.（可能）为可能的SyncBN层做准备
  #
  # 在每个step中，DDP模型都会做下面的事情：
  # 1. 采样数据，从dataloader得到一个batch的数据，用于当前计算（for data, label in dataloader）。
  #   注释：因为我们的dataloader使用了DistributedSampler，所以各个进程之间的数据是不会重复的。如果要确保DDP性能和单卡性能一致，这边需要保证在数据上，DDP模式下的一个epoch和单卡下的一个epoch是等效的。
  # 2. 进行网络的前向计算（prediction = model(data)）
  #   2.1 同步各进程状态
  #     2.1.1（可能）对单进程多卡复制模式，要在进程内同步多卡之间的parameter和buffer
  #     2.1.2 同步各进程之间的buffer。
  #   2.2 接下来才是进行真正的前向计算
  #   2.3（可能）当DDP参数find_unused_parameter为true时，其会在forward结束时，启动一个回溯，标记出所有没被用到的parameter，提前把这些设定为ready。
  #     注释：find_unused_parameter的默认值是false，因为其会拖慢速度。
  # 3. 计算梯度（loss.backward()）
  #   3.1 reducer外面：各个进程各自开始反向地计算梯度。
  #     3.1.1 注释：梯度是反向计算的，所以最后面的参数反而是最先得到梯度的。
  #   3.2 reducer外面：当某个parameter的梯度计算好了的时候，其之前注册的grad hook就会被触发，在reducer里把这个parameter的状态标记为ready。
  #   3.3 reducer里面：当某个bucket的所有parameter都是ready状态时，reducer会开始对这个bucket的所有parameter都开始一个异步的all-reduce梯度平均操作。
  #     注释：
  #       3.3.1 bucket的执行过程也是有顺序的，其顺序与parameter是相反的，即最先注册的parameter的bucket在最后面。
  #       3.3.2 所以，我们在创建module的时候，请务必把先进行计算的parameter注册在前面，后计算的在后面。不然，reducer会卡在某一个bucket等待，使训练时间延长！
  #         3.3.2.1 所谓的参数注册，其实就是创建网络层。也就是要求按照网络计算顺序，依次创建网络层。
  #   3.4 reducer里面：当所有bucket的梯度平均都结束后，reducer才会把得到的平均grad结果正式写入到parameter.grad里面。
  #   注释：这一步，感觉没有必要等全部结束之后才进行。可能得对照一下源码。
  # 4. 优化器optimizer应用gradient，更新参数（optimizer.step()）。
  #   注释：这一步，是和DDP没关系的。
  model = ToyModel().to(DEVICE)
  ddp_model = DDP(model, device_ids=None)
  ddp_model.train()
  loss_fn = torch.nn.CrossEntropyLoss().to(DEVICE)

  # Init the optimizer.
  #
  # 我们可以看到，因为optimizer和DDP是没有关系的，所以optimizer初始状态的同一性是不被DDP保证的！
  # 大多数官方optimizer，其实现能保证从同样状态的model初始化时，其初始状态是相同的。
  # 所以这边我们只要保证在DDP模型创建后才初始化optimizer，就不用做额外的操作。
  # 但是，如果自定义optimizer，则需要你自己来保证其统一性！
  # 回顾一下文章最开始的代码，你会发现，optimizer确实是在DDP之后定义的。这个时候的模式已经是被初始化为相同的参数，所以能够保证优化器的初始状态是相同的。
  optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)

  # Dataset
  download_path = f"./data_{rank}"
  my_trainset = torchvision.datasets.CIFAR10(root=download_path, train=True, download=True, transform=dataset_transform)
  # DDP：使用DistributedSampler，DDP帮我们把细节都封装起来了。
  #      用，就完事儿！
  train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)
  # DDP：需要注意的是，这里的batch_size指的是每个进程下的batch_size。
  #      也就是说，总batch_size是这里的batch_size再乘以并行数(world_size)。
  assert BATCH_SIZE % WORLD_SIZE == 0
  trainloader = torch.utils.data.DataLoader(my_trainset, batch_size=BATCH_SIZE//WORLD_SIZE, sampler=train_sampler)

  for epoch in range(EPOCHS):
    # The distributed training loss is not going to be the same as the single device training.
    # The reason is that the distributed sampler uses "epoch" as the sampling seed in each host.
    #
    # 不知道你有没有好奇，为什么给dataloader加一个DistributedSampler，就可以无缝对接DDP模式呢？
    # 其实原理很简单，就是给不同进程分配数据集的不重叠、不交叉部分。
    # 那么问题来了，每次epoch我们都会随机shuffle数据集，那么，不同进程之间要怎么保持shuffle后数据集的一致性呢？
    # DistributedSampler的实现方式是，不同进程会使用一个相同的随机数种子，这样shuffle出来的东西就能确保一致。
    #
    # 具体实现上，DistributedSampler使用当前epoch作为随机数种子，从而使得不同epoch下有不同的shuffle结果。
    # 所以，记得每次epoch开始前都要调用一下sampler的set_epoch方法，这样才能让数据集随机shuffle起来。
    trainloader.sampler.set_epoch(epoch)

    for it, (data, label) in enumerate(trainloader):

      # forward
      optimizer.zero_grad()
      outputs = ddp_model(data)

      # backward
      loss = loss_fn(outputs, label)
      loss.backward()

      if rank == 0 and it % 200 == 0:
        print(f"{epoch=}, {it=}, loss={loss.item():.3f}")

      optimizer.step()

  dist.destroy_process_group()



os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355' # You can choose a different port if 12355 is in use

processes = []
for rank in range(WORLD_SIZE):
  p = mp.Process(target=run_single_process, args=(rank, WORLD_SIZE))
  p.start()
  processes.append(p)

for p in processes:
  p.join()

Starting process with rank=0, world_size=4
Starting process with rank=1, world_size=4
Starting process with rank=2, world_size=4
Starting process with rank=3, world_size=4
epoch=0, it=0, loss=2.284
epoch=0, it=200, loss=2.244
epoch=0, it=400, loss=2.296
epoch=0, it=600, loss=2.322
epoch=0, it=800, loss=2.286
epoch=0, it=1000, loss=2.185
epoch=0, it=1200, loss=2.227
epoch=0, it=1400, loss=1.905
epoch=0, it=1600, loss=2.099
epoch=0, it=1800, loss=1.635
epoch=0, it=2000, loss=2.344
epoch=0, it=2200, loss=1.547
epoch=0, it=2400, loss=2.609
epoch=0, it=2600, loss=2.036
epoch=0, it=2800, loss=1.854
epoch=0, it=3000, loss=1.633
epoch=1, it=0, loss=1.158
epoch=1, it=200, loss=1.877
epoch=1, it=400, loss=2.356
epoch=1, it=600, loss=1.611
epoch=1, it=800, loss=1.564
epoch=1, it=1000, loss=1.683
epoch=1, it=1200, loss=0.823
epoch=1, it=1400, loss=1.330
epoch=1, it=1600, loss=1.527
epoch=1, it=1800, loss=2.017
epoch=1, it=2000, loss=1.623
epoch=1, it=2200, loss=0.835
epoch=1, it=2400, loss=1.886
e