In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from collections import OrderedDict

# pre-trained model weight 讀取並查看
可使用 pytorch 的 torchvision.models 中所提供的模型權重，也可以使用自己訓練或下載的模型權重檔。

## 使用 pytorch 提供的 pre-trained model weight

In [2]:
model = models.resnet18(pretrained=True)
model_state = model.state_dict()

print("model_state type:", type(model_state))

model_state type: <class 'collections.OrderedDict'>


In [3]:
for param_tensor in model_state:
    print("name:", param_tensor)
    print("value:", model_state[param_tensor])

name: conv1.weight
value: tensor([[[[-1.0419e-02, -6.1356e-03, -1.8098e-03,  ...,  5.6615e-02,
            1.7083e-02, -1.2694e-02],
          [ 1.1083e-02,  9.5276e-03, -1.0993e-01,  ..., -2.7124e-01,
           -1.2907e-01,  3.7424e-03],
          [-6.9434e-03,  5.9089e-02,  2.9548e-01,  ...,  5.1972e-01,
            2.5632e-01,  6.3573e-02],
          ...,
          [-2.7535e-02,  1.6045e-02,  7.2595e-02,  ..., -3.3285e-01,
           -4.2058e-01, -2.5781e-01],
          [ 3.0613e-02,  4.0960e-02,  6.2850e-02,  ...,  4.1384e-01,
            3.9359e-01,  1.6606e-01],
          [-1.3736e-02, -3.6746e-03, -2.4084e-02,  ..., -1.5070e-01,
           -8.2230e-02, -5.7828e-03]],

         [[-1.1397e-02, -2.6619e-02, -3.4641e-02,  ...,  3.2521e-02,
            6.6221e-04, -2.5743e-02],
          [ 4.5687e-02,  3.3603e-02, -1.0453e-01,  ..., -3.1253e-01,
           -1.6051e-01, -1.2826e-03],
          [-8.3730e-04,  9.8420e-02,  4.0210e-01,  ...,  7.0789e-01,
            3.6887e-01,  1.2455e

value: tensor([0.0185, 0.0264, 0.0129, 0.0127, 0.0148, 0.0160, 0.0125, 0.0181, 0.0194,
        0.0116, 0.0075, 0.0112, 0.0349, 0.0177, 0.0208, 0.0254, 0.0236, 0.0198,
        0.0229, 0.0344, 0.0199, 0.0538, 0.0316, 0.0196, 0.0128, 0.0192, 0.0200,
        0.0287, 0.0182, 0.0273, 0.0113, 0.0118, 0.0211, 0.0085, 0.0169, 0.0159,
        0.0126, 0.0156, 0.0100, 0.0089, 0.0185, 0.0156, 0.0082, 0.0131, 0.0117,
        0.0102, 0.0105, 0.0244, 0.0140, 0.0161, 0.0112, 0.0163, 0.0295, 0.0355,
        0.0150, 0.0201, 0.0174, 0.0119, 0.0201, 0.0351, 0.0188, 0.0121, 0.0134,
        0.0237, 0.0242, 0.0176, 0.0283, 0.0147, 0.0105, 0.0169, 0.0247, 0.0185,
        0.0104, 0.0167, 0.0137, 0.0225, 0.0115, 0.0248, 0.0120, 0.0126, 0.0219,
        0.0170, 0.0174, 0.0217, 0.0176, 0.0198, 0.0237, 0.0123, 0.0272, 0.0605,
        0.0114, 0.0128, 0.0200, 0.0199, 0.0225, 0.0195, 0.0165, 0.0230, 0.0188,
        0.0209, 0.0215, 0.0241, 0.0126, 0.0144, 0.0127, 0.0098, 0.0165, 0.0163,
        0.0403, 0.0153, 0.0141, 0

## 使用自己訓練或下載的權重檔

In [4]:
import gdown

resnet_model = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
gdown.download(resnet_model, "resnet-5c106cde.pth")

Downloading...
From: https://download.pytorch.org/models/resnet18-5c106cde.pth
To: C:\Users\joyle\pythonwork\torch_work\resnet-5c106cde.pth
100%|█████████████████████████████████████████████████████████████████████████████| 46.8M/46.8M [00:05<00:00, 8.63MB/s]

model has been downloaded.





In [4]:
checkpoint = torch.load('resnet-5c106cde.pth')
print("checkpoint type:", type(checkpoint))

checkpoint type: <class 'collections.OrderedDict'>


In [5]:
for k, v in checkpoint.items():
    print("name:", k)
    print("value:", v)

name: conv1.weight
value: Parameter containing:
tensor([[[[-1.0419e-02, -6.1356e-03, -1.8098e-03,  ...,  5.6615e-02,
            1.7083e-02, -1.2694e-02],
          [ 1.1083e-02,  9.5276e-03, -1.0993e-01,  ..., -2.7124e-01,
           -1.2907e-01,  3.7424e-03],
          [-6.9434e-03,  5.9089e-02,  2.9548e-01,  ...,  5.1972e-01,
            2.5632e-01,  6.3573e-02],
          ...,
          [-2.7535e-02,  1.6045e-02,  7.2595e-02,  ..., -3.3285e-01,
           -4.2058e-01, -2.5781e-01],
          [ 3.0613e-02,  4.0960e-02,  6.2850e-02,  ...,  4.1384e-01,
            3.9359e-01,  1.6606e-01],
          [-1.3736e-02, -3.6746e-03, -2.4084e-02,  ..., -1.5070e-01,
           -8.2230e-02, -5.7828e-03]],

         [[-1.1397e-02, -2.6619e-02, -3.4641e-02,  ...,  3.2521e-02,
            6.6221e-04, -2.5743e-02],
          [ 4.5687e-02,  3.3603e-02, -1.0453e-01,  ..., -3.1253e-01,
           -1.6051e-01, -1.2826e-03],
          [-8.3730e-04,  9.8420e-02,  4.0210e-01,  ...,  7.0789e-01,
          

value: tensor([-8.2312e-02, -3.3227e-02, -2.6646e-02, -1.3217e-02, -6.3828e-02,
        -1.0469e-01, -6.7094e-02, -5.3039e-02,  6.2327e-02, -1.1060e-04,
        -1.2984e-02,  1.7784e-02, -4.9284e-02, -6.5327e-02, -9.8474e-02,
         2.1125e-01, -9.2193e-02, -1.1394e-01, -1.8788e-01, -8.2167e-02,
        -5.8521e-02, -5.5994e-01, -5.1197e-02, -1.9867e-01, -7.1481e-02,
        -6.6465e-02, -1.6837e-01, -1.0312e-01,  1.0545e-01, -2.3061e-01,
        -3.3933e-02, -8.0676e-02, -3.4662e-03,  4.7170e-02, -7.8841e-02,
        -1.2650e-01, -6.1497e-02,  6.4309e-02,  2.7830e-02,  3.1326e-03,
        -7.6740e-02, -9.5088e-02,  1.3829e-04, -4.7186e-02, -3.0897e-02,
        -4.6435e-02, -6.3253e-02, -8.1402e-02, -7.0531e-02, -6.7097e-02,
        -7.9232e-02, -1.7778e-01,  5.5108e-02,  4.1108e-01,  1.5609e-03,
        -7.9884e-02, -6.9195e-02, -4.9812e-02,  3.8361e-03,  2.5202e-03,
         3.5645e-02, -2.5757e-02, -1.2249e-01, -4.6376e-02, -1.2163e-01,
        -6.4273e-02, -1.3834e-01, -2.3200e-0

# 修改layer
這部分會分為修改 layer 的參數和名稱

## 參數
### 最後一層的輸出值

In [6]:
model = models.resnet18(pretrained=True)
model.fc

Linear(in_features=512, out_features=1000, bias=True)

In [7]:
in_features = model.fc.in_features
num_class = 10

model.fc = nn.Linear(in_features, num_class)
print(model.fc)

Linear(in_features=512, out_features=10, bias=True)


In [8]:
model_state = model.state_dict()
print("weight: ", model_state['fc.weight'].shape)
print("bias: ", model_state['fc.bias'].shape)

weight:  torch.Size([10, 512])
bias:  torch.Size([10])


In [9]:
print("weight: ", model_state['fc.weight'])
print("bias: ", model_state['fc.bias'])

weight:  tensor([[-0.0178,  0.0195,  0.0206,  ..., -0.0246, -0.0031,  0.0218],
        [ 0.0218, -0.0324, -0.0375,  ..., -0.0089,  0.0187, -0.0119],
        [ 0.0048, -0.0325, -0.0199,  ..., -0.0054, -0.0382, -0.0221],
        ...,
        [-0.0324, -0.0263, -0.0365,  ...,  0.0139,  0.0190,  0.0053],
        [ 0.0231, -0.0329,  0.0187,  ...,  0.0001, -0.0297,  0.0108],
        [ 0.0147,  0.0201, -0.0160,  ...,  0.0025,  0.0128,  0.0210]])
bias:  tensor([ 0.0158,  0.0223,  0.0025, -0.0293,  0.0103,  0.0068, -0.0419,  0.0389,
         0.0371, -0.0370])


### 某層參數 
* 使用 pytorch 提供的 pre-trained model 權重

In [10]:
model = models.resnet18(pretrained=True)
model_state = model.state_dict()

In [15]:
for param_tensor in model_state:
    if param_tensor.startswith('fc'):
        print("name:", param_tensor)

name: fc.weight
name: fc.bias


In [16]:
print("org: ", model_state['fc.weight'].shape)
model_state['fc.weight'] = torch.rand((10, 512))
print("now: ", model_state['fc.weight'].shape)

org:  torch.Size([1000, 512])
now:  torch.Size([10, 512])


In [17]:
print("org: ", model_state['fc.bias'].shape)
model_state['fc.bias'] = torch.ones(10)
print("now: ", model_state['fc.bias'].shape)

org:  torch.Size([1000])
now:  torch.Size([10])


In [18]:
for param_tensor in model_state:
    if param_tensor in ['fc.weight', 'fc.bias']:
        print("name:", param_tensor)
        print("value:", model_state[param_tensor])

name: fc.weight
value: tensor([[0.4836, 0.1596, 0.4073,  ..., 0.1078, 0.6504, 0.2826],
        [0.6172, 0.7873, 0.0915,  ..., 0.9723, 0.0903, 0.8155],
        [0.1876, 0.0691, 0.0087,  ..., 0.6894, 0.4616, 0.0709],
        ...,
        [0.7497, 0.3456, 0.3441,  ..., 0.1699, 0.3600, 0.4801],
        [0.1419, 0.5504, 0.0085,  ..., 0.6065, 0.9412, 0.6216],
        [0.4666, 0.0663, 0.2911,  ..., 0.0243, 0.8549, 0.3349]])
name: fc.bias
value: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])


* 使用自己訓練或下載的權重檔

In [19]:
checkpoint = torch.load('resnet-5c106cde.pth')

In [20]:
for k, v in checkpoint.items():
    if k in ['fc.weight', 'fc.bias']:
        print("name:", k)
        print("value:", v)

name: fc.weight
value: Parameter containing:
tensor([[-0.0185, -0.0705, -0.0518,  ..., -0.0390,  0.1735, -0.0410],
        [-0.0818, -0.0944,  0.0174,  ...,  0.2028, -0.0248,  0.0372],
        [-0.0332, -0.0566, -0.0242,  ..., -0.0344, -0.0227,  0.0197],
        ...,
        [-0.0103,  0.0033, -0.0359,  ..., -0.0279, -0.0115,  0.0128],
        [-0.0359, -0.0353, -0.0296,  ..., -0.0330, -0.0110, -0.0513],
        [ 0.0021, -0.0248, -0.0829,  ...,  0.0417, -0.0500,  0.0663]],
       requires_grad=True)
name: fc.bias
value: Parameter containing:
tensor([-2.6341e-03,  3.0005e-03,  6.5581e-04, -2.6909e-02,  6.3637e-03,
         1.3260e-02, -1.1178e-02,  2.0639e-02, -3.6373e-03, -1.2325e-02,
        -1.2629e-02, -7.2057e-03, -1.9321e-02, -2.4960e-02, -1.1885e-02,
        -8.3259e-03, -9.5745e-03, -1.6658e-02,  9.1804e-03, -1.5354e-02,
         7.1358e-03,  3.0737e-02,  1.3239e-02, -7.7528e-03,  4.7448e-03,
         1.1175e-02,  1.5949e-02, -1.6712e-02, -1.0130e-03, -3.7167e-03,
         6.52

In [21]:
print("org: ", checkpoint['fc.weight'].shape)
checkpoint['fc.weight'] = torch.rand((10, 512))
print("now: ", checkpoint['fc.weight'].shape)

org:  torch.Size([1000, 512])
now:  torch.Size([10, 512])


In [22]:
print("org: ", checkpoint['fc.bias'].shape)
checkpoint['fc.bias'] = torch.ones(10)
print("now: ", checkpoint['fc.bias'].shape)

org:  torch.Size([1000])
now:  torch.Size([10])


In [23]:
for k, v in checkpoint.items():
    if k in ['fc.weight', 'fc.bias']:
        print("name:", k)
        print("value:", v)

name: fc.weight
value: tensor([[0.2027, 0.7287, 0.3048,  ..., 0.4458, 0.6887, 0.3502],
        [0.4577, 0.7108, 0.3471,  ..., 0.6735, 0.6224, 0.3283],
        [0.9822, 0.5041, 0.3341,  ..., 0.7837, 0.7394, 0.6175],
        ...,
        [0.3742, 0.1213, 0.6600,  ..., 0.7086, 0.8760, 0.8814],
        [0.0372, 0.5985, 0.8533,  ..., 0.3187, 0.7410, 0.6001],
        [0.6699, 0.5557, 0.1950,  ..., 0.1332, 0.1001, 0.1326]])
name: fc.bias
value: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])


## layer 名稱

In [24]:
model = models.resnet18()

for param_tensor in model.state_dict():
    print("name:", param_tensor)
    print("value:", model.state_dict()[param_tensor].size())

name: conv1.weight
value: torch.Size([64, 3, 7, 7])
name: bn1.weight
value: torch.Size([64])
name: bn1.bias
value: torch.Size([64])
name: bn1.running_mean
value: torch.Size([64])
name: bn1.running_var
value: torch.Size([64])
name: bn1.num_batches_tracked
value: torch.Size([])
name: layer1.0.conv1.weight
value: torch.Size([64, 64, 3, 3])
name: layer1.0.bn1.weight
value: torch.Size([64])
name: layer1.0.bn1.bias
value: torch.Size([64])
name: layer1.0.bn1.running_mean
value: torch.Size([64])
name: layer1.0.bn1.running_var
value: torch.Size([64])
name: layer1.0.bn1.num_batches_tracked
value: torch.Size([])
name: layer1.0.conv2.weight
value: torch.Size([64, 64, 3, 3])
name: layer1.0.bn2.weight
value: torch.Size([64])
name: layer1.0.bn2.bias
value: torch.Size([64])
name: layer1.0.bn2.running_mean
value: torch.Size([64])
name: layer1.0.bn2.running_var
value: torch.Size([64])
name: layer1.0.bn2.num_batches_tracked
value: torch.Size([])
name: layer1.1.conv1.weight
value: torch.Size([64, 64, 3, 3

In [25]:
checkpoint = torch.load('resnet_weights.pth')

for k, v in checkpoint.items():
    print("name:", k)
    print("value:", v.size())

name: resnet.conv1.weight
value: torch.Size([64, 3, 7, 7])
name: resnet.bn1.weight
value: torch.Size([64])
name: resnet.bn1.bias
value: torch.Size([64])
name: resnet.bn1.running_mean
value: torch.Size([64])
name: resnet.bn1.running_var
value: torch.Size([64])
name: resnet.bn1.num_batches_tracked
value: torch.Size([])
name: resnet.layer1.0.conv1.weight
value: torch.Size([64, 64, 3, 3])
name: resnet.layer1.0.bn1.weight
value: torch.Size([64])
name: resnet.layer1.0.bn1.bias
value: torch.Size([64])
name: resnet.layer1.0.bn1.running_mean
value: torch.Size([64])
name: resnet.layer1.0.bn1.running_var
value: torch.Size([64])
name: resnet.layer1.0.bn1.num_batches_tracked
value: torch.Size([])
name: resnet.layer1.0.conv2.weight
value: torch.Size([64, 64, 3, 3])
name: resnet.layer1.0.bn2.weight
value: torch.Size([64])
name: resnet.layer1.0.bn2.bias
value: torch.Size([64])
name: resnet.layer1.0.bn2.running_mean
value: torch.Size([64])
name: resnet.layer1.0.bn2.running_var
value: torch.Size([64])
n

In [26]:
state_dict = OrderedDict()

for k, v in checkpoint.items():
    state_dict[k[len('resnet.'):]] = v

In [27]:
state_dict

OrderedDict([('conv1.weight',
              tensor([[[[-3.2218e-02, -5.0886e-02, -1.1717e-02,  ...,  2.9374e-04,
                         -3.4718e-02,  7.2881e-03],
                        [-7.2209e-03, -4.8808e-02, -2.9471e-02,  ..., -3.6233e-02,
                         -4.9703e-02, -2.2587e-02],
                        [ 8.6528e-03,  1.3639e-02,  1.7644e-02,  ..., -1.2735e-02,
                          3.5807e-02,  5.8483e-02],
                        ...,
                        [-2.9141e-02,  2.6303e-02,  1.4357e-03,  ..., -2.9984e-02,
                          3.0706e-02,  1.3324e-02],
                        [ 1.5592e-02, -2.2142e-04,  6.7921e-02,  ..., -2.0002e-02,
                         -2.7611e-02,  3.3269e-02],
                        [-5.9207e-03, -1.3899e-02,  2.6648e-02,  ..., -1.1708e-02,
                          1.1330e-02, -1.1135e-02]],
              
                       [[-2.2242e-02, -3.0318e-02,  2.4385e-03,  ..., -4.0303e-04,
                          1.5125

In [28]:
model.load_state_dict(state_dict)

<All keys matched successfully>

# 新增 layer

In [29]:
class MyResNet18(nn.Module):
    def __init__(self, net_block, layers, num_classes=10):
        super(MyResNet18, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.net_block_layer(net_block, 64, layers[0])
        self.layer2 = self.net_block_layer(net_block, 128, layers[1], stride=2)
        self.layer3 = self.net_block_layer(net_block, 256, layers[2], stride=2)
        self.layer4 = self.net_block_layer(net_block, 512, layers[3], stride=2)

        ## ============== 新增的網路層 ============ ##
        self.layer5 = nn.Sequential(nn.Conv2d(layers[3], 128, kernel_size=3, stride=2, padding=1),
                                    nn.BatchNorm2d(128),
                                    nn.ReLU(inplace=True))
                
        ## ======================================= ##
        
        self.avgpooling = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * net_block.expansion, num_classes)

        # 參數初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)        

    def net_block_layer(self, net_block, out_channels, num_blocks, stride=1):
        downsample = None

      # 在 shortcut 時，若維度不一樣，要更改維度
        if stride != 1 or self.in_channels != out_channels * net_block.expansion:
            downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels * net_block.expansion, kernel_size=1, stride=stride, bias=False),
                      nn.BatchNorm2d(out_channels * net_block.expansion))

        layers = []
        layers.append(net_block(self.in_channels, out_channels, stride, downsample))
        if net_block.expansion != 1:
            self.in_channels = out_channels * net_block.expansion

        else:
            self.in_channels = out_channels

        for i in range(1, num_blocks):
            layers.append(net_block(self.in_channels, out_channels, 1, None))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpooling(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.layer5(x)
        
        x = self.avgpooling(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)

        return x
    
class basic_block(nn.Module):
    # 輸出通道乘的倍數
    expansion = 1

    def __init__(self, in_channels, out_channels, stride, downsample):
        super(basic_block, self).__init__()      
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 在 shortcut 時，若維度不一樣，要更改維度
        self.downsample = downsample 


    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [30]:
num_classes = 10

model = MyResNet18(basic_block, [2, 2, 2, 2], num_classes)
model_state = model.state_dict()

In [31]:
checkpoint = torch.load('resnet-5c106cde.pth')

checkpoint['fc.weight'] = torch.zeros((num_classes, 512))
checkpoint['fc.bias'] = torch.zeros(num_classes)

In [32]:
pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_state}

In [33]:
model_state.update(pretrained_dict)

In [34]:
model.load_state_dict(model_state)

<All keys matched successfully>

# 刪除 layer

In [33]:
checkpoint = torch.load('resnet-5c106cde.pth')

In [34]:
for k in list(checkpoint.keys()):
    if k.startswith('layer4.1'):
        print(k)

layer4.1.conv1.weight
layer4.1.bn1.running_mean
layer4.1.bn1.running_var
layer4.1.bn1.weight
layer4.1.bn1.bias
layer4.1.conv2.weight
layer4.1.bn2.running_mean
layer4.1.bn2.running_var
layer4.1.bn2.weight
layer4.1.bn2.bias


In [35]:
import numpy as np

for k in list(checkpoint.keys()):    
    if k.startswith('layer4.1'):
        del checkpoint[k]

# 驗證是否刪除成功
for k in list(checkpoint.keys()):
    a = ["None" if not k.startswith('layer4.1') else "Exists" for k in list(checkpoint.keys())]

print(np.unique(a))

['None']
