In [1]:
import torch
from torch import nn, optim
import torchvision
torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as transforms
import numpy as np
import torch.utils.tensorboard as tb
import datetime
import os
from tqdm.notebook import tqdm

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = 'cpu'
print(device)

cuda


In [3]:
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ConvertImageDtype()
])

dataset = torchvision.datasets.CIFAR10('./data/torch/cifar', download=True, transform=transform)

Files already downloaded and verified


In [4]:
train_size = int(0.7 * len(dataset))
valid_size = int(0.2 * len(dataset))
train_set, valid_set, test_set = torch.utils.data.random_split(dataset, [train_size, valid_size, len(dataset) - train_size - valid_size])

In [5]:
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2470, 0.2435, 0.2616)

normalize = transforms.Normalize(cifar_mean, cifar_std)

In [6]:
class CNN(nn.Module):

    def __init__(self, arch=[],padding=True):
        super().__init__()
        pad = 'same' if padding else 0
        size = 32
        layers = [nn.Conv2d(3,arch[0][1],arch[0][0], padding=pad)]
        if not padding: size -= arch[0][0] - 1
        for i in range(len(arch)-1):
            layers.append(nn.ReLU())
            layers.append(nn.Conv2d(arch[i][1], arch[i+1][1], arch[i+1][0],padding=pad))
            if not padding: size -= arch[i+1][0] - 1
        layers.append(nn.AvgPool2d(kernel_size=(int(size),int(size))))
        layers.append(nn.Flatten())
        layers.append(nn.Linear(arch[-1][1],10))
        self.layers = layers
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [7]:
def train(arch=[], lr=1e-3, epochs=10, batch_size=64, momentum=0.9,padding=True):

    user_data_size = int(0.1 * len(dataset)) # Each user's model is trained on a much smaller, personal dataset
    train_cut, _ = torch.utils.data.random_split(train_set, [user_data_size, len(train_set) - user_data_size])
    
    train_loader = torch.utils.data.DataLoader(train_cut, shuffle=True, batch_size=batch_size)
    valid_loader = torch.utils.data.DataLoader(valid_set, shuffle=False, batch_size=batch_size)

    network = CNN(arch=arch,padding=padding)
    opt = optim.SGD(network.parameters(), lr=lr, momentum=momentum)
    loss = nn.CrossEntropyLoss()

    # Build a name for each training run. In this case, the name has the format
    # input_size:hidden1_size:...:hiddenN_size-lr-XX-bs-XX-mom-XX
    # The first set of colon-separated integers encodes the MLP architecture while
    # lr, bs, and mom, capture the learning rate, batch size, and momentum respectively.

    train_accs = []
    valid_accs = []

    for i in tqdm(range(epochs)):

        network.train()
        train_acc = []
        for batch_xs, batch_ys in train_loader:
            batch_xs = batch_xs.to(device)
            batch_ys = batch_ys.to(device)
            
            preds = network(normalize(batch_xs))
            acc = (preds.argmax(dim=1) == batch_ys).float().mean()

            loss_val = loss(preds, batch_ys)

            opt.zero_grad()
            loss_val.backward()
            opt.step()

        train_accs.append(torch.tensor(train_acc).mean())
        

        network.eval()
        accs = []
        losses = []
        for batch_xs, batch_ys in valid_loader:
            batch_xs = batch_xs.to(device)
            batch_ys = batch_ys.to(device)
            preds = network(normalize(batch_xs))
            accs.append((preds.argmax(dim=1) == batch_ys).float().mean())
        acc = torch.tensor(accs).mean()
        valid_accs.append(acc)
        #print("EPOCH " + str(i) + ": Valid acc " + acc)
        # Log anything you want to track once per epoch here. Note that you do
        # not need to increment global_step here.
    return (network, valid_accs)

In [8]:
def iterate_model(network, lr=1e-3, epochs=25, batch_size=64, momentum=0.9,padding=True,model_count=10):
    user_data_size = int(len(dataset)/model_count) # Each user's model is trained on a much smaller, personal dataset
    train_cut, _ = torch.utils.data.random_split(train_set, [user_data_size, len(train_set) - user_data_size])
    
    train_loader = torch.utils.data.DataLoader(train_cut, shuffle=True, batch_size=batch_size)
    valid_loader = torch.utils.data.DataLoader(valid_set, shuffle=False, batch_size=batch_size)
    
    opt = optim.SGD(network.parameters(), lr=lr, momentum=momentum)
    loss = nn.CrossEntropyLoss()

    train_accs = []
    valid_accs = []

    for i in tqdm(range(epochs)):

        network.train()
        train_acc = []
        for batch_xs, batch_ys in train_loader:
            batch_xs = batch_xs.to(device)
            batch_ys = batch_ys.to(device)
            preds = network(normalize(batch_xs))
            acc = (preds.argmax(dim=1) == batch_ys).float().mean()

            loss_val = loss(preds, batch_ys)

            opt.zero_grad()
            loss_val.backward()
            opt.step()

        train_accs.append(torch.tensor(train_acc).mean())
        

        network.eval()
        accs = []
        losses = []
        for batch_xs, batch_ys in valid_loader:
            batch_xs = batch_xs.to(device)
            batch_ys = batch_ys.to(device)
            preds = network(normalize(batch_xs))
            accs.append((preds.argmax(dim=1) == batch_ys).float().mean())
        acc = torch.tensor(accs).mean()
        valid_accs.append(acc)
        #print("EPOCH " + str(i) + ": Valid acc " + acc)
        # Log anything you want to track once per epoch here. Note that you do
        # not need to increment global_step here.
    return valid_accs

In [9]:
def aggregate(models,arch,trained_models=[]):
    if trained_models == []:
        trained_models = models
    agg = CNN(arch=arch,padding=False)
    agg.to(device)
    state_agg = agg.state_dict()
    for m in trained_models:
        state_m = m.state_dict()
        for layer in state_agg:
            state_agg[layer] = state_agg[layer] + state_m[layer]

    for layer in state_agg:
        state_agg[layer] = state_agg[layer] / len(trained_models)

    for m in models:
        state_m = m.state_dict()
        for layer in state_m:
            state_m[layer] = state_agg[layer]

Ideas for Decentralized Aggregation:
* Layered Masking Aggregation
    * Each user generates a random "masked model" which is within expected model parameters
    * Add together every masked model to generate aggregate masked model (AMM)
    * Each user calculates mask = masked model - true model
    * Go around removing masks from AMM
    * So long as a user can't see both the masked model and the mask, they can't know a user's true model
    * Can calculate user i's masked model if user i-1 and i+1 are compromised (where user k adds their masked model and passes it to user k+1)
    * Can calculate user i's mask if user i-1 and i+1 are compromised
    * No one malicious user can violate security, but still pretty weak.
* Just Use MPC
    * For a large network or many parties, this gets really expensive

In [10]:
model_count = 20
arch = [(7, 8), (3, 16), (3, 32), (3, 64)]
models = [CNN(arch=arch,padding=False) for i in range(model_count)]
for m in models:
    m.to(device)
accs = []

In [11]:
def eval_model(network,loader):
    network.eval()
    accs = []
    for batch_xs, batch_ys in loader:
        batch_xs = batch_xs.to(device)
        batch_ys = batch_ys.to(device)
        preds = network(normalize(batch_xs))
        accs.append((preds.argmax(dim=1) == batch_ys).float().mean())
    acc = torch.tensor(accs).mean()
    return acc

In [12]:
rounds = 5
for r in range(rounds):
    print("Round",r+1)
    i = 0
    round_accs = []
    for m in models:
        i += 1
        print("Training model",i)
        acc = iterate_model(m,lr=2e-3 * (rounds - r),model_count=model_count)
        round_accs.append(acc[-1])
        print(acc[-1])
    print("Avg. Accuracy:",np.mean(round_accs))
    print("Aggregating")
    aggregate(models,arch)
    valid_loader = torch.utils.data.DataLoader(valid_set, shuffle=False, batch_size=64)
    print("Aggregate Evaluation Accuracy:",eval_model(models[0],valid_loader))

Round 1
Training model 1


  0%|          | 0/25 [00:00<?, ?it/s]

tensor(0.3481)
Training model 2


  0%|          | 0/25 [00:00<?, ?it/s]

tensor(0.3374)
Training model 3


  0%|          | 0/25 [00:00<?, ?it/s]

tensor(0.3184)
Training model 4


  0%|          | 0/25 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [64]:
print(np.mean(a))

0.19390824


In [82]:
test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=64)

agg.eval()
accs = []
losses = []
for batch_xs, batch_ys in test_loader:
    preds = models[0](normalize(batch_xs))
    accs.append((preds.argmax(dim=1) == batch_ys).float().mean())
acc = torch.tensor(accs).mean()
print(acc)

tensor(0.4790)
