In [1]:
from torch import nn
import numpy as np

In [2]:
def fixup_init(layer, num_layers):
    nn.init.normal_(layer.weight, mean=0, std=np.sqrt(
        2 / (layer.weight.shape[0] * np.prod(layer.weight.shape[2:]))) * num_layers ** (-0.25))

In [3]:
def init_normalization(channels, type="bn", affine=True, one_d=False):
    assert type in ["bn", "ln", "in", "gn", "max", "none", None]
    if type == "bn":
        if one_d:
            return nn.BatchNorm1d(channels, affine=affine)
        else:
            return nn.BatchNorm2d(channels, affine=affine)
    elif type == "ln":
        if one_d:
            return nn.LayerNorm(channels, elementwise_affine=affine)
        else:
            return nn.GroupNorm(1, channels, affine=affine)
    elif type == "in":
        return nn.GroupNorm(channels, channels, affine=affine)
    elif type == "gn":
        groups = max(min(32, channels//4), 1)
        return nn.GroupNorm(groups, channels, affine=affine)
    elif type == "max":
        if not one_d:
            return renormalize
        else:
            return lambda x: renormalize(x, -1)
    elif type == "none" or type is None:
        return nn.Identity()

In [4]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio,
                 norm_type, num_layers=1, groups=-1,
                 drop_prob=0., bias=True):
        super(InvertedResidual, self).__init__()
        assert stride in [1, 2, 3]
        self.drop_prob = drop_prob

        hidden_dim = round(in_channels * expand_ratio)

        if groups <= 0:
            groups = hidden_dim

        conv = nn.Conv2d

        if stride != 1:
            self.downsample = nn.Conv2d(in_channels, out_channels, stride, stride)
            nn.init.normal_(self.downsample.weight, mean=0, std=
                            np.sqrt(2 / (self.downsample.weight.shape[0] *
                            np.prod(self.downsample.weight.shape[2:]))))
        else:
            self.downsample = False

        if expand_ratio == 1:
            conv1 = conv(hidden_dim, hidden_dim, 3, stride, 1, groups=groups, bias=bias)
            conv2 = conv(hidden_dim, out_channels, 1, 1, 0, bias=bias)
            fixup_init(conv1, num_layers)
            fixup_init(conv2, num_layers)
            self.conv = nn.Sequential(
                # dw
                conv1,
                init_normalization(hidden_dim, norm_type),
                nn.ReLU(inplace=True),
                # pw-linear
                conv2,
                init_normalization(out_channels, norm_type),
            )
            nn.init.constant_(self.conv[-1].weight, 0)
        else:
            conv1 = conv(in_channels, hidden_dim, 1, 1, 0, bias=bias)
            conv2 = conv(hidden_dim, hidden_dim, 3, stride, 1, groups=groups, bias=bias)
            conv3 = conv(hidden_dim, out_channels, 1, 1, 0, bias=bias)
            fixup_init(conv1, num_layers)
            fixup_init(conv2, num_layers)
            fixup_init(conv3, num_layers)
            self.conv = nn.Sequential(
                # pw
                conv1,
                init_normalization(hidden_dim, norm_type),
                nn.ReLU(inplace=True),
                # dw
                conv2,
                init_normalization(hidden_dim, norm_type),
                nn.ReLU(inplace=True),
                # pw-linear
                conv3,
                init_normalization(out_channels, norm_type)
            )
            if norm_type != "none":
                nn.init.constant_(self.conv[-1].weight, 0)

    def forward(self, x):
        if self.downsample:
            identity = self.downsample(x)
        else:
            identity = x
        if self.training and np.random.uniform() < self.drop_prob:
            return identity
        else:
            return identity + self.conv(x)


class Residual(InvertedResidual):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, groups=1)


class ResnetCNN(nn.Module):
    def __init__(self, input_channels,
                 depths=(16, 32, 64),
                 strides=(3, 2, 2),
                 blocks_per_group=3,
                 norm_type="bn",
                 resblock=InvertedResidual,
                 expand_ratio=2,):
        super(ResnetCNN, self).__init__()
        self.depths = [input_channels] + depths
        self.resblock = resblock
        self.expand_ratio = expand_ratio
        self.blocks_per_group = blocks_per_group
        self.layers = []
        self.norm_type = norm_type
        self.num_layers = self.blocks_per_group*len(depths)
        for i in range(len(depths)):
            self.layers.append(self._make_layer(self.depths[i],
                                                self.depths[i+1],
                                                strides[i],
                                                ))
        self.layers = nn.Sequential(*self.layers)
        self.train()

    def _make_layer(self, in_channels, depth, stride,):

        blocks = [self.resblock(in_channels, depth,
                                expand_ratio=self.expand_ratio,
                                stride=stride,
                                norm_type=self.norm_type,
                                num_layers=self.num_layers,)]

        for i in range(1, self.blocks_per_group):
            blocks.append(self.resblock(depth, depth,
                                        expand_ratio=self.expand_ratio,
                                        stride=1,
                                        norm_type=self.norm_type,
                                        num_layers=self.num_layers,))

        return nn.Sequential(*blocks)

    @property
    def local_layer_depth(self):
        return self.depths[-2]

    def forward(self, inputs):
        return self.layers(inputs)


In [2]:
import torchinfo
import torch

In [6]:
model = ResnetCNN(4, depths=[16, 32, 64], strides=(3, 2, 2), blocks_per_group=3, norm_type="bn", resblock=InvertedResidual, expand_ratio=2,)
model(torch.randn(1, 4, 84,84))
torchinfo.summary(model, input_data=torch.randn(1, 4, 84,84), depth=10)

Layer (type:depth-idx)                        Output Shape              Param #
ResnetCNN                                     [1, 64, 7, 7]             --
├─Sequential: 1-1                             [1, 64, 7, 7]             --
│    └─Sequential: 2-1                        [1, 16, 28, 28]           --
│    │    └─InvertedResidual: 3-1             [1, 16, 28, 28]           --
│    │    │    └─Conv2d: 4-1                  [1, 16, 28, 28]           592
│    │    │    └─Sequential: 4-2              [1, 16, 28, 28]           --
│    │    │    │    └─Conv2d: 5-1             [1, 8, 84, 84]            40
│    │    │    │    └─BatchNorm2d: 5-2        [1, 8, 84, 84]            16
│    │    │    │    └─ReLU: 5-3               [1, 8, 84, 84]            --
│    │    │    │    └─Conv2d: 5-4             [1, 8, 28, 28]            80
│    │    │    │    └─BatchNorm2d: 5-5        [1, 8, 28, 28]            16
│    │    │    │    └─ReLU: 5-6               [1, 8, 28, 28]            --
│    │    │    │   

In [7]:
model

ResnetCNN(
  (layers): Sequential(
    (0): Sequential(
      (0): InvertedResidual(
        (downsample): Conv2d(4, 16, kernel_size=(3, 3), stride=(3, 3))
        (conv): Sequential(
          (0): Conv2d(4, 8, kernel_size=(1, 1), stride=(1, 1))
          (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(8, 8, kernel_size=(3, 3), stride=(3, 3), padding=(1, 1), groups=8)
          (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
          (6): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          (7): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): InvertedResidual(
        (conv): Sequential(
          (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2):

In [2]:
import timm

In [3]:
from common.networks import ConvNeXtImpala, ImpalaCNNLarge, ImpalaNeXtCNNLarge

In [10]:
model = ConvNeXtImpala(4, 12, nn.Linear, 2)

In [11]:
model(torch.randn(1, 4, 84,84))

tensor([[-0.1186, -0.0862,  0.0036, -0.0597, -0.0614, -0.0916, -0.0497, -0.0922,
          0.0408, -0.0544, -0.0239, -0.0627]], grad_fn=<AddBackward0>)

In [12]:
torchinfo.summary(model, input_data=torch.randn(1, 4, 84,84), depth=10)

Layer (type:depth-idx)                             Output Shape              Param #
ConvNeXtImpala                                     [1, 12]                   --
├─ConvNeXt: 1-1                                    [1, 128, 7, 7]            --
│    └─Sequential: 2-1                             [1, 32, 28, 28]           --
│    │    └─Conv2d: 3-1                            [1, 32, 28, 28]           1,184
│    │    └─LayerNorm2d: 3-2                       [1, 32, 28, 28]           64
│    └─Sequential: 2-2                             [1, 128, 7, 7]            --
│    │    └─ConvNeXtStage: 3-3                     [1, 32, 28, 28]           --
│    │    │    └─Identity: 4-1                     [1, 32, 28, 28]           --
│    │    │    └─Sequential: 4-2                   [1, 32, 28, 28]           --
│    │    │    │    └─ConvNeXtBlock: 5-1           [1, 32, 28, 28]           32
│    │    │    │    │    └─Conv2d: 6-1             [1, 32, 28, 28]           1,600
│    │    │    │    │    └─La

Bad pipe message: %s [b"\xa3\x0b\x8cj\x94\x06?\xcc\n\xd2\xcdTf\x82@\xa2C\xa6\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00<\x005\x00/"]
Bad pipe message: %s [b'\xf4Kf"d\x9d\xd5a\xb5\x84t\x97%H\xa9ylU\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0s\xc0w\x00\xc4\x00\xc3\xc0#\xc0\'\x00g\x00@\xc0r']
Bad pipe message: %s [b'\x00\xbe\x00\xbd\xc0\n\xc0\x14\x009\x008\x00\x88\x00\x87\xc0\t\xc0\x13\x003\x002\x00\x9a\x00\x99\x00E\x00D\xc0\x07\xc0\x11\xc0\x08\xc0\x12\x00\x16\x00\x13\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc

In [5]:
model = ImpalaCNNLarge(4, 12, nn.Linear, 2)
torchinfo.summary(model, input_data=torch.randn(1, 4, 84,84), depth=10)

Layer (type:depth-idx)                   Output Shape              Param #
ImpalaCNNLarge                           [1, 12]                   --
├─Sequential: 1-1                        [1, 64, 11, 11]           --
│    └─ImpalaCNNBlock: 2-1               [1, 32, 42, 42]           --
│    │    └─Conv2d: 3-1                  [1, 32, 84, 84]           1,184
│    │    └─MaxPool2d: 3-2               [1, 32, 42, 42]           --
│    │    └─ImpalaCNNResidual: 3-3       [1, 32, 42, 42]           --
│    │    │    └─ReLU: 4-1               [1, 32, 42, 42]           --
│    │    │    └─Conv2d: 4-2             [1, 32, 42, 42]           9,248
│    │    │    └─ReLU: 4-3               [1, 32, 42, 42]           --
│    │    │    └─Conv2d: 4-4             [1, 32, 42, 42]           9,248
│    │    └─ImpalaCNNResidual: 3-4       [1, 32, 42, 42]           --
│    │    │    └─ReLU: 4-5               [1, 32, 42, 42]           --
│    │    │    └─Conv2d: 4-6             [1, 32, 42, 42]           9,248
│  

In [5]:
model = ImpalaNeXtCNNLarge(in_depth=4, actions=12, linear_layer=nn.Linear, model_size=2, spectral_norm=False, stem='orig',
                        convnext_downsampling=False, layer_norm=False, activation_pos='both')
torchinfo.summary(model, input_data=torch.randn(1, 4, 84,84), depth=10)

both
both
both
both
both
both


Layer (type:depth-idx)                        Output Shape              Param #
ImpalaNeXtCNNLarge                            [1, 12]                   --
├─ImpalaNeXtDownsample: 1-1                   [1, 32, 42, 42]           --
│    └─Identity: 2-1                          [1, 4, 84, 84]            --
│    └─Conv2d: 2-2                            [1, 32, 84, 84]           6,304
│    └─MaxPool2d: 2-3                         [1, 32, 42, 42]           --
├─Sequential: 1-2                             [1, 64, 11, 11]           --
│    └─ImpalaNeXtCNNBlock: 2-4                [1, 32, 42, 42]           --
│    │    └─ImpalaNeXtCNNResidual: 3-1        [1, 32, 42, 42]           --
│    │    │    └─GELU: 4-1                    [1, 32, 42, 42]           --
│    │    │    └─Conv2d: 4-2                  [1, 32, 42, 42]           50,208
│    │    │    └─Identity: 4-3                [1, 32, 42, 42]           --
│    │    │    └─GELU: 4-4                    [1, 32, 42, 42]           --
│    │    │  