In [14]:
import torch
import torch.nn as nn
import torch.distributed as dist

In [15]:
class ColumnParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.out_features_per_rank = out_features // world_size

        self.weight = nn.Parameter(
            torch.randn(in_features, self.out_features_per_rank)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.out_features_per_rank)
        )

    def forward(self, x):
        out = torch.matmul(x, self.weight) + self.bias

        return out

In [16]:
class RowParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.in_features_per_rank = in_features // world_size

        self.weight = nn.Parameter(
            torch.randn(self.in_features_per_rank, out_features)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.out_features)
        )

    def forward(self, x):
        out = torch.matmul(x, self.weight) + self.bias
        dist.all_reduce(out, op=dist.ReduceOp.SUM)

        return out

In [25]:
class TensorParallelAttention(nn.Module):
    def __init__(self, n_heads, d_model, rank, world_size):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.rank = rank
        self.world_size = world_size

        self.n_heads_per_rank = n_heads // world_size
        self.head_dim = d_model // n_heads

        self.qkv_proj = ColumnParallelism(
            in_features=d_model,
            out_features=3 * d_model,
            rank=rank,
            world_size=world_size
        )

        self.out_proj = RowParallelism(
            in_features=d_model,
            out_features=d_model,
            rank=rank,
            world_size=world_size
        )

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.n_heads_per_rank, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = q @ k.tranpose(-2, -1) / (self.head_dim ** 0.5)
        attn_scores = nn.functional.softmax(scores, dim=-1)
        attn_output = attn_scores @ v

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_heads_per_rank * self.head_dim) 
        output = self.out_proj(attn_output)

        return output
        

In [18]:
class TensorParallelMLP(nn.Module):
    def __init__(self, d_model, d_ff, rank, world_size):
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        self.rank = rank
        self.world_size = world_size

        self.fc1 = ColumnParallelism(
            in_features=d_model,
            out_features=d_ff,
            rank=rank,
            world_size=world_size
        )

        self.fc2 = RowParallelism(
            in_features=d_ff,
            out_features=d_model,
            rank=rank,
            world_size=world_size
        )

        self.activation = nn.GELU()

    def forward(self, x):
        out = x

        out = self.fc1(out)
        out = self.activation(out)
        out = self.fc2(out)

        return out

In [19]:
class TensorParallelismBlock(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, rank, world_size):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.d_ff = d_ff
        self.rank = rank
        self.world_size = world_size

        self.attention_layer = TensorParallelAttention(
            n_heads=n_heads,
            d_model=d_model,
            d_ff=d_ff,
            rank=rank,
            world_size=world_size
        )

        self.mlp = TensorParalleismMLP(
            d_model=d_model,
            d_ff=d_ff,
            world_size=world_size,
            rank=rank
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)

    def forward(self, x):
        out = x

        out = out + self.dropout1(self.attention_layer(self.norm1(out)))
        out = out + self.dropout2(self.attention_layer(self.norm2(out)))

        return out

In [20]:
from torch.autograd import Function

In [21]:
class ColumnParallelismFunction(Function):
    @staticmethod
    def forward(ctx, input_tensor, weight, bias, rank, world_size):
        ctx.saved_for_backward = input_tensor, weight
        ctx.rank = rank
        ctx.world_size = world_size

        out = torch.matmul(input_tensor, weight)
        if bias is not None:
            out += bias

        return out
    
    @staticmethod
    def backward(ctx, grad_output):
        input_tensor, weight = ctx.saved_tensors

        grad_weight = torch.matmul(
            input_tensor.reshape(-1, input_tensor.shape[-1]).t(),
            grad_output.reshape(-1, grad_output.reshape[-1])
        )

        grad_bias = grad_output.sum(dim=list(range(len(grad_output.shape) - 1)))
        grad_input = torch.matmul(grad_output, weight.t())

        dist.all_reduce(grad_input, op=dist.ReduceOp.SUM)

        return grad_input, grad_weight, grad_bias, None, None

In [22]:
class ColumnParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.out_features_per_rank = out_features // world_size

        self.weight = nn.Parameter(
            torch.randn(in_features, self.out_features_per_rank)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.out_features_per_rank)
        )

    def forward(self, x):
        return ColumnParallelFunction.apply(x, self.weight, self.bias, self.rank, self.weight)

In [23]:
class RowParallelismFunction(Function):
    @staticmethod
    def forward(ctx, input_tensor, weight, bias, rank, world_size):
        ctx.save_for_backward = input_tensor, weight
        ctx.rank = rank
        ctx.world_size = world_size
        ctx.use_bias = bias is not None

        output = torch.matmul(input_tensor, weight) 
        if bias is not None:
            output += bias

        dist.all_reduce(output, op=dist.ReduceOp.SUM)

        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        input_tensor, weight = ctx.saved_tensors
        use_bias = ctx.use_bias

        grad_weight = torch.matmul(
            input_tensor.reshape(-1, input_tensor.shape[-1]).t(),
            grad_output.reshape(-1, grad_output.reshape[-1])
        )

        if use_bias:
            grad_bias = grad_output.sum(dim=list(range(len(grad_output.shape) - 1)))
            dist.all_reduce(grad_bias, op=dist.ReduceOp.SUM)
        else:
            grad_bias = None

        grad_input = torch.matmul(grad_output, weight.t())

        return grad_input, grad_weight, grad_bias, None, None

In [24]:
class RowParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.in_features_per_rank = in_features // world_size

        self.weight = nn.Parameter(
            torch.randn(self.in_features_per_rank, out_features)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.out_features)
        )

    def forward(self, x):
        return RowParallelismFunction(x, self.weight, self.bias, self.rank, self.world_size)