In [None]:
from init_notebook import *

In [None]:
@torch.no_grad()
def plot_conv(
        size: int, 
        kernel_size: int, 
        dilation: Union[int, Iterable[int]], 
        layers: int = 6, 
        zoom: int = 5,
        middle: bool = False,
):    
    inp = torch.zeros(1, 3, size)
    inp[..., size//2 if middle else 0] = 1
    
    grid = []
    def add_pic(state):
        img = state.permute(1, 0, 2)  # take the one batch dimension as height
        #img = (img.abs() / img.max()).pow(.3)
        grid.append(resize(img.abs().clamp(0, 1), zoom))
        
    add_pic(inp)
    for i, dil in enumerate(param_make_list(dilation, layers, "dilation")):
        padding = int(math.floor(kernel_size / 2)) * dil
        conv = nn.Conv1d(3, 3, kernel_size, padding=padding, dilation=dil, bias=False)
        conv.weight[:] = .5 * torch.rand(conv.weight.shape, generator=torch.Generator().manual_seed(23))
        #conv.bias[:] = 0.
        inp = conv(inp)
        #print(inp)
        #inp = F.gelu(inp)
        add_pic(inp)
        
    display(VF.to_pil_image(make_grid(grid, nrow=1, pad_value=.3)))
    if not middle:
        print("receptive field radius:", inp[0, 0].argwhere().flatten(0)[-1].item())
        
plot_conv(
    250, 
    #layers=3, dilation=[3, 5, 1],
    #kernel_size=13, layers=4, dilation=[3, 5, 7, 1],
    #kernel_size=7, layers=3, dilation=[1, 1, 1],
    #kernel_size=13, layers=3, dilation=[5, 7, 1],
    kernel_size=13, layers=6, dilation=[2, 3, 5, 7, 9, 1],
)

In [None]:
class ConvBlock1d(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: Union[int, str] = "same",
        dilation: int = 1,
        bias: bool = False,
    ):
        super().__init__()

        self.depth_conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            bias=bias,
        )
        self.point_conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            bias=bias,
        )

    @property
    def weight(self):
        return self.depth_conv.weight
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.depth_conv(x)
        x = self.point_conv(x)
        return x

m = DepthConv1d(1, 3, dilation=3)
print(f"params: {num_module_parameters(m):,}")

inp = torch.ones(1, 1, 10)
outp = m(inp)
print(inp.shape, "->", outp.shape)