In [1]:
import torch

## implementation of trident block

In [2]:
class trident_block(torch.nn.Module):
    expansion = 4
    def __init__(self, input_channels, output_channels, 
                 stride = 1, padding = [1,2,3], dilation = [1,2,3], downsample = None):
        super(trident_block, self).__init__()
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.downsample = downsample
        
        self.shared_weights4convolution1 = torch.nn.Parameter(torch.randn(input_channels,output_channels,1,1))
        self.shared_weights4convolution2 = torch.nn.Parameter(torch.randn(output_channels,output_channels,3,3))
        self.shared_weights4convolution3 = torch.nn.Parameter(torch.randn(output_channels,
                                                                          output_channels*self.expansion,1,1))
        bn11 = torch.nn.BatchNorm2d(output_channels)
        bn12 = torch.nn.BatchNorm2d(output_channels)
        bn13 = torch.nn.BatchNorm2d(output_channels*self.expansion)
        
        bn21 = torch.nn.BatchNorm2d(output_channels)
        bn22 = torch.nn.BatchNorm2d(output_channels)
        bn23 = torch.nn.BatchNorm2d(output_channels*self.expansion)
        
        bn31 = torch.nn.BatchNorm2d(output_channels)
        bn32 = torch.nn.BatchNorm2d(output_channels)
        bn33 = torch.nn.BatchNorm2d(output_channels*self.expansion)
        
        self.relu1 = torch.nn.ReLU(inplace = True)
        self.relu2 = torch.nn.ReLU(inplace = True)
        self.relu3 = torch.nn.ReLU(inplace = True)
        
    def forward_branch_1(self,x): # bran
        residual = x
        # conv 1x1
        output = torch.nn.functional.conv2d(x, shared_weights4convolution1, bias = None)
        output = self.bn11(output)
        output = self.relu1(output)
        # conv 3x3
        output = torch.nn.functional(output,self=shared_weights4convolution2, bias = None,
                                    stride = self.stride, padding = self.padding[0], 
                                    dilation = self.dilation[0])
        output = self.bn12(output)
        output = self.relu2(output)
        # conv 1x1
        output = torch.nn.functional(output,shared_weights4convolution3,bias = None)
        output = self.bn13(output)
        
        if self.downsample is not None:
            residual = downsample(x)
        
        output += residual
        output = self.relu3(output)
        return output
    
    def forward_branch_2(self,x): # bran
        residual = x
        # conv 1x1
        output = torch.nn.functional.conv2d(x, shared_weights4convolution1, bias = None)
        output = self.bn21(output)
        output = self.relu1(output)
        # conv 3x3
        output = torch.nn.functional(output,self=shared_weights4convolution2, bias = None,
                                    stride = self.stride, padding = self.padding[1], 
                                    dilation = self.dilation[1])
        output = self.bn22(output)
        output = self.relu2(output)
        # conv 1x1
        output = torch.nn.functional(output,shared_weights4convolution3,bias = None)
        output = self.bn23(output)
        
        if self.downsample is not None:
            residual = downsample(x)
        
        output += residual
        output = self.relu3(output)
        return output
    
    def forward_branch_3(self,x): # bran
        residual = x
        # conv 1x1
        output = torch.nn.functional.conv2d(x, shared_weights4convolution1, bias = None)
        output = self.bn31(output)
        output = self.relu1(output)
        # conv 3x3
        output = torch.nn.functional(output, self=shared_weights4convolution2, bias = None,
                                    stride = self.stride, padding = self.padding[2], 
                                    dilation = self.dilation[2])
        output = self.bn32(output)
        output = self.relu2(output)
        # conv 1x1
        output = torch.nn.functional(output, shared_weights4convolution3, bias = None)
        output = self.bn33(output)
        
        if self.downsample is not None:
            residual = downsample(x)
        
        output += residual
        output = self.relu3(output)
        return output
    
    def total_forward(self,x):
        feature_list = list()
        if self.downsample is not None:
            feature_list.append(self.forward_branch_1(x))
            feature_list.append(self.forward_branch_2(x))
            feature_list.append(self.forward_branch_3(x))
        else:
            feature_list.append(self.forward_branch_1(x[0]))
            feature_list.append(self.forward_branch_2(x[1]))
            feature_list.append(self.forward_branch_3(x[2])) 
        return feature_list

## implementation of BottleNeck

In [3]:
class BottleNeck(torch.nn.Module):
    expansion = 4
    def __init__(self, input_channels, output_channels, downsample = None):
        super(BottleNeck, self).__init__()
        # conv's structures
        self.conv_1 = torch.nn.Conv2d(input_channels, output_channels, kernel_size=1, bias=False)
        self.conv_2 = torch.nn.Conv2d(input_channels, output_channels, kernel_size=3, bias=False, padding = 1)
        self.conv_3 = torch.nn.Conv2d(input_channels, output_channels * self.expansion, kernel_size=1, bias=False)
        # normalizations
        self.bn1 = torch.nn.BatchNorm2d(output_channels)
        self.bn2 = torch.nn.BatchNorm2d(output_channels)
        self.bn3 = torch.nn.BatchNorm2d(output_channels * self.expansion)
        # acticvation and downsample
        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = downsample
        
    def forward(self, x):
        residual = x
        # conv 1x1
        output = self.conv_1(x)
        output = self.bn1(output)
        output = self.relu(output)
        # conv 3x3
        output = self.conv_2(output)
        output = self.bn2(output)
        output = self.relu(output)
        # conv 1x1
        output = self.conv_3(output)
        output = self.bn3(output)
        
        if self.downsample is not None:
            residual = downsample(x)
        
        output +=residual
        output = self.relu(output)
        
        return output

## implementation of BasicBlock

In [4]:
class Basic_Block(torch.nn.Module):
    def __init__(self, input_channels, output_channels, downsample = None):
        super(Basic_Block, self).__init__()
        # conv's structures
        self.conv1 = torch.nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1, bias=False)
        self.conv2 = torch.nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1, bias=False)
        # normalizations
        self.bn1 = torch.nn.BatchNorm2d(output_channels)
        self.bn2 = torch.nn.BatchNorm2d(output_channels)
        # acticvation and downsample
        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = downsample
        
    def forward(self, x):
        residual = x
        # first 3x3 convolution
        output = self.conv1(x)
        output = self.bn1(output)
        output = self.relu(output)
        # second 3x3 convolution
        output = self.conv2(output)
        output = self.bn2(output)
        
        if self.downsample is not None:
            residual = downsample(x)
        output += residual
        output = self.relu(output)
        
        return output

In [14]:
class ResNet(torch.nn.Module):
    def __init__(self, num_classes, net_type): # , bottle_block = BottleNeck, tri_block = trident_block,
        """
        params: avalible net_types - 'ResNet-50', 
                                     'ResNet-101', 
                                     'ResNet-152'
        """
        super(ResNet, self).__init__()
        
        self.net_type = net_type
        if self.net_type == 'ResNet-50':
            print("Net_type is ResNet-50")
        if self.net_type == 'ResNet-101':
            print("Net_type is ResNet-101")
        if self.net_type == 'ResNet-152':
            print("Net_type is ResNet-152")
            
        self.conv1 = torch.nn.Conv2d(3,64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1   = torch.nn.BatchNorm2d(64)
        self.relu  = torch.nn.ReLU(inplace=True)
        
        self.max_pooling = torch.nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) 
        
        self.layer1 = make_layer

In [15]:
mod = ResNet(num_classes=10, net_type='ResNet-50')

Net_type is ResNet-50
