In [33]:
import torch
from torch import nn, Tensor
from torch.autograd import Function

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

print(f'{DEVICE=}')

DEVICE=device(type='cuda')


In [15]:
def mlp_fwd(x, W1, W2, W3):
    return (torch.nn.functional.silu(x @ W2.T) * (x @ W1.T)) @ W3.T

class LlamaMLP(nn.Module):
    def __init__(self, hidden_size:int = 64, intermediate_size:int = 256):
        super().__init__()
        self.W1 = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.W2 = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.W3 = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act = nn.SiLU()

    def forward(self, x:Tensor) -> Tensor:
        return mlp_fwd(x, self.W1.weight, self.W2.weight, self.W3.weight)

In [16]:
import torch
import os
from typing import Dict, Any
import psutil
from time import time

class CUDAMemTracker:
    def __init__(self, device=None):
        self.device = torch.device(device or "cuda")

    def __enter__(self):
        torch.cuda.synchronize(self.device)
        torch.cuda.reset_peak_memory_stats(self.device)
        self.start_alloc = torch.cuda.memory_allocated(self.device)
        self.start_reserved = torch.cuda.memory_reserved(self.device)
        self.start_rss = psutil.Process(os.getpid()).memory_info().rss if psutil else None
        return self

    def __exit__(self, exc_type, exc, tb):
        torch.cuda.synchronize(self.device)
        self.end_alloc = torch.cuda.memory_allocated(self.device)
        self.end_reserved = torch.cuda.memory_reserved(self.device)
        self.peak_alloc = torch.cuda.max_memory_allocated(self.device)
        self.peak_reserved = torch.cuda.max_memory_reserved(self.device)

    def report(self, unit=1024**2):
        to_mb = lambda b: None if b is None else b / unit
        return {
            "alloc_delta_MB": to_mb(self.end_alloc - self.start_alloc),
            "reserved_delta_MB": to_mb(self.end_reserved - self.start_reserved),
            "peak_alloc_MB": to_mb(self.peak_alloc),
            "peak_reserved_MB": to_mb(self.peak_reserved),
            "cpu_rss_MB": to_mb(self.start_rss) if self.start_rss is not None else None,
        }



def profile_torch_module_forward(module: torch.nn.Module, inputs: Dict[str, Any]):
    t0 = time()
    with CUDAMemTracker() as t:
        module(**inputs)
    t1 = time() - t0

    report = t.report()
    print(f'Elapsed time:       {t1:3.3}s')
    print(f'Peak allocated mem: {report['peak_alloc_MB']:3.3}MB')
    print(f'Peak reserved mem:  {report['peak_alloc_MB']:3.3}MB')


def profile_torch_module_backward(module: torch.nn.Module, inputs: Dict[str, Any]):
    t0 = time()
    with CUDAMemTracker() as t:
        out = module(**inputs)
        loss = out.sum()
        loss.backward()
    t1 = time() - t0

    report = t.report()
    print(f'Elapsed time:       {t1:3.3}s')
    print(f'Peak allocated mem: {report['peak_alloc_MB']:3.3}MB')
    print(f'Peak reserved mem:  {report['peak_alloc_MB']:3.3}MB')


In [17]:
class LlamaMLPFunction(Function):
    @staticmethod
    def forward(ctx, x, W1, W2, W3):
        a = x @ W1.T
        b = x @ W2.T
        sigma = 1 / (1 + torch.exp(-b))
        c = b * sigma
        d = a * c
        e = d @ W3.T
        # ctx.save_for_backward(x, W1, W2, W3, a, b, c, d, sigma)
        ctx.save_for_backward(x, W1, W2, W3)
        return e

    @staticmethod
    def backward(ctx, grad_output):
        x, W1, W2, W3 = ctx.saved_tensors
        B, S, D = x.shape
        D_in = W3.shape[1]

        # x = x.view(B, D*S)

        a = x @ W1.T
        b = x @ W2.T
        sigma = 1 / (1 + torch.exp(-b))
        c = b * sigma
        d = a * c

        # dW3
        dW3 =  grad_output.permute(0, 2, 1) @ d
        # dW3 =  grad_output.T @ d

        # dW2
        act_prime = sigma * (1 + b * (1 - sigma))
        dL_db = (((grad_output @ W3) * a) * act_prime)
        dW2 = dL_db.permute(0, 2, 1) @ x

        # dW1
        dL_da = ((grad_output @ W3) * c)
        dW1 = dL_da.permute(0, 2, 1) @ x

        # dx
        dx = dL_da @ W1 + dL_db @ W2

        return dx, dW1, dW2, dW3




In [18]:
# class LlamaMLPFunction(Function):
#     @staticmethod
#     def forward(ctx, x, W1, W2, W3):
#         a = x @ W1.T
#         b = x @ W2.T
#         sigma = 1 / (1 + torch.exp(-b))
#         c = b * sigma
#         d = a * c
#         e = d @ W3.T
#         ctx.save_for_backward(x, W1, W2, W3, a, b, c, d, sigma)
#         return e

#     @staticmethod
#     def backward(ctx, grad_output):
#         x, W1, W2, W3, a, b, c, d, sigma = ctx.saved_tensors

#         # dW3
#         dW3 =  grad_output.T @ d

#         # dW2
#         act_prime = sigma * (1 + b * (1 - sigma))
#         dL_db = (((grad_output @ W3) * a) * act_prime)
#         dW2 = dL_db.T @ x

#         # dW1
#         dL_da = ((grad_output @ W3) * c)
#         dW1 = dL_da.T @ x

#         # dx
#         dx = dL_da @ W1 + dL_db @ W2

#         return dx, dW1, dW2, dW3



In [19]:
class LlamaMLPFunction_1(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, W1, W2, W3):
        u = x @ W2.T                      # (B,L,I)
        v = x @ W1.T                      # (B,L,I)
        h = torch.nn.functional.silu(u)   # (B,L,I)
        out = (h * v) @ W3.T              # (B,L,D)
        ctx.save_for_backward(x, W1, W2, W3, u, v, h)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        x, W1, W2, W3, u, v, h= ctx.saved_tensors


        # dW3
        # dW3 = torch.einsum("bld,bli->di", grad_output, h * v)
        dW3 = (grad_output.permute(0,2,1) @ (h * v)).sum(dim=0)

        # dL_dv, dW1
        dL_dv = (grad_output @ W3) * h
        # dW1 = torch.einsum("bli,bld->id", dL_dv, x)
        dW1 = (dL_dv.permute(0,2,1) @ x).sum(dim=0)

        # dL_du, dW2
        sigma = torch.sigmoid(u)
        act_prime = sigma * (1 + u * (1 - sigma))
        dL_du = (grad_output @ W3) * v * act_prime
        # dW2 = torch.einsum("bli,bld->id", dL_du, x)
        dW2 = (dL_du.permute(0,2,1) @ x).sum(dim=0)

        # dx
        dx = dL_dv @ W1 + dL_du @ W2

        return dx, dW1, dW2, dW3


In [66]:
class LlamaMLPFunction_2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, W1, W2, W3):
        with torch.no_grad():
            u = x @ W2.T                      # (B,L,I)
            v = x @ W1.T                      # (B,L,I)
            h = torch.nn.functional.silu(u)   # (B,L,I)
            out = (h * v) @ W3.T              # (B,L,D)
        ctx.save_for_backward(x, W1, W2, W3)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        x, W1, W2, W3= ctx.saved_tensors

        with torch.no_grad():

            # recompute activations (Save memory)
            u = x @ W2.T                      # (B,L,I)
            v = x @ W1.T                      # (B,L,I)
            h = torch.nn.functional.silu(u)   # (B,L,I)

            # dW3
            dW3 = torch.einsum("bld,bli->di", grad_output, h * v)

            # dL_dv, dW1
            dL_dv = (grad_output @ W3) * h
            dW1 = torch.einsum("bli,bld->id", dL_dv, x)

            # dL_du, dW2
            sigma = torch.sigmoid(u)
            act_prime = sigma * (1 + u * (1 - sigma))
            dL_du = (grad_output @ W3) * v * act_prime
            dW2 = torch.einsum("bli,bld->id", dL_du, x)

            # dx
            dx = dL_dv @ W1 + dL_du @ W2

        return dx, dW1, dW2, dW3


In [67]:
B, S, D = 4, 1000, 128
x = torch.rand(B, S, D, requires_grad=True).to(DEVICE)

mlp = LlamaMLP(hidden_size=D).to(DEVICE)

assert torch.allclose(
    mlp(x),
    LlamaMLPFunction_2.apply(x, mlp.W1.weight, mlp.W2.weight, mlp.W3.weight),
    atol=1e-7
)

In [68]:
### Check backward pass


# Forward pass
out_a = LlamaMLPFunction_2.apply(x, mlp.W1.weight, mlp.W2.weight, mlp.W3.weight)
out_b = mlp_fwd(x, mlp.W1.weight, mlp.W2.weight, mlp.W3.weight)

grad_output = torch.randn_like(out_a)

# Backward pass A
grads_a = torch.autograd.grad(out_a, (x, mlp.W1.weight, mlp.W2.weight, mlp.W3.weight), grad_outputs=grad_output, retain_graph=True)

# Backward pass B
grads_b = torch.autograd.grad(out_b, (x, mlp.W1.weight, mlp.W2.weight, mlp.W3.weight), grad_outputs=grad_output, retain_graph=True)

for i, (ga, gb) in enumerate(zip(grads_a, grads_b)):
    max_diff = (ga - gb).abs().max().item()
    print(f"Gradient {i}: max diff = {max_diff:3.3}")
    # assert torch.allclose(ga, gb, atol=1e-6), f"Mismatch in grad {i}"



Gradient 0: max diff = 5.96e-08
Gradient 1: max diff = 0.0
Gradient 2: max diff = 4.77e-06
Gradient 3: max diff = 0.0


In [69]:
class LlamaMLPmanual(nn.Module):
    def __init__(self, hidden_size:int = 64, intermediate_size:int = 256):
        super().__init__()
        self.W1 = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.W2 = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.W3 = nn.Linear(intermediate_size, hidden_size, bias=False)
        self.act = nn.SiLU()

    def forward(self, x:Tensor) -> Tensor:
        return LlamaMLPFunction_2.apply(x, self.W1.weight, self.W2.weight, self.W3.weight)

mlp_manual = LlamaMLPmanual(hidden_size=D).to(DEVICE)


In [70]:
inputs = {'x' : x}
profile_torch_module_backward(mlp, inputs)

print(' ')
profile_torch_module_backward(mlp_manual, inputs)



Elapsed time:       0.00483s
Peak allocated mem: 75.8MB
Peak reserved mem:  75.8MB
 
Elapsed time:       0.00484s
Peak allocated mem: 87.3MB
Peak reserved mem:  87.3MB
