In [1]:
from torchvision.datasets import MNIST, USPS, SVHN
import torch 
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import ToTensor,transforms,Compose,Resize,Lambda
from model import CNN
import numpy as np
from collections import OrderedDict
from functools import reduce

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)
device

'cuda'

In [3]:
# data download
trainM = MNIST(root='data',train=True,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
trainU = USPS(root='data',train=True,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
trainS = SVHN(root='data',split='train',download=True,transform=Compose([
Resize(28),
ToTensor(),
]))
testM = MNIST(root='data',train=False,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
testU = USPS(root='data',train=False,download=True,transform=Compose([
Resize(28),
ToTensor(),
Lambda(lambda x: x.repeat(3,1,1)),
]))
testS = SVHN(root='data',split='test',download=True,transform=Compose([
Resize(28),
ToTensor(),
]))

Using downloaded and verified file: data/train_32x32.mat
Using downloaded and verified file: data/test_32x32.mat


In [4]:
def makeDataset(raw_data,batch_size):
    data_loader = DataLoader(dataset=raw_data,
                batch_size=batch_size,
                shuffle=True,
                drop_last=True) # batch 마지막 데이터 버림
    return data_loader

train_loader_M = makeDataset(trainM,16)
train_loader_U = makeDataset(trainU,16)
train_loader_S = makeDataset(trainS,16)

test_loader_M = makeDataset(testM,16)
test_loader_U = makeDataset(testU,16)
test_loader_S = makeDataset(testS,16)

In [5]:
train_loaders = (train_loader_M,train_loader_U,train_loader_S)
test_loaders = (test_loader_M,test_loader_U,test_loader_S)

names = ['MNIST','USPS','SVHN']

In [6]:
# FedAvg
# 학습한 데이터 양 반영하지 않고 average

def fedAvg(models:list):
    num_model = len(models)
    weights_prime = [
        reduce(np.add, layer)/num_model for layer in zip(*models)
    ]
    return weights_prime

In [7]:
shape = (16,3,28,28)
model = CNN(shape)
list(model.state_dict().keys())[:2]

['conv_layer.0.weight', 'conv_layer.0.bias']

In [8]:
# train
global_model = None
round = 5
epoch = 4
models={}
for r in range(round): # round
    print('### round : ',r+1)
    for c in range(3):
        shape = (16,3,28,28)
        learning_rate = 0.001
        batches = [3,5,200]
        
        model = CNN(shape)
        if global_model != None:
            model.load_state_dict(global_model,strict=True)  # 첫번째 라운드 이후 global model 사용
            
        model = model.to(device)
        model.train()

        criterion = torch.nn.CrossEntropyLoss().to(device)
        optimizer = torch.optim.Adam(model.parameters(),lr = learning_rate)

        for e in range(epoch):
            avg_cost = 0
            loader = train_loaders[c] # epoch마다 loader 생성
            for i,data in enumerate(loader):
                x,y = data
                x=x.to(device)
                y=y.to(device)
                optimizer.zero_grad()
                predict = model(x)
                cost = criterion(predict,y)
                cost.backward()
                optimizer.step()

                avg_cost += cost / batches[c]
                if i== batches[c]:
                    break
            if (e+1)%2 == 0:
                print('[Epoch: {:>4}] cost = {:>.9}'.format(e + 1, avg_cost))
        models[names[c]]=([val.cpu().numpy() for _, val in list(model.state_dict().items())[:2] ]) # for split-low
        if r== round-1:
            torch.save(model.state_dict(),"./weight/split_low/"+names[c]+".pkl")
    global_model = fedAvg(models.values())
    params_dict = zip(list(model.state_dict().keys())[:2], global_model)
    global_model = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    
torch.save(global_model,"./weight/split_low/global_model.pkl")


### round :  1
