<a href="https://colab.research.google.com/github/lizhieffe/llm_knowledge/blob/main/examples/pytorch_dist/%5BDist%5D_PP_Practice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

It seems the main challenge for PP is to implement a scheduler/orchestrator.

It is hard to find simple code example focusing on the data communication etc. (e.g. it is relatively easy to find for TP)

Some good reads:
- https://siboehm.com/articles/22/pipeline-parallel-training & https://github.com/siboehm/ShallowSpeed
- https://torchgpipe.readthedocs.io/en/stable/

# Use PyTorch Library

> Tutorial: https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html

Note
- The tutorial only discusses the **forward pass**.
- TODO: find how to implement the **backward pass**.

In [2]:
import torch
import torch.nn as nn
from dataclasses import dataclass

@dataclass
class ModelArgs:
   dim: int = 32
   n_layers: int = 8
   n_heads: int = 4
   vocab_size: int = 100

class Transformer(nn.Module):
   def __init__(self, model_args: ModelArgs):
      super().__init__()

      self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)

      # Using a ModuleDict lets us delete layers witout affecting names,
      # ensuring checkpoints will correctly save and load.
      self.layers = torch.nn.ModuleDict()
      for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = nn.TransformerDecoderLayer(model_args.dim, model_args.n_heads)

      self.norm = nn.LayerNorm(model_args.dim)
      self.output = nn.Linear(model_args.dim, model_args.vocab_size)

   def forward(self, tokens: torch.Tensor):
      # Handling layers being 'None' at runtime enables easy pipeline splitting
      h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

      for layer in self.layers.values():
            h = layer(h, h)

      h = self.norm(h) if self.norm else h
      output = self.output(h).clone() if self.output else h
      return output

## Step 1: Partition the Transformer Model

2 stages
- Stage 1: the first few layers of the model
- Stage 2: the last few layers of the model

What is `PipelineStage`?

> We need to create PipelineStage objects that wrap the part of the model running in that stage. The PipelineStage is responsible for **allocating communication buffers and creating send/recv ops to communicate with its peers**. It manages intermediate buffers e.g. for the outputs of forward that have not been consumed yet, and it provides a utility for running the backwards for the stage model.

In [3]:
import os
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, ScheduleGPipe

def manual_model_split(
    model: nn.Module, stage_index: int, num_stages: int, device: str) -> PipelineStage:
  assert stage_index in (0, 1), f"{stage_index=}"

  if stage_index == 0:
    n_layers = len(model.layers)
    for layer_id in range(n_layers // 2, n_layers):
      del model.layers[str(layer_id)]
    assert len(model.layers) == n_layers // 2

    model.norm = None
    model.output = None

  else:
    n_layers = len(model.layers)
    for layer_id in range(n_layers // 2):
      del model.layers[str(layer_id)]
    assert len(model.layers) == n_layers // 2

    model.tok_embeddings = None

  stage = PipelineStage(
      model,
      stage_index,
      num_stages,
      device,
  )
  return stage

In [4]:
import torch.multiprocessing as mp
import torch.nn.functional as F

def init_process(rank: int, world_size: int):
  assert world_size == 2

  print(f"Starting process with {rank=}, {world_size=}")

  # Use the gloo backend for CPU-based distributed processing
  dist.init_process_group(backend="gloo", world_size=world_size, rank=rank)

  assert rank == dist.get_rank()
  assert world_size == dist.get_world_size()
  dist.barrier()

  # Config
  device = torch.device(f"cuda:{rank}") if torch.cuda.is_available() else torch.device("cpu")
  pp_group = dist.new_group()
  stage_index = rank
  num_stages = world_size
  num_microbatches = 4
  loss_fn = F.cross_entropy

  # Model
  torch.manual_seed(123)
  model_args = ModelArgs()
  model = Transformer(model_args)

  # Input
  x = torch.ones(32, 500, dtype=torch.long)   # [B, S]
  y = torch.randint(0, model_args.vocab_size, (32, 500), dtype=torch.long)

  # Partition
  def tokenwise_loss_fn(outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    B, S, VOCAB_SIZE = outputs.shape
    B_TARGET, S_TARGET = targets.shape
    assert B == B_TARGET
    assert S == S_TARGET

    outputs = outputs.reshape(-1, VOCAB_SIZE)
    targets = targets.reshape(-1)
    return loss_fn(outputs, targets)

  stage = manual_model_split(model, stage_index=stage_index, num_stages=num_stages, device=device)
  schedule = ScheduleGPipe(stage, n_microbatches=num_microbatches, loss_fn=tokenwise_loss_fn)

  model.to(device)
  x.to(device)
  y.to(device)

  # Step
  if rank == 0:
    schedule.step(x)
  else:
    losses = []
    schedule.step(target=y, losses=losses)
    losses = [l.item() for l in losses]
    print(f"{losses=}")

  dist.destroy_process_group()



os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12359' # You can choose a different port if 12355 is in use

world_size = 2

processes = []
for rank in range(world_size):
  p = mp.Process(target=init_process, args=(rank, world_size))
  p.start()
  processes.append(p)

for p in processes:
  p.join()

Starting process with rank=0, world_size=2
Starting process with rank=1, world_size=2
losses=[4.805856704711914, 4.783386707305908, 4.796078205108643, 4.799017906188965]
