<a href="https://colab.research.google.com/github/maxmatical/pytorch-projects/blob/master/xresnet_architectures.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Xresnet based architectures

- xresnet
- xres2net



In [0]:
import torch.nn as nn
import torch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
import math

In [0]:
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
        
    def forward(self, x):
        return x.view(x.size(0), -1)
    
# weight initialization
def init_cnn(m):
    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
    for l in m.children(): init_cnn(l)

# activation function
act_fn = nn.ReLU(inplace=True)
        

In [0]:
# default conv layer
def conv(ni, nf, kernel_size = 3, stride = 1, bias = False):
    """
    ni: n of in channels
    
    nf: number of filters
    
    """
    return nn.Conv2d(ni, nf, kernel_size = kernel_size, stride = stride, padding = kernel_size//2, bias = bias)


# conv + bn + act_fun
def conv_layer(ni, nf, kernel_size=3, stride=1, zero_bn=False, act=True):
    bn = nn.BatchNorm2d(nf)
    nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
    layers = [conv(ni, nf, kernel_size, stride = stride), bn]
    if act: 
        layers.append(act_fn)
    return nn.Sequential(*layers)
       



In the Resblock, the last conv_layer is set to zero_bn 

This is because if zero_bn is true, the entire conv_layer is set to 0 since the bn sets the outputs to 0

This lets the NN learn when a resblock is not needed, it learns to skip over the block if it doesn't contribute

If it contributes, then it learns the weight of the bn so then the resblock actually has weights

In [0]:
# residual block

def no_op(x): return x # no operations done if ni == nf


class Resblock(nn.Module):
    def __init__(self, expansion, ni, nh, stride = 1):
        """
        nh = number of filters for the middle layers
        """
        super(Resblock, self).__init__()
        nf, ni = nh*expansion, ni*expansion
        if expansion == 1:
            layers = [conv_layer(ni, nh, 3, stride = stride), #either stride 1 (if res block) or stride = 2 for downsampling blocks
                      conv_layer(nh, nf, 3, zero_bn = True, act = False)] # stride 1
        else:
            layers = [conv_layer(ni, nh, 1),
                      conv_layer(nh, nh, 3, stride = stride), #either stride 1 (if res block) or stride = 2 for downsampling blocks
                      conv_layer(nh, nf, 1, zero_bn = True, act = False)]
            
        self.convs = nn.Sequential(*layers)
        
        # if ni != nf, use a 1x1 conv to get the same channels, otherwise return x (no operations)
        self.idconv = no_op if ni == nf else conv_layer(ni, nf, 1, act = False)
        
        self.pooling = no_op if stride == 1 else nn.AvgPool2d(2, ceil_mode=True)
        
    def forward(self, x):
        x1 = self.convs(x) # convs operations
        x2 = self.idconv(self.pooling(x)) # pooling and 1x1 conv layer operations
        out = x1+x2
        return act_fn(out)


Tests for the resblock

In [0]:
expansion = 4

tmp = torch.randn((16, 3*expansion, 226, 226)).to(device)

a = Resblock(expansion, 3, 20).to(device)
print(a(tmp).shape)
bs, n_channels, H, W = a(tmp).size()
print(H, W, H*W)

a2 = Resblock(expansion, 3, 20, stride = 2).to(device)
print(a2(tmp).shape)
bs, n_channels, H, W = a2(tmp).size()
print(H, W, H*W)

torch.Size([16, 80, 226, 226])
226 226 51076
torch.Size([16, 80, 113, 113])
113 113 12769


For XResNet. Replaces the 7x7 stride 2 stem with 3 3x3 convs (stride = 2 for first conv) in a row

This is done beecause 3x3 convolutions are much less computationally expensive than a 7x7 (5.4 times more expensive)

In [0]:
# creating XResNet
class XResNet(nn.Sequential):
    def __init__(self, expansion, layers, ni = 3, n_classes=1000):
        
        """
        layers = list of length 4. 
        layer[i] = how many resblocks in each of the 4 chunks of the network
        expansion = what value to multiply the intermediate n_out of the convlayer by
        """
        #stem
        stem = []
        stem_sizes = [ni, 32, 32, 64]
        for i in range(3):
            stem.append(conv_layer(stem_sizes[i], stem_sizes[i+1], stride = 2 if i==0 else 1))
            
        # creating the resblock layers
        block_sizes = [64//expansion, 64, 128, 256, 512]

        
        blocks = [self._make_layer(expansion, ni = block_sizes[i], nf = block_sizes[i+1], blocks = l, stride = 1 if i == 0 else 2) #1st stage has no downsampling
                    for i, l in enumerate(layers)] #l in enumerate(layers) goes through layers in XResNet and sets the value of blocks based on layer[i]
        
        # creating network
        super().__init__(*stem,
                      nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                      *blocks,
                      nn.AdaptiveAvgPool2d(1), 
                      Flatten(),
                      nn.Linear(block_sizes[-1]*expansion, n_classes)
        )
        
        init_cnn(self)
        
    def _make_layer(self, expansion, ni, nf, blocks, stride):
        """
        blocks = int. -> number of blocks to create = layer[i]
        """
        return nn.Sequential(
            *[Resblock(expansion, ni if i == 0 else nf, nf, stride if i==0 else 1) # only stride 2 for the downsampling block(first block) and stride = 1 for residual blocks
              for i in range(blocks)])


In [0]:
xresnet_18 = XResNet(expansion = 1, layers = [2, 2, 2, 2])
xresnet_34 = XResNet(expansion = 1, layers = [3, 4, 6, 3])
xresnet_50 = XResNet(expansion = 4, layers = [3, 4, 6, 3])
xresnet_101 = XResNet(expansion = 4, layers = [3,4,23,3])
xresnet_152 = XResNet(expansion = 4, layers = [3,8,36,3])


Tests for xresnet

In [0]:
tmp = torch.randn((16, 3, 226, 226)).to(device)
a = xresnet_50.to(device)
print(a(tmp).shape)


torch.Size([16, 1000])


# XRes2Net

https://github.com/gasvn/Res2Net 


https://github.com/lessw2020/res2net-plus

In [0]:
# res2 block

def no_op(x): return x # no operations done if ni == nf

    
class Res2block(nn.Module):
    def __init__(self, expansion, ni, nh, stride = 1, base_width = 26, scale = 4, first_block = False):
        """
        ni: number of in channels
        nh: number of hidden channels
        base_width: basic width of conv3x3
        scale: scaling ratio for the convs
        first_block: whether the block is the first to be placed in the conv layer
        
        """
        super(Res2block, self).__init__()
        
        self.first_block = first_block
        self.scale = scale
        
        nf, ni = nh*expansion, ni*expansion
        
        width = int(math.floor(nf*(base_width/64.)))
#         print(width)
        
        self.conv1 = conv_layer(ni, width*scale, 1, stride = stride)
#         print(ni, width*scale)
        
        
        self.conv3 = conv_layer(width*scale, nh*expansion, kernel_size=1, act = False) # no act_fn
        
        n_branches = max(2, scale) - 1
        
        if self.first_block:
            self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
            
#         self.convs = nn.ModuleList([conv_layer(width, width, 3, stride = stride) for _ in range(n_branches)]) # should it be stride = 1 here?
        self.convs = nn.ModuleList([conv_layer(width, width, 3, stride = 1) for _ in range(n_branches)]) 

        
        # if ni != nf, use a 1x1 conv to get the same channels, otherwise return x (no operations)
        self.idconv = no_op if ni == nf else conv_layer(ni, nf, 1, act = False)
        
        self.pooling = no_op if stride == 1 else nn.AvgPool2d(2, ceil_mode=True)
        
    def forward(self, x):
        x1 = self.conv1(x) #conv2d 1x1 -> bn -> act_fn
        
#         print('x1', x1.shape)
        # splitting into self.scale equal sized chunks
        xs = torch.chunk(x1, self.scale, dim = 1)
        #initialize output tensor for concatenation later on
        y = 0
        for idx, conv in enumerate(self.convs):
#             print(self.pooling(xs[idx]).shape)
#             xs[idx] = self.pooling(xs[idx])
#               temp = self.pooling(xs[idx])

            if self.first_block:
                y = xs[idx]
                
                """
                Something needs to be fixed here for when stride != 1
                
                """
            else:

#                 print('idx', idx, 'xs[idx].shape', xs[idx].shape)
#                 if idx > 0:
#                     print('idx', idx, 'y shape', y.shape)
#                 else:
#                     print('idx', idx, 'y', y)
#                 y +=  self.idconv(self.pooling(xs[idx])) # add the residual for the 2nd and onwards chunks
#                 print('pooled x[idx]', self.pooling(xs[idx]).shape)

#                 y +=  self.pooling(xs[idx])

                y += xs[idx]

            y = conv(y)
#             print('y after conv shape', y.shape)
            x1 = torch.cat((x1, y), 1) if idx >0 else y # concat outputs, but not the 1st chunk


        if self.scale > 1:
            if self.first_block:
                x1 = torch.cat((x1, self.pool(xs[len(self.convs)])), 1) #concat all the outputs together
            else:
                x1 = torch.cat((x1, xs[len(self.convs)]),1)
                

        x1 = self.conv3(x1) # conv1x1 -> bn -> no act_fn
        
        # computing the residual, changing nf or dimensions if not matching x1
        x2 = self.idconv(self.pooling(x))
        
        out = x1+x2
        
        return out
        
        
        

        

Tests for res2block

In [59]:
expansion = 4


tmp = torch.randn((16, 3*expansion, 226, 226)).to(device)


# a = Res2block(expansion, 3, 32).to(device)
# print('output shape',a(tmp).shape)


# bs, n_channels, H, W = a(tmp).size()
# print(H, W, H*W)


a2 = Res2block(expansion, 3, 20, stride = 2, first_block = False).to(device)
print(a2(tmp).shape)

# bs, n_channels, H, W = a2(tmp).size()
# print(H, W, H*W)


torch.Size([16, 80, 113, 113])


For XRes2Net. Replace ResBlocks with Res2Blocks

In [0]:
# creating XResNet
class XRes2Net(nn.Sequential):
    def __init__(self, expansion, layers, ni = 3, n_classes=1000, base_width = 26, scale = 4):
        
        """
        layers = list of length 4. 
        layer[i] = how many resblocks in each of the 4 chunks of the network
        expansion = what value to multiply the intermediate n_out of the convlayer by
        """
        
        self.inplanes = 64
        
        #stem
        stem = []
        stem_sizes = [ni, 32, 32, 64]
        for i in range(3):
            stem.append(conv_layer(stem_sizes[i], stem_sizes[i+1], stride = 2 if i==0 else 1))
            
        # creating the resblock layers
        block_sizes = [64//expansion, 64, 128, 256, 512]
        is_first_block = [True, False, False, False]
            
        blocks = [self._make_layer(expansion, ni = block_sizes[i], nf = block_sizes[i+1], blocks = l, stride = 1 if i == 0 else 2, first_block = is_first_block[i]) #1st stage has no downsampling
                    for i, l in enumerate(layers)] #l in enumerate(layers) goes through layers in XResNet and sets the value of blocks based on layer[i]
        
        # creating network
        super().__init__(*stem,
                      nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                      *blocks,
                      nn.AdaptiveAvgPool2d(1), 
                      Flatten(),
                      nn.Linear(block_sizes[-1]*expansion, n_classes)
        )
        
        init_cnn(self)
        
    def _make_layer(self, expansion, ni, nf, blocks, stride, first_block):
        """
        blocks (int): number of blocks to create = layer[i]
        """
        return nn.Sequential(
            *[Res2block(expansion, ni if i == 0 else nf, nf, stride if i==0 else 1, base_width = 26, scale = 4, first_block = first_block) # only stride 2 for the downsampling block(first block) and stride = 1 for residual blocks
              for i in range(blocks)])


In [0]:
xres2net_50 = XRes2Net(expansion = 4, layers = [3, 4, 6, 3])

In [55]:
tmp = torch.randn((16, 3, 226, 226)).to(device)
a = xres2net_50.to(device)
print(a(tmp).shape)

x1 torch.Size([16, 416, 57, 57])
here
here
here
x1 torch.Size([16, 416, 57, 57])
here
here
here
x1 torch.Size([16, 416, 57, 57])
here
here
here
x1 torch.Size([16, 832, 29, 29])
idx 0 xs[idx].shape torch.Size([16, 208, 29, 29])
idx 0 y 0
here
idx 1 xs[idx].shape torch.Size([16, 208, 29, 29])
idx 1 y shape torch.Size([16, 208, 29, 29])
here
idx 2 xs[idx].shape torch.Size([16, 208, 29, 29])
idx 2 y shape torch.Size([16, 208, 29, 29])
here
x1 torch.Size([16, 832, 29, 29])
idx 0 xs[idx].shape torch.Size([16, 208, 29, 29])
idx 0 y 0
here
idx 1 xs[idx].shape torch.Size([16, 208, 29, 29])
idx 1 y shape torch.Size([16, 208, 29, 29])
here
idx 2 xs[idx].shape torch.Size([16, 208, 29, 29])
idx 2 y shape torch.Size([16, 208, 29, 29])
here
x1 torch.Size([16, 832, 29, 29])
idx 0 xs[idx].shape torch.Size([16, 208, 29, 29])
idx 0 y 0
here
idx 1 xs[idx].shape torch.Size([16, 208, 29, 29])
idx 1 y shape torch.Size([16, 208, 29, 29])
here
idx 2 xs[idx].shape torch.Size([16, 208, 29, 29])
idx 2 y shape tor