In [4]:
import torch
import torch.nn as nn
import torch.distributed as dist
import os

In [5]:
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12359'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

In [6]:
def create_model_stages(hidden_size, num_layers, world_size):
    """Split model into pipeline stages."""
    layers_per_stage = num_layers // world_size
    
    all_layers = []
    for i in range(num_layers):
        all_layers.append(
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU()
            )
        )
    
    stages = []
    for rank in range(world_size):
        start = rank * layers_per_stage
        end = start + layers_per_stage
        stage_layers = all_layers[start:end]
        stages.append(PipelineStage(stage_layers))
    
    return stages


In [7]:
class PipelineStage(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layer = nn.ModuleList(layers)

    def forward(self, x):
        out = x
        for layer in self.layers:
            out = layer(out)
            
        return out

In [8]:
class GPipePipeline:
    """
    GPipe-style pipeline parallelism.
    
    Phases:
    1. Fill: Forward all micro-batches (GPU 0 → GPU N)
    2. Drain: Backward all micro-batches (GPU N → GPU 0)
    """
    def __init__(self, stage, rank, world_size, micro_batches, micro_batch_size, input_size, dtype=torch.float32):
        """
        Args:
            stage: The model stage (layers) for this rank
            rank: This GPU's rank
            world_size: Total number of pipeline stages
            micro_batches: Number of micro-batches per training step
            micro_batch_size: Size of each micro-batch
            input_shape: Shape of input WITHOUT batch dimension
            dtype: Data type for tensors
        """

        self.stage = stage
        self.rank = rank
        self.world_size = world_size
        self.micro_batches = micro_batches
        self.micro_batch_size = micro_batch_size
        self.input_size = input_size
        self.dtype = dtype

        self.prev_rank = rank - 1 if rank > 0 else None
        self.next_rank = rank + 1 if rank < world_size - 1 else None

    def send_forward(self, tensor):
        if self.next_rank is not None:
            dist.send(tensor.contiguous(), dst=self.next_rank)

    def recv_forward(self):
        if self.prev_ranl is not None:
            shape = (self.micro_batch_size, *self.input_size)
            tensor = torch.zeros(shape, dtype=self.dtype, device=f'cuda:{self.rank}')
            dist.recv(tensor, src=self.next_rank)
            return tensor

        return None

    def send_backward(self, tensor):
        if self.prev_rank is not None:
            dist.send(tensor.contiguous(), dst=self.prev_rank)

    def recv_backward(self, output_shape):
        if self.next_rank is not None:
            tensor = torch.zeros(output_shape, dtype=self.dtype, device=f'cuda:{self.rank}')
            dist.recv(tensor, src=self.next_rank)
            return tensor

        return None

    def forward_step(self, input_tensor):
        """
        Forward through this stage.
        
        Returns:
            output: Output activation
            input_tensor: Saved input for backward
        """
        input_tensor.requires_grad = True
        output = self.stage(input_tensor)
        return output, input_tensor

    def backward_step(self, output, grad_output, input):
        """
        Backward through this stage.
        
        Returns:
            grad_input: Gradient w.r.t. input
        """
        output.backward(grad_output)
        return input.grad

    def train_step(self, batch):
        """
        GPipe training step with two phases.
        
        Args:
            batch: Input batch (only used by rank 0, can be None for other ranks)
        """
        input_tensor = []
        output_tensor = []

        for i in range(self.micro_batches):
            if self.rank == 0:
                start_idx = i * self.micro_batch_size
                end_idx = (i + 1) * self.micro_batch_size
                micro_batch = batch[start_idx:end_idx].to(f'cuda:{self.rank}')
            else:
                input_tensor = self.recv_forward()

            output, saved_input = self.forward_step(input_tensor)

            input_tensor.append(saved_input)
            output_tensor.append(output)

            self.send_forward(output)

        for i in range(self.micro_batches - 1, -1, -1):
            output = output_tensor[i]
            saved_input = input_tensor[i]

            if self.rank == self.world_size - 1:
                grad_output = torch.onee_like(output)
            else:
                grad_output = self.recv_backward(output.shape)

            grad_input = self.backward_step(output, grad_output, saved_input)

            if grad_input is not None:
                self.send_backward(grad_input)

In [9]:
def train_gpipe():
    rank = os.environ['rank']
    world_size = os.environ['world_size']

    setup(rank, world_size)

    hidden_size = 512
    num_layers = 16
    batch_size = 32
    num_micro_batches = 8
    micro_batch_size = batch_size // num_micro_batches  # = 4
    
    stages = create_model_stages(hidden_size, num_layers, world_size)
    stage = stages[rank].to(rank)

    pipeline = GPipePipeline(
        stage=stage,
        rank=rank,
        world_size=world_size,
        micro_batches=num_micro_batches,
        micro_batch_size=micro_batch_size,
        input_shape=(hidden_size,),
        dtype=torch.float32
    )

    optimizer = torch.optim.Adam(stage.parameters(), lr=1e-3)

    for i in range(10):
        if rank == 0:
            batch = torch.randn(batch_size, hidden_size, device=rank)
        else:
            batch = None

        optimizer.zero_grad()

        pipeline.train_step(batch)

        optimizer.step()

    cleanup()
    

In [11]:
class OneFOneBPipeline(nn.Module):
    def __init__(self, stage, rank, world_size, micro_batches, micro_batch_size, input_shape, dtype=torch.float32):
        super().__init__()

        self.stage = stage
        self.rank = rank
        self.world_size = world_size
        self.micro_batches = micro_batches
        self.micro_batch_size = micro_batch_size
        self.input_shape = input_shape
        self.dtype = dtype

        self.prev_rank = rank - 1 if rank > 0 else None
        self.next_rank = rank + 1 if rank < world_size - 1 else None

        self.warmup_states = world_size - rank - 1
        self.steady_states = micro_batches - self.warmup_states

    def send_forward(self, tensor):
        if self.next_rank is not None:
            dist.send(tensor.contiguous(), dst=self.next_rank)

    def recv_forward(self, output_shape):
        if self.prev_rank is not None:
            shape = (self.micro_batch_size, *self.input_shape)
            tensor = torch.zeros(shape, dtype=self.dtype, device=f'cuda:{self.rank}')
            dist.recv(tensor, src=self.prev_rank)
            return tensor

        return None

    def send_backward(self, tensor):
        if self.prev_rank is not None:
            dist.send(tensor.contiguous(), dst=self.prev_rank)
        
    def recv_backward(self, output_shape):
        if self.next_rank is not None:
            tensor = torch.zeros(output_shape, dtype=self.dtype, device=f'cuda:{self.rank}')
            dist.recv(tensor, src=self.next_rank)
            return tensor

        return None

    def forward_step(self, micro_batch):
        micro_batch.requires_grad = True
        output = self.stage(micro_batch)
        return output, micro_batch

    def backward_step(self, output, grad_output, input_tensor):
        output.backward(grad_output)
        return input_tensor.grad

    def train_step(self, batch):
        """
        1F1B training step with three phases.
        
        Args:
            batch: Input batch (only used by rank 0, can be None for other ranks)
        """

        input_tensors = []
        output_tensors = []

        for i in range(self.warmup_states):
            if self.rank == 0:
                start_idx = i * self.micro_batch_size
                end_idx = (i + 1) * self.micro_batch_size
                micro_batch = batch[start_idx:end_idx].to(f'cuda:{self.rank}')
            else:
                micro_batch = self.recv_forward()

            output, input = self.forward_step(micro_batch)

            input_tensors.append(input)
            output_tensors.append(output)

            self.send_forward(output)

        for i in range(self.steady_states):
            mb_idx = self.warmup_states * i

            if self.rank == 0:
                start_idx = mb_idx * self.micro_batch_size
                end_idx = (mb_idx + 1) * self.micro_batch_size
                micro_batch = batch[start_idx:end_idx].to(f'cuda:{self.rank}')
            else:
                micro_batch = self.recv_forward()

            output, input = self.forward_step(micro_batch)

            input_tensors.append(input)
            output_tensors.append(output)

            self.send_forward(output)

            output = output_tensors.pop(0)
            saved_input = input_tensors.pop(0)

            if self.rank == self.world_size - 1:
                grad_output = torch.ones_like(output, dtype=self.dtype, device=f'cuda:{self.rank}')
            else:
                grad_output = self.recv_backward(output.shape)

            grad_input = self.backward_step(output, grad_output, saved_input)

            if grad_input is not None:
                self.send_backward(grad_input)

        for _ in range(self.warmup_states):
            output = output_tensors.pop(0)
            saved_input = input_tensors.pop(0)
            
            if self.rank == self.world_size - 1:
                grad_output = torch.ones_like(output)
            else:
                grad_output = self.recv_backward(output.shape)
            
            grad_input = self.backward_step(output, grad_output, saved_input)
            
            if grad_input is not None:
                self.send_backward(grad_input)