In [44]:
# https://github.com/km1414/CNN-models/blob/master/resnet-32/resnet-32.py
# https://github.com/safwankdb/ResNet34-TF2/blob/master/model.py
# https://arxiv.org/pdf/1512.03385.pdf

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1):
        super(Block, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        
        if (in_channels != out_channels) or stride > 1:
            self.skip_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size = 1, bias = False, stride = stride),
                nn.BatchNorm2d(out_channels),
            )

        self.features = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias = False, stride = stride),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias = False),
            nn.BatchNorm2d(out_channels),
        )
        
        self.activation = nn.ReLU(inplace=True)
        
    def forward(self, x):
        print("Before: {}".format(x.size()))
        if self.in_channels != self.out_channels or self.stride > 1:
            out1 = self.features(x)
            out2 = self.skip_connection(x)
            print("After: {}".format((out1+out2).size()))
            return self.activation(out1+out2)
        else:
            out1 = self.features(x)
            print("After: {}".format((out1+x).size()))
            return self.activation(out1+x)

class ResNet34(nn.Module):
    def __init__(self, num_classes = 12):
        super(ResNet34, self).__init__()
       
        # 1st block
        self.conv1 = nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
        self.bn1 = nn.BatchNorm2d(64)
        self.activation1 = nn.ReLU(inplace = True)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding = 1)
        
        # 2nd block
        self.conv2_1 = Block(in_channels = 64, out_channels = 64)
        self.conv2_2 = Block(in_channels = 64, out_channels = 64)
        self.conv2_3 = Block(in_channels = 64, out_channels = 64)
        
        # 3rd block
        self.conv3_1 = Block(in_channels = 64, out_channels = 128, stride = 2)
        self.conv3_2 = Block(in_channels = 128, out_channels = 128)
        self.conv3_3 = Block(in_channels = 128, out_channels = 128)
        self.conv3_4 = Block(in_channels = 128, out_channels = 128)
        
        # 4th block
        self.conv4_1 = Block(in_channels = 128, out_channels = 256, stride = 2)
        self.conv4_2 = Block(in_channels = 256, out_channels = 256)
        self.conv4_3 = Block(in_channels = 256, out_channels = 256)
        self.conv4_4 = Block(in_channels = 256, out_channels = 256)
        self.conv4_5 = Block(in_channels = 256, out_channels = 256)
        self.conv4_6 = Block(in_channels = 256, out_channels = 256)
        
        # 5th block
        self.conv5_1 = Block(in_channels = 256, out_channels = 512, stride = 2)
        self.conv5_2 = Block(in_channels = 512, out_channels = 512)
        self.conv5_3 = Block(in_channels = 512, out_channels = 512)
        
        # avg pool
        self.pool2 = nn.AdaptiveAvgPool2d((1, 1))
        
        # Fully connected
        self.ff = nn.Linear(512, num_classes)
        
    def forward(self, x):
        # 1st block
        print("Before: {}".format(x.size()))
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation1(x)
        x = self.pool1(x)
        print("After: {}".format(x.size()))
        
        # 2nd block
        x = self.conv2_1(x)
        x = self.conv2_2(x)
        x = self.conv2_3(x)
        
        # 3rd block
        x = self.conv3_1(x)
        x = self.conv3_2(x)
        x = self.conv3_3(x)
        x = self.conv3_4(x)
        
        # 4th block
        x = self.conv4_1(x)
        x = self.conv4_2(x)
        x = self.conv4_3(x)
        x = self.conv4_4(x)
        x = self.conv4_5(x)
        x = self.conv4_6(x)
        
        # 5th block
        x = self.conv5_1(x)
        x = self.conv5_2(x)
        x = self.conv5_3(x)
        
        # avg pool
        print("Before: {}".format(x.size()))
        x = self.pool2(x)
        print("After: {}".format(x.size()))
        
        # flatten
        x = torch.flatten(x,1)
        
        # ff connected
        print("Before: {}".format(x.size()))
        x = self.ff(x)
        print("After: {}".format(x.size()))
        return x
    

def main():
    net = ResNet34()
    img = torch.Tensor(np.random.randn(1,3,224,224))
    out = net(img)
    print(out.size())
        
if __name__ == '__main__':
    main() 

Before: torch.Size([1, 3, 224, 224])
After: torch.Size([1, 64, 56, 56])
Before: torch.Size([1, 64, 56, 56])
After: torch.Size([1, 64, 56, 56])
Before: torch.Size([1, 64, 56, 56])
After: torch.Size([1, 64, 56, 56])
Before: torch.Size([1, 64, 56, 56])
After: torch.Size([1, 64, 56, 56])
Before: torch.Size([1, 64, 56, 56])
After: torch.Size([1, 128, 28, 28])
Before: torch.Size([1, 128, 28, 28])
After: torch.Size([1, 128, 28, 28])
Before: torch.Size([1, 128, 28, 28])
After: torch.Size([1, 128, 28, 28])
Before: torch.Size([1, 128, 28, 28])
After: torch.Size([1, 128, 28, 28])
Before: torch.Size([1, 128, 28, 28])
After: torch.Size([1, 256, 14, 14])
Before: torch.Size([1, 256, 14, 14])
After: torch.Size([1, 256, 14, 14])
Before: torch.Size([1, 256, 14, 14])
After: torch.Size([1, 256, 14, 14])
Before: torch.Size([1, 256, 14, 14])
After: torch.Size([1, 256, 14, 14])
Before: torch.Size([1, 256, 14, 14])
After: torch.Size([1, 256, 14, 14])
Before: torch.Size([1, 256, 14, 14])
After: torch.Size([1, 

In [8]:
import torch
model = torch.load("/home/arnab/Desktop/Data/ResNet34_EPOCH_1_trained_model.pt")
j = 0
for i,m in enumerate(model):
    if (i)%6 == 0:
        print(str(j)+"--------")
        j += 1
    print(i,m)

0--------
0 conv1.weight
1 bn1.weight
2 bn1.bias
3 bn1.running_mean
4 bn1.running_var
5 bn1.num_batches_tracked
1--------
6 layer1.0.conv1.weight
7 layer1.0.bn1.weight
8 layer1.0.bn1.bias
9 layer1.0.bn1.running_mean
10 layer1.0.bn1.running_var
11 layer1.0.bn1.num_batches_tracked
2--------
12 layer1.0.conv2.weight
13 layer1.0.bn2.weight
14 layer1.0.bn2.bias
15 layer1.0.bn2.running_mean
16 layer1.0.bn2.running_var
17 layer1.0.bn2.num_batches_tracked
3--------
18 layer1.1.conv1.weight
19 layer1.1.bn1.weight
20 layer1.1.bn1.bias
21 layer1.1.bn1.running_mean
22 layer1.1.bn1.running_var
23 layer1.1.bn1.num_batches_tracked
4--------
24 layer1.1.conv2.weight
25 layer1.1.bn2.weight
26 layer1.1.bn2.bias
27 layer1.1.bn2.running_mean
28 layer1.1.bn2.running_var
29 layer1.1.bn2.num_batches_tracked
5--------
30 layer1.2.conv1.weight
31 layer1.2.bn1.weight
32 layer1.2.bn1.bias
33 layer1.2.bn1.running_mean
34 layer1.2.bn1.running_var
35 layer1.2.bn1.num_batches_tracked
6--------
36 layer1.2.conv2.weig

In [41]:
import torch
model = torch.load("/home/arnab/Desktop/Data/ResNet34_EPOCH_1_trained_model.pt")

# cfgs = [(out_channel,num_model_parameters,stride(optional))]

cfgs = [(64,6), (64,12), (64,12), (64,12), (128,18,2), (128,12), (128,12), (128,12), (256,18,2), (256,12), (256,12), (256,12), (256,12), (256,12), (512,18,2), (512,12), (512,12)]


def create_model_list(model):
    model_list = []
    for key in model:
        model_list.append(key)
        
    return model_list

def create_model_dict(cfgs,model_list):
    model_dict = {}
    temp = []
    current_model_list_inx = 0
    for i,v in enumerate(cfgs):
        if len(v) == 2:
            out_channel = v[0]
            num_model_parameters = v[1]
        elif len(v) == 3:
            out_channel = v[0]
            num_model_parameters = v[1]
            stride = v[2]
            
        for j in range(current_model_list_inx,current_model_list_inx+num_model_parameters):
            temp.append(model_list[j])
        model_dict.update({i:temp})
        #print("{}:{}\n".format(i,temp))
        temp = []
        current_model_list_inx += num_model_parameters
    return model_dict

model_list = create_model_list(model)
model_dict = create_model_dict(cfgs,model_list)
m = model_dict[0]
print(model['fc.bias'].size())

torch.Size([12])


In [14]:
dict_ = {}
temp = []
for i in range(10):
    for j in range(5):
        temp.append(j)
    dict_.update({i:temp})
    temp = []
print(dict_)

{0: [0, 1, 2, 3, 4], 1: [0, 1, 2, 3, 4], 2: [0, 1, 2, 3, 4], 3: [0, 1, 2, 3, 4], 4: [0, 1, 2, 3, 4], 5: [0, 1, 2, 3, 4], 6: [0, 1, 2, 3, 4], 7: [0, 1, 2, 3, 4], 8: [0, 1, 2, 3, 4], 9: [0, 1, 2, 3, 4]}


In [37]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 1):
        super(Block, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        
        if (in_channels != out_channels) or stride > 1:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, bias = False, stride = stride)
            self.bn3= nn.BatchNorm2d(out_channels)


        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias = False, stride = stride)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.activation = nn.ReLU(inplace=True)

net = Block(in_channels=16,out_channels=16)
print(net.conv1.state_dict().keys())

odict_keys(['weight'])
