In [4]:
import torch
from torch import nn, Tensor
from torch.autograd import Function

In [9]:
class LlamaMLP(nn.Module):
    def __init__(self, hidden_size:int = 64, intermediate_size:int = 256):
        super().__init__()
        self.A = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.B = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.C = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act = nn.GELU()

    def forward(self, x:Tensor) -> Tensor:
        return self.C(self.act(self.B(x))*self.A(x))

In [10]:
class LlamaMLPFunction(Function):
    @staticmethod
    def forward(ctx, x, W_A, W_B, W_C):
        a = x @ W_A.T
        b = x @ W_B.T
        c = nn.functional.gelu(b)
        d = a * c
        e = d @ W_C.T
        ctx.save_for_backward(x, W_A, W_B, W_C, a, b, c, d, e)
        return e
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        x, W_A, W_B, W_C, a, b, c, d, e = ctx.saved_tensors

class LlamaMLPCustom(nn.Module):
    def __init__(self, hidden_size:int = 64, intermediate_size:int = 256):
        super().__init__()
        self.A = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.B = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.C = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act = nn.GELU()
    def forward(self, x:Tensor) -> Tensor:
        return LlamaMLPFunction.apply(x, self.A.weight, self.B.weight, self.C.weight)


In [12]:
mlp = LlamaMLP()
mlp_custom = LlamaMLPCustom()

mlp_custom.A.weight.data.copy_(mlp.A.weight)
mlp_custom.B.weight.data.copy_(mlp.B.weight)
mlp_custom.C.weight.data.copy_(mlp.C.weight)

x = torch.rand(128, 64)
torch.allclose(mlp(x), mlp_custom(x))

True