In [3]:
import torch
from torch import nn, Tensor
from torch.autograd import gradcheck, Function

In [None]:
### this is the reference class, check the manually implememnted backward() against this one
class FusedLinearActEager(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(128, 128)
        self.act = nn.Tanh()

    def forward(self, x:Tensor) -> Tensor:
        return self.act(self.linear(x))

class FusedLinearActCustomFunction(Function):
    @staticmethod
    def forward(ctx, W:Tensor, b:Tensor, x:Tensor):
        z = x @ W.T + b
        out = torch.nn.functional.tanh(z)
        ctx.save_for_backward(out, x, W)
        return out

    @staticmethod
    def backward(ctx, grad_output):
        out, x, W = ctx.saved_tensors
        dout_dz = 1 - out.pow(2)
        
        dz = grad_output * dout_dz  # apply chain rule

        grad_x = dz @ W
        grad_W = dz.T @ x
        grad_b = dz.sum(dim=0)
        # dout_dW = x * dout_dz
        # dout_db = dout_dz
        # dout_dx = dout_dz @ W.T
        return grad_W, grad_b, grad_x

class FusedLinearActCustomModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(128, 128)

    def forward(self, x: Tensor) -> Tensor:
        return FusedLinearActCustomFunction.apply(self.linear.weight, self.linear.bias, x)
        

# Assume input and targets
x = torch.randn(4, 128, requires_grad=True)
eager_model = FusedLinearActEager()
custom_model = FusedLinearActCustomModule()  # Your version using custom autograd.Function

# Copy weights
custom_model.linear.weight.data.copy_(eager_model.linear.weight.data)
custom_model.linear.bias.data.copy_(eager_model.linear.bias.data)


# Copy weights
assert torch.allclose(eager_model(x), custom_model(x), atol=1e-6)

# Forward and backward
out_eager = eager_model(x)
loss_eager = out_eager.sum()
loss_eager.backward()

out_custom = custom_model(x)
loss_custom = out_custom.sum()
loss_custom.backward()
...

# Compare gradients
print(torch.allclose(eager_model.linear.weight.grad, custom_model.linear.weight.grad, atol=1e-4))
print(torch.allclose(x.grad, x.grad, atol=1e-4))

True
True
