In [None]:
import sys
sys.path.append("..") # for sibling import

import torch
import walnut
import walnut.nn as nn

In [None]:
batches = 5
in_channels = 3
out_channels = 4
x_size = (8, 8)
kernel_size = (3, 3)

X = walnut.randn((batches, in_channels, *x_size))
W1 = walnut.randn((out_channels, in_channels, *kernel_size)) # (K, C, Y, X)
W2 = walnut.randn((out_channels, out_channels, *kernel_size)) # (K, C, Y, X)
B1 = walnut.randn((out_channels,))
B2 = walnut.randn((out_channels,))

t_x = torch.from_numpy(X.data).float()
t_x.requires_grad = True
t_w1 = torch.nn.Parameter(torch.from_numpy(W1.data).float(), requires_grad=True)
t_w2 = torch.nn.Parameter(torch.from_numpy(W2.data).float(), requires_grad=True)
t_b1 = torch.nn.Parameter(torch.from_numpy(B1.data).float(), requires_grad=True)
t_b2 = torch.nn.Parameter(torch.from_numpy(B2.data).float(), requires_grad=True)

### Forward

In [None]:
class model(nn.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.layers.Convolution2d(in_channels, out_channels, weights=W1)
        self.conv1.b = B1
        self.mp = nn.layers.MaxPooling2d()
        self.conv2 = nn.layers.Convolution2d(out_channels, out_channels, weights=W2)
        self.conv2.b = B2
        
        self.layers = [self.conv1, self.mp, self.conv2]

    def __call__(self, X):
        y = self.conv2(self.mp(self.conv1(X)))

        if self.training:

            def backward(dy):
                return self.conv1.backward(self.mp.backward(self.conv2.backward(dy)))
            self.backward = backward

        return y

In [None]:
w_model = model()
w_model.set_training(True)
w_out = w_model(X)
w_out[0, 1]

In [None]:
t_conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size, bias=False)
t_conv1.weight = t_w1
t_conv1.bias = t_b1
t_conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size, bias=False)
t_conv2.weight = t_w2
t_conv2.bias = t_b2
t_mp = torch.nn.MaxPool2d((2, 2))

t_out = t_conv2(t_mp(t_conv1(t_x)))
t_out[0, 1]

### Backward

In [None]:
dy = walnut.randn(w_out.shape).data
t_dy = torch.from_numpy(dy)

x_grad = w_model.backward(dy)

t_out.backward(t_dy)

X

In [None]:
x_grad[1, 0]

In [None]:
t_x.grad[1, 0]

W

In [None]:
w_model.conv1.w.grad[0, 0]

In [None]:
t_w1.grad[0, 0]

In [None]:
w_model.conv2.w.grad[0, 0]

In [None]:
t_w2.grad[0, 0]

B

In [None]:
w_model.conv1.b.grad

In [None]:
t_b1.grad

In [None]:
w_model.conv2.b.grad

In [None]:
t_b2.grad