In [1]:
import numpy as np
import torch
import random
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.utils import save_image
import math
import itertools
import statistics
import pickle

from models import *
from util import *

import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Running on {device}.')

Running on cpu.


## Load Dataset

Load one of the MNIST, Fashion-MNIST, or CIFAR10 datasets.

In [3]:
dataset = 'mnist'
batch_size = 64

if dataset == 'mnist':
    trainset = datasets.MNIST('data/mnist', download=True, train=True, transform=transforms.ToTensor())
    testset = datasets.MNIST('data/mnist', download=True, train=False, transform=transforms.ToTensor())
elif dataset == 'f_mnist':
    trainset = datasets.FashionMNIST('data/f_mnist', download=True, train=True, transform=transforms.ToTensor())
    testset = datasets.FashionMNIST('data/f_mnist', download=True, train=False, transform=transforms.ToTensor())
elif dataset == 'cifar':
    trainset = datasets.CIFAR10('data/cifar', download=True, train=True, transform=transforms.ToTensor())
    testset = datasets.CIFAR10('data/cifar', download=True, train=False, transform=transforms.ToTensor())

trainloader = torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=batch_size)
testloader = torch.utils.data.DataLoader(testset, shuffle=True)

## Helper Methods

These are some of methods we will use to compute the SplitGuard score.

In [4]:
def angle(v1, v2):
    unit_vector_1 = v1 / np.linalg.norm(v1)
    unit_vector_2 = v2 / np.linalg.norm(v2)
    dot_product = np.dot(unit_vector_1, unit_vector_2)
    return np.arccos(dot_product)

def sigmoid(x, shift=0, mult=1, exp=1):
    x_p = (x - shift) * mult
    return (1 / (1 + np.exp(-x_p))) ** exp

def sg_score(fakes, controls, regulars, shift=0, mult=1, exp=1, raw=False):
    f_mean = sum(fakes) / len(fakes)
    c_mean = sum(controls) / len(controls)
    r_mean = sum(regulars) / len(regulars)
    cr_mean = (c_mean + r_mean) / 2

    f_mean_mag = sum([np.linalg.norm(v) for v in fakes]) / len(fakes)
    c_mean_mag = sum([np.linalg.norm(v) for v in controls]) / len(controls)
    r_mean_mag = sum([np.linalg.norm(v) for v in regulars]) / len(regulars)
    cr_mean_mag = (c_mean_mag + r_mean_mag) / 2

    mag_div = (abs(f_mean_mag - cr_mean_mag) + abs(c_mean_mag - r_mean_mag))

    x = angle(f_mean, cr_mean) * (abs(f_mean_mag - cr_mean_mag) / mag_div) - angle(c_mean, r_mean) * (abs(r_mean_mag - c_mean_mag) / mag_div)

    if raw:
        return x
    else:
        return sigmoid(x, shift=shift, mult=mult, exp=exp)

## Main code

In [7]:
model_str = 'resnet'

# parameters for the squashing function
mult = 5
exp = 2

# number of randomized labels in a fake batch
b_fake = 64

# probability of sending a fake batch
p_fake = 0.1

# batch index at which splitguard starts running
N = 0

# adversary types
results = {
    'honest': [],
    'random': []
}

for adv_type in ['honest', 'random']:
    fakes, r_1, r_2, fake_indices, scores = [], [], [], [], []
    
    model = get_models(model_str, dataset, device)
    client_opt = torch.optim.Adam(list(model.parameters())[:2], lr=0.001, amsgrad=True)
    server_opt = torch.optim.Adam(list(model.parameters())[2:], lr=0.001, amsgrad=True)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1):
        for index, (images, labels) in enumerate(trainloader):
            if index == 20:
                break
            images, labels = images.to(device), labels.to(device)
            client_opt.zero_grad()
            server_opt.zero_grad()

            send_fakes = index > N and random.random() <= p_fake

            pred = model(images)

            LABELS_SENT = labels
            if send_fakes:
                # randomize labels 
                rand_labels = (labels + random.randint(1,8)) % 10
                LABELS_SENT = torch.cat((rand_labels[:b_fake], labels[b_fake:]))

            # server behavior
            if adv_type == 'honest':
                loss = criterion(pred, LABELS_SENT)
            elif adv_type == 'random':
                loss = criterion(pred, torch.randint(0, 10, labels.size(), dtype=torch.long, device=device))    
            
            loss.backward()

            client_grad = list(model.parameters())[0].grad.detach().clone().flatten()

            if send_fakes:
                fakes.append(client_grad.cpu())
                fake_indices.append(index)
                if len(r_1) > 0 and len(r_2) > 0:
                    sg = sg_score(fakes, r_1, r_2, mult=mult, exp=exp, raw=False)
                    scores.append(sg)
                # do not update client model
            else:
                if index > N:
                    if random.random() <= 0.5:
                        r_1.append(client_grad.cpu())
                    else:
                        r_2.append(client_grad.cpu())
                client_opt.step()
            server_opt.step()

    results[adv_type] = scores

KeyboardInterrupt: 

## Display Results

In [None]:
for adv_type in ['honest', 'random']:
    print(f'{adv_type} mean: {np.mean(results[adv_type])}')
    plt.plot(results[adv_type], label=f'{adv_type}')
plt.ylim(0, 1.1)
plt.xlabel('No. of fake batches')
plt.ylabel('SG score')
plt.legend()
plt.show()