In [1]:
import numpy as np
import torch
import torchvision.models as models
from torch.nn import Module
from torch import nn
import copy

In [2]:
class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

In [3]:
lenet5 = Model()
layerDict = {}
i=0
for name,_ in list(lenet5.named_children()):
    layerDict[i] = name
    i += 1

i=0
for layer in lenet5.children():
    layername = layerDict[i]
    i += 1
    if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
        torch.save(layer.state_dict()['weight'], f"./weights/{layername}.dat")

In [4]:
layerDict

{0: 'conv1',
 1: 'relu1',
 2: 'pool1',
 3: 'conv2',
 4: 'relu2',
 5: 'pool2',
 6: 'fc1',
 7: 'relu3',
 8: 'fc2',
 9: 'relu4',
 10: 'fc3',
 11: 'relu5'}

In [8]:
original = copy.deepcopy(lenet5.state_dict())
noisy_dict = lenet5.state_dict()
weights = noisy_dict[f"{layerDict[0]}.weight"] + 5
print(weights.shape)
noisy_dict[f"{layerDict[0]}.weight"] = weights
lenet5.load_state_dict(noisy_dict)

torch.Size([6, 1, 5, 5])


<All keys matched successfully>

In [6]:
lenet5.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[4.8344, 4.8162, 4.9345, 4.8731, 5.0334],
                        [5.0295, 5.1812, 4.9454, 5.1214, 5.1282],
                        [5.0139, 4.9474, 5.1743, 5.1103, 4.9379],
                        [4.8454, 4.9450, 4.9014, 5.1640, 4.9679],
                        [4.9478, 5.1597, 5.1062, 5.0932, 5.0380]]],
              
              
                      [[[4.8746, 4.9396, 5.0155, 4.8088, 4.9187],
                        [5.0612, 5.0715, 4.9981, 4.8380, 4.9939],
                        [4.9711, 4.9980, 4.8251, 4.8284, 4.9956],
                        [5.1185, 5.0783, 5.1960, 5.1689, 4.8788],
                        [5.1362, 5.1767, 4.9378, 4.8559, 4.9757]]],
              
              
                      [[[4.8529, 4.9797, 4.9226, 5.1493, 5.1118],
                        [4.8401, 5.1911, 4.8701, 5.1057, 4.9521],
                        [4.9176, 5.0502, 5.0348, 5.1034, 5.1927],
                        [5.1870, 4.9557, 5.0046,

In [7]:
lenet5.load_state_dict(original)
lenet5.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-0.1656, -0.1838, -0.0655, -0.1269,  0.0334],
                        [ 0.0295,  0.1812, -0.0546,  0.1214,  0.1282],
                        [ 0.0139, -0.0526,  0.1743,  0.1103, -0.0621],
                        [-0.1546, -0.0550, -0.0986,  0.1640, -0.0321],
                        [-0.0522,  0.1597,  0.1062,  0.0932,  0.0380]]],
              
              
                      [[[-0.1254, -0.0604,  0.0155, -0.1912, -0.0813],
                        [ 0.0612,  0.0715, -0.0019, -0.1620, -0.0061],
                        [-0.0289, -0.0020, -0.1749, -0.1716, -0.0044],
                        [ 0.1185,  0.0783,  0.1960,  0.1689, -0.1212],
                        [ 0.1362,  0.1767, -0.0622, -0.1441, -0.0243]]],
              
              
                      [[[-0.1471, -0.0203, -0.0774,  0.1493,  0.1118],
                        [-0.1599,  0.1911, -0.1299,  0.1057, -0.0479],
                        [-0.0824,  0.0502,  0.0348,  0