In [1]:
from resnet import ResNet18
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import numpy as np
import pickle
from collections import OrderedDict

In [2]:
savedir = "./saved_models/noniid/run3"

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
def train(net, trainloader, optimizer, epoch):
    net.train()
    train_loss = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    if epoch%5==0:
        print('Loss: %.3f' % (train_loss/(batch_idx+1)))

In [5]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified


In [6]:
def test(net):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return 1.0*correct/total

In [7]:
def train_learner(dataset, epochs, net):
    optimizer = optim.SGD(net.parameters(), lr=0.005*5, momentum=0.9)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)
    for epoch in range(epochs):
        train(net, trainloader, optimizer, epoch)
    return net

In [8]:
def average_weights(learners):
    newState = OrderedDict()
    for k,v in learners[0].state_dict().items():
        newState[k] = v.data.cpu().numpy().copy()
    for i in range(1,len(learners),1):
        for k,v in learners[i].state_dict().items():
            newState[k] = np.add(newState[k], v.data.cpu().numpy())
    scalar = 1.0/len(learners)
    for k in newState:
        if isinstance(newState[k], np.int64):
            newState[k] *= int(scalar)
        else:
            newState[k] = np.multiply(newState[k], scalar, out=newState[k], casting="unsafe")
    for k,v in newState.items():
        if v.shape == ():
            newState[k] = torch.tensor(v).to(device)
        else:
            newState[k] = torch.cuda.FloatTensor(v).to(device)
    return newState

In [9]:
#local_trainset = pickle.load(open(os.path.join(savedir, "local_trainset"), "rb"))
local_sets = pickle.load(open(os.path.join(savedir, "local_sets"), "rb"))
n_learners = 5
#local_ds=len(local_trainset)//n_learners
#print(local_ds)
print(len(local_sets[0]))

1953


In [10]:
learners = []
for i in range(n_learners):
    net = ResNet18()
    net = net.to(device)
    learners.append(net)
criterion = nn.CrossEntropyLoss()
for e in range(5):
    print("Epoch", e)
    for i in range(n_learners):
        local_dataset = local_sets[i] #torch.utils.data.Subset(local_trainset, list(range(i*local_ds, (i+1)*local_ds))) 
        net = train_learner(local_dataset, epochs=40, net=learners[i])
        learners.append(net)
    avg_weights = average_weights(learners)
    for i in range(n_learners):
        learners[i].load_state_dict(avg_weights)
    test(learners[0])

Epoch 0
Loss: 1.882
Loss: 1.030
Loss: 0.879
Loss: 0.671
Loss: 0.351
Loss: 0.152
Loss: 0.031
Loss: 0.023
Loss: 1.491
Loss: 0.833
Loss: 0.621
Loss: 0.344
Loss: 0.138
Loss: 0.029
Loss: 0.006
Loss: 0.001
Loss: 1.791
Loss: 1.048
Loss: 0.727
Loss: 0.307
Loss: 0.071
Loss: 0.069
Loss: 0.109
Loss: 0.007
Loss: 1.837
Loss: 1.012
Loss: 0.804
Loss: 0.458
Loss: 0.231
Loss: 0.044
Loss: 0.058
Loss: 0.015
Loss: 1.497
Loss: 0.877
Loss: 0.575
Loss: 0.251
Loss: 0.038
Loss: 0.050
Loss: 0.001
Loss: 0.001
Loss: 2.365 | Acc: 10.000% (1000/10000)
Epoch 1
Loss: 1.596
Loss: 0.962
Loss: 0.883
Loss: 0.718
Loss: 0.514
Loss: 0.315
Loss: 0.120
Loss: 0.048
Loss: 1.302
Loss: 0.837
Loss: 0.661
Loss: 0.405
Loss: 0.163
Loss: 0.092
Loss: 0.046
Loss: 0.039
Loss: 1.603
Loss: 1.048
Loss: 0.766
Loss: 0.370
Loss: 0.222
Loss: 0.055
Loss: 0.023
Loss: 0.058
Loss: 1.455
Loss: 0.900
Loss: 0.763
Loss: 0.479
Loss: 0.242
Loss: 0.102
Loss: 0.048
Loss: 0.028
Loss: 1.472
Loss: 1.021
Loss: 0.865
Loss: 0.698
Loss: 0.446
Loss: 0.161
Loss: 0.

In [11]:
torch.save(learners[0], os.path.join(savedir, "localdata_fedl_40epochperiod_200epochs"))
test(learners[0])

Loss: 2.312 | Acc: 50.510% (5051/10000)


0.5051