In [1]:
from torch import nn
import numpy as np

In [2]:
def fixup_init(layer, num_layers):
    nn.init.normal_(layer.weight, mean=0, std=np.sqrt(
        2 / (layer.weight.shape[0] * np.prod(layer.weight.shape[2:]))) * num_layers ** (-0.25))

In [4]:
def init_normalization(channels, type="bn", affine=True, one_d=False):
    assert type in ["bn", "ln", "in", "gn", "max", "none", None]
    if type == "bn":
        if one_d:
            return nn.BatchNorm1d(channels, affine=affine)
        else:
            return nn.BatchNorm2d(channels, affine=affine)
    elif type == "ln":
        if one_d:
            return nn.LayerNorm(channels, elementwise_affine=affine)
        else:
            return nn.GroupNorm(1, channels, affine=affine)
    elif type == "in":
        return nn.GroupNorm(channels, channels, affine=affine)
    elif type == "gn":
        groups = max(min(32, channels//4), 1)
        return nn.GroupNorm(groups, channels, affine=affine)
    elif type == "max":
        if not one_d:
            return renormalize
        else:
            return lambda x: renormalize(x, -1)
    elif type == "none" or type is None:
        return nn.Identity()

In [5]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio,
                 norm_type, num_layers=1, groups=-1,
                 drop_prob=0., bias=True):
        super(InvertedResidual, self).__init__()
        assert stride in [1, 2, 3]
        self.drop_prob = drop_prob

        hidden_dim = round(in_channels * expand_ratio)

        if groups <= 0:
            groups = hidden_dim

        conv = nn.Conv2d

        if stride != 1:
            self.downsample = nn.Conv2d(in_channels, out_channels, stride, stride)
            nn.init.normal_(self.downsample.weight, mean=0, std=
                            np.sqrt(2 / (self.downsample.weight.shape[0] *
                            np.prod(self.downsample.weight.shape[2:]))))
        else:
            self.downsample = False

        if expand_ratio == 1:
            conv1 = conv(hidden_dim, hidden_dim, 3, stride, 1, groups=groups, bias=bias)
            conv2 = conv(hidden_dim, out_channels, 1, 1, 0, bias=bias)
            fixup_init(conv1, num_layers)
            fixup_init(conv2, num_layers)
            self.conv = nn.Sequential(
                # dw
                conv1,
                init_normalization(hidden_dim, norm_type),
                nn.ReLU(inplace=True),
                # pw-linear
                conv2,
                init_normalization(out_channels, norm_type),
            )
            nn.init.constant_(self.conv[-1].weight, 0)
        else:
            conv1 = conv(in_channels, hidden_dim, 1, 1, 0, bias=bias)
            conv2 = conv(hidden_dim, hidden_dim, 3, stride, 1, groups=groups, bias=bias)
            conv3 = conv(hidden_dim, out_channels, 1, 1, 0, bias=bias)
            fixup_init(conv1, num_layers)
            fixup_init(conv2, num_layers)
            fixup_init(conv3, num_layers)
            self.conv = nn.Sequential(
                # pw
                conv1,
                init_normalization(hidden_dim, norm_type),
                nn.ReLU(inplace=True),
                # dw
                conv2,
                init_normalization(hidden_dim, norm_type),
                nn.ReLU(inplace=True),
                # pw-linear
                conv3,
                init_normalization(out_channels, norm_type)
            )
            if norm_type != "none":
                nn.init.constant_(self.conv[-1].weight, 0)

    def forward(self, x):
        if self.downsample:
            identity = self.downsample(x)
        else:
            identity = x
        if self.training and np.random.uniform() < self.drop_prob:
            return identity
        else:
            return identity + self.conv(x)


class Residual(InvertedResidual):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, groups=1)


class ResnetCNN(nn.Module):
    def __init__(self, input_channels,
                 depths=(16, 32, 64),
                 strides=(3, 2, 2),
                 blocks_per_group=3,
                 norm_type="bn",
                 resblock=InvertedResidual,
                 expand_ratio=2,):
        super(ResnetCNN, self).__init__()
        self.depths = [input_channels] + depths
        self.resblock = resblock
        self.expand_ratio = expand_ratio
        self.blocks_per_group = blocks_per_group
        self.layers = []
        self.norm_type = norm_type
        self.num_layers = self.blocks_per_group*len(depths)
        for i in range(len(depths)):
            self.layers.append(self._make_layer(self.depths[i],
                                                self.depths[i+1],
                                                strides[i],
                                                ))
        self.layers = nn.Sequential(*self.layers)
        self.train()

    def _make_layer(self, in_channels, depth, stride,):

        blocks = [self.resblock(in_channels, depth,
                                expand_ratio=self.expand_ratio,
                                stride=stride,
                                norm_type=self.norm_type,
                                num_layers=self.num_layers,)]

        for i in range(1, self.blocks_per_group):
            blocks.append(self.resblock(depth, depth,
                                        expand_ratio=self.expand_ratio,
                                        stride=1,
                                        norm_type=self.norm_type,
                                        num_layers=self.num_layers,))

        return nn.Sequential(*blocks)

    @property
    def local_layer_depth(self):
        return self.depths[-2]

    def forward(self, inputs):
        return self.layers(inputs)


In [11]:
import torchinfo
import torch

In [18]:
model = ResnetCNN(4, depths=[16, 32, 64], strides=(3, 2, 2), blocks_per_group=3, norm_type="bn", resblock=InvertedResidual, expand_ratio=2,)
model(torch.randn(1, 4, 84,84))
torchinfo.summary(model, input_data=torch.randn(1, 4, 84,84), depth=10)

Layer (type:depth-idx)                        Output Shape              Param #
ResnetCNN                                     [1, 64, 7, 7]             --
├─Sequential: 1-1                             [1, 64, 7, 7]             --
│    └─Sequential: 2-1                        [1, 16, 28, 28]           --
│    │    └─InvertedResidual: 3-1             [1, 16, 28, 28]           --
│    │    │    └─Conv2d: 4-1                  [1, 16, 28, 28]           592
│    │    │    └─Sequential: 4-2              [1, 16, 28, 28]           --
│    │    │    │    └─Conv2d: 5-1             [1, 8, 84, 84]            40
│    │    │    │    └─BatchNorm2d: 5-2        [1, 8, 84, 84]            16
│    │    │    │    └─ReLU: 5-3               [1, 8, 84, 84]            --
│    │    │    │    └─Conv2d: 5-4             [1, 8, 28, 28]            80
│    │    │    │    └─BatchNorm2d: 5-5        [1, 8, 28, 28]            16
│    │    │    │    └─ReLU: 5-6               [1, 8, 28, 28]            --
│    │    │    │   

Bad pipe message: %s [b'WHW`\xb7\xa8Y\xba=T\rW4~4\xb9_: [h\x95e\x929\x10(\xd6"\xee\x84{}B\xe4\xd2\xb3IL\xc3\xa0\x19v\x97t\xba\xde']
Bad pipe message: %s [b'i\xc5\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 \xa2']
Bad pipe message: %s [b"\xd6\xd9\x92\x12\xeb+t\x1e\x9b'\x10!\xde\xc7\t9\xa6\xbb Q\x8b\xf0,\x91\x145\xc9\x04\x8a&\x83y\xfa\x96^\xc19\xfcW\xdcs/kp+\xab\xbb\xe64\x83!\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00"]
Bad pipe message: %s [b'\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x

In [16]:
model

tensor([[[[ 1.0370,  0.2577, -0.6504,  ...,  0.8322,  0.0262,  0.4602],
          [ 0.2142,  0.6859,  0.6797,  ..., -0.0790,  1.0790, -0.0139],
          [-0.0260, -0.0173,  0.1069,  ...,  1.3578,  0.0800, -0.1256],
          ...,
          [-0.6908,  0.7027,  0.6468,  ...,  0.1948,  0.0255,  0.2132],
          [ 0.6789, -0.8355, -0.4995,  ...,  0.8612, -0.7073,  1.1848],
          [ 0.2609,  0.7734, -0.3453,  ..., -0.5624,  1.2338,  0.4828]],

         [[ 0.2159,  0.9686, -1.3405,  ...,  0.1267,  0.6852, -0.7005],
          [ 1.4288, -0.7463,  0.2360,  ...,  0.3765,  0.4661,  0.0409],
          [ 0.0570,  0.3688,  0.7673,  ...,  1.0083, -1.6013, -0.2023],
          ...,
          [ 0.8709, -0.3425,  0.8199,  ..., -0.5406,  1.2382, -0.3373],
          [-0.1634,  0.3192, -0.6712,  ...,  0.6687,  1.3732,  0.2786],
          [-1.2140,  0.0601, -0.3932,  ...,  0.2761,  0.0425, -0.2627]],

         [[ 0.4881,  1.1397, -0.1119,  ...,  0.5542,  0.1194,  0.3526],
          [ 0.7851,  0.0861,  

In [19]:
import timm

In [28]:
model = timm.models.ConvNeXt(
        in_chans=4,
        global_pool='avg',
        output_stride= 32,
        depths=(3, 3, 3, 0),
        dims=(16, 32, 64, 64),
        kernel_sizes=7,
        stem_type='patch',
        patch_size=3, # TODO calculate and make sure it is suitable
        conv_mlp=False,
        act_layer='gelu',
        norm_layer=None,
        drop_rate=0.0,
        drop_path_rate=0.0,
    )
model.head.global_pool = nn.Identity()
model.head.norm = nn.Identity()
model.head.flatten = nn.Identity()
model.head.fc = nn.Identity()
model.head.dropout = nn.Identity()
model.stages = model.stages[:-1]

In [29]:
torchinfo.summary(model, input_data=torch.randn(1, 4, 84,84), depth=10)

Layer (type:depth-idx)                        Output Shape              Param #
ConvNeXt                                      [1, 64, 7, 7]             --
├─Sequential: 1-1                             [1, 16, 28, 28]           --
│    └─Conv2d: 2-1                            [1, 16, 28, 28]           592
│    └─LayerNorm2d: 2-2                       [1, 16, 28, 28]           32
├─Sequential: 1-2                             [1, 64, 7, 7]             --
│    └─ConvNeXtStage: 2-3                     [1, 16, 28, 28]           --
│    │    └─Identity: 3-1                     [1, 16, 28, 28]           --
│    │    └─Sequential: 3-2                   [1, 16, 28, 28]           --
│    │    │    └─ConvNeXtBlock: 4-1           [1, 16, 28, 28]           16
│    │    │    │    └─Conv2d: 5-1             [1, 16, 28, 28]           800
│    │    │    │    └─LayerNorm: 5-2          [1, 28, 28, 16]           32
│    │    │    │    └─Mlp: 5-3                [1, 28, 28, 16]           --
│    │    │    │  