In [2]:
import torch
import torch.nn as nn
import torch.distributed as dist

In [3]:
class ColumnParallelism(nn.Module):
    def __init__(self, input_size, output_size, world_size, rank):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.world_size = world_size
        self.rank = rank

        self.output_size_per_rank = output_size // world_size

        self.weight = nn.Parameter(
            torch.randn(input_size, self.output_size_per_rank)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.output_size_per_rank)
        )

    def forward(self, x):
        output = torch.matmul(x, self.weight) + self.bias

        return output

In [3]:
class RowParallelism(nn.Module):
    def __init__(self, input_size, output_size, world_size, rank):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.world_size = world_size
        self.rank = rank

        self.input_size_per_rank = input_size // world_size

        self.weight = nn.Parameter(
            torch.randn(self.input_size_per_rank, output_size)
        )
        self.bias = nn.Parameter(torch.zeros(output_size))

    def forward(self, x):
        output_partial = torch.matmul(x, self.weight)

        dist.all_reduce(output_partial, op=dist.ReduceOp.SUM)

        if self.rank == 0:
            output_partial += self.bias

        return output_partial

In [4]:
class TensorParallelismAttention(nn.Module):
    def __init__(self, n_heads, d_model, world_size, rank):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.world_size = world_size
        self.rank = rank

        self.num_heads_per_rank = n_heads // world_size
        self.head_dim = d_model // n_heads

        self.qkv_proj = ColumnParallelism(
            d_model,
            d_model * 3,
            world_size,
            rank
        )

        self.out_prof = RowParallelism(
            d_model,
            d_model,
            world_size,
            rank
        )

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        qkv = self.qkv_proj(x)

        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads_per_partition, self.head_dim)
        qkv = qkv.permute(2, 0, 1, 3, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)
        attn_weights = nn.functional.softmax(scores, dim=-1)
        attn_output = attn_weights @ v

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        output = self.out_proj(attn_output)

        return output

In [6]:
class TensorParalleismMLP(nn.Module):
    def __init__(self, d_model, d_ff, world_size, rank):
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        self.world_size = world_size
        self.rank = rank

        self.fc1 = ColumnParallelism(
            d_model, 
            d_ff, 
            world_size,
            rank
        )

        self.fc2 = RowParallelism(
            d_ff,
            d_model,
            world_size,
            rank
        )

        self.activation = nn.GeLU()

    def forward(self, x):
        out = x 

        out = self.fc1(out)
        out = self.activation(out)
        out = self.fc2(out)

        return out

In [7]:
class TensorParallelismTransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, world_size, rank):
        super().__init__()

        self.norm1 = nn.LayerNorm(d_model)
        self.attention = TensorParallelismAttention(d_model, n_heads, world_size, rank)

        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = TensorParallelismMLP(d_model, d_ff, world_size, rank)

    def forward(self, x):
        out = x

        out = out + self.attention(self.norm1(out))
        out = out + self.mlp(self.norm2(out))

        return out

#### Backpropogation

In [11]:
from torch.autograd import Function

In [12]:
class ColumnParallelFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, world_size, rank):
        ctx.save_for_backward(input, weight)
        ctx.world_size = world_size
        ctx.rank = rank

        output = torch.matmul(input, weight)
        if bias is not None:
            output += bias

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensor

        grad_weight = torch.matmul(
            input.reshape(-1, input.shape[-1]).t(), 
            grad_output.reshape(-1, grad_output.shape[-1])
        )

        grad_bias = grad_output.sum(dim=list(range(len(grad_output.shape) - 1)))

        grad_input = torch.matmul(grad_output, weight.t())

        dist.all_reduce(grad_input, op=dist.ReduceOp.SUM)

        return grad_input, grad_weight, grad_bias, None, None

In [None]:
class ColumnParallelim(nn.Module):
    def __init__(self, input_size, output_size, world_size, rank):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.world_size = world_size
        self.rank = rank

        self.output_size_per_rank = output_size // world_size

        self.weight = nn.Parameter(
            torch.randn(input_size, self.output_size_per_rank)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.output_size_per_rank)
        )

    def forward(self, x):
        return ColumnParallelFunction.apply(
            x self.weight, self.bias, self.world_size, self.rank
        )

In [14]:
class RowParallelFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, world_size, rank):
        ctx.save_for_backward(input, weight)
        ctx.world_size = world_size
        ctx.rank = rank
        ctx.use_bias = bias is not None
        
        output = torch.matmul(input, weight)
        
        dist.all_reduce(output, op=dist.ReduceOp.SUM)
        
        if bias is not None and rank == 0:
            output = output + bias
        
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        input, weight = ctx.saved_tensors
        world_size = ctx.world_size
        rank = ctx.rank
        use_bias = ctx.use_bias
        
        grad_weight = torch.matmul(
            input.reshape(-1, input.shape[-1]).t(),
            grad_output.reshape(-1, grad_output.shape[-1])
        )
        
        grad_bias = None
        if use_bias:
            grad_bias = grad_output.sum(dim=list(range(len(grad_output.shape) - 1)))
            dist.all_reduce(grad_bias, op=dist.ReduceOp.SUM)

        grad_input = torch.matmul(grad_output, weight.t())
        
        return grad_input, grad_weight, grad_bias, None, None

In [15]:
class RowParallelism(nn.Module):
    def __init__(self, input_size, output_size, world_size, rank):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.world_size = world_size
        self.rank = rank

        self.input_size_per_rank = input_size // world_size

        self.weight = nn.Parameter(
            torch.randn(self.input_size_per_rank, output_size)
        )
        self.bias = nn.Parameter(torch.zeros(output_size))

    def forward(self, x):
        return RowParallelFunction.apply(
            x, self.weight, self.bias, self.world_size, self.rank
        )