In [14]:
from torchvision.datasets import MNIST, USPS, SVHN
import torch 
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import ToTensor,transforms,Compose,Resize,Lambda
from model import CNN
import numpy as np
from collections import OrderedDict
from functools import reduce

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)
device

'cuda'

In [16]:
# data download
trainM = MNIST(root='data',train=True,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
trainU = USPS(root='data',train=True,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
trainS = SVHN(root='data',split='train',download=True,transform=Compose([
Resize(28),
ToTensor(),
]))
testM = MNIST(root='data',train=False,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
testU = USPS(root='data',train=False,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
testS = SVHN(root='data',split='test',download=True,transform=Compose([
Resize(28),
ToTensor(),
]))

Using downloaded and verified file: data/train_32x32.mat
Using downloaded and verified file: data/test_32x32.mat


In [17]:
def makeDataset(raw_data,batch_size):
    data_loader = DataLoader(dataset=raw_data,
                batch_size=batch_size,
                shuffle=True,
                drop_last=True) # batch 마지막 데이터 버림
    return data_loader

train_loader_M = makeDataset(trainM,16)
train_loader_U = makeDataset(trainU,16)
train_loader_S = makeDataset(trainS,16)

test_loader_M = makeDataset(testM,16)
test_loader_U = makeDataset(testU,16)
test_loader_S = makeDataset(testS,16)

In [18]:
train_loaders = (train_loader_M,train_loader_U,train_loader_S)
test_loaders = (test_loader_M,test_loader_U,test_loader_S)

names = ['MNIST','USPS','SVHN']

In [19]:
# FedAvg
# 학습한 데이터 양 반영하지 않고 average

def fedAvg(models:list):
    num_model = len(models)
    weights_prime = [
        reduce(np.add, layer)/num_model for layer in zip(*models)
    ]
    return weights_prime

In [20]:
# train
global_model = None
round = 5
epoch = 4
for r in range(round): # round
    models=[]
    print('### round : ',r+1)
    for c in range(3):
        shape = (16,3,28,28)
        learning_rate = 0.001
        batches = [3,5,200]
        
        model = CNN(shape)
        if global_model != None:
            model.load_state_dict(global_model,strict=True)  # 첫번째 라운드 이후 global model 사용
            
        model = model.to(device)
        model.train()

        criterion = torch.nn.CrossEntropyLoss().to(device)
        optimizer = torch.optim.Adam(model.parameters(),lr = learning_rate)

        for e in range(epoch):
            avg_cost = 0
            loader = train_loaders[c] # epoch마다 loader 생성
            for i,data in enumerate(loader):
                x,y = data
                x=x.to(device)
                y=y.to(device)
                optimizer.zero_grad()
                predict = model(x)
                cost = criterion(predict,y)
                cost.backward()
                optimizer.step()

                avg_cost += cost / batches[c]
                if i== batches[c]:
                    break
            if (e+1)%2 == 0:
                print('[Epoch: {:>4}] cost = {:>.9}'.format(e + 1, avg_cost))
        models.append([val.cpu().numpy() for _, val in model.state_dict().items()])
        if r== round-1:
            torch.save(model.state_dict(),"./weight/federated/"+names[c]+".pkl")
    global_model = fedAvg(models)
    params_dict = zip(model.state_dict().keys(), global_model)
    global_model = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    
torch.save(global_model,"./weight/federated/global_model.pkl")


### round :  1
[Epoch:    2] cost = 3.04937792
[Epoch:    4] cost = 2.84907341
[Epoch:    2] cost = 2.52549314
[Epoch:    4] cost = 1.40759063
[Epoch:    2] cost = 2.24418998
[Epoch:    4] cost = 1.94459987
### round :  2
[Epoch:    2] cost = 2.86607361
[Epoch:    4] cost = 2.78193855
[Epoch:    2] cost = 2.38904452
[Epoch:    4] cost = 1.38893437
[Epoch:    2] cost = 1.45545208
[Epoch:    4] cost = 0.9842574
### round :  3
[Epoch:    2] cost = 1.6634562
[Epoch:    4] cost = 1.34624588
[Epoch:    2] cost = 1.09606087
[Epoch:    4] cost = 0.690509796
[Epoch:    2] cost = 1.00876665
[Epoch:    4] cost = 0.835080445
### round :  4
[Epoch:    2] cost = 0.498802602
[Epoch:    4] cost = 0.976793468
[Epoch:    2] cost = 0.462655872
[Epoch:    4] cost = 0.456617266
[Epoch:    2] cost = 0.824793875
[Epoch:    4] cost = 0.730243802
### round :  5
[Epoch:    2] cost = 0.866413593
[Epoch:    4] cost = 0.486019999
[Epoch:    2] cost = 0.86703521
[Epoch:    4] cost = 0.552861631
[Epoch:    2] cost =

In [21]:
# 입력
# 세가지 모델에 대해 입력을 받음
# 세가지 데이터셋에 대해 cross check 진행

# evaluation
for m in range(3) : # 세가지 모델
    shape = (16,3,28,28)
    model = CNN(shape)
    model.load_state_dict(torch.load("./weight/federated/"+names[m]+".pkl"))
    model = model.to(device)
    model.eval()
    print('model : ',names[m])
    for c in range(3): # 세가지 데이터셋
        loader = test_loaders[c]
        correct=0
        total=0

        #load model
        with torch.no_grad():
            for x,y in loader:
                x=x.to(device)
                y=y.to(device)
                predict = torch.max(model(x).data,1)[1]
                total+=len(y)
                correct+=(predict==y).sum().item()
        print(names[c]+' Test Accuracy: ',100.*correct/total, '%')

model :  MNIST
MNIST Test Accuracy:  88.35 %
USPS Test Accuracy:  76.65 %
SVHN Test Accuracy:  59.78027043638598 %
model :  USPS
MNIST Test Accuracy:  73.19 %
USPS Test Accuracy:  82.45 %
SVHN Test Accuracy:  43.99969268592501 %
model :  SVHN
MNIST Test Accuracy:  58.86 %
USPS Test Accuracy:  47.7 %
SVHN Test Accuracy:  80.10909649661954 %


In [23]:
# evaluation

#load global model
model = CNN(shape)
model.load_state_dict(torch.load("./weight/federated/global_model.pkl"))
model = model.to(device)
model.eval()

for c in range(3):
    shape = (16,3,28,28)
    loader = test_loaders[c]
    correct=0
    total=0
    with torch.no_grad():
        for x,y in loader:
            x=x.to(device)
            y=y.to(device)
            predict = torch.max(model(x).data,1)[1]
            total+=len(y)
            correct+=(predict==y).sum().item()
    print(names[c]+' global_model Test Accuracy: ',100.*correct/total, '%')

MNIST global_model Test Accuracy:  78.19 %
USPS global_model Test Accuracy:  78.95 %
SVHN global_model Test Accuracy:  75.24969268592501 %
