In [1]:
from resnet import ResNet18
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import numpy as np
import pickle
import math

In [2]:
savedir = "./saved_models/noniid/run3"

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
criterion = nn.CrossEntropyLoss()

In [6]:
def train(net, trainloader, optimizer, epoch):
    net.train()
    train_loss = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    if epoch%10==0:
        print('Epoch%d, Loss: %.3f' % (epoch, train_loss/(batch_idx+1)))

In [7]:
def test(net):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return 1.0*correct/total

In [8]:
def train_learner(dataset, epochs, net=None):
    if net is None:
        net = ResNet18()
        net = net.to(device)
    optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)
    for epoch in range(epochs):
        train(net, trainloader, optimizer, epoch)
    return net

In [9]:
def tuning(learners, local_sets, theta, epochs):
    global_predictions = []
    for learner in learners:
        global_predictions.append([])
        for inp,lab in global_loader:
            _, preds = learner(inp.to(device)).max(1)
            global_predictions[-1] += preds.data.cpu().numpy().tolist()
    global_predictions = np.array(global_predictions)

    certain_global = []
    correct_count = 0
    for i in range(len(global_trainset)):
        tmp = np.zeros(10) #10 classes
        for pred in global_predictions[:, i]:
            tmp[pred] += 1
        if tmp.max() >= theta:
            certain_global.append((global_trainset[i][0], np.argmax(tmp)))
            if np.argmax(tmp) == global_trainset[i][1]:
                correct_count += 1
    print("Certain predictions amount", len(certain_global), "with correct in them", correct_count)

    acc = 0
    for i in range(n_learners):
        print("learner", i)
        local_dataset = local_sets[i]
        net = train_learner(local_dataset + certain_global, epochs=epochs, net = learners[i])
        acc += test(net)
        learners[i] = net
        
    return learners, acc/n_learners

In [10]:
local_trainset = torch.utils.data.Subset(trainset, list(range(0, 10000)))
print("local data", len(local_trainset))
global_trainset = torch.utils.data.Subset(trainset, list(range(10000, 50000)))
print("global data", len(global_trainset))
global_loader = torch.utils.data.DataLoader(global_trainset, batch_size=100, shuffle=False, num_workers=2)

local data 10000
global data 40000


In [13]:
p=0.25
unsorted_local_trainset = torch.utils.data.Subset(local_trainset, list(range(math.floor(p*len(local_trainset)))))
sorted_local_trainset = {}
for e in torch.utils.data.Subset(local_trainset, list(range(math.floor(p*len(local_trainset)), len(local_trainset)))):
    if sorted_local_trainset.get(e[1]) is None:
        sorted_local_trainset[e[1]] = []
    sorted_local_trainset[e[1]].append(e)

In [14]:
for k in sorted_local_trainset:
    print("class", k, "size", len(sorted_local_trainset[k]))

class 5 size 701
class 3 size 772
class 7 size 748
class 6 size 780
class 0 size 752
class 9 size 731
class 2 size 777
class 1 size 731
class 4 size 737
class 8 size 771


In [15]:
n_learners = 5
uns_local_ds = len(unsorted_local_trainset)//n_learners
theta = 4

In [16]:
local_sets = []
for i in range(n_learners):
    local_sets.append([])
    local_sets[i] += sorted_local_trainset[i]
    local_sets[i] += sorted_local_trainset[i+5]
    local_sets[i] += torch.utils.data.Subset(unsorted_local_trainset, list(range(uns_local_ds*i, uns_local_ds*(i+1))))

In [17]:
pickle.dump(local_sets, open(os.path.join(savedir, "local_sets"), "wb"))
pickle.dump(global_trainset, open(os.path.join(savedir, "global_trainset"), "wb"))

In [18]:
for l in local_sets:
    print(len(l))

1953
2011
2025
2043
1968


In [19]:
learners = []
acc = 0
for i in range(n_learners):
    print("learner", i)
    local_dataset = local_sets[i]
    net = train_learner(local_dataset, epochs=100)
    acc += test(net)
    learners.append(net)
    torch.save(net, os.path.join(savedir, "learner"+str(i)))
print("Average accuracy of local workers", acc/n_learners)

learner 0
Epoch0, Loss: 1.468
Epoch10, Loss: 0.026
Epoch20, Loss: 0.001
Epoch30, Loss: 0.000
Epoch40, Loss: 0.000
Epoch50, Loss: 0.000
Epoch60, Loss: 0.000
Epoch70, Loss: 0.000
Epoch80, Loss: 0.000
Epoch90, Loss: 0.000
Loss: 5.318 | Acc: 26.950% (2695/10000)
learner 1
Epoch0, Loss: 1.337
Epoch10, Loss: 0.034
Epoch20, Loss: 0.001
Epoch30, Loss: 0.019
Epoch40, Loss: 0.004
Epoch50, Loss: 0.001
Epoch60, Loss: 0.000
Epoch70, Loss: 0.000
Epoch80, Loss: 0.000
Epoch90, Loss: 0.000
Loss: 4.656 | Acc: 30.330% (3033/10000)
learner 2
Epoch0, Loss: 1.522
Epoch10, Loss: 0.080
Epoch20, Loss: 0.001
Epoch30, Loss: 0.000
Epoch40, Loss: 0.000
Epoch50, Loss: 0.000
Epoch60, Loss: 0.000
Epoch70, Loss: 0.000
Epoch80, Loss: 0.000
Epoch90, Loss: 0.000
Loss: 5.042 | Acc: 26.710% (2671/10000)
learner 3
Epoch0, Loss: 1.459
Epoch10, Loss: 0.015
Epoch20, Loss: 0.001
Epoch30, Loss: 0.000
Epoch40, Loss: 0.000
Epoch50, Loss: 0.000
Epoch60, Loss: 0.000
Epoch70, Loss: 0.000
Epoch80, Loss: 0.000
Epoch90, Loss: 0.000
Loss

In [20]:
for i in range(5):
    learners, acc = tuning(learners, local_sets, theta, epochs=20)
    print("Average accuracy of local workers", acc)
    for ind, l in enumerate(learners):
        torch.save(l, os.path.join(savedir, "learner"+str(ind)+"_tune"+str(i)))

Certain predictions amount 662 with correct in them 538
learner 0
Epoch0, Loss: 0.605
Epoch10, Loss: 0.001
Loss: 4.309 | Acc: 35.480% (3548/10000)
learner 1
Epoch0, Loss: 0.458
Epoch10, Loss: 0.001
Loss: 4.018 | Acc: 36.570% (3657/10000)
learner 2
Epoch0, Loss: 0.483
Epoch10, Loss: 0.001
Loss: 4.311 | Acc: 33.920% (3392/10000)
learner 3
Epoch0, Loss: 0.486
Epoch10, Loss: 0.002
Loss: 4.874 | Acc: 32.960% (3296/10000)
learner 4
Epoch0, Loss: 0.693
Epoch10, Loss: 0.131
Loss: 4.398 | Acc: 33.540% (3354/10000)
Average accuracy of local workers 0.34493999999999997
Certain predictions amount 6234 with correct in them 4587
learner 0
Epoch0, Loss: 0.458
Epoch10, Loss: 0.005
Loss: 4.482 | Acc: 44.590% (4459/10000)
learner 1
Epoch0, Loss: 0.350
Epoch10, Loss: 0.004
Loss: 4.769 | Acc: 41.380% (4138/10000)
learner 2
Epoch0, Loss: 0.414
Epoch10, Loss: 0.086
Loss: 4.157 | Acc: 41.480% (4148/10000)
learner 3
Epoch0, Loss: 0.373
Epoch10, Loss: 0.014
Loss: 5.171 | Acc: 40.470% (4047/10000)
learner 4
Epo