In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from collections import OrderedDict

# pre-trained model weight 讀取並查看
可使用 pytorch 的 torchvision.models 中所提供的模型權重，也可以使用自己訓練或下載的模型權重檔。

## 使用 pytorch 提供的 pre-trained model weight

In [2]:
model = models.resnet18(pretrained=True)
model_state = model.state_dict()

print("model_state type:", type(model_state))

model_state type: <class 'collections.OrderedDict'>


In [3]:
cnt = 0

for param_tensor in model_state:
    print("name:", param_tensor)
    print("value:", model_state[param_tensor].shape)
    
    cnt += 1
    if cnt == 10:
        break

name: conv1.weight
value: torch.Size([64, 3, 7, 7])
name: bn1.weight
value: torch.Size([64])
name: bn1.bias
value: torch.Size([64])
name: bn1.running_mean
value: torch.Size([64])
name: bn1.running_var
value: torch.Size([64])
name: bn1.num_batches_tracked
value: torch.Size([])
name: layer1.0.conv1.weight
value: torch.Size([64, 64, 3, 3])
name: layer1.0.bn1.weight
value: torch.Size([64])
name: layer1.0.bn1.bias
value: torch.Size([64])
name: layer1.0.bn1.running_mean
value: torch.Size([64])


## 使用自己訓練或下載的權重檔

In [4]:
import gdown

resnet_model = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
gdown.download(resnet_model, "resnet-5c106cde.pth")

Downloading...
From: https://download.pytorch.org/models/resnet18-5c106cde.pth
To: C:\Users\joyle\pythonwork\torch_work\resnet-5c106cde.pth
100%|█████████████████████████████████████████████████████████████████████████████| 46.8M/46.8M [00:05<00:00, 8.63MB/s]

model has been downloaded.





In [4]:
checkpoint = torch.load('resnet-5c106cde.pth')
print("checkpoint type:", type(checkpoint))

checkpoint type: <class 'collections.OrderedDict'>


In [5]:
cnt = 0
for k, v in checkpoint.items():
    print("name:", k)
    print("value:", v.shape)
    
    cnt += 1
    if cnt == 10:
        break

name: conv1.weight
value: torch.Size([64, 3, 7, 7])
name: bn1.running_mean
value: torch.Size([64])
name: bn1.running_var
value: torch.Size([64])
name: bn1.weight
value: torch.Size([64])
name: bn1.bias
value: torch.Size([64])
name: layer1.0.conv1.weight
value: torch.Size([64, 64, 3, 3])
name: layer1.0.bn1.running_mean
value: torch.Size([64])
name: layer1.0.bn1.running_var
value: torch.Size([64])
name: layer1.0.bn1.weight
value: torch.Size([64])
name: layer1.0.bn1.bias
value: torch.Size([64])


# 修改layer
這部分會分為修改 layer 的參數和名稱

## 參數
### 最後一層的輸出值

In [6]:
model = models.resnet18(pretrained=True)
model.fc

Linear(in_features=512, out_features=1000, bias=True)

In [7]:
in_features = model.fc.in_features
num_class = 10

model.fc = nn.Linear(in_features, num_class)
print(model.fc)

Linear(in_features=512, out_features=10, bias=True)


In [8]:
model_state = model.state_dict()
print("weight: ", model_state['fc.weight'].shape)
print("bias: ", model_state['fc.bias'].shape)

weight:  torch.Size([10, 512])
bias:  torch.Size([10])


In [9]:
print("weight: ", model_state['fc.weight'])
print("bias: ", model_state['fc.bias'])

weight:  tensor([[-0.0178,  0.0195,  0.0206,  ..., -0.0246, -0.0031,  0.0218],
        [ 0.0218, -0.0324, -0.0375,  ..., -0.0089,  0.0187, -0.0119],
        [ 0.0048, -0.0325, -0.0199,  ..., -0.0054, -0.0382, -0.0221],
        ...,
        [-0.0324, -0.0263, -0.0365,  ...,  0.0139,  0.0190,  0.0053],
        [ 0.0231, -0.0329,  0.0187,  ...,  0.0001, -0.0297,  0.0108],
        [ 0.0147,  0.0201, -0.0160,  ...,  0.0025,  0.0128,  0.0210]])
bias:  tensor([ 0.0158,  0.0223,  0.0025, -0.0293,  0.0103,  0.0068, -0.0419,  0.0389,
         0.0371, -0.0370])


### 某層參數 
* 使用 pytorch 提供的 pre-trained model 權重

In [10]:
model = models.resnet18(pretrained=True)
model_state = model.state_dict()

In [15]:
for param_tensor in model_state:
    if param_tensor.startswith('fc'):
        print("name:", param_tensor)

name: fc.weight
name: fc.bias


In [16]:
print("org: ", model_state['fc.weight'].shape)
model_state['fc.weight'] = torch.rand((10, 512))
print("now: ", model_state['fc.weight'].shape)

org:  torch.Size([1000, 512])
now:  torch.Size([10, 512])


In [17]:
print("org: ", model_state['fc.bias'].shape)
model_state['fc.bias'] = torch.ones(10)
print("now: ", model_state['fc.bias'].shape)

org:  torch.Size([1000])
now:  torch.Size([10])


In [18]:
for param_tensor in model_state:
    if param_tensor in ['fc.weight', 'fc.bias']:
        print("name:", param_tensor)
        print("value:", model_state[param_tensor])

name: fc.weight
value: tensor([[0.4836, 0.1596, 0.4073,  ..., 0.1078, 0.6504, 0.2826],
        [0.6172, 0.7873, 0.0915,  ..., 0.9723, 0.0903, 0.8155],
        [0.1876, 0.0691, 0.0087,  ..., 0.6894, 0.4616, 0.0709],
        ...,
        [0.7497, 0.3456, 0.3441,  ..., 0.1699, 0.3600, 0.4801],
        [0.1419, 0.5504, 0.0085,  ..., 0.6065, 0.9412, 0.6216],
        [0.4666, 0.0663, 0.2911,  ..., 0.0243, 0.8549, 0.3349]])
name: fc.bias
value: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])


* 使用自己訓練或下載的權重檔

In [6]:
checkpoint = torch.load('resnet-5c106cde.pth')

In [10]:
for k, v in checkpoint.items():
    if k in ['fc.weight', 'fc.bias']:
        print("name:", k)
        print("value:", v.shape)

name: fc.weight
value: torch.Size([1000, 512])
name: fc.bias
value: torch.Size([1000])


In [21]:
print("org: ", checkpoint['fc.weight'].shape)
checkpoint['fc.weight'] = torch.rand((10, 512))
print("now: ", checkpoint['fc.weight'].shape)

org:  torch.Size([1000, 512])
now:  torch.Size([10, 512])


In [22]:
print("org: ", checkpoint['fc.bias'].shape)
checkpoint['fc.bias'] = torch.ones(10)
print("now: ", checkpoint['fc.bias'].shape)

org:  torch.Size([1000])
now:  torch.Size([10])


In [23]:
for k, v in checkpoint.items():
    if k in ['fc.weight', 'fc.bias']:
        print("name:", k)
        print("value:", v)

name: fc.weight
value: tensor([[0.2027, 0.7287, 0.3048,  ..., 0.4458, 0.6887, 0.3502],
        [0.4577, 0.7108, 0.3471,  ..., 0.6735, 0.6224, 0.3283],
        [0.9822, 0.5041, 0.3341,  ..., 0.7837, 0.7394, 0.6175],
        ...,
        [0.3742, 0.1213, 0.6600,  ..., 0.7086, 0.8760, 0.8814],
        [0.0372, 0.5985, 0.8533,  ..., 0.3187, 0.7410, 0.6001],
        [0.6699, 0.5557, 0.1950,  ..., 0.1332, 0.1001, 0.1326]])
name: fc.bias
value: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])


## layer 名稱

In [11]:
model = models.resnet18()

cnt = 0
for param_tensor in model.state_dict():
    print("name:", param_tensor)
    print("value:", model.state_dict()[param_tensor].size())
    
    cnt += 1
    if cnt == 10:
        break

name: conv1.weight
value: torch.Size([64, 3, 7, 7])
name: bn1.weight
value: torch.Size([64])
name: bn1.bias
value: torch.Size([64])
name: bn1.running_mean
value: torch.Size([64])
name: bn1.running_var
value: torch.Size([64])
name: bn1.num_batches_tracked
value: torch.Size([])
name: layer1.0.conv1.weight
value: torch.Size([64, 64, 3, 3])
name: layer1.0.bn1.weight
value: torch.Size([64])
name: layer1.0.bn1.bias
value: torch.Size([64])
name: layer1.0.bn1.running_mean
value: torch.Size([64])


In [12]:
checkpoint = torch.load('resnet_weights.pth')

cnt = 0
for k, v in checkpoint.items():
    print("name:", k)
    print("value:", v.size())
    
    cnt += 1
    if cnt == 10:
        break

name: resnet.conv1.weight
value: torch.Size([64, 3, 7, 7])
name: resnet.bn1.weight
value: torch.Size([64])
name: resnet.bn1.bias
value: torch.Size([64])
name: resnet.bn1.running_mean
value: torch.Size([64])
name: resnet.bn1.running_var
value: torch.Size([64])
name: resnet.bn1.num_batches_tracked
value: torch.Size([])
name: resnet.layer1.0.conv1.weight
value: torch.Size([64, 64, 3, 3])
name: resnet.layer1.0.bn1.weight
value: torch.Size([64])
name: resnet.layer1.0.bn1.bias
value: torch.Size([64])
name: resnet.layer1.0.bn1.running_mean
value: torch.Size([64])


In [13]:
state_dict = OrderedDict()

for k, v in checkpoint.items():
    state_dict[k[len('resnet.'):]] = v

In [14]:
state_dict.keys()

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.weight', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.bn2.num_batches_tracked', 'layer2.0.conv1.weight', 'layer2.0.bn1.weight', 'layer2.0.bn1.bias', 'layer2.0.bn1.running_mean', 'layer2.0.bn1.running_var', 'layer2.0.bn1.num_batches_tracked', 'layer2.0.conv2.weight', 'layer2.0.bn2.weight', 'layer2.0.bn2.bias', '

In [15]:
state_dict['conv1.weight'][:1]

tensor([[[[-0.0322, -0.0509, -0.0117, -0.0062,  0.0003, -0.0347,  0.0073],
          [-0.0072, -0.0488, -0.0295, -0.0035, -0.0362, -0.0497, -0.0226],
          [ 0.0087,  0.0136,  0.0176,  0.0150, -0.0127,  0.0358,  0.0585],
          [-0.0243,  0.0452,  0.0083,  0.0163, -0.0355,  0.0162, -0.0159],
          [-0.0291,  0.0263,  0.0014,  0.0211, -0.0300,  0.0307,  0.0133],
          [ 0.0156, -0.0002,  0.0679,  0.0492, -0.0200, -0.0276,  0.0333],
          [-0.0059, -0.0139,  0.0266, -0.0367, -0.0117,  0.0113, -0.0111]],

         [[-0.0222, -0.0303,  0.0024,  0.0153, -0.0004,  0.0151,  0.0063],
          [ 0.0276,  0.0107,  0.0152,  0.0082, -0.0362,  0.0104,  0.0325],
          [ 0.0083, -0.0053, -0.0374, -0.0166,  0.0003, -0.0019,  0.0062],
          [ 0.0144,  0.0045,  0.0279, -0.0087,  0.0057, -0.0145,  0.0090],
          [ 0.0215,  0.0047, -0.0268,  0.0387,  0.0050, -0.0034,  0.0031],
          [-0.0269, -0.0255,  0.0234, -0.0368,  0.0168,  0.0030,  0.0048],
          [-0.0396, -0.

In [28]:
model.load_state_dict(state_dict)

<All keys matched successfully>

# 新增 layer

In [29]:
class MyResNet18(nn.Module):
    def __init__(self, net_block, layers, num_classes=10):
        super(MyResNet18, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.net_block_layer(net_block, 64, layers[0])
        self.layer2 = self.net_block_layer(net_block, 128, layers[1], stride=2)
        self.layer3 = self.net_block_layer(net_block, 256, layers[2], stride=2)
        self.layer4 = self.net_block_layer(net_block, 512, layers[3], stride=2)

        ## ============== 新增的網路層 ============ ##
        self.layer5 = nn.Sequential(nn.Conv2d(layers[3], 128, kernel_size=3, stride=2, padding=1),
                                    nn.BatchNorm2d(128),
                                    nn.ReLU(inplace=True))
                
        ## ======================================= ##
        
        self.avgpooling = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * net_block.expansion, num_classes)

        # 參數初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)        

    def net_block_layer(self, net_block, out_channels, num_blocks, stride=1):
        downsample = None

      # 在 shortcut 時，若維度不一樣，要更改維度
        if stride != 1 or self.in_channels != out_channels * net_block.expansion:
            downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels * net_block.expansion, kernel_size=1, stride=stride, bias=False),
                      nn.BatchNorm2d(out_channels * net_block.expansion))

        layers = []
        layers.append(net_block(self.in_channels, out_channels, stride, downsample))
        if net_block.expansion != 1:
            self.in_channels = out_channels * net_block.expansion

        else:
            self.in_channels = out_channels

        for i in range(1, num_blocks):
            layers.append(net_block(self.in_channels, out_channels, 1, None))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpooling(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.layer5(x)
        
        x = self.avgpooling(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)

        return x
    
class basic_block(nn.Module):
    # 輸出通道乘的倍數
    expansion = 1

    def __init__(self, in_channels, out_channels, stride, downsample):
        super(basic_block, self).__init__()      
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 在 shortcut 時，若維度不一樣，要更改維度
        self.downsample = downsample 


    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [30]:
num_classes = 10

model = MyResNet18(basic_block, [2, 2, 2, 2], num_classes)
model_state = model.state_dict()

In [31]:
checkpoint = torch.load('resnet-5c106cde.pth')

checkpoint['fc.weight'] = torch.zeros((num_classes, 512))
checkpoint['fc.bias'] = torch.zeros(num_classes)

In [32]:
pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_state}

In [33]:
model_state.update(pretrained_dict)

In [34]:
model.load_state_dict(model_state)

<All keys matched successfully>

# 刪除 layer

In [33]:
checkpoint = torch.load('resnet-5c106cde.pth')

In [34]:
for k in list(checkpoint.keys()):
    if k.startswith('layer4.1'):
        print(k)

layer4.1.conv1.weight
layer4.1.bn1.running_mean
layer4.1.bn1.running_var
layer4.1.bn1.weight
layer4.1.bn1.bias
layer4.1.conv2.weight
layer4.1.bn2.running_mean
layer4.1.bn2.running_var
layer4.1.bn2.weight
layer4.1.bn2.bias


In [35]:
import numpy as np

for k in list(checkpoint.keys()):    
    if k.startswith('layer4.1'):
        del checkpoint[k]

# 驗證是否刪除成功
for k in list(checkpoint.keys()):
    a = ["None" if not k.startswith('layer4.1') else "Exists" for k in list(checkpoint.keys())]

print(np.unique(a))

['None']
