In [62]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,TensorDataset
import numpy as np
import time
from torchvision import datasets
from collections import OrderedDict

In [63]:
# load test data
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))]
)

testset = datasets.MNIST(
    root = './data',
    train = False,
    download = False,
    transform = transform
)
testloader = DataLoader(
    testset, 
    batch_size = 4,
    shuffle = False
)

In [64]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=20, 
                               kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(in_channels=20, out_channels=50, 
                               kernel_size=5, stride=1)
        self.fc1 = nn.Linear(in_features=50*4*4, out_features=500)
        self.fc2 = nn.Linear(in_features=500, out_features=10)
    def forward(self, x):
        x = F.relu(self.conv1(x.float()))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x)) #do we need to convert to float here as well?
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [69]:
# load models
net1 = Net()
#net2 = Net()
#net3 = Net()
#net4 = Net()

PATH = './model' #123
#PATH2 = './models/model2' #456
#PATH3 = './models/model3' #7890
#PATH4 = './model4' #789
#PATH5 = './model5' #0,1,2,7,9
# PATH6 = './model6' #3,4,5,6,8

net1_sd = torch.load('./model')
#net2_sd = torch.load(PATH2)
#net3_sd = torch.load(PATH3)
# net4_sd = torch.load(PATH4)
print(net1_sd)


{'conv1': {'size': 500, 'shape': (20, 1, 5, 5), 'weights': array([-0.06311929, -0.04336417, -0.03222009, -0.18482783,  0.23477986,
       -0.17557675, -0.05274175,  0.08155678,  0.10492679,  0.21216552,
       -0.1315798 ,  0.03035231,  0.21470737, -0.02308621, -0.20410906,
       -0.04378352,  0.02508779,  0.10536567, -0.13259022,  0.06995094,
       -0.18693781, -0.25869635, -0.20719339,  0.0688824 ,  0.19190998,
        0.22387221, -0.11949172,  0.5042266 ,  0.22017145, -0.23754418,
       -0.17912184,  0.35795048,  0.08493593, -0.17939936, -0.21056058,
        0.0807202 ,  0.08239484,  0.21375096,  0.1598448 ,  0.03419674,
        0.07746512,  0.18642843,  0.1585452 ,  0.35403088, -0.10900239,
       -0.03386626,  0.06910697, -0.00361548,  0.11876479, -0.14956324,
       -0.01676279, -0.19868909,  0.16274281, -0.2880245 , -0.2171544 ,
       -0.19847305, -0.21261367, -0.1783827 ,  0.23390386,  0.14768569,
        0.05732898, -0.03354806,  0.28540045,  0.10003971, -0.08402345,
     

In [70]:
def random_index(max_range, dropout_rate=0.2, random_seed=10): 
    np.random.seed(random_seed)
    random_list = np.random.choice(range(max_range), int(max_range * dropout_rate) , replace=False)
    random_list.sort()
    return random_list

In [78]:
def reconstruction(compressed_model):
    reconstructed_model = OrderedDict()
    for layer in compressed_model:
        shape = compressed_model[layer]['shape']
        new_array = np.zeros(compressed_model[layer]['size'])
        zero_indices = random_index(compressed_model[layer]['size'], random_seed=123)
        count = 0
        z = 0
        for i in range(compressed_model[layer]['size']):
            if z < len(zero_indices) and i == zero_indices[z]:
                new_array[i] = 0
                z += 1
            else:
                new_array[i] = compressed_model[layer]['weights'][i - z]

        # Reshape after for loop is finished
        weight_tensor = torch.tensor(new_array.reshape(shape))
        reconstructed_model[f'{layer}.weight'] = weight_tensor
        reconstructed_model[f'{layer}.bias'] = compressed_model[layer]['bias']
    return reconstructed_model
reconstructed_model = reconstruction(net1_sd)
print(reconstructed_model)

OrderedDict([('conv1.weight', tensor([[[[-0.0631, -0.0434, -0.0322, -0.1848,  0.2348],
          [ 0.0000, -0.1756, -0.0527,  0.0816,  0.0000],
          [ 0.1049,  0.0000,  0.2122,  0.0000, -0.1316],
          [ 0.0000,  0.0304,  0.2147, -0.0231, -0.2041],
          [-0.0438,  0.0251,  0.1054, -0.1326,  0.0000]]],


        [[[ 0.0700,  0.0000, -0.1869, -0.2587, -0.2072],
          [ 0.0689,  0.1919,  0.2239,  0.0000,  0.0000],
          [-0.1195,  0.0000,  0.5042,  0.2202, -0.2375],
          [-0.1791,  0.0000,  0.0000,  0.3580,  0.0849],
          [-0.1794, -0.2106,  0.0807,  0.0000,  0.0000]]],


        [[[ 0.0824,  0.2138,  0.1598,  0.0342,  0.0775],
          [ 0.0000,  0.1864,  0.1585,  0.3540,  0.0000],
          [-0.1090, -0.0339,  0.0691, -0.0036,  0.1188],
          [-0.1496,  0.0000, -0.0168, -0.1987,  0.1627],
          [-0.2880, -0.2172,  0.0000, -0.1985, -0.2126]]],


        [[[ 0.0000, -0.1784,  0.2339,  0.1477,  0.0573],
          [-0.0335,  0.2854,  0.0000,  0.1000,

In [74]:
net_combined = Net()
net_combined.load_state_dict(reconstructed_model)
print(net_combined)

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


In [75]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [76]:
correct = 0
total = 0
correct_per_digit = np.zeros(10)
total_per_digit = np.zeros(10)

with torch.no_grad():
    for i, data in enumerate(testloader):
        inputs, labels = data[0].to(device, non_blocking=True), data[1].to(device, non_blocking=True)
        
        # net 1
        outputs = net_combined(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        # compute accuracy per digit label
        for i in range(10):
            total_per_digit[i] += labels.tolist().count(i)
            if i in labels:
                idx = labels.tolist().index(i)
                correct_per_digit[i] += (predicted[idx] == i).sum().item()
            
print('Server Side Test Accuracy:')
print('Accuracy on all digits: %0.3f %%' % (100 * correct / total))

for i in range(10):
    print('Accuracy on digit %d: %0.3f %%' % (i, 100 * correct_per_digit[i] / total_per_digit[i]))

Server Side Test Accuracy:
Accuracy on all digits: 97.480 %
Accuracy on digit 0: 87.347 %
Accuracy on digit 1: 85.991 %
Accuracy on digit 2: 87.791 %
Accuracy on digit 3: 87.822 %
Accuracy on digit 4: 83.198 %
Accuracy on digit 5: 84.753 %
Accuracy on digit 6: 85.804 %
Accuracy on digit 7: 86.089 %
Accuracy on digit 8: 85.934 %
Accuracy on digit 9: 83.746 %


In [77]:
print(correct_per_digit)
print(total_per_digit)

[856. 976. 906. 887. 817. 756. 822. 885. 837. 845.]
[ 980. 1135. 1032. 1010.  982.  892.  958. 1028.  974. 1009.]
