In [None]:
import torch
import torch.nn.functional as F
from torch import nn

class Conv2DFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, kernel, stride=1, padding=1):
        """
        Computation Graph: X-(unfold)-> U -(multiply W)-> Y' -(reshape)-> Y
        """
        b, c, h, w = X.shape
        kn, _, kh, kw = kernel.shape # kn: nr of kernels, kh: kernel height, kw: kernel width
        oh, ow = (h+2*padding-kh)//stride + 1, (w+2*padding-kw)//stride + 1

        U = F.unfold(X, (kh, kw), stride=stride, padding=padding)

        assert oh*ow == U.shape[2]

        U = U.transpose(1, 2)   # (b, k, p) --> (b, p, k)
        W = kernel.view(kn, -1).t()   # (nr_kernels, nr_input_channels, kernel_height, kernel_width) --> (k, nr_kernels)
        Y_prime = U.matmul(W)   # (b, p, nr_kernels), nr_kernels = nr_output_channels
        Y_prime = Y_prime.transpose(1, 2) # (b, oc, p)

        Y = Y_prime.reshape(b, kn, oh, ow)   # (b, oc, oh, ow)

        ctx.hw = (h, w)
        ctx.kshape = (kn, c, kh, kw)
        ctx.sp = (stride, padding)
        ctx.UW = (U, W)
        ctx.yprime_shape = Y_prime.shape

        return Y

    @staticmethod
    def backward(ctx, grad_Y):
        (h, w), (kn, c, kh, kw), (stride, padding), (U, W), Y_prime_size = ctx.hw, ctx.kshape, ctx.sp, ctx.UW, ctx.yprime_shape

        grad_Y_prime = grad_Y.reshape(Y_prime_size).transpose(1, 2)

        kernel_grad = U.transpose(1, 2).matmul(grad_Y_prime)
        kernel_grad = kernel_grad.sum(dim=0)
        kernel_grad = kernel_grad.t().reshape(kn, c, kh, kw)

        grad_U = grad_Y_prime.matmul(W.t())
        grad_U = grad_U.transpose(1, 2)

        input_batch_grad = F.fold(grad_U, (h, w), (kh, kw), stride=stride, padding=padding)

        return input_batch_grad, kernel_grad, None, None


input_batch = torch.randn(16, 3, 32, 32, requires_grad=True)
kernel = torch.randn(2, 3, 4, 5, requires_grad=True)
output = Conv2DFunc.apply(input_batch, kernel)
output.backward(torch.ones_like(output))