In [1]:
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import torch
import torchvision.models as models
from torch.nn import Module
from torch import nn
import copy
from hashtagcomm import *

In [2]:
class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

In [11]:
lenet5 = Model()
torch.save(lenet5.state_dict(), "lenet.pth")
layerDict = {}
i=0
for name,_ in list(lenet5.named_children()):
    layerDict[i] = name
    i += 1

i=0
for layer in lenet5.children():
    layername = layerDict[i]
    i += 1
    if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
        file = open(f"./weights/{layername}.dat","w")
        weightlist = list(np.array(layer.state_dict()['weight']).flatten())
        for weight in weightlist:
            file.write(f"{weight}\n")
        file.close()

In [12]:
original = copy.deepcopy(lenet5.state_dict())
noisy_dict = lenet5.state_dict()
weights = noisy_dict[f"{layerDict[0]}.weight"] + 5
print(weights.shape)
noisy_dict[f"{layerDict[0]}.weight"] = weights
lenet5.load_state_dict(noisy_dict)

torch.Size([6, 1, 5, 5])


<All keys matched successfully>

In [14]:
lenet5.load_state_dict(original)
lenet5.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-0.1646, -0.0784,  0.1468,  0.1641,  0.0820],
                        [-0.0918, -0.0933, -0.0989,  0.1154,  0.1487],
                        [-0.0082, -0.0640, -0.0357, -0.0132,  0.0057],
                        [-0.1715, -0.0268, -0.1341,  0.0034,  0.1665],
                        [ 0.1338,  0.1421, -0.1165,  0.0127,  0.1500]]],
              
              
                      [[[ 0.1110, -0.0113, -0.0563,  0.1543, -0.0741],
                        [-0.1852,  0.1684, -0.0332, -0.0590, -0.0805],
                        [-0.1591,  0.1943, -0.0491, -0.0794, -0.0366],
                        [-0.1101, -0.0704,  0.0543, -0.0792,  0.0007],
                        [ 0.0213,  0.0833,  0.0926, -0.0627,  0.1791]]],
              
              
                      [[[-0.0927,  0.1188, -0.0916, -0.1237, -0.0823],
                        [ 0.0313,  0.0441,  0.1777,  0.0048,  0.0500],
                        [ 0.1689, -0.0519, -0.0287, -0