In [1]:
import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable,Dict,Optional,Tuple
import torchvision.transforms as transforms
import torchvision.datasets
from torchvision.datasets import CIFAR10
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import os
import sys
import pickle
from collections import OrderedDict
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('model'))))
from model import *
import pickle
import torchvision.transforms as transforms

In [21]:
def load_data_client():
  #data loader
  '''Load data from source'''
  transform = transforms.Compose(
  [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5,))]
  )
  trainset = torchvision.datasets.MNIST(root='./dataset',train=True,download=True,transform=transform)
  print(trainset)
  trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
  print(len(trainloader))
  print(len(trainloader))
  return trainloader

In [22]:
server_net = Server_Net()
client_net = Client_Net()
train_loader = load_data_client()

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./dataset
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5,), std=(0.5,))
           )
1875
1875


In [14]:
len(train_loader)

1875

In [15]:
1875*32

60000

## client part 학습

In [7]:
def train_client(trainloader, epochs):
  client_net.eval()
  for i,data in enumerate(trainloader):
      images, labels = data
      output= client_net(images)
      output = output.data
      output_arr = output.numpy()
      target_arr = labels.numpy()
      if i==0:
        print(output_arr.shape)
        print(target_arr.shape)
        input_arr = output_arr
        label_arr = target_arr
      else:
        input_arr = np.concatenate((input_arr,output_arr),axis=0)
        label_arr = np.concatenate((label_arr,target_arr),axis=0)
  return input_arr, label_arr

In [8]:
input,label = train_client(train_loader,1)

(32, 16, 5, 5)
(32,)


## server part 학습

In [2]:
def load_data_server():
    #data loader
    '''load data from server'''
    #load code, set loader by transformed data
    loaders = []
    for clientsID in range(1):
        transformed_input_tensor = torch.Tensor(input)
        transformed_label_tensor = torch.Tensor(label).type(torch.LongTensor)

        transformed_data = TensorDataset(transformed_input_tensor,transformed_label_tensor)
        transformed_loader = DataLoader(transformed_data,batch_size=32,shuffle=False)
    loaders.append(transformed_loader)
    return loaders

In [3]:
def splitParameter(net):
  client_layer = {}
  num_client_layer=2
  for index,(name, val) in enumerate(net.state_dict().items()):
    if index >= num_client_layer*2:
      break
    client_layer[name] = val
  return client_layer

In [4]:
def train_server(net, loaders):
  """Train the frozen network on the training set, only one time"""
  criterion = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(net.parameters(),lr = 0.001,momentum=0.9)
  net.train() # frozen part이기 때문에 eval mode에서 진행
  num_examples = 0
  for trainloader in loaders:
    for images, labels in trainloader:
        optimizer.zero_grad()
        loss = criterion(net(images), labels)
        loss.backward()
        optimizer.step()
        num_examples += labels.size(0)
  #weight에서 client부분 분리하여 전달
  net = splitParameter(net)
  print(net.keys())
  return [val.cpu().numpy() for _, val in net.items()]

In [12]:
server_loaders = load_data_server()
parameter = train_server(server_net,server_loaders)

dict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])


In [36]:
client_net.state_dict()['conv1.weight'].size()

torch.Size([6, 1, 3, 3])

In [32]:
len([val.cpu().numpy() for _, val in client_net.state_dict().items()])
len([val.cpu().numpy() for _, val in client_net.state_dict().items()][0])

6

## 정확도 test

In [7]:
with open('parameter.pkl','rb') as f:
    test_net = pickle.load(f)

In [8]:
def load_data_test():
  #data loader
  '''Load data from source'''
  transform = transforms.Compose(
  [transforms.ToTensor(), transforms.Normalize((0.5, ), (0.5,))]
  )
  testset = torchvision.datasets.MNIST(root='./dataset',train=False,download=True,transform=transform)
  testloader = DataLoader(testset, batch_size=32)
  return testloader

In [9]:
def eval(net,test_loader):
    net.eval()
    correct=0
    total=0
    for images, labels in test_loader:
        _,predicted = torch.max(net.test(images).data,1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()
    print(100*correct/total)
        

In [10]:
test_loader = load_data_test()

In [11]:
eval(test_net,test_loader)

43.02
