In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import numpy as np
from leaf import Tensor
from leaf.functions.function import Function
from functools import partialmethod

class Conv2d(Function):
    # padding doesn't work correctly, should pad input not output...
    def forward(self, x, w, stride=1, padding=0):
        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        
        batch_size, in_C_, in_H, in_W = x.shape
        out_C, in_C, kernel_H, kernel_W = w.shape
        stride_H, stride_W = stride
        padding_H, padding_W = padding
        
        assert in_C == in_C_
        
        out_H = ((in_H + 2 * padding_H - (kernel_H - 1) - 1) // stride_H) + 1
        out_W = ((in_W + 2 * padding_W - (kernel_W - 1) - 1) // stride_W) + 1
        
        self.save_for_backward(x, w, stride)
        tw = w.reshape(out_C, -1).T
        result = np.zeros((batch_size, out_C, out_H, out_W)).astype(x.dtype)
        for h in range(out_H):
            for w in range(out_W):
                ih, iw = h * stride_H, w * stride_W
                result[:, :, h, w] = np.dot(
                    x[:, :, ih:ih+kernel_H, iw:iw+kernel_W].reshape(batch_size, -1), tw)
        
        return result
    
    def backward(self, grad, **kwargs):
        x, w, stride, = self.saved_tensors
        _, _, out_H, out_W = grad.shape
        batch_size, in_C_, in_H, in_W = x.shape
        out_C, in_C, kernel_H, kernel_W = w.shape
        stride_H, stride_W = stride
        
        dx = np.zeros((batch_size, in_C_, in_H, in_W)).astype(x.dtype)
        dw = np.zeros((out_C, in_C, kernel_H, kernel_W)).astype(w.dtype)
        tw = w.reshape(out_C, -1)
        for h in range(out_H):
            for w in range(out_W):
                ih, iw = h * stride_H, w * stride_W
                g = grad[:, :, h, w]
                dw += g.T.dot(
                    x[:, :, ih:ih+kernel_H, iw:iw+kernel_W].reshape(batch_size, -1)).reshape(dw.shape)
                dx[:, :, ih:ih+kernel_H, iw:iw+kernel_W] += g.dot(
                    tw).reshape(batch_size, in_C_, kernel_H, kernel_W)
        
        return  dx, dw
    

class MaxPool2d(Function):
    def forward(self, x, kernel_size=2, stride=1, padding=0):
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        if isinstance(stride, int):
            stride = (stride, stride)
        if isinstance(padding, int):
            padding = (padding, padding)
        
        batch_size, in_C, in_H, in_W = x.shape
        kernel_H, kernel_W = kernel_size
        stride_H, stride_W = stride
        padding_H, padding_W = padding
        
        out_H = ((in_H + 2 * padding_H - (kernel_H - 1) - 1) // stride_H) + 1
        out_W = ((in_W + 2 * padding_W - (kernel_W - 1) - 1) // stride_W) + 1
        
        self.save_for_backward(x, kernel_size, stride)
        result = np.zeros((batch_size, in_C, out_H, out_W)).astype(x.dtype)
        for h in range(out_H):
            for w in range(out_W):
                ih, iw = h * stride_H, w * stride_W
                result[:, :, h, w] = np.max(
                    x[:, :, ih:ih+kernel_H, iw:iw+kernel_W].reshape(batch_size, in_C, -1), axis=-1)
        
        return result
    
    def backward(self, grad, **kwargs):
        x, kernel_size, stride, = self.saved_tensors
        _, _, out_H, out_W = grad.shape
        batch_size, in_C, in_H, in_W = x.shape
        kernel_H, kernel_W = kernel_size
        stride_H, stride_W = stride
        
        dx = np.zeros((batch_size, in_C, in_H, in_W)).astype(x.dtype)
        for h in range(out_H):
            for w in range(out_W):
                ih, iw = h * stride_H, w * stride_W
                xx = x[:, :, ih:ih+kernel_H, iw:iw+kernel_W].reshape(batch_size, in_C, -1)
                mask = (xx.max(axis=-1, keepdims=True) == xx)
                xd = (xx * mask)
                gg = np.expand_dims(grad[:, :, h, w], axis=-1)
                dx[:, :, ih:ih+kernel_H, iw:iw+kernel_W] = (gg * xd).reshape(batch_size, in_C, kernel_H, kernel_W)
        
        return dx
    
setattr(Tensor, 'conv2d', partialmethod(Conv2d.apply, Conv2d))
setattr(Tensor, 'maxpool2d', partialmethod(MaxPool2d.apply, MaxPool2d))

In [2]:
import torch

x = Tensor.uniform(8, 2, 12, 12, requires_grad=True)
w = Tensor.uniform(4, 2, 5, 5, requires_grad=True)

tx = torch.tensor(x.data, requires_grad=True)
tw = torch.tensor(w.data, requires_grad=True)

lco = x.conv2d(w, stride=2, padding=0)
print(lco.shape)

tco = torch.nn.functional.conv2d(tx, tw, stride=2, groups=1, padding=0)
print(tco.shape)

np.testing.assert_allclose(tco.detach().numpy(), lco.data, rtol=1e-6, atol=1e-3)
print('forward pass with convolution passed test')

lco.mean().backward()
tco.mean().backward()

np.testing.assert_allclose(tw.grad.numpy(), w.grad, rtol=1e-6, atol=1e-3)
np.testing.assert_allclose(tx.grad.numpy(), x.grad, rtol=1e-6, atol=1e-3)
print('backward pass with convolution passed test')

(8, 4, 4, 4)
torch.Size([8, 4, 4, 4])
forward pass with convolution passed test
backward pass with convolution passed test


In [3]:
mx = Tensor.uniform(8, 2, 12, 12, requires_grad=True)

lmo = mx.maxpool2d(kernel_size=2, stride=1, padding=0)
print(lmo.shape)

mtx = torch.tensor(mx.data, requires_grad=True)

tmo = torch.nn.functional.max_pool2d(mtx, (2, 2), stride=1, padding=0)
print(tmo.shape)

np.testing.assert_allclose(tmo.detach().numpy(), lmo.data, rtol=1e-6, atol=1e-3)
print('forward pass with maxpool2d passed test')

lmo.mean().backward()
tmo.mean().backward()
np.testing.assert_allclose(mtx.grad.numpy(), mx.grad, rtol=1e-6, atol=1e-3)
print('backward pass with maxpool2d passed test')

(8, 2, 11, 11)
torch.Size([8, 2, 11, 11])
forward pass with maxpool2d passed test


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


AssertionError: 
Not equal to tolerance rtol=1e-06, atol=0.001

Mismatched elements: 569 / 2304 (24.7%)
Max absolute difference: 0.00206621
Max relative difference: 22299.613
 x: array([[[[0.000517, 0.000517, 0.      , ..., 0.      , 0.      ,
          0.      ],
         [0.      , 0.      , 0.      , ..., 0.00155 , 0.000517,...
 y: array([[[[ 6.796276e-06,  5.115369e-07, -0.000000e+00, ...,
          -0.000000e+00, -0.000000e+00, -0.000000e+00],
         [-0.000000e+00, -0.000000e+00, -0.000000e+00, ...,...