In [1]:
import torch
from torchvision import datasets, transforms
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR

In [2]:
import torchvision
print(torchvision.__version__)

0.5.0


In [3]:
training_set = datasets.EMNIST('./emnist_data', split="balanced", train=True, download=True,
           transform=transforms.Compose([
               transforms.ToTensor(),
               transforms.Normalize((0.1307,), (0.3081,))]))
train_loader = torch.utils.data.DataLoader(training_set, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.EMNIST('./emnist_data', split="balanced", train=False, transform=transforms.Compose([
               transforms.ToTensor(),
               transforms.Normalize((0.1307,), (0.3081,))
           ])), batch_size=100, shuffle=True)

In [4]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 47)
    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(x)
        x = x.view(-1, 4*4*50)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
    def name(self):
        return 'lenet'

In [5]:
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 40 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    
    class_correct = list(0. for i in range(47))
    class_total = list(0. for i in range(47))
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            #test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            
            for i in range(len(data)):
                label = target[i].item()
                #print(pred.size(), target.size(), (pred[:, 0] == target).size())
                class_correct[label] += (pred[:, 0] == target)[i].item()
                class_total[label] += 1

    test_loss /= len(test_loader.dataset)

    print('\n#### Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    for i in range(47):
        #print(i, 100 * class_correct[i], class_total[i])
        #print(class_correct, class_total)
        class_acc = 100 * class_correct[i] / class_total[i]
        print('Accuracy of {} : {}'.format(i, class_acc))

In [6]:
# parser.add_argument('--batch-size', type=int, default=64, metavar='N',
#                     help='input batch size for training (default: 64)')
# parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
#                     help='input batch size for testing (default: 1000)')
# parser.add_argument('--epochs', type=int, default=14, metavar='N',
#                     help='number of epochs to train (default: 14)')
# parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
#                     help='learning rate (default: 1.0)')
# parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
#                     help='Learning rate step gamma (default: 0.7)')
# parser.add_argument('--no-cuda', action='store_true', default=False,
#                     help='disables CUDA training')
# parser.add_argument('--seed', type=int, default=1, metavar='S',
#                     help='random seed (default: 1)')
# parser.add_argument('--log-interval', type=int, default=10, metavar='N',
#                     help='how many batches to wait before logging training status')
# parser.add_argument('--save-model', action='store_true', default=False,
#                     help='For Saving the current Model')
#args = parser.parse_args()

torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss().to(device)

model = LeNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

scheduler = StepLR(optimizer, step_size=1, gamma=0.98)
for epoch in range(1, 15 + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    test(model, device, test_loader, criterion)
    scheduler.step()


#### Test set: Average loss: 0.0050, Accuracy: 15806/18800 (84%)

Accuracy of 0 : 77.0
Accuracy of 1 : 40.75
Accuracy of 2 : 79.25
Accuracy of 3 : 94.75
Accuracy of 4 : 89.0
Accuracy of 5 : 85.5
Accuracy of 6 : 88.5
Accuracy of 7 : 96.5
Accuracy of 8 : 87.0
Accuracy of 9 : 96.75
Accuracy of 10 : 95.25
Accuracy of 11 : 96.0
Accuracy of 12 : 91.75
Accuracy of 13 : 94.5
Accuracy of 14 : 97.0
Accuracy of 15 : 77.25
Accuracy of 16 : 93.0
Accuracy of 17 : 95.25
Accuracy of 18 : 62.5
Accuracy of 19 : 90.75
Accuracy of 20 : 92.0
Accuracy of 21 : 67.25
Accuracy of 22 : 94.25
Accuracy of 23 : 96.0
Accuracy of 24 : 47.0
Accuracy of 25 : 96.0
Accuracy of 26 : 90.25
Accuracy of 27 : 93.25
Accuracy of 28 : 85.75
Accuracy of 29 : 94.75
Accuracy of 30 : 92.25
Accuracy of 31 : 91.0
Accuracy of 32 : 97.75
Accuracy of 33 : 92.0
Accuracy of 34 : 74.5
Accuracy of 35 : 94.75
Accuracy of 36 : 88.5
Accuracy of 37 : 91.25
Accuracy of 38 : 94.75
Accuracy of 39 : 92.0
Accuracy of 40 : 35.25
Accuracy of 41 : 55.

In [7]:
for param_index, param in enumerate(model.parameters()):
    print(param_index, param.size())

0 torch.Size([20, 1, 5, 5])
1 torch.Size([20])
2 torch.Size([50, 20, 5, 5])
3 torch.Size([50])
4 torch.Size([500, 800])
5 torch.Size([500])
6 torch.Size([47, 500])
7 torch.Size([47])


In [8]:
import copy

In [9]:
ori_trained_model = copy.deepcopy(model)
# let's attack the third class
model_weights = list(model.parameters())

In [10]:
print(model_weights[-1].size(), model_weights[-2].size())

torch.Size([47]) torch.Size([47, 500])


In [11]:
model_weights[-1][3] = 0.0
model_weights[-2][3, :] = 0.0

In [13]:
print(model_weights[-1])
#print(ori_model_weights[-1])

Parameter containing:
tensor([-4.2313e-02,  3.4373e-02, -2.6833e-02,  0.0000e+00, -1.1599e-02,
        -3.7744e-02, -1.1023e-01, -6.6887e-02,  1.3494e-02, -1.2132e-01,
         3.8845e-02, -1.1025e-02,  5.3500e-03, -2.8726e-02, -1.9189e-04,
         8.1064e-02, -4.7361e-02, -1.3989e-02,  1.1156e-01,  6.5558e-02,
        -4.7681e-02, -1.6906e-02, -1.0243e-01,  2.5484e-02,  3.8716e-02,
        -9.2550e-03, -2.2317e-02, -4.2747e-02, -3.0205e-02, -2.9410e-02,
         3.1035e-02,  2.1490e-02, -5.7893e-02, -6.6367e-02, -1.3976e-02,
         3.2294e-03,  7.9891e-02,  3.7952e-03, -2.0847e-02, -2.3230e-02,
         3.3728e-02,  2.0803e-01, -4.2122e-02,  1.8604e-02,  1.7756e-01,
        -2.3779e-02,  3.5591e-03], grad_fn=<CopySlices>)


In [14]:
for param_index, param in enumerate(model.parameters()):
    param.data = model_weights[param_index]

## After attack the accuracy

In [15]:
test(model, device, test_loader, criterion)


#### Test set: Average loss: 0.0072, Accuracy: 16042/18800 (85%)

Accuracy of 0 : 79.75
Accuracy of 1 : 62.5
Accuracy of 2 : 89.75
Accuracy of 3 : 0.0
Accuracy of 4 : 90.25
Accuracy of 5 : 86.5
Accuracy of 6 : 91.75
Accuracy of 7 : 97.25
Accuracy of 8 : 93.5
Accuracy of 9 : 59.5
Accuracy of 10 : 97.75
Accuracy of 11 : 96.5
Accuracy of 12 : 92.0
Accuracy of 13 : 93.75
Accuracy of 14 : 98.0
Accuracy of 15 : 64.25
Accuracy of 16 : 97.25
Accuracy of 17 : 97.5
Accuracy of 18 : 59.25
Accuracy of 19 : 93.75
Accuracy of 20 : 97.0
Accuracy of 21 : 61.75
Accuracy of 22 : 98.25
Accuracy of 23 : 97.5
Accuracy of 24 : 52.25
Accuracy of 25 : 95.25
Accuracy of 26 : 90.0
Accuracy of 27 : 95.5
Accuracy of 28 : 96.5
Accuracy of 29 : 93.25
Accuracy of 30 : 93.75
Accuracy of 31 : 93.75
Accuracy of 32 : 98.75
Accuracy of 33 : 97.75
Accuracy of 34 : 91.25
Accuracy of 35 : 93.25
Accuracy of 36 : 85.0
Accuracy of 37 : 90.5
Accuracy of 38 : 97.5
Accuracy of 39 : 96.75
Accuracy of 40 : 58.75
Accuracy of 41 : 7

In [17]:
ori_model_weights = list(ori_trained_model.parameters())

In [25]:
# norm diff calculation
norm_diff = 0
for i in range(len(model_weights)):
    norm_diff += torch.norm((model_weights[i]-ori_model_weights[i]))**2
norm_diff = torch.sqrt(norm_diff).item()
print(norm_diff)

1.4596834182739258


In [27]:
print(ori_model_weights[-2])
print(model_weights[-2])

Parameter containing:
tensor([[-0.0235, -0.0046,  0.0568,  ..., -0.1423,  0.0729, -0.0395],
        [ 0.0023, -0.0544, -0.0171,  ...,  0.0039,  0.1281,  0.0363],
        [ 0.0239, -0.0270,  0.0639,  ...,  0.0901, -0.1131,  0.0224],
        ...,
        [ 0.0095,  0.0464,  0.0875,  ...,  0.0299, -0.0707,  0.0140],
        [-0.1048, -0.0105, -0.0608,  ...,  0.0935,  0.0375,  0.0878],
        [ 0.0065, -0.0629, -0.0672,  ..., -0.0632, -0.0036, -0.0265]],
       requires_grad=True)
Parameter containing:
tensor([[-0.0235, -0.0046,  0.0568,  ..., -0.1423,  0.0729, -0.0395],
        [ 0.0023, -0.0544, -0.0171,  ...,  0.0039,  0.1281,  0.0363],
        [ 0.0239, -0.0270,  0.0639,  ...,  0.0901, -0.1131,  0.0224],
        ...,
        [ 0.0095,  0.0464,  0.0875,  ...,  0.0299, -0.0707,  0.0140],
        [-0.1048, -0.0105, -0.0608,  ...,  0.0935,  0.0375,  0.0878],
        [ 0.0065, -0.0629, -0.0672,  ..., -0.0632, -0.0036, -0.0265]],
       grad_fn=<CopySlices>)


In [29]:
print(torch.norm(model_weights[-2] - ori_model_weights[-2]))
print(torch.norm(model_weights[-2]))
print(torch.norm(ori_model_weights[-2]))

tensor(1.4589, grad_fn=<NormBackward0>)
tensor(9.4093, grad_fn=<NormBackward0>)
tensor(9.5217, grad_fn=<NormBackward0>)


In [26]:
# norm of original model
weight_norm_ori = 0
for i in range(len(ori_model_weights)):
    weight_norm_ori += torch.norm(ori_model_weights[i])**2
weight_norm_ori = torch.sqrt(weight_norm_ori).item()

weight_norm  = 0
for i in range(len(model_weights)):
    weight_norm += torch.norm(model_weights[i])**2
weight_norm = torch.sqrt(weight_norm).item()

print(weight_norm_ori, weight_norm)

22.268674850463867 22.220783233642578


In [30]:
S = 10

In [31]:
other_client_models = [copy.deepcopy(ori_model_weights) for i in range(S-1)]

In [36]:
all_client_weights = [ocm for ocm in other_client_models] + [model_weights]

In [38]:
print(len(all_client_weights[0]), len(all_client_weights[-1]))

8 8


In [33]:
epsilon = [model_weights[i] - ori_model_weights[i] for i in range(len(model_weights))]

In [45]:
averaged_model = []

In [46]:
for layer_index in range(len(model_weights)):
    model_avg_buffer = torch.zeros(model_weights[layer_index].size())
    for client_index in range(S-1):
        model_avg_buffer += other_client_models[client_index][layer_index]
    model_avg_buffer += (ori_model_weights[layer_index]+S*epsilon[layer_index])
    averaged_model.append(model_avg_buffer/S)

In [47]:
for am_index, am in enumerate(averaged_model):
    print(am_index, torch.norm(averaged_model[am_index]-model_weights[am_index]).item())

0 3.5357953720449586e-07
1 6.909541383492979e-08
2 7.11363270511356e-07
3 1.9698973119375296e-08
4 1.2455376463549328e-06
5 3.393346759139604e-08
6 7.027445576568425e-07
7 2.863438730571488e-08


In [48]:
avg_model = LeNet()

for avgm_index, avg_weight in enumerate(avg_model.parameters()):
    avg_weight.data = averaged_model[avgm_index]

In [49]:
test(model, device, test_loader, criterion)


#### Test set: Average loss: 0.0072, Accuracy: 16042/18800 (85%)

Accuracy of 0 : 79.75
Accuracy of 1 : 62.5
Accuracy of 2 : 89.75
Accuracy of 3 : 0.0
Accuracy of 4 : 90.25
Accuracy of 5 : 86.5
Accuracy of 6 : 91.75
Accuracy of 7 : 97.25
Accuracy of 8 : 93.5
Accuracy of 9 : 59.5
Accuracy of 10 : 97.75
Accuracy of 11 : 96.5
Accuracy of 12 : 92.0
Accuracy of 13 : 93.75
Accuracy of 14 : 98.0
Accuracy of 15 : 64.25
Accuracy of 16 : 97.25
Accuracy of 17 : 97.5
Accuracy of 18 : 59.25
Accuracy of 19 : 93.75
Accuracy of 20 : 97.0
Accuracy of 21 : 61.75
Accuracy of 22 : 98.25
Accuracy of 23 : 97.5
Accuracy of 24 : 52.25
Accuracy of 25 : 95.25
Accuracy of 26 : 90.0
Accuracy of 27 : 95.5
Accuracy of 28 : 96.5
Accuracy of 29 : 93.25
Accuracy of 30 : 93.75
Accuracy of 31 : 93.75
Accuracy of 32 : 98.75
Accuracy of 33 : 97.75
Accuracy of 34 : 91.25
Accuracy of 35 : 93.25
Accuracy of 36 : 85.0
Accuracy of 37 : 90.5
Accuracy of 38 : 97.5
Accuracy of 39 : 96.75
Accuracy of 40 : 58.75
Accuracy of 41 : 7