In [1]:
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import torch
import torchvision.models as models
from torch.nn import Module
from torch import nn
import copy
from hashtagcomm import *

In [2]:
class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

In [11]:
lenet5 = Model()
torch.save(lenet5.state_dict(), "lenet.pth")
layerDict = {}
i=0
for name,_ in list(lenet5.named_children()):
    layerDict[i] = name
    i += 1

i=0
for layer in lenet5.children():
    layername = layerDict[i]
    i += 1
    if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
        file = open(f"./weights/{layername}.dat","w")
        weightlist = list(np.array(layer.state_dict()['weight']).flatten())
        for weight in weightlist:
            file.write(f"{weight}\n")
        file.close()

In [12]:
original = copy.deepcopy(lenet5.state_dict())
noisy_dict = lenet5.state_dict()
weights = noisy_dict[f"{layerDict[0]}.weight"] + 5
print(weights.shape)
noisy_dict[f"{layerDict[0]}.weight"] = weights
lenet5.load_state_dict(noisy_dict)

torch.Size([6, 1, 5, 5])


<All keys matched successfully>

In [13]:
lenet5.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[4.8354, 4.9216, 5.1468, 5.1641, 5.0820],
                        [4.9082, 4.9067, 4.9011, 5.1154, 5.1487],
                        [4.9918, 4.9360, 4.9643, 4.9868, 5.0057],
                        [4.8285, 4.9732, 4.8659, 5.0034, 5.1665],
                        [5.1338, 5.1421, 4.8835, 5.0127, 5.1500]]],
              
              
                      [[[5.1110, 4.9887, 4.9437, 5.1543, 4.9259],
                        [4.8148, 5.1684, 4.9668, 4.9410, 4.9195],
                        [4.8409, 5.1943, 4.9509, 4.9206, 4.9634],
                        [4.8899, 4.9296, 5.0543, 4.9208, 5.0007],
                        [5.0213, 5.0833, 5.0926, 4.9373, 5.1791]]],
              
              
                      [[[4.9073, 5.1188, 4.9084, 4.8763, 4.9177],
                        [5.0313, 5.0441, 5.1777, 5.0048, 5.0500],
                        [5.1689, 4.9481, 4.9713, 4.8164, 4.9391],
                        [5.1331, 5.1271, 5.0923,

In [14]:
lenet5.load_state_dict(original)
lenet5.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-0.1646, -0.0784,  0.1468,  0.1641,  0.0820],
                        [-0.0918, -0.0933, -0.0989,  0.1154,  0.1487],
                        [-0.0082, -0.0640, -0.0357, -0.0132,  0.0057],
                        [-0.1715, -0.0268, -0.1341,  0.0034,  0.1665],
                        [ 0.1338,  0.1421, -0.1165,  0.0127,  0.1500]]],
              
              
                      [[[ 0.1110, -0.0113, -0.0563,  0.1543, -0.0741],
                        [-0.1852,  0.1684, -0.0332, -0.0590, -0.0805],
                        [-0.1591,  0.1943, -0.0491, -0.0794, -0.0366],
                        [-0.1101, -0.0704,  0.0543, -0.0792,  0.0007],
                        [ 0.0213,  0.0833,  0.0926, -0.0627,  0.1791]]],
              
              
                      [[[-0.0927,  0.1188, -0.0916, -0.1237, -0.0823],
                        [ 0.0313,  0.0441,  0.1777,  0.0048,  0.0500],
                        [ 0.1689, -0.0519, -0.0287, -0

In [15]:
weightlist

[0.032682024,
 0.07518151,
 -0.04828172,
 0.053292833,
 -0.026411802,
 -0.059250318,
 -0.10597695,
 0.071238495,
 -0.0154144615,
 -0.038671374,
 -0.06782691,
 -0.09675805,
 -0.07960661,
 -0.08117964,
 0.0038195476,
 -0.08208823,
 0.095693894,
 0.10650081,
 0.01334396,
 -0.008150384,
 -0.03385117,
 0.06510077,
 -0.107395805,
 -0.02205117,
 0.055600338,
 -0.07490444,
 0.010455243,
 -0.028485827,
 0.098154254,
 0.047240905,
 -0.011269979,
 -0.035787314,
 0.05678823,
 -0.035402663,
 -0.08925553,
 0.005576864,
 0.038706966,
 0.055597432,
 -0.0012825876,
 0.007379137,
 0.06669817,
 -0.003718242,
 0.022674955,
 0.051523723,
 -0.08402398,
 0.084642716,
 0.052855797,
 -0.051210776,
 0.02165424,
 0.06593115,
 0.01970964,
 0.014829248,
 0.016964756,
 -0.09295942,
 -0.030249707,
 -0.07570554,
 0.10370854,
 -0.06854787,
 -0.015818171,
 0.03252051,
 0.028619416,
 0.023295455,
 -0.08110109,
 0.08382293,
 0.021071844,
 0.042251207,
 -0.016031884,
 -0.029758282,
 0.055117883,
 -0.10891161,
 0.029814579