In [5]:
import torch
import torch.nn as nn
import torch.distributed as dist 
import os

In [6]:
def setup_rank(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group(
        backend='nccl',
        rank=rank,
        world_size=world_size
    )

    torch.cuda.set_device(rank)

In [7]:
def cleanup_rank():
    dist.destroy_process_group()

In [10]:
class PipelineStage(nn.Module):
    def __init__(self, layers):
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        out = x
        for layer in self.layers:
            out = layer(out)

        return out

In [11]:
def create_model_stages(hidden_size, n_layers, world_size):
    n_layers_per_stage = n_layers // world_size

    all_layers = []
    for _ in range(n_layers):
        all_layers.append(
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.GELU()
            )
        )

    stages = []
    for i in range(world_size):
        start_idx = i * n_layers_per_stage
        end_idx = start_idx + n_layers_per_stage

        stage_layers = all_layers[start_idx: end_idx]
        stages.append(PipelineStage(stage_layers))

    return stages

In [None]:
class GPipePipeline:
    """
    GPipe-style pipeline parallelism.
    
    Phases:
    1. Fill: Forward all micro-batches (GPU 0 → GPU N)
    2. Drain: Backward all micro-batches (GPU N → GPU 0)
    """
    def __init__(self, stage, rank, world_size, n_micro_batches, micro_batch_size, input_size, dtype=torch.float32):
        """
        Args:
            stage: The model stage (layers) for this rank
            rank: This GPU's rank
            world_size: Total number of pipeline stages
            micro_batches: Number of micro-batches per training step
            micro_batch_size: Size of each micro-batch
            input_shape: Shape of input WITHOUT batch dimension
            dtype: Data type for tensors
        """
        self.stage = stage
        self.rank = rank
        self.world_size = world_size
        self.n_micro_batches = n_micro_batches
        self.micro_batch_size = micro_batch_size
        self.input_size = input_size
        self.dtype = dtype

        self.prev_rank = self.rank - 1 if self.rank > 0 else None
        self.next_rank = self.rank + 1 if self.rank < self.world_size - 1 else None

    def send(self, activation):
        if self.next_rank is not None:
            dist.send(activation.contiguous(), dst=self.next_rank)

    def recv(self, shape):
        if self.prev_rank is not None:
            tensor = torch.zeros(shape, dtype=self.dtype, device=self.rank)
            dist.recv(tensor, src=self.prev_rank)
            return tensor
        
        return None
    
    def forward_step(self, micro_batch):
        micro_batch.required_grad = True
        batch_output = self.stage(micro_batch)

        return batch_output, micro_batch
    
    def backward_step(self, output, grad_output, input):
        output.backward(grad_output)
        return input.grad
    
    def train_step(self, batch):
        """
        GPipe training step with two phases.
        
        Args:
            batch: Input batch (only used by rank 0, can be None for other ranks)
        """

        saved_inputs = []
        outputs = []

        for i in range(self.n_micro_batches):
            if self.rank == 0:
                start_idx = i * self.micro_batch_size
                end_idx = start_idx + self.micro_batch_size

                micro_batch = batch[start_idx : end_idx]
            else:
                micro_batch = self.recv((self.micro_batch_size, *self.input_size))

            saved_inputs.append(micro_batch)
            output = self.forward_step(micro_batch)
            outputs.append(output)

            self.forward(output)

        for i in range(self.n_micro_batches - 1, -1, -1):
            saved_input = saved_inputs[i]
            output = outputs[i]

            if self.rank == self.world_size - 1:
                grad_output = torch.ones_like(output, dtype=self.dtype, device=self.rank)
            else:
                grad_output = self.recv(output.shape)

            grad_input = self.backward_step(output, grad_output, saved_input)
            if grad_input is not None:
                self.send(grad_input)

In [None]:
class OneFOneBPipeline:
    def __init__(self, stage, rank, world_size, n_micro_batches, micro_batch_size, input_size, dtype=torch.float32):
        self.stage = stage
        self.rank = rank
        self.world_size = world_size
        self.n_micro_batches = n_micro_batches
        self.micro_batch_size = micro_batch_size
        self.input_size = input_size
        self.dtype = dtype

        self.prev_rank = self.rank - 1 if self.rank > 0 else None
        self.next_rank = self.rank + 1 if self.rank < self.world_size - 1 else None

        self.warmup_states = world_size - rank - 1
        self.steady_states = n_micro_batches - self.warmup_states

    def send(self, activations):
        if self.next_rank is not None:
            dist.send(activations.contiguous(), dst=self.next_rank)

    def recv(self, shape):
        if self.prev_rank is not None:
            tensor = torch.zeros(shape, dtype=self.dtype, device=self.rank)
            dist.recv(tensor, src=self.prev_rank)
            return tensor
        
        return None
    
    def forward_step(self, micro_batch):
        micro_batch.requires_grad = True
        output = self.stage(micro_batch)
        return output, micro_batch
    
    def backward_step(self, output, grad_output, input):
        output.backwards(grad_output)
        return input.grad
    
    def train_step(self, batch):
        """
        1F1B training step with three phases.
        
        Args:
            batch: Input batch (only used by rank 0, can be None for other ranks)
        """

        saved_inputs = []
        outputs = []

        for i in range(self.warmup_states):
            if self.rank == 0:
                start_idx = i * self.micro_batch_size
                end_idx = start_idx + self.micro_batch_size

                micro_batch = batch[start_idx : end_idx]
            else:
                micro_batch = self.recv((self.micro_batch_size, *self.input_size))
            
            output = self.forward_step(micro_batch)
            self.send(output)

            saved_inputs.append(micro_batch)
            outputs.append(output)

        for i in range(self.steady_states):
            micro_batch_idx = i + self.warmup_states

            if self.rank == 0:
                start_idx = micro_batch_idx * self.micro_batch_size
                end_idx = start_idx + self.micro_batch_size

                micro_batch = batch[start_idx : end_idx]
            else:
                micro_batch = self.recv((self.micro_batch_size, *self.input_size))

            output = self.forward_step(micro_batch)
            self.send(output)

            saved_inputs.append(micro_batch)
            outputs.append(output)

            output = outputs.pop(0)
            saved_input = saved_inputs.pop(0)

            if self.rank == self.world_size - 1:
                grad_output = torch.ones_like(output, dtype=self.dtype, device=self.rank)
            else:
                grad_output = self.recv(output.shape)

            grad_input = self.backward_step(output, grad_output, saved_input)
            if grad_input is not None:
                self.send(grad_input)

        for i in range(self.warmup_states):
            output = outputs.pop(0)
            saved_input = saved_inputs.pop(0)

            if self.rank == self.world_size - 1:
                grad_output = torch.ones_like(output, dtype=self.dtype, device=self.rank)
            else:
                grad_output = self.recv(output.shape)

            grad_input = self.backward_step(output, grad_output, saved_input)
            if grad_input is not None:
                self.send(grad_input)

In [None]:
def train_gpipe():
    rank = os.environ['rank']
    world_size = os.environ['world_size']

    setup_rank(rank, world_size)

    hidden_size = 512
    num_layers = 16
    batch_size = 32
    num_micro_batches = 8
    micro_batch_size = batch_size // num_micro_batches  # = 4
    
    stages = create_model_stages(hidden_size, num_layers, world_size)
    stage = stages[rank].to(rank)

    pipeline = GPipePipeline(
        stage=stage,
        rank=rank,
        world_size=world_size,
        micro_batches=num_micro_batches,
        micro_batch_size=micro_batch_size,
        input_shape=(hidden_size,),
        dtype=torch.float32
    )

    optimizer = torch.optim.Adam(stage.parameters(), lr=1e-3)

    for i in range(10):
        if rank == 0:
            batch = torch.randn(batch_size, hidden_size, device=rank)
        else:
            batch = None

        optimizer.zero_grad()

        pipeline.train_step(batch)

        optimizer.step()

    cleanup_rank()
    