In [13]:
## Cloud_Server.py
## Distributed -> C_server -> L_server -> Edge
import torch
from torch import optim
from torch.autograd import Variable
from Model import twoNN
from Data_Loader import loader

class CS(object):
    def __init__(self, size, data_loader):
        self.size = size
        self.test_loader = data_loader[1]
        self.model = twoNN()
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
        self.accuracy = []
        self.ESs = [None]*size
        self.count=0

    @staticmethod
    def __average_weights(self,ESs):
        for info in ESs[1:]:
            for key in info:
                ESs[0][key] = info[key] + ESs[0][key]
        for key in ESs[0]:
            ESs[0][key] = ESs[0][key] / self.size  
        weights = ESs[0]
        return weights

    def __test(self):
        test_correct = 0
        self.model.eval()
        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = Variable(data),Variable(target)
                output = self.model(data)
                pred = output.argmax(dim=1, keepdim=True)
                test_correct += pred.eq(target.view_as(pred)).sum().item()
        return test_correct / len(self.test_loader.dataset)

    def aggregate(self):
        weights_info = self.ESs
        weights = self.__average_weights(self,weights_info)
        self.model.load_state_dict(weights)
        test_accuracy = self.__test()
        self.accuracy.append(test_accuracy)
        print('\n[Global Model]  Test Accuracy: {:.2f}%\n'.format(test_accuracy * 100.))

In [14]:
## Edge_Server.py
import torch
from torch import optim
from torch.autograd import Variable
from Model import twoNN
from Data_Loader import loader


class ES():
    def __init__(self, size, data_loader, CS):
        self.size = size
        self.test_loader = data_loader[1]
        self.model = twoNN()
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3)
        self.accuracy = []
        self.clients = [None]*size
        self.count=0
        self.CS = CS

    @staticmethod
    def __average_weights(self,clients):
        for info in clients[1:]:
            for key in info:
                clients[0][key]=info[key] + clients[0][key]
        for key in clients[0]:
            clients[0][key]=clients[0][key]/self.size  
        weights=clients[0]
        return weights


## Optinal 
    def __test(self):
        test_correct = 0
        self.model.eval()
        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = Variable(data), Variable(target)
                output = self.model(data)
                pred = output.argmax(dim=1, keepdim=True)
                test_correct += pred.eq(target.view_as(pred)).sum().item()
        return test_correct / len(self.test_loader.dataset)

    def aggregate(self):
        ## overlapped client weight gamma needed 
        weights_info = self.clients
        weights = self.__average_weights(self,weights_info)
        self.model.load_state_dict(weights)
        self.CS.ESs[self.CS.count%self.CS.size]=weights
        self.CS.count+=1
        test_accuracy = self.__test()
        self.accuracy.append(test_accuracy)
        print('\n[Global Model]  Test Accuracy: {:.2f}%\n'.format(test_accuracy * 100.))

In [15]:
## MES_Client.py
import torch

from torch import nn
from torch.autograd import Variable
class Client(object):
    def __init__(self, rank, data_loader, local_epoch, ES):
        # seed
        seed = 19201077 + 19950920 + rank
        torch.manual_seed(seed)
        self.rank = rank
        self.local_epoch = local_epoch
        self.ES=ES
        self.test_loader = data_loader[1]
        self.train_loader = iter(data_loader[0])

    @staticmethod
    def __load_global_model(self):
        model = twoNN()
        n = len(self.ES)
        if(n!=1):
            weight = self.ES[0].model.state_dict()
            for i in range(1,n):
                for key in weight:
                    weight[key] = weight[key] + self.ES[i].model.state_dict()[key]
            
            for key in weight:
                weight[key] = weight[key]/n
                
            model.load_state_dict(weight)    
        
        else:
            model.load_state_dict(self.ES[0].model.state_dict())  
            
        return model

    def __train(self, model):
        # update local model
        train_loss = 0
        train_correct = 0
        model.train()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)

        for e in range(self.local_epoch):
            self.train_loader.__next__()
            chk=0
            for data, target in self.train_loader:
                if(chk==1):
                    self.train_loader.__next__()
                    ## 끝나는 지점 boundary 만들기
                    break
                chk+=1
                data, target = Variable(data), Variable(target)
                optimizer.zero_grad()
                output = model(data)
                loss = nn.CrossEntropyLoss()(output, target)
                loss.backward()
                optimizer.step()      
                if(e==self.local_epoch-1):
                    train_loss += loss ## 필요시 local_epoch 마다 찍어보기. 현 상태 -> 각 client의 마지막 local epoch 데이터 저장
                    pred = output.argmax(dim=1, keepdim=True)
                    train_correct += pred.eq(target.view_as(pred)).sum().item()
            
        print('[Rank {:>2}]  Loss: {:>4.6f},  Accuracy: {:>.4f}'.format(
            self.rank,
            train_loss,
            ##train_correct / len(self.train_loader.dataset)
            train_correct / 64
        ))
        weights=model.state_dict()
        return weights

    def run(self):
        model = self.__load_global_model(self)
        weights = self.__train(model=model)
        for es in self.ES:
            es.clients[es.count%es.size]=weights
            es.count+=1
        


In [16]:
## Hand_over Algorithm
import numpy as np
import math as mt
import matplotlib.pyplot as plt
import random
def cleint_genrate(n_client = 100, n_cell = 25, stride = 40, radius = 20, bound = 60):
    figure, axes = plt.subplots()
    axes.set_aspect(1)
    clients = np.random.randint(-bound, bound, size=(n_client, 2))
    x = []
    y = []
    ESs = [[None for j in range(2)] for i in range(n_cell)]
    n = int(mt.sqrt(n_cell))
    for i in range(n):
        for j in range(n):
            ESs[i*n+j][1]=radius*i-stride
            ESs[i*n+j][0]=radius*j-stride
            draw_circle = plt.Circle((radius*i-stride, radius*j-stride), radius,fill=False)
            axes.add_artist(draw_circle)
    for i in range(n_client):
        x.append(clients[i][0])
        y.append(clients[i][1])
    plt.scatter(x,y)
    #print(x)
    ES = [[[ [] for j in range(2)] for i in range(100)] for k in range(n_cell)]  ## 100 : temporary allocate memory, it can be changed. some boundary value needs more memory
    cnt = 0

    for i in range(n_cell):
        temp = []
        cnt=0
        for j in range(n_client):
            dist = mt.sqrt(( ESs[i][0]-clients[j][0] )**2+( ESs[i][1]-clients[j][1] )**2)
            if(dist<radius):
                ES[i][cnt][0]=clients[j][0]
                ES[i][cnt][1]=clients[j][1]
                cnt+=1
            
    for i in range(n_cell):
        chk = 0   
        for j in range(100):
            if(np.size(ES[i][j][0])==0):
                chk = j
                del ES[i][j:]
                break
            
    return ES, clients

def ID_generate(ES, ES_obj, clients):
    client_ID ={}
    obj_ID={}
    temp_ES = []
    temp_obj = []
    cnt = 0 
    for i in clients:
        for j in range(len(ES)):
            for k in ES[j]:
                if(all(i == k)):
                    temp_ES.append(j)
                    temp_obj.append(ES_obj[j])
        client_ID[cnt] = temp_ES
        obj_ID[cnt] = temp_obj
        temp_ES = []
        temp_obj = []
        cnt+=1
    # for k,v in client_ID.items():
    #     print(k, v)
    return client_ID, obj_ID

In [None]:
def fed_MES():
    # hyper parameter
    n_client   = 100
    n_ES       = 25
    ES_epoch    = 2
    Cl_epoch    = 20
    batch_size = 64
    ESs = []
    ES_client = [[] for j in range(n_ES) ]
    clients_obj = []
    print('Initialize Dataset...')
    # data_loader = loader('mnist', batch_size=batch_size)
    data_loader = loader('mnist', batch_size=batch_size)

    print('Initialize Cloud Server...')
    CS1 = CS(size=n_ES, data_loader=data_loader.get_loader([]))
    ES_list, clients = cleint_genrate()
    for i in range(n_ES):
        ESs.append(ES(size=len(ES_list[i]), data_loader=data_loader.get_loader([]), CS=CS1))
    client_ID, obj_ID = ID_generate(ES_list, ESs ,clients)
    print('Initialize Edge Servers and Clients...')

    for i in range(n_client):
        clients_obj.append(Client(rank=i, data_loader=data_loader.get_loader(
        random.sample(range(0, 10), 4)
        ),local_epoch=Cl_epoch,
        ES = list(obj_ID.values())[i] ))
            
    for k, v in client_ID.items():
        for i in v:
            ES_client[i].append(clients_obj[k])
    
    for ESe in range(ES_epoch):
        print('\n================== Edge Server Epoch {:>3} =================='.format(ESe + 1))
        for ESn in range(n_ES):
            print("================= Edge Server :",ESn,"process =================")
            for c in ES_client[ESn]:
                c.run()
            ESs[ESn].aggregate()
    CS1.aggregate()   
if __name__ == '__main__':
    fed_MES()
