In [None]:
import torch
import walnut
import walnut.nn as nn

In [None]:
batches = 10
in_channels = 1
out_channels = 16
x_size = 16
kernel_size = 2

X = walnut.randn((batches, in_channels, x_size))
W = walnut.randn((out_channels, in_channels, kernel_size))
B = walnut.randn((out_channels,))

t_x = torch.from_numpy(X.data).float()
t_x.requires_grad = True
t_w = torch.nn.Parameter(torch.from_numpy(W.data).float(), requires_grad=True)
t_b = torch.nn.Parameter(torch.from_numpy(B.data).float(), requires_grad=True)

# Forward

In [None]:
strides = 1
dilation = 2
pad = "same"

### Walnut

In [None]:
w_conv = nn.modules.Convolution1d(out_channels, kernel_size, input_shape=X.shape[1:], stride=strides, dil=dilation, pad=pad)
w_conv.w = W
w_conv.b = B
w_out = w_conv(X)
w_out[0, 0]

### Torch

In [None]:
t_conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, bias=False, stride=strides, dilation=dilation, padding=pad)
t_conv.weight = t_w
t_conv.bias = t_b
t_out = t_conv(t_x)
t_out[0, 0]

# Backward

In [None]:
dy = walnut.randn(t_out.shape)
t_dy = torch.from_numpy(dy.data)
_ = w_conv.backward(dy.data)
t_out.backward(t_dy)

### dx = dy * w
Walnut

In [None]:
w_conv.x.grad[0, 0]

Torch

In [None]:
t_x.grad[0, 0]

### dw = x * dy
Walnut

In [None]:
w_conv.w.grad[0, 0]

Torch

In [None]:
t_w.grad[0, 0]

### db
Walnut

In [None]:
w_conv.b.grad

Torch

In [None]:
t_b.grad