<a href="https://colab.research.google.com/github/jeffrey96158/Vertical-Federated-Learning-without-explicit-ID-Matching/blob/main/Vertical_Federated_Learning_with_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import Package

In [None]:
!pip install phe
#!pip install cupy-cuda100
!curl https://colab.chainer.org/install | sh -


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting phe
  Downloading phe-1.5.0-py2.py3-none-any.whl (53 kB)
[K     |████████████████████████████████| 53 kB 1.6 MB/s 
[?25hInstalling collected packages: phe
Successfully installed phe-1.5.0
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1553  100  1553    0     0   4076      0 --:--:-- --:--:-- --:--:--  4086
********************************************************************************
CUDA version could not be detected!
Try installing Chainer manually by: pip install chainer
********************************************************************************


In [None]:
# pytorch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.utils.data as data

# others
from tqdm.notebook import tqdm as tqdm
import argparse
from sklearn.metrics import f1_score
import random

import matplotlib.pyplot as plt
import numpy as np
#import cupy as cp
import pandas as pd
# pallier
from phe import paillier

# Download Dataset (MNIST)
MNIST is a data set of handwrite number

In [None]:
trans = transforms.Compose([transforms.ToTensor()])
train_set = dset.MNIST(root='.', train=True, download=True ,transform=trans)
test_set = dset.MNIST(root='.', train=False,transform=trans)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



# Arguments

In [None]:
parser = argparse.ArgumentParser(description='FL with PyTorch MNIST')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=50, metavar='S',
                    help='random seed (default: 50)')
parser.add_argument('--epochs', type=int, default=1, metavar='N',
                    help='number of epochs to train (default: 1)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
                    help='learning rate (default: 0.1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--batch-size', type=int, default=50, metavar='N',
                    help='input batch size for training (default: 50)')
parser.add_argument('--test-batch-size', type=int, default=50, metavar='N',
                    help='input batch size for testing (default: 50)')
parser.add_argument('--num-participants', type=int, default=12, metavar='NP',
                    help='number of participants (default: 12)')
parser.add_argument('--randomorder_rounds', type=int, default=1, metavar='NP',
                    help='rounds for Random Order (default: 1)')
parser.add_argument('--asynchronous_rounds', type=int, default=1, metavar='NP',
                    help='rounds for Asynchronous (default: 1)')
args = parser.parse_args(args=[])

# Class Definition

In [None]:
## Initialization
# Control Seed
# CUDA is a GPU core for parrallel computing
torch.manual_seed(args.seed)

# Select Device
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else 'cpu')

if use_cuda:
    print("Using CUDA!")
    torch.cuda.manual_seed(args.seed)
else:
    print('Not using CUDA!!!')

Using CUDA!


# Network
- Split the network into client_network and server_network
- client_network get partial data to local train, all client_network output would deliver to server for global training.

In [None]:
#使用linear regression model
#連續使用3次linear regression

#a minist data include a 28 x 28 image, we seperate data by the row.

class client_network(nn.Module):
    def __init__(self, num):
        super(client_network,self).__init__()
        self.L1 = nn.Linear(28 * num ,128)
        self.L2 = nn.Linear(128,64)
    def forward(self , x):
        x = F.relu(self.L1(x))
        x = F.relu(self.L2(x))
        return x

class server_network(nn.Module):
    def __init__(self, num):
        super(server_network,self).__init__()
        self.output = nn.Linear(64 * num ,10)
    def forward(self, x):
        x = self.output(x)  #若是loss用crossentropy 他最後一層會自己用softmax
        return x

#Client

In [None]:
class Participant():
    def __init__(self,data_num):
        global device
        self.local_network = client_network(data_num).to(device)
        self.loss_fn = nn.CrossEntropyLoss()
        self.lr = args.lr
        self.optimizer = torch.optim.SGD(params=self.local_network.parameters(torch.tensor(1.0, requires_grad=True)), lr = self.lr, momentum= 0.9)
        self.loss_history = []
        self.testloss_history = []
        self.accu_history = []
        self.gradient_buffer = dict() # download grandient_buffer
    def LocalTrain(self,data):
      net_out = self.local_network(data)
      return net_out


#Server

In [None]:
class GlobalServer():
    def __init__(self,participants_num):
        global device
        self.global_network = server_network(participants_num).to(device)
        self.gradient_buffer = dict()
        self.temp_gradient_buffer = dict()
        self.loss_history = []
        self.lr = args.lr
        self.optimizer = torch.optim.SGD(params=self.global_network.parameters(torch.tensor(1.0, requires_grad=True)), lr = self.lr, momentum= 0.9)
        self.loss_fn = nn.CrossEntropyLoss()
        # GlobalServer set a dict to preserve net_out
        self.net_out_buffer = dict()
        self.label_buffer = dict()
    def ServerTrain(self,net_in):
      net_out = self.global_network(net_in)
      return net_out

#SplitData

In [None]:
def GenerateParticipant(p_num):
    P_list = list()
    for p in range(p_num):
      P_list.append(Participant(int(28 / p_num)))
    return P_list

# TrainSet_list[num], TestSet_list[num] -> store the name
# TrainDataSet_dict[num], TestDataSet_dict[num] -> store the data
def SplitData(train_set, test_set, num):
    #print(f"len(train_set):{len(train_set)}, len(test_set):{len(test_set)}")
    #print(train_set[1])
    #print(f"len(train_set.columns):{train_set.shape}, len(test_set):{test_set.shape}")
    if len(train_set) % num or len(test_set) % num:
        print(f"len(train_set):{len(train_set)}, len(test_set):{len(test_set)}, can not be indivisible by {num}")
        return

    train_split = int(len(train_set) / num)
    test_split = int(len(test_set) / num)

    portions = [train_split] * num
    print(f'portion: {portions}')
    TrainSet_list = [None] * num
    TrainSet_list = data.random_split(train_set, portions)

    portions = [test_split] * num
    TestSet_list = [None] * num
    TestSet_list = data.random_split(test_set, portions)

    TrainDataSet_dict = dict()
    TestDataSet_dict = dict()

    # mini batch
    for i in range(num):
        s = 'P'
        s = s + str(i)
        TrainDataSet_dict[s] = data.DataLoader(dataset =  TrainSet_list[i], batch_size=args.batch_size, shuffle=True, num_workers = 0)
        TestDataSet_dict[s] = data.DataLoader(dataset =  TestSet_list[i], batch_size=args.test_batch_size, shuffle=True, num_workers = 0)

    return TrainDataSet_dict, TestDataSet_dict

# test num means how many test we would like to perform (or how many slices we divide the dataset)
# particapants num means how many participants in each test
def Multi_Participants_SplitData(train_set, test_set, participants_num):

    split_train_set = [[] for i in range(participants_num)]
    split_test_set = [[] for i in range(participants_num)]
    print(f'TrainSet_list type: {type(split_train_set)}')
    # 根據participants數量將train set和test set切成n等分
    split_list = list() # 每個participant會拿到一張照片的多少feature

    #設定每個participant拿多少feature
    if 28 % participants_num == 0:
      for i in range(participants_num):
        split_list.append(int(28 / participants_num))
    else:
      for i in range(participants_num):
        split_list.append(int(28 / participants_num))
      split_list.append(28 % participants_num)

    for i in range(len(train_set)):
      ts = train_set[i][0][0]
      split_data = torch.split(ts, split_list)
      # to check, may wrong, split_data[p]
      for p in range(participants_num):
        datas = (split_data[p],train_set[i][1])
        split_train_set[p].append(datas)

    for i in range(len(test_set)):
      ts = test_set[i][0][0]
      split_data = torch.split(ts, split_list)
      # to check, may wrong, split_data[p]
      for p in range(participants_num):
        datas = (split_data[p],test_set[i][1])
        split_test_set[p].append(datas)

    torch.manual_seed(0)

    TrainDataSet_dict = [[] for i in range(participants_num)]
    TestDataSet_dict = [[] for i in range(participants_num)]

    for p in range(participants_num):
      TrainDataSet_dict[p] = data.DataLoader(dataset =  split_train_set[p], batch_size=args.batch_size, shuffle=True, num_workers = 0)
      TestDataSet_dict[p] = data.DataLoader(dataset =  split_test_set[p], batch_size=args.test_batch_size, shuffle=True, num_workers = 0)

    return TrainDataSet_dict,TestDataSet_dict

def MySplitData(train_set, test_set, num, participants_num):
    #print(train_set[0])
    #print(train_set[1][0].size())
    if(28 % 2):
      print("can't split")
      return
    train_set_1 = list()
    train_set_2 = list()
    test_set_1 = list()
    test_set_2 = list()

    for i in range(len(train_set)):
      # tensor好像是三維的
      ts = train_set[i][0][0]
      split_data = torch.split(ts, 14)
      train_set_1.append((split_data[0], train_set[i][1]))
      train_set_2.append((split_data[1], train_set[i][1]))
    for i in range(len(test_set)):
      ts = test_set[i][0][0]
      split_data = torch.split(ts, 14)
      test_set_1.append((split_data[0], test_set[i][1]))
      test_set_2.append((split_data[1], test_set[i][1]))

    #fix the random
    torch.manual_seed(0)

    train_split = int(len(train_set_1) / num)
    test_split = int(len(test_set_1) / num)

    portions = [train_split] * num # [6000, 6000, ...]

    TrainSet_list_1 = [None] * num
    TrainSet_list_1 = data.random_split(train_set_1, portions)
    TrainSet_list_2 = [None] * num
    TrainSet_list_2 = data.random_split(train_set_2, portions)

    portions = [test_split] * num

    TestSet_list_1 = [None] * num
    TestSet_list_1 = data.random_split(test_set_1, portions)
    TestSet_list_2 = [None] * num
    TestSet_list_2 = data.random_split(test_set_2, portions)

    TrainDataSet1_dict = dict()
    TrainDataSet2_dict = dict()
    TestDataSet1_dict = dict()
    TestDataSet2_dict = dict()
    for i in range(num):
        s = 'P'
        s = s + str(i)
        TrainDataSet1_dict[s] = data.DataLoader(dataset =  TrainSet_list_1[i], batch_size=args.batch_size, shuffle=True, num_workers = 0)
        TrainDataSet2_dict[s] = data.DataLoader(dataset =  TrainSet_list_2[i], batch_size=args.batch_size, shuffle=True, num_workers = 0)
        TestDataSet1_dict[s] = data.DataLoader(dataset =  TestSet_list_1[i], batch_size=args.test_batch_size, shuffle=True, num_workers = 0)
        TestDataSet2_dict[s] = data.DataLoader(dataset =  TestSet_list_2[i], batch_size=args.test_batch_size, shuffle=True, num_workers = 0)

    return TrainDataSet1_dict,TrainDataSet2_dict,TestDataSet1_dict,TestDataSet2_dict



In [None]:
def zipping(data, p_num):

  if p_num == 1: return zip(data[0])
  if p_num == 2: return zip(data[0], data[1])
  if p_num == 3: return zip(data[0], data[1], data[2])
  if p_num == 4: return zip(data[0], data[1], data[2], data[3])


  return zip(data[0], data[1], data[2], data[3], data[4], data[5])

# Vertical Training

In [None]:
# TrainDataSet_dict[p][t]
# p is paticipant num, t is test num
# vetical_train_multiple(client_list, Server, test_data, train_data)
# client_list = list of participants in order
# train_data = list of training data of the same test

def vertical_train_multiple(client_list, Server, train_data,  epochs, p_num):
    #p_num is the number of participants
    train_loss_history = []
    accu_history = []
    Predicts = []
    Labels = []
    f1_history = []
    for e in tqdm(range(epochs)):
        epoch_loss_sum = 0
        correct = 0
        for i, pack in enumerate(zipping(train_data,p_num)):
            data = [0 for x in range(p_num)]
            label = [0 for x in range(p_num)]
            #next iteration
            for p in range(p_num):
                (data[p],label[p]) = pack[p]
                #print(data[p])
                #print(label[p])
                data[p],label[p] = data[p].to(device), label[p].to(device)
                batch_size = data[0].shape[0]
                data[p] = data[p].view(batch_size,-1)
            #Client forward training
            client_out = [0 for x in range(p_num)]
            for p in range(p_num):
                client_out[p] = client_list[p].LocalTrain(data[p])
                #print(client_out[p])
            c_out = torch.cat((client_out[0],client_out[1]),-1)
            #Concat the ouput of client training
            for p in range(2,p_num,1):
                c_out = torch.cat((c_out,client_out[p]),-1)

            #Server training
            server_out = Server.ServerTrain(c_out)

            #Calculating loss
            server_loss = Server.loss_fn(server_out,label[0])
            client_loss = [0 for x in range(p_num)]
            for p in range(p_num):
                client_loss[p] = Server.loss_fn(client_out[p],label[p])
            #Sever backward
            Server.optimizer.zero_grad()
            server_loss.backward(retain_graph=True)
            Server.optimizer.step()
            #Client backward
            for p in range(p_num-1):
                client_list[p].optimizer.zero_grad()
                client_loss[p].backward(retain_graph = True)
                client_list[p].optimizer.step()
            client_list[p_num-1].optimizer.zero_grad()
            client_loss[p_num-1].backward(retain_graph = True)
            client_list[p_num-1].optimizer.step()

            epoch_loss_sum += float(server_loss.item())
            predict = server_out.data.max(1, keepdim=True)[1]
            correct += predict.eq(label[0].data.view_as(predict)).sum().item()
            Labels = Labels + label[0].tolist()
            Predicts = Predicts + predict.tolist()

        epoch_loss_sum /= len(train_data[0].dataset)
        train_loss_history.append(epoch_loss_sum)
        accuracy = 100. * correct / len(train_data[0].dataset)
        accu_history.append(accuracy / 100)
        f1 = f1_score(Labels, Predicts, average = "macro")
        f1_history.append(f1)
    return train_loss_history, accu_history, f1_history


#Vertical Testing

In [None]:
def vertical_test_multiple(client_list,Server,test_data,p_num):
    p_num = len(client_list)
    correct = 0
    epoch_loss_sum = 0
    Labels = []
    Predicts = []
    testloss_history = []
    accu_history = []
    zip_data = zipping(test_data, p_num)

    for i, pack in enumerate(zipping(test_data, p_num)):
      it = [0 for x in range(p_num)]
      data = [0 for x in range(p_num)]
      label = [0 for x in range(p_num)]
      #next iteration
      for p in range(p_num):
        (data[p],label[p]) = pack[p]
        data[p],label[p] = data[p].to(device), label[p].to(device)
        batch_size = data[0].shape[0]
        data[p] = data[p].view(batch_size,-1)
      #Client forward training
      client_out = [0 for x in range(p_num)]
      for p in range(p_num):
          client_out[p] = client_list[p].LocalTrain(data[p])

      #Concat the ouput of client training
      c_out = torch.cat((client_out[0],client_out[1]),-1)
      for p in range(2,p_num,1):
          c_out = torch.cat((c_out,client_out[p]),-1)

      #Server training
      server_out = Server.ServerTrain(c_out)

      #Calculating loss
      server_loss = Server.loss_fn(server_out,label[0])
      client_loss = [0 for x in range(p_num)]
      for p in range(p_num):
          client_loss[p] = Server.loss_fn(client_out[p],label[p])
      epoch_loss_sum += float(server_loss.item())

      predict = server_out.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
      correct += predict.eq(label[0].data.view_as(predict)).sum().item()
      Labels = Labels + label[0].tolist()
      Predicts = Predicts + predict.tolist()


    epoch_loss_sum /= len(test_data[0].dataset)
    accuracy = 100. * correct / len(test_data[0].dataset)
    testloss_history.append(epoch_loss_sum)
    accu_history.append(accuracy / 100)
    f1 = f1_score(Labels, Predicts, average = "macro")
    print(f'Test set: Average loss: {epoch_loss_sum:.4f}, Accuracy: {correct}/{len(test_data[0].dataset)} ({accuracy:.2f}%), F1-score: {f1:.4f}')
    return accuracy



# Initialize

In [None]:
#測試準確率的 test dataset (全部10,000的資料)
test_dataset = data.DataLoader(dataset=test_set, batch_size=args.test_batch_size,shuffle=True)
#test num 決定要做幾次test
p_num = 4
# Dataset and Participant Gernerate
Participant_dict = GenerateParticipant(p_num)
# TrainData_dict, TestData_dict = SplitData(train_set, test_set, args.num_participants)
# TrainData1_dict, TrainData2_dict, TestData1_dict, TestData2_dict = MySplitData(train_set, test_set, 5, 2)


# TrainData[participants_num][set_num]
TrainData_dict, TestData_dict = Multi_Participants_SplitData(train_set, test_set, p_num)

# Participants keys
#Participant_keys = list(Participant_dict.keys())

# global server initialization
Server = GlobalServer(p_num)

TrainSet_list type: <class 'list'>


# Vertical Training test

In [None]:
p_num_list = [4,6]

for p_num in p_num_list:
    Participant_dict = GenerateParticipant(p_num)
    TrainData_dict, TestData_dict = Multi_Participants_SplitData(train_set, test_set, p_num)
    Server = GlobalServer(p_num)

    train_loss, train_accuracy, train_f1 = vertical_train_multiple(Participant_dict,Server,TrainData_dict,300,p_num)


    x_value = np.arange(0, len(train_loss), 1)
    plt.plot(x_value, train_accuracy, color = 'c')
    plt.plot(x_value, train_f1, color = 'y')
    ax = plt.gca()
    ax.set_ylim([0, 1])
    plt.ylabel('accuracy and f1_score')
    plt.xlabel('Training epochs')
    s = str(p_num) + "_participants_with_lr_=_" + str(args.lr)
    plt.title(s)
    plt.show()


    plt.plot(x_value,train_loss,color = 'r')
    plt.ylabel("loss") # y label
    plt.xlabel("Training epochs") # x label
    s = str(p_num) + "_participants_with_lr_=_" + str(args.lr)
    plt.title(s)
    #plt.savefig(f"{images_dir}/{f}")
    plt.show()

    vertical_test_multiple(Participant_dict,Server,TestData_dict,p_num)



TrainSet_list type: <class 'list'>


  0%|          | 0/300 [00:00<?, ?it/s]

# Multiple Participant Experiment


In [None]:
p_list = [4]
accu = list()
for part in p_list:
  #測試準確率的 test dataset (全部10,000的資料)
  test_dataset = data.DataLoader(dataset=test_set, batch_size=args.test_batch_size,shuffle=True)
  #test num 決定要做幾次test
  test_num = 5
  p_num = part
  # Dataset and Participant Gernerate
  Participant_dict = GenerateParticipant(p_num,test_num)
  # TrainData_dict, TestData_dict = SplitData(train_set, test_set, args.num_participants)
  # TrainData1_dict, TrainData2_dict, TestData1_dict, TestData2_dict = MySplitData(train_set, test_set, 5, 2)

  # TrainData[participants_num][set_num]
  TrainData_dict, TestData_dict = Multi_Participants_SplitData(train_set, test_set, test_num, p_num)

  # Participants keys
  Participant_keys = list(Participant_dict.keys())

  # global server initialization
  Server = GlobalServer(p_num)
  print(f'Now training {part} participants vertical training')
  #training
  for _ in tqdm(range(args.randomorder_rounds)):
      Client_list = [[0 for x in range(p_num)] for x in range(test_num)]
      Train_data_list = [[0 for x in range(p_num)] for x in range(test_num)]
      Test_data_list = [[0 for x in range(p_num)] for x in range(test_num)]
      # maintain Clinet_list
      for t in range(test_num):
          for p in range(p_num):
              s = 'P'
              j = t * p_num + p
              s = s + str(j)
              #print(f't is {t}, p is {p}, s is {s}')
              Client_list[t][p] = Participant_dict[s]
      # maintain Train_data_list
      for t in range(test_num):
          s = 'T'
          s = s + str(t)
          for p in range(p_num):
              Train_data_list[t][p] = TrainData_dict[p][s]
      # maintain Test_data_list
      for t in range(test_num):
          s = 'T'
          s = s + str(t)
          for p in range(p_num):
              Test_data_list[t][p] = TestData_dict[p][s]
      accuracy = 0
      for t in tqdm(range(test_num)):
          print(f'{part} participants test {t}')
          vertical_train_multiple(Client_list[t], Server , Train_data_list[t], 300, p_num)
          accuracy += vertical_test_multiple(Client_list[t],Server,Test_data_list[t])
      accuracy /= test_num
      accu.append(accuracy)

  # #ploting figure
  # width = 0.3
  # x = np.arrange(len(p_list))
  # plt.bar(x,accu,color = 'green')
  # plt.xticks(x,p_list)
  # plt.title("Accuracy of multiple participant")
  # plt.show()