In [1]:
import torch
import torch.nn as nn
import torch.distributed as dist

In [2]:
class SequenceParallelLayerNorm(nn.Module):
    def __init__(self, d_model, world_size, rank, eps=1e-5):
        super().__init__()

        self.d_model = d_model
        self.world_size = world_size
        self.rank = rank
        self.eps = eps

        self.alpha = nn.Parameter(torch.randn(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True)

        normalised = (x - mean) / (var + self.eps)
        normalised = self.alpha * normalised + self.beta

        return normalised

In [3]:
class SequenceParallelDropout(nn.Module):
    def __init__(self, p, world_size, rank):
        super().__init__()

        self.p = p
        self.world_size = world_size
        self.rank = rank

        self.register_buffer("rng_seed", torch.zeros(1, dtype=torch.long))

    def forward(self, x):
        if not self.training or self.p == 0:
            return x
        
        if self.rank == 0:
            self.rng_seed.random_()

        dist.broadcast(self.rng_seed, src=0)
        seed = int(self.rng_seed.item())

        with torch.random.fork_rng(devices=[x.device]):
            torch.random.manual_seed(seed)

            return nn.functional.dropout(x, self.p, training=True)

In [9]:
def sp_to_tp_transitions(tensor, world_size, rank):
    batch_size, seq_len_local, d_model = tensor.shape
    d_model_local = d_model // world_size

    tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)]
    dist.all_gather(tensor_list, tensor)

    full_tensor = torch.cat(tensor_list, dim=1)

    return full_tensor

In [6]:
def tp_to_sp_transitions(tensor, world_size, rank):
    batch_size, seq_len, d_model_local = tensor.shape
    d_model = d_model_local * world_size
    seq_loval = seq_len // world_size

    tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)]
    dist.all_gather(tensor_list, tensor)

    full_tensor = torch.cat(tensor_list, dim=2)

    chunks = torch.chunk(full_tensor, world_size, dim=1)

    output = chunks[rank].contiguous()

    return output

In [7]:
class ColumnParallelism(nn.Module):
    def __init__(self, input_size, output_size, world_size, rank):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.world_size = world_size
        self.rank = rank

        self.output_size_per_rank = output_size // world_size

        self.weight = nn.Parameter(
            torch.randn(input_size, self.output_size_per_rank)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.output_size_per_rank)
        )

    def forward(self, x):
        output = torch.matmul(x, self.weight) + self.bias

        return output

In [8]:
class RowParallelism(nn.Module):
    def __init__(self, input_size, output_size, world_size, rank):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.world_size = world_size
        self.rank = rank

        self.input_size_per_rank = input_size // world_size

        self.weight = nn.Parameter(
            torch.randn(self.input_size_per_rank, output_size)
        )
        self.bias = nn.Parameter(torch.zeros(output_size))

    def forward(self, x):
        output_partial = torch.matmul(x, self.weight)

        dist.all_reduce(output_partial, op=dist.ReduceOp.SUM)

        if self.rank == 0:
            output_partial += self.bias

        return output_partial

In [10]:
class TPSPTransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, world_size, rank):
        super().__init__()

        self.d_model = d_model
        self.n_heads = n_heads
        self.world_size = world_size
        self.rank = rank
        self.d_ff = d_ff

        self.head_dim = d_model // n_heads
        self.n_heads_per_rank = n_heads // world_size

        self.qkv_proj = ColumnParallelism(
            d_model, 
            d_model * 3, 
            world_size, 
            rank
        )

        self.out_proj = RowParallelism(
            d_model, 
            d_model, 
            world_size, 
            rank
        )

        self.mlp_fc1 = ColumnParallelism(
            d_model, 
            d_ff,
            world_size, 
            rank
        )

        self.mlp_fc2 = RowParallelism(
            d_ff, 
            d_model, 
            world_size,
            rank
        )

        self.norm1 = SequenceParallelLayerNorm(d_model, world_size, rank)
        self.norm2 = SequenceParallelLayerNorm(d_model, world_size, rank)

        self.dropout1 = SequenceParallelDropout(0.1, world_size, rank)
        self.dropout2 = SequenceParallelDropout(0.1, world_size, rank)

        self.activation = nn.GELU()

    def forward(self, input):
        out = self.norm1(input)
        out = sp_to_tp_transitions(
            out, self.world_size, self.rank
        )

        batch_size, seq_len, d_model = out.shape
        qkv = self.qkv_proj(out).reshape(batch_size, seq_len, 3, self.n_heads_per_rank, self.head_dim).permute(2, 0, 1, 3, 4)
        q, k, v = qkv[0]. qkv[1], qkv[2]

        scores = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)   
        attn = torch.softmax(scores, dim=-1)
        out = attn @ v

        out = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
        out = self.out_proj(out)
        out_sp = tp_to_sp_transitions(out, self.world_size, self.rank)

        new_input = input + self.dropout1(out_sp)
        out = self.norm2(new_input)
        out = sp_to_tp_transitions(out, self.world_size, self.rank)

        out = self.mlp_fc1(out)
        out = self.activation(out)
        out = self.mlp_fc2(out)

        out_sp = tp_to_sp_transitions(out, self.world_size, self.rank)

        out = new_input + self.dropout2(out_sp)

        return out

In [11]:
"""
Communication Pattern:
======================

Per Transformer Block:
1. LayerNorm (SP): No communication ✅
2. SP → TP: AllGather (sequence dimension)
3. QKV (TP): No communication (column parallel)
4. Attention: Local computation
5. Out projection (TP): AllReduce
6. TP → SP: AllGather (hidden) + ReduceScatter (sequence)
7. LayerNorm (SP): No communication ✅
8. SP → TP: AllGather (sequence)
9. MLP (TP): AllReduce
10. TP → SP: AllGather (hidden) + ReduceScatter (sequence)

Total: 4 AllGathers + 2 AllReduces + 2 ReduceScatters per block

This is MORE communication than pure TP!
SP is only worth it for VERY long sequences where memory is the bottleneck.
"""



#### Ring Attention

In [None]:
class RingAttention(nn.Module):
    def __init__(self, d_model, n_heads, world_size, rank):
        super().__init__()

        self.d_model = d_model
        self.n_heads = n_heads
        self.world_size = world_size
        self.rank = rank
        self.head_dim = d_model // n_heads

        self.qkv_proj = nn.Linear(d_model, d_model * 3)

    def forward(self, x):
        batch_size, seq_len_local, d_model = x.shape

        qkv = self.qkv_proj(x).reshape(batch_size, seq_len_local, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 1, 3, 4)
        q, k, v = q[0], k[0], v[0]

        m = torch.full(
            (batch_size, self.n_heads, seq_len_local, 1),
            -torch.inf,
            device=x.device
        )

        d = torch.zeros(
            (batch_size, self.n_heads, seq_len_local, 1),
            device=x.device
        )

        o = torch.zeros(
            (batch_size, self.n_heads, seq_len_local, self.head_dim),
            device=x.device
        )

        k_recv = k.clone()
        v_recv = v.clone()

        for step in range(self.world_size - 1):
            scores = q @ k_recv.transpose(-2, -1) / (self.head_dim ** 0.5)

            m, d, o = self._online_softmax(m, d, o, scores, v_recv)

            k_recv, v_recv = self.ring_exchange(k_recv, v_recv)

        o = o.transpose(1, 2).reshape(batch_size, seq_len_local, d_model)

        return o

    def _online_softmax(self, m_prev, d_prev, o_prev, scores_new, v_new):
        m_new = scores_new.max(dim=-1, keepdim=True).values()
        m_new = torch.max(m_prev, m_new)

        d = torch.exp(scores_new - m_new)

        d_new = d_prev * torch.exp(m_prev - m_new) + d

        o_new = o_prev * d_prev * torch.exp(m_prev - m_new) / d_new + torch.matmul(d / d_new, v_new)

        return m_new, d_new, o_new
    
    def ring_exchange(self, k, v):
        send_rank = (self.rank + 1) % self.world_size
        recv_rank = (self.rank - 1 + self.world_size) % self.world_size

        recv_k = torch.zeros_like(k)
        recv_v = torch.zeros_like(v)

        k_send_handle = dist.isend(k, dst=send_rank)
        v_send_handle = dist.isend(v, dst=send_rank)

        k_recv_handle = dist.irecv(recv_k, src=recv_rank)
        v_recv_handle = dist.irecv(recv_v, src=recv_rank)

        k_send_handle.wait()
        v_send_handle.wait()

        k_recv_handle.wait()
        v_recv_handle.wait()

        return recv_k, recv_v

In [1]:
import torch
import torch.nn as nn
import torch.distributed as dist

In [None]:
class SequenceParallelismLayerNorm(nn.Module):
    def __init__(self, d_model, rank, world_size, eps=1e-5):
        super().__init__()

        self.d_model = d_model
        self.rank = rank
        self.world_size = world_size
        self.eps = eps

        self.alpha = nn.Parameter(torch.randn(d_model))
        self.beta = nn.Parameter(torch.randn(d_model))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True)

        normalised = (x - mean) / torch.sqrt(var + self.eps)
        out = self.alpha * normalised + self.beta

        return out

In [3]:
class SequenceParallelismDropout(nn.Module):
    def __init__(self, p, world_size, rank):
        super().__init__()

        self.p = p
        self.rank = rank
        self.world_size = world_size

        self.register_buffer("rng_seed", torch.zeros(1, device=rank, dtype=torch.float32))

    def forward(self, x):
        if not self.training or self.p == 0:
            return x
        
        if self.rank == 0:
            self.rng_seed.random_()

        dist.broadcast(self.rng_seed, src=0)
        seed = int(self.rng_seed.item())

        with torch.random.fork_rng(devices=[x.device]):
            torch.random.manual_seed(seed)

            nn.functional.dropout(x, p=self.p, training=True)
            return x

In [4]:
class ColumnParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size
        
        self.out_features_per_rank = out_features // world_size
        
        self.weight = nn.Parameter(torch.randn(in_features, self.out_features_per_rank))
        self.bias = nn.Parameter(torch.zeros(self.out_features_per_rank))

    def forward(self, x):
        out = torch.matmul(x, self.weight) + self.bias

        return out

In [None]:
class RowParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.in_features_per_rank = in_features // world_size

        self.weight = nn.Parameter(
            torch.randn(self.in_features_per_rank, out_features)
        )

        self.bias = nn.Parameter(
            torch.zeros(out_features)
        )

    def forward(self, x):
        out = torch.matmul(x, self.weight)
        if self.rank == 0:
            out += self.bias

        dist.all_reduce(out, op=dist.ReduceOp.SUM)

        return out

In [None]:
class TPAttentionLayer(nn.Module):
    def __init__(self, n_heads, d_model, rank, world_size):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.rank = rank
        self.world_size = world_size

        self.n_heads_per_rank = n_heads // world_size
        self.head_dim = d_model // n_heads

        self.qkv_proj = ColumnParallelism(
            in_features=d_model,
            out_features=d_model * 3,
            rank=rank,
            world_size=world_size
        )

        self.out_proj = RowParallelism(
            in_features=d_model,
            out_features=d_model,
            rank=rank,
            world_size=world_size
        )

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        qkv_proj = self.qkv_proj(x).reshape(batch_size, seq_len, 3, self.n_heads_per_rank, self.head_dim)
        qkv_proj = qkv_proj.permute(2, 0, 3, 1, 4)
        q, k, v = qkv_proj[0], qkv_proj[1], qkv_proj[2]

        scores = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)
        attn = nn.functional.softmax(scores, dim=-1)
        output = attn @ v

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_heads_per_rank * self.head_dim)
        output = self.out_proj(output)

        return output

In [7]:
class TPMLP(nn.Module):
    def __init__(self, d_model, d_ff, rank, world_size):
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff 
        self.rank = rank
        self.world_size = world_size

        self.fc1 = ColumnParallelism(
            in_features=d_model,
            out_features=d_ff,
            rank=rank,
            world_size=world_size
        )

        self.fc2 = RowParallelism(
            in_features=d_ff,
            out_features=d_model,
            rank=rank,
            world_size=world_size
        )

        self.activation = nn.GELU()


    def forward(self, x):
        out = x

        out = self.fc1(out)
        out = self.activation(out)
        out = self.fc2(out)

        return out

In [None]:
class TPSPTransformerBlock(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, rank, world_size, p):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.rank = rank
        self.world_size = world_size
        self.p = p

        self.attention_layer = TPAttentionLayer(
            n_heads=n_heads,
            d_model=d_model,
            rank=rank,
            world_size=world_size
        )

        self.mlp = TPMLP(
            d_model=d_model,
            d_ff=d_ff,
            rank=rank,
            world_size=world_size
        )

        self.norm1 = SequenceParallelismLayerNorm(
            d_model=d_model,
            rank=rank,
            world_size=world_size
        )

        self.norm2 = SequenceParallelismLayerNorm(
            d_model=d_model,
            rank=rank,
            world_size=world_size
        )

        self.dropout1 = SequenceParallelDropout(
            p=p,
            world_size=world_size,
            rank=rank
        )

        self.dropout2 = SequenceParallelDropout(
            p=p,
            world_size=world_size,
            rank=rank
        )

    def sp_to_tp_transition(self, x):
        tensor_list = [torch.zeros_like(x) for _ in range(self.world_size)]
        dist.all_gather(tensor_list, x)

        full_tensor = torch.cat(tensor_list, dim=1)
        return full_tensor
    
    def tp_to_sp_transitions(self, x):
        chunks = torch.chunk(x, chunks=self.world_size, dim=1)
        output = chunks[self.rank].contiguous()

        return output
    
    def forward(self, x):
        residual = x
        out = self.norm1(x)
        out = self.sp_to_tp_transition(out)
        out = self.attention_layer(out)
        out = self.tp_to_sp_transitions(out)
        out = residual + self.dropout1(out)

        residual = out
        out = self.norm2(out)
        out = self.sp_to_tp_transition(out)
        out = self.attention_layer(out)
        out = self.tp_to_sp_transitions(out)
        out = residual + self.dropout2(out)

        return out

#### Ring Attention

In [8]:
class RingAttention(nn.Module):
    def __init__(self, n_heads, d_model, rank, world_size):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.rank = rank
        self.world_size = world_size
        self.head_dim = d_model // n_heads

        self.qkv_proj = nn.Linear(d_model, d_model * 3)

    def forward(self, x):
        batch_size, seq_len_local, d_model = x.shape

        qkv = self.qkv_proj(x).reshape(batch_size, seq_len_local, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        m = torch.full(
            (batch_size, self.n_heads, seq_len_local, 1),
            -torch.inf,
            device=x.device
        )

        d = torch.full(
            (batch_size, self.n_heads, seq_len_local, 1),
            0,
            device=x.device
        )

        o = torch.full(
            (batch_size, self.n_heads, seq_len_local, self.head_dim),
            0,
            device=x.device
        )

        k_recv = k.clone()
        v_recv = v.clone()

        for _ in range(self.world_size - 1):
            scores = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)

            m, d, o = self._online_softmax(m, d, o, scores, v_recv)

            k_recv, v_recv = self.ring_exchange(k_recv, v_recv)

        output = o.transpose(1, 2).contiguous().view(batch_size, seq_len_local, d_model)

        return output

    def _online_softmax(self, m_prev, d_prev, o_prev, scores, v):
        m_new = scores.max(dim=-1, keepdim=True).values()
        m_new = torch.max(m_prev, m_new)

        d = torch.exp(scores - m_new)
        d_new = d_prev * torch.exp(m_prev - m_new) + d

        o_new = o_prev * d_prev * torch.exp(m_prev - m_new) / d_new + torch.matmul(d / d_new, v)

        return m_new, d_new, o_new
    
    def ring_exchange(self, k, v):
        recv_rank = (self.rank - 1 + self.world_size) % self.world_size
        send_rank = (self.rank + 1) % self.world_size

        k_recv = torch.zeros_like(k)
        v_recv = torch.zeros_like(v)

        k_recv_handle = dist.irecv(k_recv, src=recv_rank)
        v_recv_handle = dist.irecv(v_recv, src=recv_rank)

        k_send_handle = dist.isend(k, dst=send_rank)
        v_send_handle = dist.isend(v, dst=send_rank)

        k_recv_handle.wait()
        v_recv_handle.wait()

        k_send_handle.wait()
        v_send_handle.wait()

        return k_recv, v_recv