In [1]:
from resnet import ResNet18
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import numpy as np
import pickle

In [2]:
savedir = "./saved_models/run6"

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
def train(net, trainloader, optimizer, epoch):
    net.train()
    train_loss = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    if epoch%10==0:
        print('Epoch%d, Loss: %.3f' % (epoch, train_loss/(batch_idx+1)))

In [5]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified


In [6]:
def test(net):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    return 1.0*correct/total

In [7]:
def train_learner(dataset, epochs, net=None):
    if net is None:
        net = ResNet18()
        net = net.to(device)
    optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)
    for epoch in range(epochs):
        train(net, trainloader, optimizer, epoch)
    return net

In [8]:
local_trainset = pickle.load(open(os.path.join(savedir, "local_trainset"), "rb"))

In [9]:
criterion = nn.CrossEntropyLoss()
net = train_learner(local_trainset, epochs=200)

Epoch0, Loss: 1.836
Epoch10, Loss: 0.646
Epoch20, Loss: 0.350
Epoch30, Loss: 0.214
Epoch40, Loss: 0.141
Epoch50, Loss: 0.071
Epoch60, Loss: 0.041
Epoch70, Loss: 0.038
Epoch80, Loss: 0.020
Epoch90, Loss: 0.021
Epoch100, Loss: 0.026
Epoch110, Loss: 0.016
Epoch120, Loss: 0.006
Epoch130, Loss: 0.007
Epoch140, Loss: 0.002
Epoch150, Loss: 0.007
Epoch160, Loss: 0.003
Epoch170, Loss: 0.016
Epoch180, Loss: 0.010
Epoch190, Loss: 0.006


In [10]:
test(net)

Loss: 1.076 | Acc: 84.230% (8423/10000)


0.8423

In [11]:
torch.save(net, os.path.join(savedir, "localdata_centralized"))