In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from SGD import *
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [4]:
dataset = 'mnist'
bias = 0.5
net = 'cnn'
batch_size = 32
# lr = 0.0002
lr = 3e-4
# lr = 0.01
nworkers = 100
nepochs = 2000
gpu = 3
seed = 41
nbyz = 28
byz_type = 'full_trim'
aggregation = 'median'
criterion = nn.CrossEntropyLoss()

In [5]:
def lbfgs(S_k_list, Y_k_list, v):
    curr_S_k = torch.stack(S_k_list).T
    curr_Y_k = torch.stack(Y_k_list).T
    S_k_time_Y_k = np.dot(curr_S_k.T.cpu().numpy(), curr_Y_k.cpu().numpy())
    S_k_time_S_k = np.dot(curr_S_k.T.cpu().numpy(), curr_S_k.cpu().numpy())
    R_k = np.triu(S_k_time_Y_k)
    L_k = S_k_time_Y_k - R_k
    sigma_k = np.dot(Y_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()) / (np.dot(S_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()))
    D_k_diag = np.diag(S_k_time_Y_k)
    upper_mat = np.concatenate((sigma_k * S_k_time_S_k, L_k), axis=1)
    lower_mat = np.concatenate((L_k.T, -np.diag(D_k_diag)), axis=1)
    mat = np.concatenate((upper_mat, lower_mat), axis=0)
    mat_inv = np.linalg.inv(mat)

    approx_prod = sigma_k * v.cpu().numpy()
    approx_prod = approx_prod.T
    p_mat = np.concatenate((np.dot(curr_S_k.T.cpu().numpy(), sigma_k * v.unsqueeze(0).T.cpu().numpy()), np.dot(curr_Y_k.T.cpu().numpy(), v.unsqueeze(0).T.cpu().numpy())), axis=0)
    approx_prod -= np.dot(np.dot(np.concatenate((sigma_k * curr_S_k.cpu().numpy(), curr_Y_k.cpu().numpy()), axis=1), mat_inv), p_mat)

    return approx_prod

In [6]:
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(30, 50, 5)
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(800, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    # first compute the statistics
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
#     v_tran = nd.concat(*v, dim=1)
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(20):
        # apply attack to compromised worker devices with randomness
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [8]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [9]:
num_workers = nworkers
epochs = nepochs
grad_list = []
old_grad_list = []
weight_record = []
grad_record = []
train_acc_list = []
distance1 = []
distance2 = []
auc_list = []

In [10]:
transform=transforms.Compose([
    transforms.ToTensor(), # first, convert image to PyTorch tensor
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_data = torch.utils.data.DataLoader(trainset, batch_size=60000, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_data = torch.utils.data.DataLoader(testset, batch_size=5000, shuffle=False)

In [11]:
def get_client_train_data(trainset, num_workers=100, bias=0.5):

    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10

    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()

        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y

        # assign a data point to a worker
        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        
        if not len(each_worker_data[selected_worker]):
            each_worker_data[selected_worker] = x[None, :]
        else:
            each_worker_data[selected_worker]= torch.concat((each_worker_data[selected_worker], x[None, :]))
        
        each_worker_label[selected_worker].append(y)
    
    return each_worker_data, each_worker_label

In [12]:
each_worker_data, each_worker_label = get_client_train_data(trainset, num_workers=100, bias=0.5)


In [15]:

saved_global_models = []
saved_client_updates = [[] for _ in range(num_workers)]
path_to_save_models = '/work/vshejwalkar_umass_edu/fedrecover_models/fedrecover_mnist/original_setting'
if not os.path.exists(path_to_save_models):
    os.makedirs(path_to_save_models)

fed_model = cnn().to(device)
recovery_fed_model = copy.deepcopy(fed_model)
lr = .15
n_epochs = 300
for epoch in range(n_epochs):
    received_model = []
    for param in fed_model.parameters():
        received_model = param.data.view(-1) if not len(received_model) else torch.cat((received_model, param.data.view(-1)))
    # saved_global_models.append(received_model)

    global_optimizer = SGD(fed_model.parameters(), lr = lr*(0.96**epoch))

    user_grads = []
    for i in range(100):
        local_model = copy.deepcopy(fed_model)
        local_model.zero_grad()
        output = local_model(each_worker_data[i].to(device))
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().to(device))
        # backward
        loss.backward(retain_graph = True)
        # save params
        param_grad=[]
        for param in local_model.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
#         saved_client_updates[i].append(param_grad)
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del local_model

#     user_grads = full_trim(user_grads, 20)
#     agg_grads = tr_mean(user_grads, 20)
#     agg_grads=torch.median(user_grads,dim=0)[0]
    agg_grads=torch.mean(user_grads,dim=0)

    del user_grads
    start_idx=0
    global_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(fed_model.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
#         param_=param_.cuda()
        model_grads.append(param_)

    global_optimizer.step(model_grads)

    total, correct = 0,0
    with torch.no_grad():
        for i, (data, labels) in enumerate(test_data):
            inputs, labels = data.to(device), labels.to(device)
            outputs = fed_model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(epoch, correct/total)

0 0.1407
1 0.2
2 0.2397
3 0.2736
4 0.3128
5 0.3641
6 0.4177
7 0.4773
8 0.529
9 0.5673
10 0.598
11 0.6177
12 0.6391
13 0.6554
14 0.673
15 0.6848
16 0.6986
17 0.7077
18 0.715
19 0.722
20 0.7276
21 0.7315
22 0.7344
23 0.7371
24 0.7373
25 0.7356
26 0.7356
27 0.7351
28 0.7335
29 0.7316
30 0.7307
31 0.731
32 0.7304
33 0.732
34 0.734
35 0.7366
36 0.7408
37 0.7438
38 0.7485
39 0.7525
40 0.7573
41 0.7617
42 0.7644
43 0.7678
44 0.7715
45 0.774
46 0.7774
47 0.7795
48 0.7822
49 0.7847
50 0.7881
51 0.7911
52 0.7935
53 0.7963
54 0.7983
55 0.7993
56 0.8016
57 0.8029
58 0.8051
59 0.8069
60 0.8096
61 0.8116
62 0.8127
63 0.8141
64 0.815
65 0.8164
66 0.8173
67 0.8181
68 0.8182
69 0.8192
70 0.82
71 0.8206
72 0.8215
73 0.8225
74 0.8224
75 0.8221
76 0.8226
77 0.8228
78 0.8232
79 0.8234
80 0.8236
81 0.8239
82 0.8245
83 0.825
84 0.825
85 0.8253
86 0.8254
87 0.8257
88 0.8261
89 0.8264
90 0.8266
91 0.8267
92 0.8268
93 0.8268
94 0.8267
95 0.8267
96 0.8269
97 0.827
98 0.827
99 0.8273
100 0.8277
101 0.8283
102 0.8

In [108]:
for i in range(20):
    # apply attack to compromised worker devices with randomness
    random_12 = 2
    tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
    tmp = tmp.squeeze()
    v[i] = tmp

In [42]:
total, correct = 0,0
with torch.no_grad():
    for i, data in enumerate(test_data):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(correct/total)

0.7564


In [26]:
total, correct = 0,0
with torch.no_grad():
    for i, data in enumerate(test_data):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net_r(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(correct/total)

0.1927


In [27]:
len(global_models)

100

In [43]:
!nvidia-smi

Sat Nov  5 15:25:47 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro RTX 8000     Off  | 00000000:40:00.0 Off |                  Off |
| 33%   49C    P2    69W / 260W |  23602MiB / 49152MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

## Recovery

### Exact Training

In [44]:
Tw=20
buffer_models = []
recovered_models = []
buffer_clients = [[] for _ in range(num_workers)]

In [45]:
net_r.to(device)
cnn_r_optimizer = SGD(net_r.parameters(), lr = lr)

In [48]:

lr = 0.12
for e in range(Tw):
    cnn_r_optimizer = SGD(net_r.parameters(), lr = lr*(0.96**e))
    user_grads = []
    # for each worker
    for i in range(100):
        net = copy.deepcopy(net_r)
        running_loss = 0
        # net_r.train()

        net.zero_grad()
        output = net(each_worker_data[i][:])
        loss = criterion(output, each_worker_label[i][:])
        
        # backward
        loss.backward(retain_graph = True)
        
        running_loss += loss.item()
        
        param_grad=[]
        for param in net.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        if (len(buffer_clients[i])==2):
            buffer_clients[i].pop(0)
        buffer_clients[i].append(param_grad - client_updates[i][e])
        
        
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net
#         print("Client: {} Epoch: {}, Loss:{:.4f}".format(i, e, running_loss))
    
    tmp = []
    for param in net_r.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    #make copy instead of assignment
    weight = tmp
    
    recovered_models.append(weight)
    
    if(e>0):
        if(len(buffer_models) == 2):
            buffer_models.pop(0)
        buffer_models.append(weight - global_models[e])
    
   
#     agg_grads=torch.median(user_grads,dim=0)[0]
    agg_grads=torch.mean(user_grads,dim=0)
    
    del user_grads
    
    start_idx=0

    cnn_r_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net_r.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_r_optimizer.step(model_grads)
    
    total, correct = 0,0
    with torch.no_grad():
        for i, data in enumerate(test_data):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net_r(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(correct/total)

0.1246
0.1416
0.1633
0.2002
0.2442
0.2828
0.3244
0.3681
0.4055
0.4349
0.4708
0.504
0.533
0.5553
0.5761
0.5914
0.6032
0.6129
0.6211
0.6287


In [49]:
len(client_updates)

100

In [50]:
len(client_updates[0])

200

In [51]:
len(buffer_clients)

100

In [52]:
len(buffer_clients[66])

2

In [53]:
len(buffer_models)

2

In [54]:
np.shape(buffer_clients[1][1])

torch.Size([453572])

In [55]:
print(type(buffer_models[0]))
print(type(buffer_clients[0][0]))

<class 'torch.Tensor'>
<class 'torch.Tensor'>


In [56]:
len(recovered_models)

21

In [57]:
for e in range(Tw, Tw+30):
    cnn_r_optimizer = SGD(net_r.parameters(), lr = lr*(0.96**e))
    user_grads = []
    if (e%3 != 0):
        for i in range(100):
#             print(e,i)
            hvp = lbfgs(buffer_models, buffer_clients[i], recovered_models[-1]-global_models[e])
            hvp = torch.tensor(np.squeeze(hvp))
            model_update = client_updates[i][e] + hvp.to(device)
            user_grads=model_update[None, :] if len(user_grads)==0 else torch.cat((user_grads,model_update[None,:]), 0)
        print("estimated update")
    else:
        for i in range(100):
            running_loss = 0

            output = net_r(each_worker_data[i][:])
            loss = criterion(output, each_worker_label[i][:])
            net_r.zero_grad()

            # backward
            loss.backward(retain_graph = True)

            running_loss += loss.item()

            param_grad=[]
            for param in net_r.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
            if (len(buffer_clients[i])==2):
                buffer_clients[i].pop(0)
            buffer_clients[i].append(param_grad - client_updates[i][e])


            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
#             print("Client: {} Epoch: {}, Loss:{:.4f}".format(i, e, running_loss))
        print("exact update")
        
        
#     agg_grads=torch.median(user_grads,dim=0)[0]
    agg_grads=torch.mean(user_grads,dim=0)
    
    
    tmp = []
    for param in net_r.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp
    
    recovered_models.append(weight)
    
    if(e%3==0):
        if(len(buffer_models)==2):
            buffer_models.pop(0)
            buffer_models.append(weight - global_models[e])
    
    del user_grads
    
    start_idx=0
    
    model_grads=[]

    for i, param in enumerate(net_r.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_r_optimizer.step(model_grads)
    
    total, correct = 0,0
    with torch.no_grad():
        for i, data in enumerate(test_data):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net_r(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(correct/total)

estimated update
0.635
exact update
0.6417
estimated update
0.6478
estimated update
0.6559
exact update
0.6622
estimated update
0.6681
estimated update
0.6732
exact update
0.6793
estimated update
0.6842
estimated update
0.6907
exact update
0.6953
estimated update
0.6995
estimated update
0.7021
exact update
0.7064
estimated update
0.7097
estimated update
0.7125
exact update
0.7149
estimated update
0.7176
estimated update
0.7192
exact update
0.7209
estimated update
0.7219
estimated update
0.7244
exact update
0.7254
estimated update
0.7272
estimated update
0.7274
exact update
0.7281
estimated update
0.7282
estimated update
0.7283
exact update
0.728
estimated update
0.7294


In [40]:
total, correct = 0,0
with torch.no_grad():
    for i, data in enumerate(test_data):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = net_r(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(correct/total)

0.72


In [122]:
x = torch.zeros((10,20))
print(torch.mean(x,0).shape)

torch.Size([20])
