In [4]:
import torch
from torch import nn, Tensor
from torch.autograd import Function

In [15]:
class LlamaMLP(nn.Module):
    def __init__(self, hidden_size:int = 64, intermediate_size:int = 256):
        super().__init__()
        self.A = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.B = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.C = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act = nn.SiLU()

    def forward(self, x:Tensor) -> Tensor:
        return self.C(self.act(self.B(x))*self.A(x))

In [16]:
class LlamaMLPFunction(Function):
    @staticmethod
    def forward(ctx, x, W_A, W_B, W_C):
        a = x @ W_A.T
        b = x @ W_B.T
        c = nn.functional.silu(b)
        d = a * c
        e = d @ W_C.T
        ctx.save_for_backward(x, W_A, W_B, W_C, a, b, c, d, e)
        return e
    
    @staticmethod
    def backward(ctx, grad_output):
        x, W_A, W_B, W_C, a, b, c, d, e = ctx.saved_tensors


In [None]:
mlp = LlamaMLP()

x = torch.rand(128, 64)
torch.allclose(mlp(x), LlamaMLPFunction.apply(x, mlp.A.weight, mlp.B.weight, mlp.C.weight))

True