In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
import torch.nn.functional as F
import torchvision
import numpy as np
import time
from tqdm import tqdm

from FL.agents import *
from FL.models import *
from FL.util import *

In [2]:
n_workers = 5
n_epochs = 1000
batch_size = 128
learning_rate = 0.001
noniid = False
load_model = False

## Setup

In [3]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Import Datasets
trainset = torchvision.datasets.CIFAR10(
    root='../data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(
    root='../data', train=False, download=True, transform=transform)

# Batch Loaders
if noniid:
    def noniid_batch_trainset(trainset, c):
        indices = (np.array(trainset.targets) == c)
        trainset2 = copy.deepcopy(trainset)
        trainset2.data = trainset2.data[indices]
        trainset2.targets = [c for i in range(len(indices))]

        return trainset2

    trainsets = [noniid_batch_trainset(trainset,i) for i in set(trainset.targets)]
else:
    trainsets = [trainset]

samplers = [torch.utils.data.RandomSampler(i, replacement=True) for i in trainsets]
trainloaders = [torch.utils.data.DataLoader(
    trainsets[i], batch_size=batch_size, shuffle=False, sampler=samplers[i],
    num_workers=0) for i in range(len(trainsets))]

testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Aggregator rule: mean
def rule(ups_list):  # ups_list is a list of list of tensors
    return [torch.stack([x[i] for x in ups_list]).mean(0)
            for i in range(len(ups_list[0]))]

In [5]:
# Setup Learning Model
model = PerformantNet1()
if load_model:
    model.load_state_dict(torch.load("PerformantNet1_10epochs.pt"))
    n_epochs = 0

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cpu_device = torch.device("cpu")
print(device)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss = nn.CrossEntropyLoss()

cpu


In [6]:
# Setup Federated Learning Framework
central = Central(model, optimizer)
worker_list = []
for i in range(n_workers):
    worker_list.append(Worker(loss))
agg = Agg(rule)

## Main training loop

In [7]:
epochs = []
accuracies = []

# Training Loop
for t in tqdm(range(n_epochs)):
    first_time = time.time()

    weight_ups = []
    losses = []
    central.model.train()

    dataiters = [iter(trainloader) for trainloader in trainloaders]

    # Worker Loop
    for i in range(n_workers):
        k = np.random.randint(0, len(dataiters))
        dataiter = dataiters[k]

        batch_inp, batch_outp = dataiter.next()
        batch_inp, batch_outp = batch_inp.to(device), batch_outp.to(device)

        worker_list[i].model = central.model

        ups, loss = worker_list[i].fwd_bkwd(batch_inp, batch_outp)
        print(loss)
        weight_ups.append(ups)
        losses.append(loss)

    # Aggregate Worker Gradients
    weight_ups_FIN = agg.rule(weight_ups)
    
    print('Avg. Loss: {}'.format(np.mean(losses)))

    # Update Central Model
    central.update_model(weight_ups_FIN)

    central.model.eval()

    if t > 0 and t % 100 == 0:
        print('Epoch: {}, Time to complete: {}'.format(t, time.time() - first_time))

    if t % 250 == 0 and t > 0:
        # print('Epoch: {}'.format(t))
        accuracy = print_test_accuracy(model, testloader)
        epochs.append(t)
        accuracies.append(accuracy)

print('Done training')

  0%|          | 0/1000 [00:00<?, ?it/s]

2.305928
2.2995133
2.3038704
2.3002784


  0%|          | 1/1000 [00:11<3:13:43, 11.64s/it]

2.3035467
Avg. Loss: 2.3026273250579834
2.305126
2.3083975
2.3031292
2.3014648


  0%|          | 2/1000 [00:23<3:13:46, 11.65s/it]

2.3026724
Avg. Loss: 2.3041579723358154
2.304352
2.2978926
2.2973106
2.3069046


  0%|          | 3/1000 [00:34<3:13:13, 11.63s/it]

2.309531
Avg. Loss: 2.3031983375549316
2.3027136
2.3110538
2.3057003
2.3036351


  0%|          | 4/1000 [00:46<3:12:12, 11.58s/it]

2.3003535
Avg. Loss: 2.3046913146972656
2.2980118
2.3001523
2.3028452
2.3062322


  0%|          | 5/1000 [00:57<3:12:00, 11.58s/it]

2.3085353
Avg. Loss: 2.3031554222106934
2.298717
2.2983627
2.303865
2.288769


  1%|          | 6/1000 [01:09<3:10:49, 11.52s/it]

2.2911859
Avg. Loss: 2.296180009841919
2.3000922
2.284395
2.2924833
2.2583342


  1%|          | 7/1000 [01:20<3:10:21, 11.50s/it]

2.2626858
Avg. Loss: 2.2795979976654053
2.2947316
2.2157772
2.2785218
2.2123134


  1%|          | 8/1000 [01:32<3:09:57, 11.49s/it]

2.2912982
Avg. Loss: 2.258528232574463
2.2373345
2.2078683
2.243802
2.2581875


  1%|          | 9/1000 [01:43<3:08:57, 11.44s/it]

2.2575805
Avg. Loss: 2.240954637527466
2.2598557
2.2529771
2.2911768
2.2385538


  1%|          | 10/1000 [01:55<3:08:59, 11.45s/it]

2.272146
Avg. Loss: 2.262941837310791
2.2147777
2.2157502
2.2125566
2.213493


  1%|          | 11/1000 [02:06<3:08:47, 11.45s/it]

2.2261937
Avg. Loss: 2.2165541648864746
2.2648177
2.1370757
2.147506


KeyboardInterrupt: 