In [6]:
import os
import torch
import torch.nn as nn
import torch.distributed as dist

In [7]:
def setup():
    rank = os.environ['RANK']
    world_size = os.environ['WORLD_SIZE']

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group(
        backend='nccl',
        rank=rank,
        world_size=world_size
    )

    torch.cuda.set_device(rank)

In [8]:
def cleanup():
    dist.destroy_process_group()

In [9]:
class PipelineStage(nn.Module):
    def __init__(self, layers):
        super().__init__()

        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        out = x
        for layer in self.layers:
            out = layer(out)

        return out

In [10]:
def create_stages(hidden_stages, n_layers, world_size):
    n_layers_per_stage = n_layers // world_size

    all_layers = []
    for _ in range(n_layers):
        layers = nn.Sequential(
            nn.Linear(hidden_stages, hidden_stages),
            nn.GELU()
        )
        all_layers.append(layers)

    stages = []
    for idx in range(world_size):
        start_idx = idx * n_layers_per_stage
        end_idx = (idx + 1) * n_layers_per_stage

        layers = all_layers[start_idx : end_idx]
        stages.append(PipelineStage(layers))

    return stages

In [None]:
class GPipePipeline:
    """
    GPipe-style pipeline parallelism.
    
    Phases:
    1. Fill: Forward all micro-batches (GPU 0 → GPU N)
    2. Drain: Backward all micro-batches (GPU N → GPU 0)
    """
    def __init__(
        self, 
        stage: PipelineStage, 
        rank: int, 
        world_size: int, 
        n_micro_batches: int,
        micro_batch_size: int,
        input_shape: tuple, 
        dtype=torch.float32
    ):
        """
        Args:
            stage: The model stage (layers) for this rank
            rank: This GPU's rank
            world_size: Total number of pipeline stages
            micro_batches: Number of micro-batches per training step
            micro_batch_size: Size of each micro-batch
            input_shape: Shape of input WITHOUT batch dimension
            dtype: Data type for tensors
        """
        self.stage = stage
        self.rank = rank
        self.world_size = world_size
        self.n_micro_batches = n_micro_batches
        self.micro_batch_size = micro_batch_size
        self.input_shape = input_shape
        self.dtype = dtype

        self.next_rank = self.rank + 1 if self.rank < self.world_size - 1 else None
        self.prev_rank = self.rank - 1 if self.rank > 0 else None

    def send(self, tensor: torch.Tensor, send_rank: int):
        if send_rank is not None:
            dist.send(tensor.contiguous(), dst=send_rank)
        
    def recv(self, tensor_buffer: torch.tensor, recv_rank: int):
        if self.prev_rank is not None:
            dist.recv(tensor_buffer, src=recv_rank)
            return tensor_buffer
        
        return None
    
    def forward(self, micro_batch: torch.tensor):
        micro_batch.requires_grad = True
        output = self.stage(micro_batch)
        
        return output, micro_batch
    
    def backward(self, grad: torch.tensor, output: torch.tensor, saved_input: torch.tensor):
        output.backward(grad)
        return saved_input.grad
    
    def train_step(self, batch):
        """
        GPipe training step with two phases.
        
        Args:
            batch: Input batch (only used by rank 0, can be None for other ranks)
        """

        saved_inputs = []
        outputs = []

        for idx in range(self.n_micro_batches):
            if self.rank == 0:
                start_idx = idx * self.micro_batch_size
                end_idx = (idx + 1) * self.micro_batch_size

                micro_batch = batch[start_idx : end_idx]
            else:
                tensor_buffer = torch.zeros((self.micro_batch_size, *self.input_shape), dtype=self.dtype)
                micro_batch = self.recv(tensor_buffer, self.prev_rank)

            output, saved_input = self.forward(micro_batch)

            outputs.append(output)
            saved_inputs.append(saved_input)

            self.send(saved_input, self.next_rank)

        for idx in range(self.micro_batch_size - 1, -1, -1):
            if self.rank == self.world_size - 1:
                grad_input = torch.ones_like(outputs[idx], dtype=self.dtype)
            else:
                tensor_buffer = torch.zeros_like(output[idx], dtype=self.dtype)
                grad_input = self.recv(tensor_buffer, self.next_rank)

            grad_output = self.backward(grad_input, outputs[idx], saved_input[idx])

            self.send(grad_output, self.prev_rank)

In [14]:
class OneFOneBPipeline:
    def __init__(
        self, 
        stage: PipelineStage, 
        rank: int, 
        world_size: int, 
        n_micro_batches: int,
        micro_batch_size: int,
        input_shape: tuple, 
        dtype=torch.float32
    ):
        self.stage = stage
        self.rank = rank
        self.world_size = world_size
        self.n_micro_batches = n_micro_batches
        self.micro_batch_size = micro_batch_size
        self.input_shape = input_shape
        self.dtype = dtype

        self.next_rank = self.rank + 1 if self.rank < self.world_size - 1 else None
        self.prev_rank = self.rank - 1 if self.rank > 0 else None

        self.warmup_states = world_size - rank - 1
        self.active_states = n_micro_batches - self.warmup_states

    def send(self, tensor: torch.tensor, send_rank: int):
        if send_rank is not None:
            dist.send(tensor.contiguous(), dst=send_rank)

    def recv(self, tensor_buffer: torch.tensor, recv_rank: int):
        if recv_rank is not None:
            dist.recv(tensor_buffer, src=recv_rank)
            return tensor_buffer
        
        return None
    
    def forward(self, micro_batch: torch.tensor):
        micro_batch.requires_grad = True
        output = self.stage(micro_batch)

        return output, micro_batch
    
    def backward(self, grad_input: torch.tensor, output: torch.tensor, saved_input: torch.tensor):
        output.backward(grad_input)
        return saved_input.grad
    
    def train_step(self, batch):
        saved_inputs = []
        outputs = []

        for idx in range(self.warmup_states):
            if self.rank == 0:
                start_idx = idx * self.micro_batch_size
                end_idx = (idx + 1) * self.micro_batch_size

                micro_batch = batch[start_idx : end_idx]
            else:
                tensor_buffer = torch.zeros((self.micro_batch_size, *self.input_shape), dtype=self.dtype)
                micro_batch = self.recv(tensor_buffer, self.prev_rank)

            output, saved_input = self.forward(micro_batch)

            outputs.append(output)
            saved_inputs.append(saved_input)

            self.send(output, self.next_rank)

        for idx in range(self.active_states):
            if self.rank == 0:
                start_idx = (idx + self.warmup_states) * self.micro_batch_size
                end_idx = (idx + 1 + self.warmup_states) * self.micro_batch_size

                micro_batch = batch[start_idx : end_idx]
            else:
                tensor_buffer = torch.zeros((self.micro_batch_size, *self.input_shape), dtype=self.dtype)
                micro_batch = self.recv(tensor_buffer, self.prev_rank)

            output, saved_input = self.forward(micro_batch)

            outputs.append(output)
            saved_inputs.append(saved_input)

            self.send(output, self.next_rank)

            output = outputs.pop(0)
            saved_input = saved_inputs.pop(0)

            if self.rank == self.world_size - 1:
                grad_input = torch.ones_like(output, dtype=self.dtype)
            else:
                tensor_buffer = torch.zeros_like(output, dtype=self.dtype)
                grad_input = self.recv(tensor_buffer, self.next_rank)

            grad_output = self.backward(grad_input, output, saved_input)

            self.send(grad_output, self.prev_rank)

        for idx in range(self.warmup_states):
            output = outputs[idx]
            saved_input = saved_inputs[idx]

            if self.rank == self.world_size - 1:
                grad_input = torch.ones_like(output, dtype=self.dtype)
            else:
                tensor_buffer = torch.zeros_like(output, dtype=self.dtype)
                grad_input = self.recv(tensor_buffer, self.next_rank)

            grad_output = self.backward(grad_input, output, saved_input)

            self.send(grad_output, self.prev_rank)

In [15]:
def train_gpipe():
    rank = os.environ['rank']
    world_size = os.environ['world_size']

    setup(rank, world_size)

    hidden_size = 512
    num_layers = 16
    batch_size = 32
    num_micro_batches = 8
    micro_batch_size = batch_size // num_micro_batches  # = 4
    
    stages = create_stages(hidden_size, num_layers, world_size)
    stage = stages[rank].to(rank)

    pipeline = GPipePipeline(
        stage=stage,
        rank=rank,
        world_size=world_size,
        micro_batches=num_micro_batches,
        micro_batch_size=micro_batch_size,
        input_shape=(hidden_size,),
        dtype=torch.float32
    )

    optimizer = torch.optim.Adam(stage.parameters(), lr=1e-3)

    for i in range(10):
        if rank == 0:
            batch = torch.randn(batch_size, hidden_size, device=rank)
        else:
            batch = None

        optimizer.zero_grad()
        pipeline.train_step(batch)
        optimizer.step()

    cleanup()
    