# Experimenting with Architectures

#### Packages

In [3]:
# Packages
import numpy as np
import warnings
import math

# Machine learning
import torch
from torch import nn
from torch.nn import functional as F

### Experimenting with Kernel Size

In [4]:
# Kernel sizes on modified OSnet (type 1)
# We could re-use the kernel sizes to make up six bottleneck layers
for i in range(4):
    print([2**n for n in range(4-i, 9-i)])

[16, 32, 64, 128, 256]
[8, 16, 32, 64, 128]
[4, 8, 16, 32, 64]
[2, 4, 8, 16, 32]


In [5]:
# Kernel sizes on modified OSnet (type 2)
# With type 1, we don't have any kernels between 128 and 256 (for instance)
# type 2 might be better for feature extraction
for i in np.arange(1, 7)[::-1]:
    print([(2**i)*n for n in range(1, 6)])

[64, 128, 192, 256, 320]
[32, 64, 96, 128, 160]
[16, 32, 48, 64, 80]
[8, 16, 24, 32, 40]
[4, 8, 12, 16, 20]
[2, 4, 6, 8, 10]


In [6]:
# We can choose to make these kernels sizes odd instead
# So the minimum kernel size that we will be looking at is 3, not 2
# So just, add 1 to all kernels sizes?

# Kernel sizes on modified OSnet (type 2)
# With type 1, we don't have any kernels between 128 and 256 (for instance)
# type 2 might be better for feature extraction
for i in np.arange(1, 7)[::-1]:
    print([(2**i)*n + 1 for n in range(1, 6)])

[65, 129, 193, 257, 321]
[33, 65, 97, 129, 161]
[17, 33, 49, 65, 81]
[9, 17, 25, 33, 41]
[5, 9, 13, 17, 21]
[3, 5, 7, 9, 11]


In [7]:
# We could also look at using prime number kernel sizes
# Refer: https://openreview.net/pdf?id=PDYs7Z2XFGv

# The Sieve of Eratosthenes method of calculating the primes less than the limit
def getPrimes(limit):
    # The list of prime numbers
    primes = []
    # The boolean list of whether a number is prime
    numbers = [True] * limit
    # Loop all of the numbers in numbers starting from 2
    for i in range(2, limit):
        # If the number is prime
        if numbers[i]:
            # Add it onto the list of prime numbers
            primes.append(i)
            # Loop over all of the other factors in the list
            for n in range(i ** 2, limit, i):
                # Make them not prime
                numbers[n] = False

    # Return the list of prime numbers
    return primes

def get_closest_prime_kernel_size(kernel_size):
    primes = getPrimes(kernel_size + 100)

    # The distance away from the closest prime
    maxDist = math.inf
    # The closest prime
    closest_prime = 0

    # Loop all of the primes
    for p in primes:
        # if the prime number is closer than maxDist
        if abs(kernel_size - p) < maxDist:
            maxDist = abs(kernel_size - p)
            closest_prime = p
    return closest_prime

kernel_size = 320
closest_prime = get_closest_prime_kernel_size(kernel_size)
# Print the output
print(closest_prime, "is the closest prime kernel size")

317 is the closest prime kernel size


In [8]:
# Kernel sizes on modified OSnet (type 2p)
# Modifying type 2 to use closest prime kernel size
for i in np.arange(1, 7)[::-1]:
    kernels = [(2**i)*n for n in range(1, 6)]
    prime_kernels = [get_closest_prime_kernel_size(foo) for foo in kernels]
    print(prime_kernels)

[61, 127, 191, 257, 317]
[31, 61, 97, 127, 157]
[17, 31, 47, 61, 79]
[7, 17, 23, 31, 41]
[3, 7, 11, 17, 19]
[2, 3, 5, 7, 11]


### Modifying the OSnet Architecture

Github Link: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/models/osnet_ain.py#L309

Link to Paper: https://arxiv.org/abs/1905.00953

#### Basic Layers

In [9]:
# MODIFIED
class ConvLayer(nn.Module):
    """Convolution layer (conv + bn + relu)."""

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        groups=1,
        IN=False
    ):
        super(ConvLayer, self).__init__()
        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=False,
            groups=groups
        )
        if IN:
            self.bn = nn.InstanceNorm1d(out_channels, affine=True)
        else:
            self.bn = nn.BatchNorm1d(out_channels)
        self.silu = nn.SiLU() # Using the swish function instead of ReLU

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.silu(x)

In [10]:
# MODIFIED
class Conv1x1(nn.Module):
    """1x1 convolution + bn + relu."""

    def __init__(self, in_channels, out_channels, stride=1, groups=1):
        super(Conv1x1, self).__init__()
        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=stride,
            padding=0,
            bias=False,
            groups=groups
        )
        self.bn = nn.BatchNorm1d(out_channels)
        self.silu = nn.SiLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.silu(x)

In [11]:
# MODIFIED
class Conv1x1Linear(nn.Module):
    """1x1 convolution + bn (w/o non-linearity)."""

    def __init__(self, in_channels, out_channels, stride=1, bn=True):
        super(Conv1x1Linear, self).__init__()
        self.conv = nn.Conv1d(
            in_channels, 
            out_channels, 
            kernel_size=1, 
            stride=stride, 
            padding=0, 
            bias=False
        )
        self.bn = None
        if bn:
            self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        return x

In [12]:
# MODIFIED
class LightConvNxN(nn.Module):
    """Lightweight NxN convolution.

    1x1 (linear) + dw NxN (nonlinear).
    """

    def __init__(self, in_channels, out_channels, kernel_size):
        super(LightConvNxN, self).__init__()
        self.conv1 = nn.Conv1d(
            in_channels, out_channels, 1, stride=1, padding=0, bias=False
        )
        # Before applying kernel size, use padding=same if required
        padding = 'same' if kernel_size != 3 else 1
        self.conv2 = nn.Conv1d(
            out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
            bias=False,
            groups=out_channels
        )
        self.bn = nn.BatchNorm1d(out_channels)
        self.silu = nn.SiLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.bn(x)
        return self.silu(x)

In [13]:
# MODIFIED
class LightConvStream(nn.Module):
    """Lightweight convolution stream. Stacking conv layers for larger receptive field."""

    def __init__(self, in_channels, out_channels, depth, kernel_size):
        super(LightConvStream, self).__init__()
        assert depth >= 1, 'depth must be equal to or larger than 1, but got {}'.format(
            depth
        )
        layers = []
        layers += [LightConvNxN(in_channels, out_channels, kernel_size)]
        for i in range(depth - 1):
            layers += [LightConvNxN(out_channels, out_channels, kernel_size)]
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

#### Building blocks for omni-scale feature learning

In [14]:
# MODIFIED
class ChannelGate(nn.Module):
    """A mini-network that generates channel-wise gates conditioned on input tensor."""

    def __init__(
        self,
        in_channels,
        num_gates=None,
        return_gates=False,
        gate_activation='sigmoid',
        reduction=8
    ):
        super(ChannelGate, self).__init__()
        if num_gates is None:
            num_gates = in_channels
        self.return_gates = return_gates
        # Global average pooling here is only used to get channel-wise weights
        # Each scale learnt is given an individual weight
        # The vector of weights is then multiplied with the different scales to get weighted scales
        # This does not get rid of localisation features in any way
        self.global_avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Conv1d(
            in_channels,
            in_channels // reduction,
            kernel_size=1,
            bias=True,
            padding=0
        )
        self.norm1 = None
        self.relu = nn.ReLU()
        self.silu = nn.SiLU()
        self.fc2 = nn.Conv1d(
            in_channels // reduction,
            num_gates,
            kernel_size=1,
            bias=True,
            padding=0
        )
        if gate_activation == 'sigmoid':
            self.gate_activation = nn.Sigmoid()
        elif gate_activation == 'relu':
            self.gate_activation = nn.ReLU()
        elif gate_activation == 'silu':
            self.gate_activation = nn.SiLU()
        elif gate_activation == 'linear':
            self.gate_activation = None
        else:
            raise RuntimeError(
                "Unknown gate activation: {}".format(gate_activation)
            )

    def forward(self, x):
        input = x
        x = self.global_avgpool(x)
        x = self.fc1(x)
        x = self.silu(x)
        x = self.fc2(x)
        if self.gate_activation is not None:
            x = self.gate_activation(x)
        if self.return_gates:
            return x
        # Each scale should now be weighted
        return input * x

In [15]:
# MODIFIED
class OSBlock(nn.Module):
    """Omni-scale feature learning block."""

    def __init__(self, in_channels, out_channels, kernel_sizes, reduction=4, T=5, stacking=False, **kwargs):
        super(OSBlock, self).__init__()
        assert T >= 1
        assert out_channels >= reduction and out_channels % reduction == 0
        mid_channels = out_channels // reduction

        self.conv1 = Conv1x1(in_channels, mid_channels)
        # Stacking conv layers for increased receptive field and low cost
        # We can either use non-stacked large kernels or stacked small kernels
        # k=3 with T=5 gives an effective receptive field of (2*5+1=11)
        self.conv2 = nn.ModuleList()
        if stacking:
            # Iterating through different scales
            for n, t in enumerate(range(1, T + 1)):
                self.conv2 += [LightConvStream(mid_channels, mid_channels, t, kernel_sizes[n])]
        else:
            # Using larger kernel sizes without stacking
            for kernel_size in kernel_sizes:
                self.conv2 += [LightConvNxN(mid_channels, mid_channels, kernel_size)]
        
        self.gate = ChannelGate(mid_channels)
        self.conv3 = Conv1x1Linear(mid_channels, out_channels)
        self.downsample = None
        if in_channels != out_channels:
            self.downsample = Conv1x1Linear(in_channels, out_channels)

    def forward(self, x):
        identity = x
        x1 = self.conv1(x)
        x2 = 0
        for conv2_t in self.conv2:
            x2_t = conv2_t(x1)
            x2 = x2 + self.gate(x2_t)
        x3 = self.conv3(x2)
        if self.downsample is not None:
            identity = self.downsample(identity)
        out = x3 + identity
        return F.silu(out)

In [16]:
# MODIFIED
class OSBlockINin(nn.Module):
    """Omni-scale feature learning block with instance normalization."""

    def __init__(self, in_channels, out_channels, kernel_sizes, reduction=4, T=5, stacking=False, **kwargs):
        super(OSBlockINin, self).__init__()
        assert T >= 1
        assert out_channels >= reduction and out_channels % reduction == 0
        mid_channels = out_channels // reduction

        self.conv1 = Conv1x1(in_channels, mid_channels)
        # Stacking conv layers for increased receptive field and low cost
        # We can either use non-stacked large kernels or stacked small kernels
        # k=3 with T=5 gives an effective receptive field of (2*5+1=11)
        self.conv2 = nn.ModuleList()
        if stacking:
            # Iterating through different scales
            for n, t in enumerate(range(1, T + 1)):
                self.conv2 += [LightConvStream(mid_channels, mid_channels, t, kernel_sizes[n])]
        else:
            # Using larger kernel sizes without stacking
            for kernel_size in kernel_sizes:
                self.conv2 += [LightConvStream(mid_channels, mid_channels, 1, kernel_size)]
        
        self.gate = ChannelGate(mid_channels)
        self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn=False)
        self.downsample = None
        if in_channels != out_channels:
            self.downsample = Conv1x1Linear(in_channels, out_channels)
        self.IN = nn.InstanceNorm1d(out_channels, affine=True)
    
    def forward(self, x):
        identity = x
        x1 = self.conv1(x)
        x2 = 0
        for conv2_t in self.conv2:
            x2_t = conv2_t(x1)
            x2 = x2 + self.gate(x2_t)
        x3 = self.conv3(x2)
        x3 = self.IN(x3) # IN inside residual
        if self.downsample is not None:
            identity = self.downsample(identity)
        out = x3 + identity
        return F.relu(out)

#### Network Architecture

In [29]:
# MODIFIED
class OSNet(nn.Module):
    """Omni-Scale Network.
    
    Reference:
        - Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.
        - Zhou et al. Learning Generalisable Omni-Scale Representations
          for Person Re-Identification. TPAMI, 2021.
    """

    def __init__(
        self,
        blocks,
        layers,
        channels,
        kernel_sizes,
        strides,
        conv1_IN=False,
        in_channels=1,
        stacking=True,
        initial_dim_reduction = False
    ):
        super(OSNet, self).__init__()
        num_blocks = len(blocks)
        assert num_blocks == len(layers)
        assert num_blocks == len(channels) - 1

        # options
        self.initial_dim_reduction = initial_dim_reduction

        # convolutional backbone
        self.conv1 = ConvLayer(
            in_channels, channels[0], 7, stride=strides[0], padding=3, IN=conv1_IN
        )
        self.maxpool = nn.MaxPool1d(3, stride=strides[1], padding=1)

        ## OSnet bottlenecks and dimensionality reduction
        # conv2 = bottleneck x2
        if not self.initial_dim_reduction:
            channels_init = in_channels
        else:
            channels_init = channels[0]
        self.conv2 = self._make_layer(
            blocks[0], layers[0], kernel_sizes[0], channels_init, channels[1], stacking
        )
        # pool2 = 1x1 conv + 2x2 avg pool + !!!stride 2!!!
        # Length of the array reduced by x2
        self.pool2 = nn.Sequential(
            Conv1x1(channels[1], channels[1]), nn.AvgPool1d(2, stride=strides[2])
        )

        # conv3 = bottleneck x2
        self.conv3 = self._make_layer(
            blocks[1], layers[1], kernel_sizes[1], channels[1], channels[2], stacking
        )
        # pool3 = 1x1 conv + 2x2 avg pool + !!!stride 2!!!
        # Length of the array reduced by x4
        self.pool3 = nn.Sequential(
            Conv1x1(channels[2], channels[2]), nn.AvgPool1d(2, stride=strides[3])
        )

        # conv4 = bottleneck x2
        self.conv4 = self._make_layer(
            blocks[2], layers[2], kernel_sizes[2], channels[2], channels[3], stacking
        )
        self.conv5 = Conv1x1(channels[3], channels[3])
        
        self._init_params()

    def _make_layer(self, blocks, layer, kernel_sizes, in_channels, out_channels, stacking):
        # I'm guessing layer variable is not used here because it's always (2,2,2)
        layers = []
        layers += [blocks[0](in_channels, out_channels, kernel_sizes=kernel_sizes[0], stacking=stacking)]
        for i in range(1, len(blocks)):
            layers += [blocks[i](out_channels, out_channels, kernel_sizes=kernel_sizes[i], stacking=stacking)]
        return nn.Sequential(*layers)

    def _init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu'
                )
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.InstanceNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def featuremaps(self, x):
        if self.initial_dim_reduction:
            x = self.conv1(x)
            x = self.maxpool(x)
        # Rest of OSnet
        print(x.shape)
        x = self.conv2(x)
        print(x.shape)
        x = self.pool2(x)
        print(x.shape)
        x = self.conv3(x)
        x = self.pool3(x)
        x = self.conv4(x)
        x = self.conv5(x).unsqueeze(1)
        return x

    def forward(self, x):
        # Feature maps is all we need from frontend
        return self.featuremaps(x)

#### OSnet with different configurations

In [18]:
def osnet_ain_x1_0(kernel_sizes, strides, stacking=True, initial_dim_reduction=False):
    model = OSNet(
        blocks=[
            [OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
            [OSBlockINin, OSBlock]
        ],
        layers=[2, 2, 2],
        channels=[64, 256, 384, 512],
        conv1_IN=True,
        kernel_sizes=kernel_sizes,
        strides=strides,
        stacking=stacking,
        initial_dim_reduction = initial_dim_reduction
    )
    return model

In [19]:
def osnet_ain_x0_75(kernel_sizes, strides, stacking=True, initial_dim_reduction=False):
    model = OSNet(
        blocks=[
            [OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
            [OSBlockINin, OSBlock]
        ],
        layers=[2, 2, 2],
        channels=[48, 192, 288, 384],
        conv1_IN=True,
        kernel_sizes=kernel_sizes,
        strides=strides,
        stacking=stacking,
        initial_dim_reduction = initial_dim_reduction
    )
    return model

In [20]:
def osnet_ain_x0_5(kernel_sizes, strides, stacking=True, initial_dim_reduction=False):
    model = OSNet(
        blocks=[
            [OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
            [OSBlockINin, OSBlock]
        ],
        layers=[2, 2, 2],
        channels=[32, 128, 192, 256],
        conv1_IN=True,
        kernel_sizes=kernel_sizes,
        strides=strides,
        stacking=stacking,
        initial_dim_reduction = initial_dim_reduction
    )
    return model

In [21]:
def osnet_ain_x0_25(kernel_sizes, strides, stacking=True, initial_dim_reduction=False):
    model = OSNet(
        blocks=[
            [OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
            [OSBlockINin, OSBlock]
        ],
        layers=[2, 2, 2],
        channels=[16, 64, 96, 128],
        conv1_IN=True,
        kernel_sizes=kernel_sizes,
        strides=strides,
        stacking=stacking,
        initial_dim_reduction = initial_dim_reduction
    )
    return model

In [22]:
def osnet_ain_custom(channels, kernel_sizes, strides, stacking=True, initial_dim_reduction=False):
    model = OSNet(
        blocks=[
            [OSBlockINin, OSBlockINin], [OSBlock, OSBlockINin],
            [OSBlockINin, OSBlock]
        ],
        layers=[2, 2, 2],
        channels=channels,
        conv1_IN=True,
        kernel_sizes=kernel_sizes,
        strides=strides,
        stacking=stacking,
        initial_dim_reduction = initial_dim_reduction
    )
    return model

#### Running OSnet 

In [23]:
k_default = [3,3,3,3,3]

network = osnet_ain_x0_25(
            kernel_sizes=[
                [k_default, k_default], [k_default, k_default],
                [k_default, k_default]
            ],
            strides=[2,2,8,4],
            stacking=True,
            initial_dim_reduction=False
        )

input = torch.randn(64, 1, 4096)
output = network(input)
print(output.shape)

torch.Size([64, 1, 4096])
torch.Size([64, 64, 4096])
torch.Size([64, 64, 512])
torch.Size([64, 1, 128, 128])


In [24]:
# Kernel sizes on modified OSnet (type 2)
# With type 1, we don't have any kernels between 128 and 256 (for instance)
# type 2 might be better for feature extraction
for i in np.arange(1, 7)[::-1]:
    print([(2**i)*n + 1 for n in range(1, 6)])

[65, 129, 193, 257, 321]
[33, 65, 97, 129, 161]
[17, 33, 49, 65, 81]
[9, 17, 25, 33, 41]
[5, 9, 13, 17, 21]
[3, 5, 7, 9, 11]


In [25]:
# Does it work with different kernel sizes?
k_check = [65, 129, 193, 257, 321]

network = osnet_ain_x0_25(
            kernel_sizes=[
                [k_check, k_check], [k_check, k_check],
                [k_check, k_check]
            ],
            strides=[2,2,8,4],
            stacking=False,
            initial_dim_reduction=False
        )

input = torch.randn(64, 1, 4096)
output = network(input)
print(output.shape)

torch.Size([64, 1, 4096])
torch.Size([64, 64, 4096])
torch.Size([64, 64, 512])
torch.Size([64, 1, 128, 128])


In [26]:
# All Custom Kernels
kernel_sizes = []
tmp = []
for i in np.arange(1, 7)[::-1]:
    tmp.append([(2**i)*n + 1 for n in range(1, 6)])
    if i%2!=0:
        kernel_sizes.append(tmp)
        tmp = []

# Without initial dim reduction, num channels starts with channels[1]
network = osnet_ain_custom(
            channels=[16, 64, 96, 128],
            kernel_sizes=kernel_sizes,
            strides=[2,2,8,4],
            stacking=False,
            initial_dim_reduction=False
        )

input = torch.randn(1, 1, 4096)
output = network(input)
print(output.shape)

torch.Size([1, 1, 4096])
torch.Size([1, 64, 4096])
torch.Size([1, 64, 512])
torch.Size([1, 1, 128, 128])


In [27]:
# Kernel sizes on modified OSnet (type 1)
# We could re-use the kernel sizes to make up six bottleneck layers
kernel_sizes = []
kernel_sizes.append([[16, 32, 64, 128, 256], [8, 16, 32, 64, 128]])
kernel_sizes.append([[8, 16, 32, 64, 128], [2, 4, 8, 16, 32]])
kernel_sizes.append([[2, 4, 8, 16, 32], [2, 4, 8, 16, 32]])

# Without initial dim reduction, num channels starts with channels[1]
# We can reduce these number for a smaller network
# ChannelGate has a reduction factor which is 32 by default. So channels=32 cannot be used with this.
network = osnet_ain_custom(
            channels=[16, 32, 64, 128],
            kernel_sizes=kernel_sizes,
            strides=[2,2,8,4],
            stacking=False,
            initial_dim_reduction=False
        )

input = torch.randn(64, 1, 4096)
output = network(input)
print(output.shape)

torch.Size([64, 1, 4096])
torch.Size([64, 32, 4096])
torch.Size([64, 32, 512])
torch.Size([64, 1, 128, 128])


  return F.conv1d(input, weight, bias, self.stride,


In [28]:
# Kernel sizes on modified OSnet (type 1odd)
# We could re-use the kernel sizes to make up six bottleneck layers
kernel_sizes = []
kernel_sizes.append([[17, 33, 65, 129, 257], [9, 17, 33, 65, 129]])
kernel_sizes.append([[9, 17, 33, 65, 129], [3, 5, 9, 17, 33]])
kernel_sizes.append([[3, 5, 9, 17, 33], [3, 5, 9, 17, 33]])

network = osnet_ain_custom(
            channels=[16, 32, 64, 128],
            kernel_sizes=kernel_sizes,
            strides=[2,2,8,4],
            stacking=False,
            initial_dim_reduction=False
        )

input = torch.randn(64, 1, 4096)
output = network(input)
print(output.shape)

torch.Size([64, 1, 4096])
torch.Size([64, 32, 4096])
torch.Size([64, 32, 512])
torch.Size([64, 1, 128, 128])


In [31]:
# Kernel sizes on modified OSnet (type 1odd)
# We could re-use the kernel sizes to make up six bottleneck layers
kernel_sizes = []
kernel_sizes.append([[17, 33, 65, 129, 257], [9, 17, 33, 65, 129]])
kernel_sizes.append([[9, 17, 33, 65, 129], [3, 5, 9, 17, 33]])
kernel_sizes.append([[3, 5, 9, 17, 33], [3, 5, 9, 17, 33]])

network = osnet_ain_custom(
            channels=[16, 32, 64, 128],
            kernel_sizes=kernel_sizes,
            strides=[4,2,2,2],
            stacking=False,
            initial_dim_reduction=True
        )

input = torch.randn(64, 1, 4096)
output = network(input)
print(output.shape)

torch.Size([64, 16, 512])
torch.Size([64, 32, 512])
torch.Size([64, 32, 256])
torch.Size([64, 1, 128, 128])
