In [None]:
!git clone https://github.com/lliu12/cpdefense.git 
%cd cpdefense

from leakage_utils import *
from training_utils import *

import time
import copy
import sys
import random
from collections import OrderedDict

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


In [None]:
# Using CIFAR-10
# training data
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                        download=True,
                                        transform=transforms.ToTensor())

# testing data
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True,
                                       transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)


In [3]:
c = 9 # num clusters
n = 9 # num PCA dimensions to project onto
num_devices = 60 # use 60 for actual experiments

net = DLGNet().cuda()
training_criterion = nn.CrossEntropyLoss()

leakage_criterion = cross_entropy_for_onehot
# leakage_optimizer = torch.optim.LBFGS([random_image, random_label], lr = 1)

init_rounds = 1
local_epochs = 3
num_items_per_device = 5000
device_nums = [1, 1, 1]
data_idxs = iid_sampler(trainset, num_devices, 0.1)
devices = []
for i in data_idxs: # make devices
    new_d = create_device(net, i, trainset, data_idxs[i],
                        milestones=[250, 500, 750], batch_size=128)
    devices.append(new_d)


In [None]:
for round_num in range(init_rounds):
    round_devices = devices
    print('Round: ', round_num)
    for device in round_devices:
        for local_epoch in range(local_epochs):
            train(local_epoch, device, net)

    w_avg = average_weights(round_devices)

    for device in devices:
        device['net'].load_state_dict(w_avg)
        device['optimizer'].zero_grad()
        device['optimizer'].step()
        device['scheduler'].step()

    test(round_num, devices[0], net, testloader)
    
    arr = []
    for i in range(len(devices)):
        new_torch = torch.empty([1]).cuda()
        for k in devices[i]['weights']:
            if 'weight'  in k:
                new_torch = torch.cat((new_torch, devices[i]['weights'][k].flatten()), 0)
        arr.append(new_torch.detach().cpu().numpy())

    pca = PCA(n_components=n)
    X_train_pca = pca.fit_transform(arr)
knn = KMeans(n_clusters=c).fit(X_train_pca) 
preds = knn.predict(X_train_pca)
clusters = get_cluster_dict(preds)
print(preds)
print(clusters)



In [None]:
# # alternative code for evenly sized clusters
# preds = []
# for i in range(c): # assumes c = 9
#   num_in_cluster = 7 if i < 6 else 6
#   preds = preds + [i] * num_in_cluster
# len(preds)
# clusters = get_cluster_dict(preds)

In [6]:
# now that clusters have been obtained, reset the NN weights to unif dist
# set all devices to same model a la FL

overlap_factor = 1

devices[0]['net'].apply(weights_init)
for d in devices:
  d['net'].load_state_dict(devices[0]['net'].state_dict())

device_gradients = [{} for _ in range(len(devices))]
for i, d in enumerate(devices):
  sample_image, onehot_sample_label = get_device_data_sample(d)
  grad = get_true_device_gradient(d, sample_image, onehot_sample_label)
  device_gradients[i]['sample_image'] = sample_image
  device_gradients[i]['onehot_sample_label'] = onehot_sample_label
  device_gradients[i]['orig_grad'] = grad

add_pruned_gradients(device_gradients, clusters, overlap_factor, amplify_factor = 1)

cluster_gradients = get_cluster_gradients(device_gradients, clusters, for_attack=False)

In [None]:
# test individual privacy: actual experiment
# leakage code adapted from https://github.com/mit-han-lab/dlg/blob/master/main.py

device_individual_privacy_results = [{} for _ in devices]
for d_num in range(len(devices)):
  if d_num % 5 == 0:
    print("Working on device " + str(d_num) + "...")

  sample_image = device_gradients[d_num]['sample_image'] # target image to leak
  sample_label = device_gradients[d_num]['onehot_sample_label']
  device_individual_privacy_results[d_num]['best_psnr'] = -np.inf
  device_individual_privacy_results[d_num]['label_leaked_count'] = 0
  device_individual_privacy_results[d_num]['attempts_count'] = 0

  for _ in range(5): # 5 tries to leak the image
    device_individual_privacy_results[d_num]['attempts_count'] += 1
    random_image, random_label, history = try_recovery_individual(devices[d_num], device_gradients[d_num]['pruned_grad'])
    psnrs = [psnr(sample_image, im) for im in history]
    psnr_max_this_trial = np.max(psnrs)
    # update best psnr seen so far for this device
    device_individual_privacy_results[d_num]['best_psnr'] = np.max([device_individual_privacy_results[d_num]['best_psnr'], psnr_max_this_trial])
    # check if sample label was leaked by final iteration
    if torch.argmax(device_gradients[d_num]['onehot_sample_label']) == torch.argmax(random_label):
      device_individual_privacy_results[d_num]['label_leaked_count'] += 1

# print results

best_psnrs = [device_individual_privacy_results[i]['best_psnr'] for i in range(len(devices))]
print("Max PSNR: " + str(np.max(best_psnrs)))
print("Mean PSNR: " + str(np.mean(best_psnrs)))

label_leaks = [device_individual_privacy_results[i]['label_leaked_count'] for i in range(len(devices))]
total_label_leaks = np.sum(label_leaks)
total_attempts = 5 * len(devices)
print("Percentage of Labels Leaked: " + str(total_label_leaks / total_attempts))


In [None]:
# launch an attack on a specific device
# leakage works best for a device that's the only device in its cluster

d_num = 24
im, lab, hist = try_recovery_individual(devices[d_num], device_gradients[d_num]['pruned_grad'])
print("Recovered PSNR: ")
psnr(im, device_gradients[d_num]['sample_image'])

In [None]:
# print the progression of optimization for this device
plot_history(hist)

In [None]:
# plot final recovered image

plt.imshow(np.transpose(im.detach().numpy(), (1,2,0)))
plt.axis('off')

In [None]:
# plot true original device image
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck']
plt.imshow(np.transpose(device_gradients[d_num]['sample_image'], (1,2,0)))
plt.axis('off')
sample_label = device_gradients[d_num]['onehot_sample_label']
print("Label: " + str(sample_label) + ", a " + classes[torch.argmax(sample_label)])

In [12]:
 # get the image and label pair stored for a device
def get_device_pair(d_num):
  return device_gradients[d_num]['sample_image'], device_gradients[d_num]['onehot_sample_label']
 
 # try manually setting devices in desired clusters to certain images 
 # just an example!
 # (will need to be redone for new clusters/data distributions)
 cluster_attack_device_gradients = [{} for _ in range(len(devices))]
for i, d in enumerate(devices):
  if i == 53 or i == 59 or i == 40:
    sample_image, onehot_sample_label = get_device_pair(2)
  elif i == 8 or i == 44 or i == 26:
    sample_image, onehot_sample_label = get_device_pair(16)
  elif i == 0 or i == 52:
    sample_image, onehot_sample_label = get_device_pair(5)
  elif i == 30:
    sample_image, onehot_sample_label = get_device_pair(15)
  else:
    sample_image, onehot_sample_label = get_device_pair(26)
  grad = get_true_device_gradient(d, sample_image, onehot_sample_label)
  cluster_attack_device_gradients[i]['sample_image'] = sample_image
  cluster_attack_device_gradients[i]['onehot_sample_label'] = onehot_sample_label
  cluster_attack_device_gradients[i]['orig_grad'] = grad

add_pruned_gradients(cluster_attack_device_gradients, clusters, overlap_factor, amplify_factor = 1)

cluster_attack_cluster_gradients = get_cluster_gradients(cluster_attack_device_gradients, clusters, for_attack=False)

In [None]:
# test cluster privacy
# will take a lot of runs to get leakage :)
random_image, random_label = get_random_pair()

# make optimizer
leakage_optimizer = torch.optim.LBFGS([random_image, random_label], lr = 1)

select_cluster = 0 # set cluster to attack
actual_grad = copy.deepcopy(cluster_attack_cluster_gradients[select_cluster])

history = []
num_iters = 100
save_every = 10

for iters in range(num_iters):
  def cluster_closure():
    leakage_optimizer.zero_grad()
    # need to do these lines below for all devices in the cluster using random image and populate somewhere in device_gradients
    dummy_preds = {}
    dummy_onehot_labels = {}
    dummy_losses = {}
    for i, select_device in enumerate(clusters[select_cluster]):
      dummy_preds[i] = devices[select_device]['net'](torch.unsqueeze(random_image, dim = 0).cuda())
      dummy_onehot_labels[i] = F.softmax(random_label, dim=-1).cuda()
      dummy_losses[i] = leakage_criterion(dummy_preds[i], dummy_onehot_labels[i]) 

      cluster_attack_device_gradients[select_device]["grad_for_attack"] = torch.autograd.grad(dummy_losses[i], devices[select_device]['net'].parameters(), create_graph = True)

    # then...
    prune_attack_gradients(cluster_attack_device_gradients, clusters[select_cluster], overlap_factor, amplify_factor = 1)
    dummy_grad = get_cluster_gradient(select_cluster, cluster_attack_device_gradients, clusters, for_attack = True)

    grad_diff = 0
    for gx, gy in zip(dummy_grad, actual_grad): 
      grad_diff += ((gx - gy) ** 2).sum()
    grad_diff.backward()
    return grad_diff
  leakage_optimizer.step(cluster_closure)
  if iters % save_every == 0:
    diff = cluster_closure()
    print(diff)
    history.append(copy.deepcopy(random_image.cpu().detach()))

plot_history(history)