In [15]:
import torch

class MyCustomFCLayer(torch.autograd.Function):
    # Don't actually use this in pratice, does not consider broadcasting.
    @staticmethod
    def forward(ctx, x, W, b=None):
        ctx.save_for_backward(x, W)
        print("Custom Forward was called!")
        z = torch.matmul(W, x) # Matrix multiplication
        if b is not None:
            z += b # Direct addition of the bias vector
        return z
        

    @staticmethod
    def backward(ctx, grad_z):
        x, W = ctx.saved_tensors
        print(f"Custom Backward was called! Got output gradient {grad_z}")
        grad_x = grad_W = grad_b = None
        if ctx.needs_input_grad[0]:
            grad_x = W.T @ grad_z # nabla_x Wx + b = W^T, nabla_x f = W^T nabla_z f
        if ctx.needs_input_grad[1]:
            grad_W = torch.outer(grad_z, x) # analytical without tensor product with sum
        if ctx.needs_input_grad[2]:
            grad_b = grad_z # Direct gradient without summing
        return grad_x, grad_W, grad_b

In [16]:
x_dim = 3
z_dim = 5

torch.manual_seed(0)

t = torch.randn(size=(z_dim,))

x_pytorch = torch.arange(x_dim, dtype=torch.float32, requires_grad=True)
x_ours = torch.arange(x_dim, dtype=torch.float32, requires_grad=True)
print("input x :", x_ours)

# pytorch W and b init
lin = torch.nn.Linear(in_features=x_dim, out_features=z_dim)

# Our manual W and b. We copy parameters for comparison
W = torch.empty(size=(z_dim, x_dim), requires_grad=True)
b = torch.empty(size=(z_dim,), requires_grad=True)
with torch.no_grad():
    W.copy_(lin.weight)
    b.copy_(lin.bias)

input x : tensor([0., 1., 2.], requires_grad=True)


In [17]:
z_pytorch = lin(x_pytorch)

z_ours = MyCustomFCLayer.apply(x_pytorch, W, b)

Custom Forward was called!


In [18]:
print("Our z: ", z_ours)
print("PyTorch's z: ", z_pytorch)

Our z:  tensor([-1.3940,  0.5576,  0.0218,  0.5202,  1.6055],
       grad_fn=<MyCustomFCLayerBackward>)
PyTorch's z:  tensor([-1.3940,  0.5576,  0.0218,  0.5202,  1.6055], grad_fn=<ViewBackward0>)


In [19]:
# Ignore this. Just to make sure that we can rerun this cell without issues
z_ours = MyCustomFCLayer.apply(x_ours, W, b)   # Ignore this.
x_ours.grad = W.grad = b.grad = None           # Ignore this.

mse_ours = (z_ours - t).square().mean()
mse_ours.backward()
print()
print("Our grad x: ", x_ours.grad)
print("Our grad W: ", W.grad)
print("Our grad b: ", b.grad)

Custom Forward was called!
Custom Backward was called! Got output gradient tensor([-1.1740,  0.3404,  0.8802, -0.0193,  1.0760])

Our grad x:  tensor([0.2009, 0.3733, 1.1361])
Our grad W:  tensor([[-0.0000, -1.1740, -2.3480],
        [ 0.0000,  0.3404,  0.6808],
        [ 0.0000,  0.8802,  1.7605],
        [-0.0000, -0.0193, -0.0386],
        [ 0.0000,  1.0760,  2.1520]])
Our grad b:  tensor([-1.1740,  0.3404,  0.8802, -0.0193,  1.0760])


In [20]:
# Ignore this. Just to make sure that we can rerun this cell without issues
z_pytorch = lin(x_pytorch)                                  # Ignore this.
x_pytorch.grad = lin.weight.grad = lin.bias.grad = None     # Ignore this.

mse_pytorch = (z_pytorch - t).square().mean()
mse_pytorch.backward()
print("PyTorch grad x: ", x_pytorch.grad)
print("PyTorch grad W: ", W.grad)
print("PyTorch grad b: ", b.grad)

PyTorch grad x:  tensor([0.2009, 0.3733, 1.1361])
PyTorch grad W:  tensor([[-0.0000, -1.1740, -2.3480],
        [ 0.0000,  0.3404,  0.6808],
        [ 0.0000,  0.8802,  1.7605],
        [-0.0000, -0.0193, -0.0386],
        [ 0.0000,  1.0760,  2.1520]])
PyTorch grad b:  tensor([-1.1740,  0.3404,  0.8802, -0.0193,  1.0760])
