In [1]:
import torch
from torch import nn, optim
import torchvision
torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as transforms
import numpy as np
import torch.utils.tensorboard as tb
import datetime
import os
from tqdm.notebook import tqdm

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = 'cpu'
print(device)

cuda


In [3]:
transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ConvertImageDtype()
])

dataset = torchvision.datasets.CIFAR10('./data/torch/cifar', download=True, transform=transform)

Files already downloaded and verified


In [4]:
train_size = int(0.7 * len(dataset))
valid_size = int(0.2 * len(dataset))
train_set, valid_set, test_set = torch.utils.data.random_split(dataset, [train_size, valid_size, len(dataset) - train_size - valid_size])

In [5]:
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2470, 0.2435, 0.2616)

normalize = transforms.Normalize(cifar_mean, cifar_std)

In [6]:
class CNN(nn.Module):

    def __init__(self, arch=[],padding=True):
        super().__init__()
        pad = 'same' if padding else 0
        size = 32
        layers = [nn.Conv2d(3,arch[0][1],arch[0][0], padding=pad)]
        if not padding: size -= arch[0][0] - 1
        for i in range(len(arch)-1):
            layers.append(nn.ReLU())
            layers.append(nn.Conv2d(arch[i][1], arch[i+1][1], arch[i+1][0],padding=pad))
            if not padding: size -= arch[i+1][0] - 1
        layers.append(nn.AvgPool2d(kernel_size=(int(size),int(size))))
        layers.append(nn.Flatten())
        layers.append(nn.Linear(arch[-1][1],10))
        self.layers = layers
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [7]:
def train(arch=[], lr=1e-3, epochs=10, batch_size=64, momentum=0.9,padding=True):

    user_data_size = int(0.1 * len(dataset)) # Each user's model is trained on a much smaller, personal dataset
    train_cut, _ = torch.utils.data.random_split(train_set, [user_data_size, len(train_set) - user_data_size])
    
    train_loader = torch.utils.data.DataLoader(train_cut, shuffle=True, batch_size=batch_size)
    valid_loader = torch.utils.data.DataLoader(valid_set, shuffle=False, batch_size=batch_size)

    network = CNN(arch=arch,padding=padding)
    opt = optim.SGD(network.parameters(), lr=lr, momentum=momentum)
    loss = nn.CrossEntropyLoss()

    # Build a name for each training run. In this case, the name has the format
    # input_size:hidden1_size:...:hiddenN_size-lr-XX-bs-XX-mom-XX
    # The first set of colon-separated integers encodes the MLP architecture while
    # lr, bs, and mom, capture the learning rate, batch size, and momentum respectively.

    train_accs = []
    valid_accs = []

    for i in tqdm(range(epochs)):

        network.train()
        train_acc = []
        for batch_xs, batch_ys in train_loader:
            batch_xs = batch_xs.to(device)
            batch_ys = batch_ys.to(device)
            
            preds = network(normalize(batch_xs))
            acc = (preds.argmax(dim=1) == batch_ys).float().mean()

            loss_val = loss(preds, batch_ys)

            opt.zero_grad()
            loss_val.backward()
            opt.step()

        train_accs.append(torch.tensor(train_acc).mean())
        

        network.eval()
        accs = []
        losses = []
        for batch_xs, batch_ys in valid_loader:
            batch_xs = batch_xs.to(device)
            batch_ys = batch_ys.to(device)
            preds = network(normalize(batch_xs))
            accs.append((preds.argmax(dim=1) == batch_ys).float().mean())
        acc = torch.tensor(accs).mean()
        valid_accs.append(acc)
        #print("EPOCH " + str(i) + ": Valid acc " + acc)
        # Log anything you want to track once per epoch here. Note that you do
        # not need to increment global_step here.
    return (network, valid_accs)

In [8]:
def iterate_model(network, lr=1e-3, epochs=25, batch_size=64, momentum=0.9,padding=True,model_count=10):
    user_data_size = int(len(dataset)/model_count) # Each user's model is trained on a much smaller, personal dataset
    train_cut, _ = torch.utils.data.random_split(train_set, [user_data_size, len(train_set) - user_data_size])
    
    train_loader = torch.utils.data.DataLoader(train_cut, shuffle=True, batch_size=batch_size)
    valid_loader = torch.utils.data.DataLoader(valid_set, shuffle=False, batch_size=batch_size)
    
    opt = optim.SGD(network.parameters(), lr=lr, momentum=momentum)
    loss = nn.CrossEntropyLoss()

    train_accs = []
    valid_accs = []

    for i in tqdm(range(epochs)):

        network.train()
        train_acc = []
        for batch_xs, batch_ys in train_loader:
            batch_xs = batch_xs.to(device)
            batch_ys = batch_ys.to(device)
            preds = network(normalize(batch_xs))
            acc = (preds.argmax(dim=1) == batch_ys).float().mean()

            loss_val = loss(preds, batch_ys)

            opt.zero_grad()
            loss_val.backward()
            opt.step()

        train_accs.append(torch.tensor(train_acc).mean())
        

        network.eval()
        accs = []
        losses = []
        for batch_xs, batch_ys in valid_loader:
            batch_xs = batch_xs.to(device)
            batch_ys = batch_ys.to(device)
            preds = network(normalize(batch_xs))
            accs.append((preds.argmax(dim=1) == batch_ys).float().mean())
        acc = torch.tensor(accs).mean()
        valid_accs.append(acc)
        #print("EPOCH " + str(i) + ": Valid acc " + acc)
        # Log anything you want to track once per epoch here. Note that you do
        # not need to increment global_step here.
    return valid_accs

In [9]:
def models_diff(n1,n2,loader):
    n1.eval()
    n2.eval()
    results = []
    for batch_xs, batch_ys in loader:
        batch_xs = batch_xs.to(device)
        batch_ys = batch_ys.to(device)
        p1 = n1(normalize(batch_xs))
        p2 = n2(normalize(batch_xs))
        results.append(torch.norm(p1-p2))
    return torch.tensor(results).mean()

In [10]:
def override_models(networks, target_dict):
    for m in models:
        m.load_state_dict(target_dict)

In [11]:
def aggregate(models,arch,trained_models=[]):
    if trained_models == []:
        trained_models = models
    agg = CNN(arch=arch,padding=False)
    agg.to(device)
    
    state_agg = agg.state_dict()
    for layer in state_agg:
        state_agg[layer] = torch.zeros(state_agg[layer].shape).to(device)
    
    i = 0
    for m in trained_models:
        state_m = m.state_dict()
        for layer in state_agg:
            state_agg[layer] = state_agg[layer] + state_m[layer]

    for layer in state_agg:
        state_agg[layer] = state_agg[layer] / len(trained_models)
    
    valid_loader = torch.utils.data.DataLoader(valid_set, shuffle=False, batch_size=64)

    override_models(models, state_agg)

Ideas for Decentralized Aggregation:
* Layered Masking Aggregation
    * Each user generates a random "masked model" which is within expected model parameters
    * Add together every masked model to generate aggregate masked model (AMM)
    * Each user calculates mask = masked model - true model
    * Go around removing masks from AMM
    * So long as a user can't see both the masked model and the mask, they can't know a user's true model
    * Can calculate user i's masked model if user i-1 and i+1 are compromised (where user k adds their masked model and passes it to user k+1)
    * Can calculate user i's mask if user i-1 and i+1 are compromised
    * No one malicious user can violate security, but still pretty weak.
* Just Use MPC
    * For a large network or many parties, this gets really expensive

In [18]:
model_count = 10
arch = [(7, 8), (3, 16), (3, 32), (3, 64)]
models = [CNN(arch=arch,padding=True) for i in range(model_count)] # TODO: Initialize models with the same weights, otherwise the first epoch is wasted
d = models[0].state_dict()
override_models(models, d)
for m in models:
    m.to(device)
accs = []
iterate_prob = 1

In [19]:
def eval_model(network,loader):
    # Given a network and a data loader, evaluate the accuracy of the network
    network.eval()
    accs = []
    for batch_xs, batch_ys in loader:
        batch_xs = batch_xs.to(device)
        batch_ys = batch_ys.to(device)
        preds = network(normalize(batch_xs))
        accs.append((preds.argmax(dim=1) == batch_ys).float().mean())
    acc = torch.tensor(accs).mean()
    return acc

In [20]:
rng = np.random.default_rng()
rounds = 5 # Number of iterations to run alternating between training and aggregating
for r in range(rounds):
    print("Round",r+1)
    i = 0
    round_accs = []
    iterated = []
    for m in models:
        if rng.random() > iterate_prob: # Train using only a fraction of the devices available
            continue
        iterated.append(i)
        i += 1
        print("Training model",i)
        acc = iterate_model(m,lr=2e-3 * (rounds - r),model_count=model_count,epochs=10) # 50 epochs
        valid_loader = torch.utils.data.DataLoader(valid_set, shuffle=False, batch_size=64)
        round_accs.append(eval_model(m, valid_loader))
        print(round_accs[-1])
    print("Avg. Accuracy:",np.mean(round_accs))
    #print("Aggregating")
    #aggregate(models,arch,trained_models=[models[i] for i in iterated])
    #valid_loader = torch.utils.data.DataLoader(valid_set, shuffle=False, batch_size=64)
    #print("Aggregate Evaluation Accuracy:",eval_model(models[0],valid_loader))

Round 1
Training model 1


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3085)
Training model 2


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3205)
Training model 3


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3111)
Training model 4


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.2954)
Training model 5


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.2949)
Training model 6


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3082)
Training model 7


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3027)
Training model 8


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3096)
Training model 9


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3042)
Training model 10


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3104)
Avg. Accuracy: 0.30655852
Round 2
Training model 1


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3677)
Training model 2


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3494)
Training model 3


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3576)
Training model 4


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3475)
Training model 5


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3384)
Training model 6


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3223)
Training model 7


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3527)
Training model 8


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3452)
Training model 9


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3642)
Training model 10


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3447)
Avg. Accuracy: 0.34897494
Round 3
Training model 1


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3809)
Training model 2


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3538)
Training model 3


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3658)
Training model 4


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3713)
Training model 5


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3388)
Training model 6


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3697)
Training model 7


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3673)
Training model 8


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3681)
Training model 9


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3529)
Training model 10


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3618)
Avg. Accuracy: 0.36304733
Round 4
Training model 1


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3905)
Training model 2


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3961)
Training model 3


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3815)
Training model 4


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3964)
Training model 5


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4175)
Training model 6


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4068)
Training model 7


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4108)
Training model 8


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4050)
Training model 9


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3739)
Training model 10


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.3769)
Avg. Accuracy: 0.3955414
Round 5
Training model 1


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4283)
Training model 2


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4194)
Training model 3


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4193)
Training model 4


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4010)
Training model 5


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4446)
Training model 6


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4223)
Training model 7


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4217)
Training model 8


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4276)
Training model 9


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4236)
Training model 10


  0%|          | 0/10 [00:00<?, ?it/s]

tensor(0.4033)
Avg. Accuracy: 0.4210987


In [64]:
print(np.mean(a))

0.19390824


In [82]:
test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=64)

agg.eval()
accs = []
losses = []
for batch_xs, batch_ys in test_loader:
    preds = models[0](normalize(batch_xs))
    accs.append((preds.argmax(dim=1) == batch_ys).float().mean())
acc = torch.tensor(accs).mean()
print(acc)

tensor(0.4790)


I present three accuracy values for each experiment. The control group is the average accuracy of 10 models over 50 epochs, broken up into 5 iterations of 10 epochs where each model gets a random cut of 1/10th of the data in each iteration. The constituent accuracies are the accuracies of the local models after each training iteration, but before aggregation occurs. The aggregate accuracies are the accuracies of the aggregate values after each iteration.

[(7, 8), (3, 16), (3, 32), (3, 64)], no padding:
* Avg. Control Accuracies: 0.3174, 0.3848, 0.4233, 0.4589, 0.4809
* Avg. Constituent Accuracies: 0.3216, 0.3769, 0.4093, 0.4464, 0.4697
* Aggregate Accuracies: 0.3186, 0.4010, 0.4367, 0.4656, 0.4804

| Iteration | Avg. Control Accuracy | Avg. Constituent Accuracy | Aggregate Accuracy |
| --------  | --------------------- | ------------------------- | ------------------ |
| 1         | 0.3174                | 0.3216                    | 0.3186             |
| 2         | 0.3848                | 0.3769                    | 0.4010             |
| 3         | 0.4233                | 0.4093                    | 0.4367             |
| 4         | 0.4589                | 0.4464                    | 0.4656             |
| 5         | 0.4809                | 0.4697                    | 0.4804             |

[(7, 8), (3, 16), (3, 32), (3, 64)], padding:
* Avg. Control Accuracies: 0.3066, 0.3490, 0.3630, 0.3955, 0.4211
* Avg. Constituent Accuracies: 0.3056, 0.3379, 0.3669, 0.3955, 0.4310
* Aggregate Accuracies: 0.3266, 0.3565, 0.3852, 0.4232, 0.4475

| Iteration | Avg. Control Accuracy | Avg. Constituent Accuracy | Aggregate Accuracy |
| --------  | --------------------- | ------------------------- | ------------------ |
| 1         | 0.3066                | 0.3056                    | 0.3266             |
| 2         | 0.3490                | 0.3379                    | 0.3565             |
| 3         | 0.3630                | 0.3669                    | 0.3852             |
| 4         | 0.3955                | 0.3955                    | 0.4232             |
| 5         | 0.4211                | 0.4310                    | 0.4475             |

With this basic aggregation technique, we see that the aggregate models are successfully integrating the knowledge of the constituent models. The validation accuracy of the aggregate model is pretty consistently a few percentage points above the average accuracy of the constituent models. It's interesting to note that while the average constituent accuracy after training is rarely much better than the aggregate accuracy those models started with, the aggregate accuracy improves. The constituent accuracy is pretty even with the control accuracy, and my results show the aggregate always being at least as good as the control group, suggesting that federated learning does improve the accuracy when compared to just training locally. This also is before considering inequal splits of training data.