In [6]:
import math

import numpy as np
import torch
import torch.nn as nn

In [7]:
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):

        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))

In [None]:
class Star2(nn.Module):
    def __init__(self, ch, layer, reverse=False):
        super().__init__()
        self.reverse = reverse
        ch = ch[::-1] if self.reverse else ch

        ch_ = [[ch[i], ch[i+1]] for i in range(len(ch)-1)]
        
        self.layers = []
        for c1, c2, in ch_:
            if layer.lower() == 'conv':
                self.layers += [Conv(c1, c2, 3, 1, autopad(3))]
            elif layer.lower() == 'pw_conv':
                self.layers += [Conv(c1, c2, 1, 1, autopad(1))]
            elif layer.lower() == 'gpw_conv':
                self.layers += [Conv(c1, c2, 1, 1, autopad(1), g=math.gcd(c1, c2))]
        self.layers = nn.ModuleList(self.layers)


    def forward(self, x):
        x = x[::-1] if self.reverse else x

        y = x[0]
        for layer, xx in zip(self.layers, x[1:]):
            y = layer(y) * xx
        return y
    

In [27]:
s = Star2([64, 128], 'pw_conv', True)

In [28]:
s.layers

ModuleList(
  (0): Conv(
    (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
)