<a href="https://colab.research.google.com/github/lizhieffe/llm_knowledge/blob/main/PyTorch_DDP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

- Reference: https://zhuanlan.zhihu.com/p/178402798

# Single GPU

In [20]:
import torch

DEVICE = 'cpu'

torch.manual_seed(123)

# Init
input = torch.randn(20, 10).to(DEVICE) # (20, 10)
labels = torch.randn(20, 10).to(DEVICE)

loss_fn = torch.nn.MSELoss()

model = torch.nn.Linear(10, 10).to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=1)

for it in range(1):
  # forward
  optimizer.zero_grad()
  outputs = model(input)

  # backward
  loss_fn(outputs, labels).backward()
  optimizer.step()

  # check model params
  print(f"In epoch {it}")
  for name, param in model.named_parameters():
    if param.requires_grad:
      print(f"{name=}, {param.data=}")

In epoch 0
name='weight', param.data=tensor([[ 0.0944, -0.1266, -0.1251,  0.1248, -0.0659, -0.0834,  0.1643, -0.1748,
          0.1602,  0.2234],
        [-0.1549, -0.2573,  0.0233, -0.1648,  0.1694, -0.1869, -0.2008, -0.0400,
         -0.0904, -0.2047],
        [ 0.1007,  0.1481,  0.0845,  0.1624,  0.0025,  0.2104, -0.2733,  0.2917,
          0.1926, -0.1286],
        [-0.0789, -0.1440,  0.2480,  0.0249,  0.0506,  0.2015, -0.1588,  0.2132,
         -0.2207, -0.1388],
        [-0.0530, -0.2940, -0.2188, -0.2619, -0.1186, -0.1148, -0.0541,  0.0963,
         -0.0057, -0.1375],
        [-0.2125,  0.2258, -0.2172, -0.1571, -0.1958,  0.2219,  0.2494,  0.0186,
         -0.1865,  0.0155],
        [ 0.0577,  0.0483,  0.0066, -0.1421, -0.0218,  0.0573, -0.0655,  0.0892,
         -0.1105,  0.0901],
        [-0.0250,  0.1317, -0.1487, -0.2025, -0.2007,  0.1900,  0.1915,  0.1950,
          0.1023,  0.0138],
        [-0.0253, -0.2263, -0.0107, -0.1052, -0.0309,  0.1138, -0.0093, -0.1307,
          

# DDP

- Here we use CPU instead of GPU to reduce the system requirement.


## Try 1 - Manually dist the training data.

- We expect the trained model weights to be exactly the same as the single CPU scenario.

In [23]:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

DEVICE = 'cpu'

def run_single_process(rank: int, world_size: int):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  split_data_size = 20 // 4

  torch.manual_seed(123)

  # Create the train set.
  if rank == 0:
    inputs = torch.randn(20, 10)
    inputs_split_list = torch.split(inputs, split_data_size, dim=0)
    inputs_split_list = list(inputs_split_list)
    assert (20 // split_data_size) == len(inputs_split_list)

    targets = torch.randn(20, 10)
    targets_split_list = torch.split(targets, split_data_size, dim=0)
    targets_split_list = list(targets_split_list)
    assert (20 // split_data_size) == len(targets_split_list)
  else:
    inputs_split_list = None
    targets_split_list = None

  # Split the train set and send to the distributed workers.
  inputs_split = torch.zeros((split_data_size, 10), dtype=torch.float32)
  dist.scatter(inputs_split, inputs_split_list, src=0)
  inputs_split.to(DEVICE)

  targets_split = torch.zeros((split_data_size, 10), dtype=torch.float32)
  dist.scatter(targets_split, targets_split_list, src=0)
  targets_split.to(DEVICE)

  # Init the model
  model = torch.nn.Linear(10, 10).to(DEVICE)
  ddp_model = DDP(model, device_ids=None)
  loss_fn = torch.nn.MSELoss()
  optimizer = torch.optim.SGD(ddp_model.parameters(), lr=1)

  # forward
  optimizer.zero_grad()
  outputs = ddp_model(inputs_split)

  # backward
  loss_fn(outputs, targets_split).backward()

  # check model params
  # if rank == 0:
  #   print("Before backward")
  #   for name, param in ddp_model.named_parameters():
  #     if param.requires_grad:
  #       print(f"{name=}, {param.data=}")

  optimizer.step()

  # check model params
  if rank == 0:
    print("After backward")
    for name, param in ddp_model.named_parameters():
      if param.requires_grad:
        print(f"{name=}, {param.data=}")

  dist.destroy_process_group()

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355' # You can choose a different port if 12355 is in use

world_size = 4

processes = []
for rank in range(world_size):
  p = mp.Process(target=run_single_process, args=(rank, world_size))
  p.start()
  processes.append(p)

for p in processes:
  p.join()

Starting process with rank=0, world_size=4
Starting process with rank=1, world_size=4
Starting process with rank=2, world_size=4
Starting process with rank=3, world_size=4
After backward
name='module.weight', param.data=tensor([[ 0.0944, -0.1266, -0.1251,  0.1248, -0.0659, -0.0834,  0.1643, -0.1748,
          0.1602,  0.2234],
        [-0.1549, -0.2573,  0.0233, -0.1648,  0.1694, -0.1869, -0.2008, -0.0400,
         -0.0904, -0.2047],
        [ 0.1007,  0.1481,  0.0845,  0.1624,  0.0025,  0.2104, -0.2733,  0.2917,
          0.1926, -0.1286],
        [-0.0789, -0.1440,  0.2480,  0.0249,  0.0506,  0.2015, -0.1588,  0.2132,
         -0.2207, -0.1388],
        [-0.0530, -0.2940, -0.2188, -0.2619, -0.1186, -0.1148, -0.0541,  0.0963,
         -0.0057, -0.1375],
        [-0.2125,  0.2258, -0.2172, -0.1571, -0.1958,  0.2219,  0.2494,  0.0186,
         -0.1865,  0.0155],
        [ 0.0577,  0.0483,  0.0066, -0.1421, -0.0218,  0.0573, -0.0655,  0.0892,
         -0.1105,  0.0901],
        [-0.0250,

## Try 2 - use distributed sampler

In [54]:
class ToyModel(nn.Module):

  def __init__(self):
      super(ToyModel, self).__init__()
      self.conv1 = nn.Conv2d(3, 6, 5)
      self.pool = nn.MaxPool2d(2, 2)
      self.conv2 = nn.Conv2d(6, 16, 5)
      self.fc1 = nn.Linear(16 * 5 * 5, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = x.view(-1, 16 * 5 * 5)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x

# from the official doc: https://docs.pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
class ToyModel(nn.Module):
  def __init__(self):
      super().__init__()
      self.conv1 = nn.Conv2d(3, 6, 5)
      self.pool = nn.MaxPool2d(2, 2)
      self.conv2 = nn.Conv2d(6, 16, 5)
      self.fc1 = nn.Linear(16 * 5 * 5, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = torch.flatten(x, 1) # flatten all dimensions except batch
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x

In [56]:
# @title single CPU version

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision

from torch.nn.parallel import DistributedDataParallel as DDP

DEVICE = 'cpu'

# Init the model
model = ToyModel().to(DEVICE)
model.train()
loss_fn = torch.nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Dataset
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
download_path = "./data"
my_trainset = torchvision.datasets.CIFAR10(root=download_path, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(my_trainset, batch_size=16)

for epoch in range(1):
  for it, (data, label) in enumerate(trainloader):

    # forward
    optimizer.zero_grad()
    outputs = model(data)

    # backward
    loss = loss_fn(outputs, label)
    loss.backward()

    if it % 200 == 0:
      print(f"{epoch=}, {it=}, loss={loss.item():.3f}")

    optimizer.step()

epoch=0, it=0, loss=2.311
epoch=0, it=200, loss=2.296
epoch=0, it=400, loss=2.280
epoch=0, it=600, loss=2.284
epoch=0, it=800, loss=2.257
epoch=0, it=1000, loss=2.307
epoch=0, it=1200, loss=2.150
epoch=0, it=1400, loss=2.144
epoch=0, it=1600, loss=2.273
epoch=0, it=1800, loss=1.866
epoch=0, it=2000, loss=1.904
epoch=0, it=2200, loss=1.905
epoch=0, it=2400, loss=1.841
epoch=0, it=2600, loss=1.785
epoch=0, it=2800, loss=1.831
epoch=0, it=3000, loss=2.042


In [57]:
# @title dist version

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision

from torch.nn.parallel import DistributedDataParallel as DDP

DEVICE = 'cpu'

class ToyModel(nn.Module):
  def __init__(self):
      super(ToyModel, self).__init__()
      self.conv1 = nn.Conv2d(3, 6, 5)
      self.pool = nn.MaxPool2d(2, 2)
      self.conv2 = nn.Conv2d(6, 16, 5)
      self.fc1 = nn.Linear(16 * 5 * 5, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = x.view(-1, 16 * 5 * 5)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x

def run_single_process(rank: int, world_size: int):
  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  torch.manual_seed(123)

  # Init the model
  model = ToyModel().to(DEVICE)
  ddp_model = DDP(model, device_ids=None)
  ddp_model.train()
  loss_fn = torch.nn.CrossEntropyLoss().to(DEVICE)
  optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)

  # Dataset
  transform = torchvision.transforms.Compose([
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])
  download_path = f"./data_{rank}"
  my_trainset = torchvision.datasets.CIFAR10(root=download_path, train=True, download=True, transform=transform)
  train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)
  trainloader = torch.utils.data.DataLoader(my_trainset, batch_size=16, sampler=train_sampler)

  for epoch in range(10):
    trainloader.sampler.set_epoch(epoch)

    for it, (data, label) in enumerate(trainloader):

      # forward
      optimizer.zero_grad()
      outputs = ddp_model(data)

      # backward
      loss = loss_fn(outputs, label)
      loss.backward()

      if rank == 0 and it % 200 == 0:
        print(f"{epoch=}, {it=}, loss={loss.item():.3f}")

      optimizer.step()

  dist.destroy_process_group()




os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355' # You can choose a different port if 12355 is in use

world_size = 4

processes = []
for rank in range(world_size):
  p = mp.Process(target=run_single_process, args=(rank, world_size))
  p.start()
  processes.append(p)

for p in processes:
  p.join()

Starting process with rank=0, world_size=4Starting process with rank=1, world_size=4

Starting process with rank=2, world_size=4
Starting process with rank=3, world_size=4
epoch=0, it=0, loss=2.279
epoch=0, it=200, loss=2.307
epoch=0, it=400, loss=2.318
epoch=0, it=600, loss=2.288
epoch=1, it=0, loss=2.294
epoch=1, it=200, loss=2.279
epoch=1, it=400, loss=2.239
epoch=1, it=600, loss=2.097
epoch=2, it=0, loss=2.203
epoch=2, it=200, loss=2.002
epoch=2, it=400, loss=1.780
epoch=2, it=600, loss=1.853
epoch=3, it=0, loss=2.172
epoch=3, it=200, loss=1.627
epoch=3, it=400, loss=1.575
epoch=3, it=600, loss=1.668
epoch=4, it=0, loss=1.775
epoch=4, it=200, loss=1.398
epoch=4, it=400, loss=1.625
epoch=4, it=600, loss=1.846
epoch=5, it=0, loss=1.719
epoch=5, it=200, loss=1.477
epoch=5, it=400, loss=2.220
epoch=5, it=600, loss=1.522
epoch=6, it=0, loss=1.281
epoch=6, it=200, loss=1.594
epoch=6, it=400, loss=1.658
epoch=6, it=600, loss=1.494
epoch=7, it=0, loss=1.261
epoch=7, it=200, loss=1.528
epoc