### 1. Sequence Parallelism

Goal: Split work along the sequence dimension (tokens) so that each GPU holds different tokens of the same batch.

#### How it works
	•	The input sequence (e.g., 2048 tokens) is divided across GPUs.
	•	GPU 0 processes tokens 0–1023
	•	GPU 1 processes tokens 1024–2047
	•	Each GPU computes its part of attention, feed-forward, layer norm, etc.
	•	Required cross-GPU operations:
	•	All-gather of hidden states
	•	Reduce-scatter for attention outputs
	•	Sharded softmax, sharded dropout, etc.

#### Why it helps
	•	Reduces activation memory per GPU (biggest memory saver for long sequences).
	•	Especially useful for long-context LLMs.
	•	Compatible with tensor parallelism.

Analogy:
Each GPU handles “its slice” of the sequence like chapters in a book, but they sync before/after attention.


### 2. Pipeline Parallelism

Goal: Split the model layers across GPUs.

Example: A 48-layer transformer across 4 GPUs:

	•	GPU 0 → Layers 1–12
	•	GPU 1 → Layers 13–24
	•	GPU 2 → Layers 25–36
	•	GPU 3 → Layers 37–48

How it works

Forward pass flows like a pipeline:
Batch → GPU0 → GPU1 → GPU2 → GPU3 → loss

Backward pass flows in reverse.

To avoid idle GPUs, we use microbatching:

Microbatch 1 in GPU0  
Microbatch 2 in GPU0 while Microbatch 1 is in GPU1  
...

Why it helps
	•	Allows training of very deep networks that don’t fit on a single GPU.
	•	Good for vertical model partitioning.

Cost
	•	Pipeline “bubble” (idle time) unless microbatching is tuned well.
	•	Adds communication cost between stages.

Analogy

An assembly line: each GPU handles a segment of the model.

### 3. Tensor Parallelism (Model Parallelism)

Goal: Split individual layers (matrices) across GPUs.

Example: a linear layer

Y = XW

where W is too large for a single GPU.

Split W:

	•	Horizontal split: shard output features
	•	Vertical split: shard input features

How it works

GPU 0 computes part of XW
GPU 1 computes part of XW
GPU 2 computes part …
Then reduce-sum or concat to form the full output.

Used for:

	•	Large feed-forward layers
	•	Large attention projections (Q/K/V, output matrices)

Why it helps

	•	Allows a single layer with billions of parameters to run across multiple GPUs.
	•	Required for modern 70B–500B+ parameter models.

Cost

	•	Heavy all-reduce communications.
	•	Needs high-speed interconnect (NVLink, H100 NVSwitch).

Analogy:
Each GPU holds part of a giant matrix. They all compute partial results and merge them.

### How They Fit Together

Modern large-scale training uses all three:

Tensor Parallelism → lets each layer fit in GPU memory

Pipeline Parallelism → lets all layers fit across GPUs

Sequence Parallelism → keeps activation memory manageable

Megatron-LM, DeepSpeed, PaLM, Llama, GPT-4 training stacks use a combination of these.

### Tensor Parallelism (shard big matrices across GPUs)

Idea: split large Linear/attention projections across a tensor group; each GPU computes a partial result → all-reduce/concat to form the full output.

In [None]:
# Assuming we have 8 GPUs
# global ranks: 0 1 2 3 4 5 6 7
# TP Group 0: [0, 1, 2, 3]
# TP Group 1: [4, 5, 6, 7]
# tp_ranks = [0, 1, 2, 3]       # for the first group
# tp_ranks = [4, 5, 6, 7]       # for the second group


# world: N GPUs
# create a tensor-parallel group tp_group (subset of ranks), size = TP
tp_group = dist.new_group(ranks=tp_ranks)

class TensorParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, tp_group):
        super().__init__()
        self.tp_group = tp_group
        self.tp_size = dist.get_world_size(tp_group)

        # column-parallel shard: split OUT features across GPUs
        assert out_features % self.tp_size == 0
        out_local = out_features // self.tp_size
        self.weight = nn.Parameter(torch.empty(out_local, in_features))
        self.bias   = nn.Parameter(torch.empty(out_local))

        # init weights (e.g., xavier) … and place on local device
        ...

    def forward(self, x):
        # x: [B, *, in_features], replicated across tp_group
        # local matmul -> partial output shard
        y_local = F.linear(x, self.weight, self.bias)  # [B, *, out_local]
        # to assemble full output, all-gather shards along last dim
        y_list = [torch.empty_like(y_local) for _ in range(self.tp_size)]
        dist.all_gather(y_list, y_local, group=self.tp_group)
        y_full = torch.cat(y_list, dim=-1)            # [B, *, out_features]
        return y_full

class TensorParallelRowLinear(nn.Module):
    """Row-parallel shard: split IN features, then all-reduce partial outputs."""
    def __init__(self, in_features, out_features, tp_group):
        super().__init__()
        self.tp_group = tp_group
        self.tp_size  = dist.get_world_size(tp_group)
        assert in_features % self.tp_size == 0
        in_local = in_features // self.tp_size
        self.weight = nn.Parameter(torch.empty(out_features, in_local))
        self.bias   = nn.Parameter(torch.empty(out_features))
        ...

    def forward(self, x_sharded):
        # x_sharded is pre-split across tp_group: [B, *, in_local]
        # each rank computes partial matmul -> [B, *, out_features]
        y_partial = F.linear(x_sharded, self.weight)  # no bias yet
        # sum partials across ranks (all-reduce)
        dist.all_reduce(y_partial, op=dist.ReduceOp.SUM, group=self.tp_group)
        y_full = y_partial + self.bias
        return y_full

# In attention, do the same for QKV and out-proj:
# - QKV: column-parallel (gather after attention)
# - out-proj: row-parallel (all-reduce)

	•	Choose column-parallel for projections that are later concatenated; row-parallel for projections that are summed.
	•	Needs fast interconnect (NVLink/NVSwitch).

### Pipeline Parallelism (split layers into stages across GPUs)

Idea: partition the model by layers into P stages on different ranks; feed microbatches through a 1F1B (one-forward-one-backward) schedule to reduce bubbles.

In [None]:
# Suppose world size = P * DP (ignore TP here for clarity)
# Define stage_id per rank, create p2p connections with next/prev stage.
stage_id = ...           # 0 .. P-1
prev_rank = rank-1 if stage_id > 0 else None
next_rank = rank+1 if stage_id < P-1 else None

# Split model layers
full_layers = build_transformer_layers(L)
partitions = split_layers(full_layers, P)        # list of layer-slices
stage = nn.Sequential(*partitions[stage_id]).to(local_device)

def run_pipeline_step(microbatch):
    """One microbatch forward on this stage."""
    if stage_id == 0:
        x = microbatch.to(local_device)          # source stage reads data
    else:
        x = recv_tensor(src=prev_rank)

    # Forward through local layers
    y = stage(x)

    if stage_id == P-1:
        # compute loss/logits; return to host or start backward
        logits = y
        return logits
    else:
        send_tensor(y, dst=next_rank)

def train_one_iteration(microbatches):
    """Naive 1F1B sketch; real schedulers overlap F/B + p2p."""
    # warmup forwards to fill the pipe
    for mb in microbatches[:stage_id]:
        run_pipeline_step(mb)

    # steady state: for each mb, do F; receive grad; do B; send grad back
    for i, mb in enumerate(microbatches[stage_id: stage_id + (len(microbatches) - (P-1 - stage_id))]):
        # FORWARD
        y = run_pipeline_step(mb)

        # receive grad from next stage (except last)
        if stage_id < P-1:
            grad_y = recv_tensor(src=next_rank)
            y.backward(grad_y)
        else:
            # last stage computes loss and starts backward
            loss = loss_fn(y, labels_for_mb(i))
            loss.backward()
            # send grad to prev
            send_tensor(y.grad, dst=prev_rank)

    # drain the pipe: remaining backward passes
    ...

# send/recv utilities (pseudocode)
def send_tensor(t, dst):  dist.send(t.detach(), dst=dst)
def recv_tensor(src):     buf = torch.empty_like(proto_tensor); dist.recv(buf, src=src); return buf

	•	Real implementations use torch.distributed.pipeline (or DeepSpeed/Megatron schedulers), overlapping p2p comms with compute and finely tuning microbatch size to minimize bubbles.

### Sequence Parallelism (shard tokens across GPUs inside a TP group)

Idea: split the sequence length across ranks (e.g., 4 GPUs each hold 1/4 of the tokens). Use all-gather/reduce-scatter around places that need full-sequence context (attention), and keep per-token ops local.

In [None]:
# Assume we already have a tensor-parallel group tp_group of size = SP (sequence-parallel size)
sp_group = tp_group
rank_in_sp = dist.get_rank(sp_group)
sp_size    = dist.get_world_size(sp_group)

def split_sequence(x):
    # x: [B, S, H], shard S across ranks
    assert S % sp_size == 0
    S_local = S // sp_size
    start = rank_in_sp * S_local
    end   = start + S_local
    return x[:, start:end, :]  # [B, S_local, H]

def allgather_sequence(x_local):
    # gather shards along S to recover [B, S, H]
    bufs = [torch.empty_like(x_local) for _ in range(sp_size)]
    dist.all_gather(bufs, x_local, group=sp_group)
    return torch.cat(bufs, dim=1)

def reducescatter_sequence(y_full):
    # inverse of all-gather after ops that produce [B,S,H]
    # split y_full along S then sum (or just split if concatenate semantics)
    chunks = list(y_full.chunk(sp_size, dim=1))
    y_local = torch.empty_like(chunks[0])
    dist.reduce_scatter(y_local, chunks, op=dist.ReduceOp.SUM, group=sp_group)
    return y_local

class SeqParallelBlock(nn.Module):
    def __init__(self, attn, mlp):
        super().__init__()
        self.attn = attn    # standard attention submodule
        self.mlp  = mlp

    def forward(self, h_full):
        # h_full: [B,S,H] (typically comes from previous layer’s gather or the input)
        # 1) shard tokens across ranks
        h_local = split_sequence(h_full)            # [B,S_local,H]

        # 2) per-token ops stay local (e.g., LayerNorm on last dim)
        h_local = layernorm(h_local)

        # 3) attention needs global context across S:
        #    gather all tokens -> run attention -> scatter back
        h_all = allgather_sequence(h_local)         # [B,S,H]
        a_all = self.attn(h_all)                    # [B,S,H]
        a_local = reducescatter_sequence(a_all)     # [B,S_local,H]

        # 4) MLP is position-wise → can be local if implemented with TP-friendly shards
        #    (or do gather->mlp->scatter similarly, depending on layout)
        m_local = self.mlp(a_local)                 # [B,S_local,H]

        return m_local

# Stacking blocks:
#   for each transformer layer within a TP group, alternate split/gather/scatter as needed.

	•	Sequence parallelism is mainly a training trick to reduce activation memory for long sequences.
	•	For inference decoding, we usually avoid SP because it hurts KV-cache locality and prefix reuse.