In [1]:
from torch import nn
import torch

In [2]:
class CausalConv1d(nn.Module):
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 kernel_size, 
                 dilation=1,
                 causal=True,
                 **kwargs):
        super(CausalConv1d, self).__init__()
        self.causal = causal
        if causal:
            self.padding = (kernel_size - 1) * dilation
            self.conv = nn.Conv1d(in_channels, 
                                  out_channels, 
                                  kernel_size, 
                                  padding=self.padding, 
                                  dilation=dilation, 
                                  bias=False, 
                                  **kwargs)
        else:
            self.padding = 0
            self.conv = nn.Conv1d(in_channels, 
                                  out_channels, 
                                  kernel_size, 
                                  padding=self.padding, 
                                  dilation=dilation, 
                                  bias=False, 
                                  **kwargs)
    
    def forward(self, input_):
        if self.causal:
            return self.conv(input_)[:, :, :-self.padding] if self.padding else self.conv(input_)
        else:
            return self.conv(input_)


In [3]:
B=1
C=2
O=2
L=10
K=2
D=3
padding = (K - 1) * D
padding=0
causal=True
print('padding', padding)
# S=1
x = torch.arange(L).float().repeat((B,C,1)) # B,C,L
x.requires_grad=True


m1 = CausalConv1d(C,O,K,1,causal=causal)
m1.zero_grad()
m2 = CausalConv1d(O,O,K,2,causal=causal)
m2.zero_grad()

o = m2(m1(x))
print('x:', x.shape) # B,C,L
# print('m:', m.weight.shape) # O,C,K

print('o:',o.shape) # B,O,L
print('\n',o)

x.grad = None
o[0,:,-3].mean().backward()
print('m.conv.grad\n', m1.conv.weight.grad)
print('x.grad\n', x.grad)

padding 0
x: torch.Size([1, 2, 10])
o: torch.Size([1, 2, 10])

 tensor([[[ 0.0000, -0.3043, -0.4812, -0.4617, -0.4812, -0.5008, -0.5203,
          -0.5398, -0.5593, -0.5789],
         [ 0.0000, -0.1504, -0.2772, -0.7280, -1.0385, -1.3490, -1.6596,
          -1.9701, -2.2806, -2.5911]]], grad_fn=<SliceBackward0>)
m.conv.grad
 tensor([[[-1.6514, -2.0302],
         [-1.6514, -2.0302]],

        [[ 2.6584,  3.1030],
         [ 2.6584,  3.1030]]])
x.grad
 tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.1112, -0.0630,  0.2115,
          -0.2304,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0606, -0.0007, -0.1360,
           0.0031,  0.0000,  0.0000]]])


In [4]:
B=1
C=2
O=2
L=10
K=3
# D=3
causal = True
padding = (K - 1) * D
padding=0
print('padding', padding)
# S=1
x = torch.arange(L).float().repeat((B,C,1)) # B,C,L
x.requires_grad=True
x.grad = None

m1 = CausalConv1d(C,O,K,1, causal=causal)
m1.zero_grad()
m2 = CausalConv1d(O,O,K,2,causal=causal)
m2.zero_grad()

o = m2(m1(x))
print('x:', x.shape) # B,C,L
# print('m:', m.weight.shape) # O,C,K

print('o:',o.shape) # B,O,L
print('\n',o)


o[0,:,-2].mean().backward()
# print('m.conv.grad\n', m.conv.weight.grad)
print('x.grad\n', x.grad)

padding 0
x: torch.Size([1, 2, 10])
o: torch.Size([1, 2, 10])

 tensor([[[ 0.0000, -0.0498, -0.1917, -0.3475, -0.7206, -0.9686, -1.3282,
          -1.5337, -1.7393, -1.9448],
         [ 0.0000,  0.1781,  0.5667,  0.7634,  1.1631,  1.3731,  1.5376,
           1.7228,  1.9080,  2.0932]]], grad_fn=<SliceBackward0>)
x.grad
 tensor([[[ 0.0000,  0.0000,  0.0538, -0.0494, -0.0353, -0.0065, -0.0481,
           0.0448,  0.0338,  0.0000],
         [ 0.0000,  0.0000,  0.0335, -0.0291,  0.0030, -0.0007, -0.0548,
           0.0144,  0.0304,  0.0000]]])


In [25]:
o

tensor([[[ 0.0000,  0.0589, -0.0095, -0.1673, -0.2739, -0.3528, -0.2900,
          -0.3044, -0.3189, -0.3333],
         [ 0.0000, -0.1277, -0.2582, -0.0409,  0.1276, -0.0453, -0.1835,
          -0.1501, -0.1168, -0.0834]]], grad_fn=<SliceBackward0>)

In [26]:
x

tensor([[[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],
         [0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]]], requires_grad=True)