### ConvBranches

First of all import library:

In [1]:
import torch
from convblock import ConvBlock, ConvBranches, Config

In essence, **ConvBranches** is just several **ConvBlock** modules stacked in parallel branches, that are concatenated, summed or multiplied in the end.
The same concepts of parameters vectorization that were explained in tutorial on **ConvBlock** can be applied to **ConvBranches** module. As example we will can create **ASPP** module used by **DeepLab** architecture:

In [25]:
input_shape = (32, 32, 32)
rates = [1, 3, 5, 12]
filters = 16
ConvBranches(
    input_shape=input_shape, mode='.',
    branch0=dict(layout='p cna u',
                c={'filters': filters, 'kernel_size': 1,
                    'stride': 1, 'dilation': 1},
                p={'mode': 'avg', 'output_size': 1, 'adaptive': True},
                u={'size': tuple(input_shape[1:])}),
    branch1=dict(layout='cna', c={'filters': filters, 'kernel_size': 3,
                                  'stride': 1, 'dilation': rates[0]}),
    branch2=dict(layout='cna', c={'filters': filters, 'kernel_size': 3,
                                  'stride': 1, 'dilation': rates[1]}),
    branch3=dict(layout='cna', c={'filters': filters, 'kernel_size': 3,
                                  'stride': 1, 'dilation': rates[2]}),
    branch4=dict(layout='cna', c={'filters': filters, 'kernel_size': 3,
                                  'stride': 1, 'dilation': rates[3]})
)

ConvBranches(
  (branches): ModuleDict(
    (branch0): ConvBlock(
      (Module_0): AdaptiveAvgPool(input_shape=[32 32 32], output_shape=[32  1  1])
      (Module_1): Conv(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (Module_2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Module_3): ReLU(inplace=True)
      (Module_4): Upsample(input_shape=(16, 1, 1), output_shape=(16, 32, 32), mode='bilinear')
    )
    (branch1): ConvBlock(
      (Module_0): Conv(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1, 1, 1), padding_mode=constant, bias=False)
      (Module_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Module_2): ReLU(inplace=True)
    )
    (branch2): ConvBlock(
      (Module_0): Conv(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3, 3, 3), padding_mode=constant, dilation=(3, 3), bias=False)
      (Module_1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_run

Another example would be **InceptionA** module:

In [26]:
ConvBranches(
    input_shape=(32, 32, 32), mode='.',
    branch1x1={'layout': 'cna', 'c': {'filters': 32, 'kernel_size': 1}},
    branch_pool={'layout': 'p cna', 'c': {'filters': 32, 'kernel_size': 1},
                 'p': {'kernel_size': 3, 'stride': 1, 'mode': 'avg'}},
    branch3x3={'layout': 'cna cna', 'c': {'filters': [48, 64], 'kernel_size': [1, 3]}},
    branch5x4={'layout': 'cna cna cna', 'c': {'filters': [64, 96, 96], 'kernel_size': [1, 3, 3]}}
)

ConvBranches(
  (branches): ModuleDict(
    (branch1x1): ConvBlock(
      (Module_0): Conv(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (Module_1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Module_2): ReLU(inplace=True)
    )
    (branch3x3): ConvBlock(
      (Module_0): Conv(32, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (Module_1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Module_2): ReLU(inplace=True)
      (Module_3): Conv(48, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1, 1, 1), padding_mode=constant, bias=False)
      (Module_4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Module_5): ReLU(inplace=True)
    )
    (branch5x4): ConvBlock(
      (Module_0): Conv(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (Module_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    

Or **InceptionB** module:

In [27]:
ConvBranches(
    input_shape=(32, 32, 32), mode='.',
    branch1x1={'layout': 'cna', 'c': {'filters': 192, 'kernel_size': 1}},
    branch_pool={'layout': 'p cna', 'c': {'filters': 192, 'kernel_size': 1},
                 'p': {'mode': 'avg', 'kernel_size': 3, 'stride': 1}},
    branch7x7={'layout': 'cna cna cna', 'c': {'kernel_size': [(1, 1), (1, 7), (7, 1)],
                                              'filters': [128, 128, 192]}},
    branch7x7dbl={'layout': 'cna cna cna cna cna',
                  'c': {'kernel_size': [(1, 1), (1, 7), (7, 1), (1, 7), (7, 1)],
                        'filters': (128, 128, 128, 128, 192)}}
)

ConvBranches(
  (branches): ModuleDict(
    (branch1x1): ConvBlock(
      (Module_0): Conv(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (Module_1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Module_2): ReLU(inplace=True)
    )
    (branch7x7): ConvBlock(
      (Module_0): Conv(32, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (Module_1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Module_2): ReLU(inplace=True)
      (Module_3): Conv(128, 128, kernel_size=(1, 7), stride=(1, 1), padding=(0, 0, 3, 3), padding_mode=constant, bias=False)
      (Module_4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (Module_5): ReLU(inplace=True)
      (Module_6): Conv(128, 192, kernel_size=(7, 1), stride=(1, 1), padding=(3, 3, 0, 0), padding_mode=constant, bias=False)
      (Module_7): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_run

### Creating custom complex convolutional modules using ConvBlock and ConvBranches

Let's take a look on implementation of **NASCell** from **NASNet** architecture. First of all, we would need to import **Module** class whis is similiar to **torch.nn.Module**, but can get track of input-output shapes and strides

In [28]:
from convblock import Module

In [29]:
class NASCell(Module):

    def __init__(self, input_shape: 'ArrayLike[int]',
                 stride: tuple = (1, 1)):
        super().__init__(input_shape)
        filters = self.input_shape[:, 0]
        stride = [int(v) for v in stride]
        self.conv_3x3_1 = ConvBlock(
            input_shape=self.input_shape[0],
            layout='cna',
            c=dict(kernel_size=3,
                   filters=filters[0],
                   groups=filters[0],
                   stride=stride[0])
        )
        self.conv_3x3_2 = ConvBlock(
            input_shape=self.input_shape[1],
            layout='cna',
            c=dict(kernel_size=3,
                   filters=filters[1],
                   groups=filters[1],
                   stride=stride[1])
        )
        self.pool_3x3 = ConvBlock(
            input_shape=self.input_shape[0],
            layout='p', p=dict(kernel_size=3,
                               stride=stride[0])
        )
        self.block_1 = ConvBranches(
            input_shape=self.input_shape[0],
            mode='+',
            branch_conv7={
                'layout': 'cna',
                'c': {
                    'kernel_size': 7,
                    'filters': filters[0],
                    'groups': filters[0],
                    'stride': stride[0]
                }
            },
            branch_pool={
                'layout': 'p',
                'p': {
                    'kernel_size': 3,
                    'stride': stride[0]
                }
            }
        )
        self.block_2 = ConvBranches(
            input_shape=self.input_shape[0],
            mode='+',
            branch_conv3={
                'layout': 'cna',
                'c': {
                    'kernel_size': 7,
                    'filters': filters[0],
                    'groups': filters[0],
                    'stride': stride[0]
                }
            },
            branch_pool={
                'layout': 'p',
                'p': {
                    'kernel_size': 3,
                    'stride': stride
                }
            }
        )
        self.block_3 = ConvBranches(
            input_shape=self.input_shape[1],
            mode='+',
            branch_conv3={
                'layout': 'cna',
                'c': {
                    'kernel_size': 5,
                    'filters': filters[1],
                    'groups': filters[1],
                    'stride': stride[1]
                }
            },
            branch_conv5={
                'layout': 'cna',
                'c': {
                    'kernel_size': 3,
                    'filters': filters[1],
                    'groups': filters[1],
                    'stride': stride[1]
                }
            }
        )

    def forward(self, inputs):
        x, y = inputs
        x1 = self.block_1(x)
        x2 = self.block_2(x)
        x3 = self.conv_3x3_1(x2) + self.pool_3x3(x)
        x4 = self.block_3(y)
        x5 = self.conv_3x3_2(y) + x
        return torch.cat([x1, x2, x3, x4, x5], dim=1)


In [30]:
cell = NASCell([(32, 32, 32), (32, 32, 32)])

In [31]:
cell([torch.rand(2, 32, 32, 32), torch.rand(2, 32, 32, 32)]).shape

torch.Size([2, 160, 32, 32])