In [1]:
from resnet import ResNet18
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import numpy as np
import pickle

In [2]:
savedir = "./saved_models/run7"

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
criterion = nn.CrossEntropyLoss()

In [6]:
def train(net, trainloader, optimizer, epoch):
    net.train()
    train_loss = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    if epoch%10==0:
        print('Epoch%d, Loss: %.3f' % (epoch, train_loss/(batch_idx+1)))

In [7]:
def test(net):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return 1.0*correct/total

In [8]:
def train_learner(dataset, epochs, net=None):
    if net is None:
        net = ResNet18()
        net = net.to(device)
    optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)
    for epoch in range(epochs):
        train(net, trainloader, optimizer, epoch)
    return net

In [9]:
def tuning(learners, local_ds, theta, epochs):
    global_predictions = []
    for learner in learners:
        global_predictions.append([])
        for inp,lab in global_loader:
            _, preds = learner(inp.to(device)).max(1)
            global_predictions[-1] += preds.data.cpu().numpy().tolist()
    global_predictions = np.array(global_predictions)

    certain_global = []
    correct_count = 0
    for i in range(len(global_trainset)):
        tmp = np.zeros(10) #10 classes
        for pred in global_predictions[:, i]:
            tmp[pred] += 1
        if tmp.max() >= theta:
            certain_global.append((global_trainset[i][0], np.argmax(tmp)))
            if np.argmax(tmp) == global_trainset[i][1]:
                correct_count += 1
    print("Certain predictions amount", len(certain_global), "with correct in them", correct_count)

    acc = 0
    for i in range(n_learners):
        print("learner", i)
        local_dataset = torch.utils.data.Subset(local_trainset, list(range(i*local_ds, (i+1)*local_ds)))
        net = train_learner(local_dataset + certain_global, epochs=epochs, net = learners[i])
        acc += test(net)
        learners[i] = net
        
    return learners, acc/n_learners

In [10]:
local_trainset = torch.utils.data.Subset(trainset, list(range(0, 5000)))
print("local data", len(local_trainset))
global_trainset = torch.utils.data.Subset(trainset, list(range(5000, 50000)))
print("global data", len(global_trainset))
global_loader = torch.utils.data.DataLoader(global_trainset, batch_size=100, shuffle=False, num_workers=2)

local data 5000
global data 45000


In [11]:
pickle.dump(local_trainset, open(os.path.join(savedir, "local_trainset"), "wb"))
pickle.dump(global_trainset, open(os.path.join(savedir, "global_trainset"), "wb"))

In [12]:
n_learners = 5
theta = 4
local_ds = len(local_trainset)//n_learners
print("Length of the local dataset", local_ds)

Length of the local dataset 1000


In [13]:
learners = []
acc = 0
for i in range(n_learners):
    print("learner", i)
    local_dataset = torch.utils.data.Subset(local_trainset, list(range(i*local_ds, (i+1)*local_ds)))
    net = train_learner(local_dataset, epochs=100)
    acc += test(net)
    learners.append(net)
    torch.save(net, os.path.join(savedir, "learner"+str(i)))
print("Average accuracy of local workers", acc/n_learners)

learner 0
Epoch0, Loss: 2.227
Epoch10, Loss: 1.305
Epoch20, Loss: 0.757
Epoch30, Loss: 0.420
Epoch40, Loss: 0.289
Epoch50, Loss: 0.114
Epoch60, Loss: 0.075
Epoch70, Loss: 0.025
Epoch80, Loss: 0.106
Epoch90, Loss: 0.052
Loss: 3.126 | Acc: 50.500% (5050/10000)
learner 1
Epoch0, Loss: 2.159
Epoch10, Loss: 1.197
Epoch20, Loss: 0.797
Epoch30, Loss: 0.467
Epoch40, Loss: 0.304
Epoch50, Loss: 0.110
Epoch60, Loss: 0.075
Epoch70, Loss: 0.129
Epoch80, Loss: 0.035
Epoch90, Loss: 0.037
Loss: 3.325 | Acc: 48.140% (4814/10000)
learner 2
Epoch0, Loss: 2.227
Epoch10, Loss: 1.253
Epoch20, Loss: 0.754
Epoch30, Loss: 0.331
Epoch40, Loss: 0.237
Epoch50, Loss: 0.132
Epoch60, Loss: 0.121
Epoch70, Loss: 0.052
Epoch80, Loss: 0.007
Epoch90, Loss: 0.006
Loss: 2.945 | Acc: 53.190% (5319/10000)
learner 3
Epoch0, Loss: 2.219
Epoch10, Loss: 1.382
Epoch20, Loss: 0.860
Epoch30, Loss: 0.441
Epoch40, Loss: 0.285
Epoch50, Loss: 0.162
Epoch60, Loss: 0.127
Epoch70, Loss: 0.046
Epoch80, Loss: 0.025
Epoch90, Loss: 0.021
Loss

In [14]:
for i in range(5):
    learners, acc = tuning(learners, local_ds, theta, epochs=20)
    print("Average accuracy of local workers", acc)
    for ind, l in enumerate(learners):
        torch.save(l, os.path.join(savedir, "learner"+str(ind)+"_tune"+str(i)))

Certain predictions amount 18704 with correct in them 15496
learner 0
Epoch0, Loss: 0.414
Epoch10, Loss: 0.040
Loss: 2.881 | Acc: 61.680% (6168/10000)
learner 1
Epoch0, Loss: 0.425
Epoch10, Loss: 0.033
Loss: 2.669 | Acc: 62.790% (6279/10000)
learner 2
Epoch0, Loss: 0.416
Epoch10, Loss: 0.035
Loss: 2.938 | Acc: 61.750% (6175/10000)
learner 3
Epoch0, Loss: 0.404
Epoch10, Loss: 0.045
Loss: 2.683 | Acc: 62.780% (6278/10000)
learner 4
Epoch0, Loss: 0.404
Epoch10, Loss: 0.046
Loss: 2.742 | Acc: 62.840% (6284/10000)
Average accuracy of local workers 0.62368
Certain predictions amount 31870 with correct in them 24163
learner 0
Epoch0, Loss: 0.330
Epoch10, Loss: 0.031
Loss: 3.295 | Acc: 65.030% (6503/10000)
learner 1
Epoch0, Loss: 0.333
Epoch10, Loss: 0.034
Loss: 3.071 | Acc: 65.110% (6511/10000)
learner 2
Epoch0, Loss: 0.330
Epoch10, Loss: 0.035
Loss: 3.218 | Acc: 64.810% (6481/10000)
learner 3
Epoch0, Loss: 0.342
Epoch10, Loss: 0.035
Loss: 3.158 | Acc: 65.090% (6509/10000)
learner 4
Epoch0, L