In [1]:
import torch
from torch import nn, Tensor
from torch.autograd import Function

In [2]:
def mlp_fwd(x, W1, W2, W3):
    return (torch.nn.functional.silu(x @ W2.T) * (x @ W1.T)) @ W3.T

class LlamaMLP(nn.Module):
    def __init__(self, hidden_size:int = 64, intermediate_size:int = 256):
        super().__init__()
        self.W1 = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.W2 = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.W3 = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act = nn.SiLU()

    def forward(self, x:Tensor) -> Tensor:
        return mlp_fwd(x, self.W1.weight, self.W2.weight, self.W3.weight)

In [None]:
class LlamaMLPFunction(Function):
    @staticmethod
    def forward(ctx, x, W1, W2, W3):
        a = x @ W1.T
        b = x @ W2.T
        sigma = 1 / (1 + torch.exp(-b))
        c = b * sigma
        d = a * c
        e = d @ W3.T
        # ctx.save_for_backward(x, W1, W2, W3, a, b, c, d, sigma)
        ctx.save_for_backward(x, W1, W2, W3)
        return e
    
    @staticmethod
    def backward(ctx, grad_output):
        x, W1, W2, W3 = ctx.saved_tensors
        B, S, D = x.shape
        D_in = W3.shape[1]

        a = x @ W1.T
        b = x @ W2.T
        sigma = 1 / (1 + torch.exp(-b))
        c = b * sigma
        d = a * c
        
        # dW3
        dW3 =  grad_output.permute(B, D, S) @ d

        # dW2
        act_prime = sigma * (1 + b * (1 - sigma))
        dL_db = (((grad_output @ W3) * a) * act_prime)
        dW2 = dL_db.permute(B, D_in, S) @ x 
        
        # dW1
        dL_da = ((grad_output @ W3) * c)
        dW1 = dL_da.permute(B, D_in, S) @ x 

        # dx
        dx = dL_da @ W1 + dL_db @ W2

        return dx, dW1, dW2, dW3




In [71]:
mlp = LlamaMLP()

x = torch.rand(256, 64)
assert torch.allclose(
    mlp(x), 
    LlamaMLPFunction.apply(x, mlp.W1.weight, mlp.W2.weight, mlp.W3.weight), 
    atol=1e-7
)

In [72]:
### Check backward pass
x = torch.rand(256, 64, requires_grad=True)

# Forward pass
out_a = LlamaMLPFunction.apply(x, mlp.W1.weight, mlp.W2.weight, mlp.W3.weight)
out_b = mlp_fwd(x, mlp.W1.weight, mlp.W2.weight, mlp.W3.weight)

grad_output = torch.randn_like(out_a)

# Backward pass A
grads_a = torch.autograd.grad(out_a, (x, mlp.W1.weight, mlp.W2.weight, mlp.W3.weight), grad_outputs=grad_output, retain_graph=True)

# Backward pass B
grads_b = torch.autograd.grad(out_b, (x, mlp.W1.weight, mlp.W2.weight, mlp.W3.weight), grad_outputs=grad_output, retain_graph=True)

for i, (ga, gb) in enumerate(zip(grads_a, grads_b)):
    max_diff = (ga - gb).abs().max().item()
    print(f"Gradient {i}: max diff = {max_diff:3.3}")
    assert torch.allclose(ga, gb, atol=1e-6), f"Mismatch in grad {i}"
    


Gradient 0: max diff = 5.96e-08
Gradient 1: max diff = 5.96e-07
Gradient 2: max diff = 7.15e-07
Gradient 3: max diff = 1.43e-06
